Initial import
This commit is contained in:
26
.gitignore
vendored
Normal file
26
.gitignore
vendored
Normal file
@@ -0,0 +1,26 @@
|
||||
# Allowlisting gitignore template for GO projects prevents us
|
||||
# from adding various unwanted local files, such as generated
|
||||
# files, developer configurations or IDE-specific files etc.
|
||||
#
|
||||
# Recommended: Go.AllowList.gitignore
|
||||
|
||||
# Ignore everything
|
||||
*
|
||||
|
||||
# But these files...
|
||||
!.gitignore
|
||||
|
||||
!*.go
|
||||
!go.sum
|
||||
!go.mod
|
||||
|
||||
!README.md
|
||||
!LICENSE
|
||||
|
||||
!Makefile
|
||||
|
||||
!*.sh
|
||||
!*.md
|
||||
|
||||
# ...even if they are in subdirectories
|
||||
!*/
|
||||
64
convert.go
Normal file
64
convert.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"go.mongodb.org/mongo-driver/v2/bson"
|
||||
|
||||
// "github.com/go-viper/mapstructure/v2"
|
||||
)
|
||||
|
||||
func ToMap(data any) (bson.M, error) {
|
||||
// 1. Marshal the struct to BSON bytes
|
||||
b, err := bson.Marshal(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 2. Unmarshal the bytes back into a bson.M
|
||||
var res bson.M
|
||||
err = bson.Unmarshal(b, &res)
|
||||
|
||||
return res, err
|
||||
}
|
||||
|
||||
func ToStruct(m bson.M, target any) error {
|
||||
// 1. Convert map to BSON bytes
|
||||
data, err := bson.Marshal(m)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 2. Unmarshal bytes into the target struct
|
||||
return bson.Unmarshal(data, target)
|
||||
}
|
||||
|
||||
// func ToMap(input any) (bson.M, error) {
|
||||
// var result bson.M
|
||||
// config := &mapstructure.DecoderConfig{
|
||||
// TagName: "bson", // Use bson tags instead of field names
|
||||
// Result: &result,
|
||||
// }
|
||||
|
||||
// decoder, err := mapstructure.NewDecoder(config)
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
|
||||
// err = decoder.Decode(input)
|
||||
// return result, err
|
||||
// }
|
||||
|
||||
// func BsonToStructHook() mapstructure.DecodeHookFunc {
|
||||
// return func(f reflect.Type, t reflect.Type, data interface{}) (interface{}, error) {
|
||||
// // 1. Convert primitive.ObjectID -> string
|
||||
// if f == reflect.TypeOf(primitive.ObjectID{}) && t.Kind() == reflect.String {
|
||||
// return data.(primitive.ObjectID).Hex(), nil
|
||||
// }
|
||||
|
||||
// // 2. Convert primitive.DateTime -> time.Time
|
||||
// if f == reflect.TypeOf(primitive.DateTime(0)) && t == reflect.TypeOf(time.Time{}) {
|
||||
// return data.(primitive.DateTime).Time(), nil
|
||||
// }
|
||||
|
||||
// return data, nil
|
||||
// }
|
||||
// }
|
||||
6
defaults.go
Normal file
6
defaults.go
Normal file
@@ -0,0 +1,6 @@
|
||||
package mongo
|
||||
|
||||
import _ "embed"
|
||||
|
||||
//go:embed structful.json
|
||||
var structfulJson []byte
|
||||
22
delete_one.go
Normal file
22
delete_one.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"go.mongodb.org/mongo-driver/v2/bson"
|
||||
)
|
||||
|
||||
// DeleteOne will delete the first document that matches the filter.
|
||||
func (c *MongoClient) DeleteOne(ctx context.Context, database, name string, filter bson.M) error {
|
||||
|
||||
// 1. Prepare query.
|
||||
collection := c.GetCollection(database, name)
|
||||
|
||||
// 2. Query
|
||||
_, err := collection.DeleteOne(ctx, filter)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
127
find.go
Normal file
127
find.go
Normal file
@@ -0,0 +1,127 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
// "log"
|
||||
"context"
|
||||
// "fmt"
|
||||
// "encoding/base64"
|
||||
|
||||
// "go.mongodb.org/mongo-driver/v2/mongo/options"
|
||||
"go.mongodb.org/mongo-driver/v2/bson"
|
||||
|
||||
// "git.gsuntres.com/boxtep/boxtep/core"
|
||||
)
|
||||
|
||||
// Find is used to fetch the first page of data.
|
||||
func (c *MongoClient) Find(ctx context.Context, database, name string, filter bson.M, limit int64) (bson.M, error) {
|
||||
|
||||
// 1. Prepare to query.
|
||||
collection := c.GetCollection(database, name)
|
||||
|
||||
pageSize := max(limit, c.Limit)
|
||||
|
||||
|
||||
// id := DecodeCursor(nextCursor)
|
||||
|
||||
// filter["_id"] = bson.M{"_id": bson.M{"$gt": id}}
|
||||
// opts := options.Find().
|
||||
// SetLimit(pageSize + 1).
|
||||
// SetSort(bson.M{"_id": 1})
|
||||
// id := DecodeCursor(nextCursor)
|
||||
sort := bson.M{"_id": 1}
|
||||
// filter["_id"] = bson.M{"_id": bson.M{"$gt": id}}
|
||||
|
||||
pipeline := BuildPaginationPipeline(0, pageSize + 1, filter, sort)
|
||||
|
||||
// 2. Query
|
||||
cursor, err := collection.Aggregate(ctx, pipeline)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer cursor.Close(ctx)
|
||||
|
||||
// 3. Build results
|
||||
var facetResults []bson.M
|
||||
if err = cursor.All(ctx, &facetResults); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
root := facetResults[0]
|
||||
|
||||
data := root["data"].(bson.A)
|
||||
|
||||
metadata := root["metadata"].(bson.A)
|
||||
|
||||
var totalValue any
|
||||
if len(metadata) != 0 {
|
||||
metadataRoot := metadata[0].(bson.M)
|
||||
totalValue = metadataRoot["total"]
|
||||
}
|
||||
|
||||
var total int64
|
||||
switch v := totalValue.(type) {
|
||||
case int32:
|
||||
total = int64(v)
|
||||
case int64:
|
||||
total = v
|
||||
default:
|
||||
total = 0
|
||||
}
|
||||
|
||||
hasMore := false
|
||||
if int64(len(data)) > pageSize {
|
||||
hasMore = true
|
||||
data = data[:pageSize]
|
||||
}
|
||||
|
||||
out := bson.M{
|
||||
"data": data,
|
||||
"has_more": hasMore,
|
||||
"total": total,
|
||||
}
|
||||
|
||||
if hasMore {
|
||||
// next cursor
|
||||
var last bson.M = data[len(data) - 1].(bson.M)
|
||||
var nextCursor string
|
||||
lastId := last["_id"]
|
||||
|
||||
nextCursor = EncodeCursor(lastId.(bson.ObjectID))
|
||||
|
||||
out["next_cursor"] = nextCursor
|
||||
}
|
||||
|
||||
// _data, err := bson.Marshal(out)
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
|
||||
// var r bson.M
|
||||
// if err := bson.Unmarshal(_data, &r); err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// func (c *MongoClient) FindNext(ctx context.Context, database, name string, filter bson.M, nextCursor string, limit int64) ([]bson.M, error) {
|
||||
// collection := c.GetCollection(database, name)
|
||||
|
||||
// opts := options.Find().
|
||||
// SetLimit(max(limit, c.Limit)).
|
||||
// SetSort()
|
||||
|
||||
// cursor, err := collection.Find(ctx, filter)
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
|
||||
// var results []bson.M
|
||||
// if err = cursor.All(ctx, &results); err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
|
||||
// return results, err
|
||||
// }
|
||||
|
||||
// func EncodeCursor()
|
||||
77
find_cursor.go
Normal file
77
find_cursor.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
|
||||
"go.mongodb.org/mongo-driver/v2/bson"
|
||||
)
|
||||
|
||||
func EncodeCursor(v any) string {
|
||||
var id string
|
||||
switch v.(type) {
|
||||
case bson.ObjectID:
|
||||
id = v.(bson.ObjectID).Hex()
|
||||
default:
|
||||
id = v.(string)
|
||||
}
|
||||
|
||||
val := fmt.Sprintf("%s", id)
|
||||
|
||||
return base64.RawURLEncoding.EncodeToString([]byte(val))
|
||||
}
|
||||
|
||||
func DecodeCursor(okey string) (bson.ObjectID, error) {
|
||||
data, err := base64.StdEncoding.DecodeString(okey)
|
||||
|
||||
if err != nil {
|
||||
return bson.NilObjectID, err
|
||||
}
|
||||
|
||||
oid, _ := bson.ObjectIDFromHex(string(data))
|
||||
|
||||
return oid, nil
|
||||
}
|
||||
|
||||
|
||||
// func EncodeCursor(name string, v1 any, v2 any) string {
|
||||
// var otherValue string
|
||||
// var otherType string
|
||||
|
||||
// switch v1.(type) {
|
||||
// case time.Time:
|
||||
// otherType = "time.Time"
|
||||
// otherValue = v1.(time.Time).Truncate(time.Millisecond).Format()
|
||||
// case int:
|
||||
// otherType = "time.Time"
|
||||
// otherValue = v1.(time.Time).Truncate(time.Millisecond).Format()
|
||||
// }
|
||||
|
||||
// var id string
|
||||
// switch v2.(type) {
|
||||
// case bson.ObjectID:
|
||||
// id = v2.(bson.ObjectID).Hex()
|
||||
// case string:
|
||||
// id = v2.(string)
|
||||
// default:
|
||||
// id = v2.(string)
|
||||
// }
|
||||
|
||||
// v := fmt.Sprintf("%v|%s", misc, id)
|
||||
|
||||
// fmt.Println(v, name)
|
||||
|
||||
// return base64.RawURLEncoding.EncodeToString([]byte(v))
|
||||
// }
|
||||
|
||||
// func DecodeCursor(okey string) (name string, v1 any, v2 any) {
|
||||
// data, err := base64.StdEncoding.DecodeString(okey)
|
||||
// if err != nil {
|
||||
// return "", nil, nil
|
||||
// }
|
||||
|
||||
// parts := strings.Split(string(data), "|")
|
||||
// if len(parts) != 3 {
|
||||
// return "", nil, nil
|
||||
// }
|
||||
// }
|
||||
26
find_cursor_test.go
Normal file
26
find_cursor_test.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"time"
|
||||
"testing"
|
||||
|
||||
"go.mongodb.org/mongo-driver/v2/bson"
|
||||
)
|
||||
|
||||
func TestCursor_Default(t *testing.T) {
|
||||
now := time.Now()
|
||||
|
||||
oid := bson.NewObjectIDFromTimestamp(now)
|
||||
|
||||
okey := EncodeCursor(oid)
|
||||
|
||||
id, err := DecodeCursor(okey)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Decode failed: %v", err)
|
||||
}
|
||||
|
||||
if id != oid {
|
||||
t.Fatalf("Failed to decode id")
|
||||
}
|
||||
}
|
||||
23
find_one.go
Normal file
23
find_one.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"go.mongodb.org/mongo-driver/v2/bson"
|
||||
)
|
||||
|
||||
// FindOne will return at most one document based on the filter provided.
|
||||
func (c *MongoClient) FindOne(ctx context.Context, database, name string, filter bson.M) (bson.M, error) {
|
||||
|
||||
// 1. Prepare to query.
|
||||
collection := c.GetCollection(database, name)
|
||||
|
||||
// 2. Query
|
||||
var out bson.M
|
||||
err := collection.FindOne(ctx, filter).Decode(&out)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
30
find_one_test.go
Normal file
30
find_one_test.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"go.mongodb.org/mongo-driver/v2/bson"
|
||||
)
|
||||
|
||||
func TestFindOne_Default(t *testing.T) {
|
||||
client := GetMongoClient()
|
||||
|
||||
data := map[string]any {
|
||||
"_id": "su_123459",
|
||||
"name": "MyNameFindOne",
|
||||
"age": int32(25),
|
||||
}
|
||||
|
||||
o, err := client.InsertOne(context.Background(), "mydb", "mycollection", data)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to insertOne %#v", err)
|
||||
}
|
||||
filter := bson.M{"name": "MyNameFindOne"}
|
||||
found, err := client.FindOne(context.Background(), "mydb", "mycollection", filter)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to findOne %#v", err)
|
||||
}
|
||||
|
||||
AssertSubset(t, o, found, "Should have been equal")
|
||||
}
|
||||
51
find_test.go
Normal file
51
find_test.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"go.mongodb.org/mongo-driver/v2/bson"
|
||||
)
|
||||
|
||||
func TestFind_Default(t *testing.T) {
|
||||
client := GetMongoClient()
|
||||
|
||||
filter := bson.M{"name": bson.M{"$regex": "OSRAM"}}
|
||||
findResult, err := client.Find(context.Background(), "mydb", "mycollection", filter, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to insertOne %#v", err)
|
||||
}
|
||||
|
||||
dataAny, hasData := findResult["data"]
|
||||
if !hasData {
|
||||
t.Fatal("no data")
|
||||
}
|
||||
|
||||
data := dataAny.(bson.A)
|
||||
|
||||
if len(data) != 1 {
|
||||
t.Fatalf("Expected to return 1 document but got %d", len(data))
|
||||
}
|
||||
|
||||
hasMoreAny, hasMoreOk := findResult["has_more"]
|
||||
if !hasMoreOk {
|
||||
t.Fatal("no has more")
|
||||
}
|
||||
|
||||
hasMore := hasMoreAny.(bool)
|
||||
|
||||
if hasMore {
|
||||
t.Fatalf("Expected to have reached the end of the results")
|
||||
}
|
||||
|
||||
totalAny, totalOk := findResult["total"]
|
||||
if !totalOk {
|
||||
t.Fatal("no total")
|
||||
}
|
||||
|
||||
total := totalAny.(int64)
|
||||
|
||||
if total != 1 {
|
||||
t.Fatalf("Expected total to be 1 but found %d", total)
|
||||
}
|
||||
}
|
||||
63
generic.go
Normal file
63
generic.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log"
|
||||
"context"
|
||||
|
||||
"go.mongodb.org/mongo-driver/v2/bson"
|
||||
"go.mongodb.org/mongo-driver/v2/mongo"
|
||||
"go.mongodb.org/mongo-driver/v2/mongo/options"
|
||||
)
|
||||
|
||||
func (c *MongoClient) GenericFind(ctx context.Context, payload *FindRequest) (*DataResults, error) {
|
||||
account := ctx.Value("account").(string)
|
||||
if account == "" {
|
||||
return nil, errors.New("account required")
|
||||
}
|
||||
|
||||
if payload.Entity == "" {
|
||||
return nil, errors.New("entity required")
|
||||
}
|
||||
|
||||
name := payload.Entity
|
||||
|
||||
database := c.GetName(account)
|
||||
|
||||
collection := c.GetCollection(database, name)
|
||||
|
||||
log.Printf("%v", collection)
|
||||
|
||||
var filter bson.D
|
||||
var err error
|
||||
|
||||
filter, err = MapToBsonD(payload.Filter)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
opts := options.Find()
|
||||
|
||||
var cursor *mongo.Cursor
|
||||
cursor, err = collection.Find(context.TODO(), filter, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
results := make([]map[string]any, 0)
|
||||
if err = cursor.All(context.TODO(), &results); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dataResults := &DataResults{
|
||||
Data: results,
|
||||
}
|
||||
|
||||
return dataResults, nil
|
||||
}
|
||||
|
||||
func GenericInsertOne(ctx context.Context, entityType string, data any) (any, error) {
|
||||
// coll := getCollection("sampledb", "dummy")
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
25
get_one.go
Normal file
25
get_one.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"go.mongodb.org/mongo-driver/v2/bson"
|
||||
)
|
||||
|
||||
// FindOne will return at most one document based on the filter provided.
|
||||
func (c *MongoClient) GetOne(ctx context.Context, database, name string, id any) (bson.M, error) {
|
||||
|
||||
// 1. Prepare to query.
|
||||
collection := c.GetCollection(database, name)
|
||||
|
||||
// 2. Query
|
||||
filter := map[string]any { "_id": id }
|
||||
|
||||
var out bson.M
|
||||
err := collection.FindOne(ctx, filter).Decode(&out)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
49
get_one_test.go
Normal file
49
get_one_test.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetOne_ByIdString(t *testing.T) {
|
||||
client := GetMongoClient()
|
||||
|
||||
data := map[string]any {
|
||||
"_id": "su_156345",
|
||||
"name": "MyNameGetOne",
|
||||
"age": int32(25),
|
||||
}
|
||||
|
||||
o, err := client.InsertOne(context.Background(), "mydb", "mycollection", data)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to insertOne %#v", err)
|
||||
}
|
||||
|
||||
found, err := client.GetOne(context.Background(), "mydb", "mycollection", "su_156345")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to findOne %#v", err)
|
||||
}
|
||||
|
||||
AssertSubset(t, o, found, "Should have been equal")
|
||||
}
|
||||
|
||||
func TestGetOne_ByObjectId(t *testing.T) {
|
||||
client := GetMongoClient()
|
||||
|
||||
data := map[string]any {
|
||||
"name": "MyNameGetOne",
|
||||
"age": int32(25),
|
||||
}
|
||||
|
||||
o, err := client.InsertOne(context.Background(), "mydb", "mycollection", data)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to insertOne %#v", err)
|
||||
}
|
||||
|
||||
found, err := client.GetOne(context.Background(), "mydb", "mycollection", o["_id"])
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to findOne %#v", err)
|
||||
}
|
||||
|
||||
AssertSubset(t, o, found, "Should have been equal")
|
||||
}
|
||||
83
insert.go
Normal file
83
insert.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
"fmt"
|
||||
|
||||
"go.mongodb.org/mongo-driver/v2/bson"
|
||||
"github.com/matoous/go-nanoid/v2"
|
||||
|
||||
// "git.gsuntres.com/boxtep/boxtep/core"
|
||||
)
|
||||
|
||||
const alphabet = "0123456789abcdefghijclmnopqrstuvwxyz"
|
||||
|
||||
const maxLen = 10
|
||||
|
||||
// InsertOneWithStruct can be used to insert defined structs.
|
||||
func (c *MongoClient) InsertOneFromStruct(ctx context.Context, database, name string, data any) (bson.M, error) {
|
||||
o, err := ToMap(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return c.InsertOne(ctx, database, name, o)
|
||||
}
|
||||
|
||||
// InsertOne will add missing ids and the created date before saving to the database.
|
||||
func (c *MongoClient) InsertOne(ctx context.Context, database, name string, data bson.M) (bson.M, error) {
|
||||
collection := c.GetCollection(database, name)
|
||||
|
||||
prepareForInsert(data, c.GetIdPrefix(name))
|
||||
|
||||
if _, err := collection.InsertOne(ctx, data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// prepareForInsert takes a map[string]any and:
|
||||
// * adds a new _id if property does not exist or is an empty string
|
||||
// * adds/updates property created_at using the current timestamp.
|
||||
func prepareForInsert(data bson.M, idPrefix string) {
|
||||
ensureId(data, idPrefix)
|
||||
ensureCreatedAt(data)
|
||||
ensureUpdatedAt(data)
|
||||
}
|
||||
|
||||
// ensureId adds the id property when missing or when it's an empty string.
|
||||
func ensureId(data bson.M, idPrefix string) string {
|
||||
maybeId, hasId := data["_id"]
|
||||
|
||||
var id, finalId string
|
||||
if !hasId || maybeId == "" {
|
||||
id, _ = gonanoid.Generate(alphabet, maxLen)
|
||||
if idPrefix != "" {
|
||||
finalId = fmt.Sprintf("%s_%s", idPrefix, id)
|
||||
} else {
|
||||
finalId = id
|
||||
}
|
||||
|
||||
data["_id"] = finalId
|
||||
}
|
||||
|
||||
return finalId
|
||||
}
|
||||
|
||||
func ensureCreatedAt(data bson.M) time.Time {
|
||||
now := time.Now().UTC().Truncate(time.Millisecond)
|
||||
|
||||
data["created_at"] = now
|
||||
|
||||
return now
|
||||
}
|
||||
|
||||
func ensureUpdatedAt(data bson.M) time.Time {
|
||||
now := time.Now().UTC().Truncate(time.Millisecond)
|
||||
|
||||
data["updated_at"] = now
|
||||
|
||||
return now
|
||||
}
|
||||
175
insert_test.go
Normal file
175
insert_test.go
Normal file
@@ -0,0 +1,175 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"os"
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
"testing/synctest"
|
||||
"encoding/json"
|
||||
|
||||
"go.mongodb.org/mongo-driver/v2/bson"
|
||||
)
|
||||
|
||||
func TestInsertOne(t *testing.T) {
|
||||
client := GetMongoClient()
|
||||
|
||||
data := map[string]any {
|
||||
"_id": "su_123456",
|
||||
"name": "MyName",
|
||||
"age": int32(25),
|
||||
}
|
||||
|
||||
o, err := client.InsertOne(context.Background(), "mydb", "mycollection", data)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to insertOne %#v", err)
|
||||
}
|
||||
|
||||
// raw query
|
||||
var results bson.M
|
||||
filter := bson.M{ "name": "MyName" }
|
||||
|
||||
c := client.Client.Database("mydb").Collection("mycollection")
|
||||
c.FindOne(context.Background(), filter).Decode(&results)
|
||||
|
||||
AssertSubset(t, o, results, "Should have been equal")
|
||||
}
|
||||
|
||||
func TestInsertOne_DataStruct(t *testing.T) {
|
||||
client := GetMongoClient()
|
||||
|
||||
data := &Sample {
|
||||
Id: "su_123457",
|
||||
Name: "MyName200",
|
||||
Age: 25,
|
||||
}
|
||||
|
||||
o, err := client.InsertOneFromStruct(context.Background(), "mydb", "mycollection", data)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to insertOne %#v", err)
|
||||
}
|
||||
|
||||
// raw query
|
||||
var results bson.M
|
||||
filter := map[string]any { "name": "MyName200" }
|
||||
|
||||
c := client.Client.Database("mydb").Collection("mycollection")
|
||||
c.FindOne(context.Background(), filter).Decode(&results)
|
||||
|
||||
AssertSubset(t, o, results, "Should have been equal")
|
||||
}
|
||||
|
||||
func TestInsertOne_WithIdPrefix(t *testing.T) {
|
||||
// Read JSON file
|
||||
data, err := os.ReadFile("./.test/user.json")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var user bson.M
|
||||
if err := json.Unmarshal(data, &user); err != nil {
|
||||
t.Fatalf("Length: %d, First bytes: %x\n", len(data), data[:4])
|
||||
}
|
||||
|
||||
client := GetMongoClient()
|
||||
client.AddDefinition(user)
|
||||
|
||||
in := map[string]any {
|
||||
"name": "MyName112",
|
||||
"age": int32(25),
|
||||
}
|
||||
|
||||
o, err := client.InsertOne(context.Background(), "mydb", "user", in)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to insertOne %#v", err)
|
||||
}
|
||||
|
||||
// raw query
|
||||
var results bson.M
|
||||
filter := map[string]any { "name": "MyName112" }
|
||||
c := client.Client.Database("mydb").Collection("user")
|
||||
c.FindOne(context.Background(), filter).Decode(&results)
|
||||
|
||||
if !strings.HasPrefix(results["_id"].(string), "usr_") {
|
||||
t.Fatal("_id should have been prefixed")
|
||||
}
|
||||
|
||||
AssertSubset(t, o, results, "Should have been equal")
|
||||
}
|
||||
|
||||
func TestPrepareForInsert_WithoutId(t *testing.T) {
|
||||
data := map[string]any {
|
||||
"name": "My Name Is",
|
||||
}
|
||||
|
||||
prepareForInsert(data, "")
|
||||
|
||||
if id, okid := data["_id"]; !okid || id == "" {
|
||||
t.Fatal("Failed to add Id")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrepareForInsert_ExistingId(t *testing.T) {
|
||||
data := map[string]any {
|
||||
"_id": "myidxxxxxx",
|
||||
"name": "My Name Is",
|
||||
}
|
||||
|
||||
prepareForInsert(data, "")
|
||||
|
||||
if data["_id"] == "" {
|
||||
t.Fatal("id was updated")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureId(t *testing.T) {
|
||||
data := map[string]any {
|
||||
"Name": "My Name Is",
|
||||
}
|
||||
|
||||
ensureId(data, "")
|
||||
|
||||
if id, okid := data["_id"]; !okid || id == "" {
|
||||
t.Fatal("Failed to add Id")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureId_EmptyId(t *testing.T) {
|
||||
data := map[string]any {
|
||||
"_id": "myidxxxxxx",
|
||||
"name": "My Name Is",
|
||||
}
|
||||
|
||||
ensureId(data, "")
|
||||
|
||||
if data["_id"] == "" {
|
||||
t.Fatal("Id was updated")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureId_ExistingId(t *testing.T) {
|
||||
data := map[string]any {
|
||||
"_id": "",
|
||||
"name": "My Name Is",
|
||||
}
|
||||
|
||||
ensureId(data, "")
|
||||
|
||||
if id, okid := data["_id"]; !okid || id == "" {
|
||||
t.Fatal("Failed to add Id")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureCreatedAt(t *testing.T) {
|
||||
data := map[string]any {
|
||||
"name": "My Name Is",
|
||||
}
|
||||
|
||||
synctest.Test(t, func(t *testing.T) {
|
||||
now := ensureCreatedAt(data)
|
||||
|
||||
if createdAt, _ := data["created_at"]; createdAt != now {
|
||||
t.Fatal("Failed to add CreatedAt")
|
||||
}
|
||||
})
|
||||
}
|
||||
352
main.go
Normal file
352
main.go
Normal file
@@ -0,0 +1,352 @@
|
||||
// Package mongo provides a simple interface to the database.
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
"fmt"
|
||||
"encoding/json"
|
||||
|
||||
"go.mongodb.org/mongo-driver/v2/event"
|
||||
"go.mongodb.org/mongo-driver/v2/mongo"
|
||||
"go.mongodb.org/mongo-driver/v2/mongo/options"
|
||||
"go.mongodb.org/mongo-driver/v2/bson"
|
||||
|
||||
"git.gsuntres.com/boxtep/boxtep/core"
|
||||
"git.gsuntres.com/boxtep/boxtep/sys"
|
||||
)
|
||||
|
||||
// WithiSessionFunc will run operations on the database within the same session.
|
||||
// If this functions returns an error, the system will rollback the transaction.
|
||||
// Sessions require a resplica set.
|
||||
type WithinSessionFunc func(context.Context, *mongo.Session) (any, error)
|
||||
|
||||
type IMongoClient interface {
|
||||
// WithinSession takes a context, the database and the collection and excecutes the callback function provided.
|
||||
WithinSession(context.Context, string, string, WithinSessionFunc) (any, error)
|
||||
|
||||
// InsertOne will insert data to the specified namespace.
|
||||
InsertOne(context.Context, string, string, bson.M) (bson.M, error)
|
||||
|
||||
// Find
|
||||
Find(context.Context, string, string, bson.M) ([]bson.M, error)
|
||||
|
||||
// GetCollection returns the requested collection within an account. It will create it if it doesn't exist.
|
||||
GetCollection(database, name string) *mongo.Collection
|
||||
}
|
||||
|
||||
type CollectionDefinition struct {
|
||||
Name string `bson:"_name"`
|
||||
Singular string `bson:"singular"`
|
||||
Plural string `bson: "plural"`
|
||||
IdPrefix string `bson: "idPrefix"`
|
||||
IndexSpecs []map[string]any `bson: "indexSpecs"`
|
||||
Schema map[string]any `bson: "schema"`
|
||||
}
|
||||
|
||||
// func (cd *CollectionDefinition) GetSchema(name string)
|
||||
|
||||
// MongoClient
|
||||
type MongoClient struct {
|
||||
// Client the actual connected instance of mongo client.
|
||||
Client *mongo.Client
|
||||
// Debug set to true for dislaying info level logs.
|
||||
Debug bool
|
||||
// DebugQuery set to true to log all queries done.
|
||||
DebugQuery bool
|
||||
// Limit the default limit to use in queries.
|
||||
Limit int64
|
||||
// DBPrefix optinal string literal to distinquise
|
||||
DBPrefix string
|
||||
// Registry holds critical information about collection's structure, like schema and indexes
|
||||
Registry map[string]*CollectionDefinition
|
||||
}
|
||||
|
||||
func (c *MongoClient) GetIdPrefix(name string) string {
|
||||
def, ok := c.Registry[name]
|
||||
if ok {
|
||||
return def.IdPrefix
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
const ADD_DEFINITION_SCHEMA = `
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"_name": {
|
||||
"type": "string"
|
||||
},
|
||||
"idPrefix": {
|
||||
"type": "string"
|
||||
},
|
||||
"indexSpecs": {
|
||||
"type": "array"
|
||||
},
|
||||
"singular": {
|
||||
"type": "string"
|
||||
},
|
||||
"system": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"plural": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"required": ["_name", "singular", "plural"],
|
||||
"additionalProperties": true
|
||||
}
|
||||
`
|
||||
|
||||
func (c *MongoClient) AddDefinition(data map[string]any) {
|
||||
if valid := core.Validate(ADD_DEFINITION_SCHEMA, data); valid != nil {
|
||||
log.Printf("failed to register data: %v", valid)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("Registering %s", data["_name"])
|
||||
|
||||
b, err := bson.Marshal(data)
|
||||
if err != nil {
|
||||
log.Printf("failed to marshal: %v", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
var cd CollectionDefinition
|
||||
bson.Unmarshal(b, &cd)
|
||||
|
||||
if len(cd.Name) > 0 {
|
||||
c.Registry[cd.Name] = &cd
|
||||
}
|
||||
}
|
||||
|
||||
func (c *MongoClient) GetCollection(database, name string) *mongo.Collection {
|
||||
if c.Debug {
|
||||
log.Printf("Using collection: %s.%s", database, name)
|
||||
}
|
||||
|
||||
db := c.Client.Database(database)
|
||||
|
||||
// 1. List existing collections
|
||||
names, err := db.ListCollectionNames(context.TODO(), bson.D{})
|
||||
|
||||
if err != nil {
|
||||
log.Printf("Failed to list collections: %#v", err.Error())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 2. If collection exist return it, otherwise create it and then return it
|
||||
if slices.Contains(names, name) {
|
||||
return db.Collection(name)
|
||||
} else {
|
||||
opts := options.CreateCollection()
|
||||
|
||||
// maybe get from schema
|
||||
cdef, ok := c.Registry[name]
|
||||
if ok {
|
||||
log.Printf("Schema found for %s; will use it", name)
|
||||
|
||||
ApplySchema(cdef, opts)
|
||||
} else {
|
||||
log.Printf("No schema for %s", name)
|
||||
}
|
||||
|
||||
if err := db.CreateCollection(context.TODO(), name, opts); err != nil {
|
||||
log.Printf("Failed to create collection: %#v", err)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
collection := db.Collection(name)
|
||||
|
||||
c.CreateIndexes(collection, cdef)
|
||||
|
||||
return collection
|
||||
}
|
||||
}
|
||||
|
||||
func ApplySchema(cdef *CollectionDefinition, opts *options.CreateCollectionOptionsBuilder) {
|
||||
// Add schema validation
|
||||
if cdef.Schema != nil {
|
||||
schemaBson, err := bson.Marshal(cdef.Schema)
|
||||
if err != nil {
|
||||
log.Printf("failed to parse schema: %v", err)
|
||||
} else {
|
||||
var validatorSchema bson.M
|
||||
bson.Unmarshal(schemaBson, &validatorSchema)
|
||||
validator := bson.M{
|
||||
"$jsonSchema": validatorSchema,
|
||||
}
|
||||
|
||||
opts.SetValidator(validator)
|
||||
}
|
||||
} else {
|
||||
log.Printf("Invalid schema, do nothing.")
|
||||
}
|
||||
}
|
||||
|
||||
var client *MongoClient = &MongoClient{
|
||||
Limit: 10,
|
||||
Registry: make(map[string]*CollectionDefinition, 0),
|
||||
}
|
||||
|
||||
func GetMongoClient() *MongoClient {
|
||||
return client
|
||||
}
|
||||
|
||||
const validSchema_StartProps = `
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"MongoUri": {
|
||||
"type": "string"
|
||||
},
|
||||
"MongoUser": {
|
||||
"type": "string"
|
||||
},
|
||||
"MongoPass": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"required": ["MongoUri"]
|
||||
}
|
||||
`
|
||||
|
||||
type MongoStartProps struct {
|
||||
MongoUri string
|
||||
MongoUser string
|
||||
MongoPass string
|
||||
MongoDebugQuery bool
|
||||
MongoDBPrefix string
|
||||
}
|
||||
|
||||
func Start(props *MongoStartProps) error {
|
||||
if err := core.Validate(validSchema_StartProps, props); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
uri := props.MongoUri
|
||||
user := props.MongoUser
|
||||
pass := props.MongoPass
|
||||
|
||||
// Create a new client and connect to the server
|
||||
var err error
|
||||
|
||||
cOptions := options.Client().
|
||||
ApplyURI(uri).
|
||||
SetAuth(options.Credential{
|
||||
Username: user,
|
||||
Password: pass,
|
||||
}).
|
||||
SetConnectTimeout(5 * time.Second).
|
||||
SetServerAPIOptions(&options.ServerAPIOptions{
|
||||
ServerAPIVersion: options.ServerAPIVersion1,
|
||||
}).
|
||||
SetBSONOptions(&options.BSONOptions{
|
||||
DefaultDocumentM: true,
|
||||
UseJSONStructTags: false,
|
||||
NilSliceAsEmpty: true,
|
||||
NilMapAsEmpty: true,
|
||||
}).
|
||||
SetRegistry(GetCustomRegistry())
|
||||
|
||||
client.DebugQuery = props.MongoDebugQuery
|
||||
if client.DebugQuery {
|
||||
// Debug queries
|
||||
monitor := &event.CommandMonitor{
|
||||
Started: func(_ context.Context, e *event.CommandStartedEvent) {
|
||||
log.Printf("%d@Start %s#%s %s", e.RequestID, e.DatabaseName, e.CommandName, e.Command)
|
||||
},
|
||||
Succeeded: func(_ context.Context, e *event.CommandSucceededEvent) {
|
||||
log.Printf("%d@OK in %s", e.RequestID, e.Reply)
|
||||
},
|
||||
Failed: func(_ context.Context, e *event.CommandFailedEvent) {
|
||||
log.Printf("%d@Fail in %s", e.RequestID, e.Failure)
|
||||
},
|
||||
}
|
||||
|
||||
cOptions = cOptions.SetMonitor(monitor)
|
||||
}
|
||||
|
||||
// set DBPrefix
|
||||
client.DBPrefix = props.MongoDBPrefix
|
||||
|
||||
if err := cOptions.Validate(); err != nil {
|
||||
log.Fatalf("Failed to validate mongo options: %+v", err.Error())
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
client.Client, err = mongo.Connect(cOptions)
|
||||
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to connect: %+v", err.Error())
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
sys.OnExit(func () {
|
||||
if err := client.Client.Disconnect(context.TODO()); err != nil {
|
||||
log.Fatalf("failed to disconnect: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
// Register defaults.go
|
||||
var structfulDef bson.M
|
||||
json.Unmarshal(structfulJson, &structfulDef)
|
||||
|
||||
client.AddDefinition(structfulDef)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetClient() *mongo.Client {
|
||||
return client.Client
|
||||
}
|
||||
|
||||
func (c *MongoClient) GetName(account string) string {
|
||||
return fmt.Sprintf("%s%s", c.DBPrefix, account)
|
||||
}
|
||||
|
||||
func (c *MongoClient) ExtractAccount(name string) string {
|
||||
v, ok := strings.CutPrefix(name, c.DBPrefix)
|
||||
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
return v
|
||||
}
|
||||
|
||||
func Stop() error {
|
||||
if err := client.Client.Disconnect(context.TODO()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Printf("Successfully stopped mongo.")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func MapToBsonD(m map[string]any) (bson.D, error) {
|
||||
data, err := bson.Marshal(m)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var d bson.D
|
||||
|
||||
err = bson.Unmarshal(data, &d)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return d, nil
|
||||
}
|
||||
75
main_index.go
Normal file
75
main_index.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
|
||||
"go.mongodb.org/mongo-driver/v2/mongo"
|
||||
"go.mongodb.org/mongo-driver/v2/mongo/options"
|
||||
"go.mongodb.org/mongo-driver/v2/bson"
|
||||
)
|
||||
|
||||
// CreateIndexes will create indexes for the given collection and definition.
|
||||
func (c *MongoClient) CreateIndexes(collection *mongo.Collection, cdef *CollectionDefinition) {
|
||||
if cdef == nil {
|
||||
log.Printf("No definitions will not create indexes")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// handle indexes
|
||||
indexModels := make([]mongo.IndexModel, 0)
|
||||
|
||||
for _, keyDef := range cdef.IndexSpecs {
|
||||
log.Printf("Key Definition %s", keyDef["name"])
|
||||
kdb, err := bson.Marshal(keyDef)
|
||||
if err != nil {
|
||||
log.Printf("failed to marshal %v", err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
kdRaw := bson.Raw(kdb)
|
||||
if err := kdRaw.Validate(); err != nil {
|
||||
log.Printf("failed to validate bson raw: %v", err)
|
||||
|
||||
continue
|
||||
}
|
||||
//
|
||||
idxModel := mongo.IndexModel{}
|
||||
|
||||
opts := options.Index()
|
||||
|
||||
keysVal := kdRaw.Lookup("keys")
|
||||
|
||||
var keysBson bson.D
|
||||
if err := bson.Unmarshal(keysVal.Value, &keysBson); err != nil {
|
||||
log.Printf("failed to unmarshal keys value %v", err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
idxModel.Keys = keysBson
|
||||
|
||||
nameVal := kdRaw.Lookup("name")
|
||||
if name, ok := nameVal.StringValueOK(); ok {
|
||||
opts = opts.SetName(name)
|
||||
}
|
||||
|
||||
uniqueVal := kdRaw.Lookup("unique")
|
||||
if unique, ok := uniqueVal.BooleanOK(); ok {
|
||||
opts = opts.SetUnique(unique)
|
||||
}
|
||||
|
||||
partialVal := kdRaw.Lookup("partialFilterExpression")
|
||||
if partialFilterExpression, ok := partialVal.BooleanOK(); ok {
|
||||
opts = opts.SetPartialFilterExpression(partialFilterExpression)
|
||||
}
|
||||
|
||||
idxModel.Options = opts
|
||||
|
||||
indexModels = append(indexModels, idxModel)
|
||||
}
|
||||
|
||||
collection.Indexes().CreateMany(context.Background(), indexModels)
|
||||
}
|
||||
56
main_index_test.go
Normal file
56
main_index_test.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"go.mongodb.org/mongo-driver/v2/bson"
|
||||
)
|
||||
|
||||
func TestCreateIndexes(t *testing.T) {
|
||||
cd := &CollectionDefinition{
|
||||
IndexSpecs: []map[string]any{
|
||||
{
|
||||
"keys": map[string]any{
|
||||
"code": 1,
|
||||
},
|
||||
"name": "idx_1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
c := GetMongoClient()
|
||||
|
||||
collection := c.GetCollection("mydb", "mycol")
|
||||
|
||||
c.CreateIndexes(collection, cd)
|
||||
|
||||
indexView := collection.Indexes()
|
||||
|
||||
// Specify a timeout to limit the amount of time the operation can run on
|
||||
// the server.
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), time.Second)
|
||||
defer cancel()
|
||||
|
||||
cursor, err := indexView.List(ctx, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Get a slice of all indexes returned and print them out.
|
||||
var results []bson.M
|
||||
if err = cursor.All(ctx, &results); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
idx1 := results[1]
|
||||
|
||||
if len(results) != 2 {
|
||||
t.Fatal("Unexpected number of indexes")
|
||||
}
|
||||
|
||||
if idx1["name"] != "idx_1" {
|
||||
t.Fatal("Should have register index idx_1")
|
||||
}
|
||||
}
|
||||
134
main_test.go
Normal file
134
main_test.go
Normal file
@@ -0,0 +1,134 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"os"
|
||||
"time"
|
||||
"reflect"
|
||||
"testing"
|
||||
"context"
|
||||
"fmt"
|
||||
"encoding/json"
|
||||
|
||||
"go.mongodb.org/mongo-driver/v2/mongo/options"
|
||||
"go.mongodb.org/mongo-driver/v2/bson"
|
||||
|
||||
"github.com/testcontainers/testcontainers-go"
|
||||
"github.com/testcontainers/testcontainers-go/modules/mongodb"
|
||||
)
|
||||
|
||||
type Sample struct {
|
||||
Id string `json:"id" bson:"_id,omitempty"`
|
||||
Name string `json:"name" bson:"name"`
|
||||
Age int32 `json:"age" bson:"age"`
|
||||
CreatedAt time.Time `json:"created_at" bson:"created_at"`
|
||||
}
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
if os.Getenv("RUN_INTEGRATION") == "" {
|
||||
fmt.Println("Skipping package tests: RUN_INTEGRATION is missing")
|
||||
os.Exit(0) // Exit with success, but no tests ran
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
user := "admin"
|
||||
pass := "1"
|
||||
|
||||
// 1. Setup: Start the MongoDB container
|
||||
mongoContainer, err := mongodb.Run(ctx, "mongo:8",
|
||||
testcontainers.WithEnv(map[string]string{
|
||||
"MONGO_INITDB_ROOT_USERNAME": user,
|
||||
"MONGO_INITDB_ROOT_PASSWORD": pass,
|
||||
}),)
|
||||
if err != nil {
|
||||
panic("failed to start container")
|
||||
}
|
||||
|
||||
// 2. Get the connection string dynamically
|
||||
endpoint, _ := mongoContainer.ConnectionString(ctx)
|
||||
|
||||
Start(&MongoStartProps{
|
||||
MongoUri: endpoint,
|
||||
MongoUser: user,
|
||||
MongoPass: pass,
|
||||
})
|
||||
|
||||
// 3. Run tests
|
||||
LoadTestSample()
|
||||
|
||||
code := m.Run()
|
||||
|
||||
// 4. Teardown: Clean up resources
|
||||
Stop()
|
||||
|
||||
_ = testcontainers.TerminateContainer(mongoContainer)
|
||||
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
func TestApplySchema(t *testing.T) {
|
||||
// prepare schema sample
|
||||
schemaStr := `
|
||||
{
|
||||
"bsonType": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"bsonType": "string"
|
||||
},
|
||||
"active": {
|
||||
"bsonType": "boolean"
|
||||
}
|
||||
},
|
||||
"required": ["name"],
|
||||
"additionalProperties": true
|
||||
}
|
||||
`
|
||||
|
||||
var result bson.M
|
||||
err := json.Unmarshal([]byte(schemaStr), &result)
|
||||
if err != nil {
|
||||
t.Fatalf("%v", err)
|
||||
}
|
||||
|
||||
cd := &CollectionDefinition{
|
||||
Schema: result,
|
||||
}
|
||||
|
||||
opts := options.CreateCollection()
|
||||
|
||||
ApplySchema(cd, opts)
|
||||
|
||||
options := &options.CreateCollectionOptions{}
|
||||
|
||||
for _, o := range opts.List() {
|
||||
o(options)
|
||||
}
|
||||
|
||||
expected := bson.M{"$jsonSchema":bson.M{"additionalProperties":true, "bsonType":"object", "properties":bson.D{bson.E{Key:"active", Value:bson.D{bson.E{Key:"bsonType", Value:"boolean"}}}, bson.E{Key:"name", Value:bson.D{bson.E{Key:"bsonType", Value:"string"}}}}, "required":bson.A{"name"}}}
|
||||
|
||||
aj, _ := json.Marshal(options.Validator)
|
||||
ej, _ := json.Marshal(expected)
|
||||
|
||||
var am, em any
|
||||
json.Unmarshal(aj, &am)
|
||||
json.Unmarshal(ej, &em)
|
||||
|
||||
if !reflect.DeepEqual(am, em) {
|
||||
t.Fatalf("Validator should have been set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCollection_Default(t *testing.T) {
|
||||
c := GetMongoClient()
|
||||
|
||||
col := c.GetCollection("mydb", "mycol")
|
||||
|
||||
if col == nil {
|
||||
t.Fatal("should have return collection")
|
||||
}
|
||||
|
||||
if col.Name() != "mycol" {
|
||||
t.Fatal("Wrong name")
|
||||
}
|
||||
|
||||
}
|
||||
52
pipeline.go
Normal file
52
pipeline.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"go.mongodb.org/mongo-driver/v2/bson"
|
||||
"go.mongodb.org/mongo-driver/v2/mongo"
|
||||
)
|
||||
|
||||
func BuildPaginationPipeline(skip, limit int64, filter bson.M, sort bson.M) mongo.Pipeline {
|
||||
return mongo.Pipeline{
|
||||
// 1. GLOBAL FILTER: Always filter first to use indexes
|
||||
{{Key: "$match", Value: filter}},
|
||||
|
||||
// 2. GLOBAL SORT: Sort here so both 'total' and 'data' facets use the same order
|
||||
{{Key: "$sort", Value: sort}},
|
||||
|
||||
// 3. FACET: Split the pipeline into two parallel paths
|
||||
{{Key: "$facet", Value: bson.D{
|
||||
// Path A: Get the total count of documents matching the filter
|
||||
{Key: "metadata", Value: mongo.Pipeline{
|
||||
{{Key: "$count", Value: "total"}},
|
||||
}},
|
||||
// Path B: Get the specific page of data
|
||||
{Key: "data", Value: mongo.Pipeline{
|
||||
{{Key: "$skip", Value: skip}},
|
||||
{{Key: "$limit", Value: limit}},
|
||||
}},
|
||||
}}},
|
||||
}
|
||||
}
|
||||
|
||||
func BuildPaginationPipelineNext(limit int64, filter bson.M, sort bson.M) mongo.Pipeline {
|
||||
|
||||
return mongo.Pipeline{
|
||||
// 1. GLOBAL FILTER: Always filter first to use indexes
|
||||
{{Key: "$match", Value: filter}},
|
||||
|
||||
// 2. GLOBAL SORT: Sort here so both 'total' and 'data' facets use the same order
|
||||
{{Key: "$sort", Value: sort}},
|
||||
|
||||
// 3. FACET: Split the pipeline into two parallel paths
|
||||
{{Key: "$facet", Value: bson.D{
|
||||
// Path A: Get the total count of documents matching the filter
|
||||
{Key: "metadata", Value: mongo.Pipeline{
|
||||
{{Key: "$count", Value: "total"}},
|
||||
}},
|
||||
// Path B: Get the specific page of data
|
||||
{Key: "data", Value: mongo.Pipeline{
|
||||
{{Key: "$limit", Value: limit}},
|
||||
}},
|
||||
}}},
|
||||
}
|
||||
}
|
||||
88
registry.go
Normal file
88
registry.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/go-viper/mapstructure/v2"
|
||||
"go.mongodb.org/mongo-driver/v2/bson"
|
||||
)
|
||||
|
||||
func TimeToStringHook(f reflect.Type, t reflect.Type, data any) (any, error) {
|
||||
|
||||
// 1. Target must be a string
|
||||
if t.Kind() != reflect.String {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// 2. Check if source is time.Time OR *time.Time
|
||||
isTime := f == reflect.TypeOf(time.Time{})
|
||||
isTimePtr := f == reflect.TypeOf(&time.Time{})
|
||||
|
||||
if isTime || isTimePtr {
|
||||
// Handle pointer vs value during type assertion
|
||||
if isTimePtr {
|
||||
if ptr, ok := data.(*time.Time); ok && ptr != nil {
|
||||
|
||||
return ptr.Format(time.RFC3339), nil
|
||||
}
|
||||
|
||||
return "", nil
|
||||
}
|
||||
|
||||
return data.(time.Time).Format(time.RFC3339), nil
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func StructToMap(input any) (map[string]any, error) {
|
||||
var result map[string]any
|
||||
|
||||
config := &mapstructure.DecoderConfig{
|
||||
DecodeHook: TimeToStringHook,
|
||||
Result: &result,
|
||||
TagName: "bson",
|
||||
}
|
||||
|
||||
decoder, err := mapstructure.NewDecoder(config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = decoder.Decode(input)
|
||||
|
||||
return result, err
|
||||
}
|
||||
|
||||
func TruncatingTimeEncoder(ec bson.EncodeContext, vw bson.ValueWriter, val reflect.Value) error {
|
||||
// 1. Check if the type is exactly time.Time
|
||||
if val.Type() != reflect.TypeOf(time.Time{}) {
|
||||
// Fallback: Use the encoder from the current context's registry
|
||||
enc, err := ec.Registry.LookupEncoder(val.Type())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return enc.EncodeValue(ec, vw, val)
|
||||
}
|
||||
|
||||
// 2. Perform the truncation logic
|
||||
t := val.Interface().(time.Time)
|
||||
truncated := t.Truncate(time.Millisecond)
|
||||
|
||||
// 3. Write as a standard BSON DateTime
|
||||
return vw.WriteDateTime(truncated.UnixMilli())
|
||||
}
|
||||
|
||||
func GetCustomRegistry() *bson.Registry {
|
||||
reg := bson.NewRegistry()
|
||||
|
||||
// reg.RegisterTypeEncoder(
|
||||
// reflect.TypeOf(time.Time{}),
|
||||
// bson.ValueEncoderFunc(TruncatingTimeEncoder),
|
||||
// )
|
||||
// reg.SetRegistry(mgocompat.NewRegistry())
|
||||
|
||||
return reg
|
||||
}
|
||||
40
session.go
Normal file
40
session.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"go.mongodb.org/mongo-driver/v2/mongo"
|
||||
)
|
||||
|
||||
|
||||
func (c *MongoClient) WithinSession(ctx context.Context, cb WithinSessionFunc) (any, error) {
|
||||
|
||||
// 1. Start a new session
|
||||
sess, err := c.Client.StartSession()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer sess.EndSession(ctx)
|
||||
|
||||
// 2. Create a new context
|
||||
ctxNew := mongo.NewSessionContext(ctx, sess)
|
||||
|
||||
// 3. Start transaction
|
||||
if err = sess.StartTransaction(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res, err := cb(ctxNew, sess)
|
||||
if err != nil {
|
||||
_ = sess.AbortTransaction(context.Background())
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := sess.CommitTransaction(context.Background()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
90
testing.go
Normal file
90
testing.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
"context"
|
||||
"os"
|
||||
|
||||
"go.mongodb.org/mongo-driver/v2/bson"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
)
|
||||
|
||||
const SAMPLE_DB = "mydb"
|
||||
const SAMPLE_COLLECTION = "mycollection"
|
||||
|
||||
// AssertSubset compares a subset map against a superset (from Mongo).
|
||||
func AssertSubset(t *testing.T, subset, superset map[string]any, msgAndArgs ...any) bool {
|
||||
t.Helper()
|
||||
|
||||
// 1. Filter: Ignore keys in the superset that aren't in our expected subset
|
||||
ignoreExtra := cmpopts.IgnoreMapEntries(func(k string, v any) bool {
|
||||
_, ok := subset[k]
|
||||
return !ok
|
||||
})
|
||||
|
||||
// 2. Transformer: Convert Mongo v2 int64 back to time.Time for comparison
|
||||
// 1. Define the logic
|
||||
timeLogic := func(v any) any {
|
||||
switch val := v.(type) {
|
||||
case time.Time:
|
||||
return val.UTC()
|
||||
case bson.DateTime:
|
||||
return val.Time().UTC()
|
||||
case int64:
|
||||
return time.UnixMilli(val).UTC()
|
||||
default:
|
||||
return time.Time{}
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Wrap it in a Filter so it's not an "unfiltered" option
|
||||
timeTransform := cmp.FilterValues(func(x, y any) bool {
|
||||
_, xisbsontime := x.(bson.DateTime)
|
||||
_, yistime := y.(time.Time)
|
||||
_, xistime := x.(time.Time)
|
||||
_, yisbsontime := y.(bson.DateTime)
|
||||
|
||||
return (xisbsontime && yistime) || (yisbsontime && xistime)
|
||||
}, cmp.Transformer("BsonTime", timeLogic))
|
||||
|
||||
// 3. Execute Diff
|
||||
diff := cmp.Diff(subset, superset,
|
||||
ignoreExtra,
|
||||
timeTransform,
|
||||
cmpopts.EquateEmpty(), // Treats nil slice == empty slice
|
||||
)
|
||||
|
||||
if diff != "" {
|
||||
msg := "Maps do not match (subset != superset)"
|
||||
if len(msgAndArgs) > 0 {
|
||||
msg = fmt.Sprintf(msgAndArgs[0].(string), msgAndArgs[1:]...)
|
||||
}
|
||||
t.Errorf("%s:\n%s", msg, diff)
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func LoadTestSample() error {
|
||||
// Read JSON file
|
||||
jsonData, _ := os.ReadFile("./.test/sample.json")
|
||||
|
||||
var docs []bson.M
|
||||
if err := bson.UnmarshalExtJSON(jsonData, true, &docs); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Insert into MongoDB
|
||||
client := GetMongoClient()
|
||||
collection := client.Client.Database(SAMPLE_DB).Collection(SAMPLE_COLLECTION)
|
||||
if _, err := collection.InsertMany(context.TODO(), docs); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
26
types.go
Normal file
26
types.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package mongo
|
||||
|
||||
// TODO(me): Place it in find.go?
|
||||
type FindRequest struct {
|
||||
Entity string
|
||||
Limit int
|
||||
Skip int
|
||||
Filter map[string]any
|
||||
}
|
||||
|
||||
// TODO(me): Remove it in favor of [FindResult]?
|
||||
type DataResults struct {
|
||||
Data []map[string]any
|
||||
}
|
||||
|
||||
// TODO(me): Remove it in favor of [FindResult]?
|
||||
type Metadata struct {
|
||||
Total int64 `bson:"total"`
|
||||
}
|
||||
// TODO(me): Remove it in favor of [FindResult]?
|
||||
type PaginatedResponse struct {
|
||||
Metadata []Metadata `bson:"metadata" json:"metadata"`
|
||||
// Initializing this as an empty slice literal []User{}
|
||||
// prevents "null" in your JSON output.
|
||||
Data []map[string]any `bson:"data" json:"data"`
|
||||
}
|
||||
Reference in New Issue
Block a user