Add UpdateSet, Relaxed in client, Fix tests

This commit is contained in:
George Suntres
2026-04-22 10:22:23 -04:00
parent 45f9ac558f
commit 188c5a1be1
9 changed files with 163 additions and 31 deletions

View File

@@ -6,6 +6,8 @@ import (
"context" "context"
"go.mongodb.org/mongo-driver/v2/bson" "go.mongodb.org/mongo-driver/v2/bson"
"git.gsuntres.com/general/commons"
) )
func (c *MongoClient) DiscriminatorCheckAndApplyToData(ctx context.Context, name string, data map[string]any) error { func (c *MongoClient) DiscriminatorCheckAndApplyToData(ctx context.Context, name string, data map[string]any) error {
@@ -31,6 +33,24 @@ func (c *MongoClient) DiscriminatorCheckAndApplyToData(ctx context.Context, name
return nil return nil
} }
func (c *MongoClient) DiscriminatorOmitInData(name string, data bson.M) error {
cdef, ok := c.Registry[name]
if ok && cdef.Discriminator != nil {
if data == nil {
data = map[string]any{}
}
log.Printf("Making sure discriminator is not in data for %s", name)
_, ok := data[cdef.Discriminator.Field]
if ok {
delete(data, cdef.Discriminator.Field)
}
}
return nil
}
func (c *MongoClient) DiscriminatorCheckAndApplyToFilter(ctx context.Context, name string, filter bson.M) error { func (c *MongoClient) DiscriminatorCheckAndApplyToFilter(ctx context.Context, name string, filter bson.M) error {
cdef, ok := c.Registry[name] cdef, ok := c.Registry[name]
if ok && cdef.Discriminator != nil { if ok && cdef.Discriminator != nil {
@@ -48,8 +68,13 @@ func (c *MongoClient) DiscriminatorCheckAndApplyToFilter(ctx context.Context, na
// update payload // update payload
v := vAny.(string) v := vAny.(string)
filter[cdef.Discriminator.Field] = bson.M{"$eq": v} if commons.StringIsBlank(v) {
return fmt.Errorf("discriminator field required for %s", name)
}
filter[cdef.Discriminator.Field] = bson.M{"eq": v}
} }
return nil return nil
} }

View File

@@ -1,6 +1,10 @@
package mongo package mongo
import "log" import (
"log"
"go.mongodb.org/mongo-driver/v2/bson"
)
type Query struct { type Query struct {
Filter map[string]any `json:"filter"` Filter map[string]any `json:"filter"`
@@ -17,12 +21,27 @@ func makeFilter(name string, value any) *Filter {
var v string var v string
log.Printf("TYPE %T", value)
switch value.(type) { switch value.(type) {
case string: case string:
op = "eq" op = "eq"
v = value.(string) v = value.(string)
case bson.M:
vMap := value.(bson.M)
log.Printf("GEO its map[string]any %v", vMap)
for kk, vv := range vMap {
op = kk
v = vv.(string)
break
}
case map[string]any: case map[string]any:
vMap := value.(map[string]any) vMap := value.(map[string]any)
log.Printf("GEO its bson.M %v", vMap)
for kk, vv := range vMap { for kk, vv := range vMap {
op = kk op = kk
v = vv.(string) v = vv.(string)
@@ -40,6 +59,8 @@ func makeFilter(name string, value any) *Filter {
Value: v, Value: v,
} }
log.Printf("FILTER -> %#v", o)
return o return o
} }
@@ -48,7 +69,7 @@ func Mongofy(q *Query) map[string]any {
conditions := make([]map[string]interface{}, 0) conditions := make([]map[string]interface{}, 0)
logic := "and" logic := "and"
log.Printf("GEO q.Filter %#v", q.Filter)
for k, v := range q.Filter { for k, v := range q.Filter {
if k == "_logic" { if k == "_logic" {
logic = v.(string) logic = v.(string)

19
find.go
View File

@@ -37,7 +37,11 @@ func (c *MongoClient) Find(ctx context.Context, database, name string, filter bs
return nil, err return nil, err
} }
pipeline := BuildPaginationPipeline(0, pageSize + 1, filter, sort) f := Mongofy(&Query{
Filter: filter,
})
pipeline := BuildPaginationPipeline(0, pageSize + 1, f, sort)
// 2. Query // 2. Query
cursor, err := collection.Aggregate(ctx, pipeline) cursor, err := collection.Aggregate(ctx, pipeline)
@@ -114,6 +118,10 @@ func (c *MongoClient) FindOffset(ctx context.Context, database, name string, fil
limit = opts.Limit limit = opts.Limit
} }
if opts != nil {
offset = opts.Offset
}
// 1. Prepare to query. // 1. Prepare to query.
finalName := name finalName := name
if opts != nil && commons.StringIsNotBlank(opts.Alias) { if opts != nil && commons.StringIsNotBlank(opts.Alias) {
@@ -175,9 +183,13 @@ func (c *MongoClient) FindOffset(ctx context.Context, database, name string, fil
} }
hasMore := false hasMore := false
if int64(len(data)) > finalLimit { if total > offset + finalLimit {
hasMore = true hasMore = true
data = data[:finalLimit] }
hasPrevious := false
if finalLimit - offset < 0 {
hasPrevious = true
} }
out := bson.M{ out := bson.M{
@@ -185,6 +197,7 @@ func (c *MongoClient) FindOffset(ctx context.Context, database, name string, fil
"offset": offset, "offset": offset,
"limit": finalLimit, "limit": finalLimit,
"has_more": hasMore, "has_more": hasMore,
"has_previous": hasPrevious,
"total": total, "total": total,
} }

View File

@@ -76,33 +76,25 @@ func TestFind_Discriminator(t *testing.T) {
client := GetMongoClient() client := GetMongoClient()
client.AddDefinition(store) client.AddDefinition(store)
client.AddDefinition(offer) client.AddDefinition(offer)
// Save two offers with the similar name in different stores each. // Store str_1234 has OSRAM 1.
ctx1 := context.Background() ctx1 := context.Background()
ctx1 = context.WithValue(ctx1, "account", "xxxxxx")
ctx1 = context.WithValue(ctx1, "store", "str_1234") ctx1 = context.WithValue(ctx1, "store", "str_1234")
// One offer in str_1234
offer1 := map[string]any { "name": "OSRAM 1" } offer1 := map[string]any { "name": "OSRAM 1" }
_, err = client.InsertOne(ctx1, "mydb", "offer", offer1) _, err = client.InsertOne(ctx1, "mydb", "offer", offer1)
if err != nil { t.Fatalf("Failed to insertOne %#v", err) } if err != nil { t.Fatalf("Failed to insertOne %#v", err) }
// The other in str_4321 // Store str_4321 has OSRAM 2
ctx2 := context.Background() ctx2 := context.Background()
ctx2 = context.WithValue(ctx2, "account", "xxxxxx")
ctx2 = context.WithValue(ctx2, "store", "str_4321") ctx2 = context.WithValue(ctx2, "store", "str_4321")
offer2 := map[string]any { "name": "OSRAM 2" } offer2 := map[string]any { "name": "OSRAM 2" }
_, err = client.InsertOne(ctx2, "mydb", "offer", offer2) _, err = client.InsertOne(ctx2, "mydb", "offer", offer2)
if err != nil { t.Fatalf("Failed to insertOne %#v", err) } if err != nil { t.Fatalf("Failed to insertOne %#v", err) }
// Now searching in store str_1234 for OSRAM should return only one // Within store str_1234 searching for OSRAM should return one result
filter := bson.M{"name": bson.M{"$regex": "OSRAM*"}} filter := bson.M{"name": bson.M{"$regex": "OSRAM*"}}
findResult, err := client.Find(ctx1, "mydb", "offer", filter, &FindOptions{ findResult, err := client.Find(ctx1, "mydb", "offer", filter, &FindOptions{ Offset: int64(0) })
Offset: int64(0),
})
if err != nil { t.Fatalf("Failed to find %#v", err) } if err != nil { t.Fatalf("Failed to find %#v", err) }
dataAny, hasData := findResult["data"] dataAny, hasData := findResult["data"]

52
main.go
View File

@@ -40,6 +40,10 @@ type IMongoClient interface {
// GetCollection returns the requested collection within an account. It will create it if it doesn't exist. // GetCollection returns the requested collection within an account. It will create it if it doesn't exist.
GetCollection(database, name string) *mongo.Collection GetCollection(database, name string) *mongo.Collection
SetRelaxed()
SetStrict()
} }
type Timeseries struct { type Timeseries struct {
@@ -81,6 +85,16 @@ type MongoClient struct {
DBPrefix string DBPrefix string
// Registry holds critical information about collection's structure, like schema and indexes // Registry holds critical information about collection's structure, like schema and indexes
Registry map[string]*CollectionDefinition Registry map[string]*CollectionDefinition
// Relaxed if set to true will not enforce schema
Relaxed bool
}
func (c *MongoClient) SetRelaxed() {
c.Relaxed = true
}
func (c *MongoClient) SetStrict() {
c.Relaxed = false
} }
func (c *MongoClient) GetIdPrefix(name string) string { func (c *MongoClient) GetIdPrefix(name string) string {
@@ -114,6 +128,13 @@ func (c *MongoClient) AddDefinition(data map[string]any) {
if len(cd.Name) > 0 { if len(cd.Name) > 0 {
c.Registry[cd.Name] = &cd c.Registry[cd.Name] = &cd
} }
if cd.Views != nil {
for k, _ := range cd.Views {
log.Printf("Registering alias %s to %s", k, cd.Name)
c.Registry[k] = &cd
}
}
} }
func (c *MongoClient) DropDatabase_DANGER(database string) bool { func (c *MongoClient) DropDatabase_DANGER(database string) bool {
@@ -139,6 +160,9 @@ func (c *MongoClient) GetCollection(database, name string) *mongo.Collection {
db := c.Client.Database(database) db := c.Client.Database(database)
// To make sure we create the right collection when dealing with aliases.
actualName := name
// 1. List existing collections // 1. List existing collections
names, err := db.ListCollectionNames(context.TODO(), bson.D{}) names, err := db.ListCollectionNames(context.TODO(), bson.D{})
@@ -149,32 +173,41 @@ func (c *MongoClient) GetCollection(database, name string) *mongo.Collection {
} }
// 2. If collection exist return it, otherwise create it and then return it // 2. If collection exist return it, otherwise create it and then return it
if slices.Contains(names, name) { if slices.Contains(names, actualName) {
return db.Collection(name) return db.Collection(actualName)
} else { } else {
opts := options.CreateCollection() opts := options.CreateCollection()
// maybe get from schema // maybe get from schema
cdef, ok := c.Registry[name] cdef, ok := c.Registry[actualName]
if ok { if ok {
log.Printf("Schema found for %s; will use it", name)
actualName = cdef.Name
log.Printf("Schema found for %s; will use it", actualName)
ApplyTimeSeries(cdef, opts) ApplyTimeSeries(cdef, opts)
ApplySchema(cdef, opts) ApplySchema(cdef, opts)
} else { } else {
log.Printf("No schema for %s", name) log.Printf("No schema for %s", actualName)
if c.Relaxed == false {
return nil
}
} }
if err := db.CreateCollection(context.TODO(), name, opts); err != nil { if err := db.CreateCollection(context.TODO(), actualName, opts); err != nil {
log.Printf("Failed to create collection: %#v", err) log.Printf("Failed to create collection: %#v", err)
return nil return nil
} }
collection := db.Collection(name) collection := db.Collection(actualName)
c.CreateIndexes(collection, cdef) c.CreateIndexes(collection, cdef)
c.CreateViews(db, cdef)
if c.CreateViews(db, cdef) {
collection = db.Collection(name)
}
return collection return collection
} }
@@ -216,6 +249,7 @@ func ApplySchema(cdef *CollectionDefinition, opts *options.CreateCollectionOptio
var client *MongoClient = &MongoClient{ var client *MongoClient = &MongoClient{
Limit: 10, Limit: 10,
Registry: make(map[string]*CollectionDefinition, 0), Registry: make(map[string]*CollectionDefinition, 0),
Relaxed: false,
} }
func GetMongoClient() *MongoClient { func GetMongoClient() *MongoClient {

View File

@@ -25,8 +25,12 @@ func TestCreateIndexes(t *testing.T) {
c := GetMongoClient() c := GetMongoClient()
c.SetRelaxed()
collection := c.GetCollection("mydb", "mycol") collection := c.GetCollection("mydb", "mycol")
c.SetStrict()
c.CreateIndexes(collection, cd) c.CreateIndexes(collection, cd)
indexView := collection.Indexes() indexView := collection.Indexes()

View File

@@ -47,11 +47,13 @@ func TestMain(m *testing.M) {
// 2. Get the connection string dynamically // 2. Get the connection string dynamically
endpoint, _ := mongoContainer.ConnectionString(ctx) endpoint, _ := mongoContainer.ConnectionString(ctx)
mongoDebug := os.Getenv("DEBUG") != ""
Start(&MongoStartProps{ Start(&MongoStartProps{
MongoUri: endpoint, MongoUri: endpoint,
MongoUser: user, MongoUser: user,
MongoPass: pass, MongoPass: pass,
MongoDebugQuery: false, MongoDebugQuery: mongoDebug,
}) })
// 3. Run tests // 3. Run tests

View File

@@ -10,13 +10,15 @@ import (
) )
// CreateViews will create views for the given collection and definition. // CreateViews will create views for the given collection and definition.
func (c *MongoClient) CreateViews(db *mongo.Database, cdef *CollectionDefinition) { func (c *MongoClient) CreateViews(db *mongo.Database, cdef *CollectionDefinition) bool {
if cdef == nil || cdef.Views == nil { if cdef == nil || cdef.Views == nil {
log.Printf("No definition for views found.") log.Printf("No definition for views found.")
return return false
} }
viewCreated := false
for name, defVal := range cdef.Views { for name, defVal := range cdef.Views {
// 1. Decode definition // 1. Decode definition
@@ -82,6 +84,9 @@ func (c *MongoClient) CreateViews(db *mongo.Database, cdef *CollectionDefinition
continue continue
} }
viewCreated = true
} }
return viewCreated
} }

36
update_set.go Normal file
View File

@@ -0,0 +1,36 @@
package mongo
import (
"context"
"go.mongodb.org/mongo-driver/v2/bson"
)
// UpdateSet search documents using filter and updates the first it finds using the $set operator.
func (c *MongoClient) UpdateSet(ctx context.Context, database, name string, filter, data bson.M) (bool, error) {
collection := c.GetCollection(database, name)
prepareForUpdateSet(data)
if err := c.DiscriminatorCheckAndApplyToFilter(ctx, name, filter); err != nil {
return false, err
}
if err := c.DiscriminatorOmitInData(name, data); err != nil {
return false, err
}
update := bson.M{ "$set": data }
updateResult, err := collection.UpdateOne(ctx, filter, update)
if err != nil {
return false, err
}
changed := updateResult.ModifiedCount != 0
return changed, nil
}
func prepareForUpdateSet(data bson.M) {
ensureUpdatedAt(data)
}