新增查询构建器功能
This commit is contained in:
parent
0dd96a1858
commit
342b25cd78
19
go.mod
19
go.mod
@ -1,17 +1,32 @@
|
||||
module go-study
|
||||
|
||||
go 1.23
|
||||
go 1.23.0
|
||||
|
||||
toolchain go1.23.3
|
||||
|
||||
require (
|
||||
github.com/marusama/cyclicbarrier v1.1.0
|
||||
github.com/stretchr/testify v1.9.0
|
||||
go.uber.org/mock v0.5.0
|
||||
golang.org/x/exp v0.0.0-20241108190413-2d47ceb2692f
|
||||
golang.org/x/sync v0.9.0
|
||||
golang.org/x/sync v0.14.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/golang/snappy v1.0.0 // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
github.com/klauspost/compress v1.18.0 // indirect
|
||||
github.com/montanaflynn/stats v0.7.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/xdg-go/pbkdf2 v1.0.0 // indirect
|
||||
github.com/xdg-go/scram v1.1.2 // indirect
|
||||
github.com/xdg-go/stringprep v1.0.4 // indirect
|
||||
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 // indirect
|
||||
go.mongodb.org/mongo-driver v1.17.3 // indirect
|
||||
golang.org/x/crypto v0.38.0 // indirect
|
||||
golang.org/x/text v0.25.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
gorm.io/gorm v1.26.1 // indirect
|
||||
)
|
||||
|
52
go.sum
52
go.sum
@ -1,26 +1,78 @@
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/golang/snappy v1.0.0 h1:Oy607GVXHs7RtbggtPBnr2RmDArIsAefDwvrdWvRhGs=
|
||||
github.com/golang/snappy v1.0.0/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
|
||||
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
|
||||
github.com/marusama/cyclicbarrier v1.1.0 h1:ol/AG+sjvh5yz832avbNjaowoerBuD3AgozxL+aD9u0=
|
||||
github.com/marusama/cyclicbarrier v1.1.0/go.mod h1:5u93l83cy51YXdz6eKq6kO9+9mGAooB6DHMAxcSuWwQ=
|
||||
github.com/montanaflynn/stats v0.7.1 h1:etflOAAHORrCC44V+aR6Ftzort912ZU+YLiSTuV8eaE=
|
||||
github.com/montanaflynn/stats v0.7.1/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt6R8Bnaayow=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c=
|
||||
github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI=
|
||||
github.com/xdg-go/scram v1.1.2 h1:FHX5I5B4i4hKRVRBCFRxq1iQRej7WO3hhBuJf+UUySY=
|
||||
github.com/xdg-go/scram v1.1.2/go.mod h1:RT/sEzTbU5y00aCK8UOx6R7YryM0iF1N2MOmC3kKLN4=
|
||||
github.com/xdg-go/stringprep v1.0.4 h1:XLI/Ng3O1Atzq0oBs3TWm+5ZVgkq2aqdlvP9JtoZ6c8=
|
||||
github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM=
|
||||
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 h1:ilQV1hzziu+LLM3zUTJ0trRztfwgjqKnBWNtSRkbmwM=
|
||||
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78/go.mod h1:aL8wCCfTfSfmXjznFBSZNN13rSJjlIOI1fUNAtF7rmI=
|
||||
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||
go.mongodb.org/mongo-driver v1.17.3 h1:TQyXhnsWfWtgAhMtOgtYHMTkZIfBTpMTsMnd9ZBeHxQ=
|
||||
go.mongodb.org/mongo-driver v1.17.3/go.mod h1:Hy04i7O2kC4RS06ZrhPRqj/u4DTYkFDAAccj+rVKqgQ=
|
||||
go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU=
|
||||
go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc=
|
||||
go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU=
|
||||
go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.38.0 h1:jt+WWG8IZlBnVbomuhg2Mdq0+BBQaHbtqHEFEigjUV8=
|
||||
golang.org/x/crypto v0.38.0/go.mod h1:MvrbAqul58NNYPKnOra203SB9vpuZW0e+RRZV+Ggqjw=
|
||||
golang.org/x/exp v0.0.0-20240119083558-1b970713d09a h1:Q8/wZp0KX97QFTc2ywcOE0YRjZPVIx+MXInMzdvQqcA=
|
||||
golang.org/x/exp v0.0.0-20240119083558-1b970713d09a/go.mod h1:idGWGoKP1toJGkd5/ig9ZLuPcZBC3ewk7SzmH0uou08=
|
||||
golang.org/x/exp v0.0.0-20241108190413-2d47ceb2692f h1:XdNn9LlyWAhLVp6P/i8QYBW+hlyhrhei9uErw2B5GJo=
|
||||
golang.org/x/exp v0.0.0-20241108190413-2d47ceb2692f/go.mod h1:D5SMRVC3C2/4+F/DB1wZsLRnSNimn2Sp/NPsCrsv8ak=
|
||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE=
|
||||
golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
|
||||
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.9.0 h1:fEo0HyrW1GIgZdpbhCRO0PkJajUS5H9IFUztCgEo2jQ=
|
||||
golang.org/x/sync v0.9.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.14.0 h1:woo0S4Yywslg6hp4eUFjTVOyKt0RookbpAHG4c1HmhQ=
|
||||
golang.org/x/sync v0.14.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||
golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ=
|
||||
golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4=
|
||||
golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gorm.io/gorm v1.26.1 h1:ghB2gUI9FkS46luZtn6DLZ0f6ooBJ5IbVej2ENFDjRw=
|
||||
gorm.io/gorm v1.26.1/go.mod h1:8Z33v652h4//uMA76KjeDH8mJXPm1QNCYrMeatR0DOE=
|
||||
|
260
query_builder/builder.go
Normal file
260
query_builder/builder.go
Normal file
@ -0,0 +1,260 @@
|
||||
package builder
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// DBProxy 数据实例结构
|
||||
type DBProxy struct {
|
||||
db *gorm.DB
|
||||
mongodb *mongo.Collection // 需提前指定.Database("db_name").Collection("collection_name")
|
||||
// redis、elasticsearch...
|
||||
}
|
||||
|
||||
// NewDBProxy 创建数据实例
|
||||
func NewDBProxy(db *gorm.DB, mongodb *mongo.Collection) *DBProxy {
|
||||
return &DBProxy{
|
||||
db: db,
|
||||
mongodb: mongodb,
|
||||
}
|
||||
}
|
||||
|
||||
// QueryMiddleware 查询中间件类型定义
|
||||
// 参数:
|
||||
//
|
||||
// ctx: 上下文
|
||||
// builder: 查询构建器实例
|
||||
// next: 下一个中间件或最终查询处理器
|
||||
//
|
||||
// 返回:
|
||||
//
|
||||
// []*R: 查询结果列表
|
||||
// int64: 总数
|
||||
// error: 错误信息
|
||||
type QueryMiddleware[R any] func(
|
||||
ctx context.Context,
|
||||
builder *builder[R],
|
||||
next func(context.Context) ([]*R, int64, error),
|
||||
) ([]*R, int64, error)
|
||||
|
||||
// builder 查询构建器,使用泛型支持多种实体类型
|
||||
// 泛型参数:
|
||||
//
|
||||
// R: 查询结果的实体类型
|
||||
type builder[R any] struct {
|
||||
data *DBProxy
|
||||
start uint32
|
||||
limit uint32
|
||||
needTotal bool
|
||||
needPagination bool
|
||||
strategy QueryListStrategy[R] // 查询策略
|
||||
middlewares []QueryMiddleware[R] // 中间件链
|
||||
|
||||
filter func(context.Context) (any, error)
|
||||
sort func() any
|
||||
}
|
||||
|
||||
// SetFilter 设置过滤条件生成函数
|
||||
// 返回支持链式调用的构建器实例
|
||||
func (b *builder[R]) SetFilter(filter func(context.Context) (any, error)) *builder[R] {
|
||||
b.filter = filter
|
||||
return b
|
||||
}
|
||||
|
||||
// SetSort 设置排序条件生成函数
|
||||
// 返回支持链式调用的构建器实例
|
||||
func (b *builder[R]) SetSort(sort func() any) *builder[R] {
|
||||
b.sort = sort
|
||||
return b
|
||||
}
|
||||
|
||||
// SetStrategy 设置查询列表策略
|
||||
// 返回支持链式调用的构建器实例
|
||||
func (b *builder[R]) SetStrategy(strategy QueryListStrategy[R]) *builder[R] {
|
||||
b.strategy = strategy
|
||||
return b
|
||||
}
|
||||
|
||||
// Use 添加中间件
|
||||
// 返回支持链式调用的构建器实例
|
||||
func (b *builder[R]) Use(middleware QueryMiddleware[R]) *builder[R] {
|
||||
b.middlewares = append(b.middlewares, middleware)
|
||||
return b
|
||||
}
|
||||
|
||||
// getQueryStrategy 获取查询列表策略
|
||||
// 如果没有设置策略,则根据数据源自动选择策略
|
||||
func (b *builder[R]) getQueryStrategy() (QueryListStrategy[R], error) {
|
||||
if b.strategy != nil {
|
||||
return b.strategy, nil
|
||||
}
|
||||
if b.data == nil {
|
||||
return nil, errors.New("no data source provided")
|
||||
}
|
||||
|
||||
switch {
|
||||
case b.data.db != nil:
|
||||
return NewQueryGormListStrategy[R](), nil
|
||||
case b.data.mongodb != nil:
|
||||
return NewQueryMongoListStrategy[R](), nil
|
||||
default:
|
||||
return nil, errors.New("query strategy not set and no valid DB found")
|
||||
}
|
||||
}
|
||||
|
||||
// QueryList 执行查询列表操作
|
||||
// 返回值与中间件类型相同,list []R 查询结果列表
|
||||
func (b *builder[R]) QueryList(ctx context.Context) ([]*R, int64, error) {
|
||||
// 尝试自动推断策略类型
|
||||
strategy, err := b.getQueryStrategy()
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 构建中间件链
|
||||
next := func(ctx context.Context) ([]*R, int64, error) {
|
||||
return strategy.QueryList(ctx, b)
|
||||
}
|
||||
|
||||
for i := len(b.middlewares) - 1; i >= 0; i-- {
|
||||
next = func(mw QueryMiddleware[R], fn func(context.Context) ([]*R, int64, error)) func(context.Context) ([]*R, int64, error) {
|
||||
return func(ctx context.Context) ([]*R, int64, error) {
|
||||
return mw(ctx, b, fn)
|
||||
}
|
||||
}(b.middlewares[i], next)
|
||||
}
|
||||
|
||||
return next(ctx)
|
||||
}
|
||||
|
||||
// QueryListStrategy 查询列表策略
|
||||
type QueryListStrategy[R any] interface {
|
||||
QueryList(context.Context, *builder[R]) ([]*R, int64, error)
|
||||
}
|
||||
|
||||
// QueryGormListStrategy GORM 查询策略实现
|
||||
type QueryGormListStrategy[R any] struct{}
|
||||
|
||||
// NewQueryGormListStrategy 创建 GORM 查询策略实例
|
||||
func NewQueryGormListStrategy[R any]() *QueryGormListStrategy[R] {
|
||||
return &QueryGormListStrategy[R]{}
|
||||
}
|
||||
|
||||
// QueryList 实现 GORM 查询逻辑
|
||||
func (s *QueryGormListStrategy[R]) QueryList(
|
||||
ctx context.Context,
|
||||
builder *builder[R],
|
||||
) (list []*R, total int64, err error) {
|
||||
filterScope, err := builder.filter(ctx)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
sortScope := builder.sort()
|
||||
// 验证过滤条件和排序条件的类型有效性
|
||||
for _, scope := range []any{filterScope, sortScope} {
|
||||
if _, ok := scope.(func(*gorm.DB) *gorm.DB); !ok {
|
||||
return nil, 0, errors.New("invalid scope")
|
||||
}
|
||||
}
|
||||
|
||||
// 使用 WaitAndGo 并行执行数据查询和总数统计操作
|
||||
if err := WaitAndGo(func() error {
|
||||
query := builder.data.db.WithContext(ctx).
|
||||
Model(&list).
|
||||
Scopes(filterScope.(func(*gorm.DB) *gorm.DB), sortScope.(func(*gorm.DB) *gorm.DB))
|
||||
|
||||
if builder.needPagination {
|
||||
if builder.limit < 1 {
|
||||
builder.limit = defaultLimit
|
||||
}
|
||||
query = query.Offset(int(builder.start)).Limit(int(builder.limit))
|
||||
}
|
||||
|
||||
return query.Find(&list).Error
|
||||
}, func() error {
|
||||
if !builder.needTotal {
|
||||
return nil
|
||||
}
|
||||
|
||||
return builder.data.db.WithContext(ctx).
|
||||
Model(&list).
|
||||
Scopes(filterScope.(func(*gorm.DB) *gorm.DB)).
|
||||
Count(&total).
|
||||
Error
|
||||
}); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return list, total, nil
|
||||
}
|
||||
|
||||
// QueryMongoListStrategy MongoDB 查询策略实现
|
||||
type QueryMongoListStrategy[R any] struct{}
|
||||
|
||||
// NewQueryMongoListStrategy 创建 MongoDB 查询策略实例
|
||||
func NewQueryMongoListStrategy[R any]() *QueryMongoListStrategy[R] {
|
||||
return &QueryMongoListStrategy[R]{}
|
||||
}
|
||||
|
||||
// QueryList 实现 MongoDB 查询逻辑
|
||||
func (s *QueryMongoListStrategy[R]) QueryList(
|
||||
ctx context.Context,
|
||||
builder *builder[R],
|
||||
) (list []*R, total int64, err error) {
|
||||
filterOpt, err := builder.filter(ctx)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
sortOpt := builder.sort()
|
||||
// 验证过滤条件和排序条件的类型有效性
|
||||
for _, opt := range []any{filterOpt, sortOpt} {
|
||||
_, mOk := opt.(bson.M)
|
||||
_, dOk := opt.(bson.D)
|
||||
if !mOk && !dOk {
|
||||
return nil, 0, errors.New("invalid option")
|
||||
}
|
||||
}
|
||||
|
||||
// 使用 WaitAndGo 并行执行数据查询和总数统计操作
|
||||
if err := WaitAndGo(func() error {
|
||||
findOpt := options.Find().SetSort(sortOpt)
|
||||
if builder.needPagination {
|
||||
if builder.limit < 1 {
|
||||
builder.limit = defaultLimit
|
||||
}
|
||||
findOpt.SetSkip(int64(builder.start)).SetLimit(int64(builder.limit))
|
||||
}
|
||||
|
||||
cursor, err := builder.data.mongodb.Find(ctx, filterOpt, findOpt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func(cursor *mongo.Cursor, ctx context.Context) {
|
||||
_ = cursor.Close(ctx)
|
||||
}(cursor, ctx)
|
||||
|
||||
return cursor.All(ctx, &list)
|
||||
}, func() error {
|
||||
if !builder.needTotal {
|
||||
return nil
|
||||
}
|
||||
|
||||
total, err = builder.data.mongodb.CountDocuments(ctx, filterOpt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return list, total, nil
|
||||
}
|
56
query_builder/mock_strategy.go
Normal file
56
query_builder/mock_strategy.go
Normal file
@ -0,0 +1,56 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: query_builder/builder.go
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen -source=query_builder/builder.go -destination=query_builder/mock_strategy.go -package=builder
|
||||
//
|
||||
|
||||
// Package builder is a generated GoMock package.
|
||||
package builder
|
||||
|
||||
import (
|
||||
context "context"
|
||||
reflect "reflect"
|
||||
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
// MockQueryListStrategy is a mock of QueryListStrategy interface.
|
||||
type MockQueryListStrategy[R any] struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockQueryListStrategyMockRecorder[R]
|
||||
}
|
||||
|
||||
// MockQueryListStrategyMockRecorder is the mock recorder for MockQueryListStrategy.
|
||||
type MockQueryListStrategyMockRecorder[R any] struct {
|
||||
mock *MockQueryListStrategy[R]
|
||||
}
|
||||
|
||||
// NewMockQueryListStrategy creates a new mock instance.
|
||||
func NewMockQueryListStrategy[R any](ctrl *gomock.Controller) *MockQueryListStrategy[R] {
|
||||
mock := &MockQueryListStrategy[R]{ctrl: ctrl}
|
||||
mock.recorder = &MockQueryListStrategyMockRecorder[R]{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockQueryListStrategy[R]) EXPECT() *MockQueryListStrategyMockRecorder[R] {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// QueryList mocks base method.
|
||||
func (m *MockQueryListStrategy[R]) QueryList(arg0 context.Context, arg1 *builder[R]) ([]*R, int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "QueryList", arg0, arg1)
|
||||
ret0, _ := ret[0].([]*R)
|
||||
ret1, _ := ret[1].(int64)
|
||||
ret2, _ := ret[2].(error)
|
||||
return ret0, ret1, ret2
|
||||
}
|
||||
|
||||
// QueryList indicates an expected call of QueryList.
|
||||
func (mr *MockQueryListStrategyMockRecorder[R]) QueryList(arg0, arg1 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryList", reflect.TypeOf((*MockQueryListStrategy[R])(nil).QueryList), arg0, arg1)
|
||||
}
|
137
query_builder/option.go
Normal file
137
query_builder/option.go
Normal file
@ -0,0 +1,137 @@
|
||||
package builder
|
||||
|
||||
const (
|
||||
defaultStart = 0 // 默认从第0条开始
|
||||
defaultLimit = 10 // 默认每页10条
|
||||
defaultNeedTotal = true // 默认需要总数
|
||||
defaultNeedPagination = true // 默认需要分页
|
||||
)
|
||||
|
||||
// Filter 定义过滤条件的通用接口类型
|
||||
type Filter any
|
||||
|
||||
// Sort 定义排序条件的通用接口类型
|
||||
type Sort any
|
||||
|
||||
// QueryListOptions 定义了查询列表的通用选项接口
|
||||
// 泛型参数:
|
||||
//
|
||||
// F - 过滤条件类型参数
|
||||
// S - 排序条件类型参数
|
||||
type QueryListOptions[F Filter, S Sort] interface {
|
||||
GetData() *DBProxy
|
||||
GetFilter() *F
|
||||
GetSort() S
|
||||
GetStart() uint32
|
||||
GetLimit() uint32
|
||||
GetNeedTotal() bool
|
||||
GetNeedPagination() bool
|
||||
}
|
||||
|
||||
// BaseQueryListOptions 实现了QueryListOptions接口的基础结构体
|
||||
// 包含查询列表所需的所有基本选项
|
||||
type BaseQueryListOptions[F Filter, S Sort] struct {
|
||||
data *DBProxy // 数据实例
|
||||
filter *F // 过滤条件生成函数
|
||||
sort S // 排序条件生成函数
|
||||
start uint32 // 分页起始位置
|
||||
limit uint32 // 每页数据条数
|
||||
needTotal bool // 是否需要查询总数
|
||||
needPagination bool // 是否需要分页
|
||||
}
|
||||
|
||||
func (opts *BaseQueryListOptions[F, S]) GetData() *DBProxy {
|
||||
return opts.data
|
||||
}
|
||||
|
||||
func (opts *BaseQueryListOptions[F, S]) GetFilter() *F {
|
||||
return opts.filter
|
||||
}
|
||||
|
||||
func (opts *BaseQueryListOptions[F, S]) GetSort() S {
|
||||
return opts.sort
|
||||
}
|
||||
|
||||
func (opts *BaseQueryListOptions[F, S]) GetStart() uint32 {
|
||||
return opts.start
|
||||
}
|
||||
|
||||
func (opts *BaseQueryListOptions[F, S]) GetLimit() uint32 {
|
||||
return opts.limit
|
||||
}
|
||||
|
||||
func (opts *BaseQueryListOptions[F, S]) GetNeedTotal() bool {
|
||||
return opts.needTotal
|
||||
}
|
||||
|
||||
func (opts *BaseQueryListOptions[F, S]) GetNeedPagination() bool { return opts.needPagination }
|
||||
|
||||
// QueryOption 定义用于配置查询选项的函数类型
|
||||
type QueryOption[F Filter, S Sort] func(options *BaseQueryListOptions[F, S])
|
||||
|
||||
// LoadQueryOptions 加载并应用查询选项
|
||||
// 参数:
|
||||
//
|
||||
// opts - 可变数量的查询选项函数
|
||||
//
|
||||
// 返回:
|
||||
//
|
||||
// 配置好的BaseQueryListOptions实例
|
||||
func LoadQueryOptions[F Filter, S Sort](opts ...QueryOption[F, S]) BaseQueryListOptions[F, S] {
|
||||
// 初始化默认选项
|
||||
options := BaseQueryListOptions[F, S]{
|
||||
start: defaultStart,
|
||||
limit: defaultLimit,
|
||||
needTotal: defaultNeedTotal,
|
||||
needPagination: defaultNeedPagination,
|
||||
}
|
||||
|
||||
// 应用所有选项函数
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
|
||||
return options
|
||||
}
|
||||
|
||||
func WithData[F Filter, S Sort](data *DBProxy) QueryOption[F, S] {
|
||||
return func(o *BaseQueryListOptions[F, S]) {
|
||||
o.data = data
|
||||
}
|
||||
}
|
||||
|
||||
func WithFilter[F Filter, S Sort](filter *F) QueryOption[F, S] {
|
||||
return func(o *BaseQueryListOptions[F, S]) {
|
||||
o.filter = filter
|
||||
}
|
||||
}
|
||||
|
||||
func WithSort[F Filter, S Sort](sort S) QueryOption[F, S] {
|
||||
return func(o *BaseQueryListOptions[F, S]) {
|
||||
o.sort = sort
|
||||
}
|
||||
}
|
||||
|
||||
func WithStart[F Filter, S Sort](start uint32) QueryOption[F, S] {
|
||||
return func(o *BaseQueryListOptions[F, S]) {
|
||||
o.start = start
|
||||
}
|
||||
}
|
||||
|
||||
func WithLimit[F Filter, S Sort](limit uint32) QueryOption[F, S] {
|
||||
return func(o *BaseQueryListOptions[F, S]) {
|
||||
o.limit = limit
|
||||
}
|
||||
}
|
||||
|
||||
func WithNeedTotal[F Filter, S Sort](needTotal bool) QueryOption[F, S] {
|
||||
return func(o *BaseQueryListOptions[F, S]) {
|
||||
o.needTotal = needTotal
|
||||
}
|
||||
}
|
||||
|
||||
func WithNeedPagination[F Filter, S Sort](needPagination bool) QueryOption[F, S] {
|
||||
return func(o *BaseQueryListOptions[F, S]) {
|
||||
o.needPagination = needPagination
|
||||
}
|
||||
}
|
73
query_builder/service.go
Normal file
73
query_builder/service.go
Normal file
@ -0,0 +1,73 @@
|
||||
package builder
|
||||
|
||||
import "context"
|
||||
|
||||
// Service 通用查询服务接口
|
||||
type Service interface {
|
||||
GetFilter(context.Context) (any, error)
|
||||
GetSort() any
|
||||
}
|
||||
|
||||
// BaseService 基础查询服务
|
||||
// 泛型参数:
|
||||
//
|
||||
// R - 返回结果类型参数
|
||||
// F - 过滤条件类型参数
|
||||
// S - 排序条件类型参数
|
||||
type BaseService[R any, F any, S any] struct {
|
||||
builder[R]
|
||||
filter *F
|
||||
sort S
|
||||
service Service
|
||||
}
|
||||
|
||||
// QueryList 执行列表查询
|
||||
func (s *BaseService[R, F, S]) QueryList(ctx context.Context) ([]*R, int64, error) {
|
||||
if s.service == nil {
|
||||
return nil, 0, nil
|
||||
}
|
||||
|
||||
// 执行查询
|
||||
return s.builder.
|
||||
SetFilter(s.service.GetFilter).
|
||||
SetSort(s.service.GetSort).
|
||||
QueryList(ctx)
|
||||
}
|
||||
|
||||
// List 查询列表功能结构
|
||||
// 泛型参数上同
|
||||
type List[R any, F any, S any] struct {
|
||||
strategy QueryListStrategy[R]
|
||||
}
|
||||
|
||||
// SetStrategy 设置查询策略
|
||||
// 支持不同数据源的查询实现,如MySQL、MongoDB等
|
||||
// 通过该方法可自定义查询策略,已有策略可根据数据源自动选择
|
||||
func (l *List[R, F, S]) SetStrategy(strategy QueryListStrategy[R]) *List[R, F, S] {
|
||||
l.strategy = strategy
|
||||
return l
|
||||
}
|
||||
|
||||
// Query 执行查询
|
||||
// 该方法会根据传入的Service实例和QueryOption选项执行查询
|
||||
func (l *List[R, F, S]) Query(
|
||||
ctx context.Context,
|
||||
s Service,
|
||||
opts ...QueryOption[F, S],
|
||||
) ([]*R, int64, error) {
|
||||
options := LoadQueryOptions(opts...)
|
||||
service := &BaseService[R, F, S]{
|
||||
builder: builder[R]{
|
||||
data: options.GetData(),
|
||||
start: options.GetStart(),
|
||||
limit: options.GetLimit(),
|
||||
needTotal: options.GetNeedTotal(),
|
||||
needPagination: options.GetNeedPagination(),
|
||||
strategy: l.strategy,
|
||||
},
|
||||
filter: options.GetFilter(),
|
||||
sort: options.GetSort(),
|
||||
service: s,
|
||||
}
|
||||
return service.QueryList(ctx)
|
||||
}
|
209
query_builder/service_test.go
Normal file
209
query_builder/service_test.go
Normal file
@ -0,0 +1,209 @@
|
||||
package builder
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"go.uber.org/mock/gomock"
|
||||
"gorm.io/gorm"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type TestEntity struct {
|
||||
ID uint32
|
||||
Name string
|
||||
Age int
|
||||
}
|
||||
|
||||
type TestFilter struct {
|
||||
Name string
|
||||
Age uint8
|
||||
}
|
||||
|
||||
type TestSort struct {
|
||||
Field string
|
||||
Direction string
|
||||
}
|
||||
|
||||
type TestService struct {
|
||||
filter TestFilter
|
||||
sort TestSort
|
||||
}
|
||||
|
||||
func (s *TestService) GetFilter(_ context.Context) (any, error) {
|
||||
return func(db *gorm.DB) *gorm.DB {
|
||||
if s.filter.Name != "" {
|
||||
db.Where("name = ?", s.filter.Name)
|
||||
}
|
||||
|
||||
if s.filter.Age > 0 {
|
||||
db.Where("age >= ?", s.filter.Age)
|
||||
}
|
||||
|
||||
return db
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *TestService) GetSort() any {
|
||||
return func(db *gorm.DB) *gorm.DB {
|
||||
// 实际项目中的排序需要根据pb文件生成的枚举值来处理
|
||||
return db.Order(s.sort.Field + " " + s.sort.Direction)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryList(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
// Mock 策略实例
|
||||
mockStrategy := NewMockQueryListStrategy[TestEntity](ctrl)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
service Service
|
||||
mockSetup func()
|
||||
opts []QueryOption[TestFilter, TestSort]
|
||||
expectedResult []*TestEntity
|
||||
expectedTotal int64
|
||||
expectedErr error
|
||||
}{
|
||||
{
|
||||
name: "无筛选查询&id升序",
|
||||
service: &TestService{},
|
||||
mockSetup: func() {
|
||||
mockStrategy.EXPECT().
|
||||
QueryList(ctx, gomock.Any()).
|
||||
Return([]*TestEntity{
|
||||
{ID: 1, Name: "Alice", Age: 25},
|
||||
{ID: 2, Name: "Bob", Age: 30},
|
||||
}, int64(2), nil)
|
||||
},
|
||||
opts: []QueryOption[TestFilter, TestSort]{
|
||||
WithData[TestFilter, TestSort](NewDBProxy(&gorm.DB{}, nil)),
|
||||
WithFilter[TestFilter, TestSort](&TestFilter{}),
|
||||
WithSort[TestFilter, TestSort](TestSort{Field: "id", Direction: "asc"}),
|
||||
},
|
||||
expectedResult: []*TestEntity{
|
||||
{ID: 1, Name: "Alice", Age: 25},
|
||||
{ID: 2, Name: "Bob", Age: 30},
|
||||
},
|
||||
expectedTotal: 2,
|
||||
expectedErr: nil,
|
||||
},
|
||||
{
|
||||
name: "无筛选查询&age降序",
|
||||
service: &TestService{},
|
||||
mockSetup: func() {
|
||||
mockStrategy.EXPECT().
|
||||
QueryList(ctx, gomock.Any()).
|
||||
Return([]*TestEntity{
|
||||
{ID: 2, Name: "Bob", Age: 30},
|
||||
{ID: 1, Name: "Alice", Age: 25},
|
||||
}, int64(2), nil)
|
||||
},
|
||||
opts: []QueryOption[TestFilter, TestSort]{
|
||||
WithData[TestFilter, TestSort](NewDBProxy(&gorm.DB{}, nil)),
|
||||
WithFilter[TestFilter, TestSort](&TestFilter{}),
|
||||
WithSort[TestFilter, TestSort](TestSort{Field: "age", Direction: "desc"}),
|
||||
},
|
||||
expectedResult: []*TestEntity{
|
||||
{ID: 2, Name: "Bob", Age: 30},
|
||||
{ID: 1, Name: "Alice", Age: 25},
|
||||
},
|
||||
expectedTotal: 2,
|
||||
expectedErr: nil,
|
||||
},
|
||||
{
|
||||
name: "有筛选查询",
|
||||
service: &TestService{},
|
||||
mockSetup: func() {
|
||||
mockStrategy.EXPECT().
|
||||
QueryList(ctx, gomock.Any()).
|
||||
Return([]*TestEntity{
|
||||
{ID: 1, Name: "Alice", Age: 25},
|
||||
}, int64(1), nil)
|
||||
},
|
||||
opts: []QueryOption[TestFilter, TestSort]{
|
||||
WithData[TestFilter, TestSort](NewDBProxy(&gorm.DB{}, nil)),
|
||||
WithFilter[TestFilter, TestSort](&TestFilter{Name: "Alice"}),
|
||||
WithSort[TestFilter, TestSort](TestSort{Field: "id", Direction: "desc"}),
|
||||
},
|
||||
expectedResult: []*TestEntity{
|
||||
{ID: 1, Name: "Alice", Age: 25},
|
||||
},
|
||||
expectedTotal: 1,
|
||||
expectedErr: nil,
|
||||
},
|
||||
{
|
||||
name: "无数据实例",
|
||||
service: &TestService{},
|
||||
mockSetup: func() {
|
||||
mockStrategy.EXPECT().
|
||||
QueryList(ctx, gomock.Any()).
|
||||
Return(nil, int64(0), nil)
|
||||
},
|
||||
opts: []QueryOption[TestFilter, TestSort]{
|
||||
WithData[TestFilter, TestSort](NewDBProxy(&gorm.DB{}, nil)),
|
||||
WithFilter[TestFilter, TestSort](&TestFilter{Name: "test"}),
|
||||
WithSort[TestFilter, TestSort](TestSort{Field: "id", Direction: "asc"}),
|
||||
},
|
||||
expectedResult: nil,
|
||||
expectedTotal: 0,
|
||||
expectedErr: nil,
|
||||
},
|
||||
{
|
||||
name: "数据实例错误",
|
||||
service: &TestService{},
|
||||
mockSetup: func() {
|
||||
mockStrategy.EXPECT().
|
||||
QueryList(ctx, gomock.Any()).
|
||||
Return(nil, int64(0), errors.New("no data source provided"))
|
||||
},
|
||||
opts: []QueryOption[TestFilter, TestSort]{
|
||||
WithFilter[TestFilter, TestSort](&TestFilter{Name: "test"}),
|
||||
WithSort[TestFilter, TestSort](TestSort{Field: "id", Direction: "asc"}),
|
||||
},
|
||||
expectedResult: nil,
|
||||
expectedTotal: 0,
|
||||
expectedErr: errors.New("no data source provided"),
|
||||
},
|
||||
}
|
||||
|
||||
list := &List[TestEntity, TestFilter, TestSort]{}
|
||||
// 这里使用 Mock 策略替代真实的策略
|
||||
// 实际使用时会根据数据源自动选择
|
||||
list.SetStrategy(mockStrategy)
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.mockSetup != nil {
|
||||
tt.mockSetup()
|
||||
}
|
||||
|
||||
result, total, err := list.Query(ctx, tt.service, tt.opts...)
|
||||
|
||||
if tt.expectedErr != nil {
|
||||
if err == nil || err.Error() != tt.expectedErr.Error() {
|
||||
t.Errorf("expected error: %v, got: %v", tt.expectedErr, err)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if total != tt.expectedTotal {
|
||||
t.Errorf("expected total: %d, got: %d", tt.expectedTotal, total)
|
||||
}
|
||||
|
||||
if len(result) != len(tt.expectedResult) {
|
||||
t.Errorf("expected result length: %d, got: %d", len(tt.expectedResult), len(result))
|
||||
}
|
||||
|
||||
for i, item := range result {
|
||||
if item.ID != tt.expectedResult[i].ID || item.Name != tt.expectedResult[i].Name || item.Age != tt.expectedResult[i].Age {
|
||||
t.Errorf("expected result[%d]: %+v, got: %+v", i, tt.expectedResult[i], item)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
23
query_builder/utils.go
Normal file
23
query_builder/utils.go
Normal file
@ -0,0 +1,23 @@
|
||||
package builder
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"runtime/debug"
|
||||
)
|
||||
|
||||
// WaitAndGo 等待所有函数执行完毕
|
||||
func WaitAndGo(fn ...func() error) error {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
fmt.Printf("Panic: %+v\n %s", err, string(debug.Stack()))
|
||||
}
|
||||
}()
|
||||
var g errgroup.Group
|
||||
for _, f := range fn {
|
||||
g.Go(func() error {
|
||||
return f()
|
||||
})
|
||||
}
|
||||
return g.Wait()
|
||||
}
|
Loading…
Reference in New Issue
Block a user