diff --git a/go.mod b/go.mod index d4318a0..760a67c 100644 --- a/go.mod +++ b/go.mod @@ -1,17 +1,32 @@ module go-study -go 1.23 +go 1.23.0 + +toolchain go1.23.3 require ( github.com/marusama/cyclicbarrier v1.1.0 github.com/stretchr/testify v1.9.0 go.uber.org/mock v0.5.0 golang.org/x/exp v0.0.0-20241108190413-2d47ceb2692f - golang.org/x/sync v0.9.0 + golang.org/x/sync v0.14.0 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect + github.com/golang/snappy v1.0.0 // indirect + github.com/jinzhu/inflection v1.0.0 // indirect + github.com/jinzhu/now v1.1.5 // indirect + github.com/klauspost/compress v1.18.0 // indirect + github.com/montanaflynn/stats v0.7.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/xdg-go/pbkdf2 v1.0.0 // indirect + github.com/xdg-go/scram v1.1.2 // indirect + github.com/xdg-go/stringprep v1.0.4 // indirect + github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 // indirect + go.mongodb.org/mongo-driver v1.17.3 // indirect + golang.org/x/crypto v0.38.0 // indirect + golang.org/x/text v0.25.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect + gorm.io/gorm v1.26.1 // indirect ) diff --git a/go.sum b/go.sum index 9951d05..1703aa1 100644 --- a/go.sum +++ b/go.sum @@ -1,26 +1,78 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/golang/snappy v1.0.0 h1:Oy607GVXHs7RtbggtPBnr2RmDArIsAefDwvrdWvRhGs= +github.com/golang/snappy v1.0.0/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= +github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= +github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= github.com/marusama/cyclicbarrier v1.1.0 h1:ol/AG+sjvh5yz832avbNjaowoerBuD3AgozxL+aD9u0= github.com/marusama/cyclicbarrier v1.1.0/go.mod h1:5u93l83cy51YXdz6eKq6kO9+9mGAooB6DHMAxcSuWwQ= +github.com/montanaflynn/stats v0.7.1 h1:etflOAAHORrCC44V+aR6Ftzort912ZU+YLiSTuV8eaE= +github.com/montanaflynn/stats v0.7.1/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt6R8Bnaayow= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c= +github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI= +github.com/xdg-go/scram v1.1.2 h1:FHX5I5B4i4hKRVRBCFRxq1iQRej7WO3hhBuJf+UUySY= +github.com/xdg-go/scram v1.1.2/go.mod h1:RT/sEzTbU5y00aCK8UOx6R7YryM0iF1N2MOmC3kKLN4= +github.com/xdg-go/stringprep v1.0.4 h1:XLI/Ng3O1Atzq0oBs3TWm+5ZVgkq2aqdlvP9JtoZ6c8= +github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM= +github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 h1:ilQV1hzziu+LLM3zUTJ0trRztfwgjqKnBWNtSRkbmwM= +github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78/go.mod h1:aL8wCCfTfSfmXjznFBSZNN13rSJjlIOI1fUNAtF7rmI= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +go.mongodb.org/mongo-driver v1.17.3 h1:TQyXhnsWfWtgAhMtOgtYHMTkZIfBTpMTsMnd9ZBeHxQ= +go.mongodb.org/mongo-driver v1.17.3/go.mod h1:Hy04i7O2kC4RS06ZrhPRqj/u4DTYkFDAAccj+rVKqgQ= go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU= go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU= go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.38.0 h1:jt+WWG8IZlBnVbomuhg2Mdq0+BBQaHbtqHEFEigjUV8= +golang.org/x/crypto v0.38.0/go.mod h1:MvrbAqul58NNYPKnOra203SB9vpuZW0e+RRZV+Ggqjw= golang.org/x/exp v0.0.0-20240119083558-1b970713d09a h1:Q8/wZp0KX97QFTc2ywcOE0YRjZPVIx+MXInMzdvQqcA= golang.org/x/exp v0.0.0-20240119083558-1b970713d09a/go.mod h1:idGWGoKP1toJGkd5/ig9ZLuPcZBC3ewk7SzmH0uou08= golang.org/x/exp v0.0.0-20241108190413-2d47ceb2692f h1:XdNn9LlyWAhLVp6P/i8QYBW+hlyhrhei9uErw2B5GJo= golang.org/x/exp v0.0.0-20241108190413-2d47ceb2692f/go.mod h1:D5SMRVC3C2/4+F/DB1wZsLRnSNimn2Sp/NPsCrsv8ak= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE= golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.9.0 h1:fEo0HyrW1GIgZdpbhCRO0PkJajUS5H9IFUztCgEo2jQ= golang.org/x/sync v0.9.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.14.0 h1:woo0S4Yywslg6hp4eUFjTVOyKt0RookbpAHG4c1HmhQ= +golang.org/x/sync v0.14.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= +golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4= +golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorm.io/gorm v1.26.1 h1:ghB2gUI9FkS46luZtn6DLZ0f6ooBJ5IbVej2ENFDjRw= +gorm.io/gorm v1.26.1/go.mod h1:8Z33v652h4//uMA76KjeDH8mJXPm1QNCYrMeatR0DOE= diff --git a/query_builder/builder.go b/query_builder/builder.go new file mode 100644 index 0000000..566a187 --- /dev/null +++ b/query_builder/builder.go @@ -0,0 +1,260 @@ +package builder + +import ( + "context" + "errors" + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" + "gorm.io/gorm" +) + +// DBProxy 数据实例结构 +type DBProxy struct { + db *gorm.DB + mongodb *mongo.Collection // 需提前指定.Database("db_name").Collection("collection_name") + // redis、elasticsearch... +} + +// NewDBProxy 创建数据实例 +func NewDBProxy(db *gorm.DB, mongodb *mongo.Collection) *DBProxy { + return &DBProxy{ + db: db, + mongodb: mongodb, + } +} + +// QueryMiddleware 查询中间件类型定义 +// 参数: +// +// ctx: 上下文 +// builder: 查询构建器实例 +// next: 下一个中间件或最终查询处理器 +// +// 返回: +// +// []*R: 查询结果列表 +// int64: 总数 +// error: 错误信息 +type QueryMiddleware[R any] func( + ctx context.Context, + builder *builder[R], + next func(context.Context) ([]*R, int64, error), +) ([]*R, int64, error) + +// builder 查询构建器,使用泛型支持多种实体类型 +// 泛型参数: +// +// R: 查询结果的实体类型 +type builder[R any] struct { + data *DBProxy + start uint32 + limit uint32 + needTotal bool + needPagination bool + strategy QueryListStrategy[R] // 查询策略 + middlewares []QueryMiddleware[R] // 中间件链 + + filter func(context.Context) (any, error) + sort func() any +} + +// SetFilter 设置过滤条件生成函数 +// 返回支持链式调用的构建器实例 +func (b *builder[R]) SetFilter(filter func(context.Context) (any, error)) *builder[R] { + b.filter = filter + return b +} + +// SetSort 设置排序条件生成函数 +// 返回支持链式调用的构建器实例 +func (b *builder[R]) SetSort(sort func() any) *builder[R] { + b.sort = sort + return b +} + +// SetStrategy 设置查询列表策略 +// 返回支持链式调用的构建器实例 +func (b *builder[R]) SetStrategy(strategy QueryListStrategy[R]) *builder[R] { + b.strategy = strategy + return b +} + +// Use 添加中间件 +// 返回支持链式调用的构建器实例 +func (b *builder[R]) Use(middleware QueryMiddleware[R]) *builder[R] { + b.middlewares = append(b.middlewares, middleware) + return b +} + +// getQueryStrategy 获取查询列表策略 +// 如果没有设置策略,则根据数据源自动选择策略 +func (b *builder[R]) getQueryStrategy() (QueryListStrategy[R], error) { + if b.strategy != nil { + return b.strategy, nil + } + if b.data == nil { + return nil, errors.New("no data source provided") + } + + switch { + case b.data.db != nil: + return NewQueryGormListStrategy[R](), nil + case b.data.mongodb != nil: + return NewQueryMongoListStrategy[R](), nil + default: + return nil, errors.New("query strategy not set and no valid DB found") + } +} + +// QueryList 执行查询列表操作 +// 返回值与中间件类型相同,list []R 查询结果列表 +func (b *builder[R]) QueryList(ctx context.Context) ([]*R, int64, error) { + // 尝试自动推断策略类型 + strategy, err := b.getQueryStrategy() + if err != nil { + return nil, 0, err + } + + // 构建中间件链 + next := func(ctx context.Context) ([]*R, int64, error) { + return strategy.QueryList(ctx, b) + } + + for i := len(b.middlewares) - 1; i >= 0; i-- { + next = func(mw QueryMiddleware[R], fn func(context.Context) ([]*R, int64, error)) func(context.Context) ([]*R, int64, error) { + return func(ctx context.Context) ([]*R, int64, error) { + return mw(ctx, b, fn) + } + }(b.middlewares[i], next) + } + + return next(ctx) +} + +// QueryListStrategy 查询列表策略 +type QueryListStrategy[R any] interface { + QueryList(context.Context, *builder[R]) ([]*R, int64, error) +} + +// QueryGormListStrategy GORM 查询策略实现 +type QueryGormListStrategy[R any] struct{} + +// NewQueryGormListStrategy 创建 GORM 查询策略实例 +func NewQueryGormListStrategy[R any]() *QueryGormListStrategy[R] { + return &QueryGormListStrategy[R]{} +} + +// QueryList 实现 GORM 查询逻辑 +func (s *QueryGormListStrategy[R]) QueryList( + ctx context.Context, + builder *builder[R], +) (list []*R, total int64, err error) { + filterScope, err := builder.filter(ctx) + if err != nil { + return nil, 0, err + } + + sortScope := builder.sort() + // 验证过滤条件和排序条件的类型有效性 + for _, scope := range []any{filterScope, sortScope} { + if _, ok := scope.(func(*gorm.DB) *gorm.DB); !ok { + return nil, 0, errors.New("invalid scope") + } + } + + // 使用 WaitAndGo 并行执行数据查询和总数统计操作 + if err := WaitAndGo(func() error { + query := builder.data.db.WithContext(ctx). + Model(&list). + Scopes(filterScope.(func(*gorm.DB) *gorm.DB), sortScope.(func(*gorm.DB) *gorm.DB)) + + if builder.needPagination { + if builder.limit < 1 { + builder.limit = defaultLimit + } + query = query.Offset(int(builder.start)).Limit(int(builder.limit)) + } + + return query.Find(&list).Error + }, func() error { + if !builder.needTotal { + return nil + } + + return builder.data.db.WithContext(ctx). + Model(&list). + Scopes(filterScope.(func(*gorm.DB) *gorm.DB)). + Count(&total). + Error + }); err != nil { + return nil, 0, err + } + + return list, total, nil +} + +// QueryMongoListStrategy MongoDB 查询策略实现 +type QueryMongoListStrategy[R any] struct{} + +// NewQueryMongoListStrategy 创建 MongoDB 查询策略实例 +func NewQueryMongoListStrategy[R any]() *QueryMongoListStrategy[R] { + return &QueryMongoListStrategy[R]{} +} + +// QueryList 实现 MongoDB 查询逻辑 +func (s *QueryMongoListStrategy[R]) QueryList( + ctx context.Context, + builder *builder[R], +) (list []*R, total int64, err error) { + filterOpt, err := builder.filter(ctx) + if err != nil { + return nil, 0, err + } + + sortOpt := builder.sort() + // 验证过滤条件和排序条件的类型有效性 + for _, opt := range []any{filterOpt, sortOpt} { + _, mOk := opt.(bson.M) + _, dOk := opt.(bson.D) + if !mOk && !dOk { + return nil, 0, errors.New("invalid option") + } + } + + // 使用 WaitAndGo 并行执行数据查询和总数统计操作 + if err := WaitAndGo(func() error { + findOpt := options.Find().SetSort(sortOpt) + if builder.needPagination { + if builder.limit < 1 { + builder.limit = defaultLimit + } + findOpt.SetSkip(int64(builder.start)).SetLimit(int64(builder.limit)) + } + + cursor, err := builder.data.mongodb.Find(ctx, filterOpt, findOpt) + if err != nil { + return err + } + defer func(cursor *mongo.Cursor, ctx context.Context) { + _ = cursor.Close(ctx) + }(cursor, ctx) + + return cursor.All(ctx, &list) + }, func() error { + if !builder.needTotal { + return nil + } + + total, err = builder.data.mongodb.CountDocuments(ctx, filterOpt) + if err != nil { + return err + } + + return nil + }); err != nil { + return nil, 0, err + } + + return list, total, nil +} diff --git a/query_builder/mock_strategy.go b/query_builder/mock_strategy.go new file mode 100644 index 0000000..7e4aa92 --- /dev/null +++ b/query_builder/mock_strategy.go @@ -0,0 +1,56 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: query_builder/builder.go +// +// Generated by this command: +// +// mockgen -source=query_builder/builder.go -destination=query_builder/mock_strategy.go -package=builder +// + +// Package builder is a generated GoMock package. +package builder + +import ( + context "context" + reflect "reflect" + + gomock "go.uber.org/mock/gomock" +) + +// MockQueryListStrategy is a mock of QueryListStrategy interface. +type MockQueryListStrategy[R any] struct { + ctrl *gomock.Controller + recorder *MockQueryListStrategyMockRecorder[R] +} + +// MockQueryListStrategyMockRecorder is the mock recorder for MockQueryListStrategy. +type MockQueryListStrategyMockRecorder[R any] struct { + mock *MockQueryListStrategy[R] +} + +// NewMockQueryListStrategy creates a new mock instance. +func NewMockQueryListStrategy[R any](ctrl *gomock.Controller) *MockQueryListStrategy[R] { + mock := &MockQueryListStrategy[R]{ctrl: ctrl} + mock.recorder = &MockQueryListStrategyMockRecorder[R]{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockQueryListStrategy[R]) EXPECT() *MockQueryListStrategyMockRecorder[R] { + return m.recorder +} + +// QueryList mocks base method. +func (m *MockQueryListStrategy[R]) QueryList(arg0 context.Context, arg1 *builder[R]) ([]*R, int64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "QueryList", arg0, arg1) + ret0, _ := ret[0].([]*R) + ret1, _ := ret[1].(int64) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// QueryList indicates an expected call of QueryList. +func (mr *MockQueryListStrategyMockRecorder[R]) QueryList(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryList", reflect.TypeOf((*MockQueryListStrategy[R])(nil).QueryList), arg0, arg1) +} diff --git a/query_builder/option.go b/query_builder/option.go new file mode 100644 index 0000000..59dd22a --- /dev/null +++ b/query_builder/option.go @@ -0,0 +1,137 @@ +package builder + +const ( + defaultStart = 0 // 默认从第0条开始 + defaultLimit = 10 // 默认每页10条 + defaultNeedTotal = true // 默认需要总数 + defaultNeedPagination = true // 默认需要分页 +) + +// Filter 定义过滤条件的通用接口类型 +type Filter any + +// Sort 定义排序条件的通用接口类型 +type Sort any + +// QueryListOptions 定义了查询列表的通用选项接口 +// 泛型参数: +// +// F - 过滤条件类型参数 +// S - 排序条件类型参数 +type QueryListOptions[F Filter, S Sort] interface { + GetData() *DBProxy + GetFilter() *F + GetSort() S + GetStart() uint32 + GetLimit() uint32 + GetNeedTotal() bool + GetNeedPagination() bool +} + +// BaseQueryListOptions 实现了QueryListOptions接口的基础结构体 +// 包含查询列表所需的所有基本选项 +type BaseQueryListOptions[F Filter, S Sort] struct { + data *DBProxy // 数据实例 + filter *F // 过滤条件生成函数 + sort S // 排序条件生成函数 + start uint32 // 分页起始位置 + limit uint32 // 每页数据条数 + needTotal bool // 是否需要查询总数 + needPagination bool // 是否需要分页 +} + +func (opts *BaseQueryListOptions[F, S]) GetData() *DBProxy { + return opts.data +} + +func (opts *BaseQueryListOptions[F, S]) GetFilter() *F { + return opts.filter +} + +func (opts *BaseQueryListOptions[F, S]) GetSort() S { + return opts.sort +} + +func (opts *BaseQueryListOptions[F, S]) GetStart() uint32 { + return opts.start +} + +func (opts *BaseQueryListOptions[F, S]) GetLimit() uint32 { + return opts.limit +} + +func (opts *BaseQueryListOptions[F, S]) GetNeedTotal() bool { + return opts.needTotal +} + +func (opts *BaseQueryListOptions[F, S]) GetNeedPagination() bool { return opts.needPagination } + +// QueryOption 定义用于配置查询选项的函数类型 +type QueryOption[F Filter, S Sort] func(options *BaseQueryListOptions[F, S]) + +// LoadQueryOptions 加载并应用查询选项 +// 参数: +// +// opts - 可变数量的查询选项函数 +// +// 返回: +// +// 配置好的BaseQueryListOptions实例 +func LoadQueryOptions[F Filter, S Sort](opts ...QueryOption[F, S]) BaseQueryListOptions[F, S] { + // 初始化默认选项 + options := BaseQueryListOptions[F, S]{ + start: defaultStart, + limit: defaultLimit, + needTotal: defaultNeedTotal, + needPagination: defaultNeedPagination, + } + + // 应用所有选项函数 + for _, opt := range opts { + opt(&options) + } + + return options +} + +func WithData[F Filter, S Sort](data *DBProxy) QueryOption[F, S] { + return func(o *BaseQueryListOptions[F, S]) { + o.data = data + } +} + +func WithFilter[F Filter, S Sort](filter *F) QueryOption[F, S] { + return func(o *BaseQueryListOptions[F, S]) { + o.filter = filter + } +} + +func WithSort[F Filter, S Sort](sort S) QueryOption[F, S] { + return func(o *BaseQueryListOptions[F, S]) { + o.sort = sort + } +} + +func WithStart[F Filter, S Sort](start uint32) QueryOption[F, S] { + return func(o *BaseQueryListOptions[F, S]) { + o.start = start + } +} + +func WithLimit[F Filter, S Sort](limit uint32) QueryOption[F, S] { + return func(o *BaseQueryListOptions[F, S]) { + o.limit = limit + } +} + +func WithNeedTotal[F Filter, S Sort](needTotal bool) QueryOption[F, S] { + return func(o *BaseQueryListOptions[F, S]) { + o.needTotal = needTotal + } +} + +func WithNeedPagination[F Filter, S Sort](needPagination bool) QueryOption[F, S] { + return func(o *BaseQueryListOptions[F, S]) { + o.needPagination = needPagination + } +} diff --git a/query_builder/service.go b/query_builder/service.go new file mode 100644 index 0000000..f3b3284 --- /dev/null +++ b/query_builder/service.go @@ -0,0 +1,73 @@ +package builder + +import "context" + +// Service 通用查询服务接口 +type Service interface { + GetFilter(context.Context) (any, error) + GetSort() any +} + +// BaseService 基础查询服务 +// 泛型参数: +// +// R - 返回结果类型参数 +// F - 过滤条件类型参数 +// S - 排序条件类型参数 +type BaseService[R any, F any, S any] struct { + builder[R] + filter *F + sort S + service Service +} + +// QueryList 执行列表查询 +func (s *BaseService[R, F, S]) QueryList(ctx context.Context) ([]*R, int64, error) { + if s.service == nil { + return nil, 0, nil + } + + // 执行查询 + return s.builder. + SetFilter(s.service.GetFilter). + SetSort(s.service.GetSort). + QueryList(ctx) +} + +// List 查询列表功能结构 +// 泛型参数上同 +type List[R any, F any, S any] struct { + strategy QueryListStrategy[R] +} + +// SetStrategy 设置查询策略 +// 支持不同数据源的查询实现,如MySQL、MongoDB等 +// 通过该方法可自定义查询策略,已有策略可根据数据源自动选择 +func (l *List[R, F, S]) SetStrategy(strategy QueryListStrategy[R]) *List[R, F, S] { + l.strategy = strategy + return l +} + +// Query 执行查询 +// 该方法会根据传入的Service实例和QueryOption选项执行查询 +func (l *List[R, F, S]) Query( + ctx context.Context, + s Service, + opts ...QueryOption[F, S], +) ([]*R, int64, error) { + options := LoadQueryOptions(opts...) + service := &BaseService[R, F, S]{ + builder: builder[R]{ + data: options.GetData(), + start: options.GetStart(), + limit: options.GetLimit(), + needTotal: options.GetNeedTotal(), + needPagination: options.GetNeedPagination(), + strategy: l.strategy, + }, + filter: options.GetFilter(), + sort: options.GetSort(), + service: s, + } + return service.QueryList(ctx) +} diff --git a/query_builder/service_test.go b/query_builder/service_test.go new file mode 100644 index 0000000..25294c4 --- /dev/null +++ b/query_builder/service_test.go @@ -0,0 +1,209 @@ +package builder + +import ( + "context" + "errors" + "go.uber.org/mock/gomock" + "gorm.io/gorm" + "testing" +) + +type TestEntity struct { + ID uint32 + Name string + Age int +} + +type TestFilter struct { + Name string + Age uint8 +} + +type TestSort struct { + Field string + Direction string +} + +type TestService struct { + filter TestFilter + sort TestSort +} + +func (s *TestService) GetFilter(_ context.Context) (any, error) { + return func(db *gorm.DB) *gorm.DB { + if s.filter.Name != "" { + db.Where("name = ?", s.filter.Name) + } + + if s.filter.Age > 0 { + db.Where("age >= ?", s.filter.Age) + } + + return db + }, nil +} + +func (s *TestService) GetSort() any { + return func(db *gorm.DB) *gorm.DB { + // 实际项目中的排序需要根据pb文件生成的枚举值来处理 + return db.Order(s.sort.Field + " " + s.sort.Direction) + } +} + +func TestQueryList(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + // Mock 策略实例 + mockStrategy := NewMockQueryListStrategy[TestEntity](ctrl) + + tests := []struct { + name string + service Service + mockSetup func() + opts []QueryOption[TestFilter, TestSort] + expectedResult []*TestEntity + expectedTotal int64 + expectedErr error + }{ + { + name: "无筛选查询&id升序", + service: &TestService{}, + mockSetup: func() { + mockStrategy.EXPECT(). + QueryList(ctx, gomock.Any()). + Return([]*TestEntity{ + {ID: 1, Name: "Alice", Age: 25}, + {ID: 2, Name: "Bob", Age: 30}, + }, int64(2), nil) + }, + opts: []QueryOption[TestFilter, TestSort]{ + WithData[TestFilter, TestSort](NewDBProxy(&gorm.DB{}, nil)), + WithFilter[TestFilter, TestSort](&TestFilter{}), + WithSort[TestFilter, TestSort](TestSort{Field: "id", Direction: "asc"}), + }, + expectedResult: []*TestEntity{ + {ID: 1, Name: "Alice", Age: 25}, + {ID: 2, Name: "Bob", Age: 30}, + }, + expectedTotal: 2, + expectedErr: nil, + }, + { + name: "无筛选查询&age降序", + service: &TestService{}, + mockSetup: func() { + mockStrategy.EXPECT(). + QueryList(ctx, gomock.Any()). + Return([]*TestEntity{ + {ID: 2, Name: "Bob", Age: 30}, + {ID: 1, Name: "Alice", Age: 25}, + }, int64(2), nil) + }, + opts: []QueryOption[TestFilter, TestSort]{ + WithData[TestFilter, TestSort](NewDBProxy(&gorm.DB{}, nil)), + WithFilter[TestFilter, TestSort](&TestFilter{}), + WithSort[TestFilter, TestSort](TestSort{Field: "age", Direction: "desc"}), + }, + expectedResult: []*TestEntity{ + {ID: 2, Name: "Bob", Age: 30}, + {ID: 1, Name: "Alice", Age: 25}, + }, + expectedTotal: 2, + expectedErr: nil, + }, + { + name: "有筛选查询", + service: &TestService{}, + mockSetup: func() { + mockStrategy.EXPECT(). + QueryList(ctx, gomock.Any()). + Return([]*TestEntity{ + {ID: 1, Name: "Alice", Age: 25}, + }, int64(1), nil) + }, + opts: []QueryOption[TestFilter, TestSort]{ + WithData[TestFilter, TestSort](NewDBProxy(&gorm.DB{}, nil)), + WithFilter[TestFilter, TestSort](&TestFilter{Name: "Alice"}), + WithSort[TestFilter, TestSort](TestSort{Field: "id", Direction: "desc"}), + }, + expectedResult: []*TestEntity{ + {ID: 1, Name: "Alice", Age: 25}, + }, + expectedTotal: 1, + expectedErr: nil, + }, + { + name: "无数据实例", + service: &TestService{}, + mockSetup: func() { + mockStrategy.EXPECT(). + QueryList(ctx, gomock.Any()). + Return(nil, int64(0), nil) + }, + opts: []QueryOption[TestFilter, TestSort]{ + WithData[TestFilter, TestSort](NewDBProxy(&gorm.DB{}, nil)), + WithFilter[TestFilter, TestSort](&TestFilter{Name: "test"}), + WithSort[TestFilter, TestSort](TestSort{Field: "id", Direction: "asc"}), + }, + expectedResult: nil, + expectedTotal: 0, + expectedErr: nil, + }, + { + name: "数据实例错误", + service: &TestService{}, + mockSetup: func() { + mockStrategy.EXPECT(). + QueryList(ctx, gomock.Any()). + Return(nil, int64(0), errors.New("no data source provided")) + }, + opts: []QueryOption[TestFilter, TestSort]{ + WithFilter[TestFilter, TestSort](&TestFilter{Name: "test"}), + WithSort[TestFilter, TestSort](TestSort{Field: "id", Direction: "asc"}), + }, + expectedResult: nil, + expectedTotal: 0, + expectedErr: errors.New("no data source provided"), + }, + } + + list := &List[TestEntity, TestFilter, TestSort]{} + // 这里使用 Mock 策略替代真实的策略 + // 实际使用时会根据数据源自动选择 + list.SetStrategy(mockStrategy) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.mockSetup != nil { + tt.mockSetup() + } + + result, total, err := list.Query(ctx, tt.service, tt.opts...) + + if tt.expectedErr != nil { + if err == nil || err.Error() != tt.expectedErr.Error() { + t.Errorf("expected error: %v, got: %v", tt.expectedErr, err) + } + } else { + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + if total != tt.expectedTotal { + t.Errorf("expected total: %d, got: %d", tt.expectedTotal, total) + } + + if len(result) != len(tt.expectedResult) { + t.Errorf("expected result length: %d, got: %d", len(tt.expectedResult), len(result)) + } + + for i, item := range result { + if item.ID != tt.expectedResult[i].ID || item.Name != tt.expectedResult[i].Name || item.Age != tt.expectedResult[i].Age { + t.Errorf("expected result[%d]: %+v, got: %+v", i, tt.expectedResult[i], item) + } + } + } + }) + } +} diff --git a/query_builder/utils.go b/query_builder/utils.go new file mode 100644 index 0000000..0261e3c --- /dev/null +++ b/query_builder/utils.go @@ -0,0 +1,23 @@ +package builder + +import ( + "fmt" + "golang.org/x/sync/errgroup" + "runtime/debug" +) + +// WaitAndGo 等待所有函数执行完毕 +func WaitAndGo(fn ...func() error) error { + defer func() { + if err := recover(); err != nil { + fmt.Printf("Panic: %+v\n %s", err, string(debug.Stack())) + } + }() + var g errgroup.Group + for _, f := range fn { + g.Go(func() error { + return f() + }) + } + return g.Wait() +}