261 lines
6.5 KiB
Go
261 lines
6.5 KiB
Go
package builder
|
||
|
||
import (
|
||
"context"
|
||
"errors"
|
||
"go.mongodb.org/mongo-driver/bson"
|
||
"go.mongodb.org/mongo-driver/mongo"
|
||
"go.mongodb.org/mongo-driver/mongo/options"
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
// DBProxy 数据实例结构
|
||
type DBProxy struct {
|
||
db *gorm.DB
|
||
mongodb *mongo.Collection // 需提前指定.Database("db_name").Collection("collection_name")
|
||
// redis、elasticsearch...
|
||
}
|
||
|
||
// NewDBProxy 创建数据实例
|
||
func NewDBProxy(db *gorm.DB, mongodb *mongo.Collection) *DBProxy {
|
||
return &DBProxy{
|
||
db: db,
|
||
mongodb: mongodb,
|
||
}
|
||
}
|
||
|
||
// QueryMiddleware 查询中间件类型定义
|
||
// 参数:
|
||
//
|
||
// ctx: 上下文
|
||
// builder: 查询构建器实例
|
||
// next: 下一个中间件或最终查询处理器
|
||
//
|
||
// 返回:
|
||
//
|
||
// []*R: 查询结果列表
|
||
// int64: 总数
|
||
// error: 错误信息
|
||
type QueryMiddleware[R any] func(
|
||
ctx context.Context,
|
||
builder *builder[R],
|
||
next func(context.Context) ([]*R, int64, error),
|
||
) ([]*R, int64, error)
|
||
|
||
// builder 查询构建器,使用泛型支持多种实体类型
|
||
// 泛型参数:
|
||
//
|
||
// R: 查询结果的实体类型
|
||
type builder[R any] struct {
|
||
data *DBProxy
|
||
start uint32
|
||
limit uint32
|
||
needTotal bool
|
||
needPagination bool
|
||
strategy QueryListStrategy[R] // 查询策略
|
||
middlewares []QueryMiddleware[R] // 中间件链
|
||
|
||
filter func(context.Context) (any, error)
|
||
sort func() any
|
||
}
|
||
|
||
// SetFilter 设置过滤条件生成函数
|
||
// 返回支持链式调用的构建器实例
|
||
func (b *builder[R]) SetFilter(filter func(context.Context) (any, error)) *builder[R] {
|
||
b.filter = filter
|
||
return b
|
||
}
|
||
|
||
// SetSort 设置排序条件生成函数
|
||
// 返回支持链式调用的构建器实例
|
||
func (b *builder[R]) SetSort(sort func() any) *builder[R] {
|
||
b.sort = sort
|
||
return b
|
||
}
|
||
|
||
// SetStrategy 设置查询列表策略
|
||
// 返回支持链式调用的构建器实例
|
||
func (b *builder[R]) SetStrategy(strategy QueryListStrategy[R]) *builder[R] {
|
||
b.strategy = strategy
|
||
return b
|
||
}
|
||
|
||
// Use 添加中间件
|
||
// 返回支持链式调用的构建器实例
|
||
func (b *builder[R]) Use(middleware QueryMiddleware[R]) *builder[R] {
|
||
b.middlewares = append(b.middlewares, middleware)
|
||
return b
|
||
}
|
||
|
||
// getQueryStrategy 获取查询列表策略
|
||
// 如果没有设置策略,则根据数据源自动选择策略
|
||
func (b *builder[R]) getQueryStrategy() (QueryListStrategy[R], error) {
|
||
if b.strategy != nil {
|
||
return b.strategy, nil
|
||
}
|
||
if b.data == nil {
|
||
return nil, errors.New("no data source provided")
|
||
}
|
||
|
||
switch {
|
||
case b.data.db != nil:
|
||
return NewQueryGormListStrategy[R](), nil
|
||
case b.data.mongodb != nil:
|
||
return NewQueryMongoListStrategy[R](), nil
|
||
default:
|
||
return nil, errors.New("query strategy not set and no valid DB found")
|
||
}
|
||
}
|
||
|
||
// QueryList 执行查询列表操作
|
||
// 返回值与中间件类型相同,list []R 查询结果列表
|
||
func (b *builder[R]) QueryList(ctx context.Context) ([]*R, int64, error) {
|
||
// 尝试自动推断策略类型
|
||
strategy, err := b.getQueryStrategy()
|
||
if err != nil {
|
||
return nil, 0, err
|
||
}
|
||
|
||
// 构建中间件链
|
||
next := func(ctx context.Context) ([]*R, int64, error) {
|
||
return strategy.QueryList(ctx, b)
|
||
}
|
||
|
||
for i := len(b.middlewares) - 1; i >= 0; i-- {
|
||
next = func(mw QueryMiddleware[R], fn func(context.Context) ([]*R, int64, error)) func(context.Context) ([]*R, int64, error) {
|
||
return func(ctx context.Context) ([]*R, int64, error) {
|
||
return mw(ctx, b, fn)
|
||
}
|
||
}(b.middlewares[i], next)
|
||
}
|
||
|
||
return next(ctx)
|
||
}
|
||
|
||
// QueryListStrategy 查询列表策略
|
||
type QueryListStrategy[R any] interface {
|
||
QueryList(context.Context, *builder[R]) ([]*R, int64, error)
|
||
}
|
||
|
||
// QueryGormListStrategy GORM 查询策略实现
|
||
type QueryGormListStrategy[R any] struct{}
|
||
|
||
// NewQueryGormListStrategy 创建 GORM 查询策略实例
|
||
func NewQueryGormListStrategy[R any]() *QueryGormListStrategy[R] {
|
||
return &QueryGormListStrategy[R]{}
|
||
}
|
||
|
||
// QueryList 实现 GORM 查询逻辑
|
||
func (s *QueryGormListStrategy[R]) QueryList(
|
||
ctx context.Context,
|
||
builder *builder[R],
|
||
) (list []*R, total int64, err error) {
|
||
filterScope, err := builder.filter(ctx)
|
||
if err != nil {
|
||
return nil, 0, err
|
||
}
|
||
|
||
sortScope := builder.sort()
|
||
// 验证过滤条件和排序条件的类型有效性
|
||
for _, scope := range []any{filterScope, sortScope} {
|
||
if _, ok := scope.(func(*gorm.DB) *gorm.DB); !ok {
|
||
return nil, 0, errors.New("invalid scope")
|
||
}
|
||
}
|
||
|
||
// 使用 WaitAndGo 并行执行数据查询和总数统计操作
|
||
if err := WaitAndGo(func() error {
|
||
query := builder.data.db.WithContext(ctx).
|
||
Model(&list).
|
||
Scopes(filterScope.(func(*gorm.DB) *gorm.DB), sortScope.(func(*gorm.DB) *gorm.DB))
|
||
|
||
if builder.needPagination {
|
||
if builder.limit < 1 {
|
||
builder.limit = defaultLimit
|
||
}
|
||
query = query.Offset(int(builder.start)).Limit(int(builder.limit))
|
||
}
|
||
|
||
return query.Find(&list).Error
|
||
}, func() error {
|
||
if !builder.needTotal {
|
||
return nil
|
||
}
|
||
|
||
return builder.data.db.WithContext(ctx).
|
||
Model(&list).
|
||
Scopes(filterScope.(func(*gorm.DB) *gorm.DB)).
|
||
Count(&total).
|
||
Error
|
||
}); err != nil {
|
||
return nil, 0, err
|
||
}
|
||
|
||
return list, total, nil
|
||
}
|
||
|
||
// QueryMongoListStrategy MongoDB 查询策略实现
|
||
type QueryMongoListStrategy[R any] struct{}
|
||
|
||
// NewQueryMongoListStrategy 创建 MongoDB 查询策略实例
|
||
func NewQueryMongoListStrategy[R any]() *QueryMongoListStrategy[R] {
|
||
return &QueryMongoListStrategy[R]{}
|
||
}
|
||
|
||
// QueryList 实现 MongoDB 查询逻辑
|
||
func (s *QueryMongoListStrategy[R]) QueryList(
|
||
ctx context.Context,
|
||
builder *builder[R],
|
||
) (list []*R, total int64, err error) {
|
||
filterOpt, err := builder.filter(ctx)
|
||
if err != nil {
|
||
return nil, 0, err
|
||
}
|
||
|
||
sortOpt := builder.sort()
|
||
// 验证过滤条件和排序条件的类型有效性
|
||
for _, opt := range []any{filterOpt, sortOpt} {
|
||
_, mOk := opt.(bson.M)
|
||
_, dOk := opt.(bson.D)
|
||
if !mOk && !dOk {
|
||
return nil, 0, errors.New("invalid option")
|
||
}
|
||
}
|
||
|
||
// 使用 WaitAndGo 并行执行数据查询和总数统计操作
|
||
if err := WaitAndGo(func() error {
|
||
findOpt := options.Find().SetSort(sortOpt)
|
||
if builder.needPagination {
|
||
if builder.limit < 1 {
|
||
builder.limit = defaultLimit
|
||
}
|
||
findOpt.SetSkip(int64(builder.start)).SetLimit(int64(builder.limit))
|
||
}
|
||
|
||
cursor, err := builder.data.mongodb.Find(ctx, filterOpt, findOpt)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
defer func(cursor *mongo.Cursor, ctx context.Context) {
|
||
_ = cursor.Close(ctx)
|
||
}(cursor, ctx)
|
||
|
||
return cursor.All(ctx, &list)
|
||
}, func() error {
|
||
if !builder.needTotal {
|
||
return nil
|
||
}
|
||
|
||
total, err = builder.data.mongodb.CountDocuments(ctx, filterOpt)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
return nil
|
||
}); err != nil {
|
||
return nil, 0, err
|
||
}
|
||
|
||
return list, total, nil
|
||
}
|