From 38265c15d146e7fa3339612a959b5a7f45f43e67 Mon Sep 17 00:00:00 2001 From: George Suntres Date: Sun, 29 Mar 2026 11:38:57 -0400 Subject: [PATCH] Initial import --- .gitignore | 26 ++++ convert.go | 64 ++++++++ defaults.go | 6 + delete_one.go | 22 +++ find.go | 127 ++++++++++++++++ find_cursor.go | 77 ++++++++++ find_cursor_test.go | 26 ++++ find_one.go | 23 +++ find_one_test.go | 30 ++++ find_test.go | 51 +++++++ generic.go | 63 ++++++++ get_one.go | 25 ++++ get_one_test.go | 49 ++++++ go.mod | 3 + insert.go | 83 +++++++++++ insert_test.go | 175 ++++++++++++++++++++++ main.go | 352 ++++++++++++++++++++++++++++++++++++++++++++ main_index.go | 75 ++++++++++ main_index_test.go | 56 +++++++ main_test.go | 134 +++++++++++++++++ pipeline.go | 52 +++++++ registry.go | 88 +++++++++++ session.go | 40 +++++ testing.go | 90 +++++++++++ types.go | 26 ++++ 25 files changed, 1763 insertions(+) create mode 100644 .gitignore create mode 100644 convert.go create mode 100644 defaults.go create mode 100644 delete_one.go create mode 100644 find.go create mode 100644 find_cursor.go create mode 100644 find_cursor_test.go create mode 100644 find_one.go create mode 100644 find_one_test.go create mode 100644 find_test.go create mode 100644 generic.go create mode 100644 get_one.go create mode 100644 get_one_test.go create mode 100644 go.mod create mode 100644 insert.go create mode 100644 insert_test.go create mode 100644 main.go create mode 100644 main_index.go create mode 100644 main_index_test.go create mode 100644 main_test.go create mode 100644 pipeline.go create mode 100644 registry.go create mode 100644 session.go create mode 100644 testing.go create mode 100644 types.go diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..a66ef27 --- /dev/null +++ b/.gitignore @@ -0,0 +1,26 @@ +# Allowlisting gitignore template for GO projects prevents us +# from adding various unwanted local files, such as generated +# files, developer configurations or IDE-specific files etc. +# +# Recommended: Go.AllowList.gitignore + +# Ignore everything +* + +# But these files... +!.gitignore + +!*.go +!go.sum +!go.mod + +!README.md +!LICENSE + +!Makefile + +!*.sh +!*.md + +# ...even if they are in subdirectories +!*/ diff --git a/convert.go b/convert.go new file mode 100644 index 0000000..ed55e9a --- /dev/null +++ b/convert.go @@ -0,0 +1,64 @@ +package mongo + +import ( + "go.mongodb.org/mongo-driver/v2/bson" + + // "github.com/go-viper/mapstructure/v2" +) + +func ToMap(data any) (bson.M, error) { + // 1. Marshal the struct to BSON bytes + b, err := bson.Marshal(data) + if err != nil { + return nil, err + } + + // 2. Unmarshal the bytes back into a bson.M + var res bson.M + err = bson.Unmarshal(b, &res) + + return res, err +} + +func ToStruct(m bson.M, target any) error { + // 1. Convert map to BSON bytes + data, err := bson.Marshal(m) + if err != nil { + return err + } + + // 2. Unmarshal bytes into the target struct + return bson.Unmarshal(data, target) +} + +// func ToMap(input any) (bson.M, error) { +// var result bson.M +// config := &mapstructure.DecoderConfig{ +// TagName: "bson", // Use bson tags instead of field names +// Result: &result, +// } + +// decoder, err := mapstructure.NewDecoder(config) +// if err != nil { +// return nil, err +// } + +// err = decoder.Decode(input) +// return result, err +// } + +// func BsonToStructHook() mapstructure.DecodeHookFunc { +// return func(f reflect.Type, t reflect.Type, data interface{}) (interface{}, error) { +// // 1. Convert primitive.ObjectID -> string +// if f == reflect.TypeOf(primitive.ObjectID{}) && t.Kind() == reflect.String { +// return data.(primitive.ObjectID).Hex(), nil +// } + +// // 2. Convert primitive.DateTime -> time.Time +// if f == reflect.TypeOf(primitive.DateTime(0)) && t == reflect.TypeOf(time.Time{}) { +// return data.(primitive.DateTime).Time(), nil +// } + +// return data, nil +// } +// } \ No newline at end of file diff --git a/defaults.go b/defaults.go new file mode 100644 index 0000000..a7574cb --- /dev/null +++ b/defaults.go @@ -0,0 +1,6 @@ +package mongo + +import _ "embed" + +//go:embed structful.json +var structfulJson []byte diff --git a/delete_one.go b/delete_one.go new file mode 100644 index 0000000..d8207ba --- /dev/null +++ b/delete_one.go @@ -0,0 +1,22 @@ +package mongo + +import ( + "context" + + "go.mongodb.org/mongo-driver/v2/bson" +) + +// DeleteOne will delete the first document that matches the filter. +func (c *MongoClient) DeleteOne(ctx context.Context, database, name string, filter bson.M) error { + + // 1. Prepare query. + collection := c.GetCollection(database, name) + + // 2. Query + _, err := collection.DeleteOne(ctx, filter) + if err != nil { + return err + } + + return nil +} diff --git a/find.go b/find.go new file mode 100644 index 0000000..3aae647 --- /dev/null +++ b/find.go @@ -0,0 +1,127 @@ +package mongo + +import ( + // "log" + "context" + // "fmt" + // "encoding/base64" + + // "go.mongodb.org/mongo-driver/v2/mongo/options" + "go.mongodb.org/mongo-driver/v2/bson" + + // "git.gsuntres.com/boxtep/boxtep/core" +) + +// Find is used to fetch the first page of data. +func (c *MongoClient) Find(ctx context.Context, database, name string, filter bson.M, limit int64) (bson.M, error) { + + // 1. Prepare to query. + collection := c.GetCollection(database, name) + + pageSize := max(limit, c.Limit) + + + // id := DecodeCursor(nextCursor) + + // filter["_id"] = bson.M{"_id": bson.M{"$gt": id}} + // opts := options.Find(). + // SetLimit(pageSize + 1). + // SetSort(bson.M{"_id": 1}) + // id := DecodeCursor(nextCursor) + sort := bson.M{"_id": 1} + // filter["_id"] = bson.M{"_id": bson.M{"$gt": id}} + + pipeline := BuildPaginationPipeline(0, pageSize + 1, filter, sort) + + // 2. Query + cursor, err := collection.Aggregate(ctx, pipeline) + if err != nil { + return nil, err + } + defer cursor.Close(ctx) + + // 3. Build results + var facetResults []bson.M + if err = cursor.All(ctx, &facetResults); err != nil { + return nil, err + } + + root := facetResults[0] + + data := root["data"].(bson.A) + + metadata := root["metadata"].(bson.A) + + var totalValue any + if len(metadata) != 0 { + metadataRoot := metadata[0].(bson.M) + totalValue = metadataRoot["total"] + } + + var total int64 + switch v := totalValue.(type) { + case int32: + total = int64(v) + case int64: + total = v + default: + total = 0 + } + + hasMore := false + if int64(len(data)) > pageSize { + hasMore = true + data = data[:pageSize] + } + + out := bson.M{ + "data": data, + "has_more": hasMore, + "total": total, + } + + if hasMore { + // next cursor + var last bson.M = data[len(data) - 1].(bson.M) + var nextCursor string + lastId := last["_id"] + + nextCursor = EncodeCursor(lastId.(bson.ObjectID)) + + out["next_cursor"] = nextCursor + } + + // _data, err := bson.Marshal(out) + // if err != nil { + // return nil, err + // } + + // var r bson.M + // if err := bson.Unmarshal(_data, &r); err != nil { + // return nil, err + // } + + return out, nil +} + +// func (c *MongoClient) FindNext(ctx context.Context, database, name string, filter bson.M, nextCursor string, limit int64) ([]bson.M, error) { +// collection := c.GetCollection(database, name) + +// opts := options.Find(). +// SetLimit(max(limit, c.Limit)). +// SetSort() + +// cursor, err := collection.Find(ctx, filter) +// if err != nil { +// return nil, err +// } + +// var results []bson.M +// if err = cursor.All(ctx, &results); err != nil { +// return nil, err +// } + +// return results, err +// } + +// func EncodeCursor() diff --git a/find_cursor.go b/find_cursor.go new file mode 100644 index 0000000..2720004 --- /dev/null +++ b/find_cursor.go @@ -0,0 +1,77 @@ +package mongo + +import ( + "encoding/base64" + "fmt" + + "go.mongodb.org/mongo-driver/v2/bson" +) + +func EncodeCursor(v any) string { + var id string + switch v.(type) { + case bson.ObjectID: + id = v.(bson.ObjectID).Hex() + default: + id = v.(string) + } + + val := fmt.Sprintf("%s", id) + + return base64.RawURLEncoding.EncodeToString([]byte(val)) +} + +func DecodeCursor(okey string) (bson.ObjectID, error) { + data, err := base64.StdEncoding.DecodeString(okey) + + if err != nil { + return bson.NilObjectID, err + } + + oid, _ := bson.ObjectIDFromHex(string(data)) + + return oid, nil +} + + +// func EncodeCursor(name string, v1 any, v2 any) string { +// var otherValue string +// var otherType string + +// switch v1.(type) { +// case time.Time: +// otherType = "time.Time" +// otherValue = v1.(time.Time).Truncate(time.Millisecond).Format() +// case int: +// otherType = "time.Time" +// otherValue = v1.(time.Time).Truncate(time.Millisecond).Format() +// } + +// var id string +// switch v2.(type) { +// case bson.ObjectID: +// id = v2.(bson.ObjectID).Hex() +// case string: +// id = v2.(string) +// default: +// id = v2.(string) +// } + +// v := fmt.Sprintf("%v|%s", misc, id) + +// fmt.Println(v, name) + +// return base64.RawURLEncoding.EncodeToString([]byte(v)) +// } + +// func DecodeCursor(okey string) (name string, v1 any, v2 any) { +// data, err := base64.StdEncoding.DecodeString(okey) +// if err != nil { +// return "", nil, nil +// } + +// parts := strings.Split(string(data), "|") +// if len(parts) != 3 { +// return "", nil, nil +// } +// } \ No newline at end of file diff --git a/find_cursor_test.go b/find_cursor_test.go new file mode 100644 index 0000000..f15e41d --- /dev/null +++ b/find_cursor_test.go @@ -0,0 +1,26 @@ +package mongo + +import ( + "time" + "testing" + + "go.mongodb.org/mongo-driver/v2/bson" +) + +func TestCursor_Default(t *testing.T) { + now := time.Now() + + oid := bson.NewObjectIDFromTimestamp(now) + + okey := EncodeCursor(oid) + + id, err := DecodeCursor(okey) + + if err != nil { + t.Fatalf("Decode failed: %v", err) + } + + if id != oid { + t.Fatalf("Failed to decode id") + } +} \ No newline at end of file diff --git a/find_one.go b/find_one.go new file mode 100644 index 0000000..411fdc3 --- /dev/null +++ b/find_one.go @@ -0,0 +1,23 @@ +package mongo + +import ( + "context" + + "go.mongodb.org/mongo-driver/v2/bson" +) + +// FindOne will return at most one document based on the filter provided. +func (c *MongoClient) FindOne(ctx context.Context, database, name string, filter bson.M) (bson.M, error) { + + // 1. Prepare to query. + collection := c.GetCollection(database, name) + + // 2. Query + var out bson.M + err := collection.FindOne(ctx, filter).Decode(&out) + if err != nil { + return nil, err + } + + return out, nil +} diff --git a/find_one_test.go b/find_one_test.go new file mode 100644 index 0000000..34e0f91 --- /dev/null +++ b/find_one_test.go @@ -0,0 +1,30 @@ +package mongo + +import ( + "context" + "testing" + + "go.mongodb.org/mongo-driver/v2/bson" +) + +func TestFindOne_Default(t *testing.T) { + client := GetMongoClient() + + data := map[string]any { + "_id": "su_123459", + "name": "MyNameFindOne", + "age": int32(25), + } + + o, err := client.InsertOne(context.Background(), "mydb", "mycollection", data) + if err != nil { + t.Fatalf("Failed to insertOne %#v", err) + } + filter := bson.M{"name": "MyNameFindOne"} + found, err := client.FindOne(context.Background(), "mydb", "mycollection", filter) + if err != nil { + t.Fatalf("Failed to findOne %#v", err) + } + + AssertSubset(t, o, found, "Should have been equal") +} diff --git a/find_test.go b/find_test.go new file mode 100644 index 0000000..d468e36 --- /dev/null +++ b/find_test.go @@ -0,0 +1,51 @@ +package mongo + +import ( + "context" + "testing" + + "go.mongodb.org/mongo-driver/v2/bson" +) + +func TestFind_Default(t *testing.T) { + client := GetMongoClient() + + filter := bson.M{"name": bson.M{"$regex": "OSRAM"}} + findResult, err := client.Find(context.Background(), "mydb", "mycollection", filter, 0) + if err != nil { + t.Fatalf("Failed to insertOne %#v", err) + } + + dataAny, hasData := findResult["data"] + if !hasData { + t.Fatal("no data") + } + + data := dataAny.(bson.A) + + if len(data) != 1 { + t.Fatalf("Expected to return 1 document but got %d", len(data)) + } + + hasMoreAny, hasMoreOk := findResult["has_more"] + if !hasMoreOk { + t.Fatal("no has more") + } + + hasMore := hasMoreAny.(bool) + + if hasMore { + t.Fatalf("Expected to have reached the end of the results") + } + + totalAny, totalOk := findResult["total"] + if !totalOk { + t.Fatal("no total") + } + + total := totalAny.(int64) + + if total != 1 { + t.Fatalf("Expected total to be 1 but found %d", total) + } +} diff --git a/generic.go b/generic.go new file mode 100644 index 0000000..53a1627 --- /dev/null +++ b/generic.go @@ -0,0 +1,63 @@ +package mongo + +import ( + "errors" + "log" + "context" + + "go.mongodb.org/mongo-driver/v2/bson" + "go.mongodb.org/mongo-driver/v2/mongo" + "go.mongodb.org/mongo-driver/v2/mongo/options" +) + +func (c *MongoClient) GenericFind(ctx context.Context, payload *FindRequest) (*DataResults, error) { + account := ctx.Value("account").(string) + if account == "" { + return nil, errors.New("account required") + } + + if payload.Entity == "" { + return nil, errors.New("entity required") + } + + name := payload.Entity + + database := c.GetName(account) + + collection := c.GetCollection(database, name) + + log.Printf("%v", collection) + + var filter bson.D + var err error + + filter, err = MapToBsonD(payload.Filter) + if err != nil { + return nil, err + } + + opts := options.Find() + + var cursor *mongo.Cursor + cursor, err = collection.Find(context.TODO(), filter, opts) + if err != nil { + return nil, err + } + + results := make([]map[string]any, 0) + if err = cursor.All(context.TODO(), &results); err != nil { + return nil, err + } + + dataResults := &DataResults{ + Data: results, + } + + return dataResults, nil +} + +func GenericInsertOne(ctx context.Context, entityType string, data any) (any, error) { + // coll := getCollection("sampledb", "dummy") + + return nil, nil +} diff --git a/get_one.go b/get_one.go new file mode 100644 index 0000000..97fa0dd --- /dev/null +++ b/get_one.go @@ -0,0 +1,25 @@ +package mongo + +import ( + "context" + + "go.mongodb.org/mongo-driver/v2/bson" +) + +// FindOne will return at most one document based on the filter provided. +func (c *MongoClient) GetOne(ctx context.Context, database, name string, id any) (bson.M, error) { + + // 1. Prepare to query. + collection := c.GetCollection(database, name) + + // 2. Query + filter := map[string]any { "_id": id } + + var out bson.M + err := collection.FindOne(ctx, filter).Decode(&out) + if err != nil { + return nil, err + } + + return out, nil +} diff --git a/get_one_test.go b/get_one_test.go new file mode 100644 index 0000000..579d2aa --- /dev/null +++ b/get_one_test.go @@ -0,0 +1,49 @@ +package mongo + +import ( + "context" + "testing" +) + +func TestGetOne_ByIdString(t *testing.T) { + client := GetMongoClient() + + data := map[string]any { + "_id": "su_156345", + "name": "MyNameGetOne", + "age": int32(25), + } + + o, err := client.InsertOne(context.Background(), "mydb", "mycollection", data) + if err != nil { + t.Fatalf("Failed to insertOne %#v", err) + } + + found, err := client.GetOne(context.Background(), "mydb", "mycollection", "su_156345") + if err != nil { + t.Fatalf("Failed to findOne %#v", err) + } + + AssertSubset(t, o, found, "Should have been equal") +} + +func TestGetOne_ByObjectId(t *testing.T) { + client := GetMongoClient() + + data := map[string]any { + "name": "MyNameGetOne", + "age": int32(25), + } + + o, err := client.InsertOne(context.Background(), "mydb", "mycollection", data) + if err != nil { + t.Fatalf("Failed to insertOne %#v", err) + } + + found, err := client.GetOne(context.Background(), "mydb", "mycollection", o["_id"]) + if err != nil { + t.Fatalf("Failed to findOne %#v", err) + } + + AssertSubset(t, o, found, "Should have been equal") +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..8e3646c --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module git.gsuntres.com/gsuntres/mongo + +go 1.25.0 diff --git a/insert.go b/insert.go new file mode 100644 index 0000000..fbcec36 --- /dev/null +++ b/insert.go @@ -0,0 +1,83 @@ +package mongo + +import ( + "context" + "time" + "fmt" + + "go.mongodb.org/mongo-driver/v2/bson" + "github.com/matoous/go-nanoid/v2" + + // "git.gsuntres.com/boxtep/boxtep/core" +) + +const alphabet = "0123456789abcdefghijclmnopqrstuvwxyz" + +const maxLen = 10 + +// InsertOneWithStruct can be used to insert defined structs. +func (c *MongoClient) InsertOneFromStruct(ctx context.Context, database, name string, data any) (bson.M, error) { + o, err := ToMap(data) + if err != nil { + return nil, err + } + + return c.InsertOne(ctx, database, name, o) +} + +// InsertOne will add missing ids and the created date before saving to the database. +func (c *MongoClient) InsertOne(ctx context.Context, database, name string, data bson.M) (bson.M, error) { + collection := c.GetCollection(database, name) + + prepareForInsert(data, c.GetIdPrefix(name)) + + if _, err := collection.InsertOne(ctx, data); err != nil { + return nil, err + } + + return data, nil +} + +// prepareForInsert takes a map[string]any and: +// * adds a new _id if property does not exist or is an empty string +// * adds/updates property created_at using the current timestamp. +func prepareForInsert(data bson.M, idPrefix string) { + ensureId(data, idPrefix) + ensureCreatedAt(data) + ensureUpdatedAt(data) +} + +// ensureId adds the id property when missing or when it's an empty string. +func ensureId(data bson.M, idPrefix string) string { + maybeId, hasId := data["_id"] + + var id, finalId string + if !hasId || maybeId == "" { + id, _ = gonanoid.Generate(alphabet, maxLen) + if idPrefix != "" { + finalId = fmt.Sprintf("%s_%s", idPrefix, id) + } else { + finalId = id + } + + data["_id"] = finalId + } + + return finalId +} + +func ensureCreatedAt(data bson.M) time.Time { + now := time.Now().UTC().Truncate(time.Millisecond) + + data["created_at"] = now + + return now +} + +func ensureUpdatedAt(data bson.M) time.Time { + now := time.Now().UTC().Truncate(time.Millisecond) + + data["updated_at"] = now + + return now +} diff --git a/insert_test.go b/insert_test.go new file mode 100644 index 0000000..465fa8f --- /dev/null +++ b/insert_test.go @@ -0,0 +1,175 @@ +package mongo + +import ( + "os" + "context" + "strings" + "testing" + "testing/synctest" + "encoding/json" + + "go.mongodb.org/mongo-driver/v2/bson" +) + +func TestInsertOne(t *testing.T) { + client := GetMongoClient() + + data := map[string]any { + "_id": "su_123456", + "name": "MyName", + "age": int32(25), + } + + o, err := client.InsertOne(context.Background(), "mydb", "mycollection", data) + if err != nil { + t.Fatalf("Failed to insertOne %#v", err) + } + + // raw query + var results bson.M + filter := bson.M{ "name": "MyName" } + + c := client.Client.Database("mydb").Collection("mycollection") + c.FindOne(context.Background(), filter).Decode(&results) + + AssertSubset(t, o, results, "Should have been equal") +} + +func TestInsertOne_DataStruct(t *testing.T) { + client := GetMongoClient() + + data := &Sample { + Id: "su_123457", + Name: "MyName200", + Age: 25, + } + + o, err := client.InsertOneFromStruct(context.Background(), "mydb", "mycollection", data) + if err != nil { + t.Fatalf("Failed to insertOne %#v", err) + } + + // raw query + var results bson.M + filter := map[string]any { "name": "MyName200" } + + c := client.Client.Database("mydb").Collection("mycollection") + c.FindOne(context.Background(), filter).Decode(&results) + + AssertSubset(t, o, results, "Should have been equal") +} + +func TestInsertOne_WithIdPrefix(t *testing.T) { + // Read JSON file + data, err := os.ReadFile("./.test/user.json") + if err != nil { + t.Fatal(err) + } + + var user bson.M + if err := json.Unmarshal(data, &user); err != nil { + t.Fatalf("Length: %d, First bytes: %x\n", len(data), data[:4]) + } + + client := GetMongoClient() + client.AddDefinition(user) + + in := map[string]any { + "name": "MyName112", + "age": int32(25), + } + + o, err := client.InsertOne(context.Background(), "mydb", "user", in) + if err != nil { + t.Fatalf("Failed to insertOne %#v", err) + } + + // raw query + var results bson.M + filter := map[string]any { "name": "MyName112" } + c := client.Client.Database("mydb").Collection("user") + c.FindOne(context.Background(), filter).Decode(&results) + + if !strings.HasPrefix(results["_id"].(string), "usr_") { + t.Fatal("_id should have been prefixed") + } + + AssertSubset(t, o, results, "Should have been equal") +} + +func TestPrepareForInsert_WithoutId(t *testing.T) { + data := map[string]any { + "name": "My Name Is", + } + + prepareForInsert(data, "") + + if id, okid := data["_id"]; !okid || id == "" { + t.Fatal("Failed to add Id") + } +} + +func TestPrepareForInsert_ExistingId(t *testing.T) { + data := map[string]any { + "_id": "myidxxxxxx", + "name": "My Name Is", + } + + prepareForInsert(data, "") + + if data["_id"] == "" { + t.Fatal("id was updated") + } +} + +func TestEnsureId(t *testing.T) { + data := map[string]any { + "Name": "My Name Is", + } + + ensureId(data, "") + + if id, okid := data["_id"]; !okid || id == "" { + t.Fatal("Failed to add Id") + } +} + +func TestEnsureId_EmptyId(t *testing.T) { + data := map[string]any { + "_id": "myidxxxxxx", + "name": "My Name Is", + } + + ensureId(data, "") + + if data["_id"] == "" { + t.Fatal("Id was updated") + } +} + +func TestEnsureId_ExistingId(t *testing.T) { + data := map[string]any { + "_id": "", + "name": "My Name Is", + } + + ensureId(data, "") + + if id, okid := data["_id"]; !okid || id == "" { + t.Fatal("Failed to add Id") + } +} + +func TestEnsureCreatedAt(t *testing.T) { + data := map[string]any { + "name": "My Name Is", + } + + synctest.Test(t, func(t *testing.T) { + now := ensureCreatedAt(data) + + if createdAt, _ := data["created_at"]; createdAt != now { + t.Fatal("Failed to add CreatedAt") + } + }) +} diff --git a/main.go b/main.go new file mode 100644 index 0000000..b3d825c --- /dev/null +++ b/main.go @@ -0,0 +1,352 @@ +// Package mongo provides a simple interface to the database. +package mongo + +import ( + "context" + "log" + "slices" + "strings" + "time" + "fmt" + "encoding/json" + + "go.mongodb.org/mongo-driver/v2/event" + "go.mongodb.org/mongo-driver/v2/mongo" + "go.mongodb.org/mongo-driver/v2/mongo/options" + "go.mongodb.org/mongo-driver/v2/bson" + + "git.gsuntres.com/boxtep/boxtep/core" + "git.gsuntres.com/boxtep/boxtep/sys" +) + +// WithiSessionFunc will run operations on the database within the same session. +// If this functions returns an error, the system will rollback the transaction. +// Sessions require a resplica set. +type WithinSessionFunc func(context.Context, *mongo.Session) (any, error) + +type IMongoClient interface { + // WithinSession takes a context, the database and the collection and excecutes the callback function provided. + WithinSession(context.Context, string, string, WithinSessionFunc) (any, error) + + // InsertOne will insert data to the specified namespace. + InsertOne(context.Context, string, string, bson.M) (bson.M, error) + + // Find + Find(context.Context, string, string, bson.M) ([]bson.M, error) + + // GetCollection returns the requested collection within an account. It will create it if it doesn't exist. + GetCollection(database, name string) *mongo.Collection +} + +type CollectionDefinition struct { + Name string `bson:"_name"` + Singular string `bson:"singular"` + Plural string `bson: "plural"` + IdPrefix string `bson: "idPrefix"` + IndexSpecs []map[string]any `bson: "indexSpecs"` + Schema map[string]any `bson: "schema"` +} + +// func (cd *CollectionDefinition) GetSchema(name string) + +// MongoClient +type MongoClient struct { + // Client the actual connected instance of mongo client. + Client *mongo.Client + // Debug set to true for dislaying info level logs. + Debug bool + // DebugQuery set to true to log all queries done. + DebugQuery bool + // Limit the default limit to use in queries. + Limit int64 + // DBPrefix optinal string literal to distinquise + DBPrefix string + // Registry holds critical information about collection's structure, like schema and indexes + Registry map[string]*CollectionDefinition +} + +func (c *MongoClient) GetIdPrefix(name string) string { + def, ok := c.Registry[name] + if ok { + return def.IdPrefix + } + + return "" +} + +const ADD_DEFINITION_SCHEMA = ` +{ + "type": "object", + "properties": { + "_name": { + "type": "string" + }, + "idPrefix": { + "type": "string" + }, + "indexSpecs": { + "type": "array" + }, + "singular": { + "type": "string" + }, + "system": { + "type": "boolean" + }, + "plural": { + "type": "string" + } + }, + "required": ["_name", "singular", "plural"], + "additionalProperties": true +} +` + +func (c *MongoClient) AddDefinition(data map[string]any) { + if valid := core.Validate(ADD_DEFINITION_SCHEMA, data); valid != nil { + log.Printf("failed to register data: %v", valid) + + return + } + + log.Printf("Registering %s", data["_name"]) + + b, err := bson.Marshal(data) + if err != nil { + log.Printf("failed to marshal: %v", err) + + return + } + + var cd CollectionDefinition + bson.Unmarshal(b, &cd) + + if len(cd.Name) > 0 { + c.Registry[cd.Name] = &cd + } +} + +func (c *MongoClient) GetCollection(database, name string) *mongo.Collection { + if c.Debug { + log.Printf("Using collection: %s.%s", database, name) + } + + db := c.Client.Database(database) + + // 1. List existing collections + names, err := db.ListCollectionNames(context.TODO(), bson.D{}) + + if err != nil { + log.Printf("Failed to list collections: %#v", err.Error()) + + return nil + } + + // 2. If collection exist return it, otherwise create it and then return it + if slices.Contains(names, name) { + return db.Collection(name) + } else { + opts := options.CreateCollection() + + // maybe get from schema + cdef, ok := c.Registry[name] + if ok { + log.Printf("Schema found for %s; will use it", name) + + ApplySchema(cdef, opts) + } else { + log.Printf("No schema for %s", name) + } + + if err := db.CreateCollection(context.TODO(), name, opts); err != nil { + log.Printf("Failed to create collection: %#v", err) + + return nil + } + + collection := db.Collection(name) + + c.CreateIndexes(collection, cdef) + + return collection + } +} + +func ApplySchema(cdef *CollectionDefinition, opts *options.CreateCollectionOptionsBuilder) { + // Add schema validation + if cdef.Schema != nil { + schemaBson, err := bson.Marshal(cdef.Schema) + if err != nil { + log.Printf("failed to parse schema: %v", err) + } else { + var validatorSchema bson.M + bson.Unmarshal(schemaBson, &validatorSchema) + validator := bson.M{ + "$jsonSchema": validatorSchema, + } + + opts.SetValidator(validator) + } + } else { + log.Printf("Invalid schema, do nothing.") + } +} + +var client *MongoClient = &MongoClient{ + Limit: 10, + Registry: make(map[string]*CollectionDefinition, 0), +} + +func GetMongoClient() *MongoClient { + return client +} + +const validSchema_StartProps = ` +{ + "type": "object", + "properties": { + "MongoUri": { + "type": "string" + }, + "MongoUser": { + "type": "string" + }, + "MongoPass": { + "type": "string" + } + }, + "required": ["MongoUri"] +} +` + +type MongoStartProps struct { + MongoUri string + MongoUser string + MongoPass string + MongoDebugQuery bool + MongoDBPrefix string +} + +func Start(props *MongoStartProps) error { + if err := core.Validate(validSchema_StartProps, props); err != nil { + return err + } + + uri := props.MongoUri + user := props.MongoUser + pass := props.MongoPass + + // Create a new client and connect to the server + var err error + + cOptions := options.Client(). + ApplyURI(uri). + SetAuth(options.Credential{ + Username: user, + Password: pass, + }). + SetConnectTimeout(5 * time.Second). + SetServerAPIOptions(&options.ServerAPIOptions{ + ServerAPIVersion: options.ServerAPIVersion1, + }). + SetBSONOptions(&options.BSONOptions{ + DefaultDocumentM: true, + UseJSONStructTags: false, + NilSliceAsEmpty: true, + NilMapAsEmpty: true, + }). + SetRegistry(GetCustomRegistry()) + + client.DebugQuery = props.MongoDebugQuery + if client.DebugQuery { + // Debug queries + monitor := &event.CommandMonitor{ + Started: func(_ context.Context, e *event.CommandStartedEvent) { + log.Printf("%d@Start %s#%s %s", e.RequestID, e.DatabaseName, e.CommandName, e.Command) + }, + Succeeded: func(_ context.Context, e *event.CommandSucceededEvent) { + log.Printf("%d@OK in %s", e.RequestID, e.Reply) + }, + Failed: func(_ context.Context, e *event.CommandFailedEvent) { + log.Printf("%d@Fail in %s", e.RequestID, e.Failure) + }, + } + + cOptions = cOptions.SetMonitor(monitor) + } + + // set DBPrefix + client.DBPrefix = props.MongoDBPrefix + + if err := cOptions.Validate(); err != nil { + log.Fatalf("Failed to validate mongo options: %+v", err.Error()) + + return err + } + + client.Client, err = mongo.Connect(cOptions) + + if err != nil { + log.Fatalf("Failed to connect: %+v", err.Error()) + + return err + } + + sys.OnExit(func () { + if err := client.Client.Disconnect(context.TODO()); err != nil { + log.Fatalf("failed to disconnect: %v", err) + } + }) + + // Register defaults.go + var structfulDef bson.M + json.Unmarshal(structfulJson, &structfulDef) + + client.AddDefinition(structfulDef) + + return nil +} + +func GetClient() *mongo.Client { + return client.Client +} + +func (c *MongoClient) GetName(account string) string { + return fmt.Sprintf("%s%s", c.DBPrefix, account) +} + +func (c *MongoClient) ExtractAccount(name string) string { + v, ok := strings.CutPrefix(name, c.DBPrefix) + + if !ok { + return "" + } + + return v +} + +func Stop() error { + if err := client.Client.Disconnect(context.TODO()); err != nil { + return err + } + + log.Printf("Successfully stopped mongo.") + + return nil +} + +func MapToBsonD(m map[string]any) (bson.D, error) { + data, err := bson.Marshal(m) + if err != nil { + return nil, err + } + + var d bson.D + + err = bson.Unmarshal(data, &d) + if err != nil { + return nil, err + } + + return d, nil +} diff --git a/main_index.go b/main_index.go new file mode 100644 index 0000000..e24ac07 --- /dev/null +++ b/main_index.go @@ -0,0 +1,75 @@ +package mongo + +import ( + "context" + "log" + + "go.mongodb.org/mongo-driver/v2/mongo" + "go.mongodb.org/mongo-driver/v2/mongo/options" + "go.mongodb.org/mongo-driver/v2/bson" +) + +// CreateIndexes will create indexes for the given collection and definition. +func (c *MongoClient) CreateIndexes(collection *mongo.Collection, cdef *CollectionDefinition) { + if cdef == nil { + log.Printf("No definitions will not create indexes") + + return + } + + // handle indexes + indexModels := make([]mongo.IndexModel, 0) + + for _, keyDef := range cdef.IndexSpecs { + log.Printf("Key Definition %s", keyDef["name"]) + kdb, err := bson.Marshal(keyDef) + if err != nil { + log.Printf("failed to marshal %v", err) + + continue + } + + kdRaw := bson.Raw(kdb) + if err := kdRaw.Validate(); err != nil { + log.Printf("failed to validate bson raw: %v", err) + + continue + } + // + idxModel := mongo.IndexModel{} + + opts := options.Index() + + keysVal := kdRaw.Lookup("keys") + + var keysBson bson.D + if err := bson.Unmarshal(keysVal.Value, &keysBson); err != nil { + log.Printf("failed to unmarshal keys value %v", err) + + continue + } + + idxModel.Keys = keysBson + + nameVal := kdRaw.Lookup("name") + if name, ok := nameVal.StringValueOK(); ok { + opts = opts.SetName(name) + } + + uniqueVal := kdRaw.Lookup("unique") + if unique, ok := uniqueVal.BooleanOK(); ok { + opts = opts.SetUnique(unique) + } + + partialVal := kdRaw.Lookup("partialFilterExpression") + if partialFilterExpression, ok := partialVal.BooleanOK(); ok { + opts = opts.SetPartialFilterExpression(partialFilterExpression) + } + + idxModel.Options = opts + + indexModels = append(indexModels, idxModel) + } + + collection.Indexes().CreateMany(context.Background(), indexModels) +} diff --git a/main_index_test.go b/main_index_test.go new file mode 100644 index 0000000..d9a1f07 --- /dev/null +++ b/main_index_test.go @@ -0,0 +1,56 @@ +package mongo + +import ( + "context" + "testing" + "time" + + "go.mongodb.org/mongo-driver/v2/bson" +) + +func TestCreateIndexes(t *testing.T) { + cd := &CollectionDefinition{ + IndexSpecs: []map[string]any{ + { + "keys": map[string]any{ + "code": 1, + }, + "name": "idx_1", + }, + }, + } + + c := GetMongoClient() + + collection := c.GetCollection("mydb", "mycol") + + c.CreateIndexes(collection, cd) + + indexView := collection.Indexes() + + // Specify a timeout to limit the amount of time the operation can run on + // the server. + ctx, cancel := context.WithTimeout(context.TODO(), time.Second) + defer cancel() + + cursor, err := indexView.List(ctx, nil) + if err != nil { + t.Fatal(err) + } + + // Get a slice of all indexes returned and print them out. + var results []bson.M + if err = cursor.All(ctx, &results); err != nil { + t.Fatal(err) + } + + idx1 := results[1] + + if len(results) != 2 { + t.Fatal("Unexpected number of indexes") + } + + if idx1["name"] != "idx_1" { + t.Fatal("Should have register index idx_1") + } +} diff --git a/main_test.go b/main_test.go new file mode 100644 index 0000000..4ce6dc1 --- /dev/null +++ b/main_test.go @@ -0,0 +1,134 @@ +package mongo + +import ( + "os" + "time" + "reflect" + "testing" + "context" + "fmt" + "encoding/json" + + "go.mongodb.org/mongo-driver/v2/mongo/options" + "go.mongodb.org/mongo-driver/v2/bson" + + "github.com/testcontainers/testcontainers-go" + "github.com/testcontainers/testcontainers-go/modules/mongodb" +) + +type Sample struct { + Id string `json:"id" bson:"_id,omitempty"` + Name string `json:"name" bson:"name"` + Age int32 `json:"age" bson:"age"` + CreatedAt time.Time `json:"created_at" bson:"created_at"` +} + +func TestMain(m *testing.M) { + if os.Getenv("RUN_INTEGRATION") == "" { + fmt.Println("Skipping package tests: RUN_INTEGRATION is missing") + os.Exit(0) // Exit with success, but no tests ran + } + + ctx := context.Background() + + user := "admin" + pass := "1" + + // 1. Setup: Start the MongoDB container + mongoContainer, err := mongodb.Run(ctx, "mongo:8", + testcontainers.WithEnv(map[string]string{ + "MONGO_INITDB_ROOT_USERNAME": user, + "MONGO_INITDB_ROOT_PASSWORD": pass, + }),) + if err != nil { + panic("failed to start container") + } + + // 2. Get the connection string dynamically + endpoint, _ := mongoContainer.ConnectionString(ctx) + + Start(&MongoStartProps{ + MongoUri: endpoint, + MongoUser: user, + MongoPass: pass, + }) + + // 3. Run tests + LoadTestSample() + + code := m.Run() + + // 4. Teardown: Clean up resources + Stop() + + _ = testcontainers.TerminateContainer(mongoContainer) + + os.Exit(code) +} + +func TestApplySchema(t *testing.T) { + // prepare schema sample + schemaStr := ` + { + "bsonType": "object", + "properties": { + "name": { + "bsonType": "string" + }, + "active": { + "bsonType": "boolean" + } + }, + "required": ["name"], + "additionalProperties": true + } + ` + + var result bson.M + err := json.Unmarshal([]byte(schemaStr), &result) + if err != nil { + t.Fatalf("%v", err) + } + + cd := &CollectionDefinition{ + Schema: result, + } + + opts := options.CreateCollection() + + ApplySchema(cd, opts) + + options := &options.CreateCollectionOptions{} + + for _, o := range opts.List() { + o(options) + } + + expected := bson.M{"$jsonSchema":bson.M{"additionalProperties":true, "bsonType":"object", "properties":bson.D{bson.E{Key:"active", Value:bson.D{bson.E{Key:"bsonType", Value:"boolean"}}}, bson.E{Key:"name", Value:bson.D{bson.E{Key:"bsonType", Value:"string"}}}}, "required":bson.A{"name"}}} + + aj, _ := json.Marshal(options.Validator) + ej, _ := json.Marshal(expected) + + var am, em any + json.Unmarshal(aj, &am) + json.Unmarshal(ej, &em) + + if !reflect.DeepEqual(am, em) { + t.Fatalf("Validator should have been set") + } +} + +func TestGetCollection_Default(t *testing.T) { + c := GetMongoClient() + + col := c.GetCollection("mydb", "mycol") + + if col == nil { + t.Fatal("should have return collection") + } + + if col.Name() != "mycol" { + t.Fatal("Wrong name") + } + +} diff --git a/pipeline.go b/pipeline.go new file mode 100644 index 0000000..e39eda3 --- /dev/null +++ b/pipeline.go @@ -0,0 +1,52 @@ +package mongo + +import ( + "go.mongodb.org/mongo-driver/v2/bson" + "go.mongodb.org/mongo-driver/v2/mongo" +) + +func BuildPaginationPipeline(skip, limit int64, filter bson.M, sort bson.M) mongo.Pipeline { + return mongo.Pipeline{ + // 1. GLOBAL FILTER: Always filter first to use indexes + {{Key: "$match", Value: filter}}, + + // 2. GLOBAL SORT: Sort here so both 'total' and 'data' facets use the same order + {{Key: "$sort", Value: sort}}, + + // 3. FACET: Split the pipeline into two parallel paths + {{Key: "$facet", Value: bson.D{ + // Path A: Get the total count of documents matching the filter + {Key: "metadata", Value: mongo.Pipeline{ + {{Key: "$count", Value: "total"}}, + }}, + // Path B: Get the specific page of data + {Key: "data", Value: mongo.Pipeline{ + {{Key: "$skip", Value: skip}}, + {{Key: "$limit", Value: limit}}, + }}, + }}}, + } +} + +func BuildPaginationPipelineNext(limit int64, filter bson.M, sort bson.M) mongo.Pipeline { + + return mongo.Pipeline{ + // 1. GLOBAL FILTER: Always filter first to use indexes + {{Key: "$match", Value: filter}}, + + // 2. GLOBAL SORT: Sort here so both 'total' and 'data' facets use the same order + {{Key: "$sort", Value: sort}}, + + // 3. FACET: Split the pipeline into two parallel paths + {{Key: "$facet", Value: bson.D{ + // Path A: Get the total count of documents matching the filter + {Key: "metadata", Value: mongo.Pipeline{ + {{Key: "$count", Value: "total"}}, + }}, + // Path B: Get the specific page of data + {Key: "data", Value: mongo.Pipeline{ + {{Key: "$limit", Value: limit}}, + }}, + }}}, + } +} \ No newline at end of file diff --git a/registry.go b/registry.go new file mode 100644 index 0000000..d84ebc6 --- /dev/null +++ b/registry.go @@ -0,0 +1,88 @@ +package mongo + +import ( + "reflect" + "time" + + "github.com/go-viper/mapstructure/v2" + "go.mongodb.org/mongo-driver/v2/bson" +) + +func TimeToStringHook(f reflect.Type, t reflect.Type, data any) (any, error) { + + // 1. Target must be a string + if t.Kind() != reflect.String { + return data, nil + } + + // 2. Check if source is time.Time OR *time.Time + isTime := f == reflect.TypeOf(time.Time{}) + isTimePtr := f == reflect.TypeOf(&time.Time{}) + + if isTime || isTimePtr { + // Handle pointer vs value during type assertion + if isTimePtr { + if ptr, ok := data.(*time.Time); ok && ptr != nil { + + return ptr.Format(time.RFC3339), nil + } + + return "", nil + } + + return data.(time.Time).Format(time.RFC3339), nil + } + + return data, nil +} + +func StructToMap(input any) (map[string]any, error) { + var result map[string]any + + config := &mapstructure.DecoderConfig{ + DecodeHook: TimeToStringHook, + Result: &result, + TagName: "bson", + } + + decoder, err := mapstructure.NewDecoder(config) + if err != nil { + return nil, err + } + + err = decoder.Decode(input) + + return result, err +} + +func TruncatingTimeEncoder(ec bson.EncodeContext, vw bson.ValueWriter, val reflect.Value) error { + // 1. Check if the type is exactly time.Time + if val.Type() != reflect.TypeOf(time.Time{}) { + // Fallback: Use the encoder from the current context's registry + enc, err := ec.Registry.LookupEncoder(val.Type()) + if err != nil { + return err + } + + return enc.EncodeValue(ec, vw, val) + } + + // 2. Perform the truncation logic + t := val.Interface().(time.Time) + truncated := t.Truncate(time.Millisecond) + + // 3. Write as a standard BSON DateTime + return vw.WriteDateTime(truncated.UnixMilli()) +} + +func GetCustomRegistry() *bson.Registry { + reg := bson.NewRegistry() + + // reg.RegisterTypeEncoder( + // reflect.TypeOf(time.Time{}), + // bson.ValueEncoderFunc(TruncatingTimeEncoder), + // ) + // reg.SetRegistry(mgocompat.NewRegistry()) + + return reg +} \ No newline at end of file diff --git a/session.go b/session.go new file mode 100644 index 0000000..68a8260 --- /dev/null +++ b/session.go @@ -0,0 +1,40 @@ +package mongo + +import ( + "context" + + "go.mongodb.org/mongo-driver/v2/mongo" +) + + +func (c *MongoClient) WithinSession(ctx context.Context, cb WithinSessionFunc) (any, error) { + + // 1. Start a new session + sess, err := c.Client.StartSession() + if err != nil { + return nil, err + } + + defer sess.EndSession(ctx) + + // 2. Create a new context + ctxNew := mongo.NewSessionContext(ctx, sess) + + // 3. Start transaction + if err = sess.StartTransaction(); err != nil { + return nil, err + } + + res, err := cb(ctxNew, sess) + if err != nil { + _ = sess.AbortTransaction(context.Background()) + + return nil, err + } + + if err := sess.CommitTransaction(context.Background()); err != nil { + return nil, err + } + + return res, nil +} \ No newline at end of file diff --git a/testing.go b/testing.go new file mode 100644 index 0000000..dc452e9 --- /dev/null +++ b/testing.go @@ -0,0 +1,90 @@ +package mongo + +import ( + "fmt" + "testing" + "time" + "context" + "os" + + "go.mongodb.org/mongo-driver/v2/bson" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" +) + +const SAMPLE_DB = "mydb" +const SAMPLE_COLLECTION = "mycollection" + +// AssertSubset compares a subset map against a superset (from Mongo). +func AssertSubset(t *testing.T, subset, superset map[string]any, msgAndArgs ...any) bool { + t.Helper() + + // 1. Filter: Ignore keys in the superset that aren't in our expected subset + ignoreExtra := cmpopts.IgnoreMapEntries(func(k string, v any) bool { + _, ok := subset[k] + return !ok + }) + + // 2. Transformer: Convert Mongo v2 int64 back to time.Time for comparison + // 1. Define the logic + timeLogic := func(v any) any { + switch val := v.(type) { + case time.Time: + return val.UTC() + case bson.DateTime: + return val.Time().UTC() + case int64: + return time.UnixMilli(val).UTC() + default: + return time.Time{} + } + } + + // 2. Wrap it in a Filter so it's not an "unfiltered" option + timeTransform := cmp.FilterValues(func(x, y any) bool { + _, xisbsontime := x.(bson.DateTime) + _, yistime := y.(time.Time) + _, xistime := x.(time.Time) + _, yisbsontime := y.(bson.DateTime) + + return (xisbsontime && yistime) || (yisbsontime && xistime) + }, cmp.Transformer("BsonTime", timeLogic)) + + // 3. Execute Diff + diff := cmp.Diff(subset, superset, + ignoreExtra, + timeTransform, + cmpopts.EquateEmpty(), // Treats nil slice == empty slice + ) + + if diff != "" { + msg := "Maps do not match (subset != superset)" + if len(msgAndArgs) > 0 { + msg = fmt.Sprintf(msgAndArgs[0].(string), msgAndArgs[1:]...) + } + t.Errorf("%s:\n%s", msg, diff) + + return false + } + + return true +} + +func LoadTestSample() error { + // Read JSON file + jsonData, _ := os.ReadFile("./.test/sample.json") + + var docs []bson.M + if err := bson.UnmarshalExtJSON(jsonData, true, &docs); err != nil { + return err + } + + // Insert into MongoDB + client := GetMongoClient() + collection := client.Client.Database(SAMPLE_DB).Collection(SAMPLE_COLLECTION) + if _, err := collection.InsertMany(context.TODO(), docs); err != nil { + return err + } + + return nil +} diff --git a/types.go b/types.go new file mode 100644 index 0000000..f8f737b --- /dev/null +++ b/types.go @@ -0,0 +1,26 @@ +package mongo + +// TODO(me): Place it in find.go? +type FindRequest struct { + Entity string + Limit int + Skip int + Filter map[string]any +} + +// TODO(me): Remove it in favor of [FindResult]? +type DataResults struct { + Data []map[string]any +} + +// TODO(me): Remove it in favor of [FindResult]? +type Metadata struct { + Total int64 `bson:"total"` +} +// TODO(me): Remove it in favor of [FindResult]? +type PaginatedResponse struct { + Metadata []Metadata `bson:"metadata" json:"metadata"` + // Initializing this as an empty slice literal []User{} + // prevents "null" in your JSON output. + Data []map[string]any `bson:"data" json:"data"` +}