package builder import ( "context" "errors" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" "gorm.io/gorm" ) // DBProxy 数据实例结构 type DBProxy struct { db *gorm.DB mongodb *mongo.Collection // 需提前指定.Database("db_name").Collection("collection_name") // redis、elasticsearch... } // NewDBProxy 创建数据实例 func NewDBProxy(db *gorm.DB, mongodb *mongo.Collection) *DBProxy { return &DBProxy{ db: db, mongodb: mongodb, } } // QueryMiddleware 查询中间件类型定义 // 参数: // // ctx: 上下文 // builder: 查询构建器实例 // next: 下一个中间件或最终查询处理器 // // 返回: // // []*R: 查询结果列表 // int64: 总数 // error: 错误信息 type QueryMiddleware[R any] func( ctx context.Context, builder *builder[R], next func(context.Context) ([]*R, int64, error), ) ([]*R, int64, error) // builder 查询构建器,使用泛型支持多种实体类型 // 泛型参数: // // R: 查询结果的实体类型 type builder[R any] struct { data *DBProxy start uint32 limit uint32 needTotal bool needPagination bool strategy QueryListStrategy[R] // 查询策略 middlewares []QueryMiddleware[R] // 中间件链 filter func(context.Context) (any, error) sort func() any } // SetFilter 设置过滤条件生成函数 // 返回支持链式调用的构建器实例 func (b *builder[R]) SetFilter(filter func(context.Context) (any, error)) *builder[R] { b.filter = filter return b } // SetSort 设置排序条件生成函数 // 返回支持链式调用的构建器实例 func (b *builder[R]) SetSort(sort func() any) *builder[R] { b.sort = sort return b } // SetStrategy 设置查询列表策略 // 返回支持链式调用的构建器实例 func (b *builder[R]) SetStrategy(strategy QueryListStrategy[R]) *builder[R] { b.strategy = strategy return b } // Use 添加中间件 // 返回支持链式调用的构建器实例 func (b *builder[R]) Use(middleware QueryMiddleware[R]) *builder[R] { b.middlewares = append(b.middlewares, middleware) return b } // getQueryStrategy 获取查询列表策略 // 如果没有设置策略,则根据数据源自动选择策略 func (b *builder[R]) getQueryStrategy() (QueryListStrategy[R], error) { if b.strategy != nil { return b.strategy, nil } if b.data == nil { return nil, errors.New("no data source provided") } switch { case b.data.db != nil: return NewQueryGormListStrategy[R](), nil case b.data.mongodb != nil: return NewQueryMongoListStrategy[R](), nil default: return nil, errors.New("query strategy not set and no valid DB found") } } // QueryList 执行查询列表操作 // 返回值与中间件类型相同,list []R 查询结果列表 func (b *builder[R]) QueryList(ctx context.Context) ([]*R, int64, error) { // 尝试自动推断策略类型 strategy, err := b.getQueryStrategy() if err != nil { return nil, 0, err } // 构建中间件链 next := func(ctx context.Context) ([]*R, int64, error) { return strategy.QueryList(ctx, b) } for i := len(b.middlewares) - 1; i >= 0; i-- { next = func(mw QueryMiddleware[R], fn func(context.Context) ([]*R, int64, error)) func(context.Context) ([]*R, int64, error) { return func(ctx context.Context) ([]*R, int64, error) { return mw(ctx, b, fn) } }(b.middlewares[i], next) } return next(ctx) } // QueryListStrategy 查询列表策略 type QueryListStrategy[R any] interface { QueryList(context.Context, *builder[R]) ([]*R, int64, error) } // QueryGormListStrategy GORM 查询策略实现 type QueryGormListStrategy[R any] struct{} // NewQueryGormListStrategy 创建 GORM 查询策略实例 func NewQueryGormListStrategy[R any]() *QueryGormListStrategy[R] { return &QueryGormListStrategy[R]{} } // QueryList 实现 GORM 查询逻辑 func (s *QueryGormListStrategy[R]) QueryList( ctx context.Context, builder *builder[R], ) (list []*R, total int64, err error) { filterScope, err := builder.filter(ctx) if err != nil { return nil, 0, err } sortScope := builder.sort() // 验证过滤条件和排序条件的类型有效性 for _, scope := range []any{filterScope, sortScope} { if _, ok := scope.(func(*gorm.DB) *gorm.DB); !ok { return nil, 0, errors.New("invalid scope") } } // 使用 WaitAndGo 并行执行数据查询和总数统计操作 if err := WaitAndGo(func() error { query := builder.data.db.WithContext(ctx). Model(&list). Scopes(filterScope.(func(*gorm.DB) *gorm.DB), sortScope.(func(*gorm.DB) *gorm.DB)) if builder.needPagination { if builder.limit < 1 { builder.limit = defaultLimit } query = query.Offset(int(builder.start)).Limit(int(builder.limit)) } return query.Find(&list).Error }, func() error { if !builder.needTotal { return nil } return builder.data.db.WithContext(ctx). Model(&list). Scopes(filterScope.(func(*gorm.DB) *gorm.DB)). Count(&total). Error }); err != nil { return nil, 0, err } return list, total, nil } // QueryMongoListStrategy MongoDB 查询策略实现 type QueryMongoListStrategy[R any] struct{} // NewQueryMongoListStrategy 创建 MongoDB 查询策略实例 func NewQueryMongoListStrategy[R any]() *QueryMongoListStrategy[R] { return &QueryMongoListStrategy[R]{} } // QueryList 实现 MongoDB 查询逻辑 func (s *QueryMongoListStrategy[R]) QueryList( ctx context.Context, builder *builder[R], ) (list []*R, total int64, err error) { filterOpt, err := builder.filter(ctx) if err != nil { return nil, 0, err } sortOpt := builder.sort() // 验证过滤条件和排序条件的类型有效性 for _, opt := range []any{filterOpt, sortOpt} { _, mOk := opt.(bson.M) _, dOk := opt.(bson.D) if !mOk && !dOk { return nil, 0, errors.New("invalid option") } } // 使用 WaitAndGo 并行执行数据查询和总数统计操作 if err := WaitAndGo(func() error { findOpt := options.Find().SetSort(sortOpt) if builder.needPagination { if builder.limit < 1 { builder.limit = defaultLimit } findOpt.SetSkip(int64(builder.start)).SetLimit(int64(builder.limit)) } cursor, err := builder.data.mongodb.Find(ctx, filterOpt, findOpt) if err != nil { return err } defer func(cursor *mongo.Cursor, ctx context.Context) { _ = cursor.Close(ctx) }(cursor, ctx) return cursor.All(ctx, &list) }, func() error { if !builder.needTotal { return nil } total, err = builder.data.mongodb.CountDocuments(ctx, filterOpt) if err != nil { return err } return nil }); err != nil { return nil, 0, err } return list, total, nil }