Add UpdateSet, Relaxed in client, Fix tests

This commit is contained in:
George Suntres
2026-04-22 10:22:23 -04:00
parent 45f9ac558f
commit 188c5a1be1
9 changed files with 163 additions and 31 deletions

View File

@@ -6,6 +6,8 @@ import (
"context"
"go.mongodb.org/mongo-driver/v2/bson"
"git.gsuntres.com/general/commons"
)
func (c *MongoClient) DiscriminatorCheckAndApplyToData(ctx context.Context, name string, data map[string]any) error {
@@ -31,6 +33,24 @@ func (c *MongoClient) DiscriminatorCheckAndApplyToData(ctx context.Context, name
return nil
}
func (c *MongoClient) DiscriminatorOmitInData(name string, data bson.M) error {
cdef, ok := c.Registry[name]
if ok && cdef.Discriminator != nil {
if data == nil {
data = map[string]any{}
}
log.Printf("Making sure discriminator is not in data for %s", name)
_, ok := data[cdef.Discriminator.Field]
if ok {
delete(data, cdef.Discriminator.Field)
}
}
return nil
}
func (c *MongoClient) DiscriminatorCheckAndApplyToFilter(ctx context.Context, name string, filter bson.M) error {
cdef, ok := c.Registry[name]
if ok && cdef.Discriminator != nil {
@@ -48,8 +68,13 @@ func (c *MongoClient) DiscriminatorCheckAndApplyToFilter(ctx context.Context, na
// update payload
v := vAny.(string)
filter[cdef.Discriminator.Field] = bson.M{"$eq": v}
if commons.StringIsBlank(v) {
return fmt.Errorf("discriminator field required for %s", name)
}
filter[cdef.Discriminator.Field] = bson.M{"eq": v}
}
return nil
}

View File

@@ -1,6 +1,10 @@
package mongo
import "log"
import (
"log"
"go.mongodb.org/mongo-driver/v2/bson"
)
type Query struct {
Filter map[string]any `json:"filter"`
@@ -17,12 +21,27 @@ func makeFilter(name string, value any) *Filter {
var v string
log.Printf("TYPE %T", value)
switch value.(type) {
case string:
op = "eq"
v = value.(string)
case bson.M:
vMap := value.(bson.M)
log.Printf("GEO its map[string]any %v", vMap)
for kk, vv := range vMap {
op = kk
v = vv.(string)
break
}
case map[string]any:
vMap := value.(map[string]any)
log.Printf("GEO its bson.M %v", vMap)
for kk, vv := range vMap {
op = kk
v = vv.(string)
@@ -40,6 +59,8 @@ func makeFilter(name string, value any) *Filter {
Value: v,
}
log.Printf("FILTER -> %#v", o)
return o
}
@@ -48,7 +69,7 @@ func Mongofy(q *Query) map[string]any {
conditions := make([]map[string]interface{}, 0)
logic := "and"
log.Printf("GEO q.Filter %#v", q.Filter)
for k, v := range q.Filter {
if k == "_logic" {
logic = v.(string)

19
find.go
View File

@@ -37,7 +37,11 @@ func (c *MongoClient) Find(ctx context.Context, database, name string, filter bs
return nil, err
}
pipeline := BuildPaginationPipeline(0, pageSize + 1, filter, sort)
f := Mongofy(&Query{
Filter: filter,
})
pipeline := BuildPaginationPipeline(0, pageSize + 1, f, sort)
// 2. Query
cursor, err := collection.Aggregate(ctx, pipeline)
@@ -114,6 +118,10 @@ func (c *MongoClient) FindOffset(ctx context.Context, database, name string, fil
limit = opts.Limit
}
if opts != nil {
offset = opts.Offset
}
// 1. Prepare to query.
finalName := name
if opts != nil && commons.StringIsNotBlank(opts.Alias) {
@@ -175,9 +183,13 @@ func (c *MongoClient) FindOffset(ctx context.Context, database, name string, fil
}
hasMore := false
if int64(len(data)) > finalLimit {
if total > offset + finalLimit {
hasMore = true
data = data[:finalLimit]
}
hasPrevious := false
if finalLimit - offset < 0 {
hasPrevious = true
}
out := bson.M{
@@ -185,6 +197,7 @@ func (c *MongoClient) FindOffset(ctx context.Context, database, name string, fil
"offset": offset,
"limit": finalLimit,
"has_more": hasMore,
"has_previous": hasPrevious,
"total": total,
}

View File

@@ -76,33 +76,25 @@ func TestFind_Discriminator(t *testing.T) {
client := GetMongoClient()
client.AddDefinition(store)
client.AddDefinition(offer)
// Save two offers with the similar name in different stores each.
// Store str_1234 has OSRAM 1.
ctx1 := context.Background()
ctx1 = context.WithValue(ctx1, "account", "xxxxxx")
ctx1 = context.WithValue(ctx1, "store", "str_1234")
// One offer in str_1234
offer1 := map[string]any { "name": "OSRAM 1" }
_, err = client.InsertOne(ctx1, "mydb", "offer", offer1)
if err != nil { t.Fatalf("Failed to insertOne %#v", err) }
// The other in str_4321
// Store str_4321 has OSRAM 2
ctx2 := context.Background()
ctx2 = context.WithValue(ctx2, "account", "xxxxxx")
ctx2 = context.WithValue(ctx2, "store", "str_4321")
offer2 := map[string]any { "name": "OSRAM 2" }
_, err = client.InsertOne(ctx2, "mydb", "offer", offer2)
if err != nil { t.Fatalf("Failed to insertOne %#v", err) }
// Now searching in store str_1234 for OSRAM should return only one
// Within store str_1234 searching for OSRAM should return one result
filter := bson.M{"name": bson.M{"$regex": "OSRAM*"}}
findResult, err := client.Find(ctx1, "mydb", "offer", filter, &FindOptions{
Offset: int64(0),
})
findResult, err := client.Find(ctx1, "mydb", "offer", filter, &FindOptions{ Offset: int64(0) })
if err != nil { t.Fatalf("Failed to find %#v", err) }
dataAny, hasData := findResult["data"]

52
main.go
View File

@@ -40,6 +40,10 @@ type IMongoClient interface {
// GetCollection returns the requested collection within an account. It will create it if it doesn't exist.
GetCollection(database, name string) *mongo.Collection
SetRelaxed()
SetStrict()
}
type Timeseries struct {
@@ -81,6 +85,16 @@ type MongoClient struct {
DBPrefix string
// Registry holds critical information about collection's structure, like schema and indexes
Registry map[string]*CollectionDefinition
// Relaxed if set to true will not enforce schema
Relaxed bool
}
func (c *MongoClient) SetRelaxed() {
c.Relaxed = true
}
func (c *MongoClient) SetStrict() {
c.Relaxed = false
}
func (c *MongoClient) GetIdPrefix(name string) string {
@@ -114,6 +128,13 @@ func (c *MongoClient) AddDefinition(data map[string]any) {
if len(cd.Name) > 0 {
c.Registry[cd.Name] = &cd
}
if cd.Views != nil {
for k, _ := range cd.Views {
log.Printf("Registering alias %s to %s", k, cd.Name)
c.Registry[k] = &cd
}
}
}
func (c *MongoClient) DropDatabase_DANGER(database string) bool {
@@ -139,6 +160,9 @@ func (c *MongoClient) GetCollection(database, name string) *mongo.Collection {
db := c.Client.Database(database)
// To make sure we create the right collection when dealing with aliases.
actualName := name
// 1. List existing collections
names, err := db.ListCollectionNames(context.TODO(), bson.D{})
@@ -149,32 +173,41 @@ func (c *MongoClient) GetCollection(database, name string) *mongo.Collection {
}
// 2. If collection exist return it, otherwise create it and then return it
if slices.Contains(names, name) {
return db.Collection(name)
if slices.Contains(names, actualName) {
return db.Collection(actualName)
} else {
opts := options.CreateCollection()
// maybe get from schema
cdef, ok := c.Registry[name]
cdef, ok := c.Registry[actualName]
if ok {
log.Printf("Schema found for %s; will use it", name)
actualName = cdef.Name
log.Printf("Schema found for %s; will use it", actualName)
ApplyTimeSeries(cdef, opts)
ApplySchema(cdef, opts)
} else {
log.Printf("No schema for %s", name)
log.Printf("No schema for %s", actualName)
if c.Relaxed == false {
return nil
}
}
if err := db.CreateCollection(context.TODO(), name, opts); err != nil {
if err := db.CreateCollection(context.TODO(), actualName, opts); err != nil {
log.Printf("Failed to create collection: %#v", err)
return nil
}
collection := db.Collection(name)
collection := db.Collection(actualName)
c.CreateIndexes(collection, cdef)
c.CreateViews(db, cdef)
if c.CreateViews(db, cdef) {
collection = db.Collection(name)
}
return collection
}
@@ -216,6 +249,7 @@ func ApplySchema(cdef *CollectionDefinition, opts *options.CreateCollectionOptio
var client *MongoClient = &MongoClient{
Limit: 10,
Registry: make(map[string]*CollectionDefinition, 0),
Relaxed: false,
}
func GetMongoClient() *MongoClient {

View File

@@ -25,8 +25,12 @@ func TestCreateIndexes(t *testing.T) {
c := GetMongoClient()
c.SetRelaxed()
collection := c.GetCollection("mydb", "mycol")
c.SetStrict()
c.CreateIndexes(collection, cd)
indexView := collection.Indexes()

View File

@@ -47,11 +47,13 @@ func TestMain(m *testing.M) {
// 2. Get the connection string dynamically
endpoint, _ := mongoContainer.ConnectionString(ctx)
mongoDebug := os.Getenv("DEBUG") != ""
Start(&MongoStartProps{
MongoUri: endpoint,
MongoUser: user,
MongoPass: pass,
MongoDebugQuery: false,
MongoDebugQuery: mongoDebug,
})
// 3. Run tests

View File

@@ -10,13 +10,15 @@ import (
)
// CreateViews will create views for the given collection and definition.
func (c *MongoClient) CreateViews(db *mongo.Database, cdef *CollectionDefinition) {
func (c *MongoClient) CreateViews(db *mongo.Database, cdef *CollectionDefinition) bool {
if cdef == nil || cdef.Views == nil {
log.Printf("No definition for views found.")
return
return false
}
viewCreated := false
for name, defVal := range cdef.Views {
// 1. Decode definition
@@ -82,6 +84,9 @@ func (c *MongoClient) CreateViews(db *mongo.Database, cdef *CollectionDefinition
continue
}
viewCreated = true
}
return viewCreated
}

36
update_set.go Normal file
View File

@@ -0,0 +1,36 @@
package mongo
import (
"context"
"go.mongodb.org/mongo-driver/v2/bson"
)
// UpdateSet search documents using filter and updates the first it finds using the $set operator.
func (c *MongoClient) UpdateSet(ctx context.Context, database, name string, filter, data bson.M) (bool, error) {
collection := c.GetCollection(database, name)
prepareForUpdateSet(data)
if err := c.DiscriminatorCheckAndApplyToFilter(ctx, name, filter); err != nil {
return false, err
}
if err := c.DiscriminatorOmitInData(name, data); err != nil {
return false, err
}
update := bson.M{ "$set": data }
updateResult, err := collection.UpdateOne(ctx, filter, update)
if err != nil {
return false, err
}
changed := updateResult.ModifiedCount != 0
return changed, nil
}
func prepareForUpdateSet(data bson.M) {
ensureUpdatedAt(data)
}