go-study/query_builder/builder.go

261 lines
6.5 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package builder
import (
"context"
"errors"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"gorm.io/gorm"
)
// DBProxy 数据实例结构
type DBProxy struct {
db *gorm.DB
mongodb *mongo.Collection // 需提前指定.Database("db_name").Collection("collection_name")
// redis、elasticsearch...
}
// NewDBProxy 创建数据实例
func NewDBProxy(db *gorm.DB, mongodb *mongo.Collection) *DBProxy {
return &DBProxy{
db: db,
mongodb: mongodb,
}
}
// QueryMiddleware 查询中间件类型定义
// 参数:
//
// ctx: 上下文
// builder: 查询构建器实例
// next: 下一个中间件或最终查询处理器
//
// 返回:
//
// []*R: 查询结果列表
// int64: 总数
// error: 错误信息
type QueryMiddleware[R any] func(
ctx context.Context,
builder *builder[R],
next func(context.Context) ([]*R, int64, error),
) ([]*R, int64, error)
// builder 查询构建器,使用泛型支持多种实体类型
// 泛型参数:
//
// R: 查询结果的实体类型
type builder[R any] struct {
data *DBProxy
start uint32
limit uint32
needTotal bool
needPagination bool
strategy QueryListStrategy[R] // 查询策略
middlewares []QueryMiddleware[R] // 中间件链
filter func(context.Context) (any, error)
sort func() any
}
// SetFilter 设置过滤条件生成函数
// 返回支持链式调用的构建器实例
func (b *builder[R]) SetFilter(filter func(context.Context) (any, error)) *builder[R] {
b.filter = filter
return b
}
// SetSort 设置排序条件生成函数
// 返回支持链式调用的构建器实例
func (b *builder[R]) SetSort(sort func() any) *builder[R] {
b.sort = sort
return b
}
// SetStrategy 设置查询列表策略
// 返回支持链式调用的构建器实例
func (b *builder[R]) SetStrategy(strategy QueryListStrategy[R]) *builder[R] {
b.strategy = strategy
return b
}
// Use 添加中间件
// 返回支持链式调用的构建器实例
func (b *builder[R]) Use(middleware QueryMiddleware[R]) *builder[R] {
b.middlewares = append(b.middlewares, middleware)
return b
}
// getQueryStrategy 获取查询列表策略
// 如果没有设置策略,则根据数据源自动选择策略
func (b *builder[R]) getQueryStrategy() (QueryListStrategy[R], error) {
if b.strategy != nil {
return b.strategy, nil
}
if b.data == nil {
return nil, errors.New("no data source provided")
}
switch {
case b.data.db != nil:
return NewQueryGormListStrategy[R](), nil
case b.data.mongodb != nil:
return NewQueryMongoListStrategy[R](), nil
default:
return nil, errors.New("query strategy not set and no valid DB found")
}
}
// QueryList 执行查询列表操作
// 返回值与中间件类型相同list []R 查询结果列表
func (b *builder[R]) QueryList(ctx context.Context) ([]*R, int64, error) {
// 尝试自动推断策略类型
strategy, err := b.getQueryStrategy()
if err != nil {
return nil, 0, err
}
// 构建中间件链
next := func(ctx context.Context) ([]*R, int64, error) {
return strategy.QueryList(ctx, b)
}
for i := len(b.middlewares) - 1; i >= 0; i-- {
next = func(mw QueryMiddleware[R], fn func(context.Context) ([]*R, int64, error)) func(context.Context) ([]*R, int64, error) {
return func(ctx context.Context) ([]*R, int64, error) {
return mw(ctx, b, fn)
}
}(b.middlewares[i], next)
}
return next(ctx)
}
// QueryListStrategy 查询列表策略
type QueryListStrategy[R any] interface {
QueryList(context.Context, *builder[R]) ([]*R, int64, error)
}
// QueryGormListStrategy GORM 查询策略实现
type QueryGormListStrategy[R any] struct{}
// NewQueryGormListStrategy 创建 GORM 查询策略实例
func NewQueryGormListStrategy[R any]() *QueryGormListStrategy[R] {
return &QueryGormListStrategy[R]{}
}
// QueryList 实现 GORM 查询逻辑
func (s *QueryGormListStrategy[R]) QueryList(
ctx context.Context,
builder *builder[R],
) (list []*R, total int64, err error) {
filterScope, err := builder.filter(ctx)
if err != nil {
return nil, 0, err
}
sortScope := builder.sort()
// 验证过滤条件和排序条件的类型有效性
for _, scope := range []any{filterScope, sortScope} {
if _, ok := scope.(func(*gorm.DB) *gorm.DB); !ok {
return nil, 0, errors.New("invalid scope")
}
}
// 使用 WaitAndGo 并行执行数据查询和总数统计操作
if err := WaitAndGo(func() error {
query := builder.data.db.WithContext(ctx).
Model(&list).
Scopes(filterScope.(func(*gorm.DB) *gorm.DB), sortScope.(func(*gorm.DB) *gorm.DB))
if builder.needPagination {
if builder.limit < 1 {
builder.limit = defaultLimit
}
query = query.Offset(int(builder.start)).Limit(int(builder.limit))
}
return query.Find(&list).Error
}, func() error {
if !builder.needTotal {
return nil
}
return builder.data.db.WithContext(ctx).
Model(&list).
Scopes(filterScope.(func(*gorm.DB) *gorm.DB)).
Count(&total).
Error
}); err != nil {
return nil, 0, err
}
return list, total, nil
}
// QueryMongoListStrategy MongoDB 查询策略实现
type QueryMongoListStrategy[R any] struct{}
// NewQueryMongoListStrategy 创建 MongoDB 查询策略实例
func NewQueryMongoListStrategy[R any]() *QueryMongoListStrategy[R] {
return &QueryMongoListStrategy[R]{}
}
// QueryList 实现 MongoDB 查询逻辑
func (s *QueryMongoListStrategy[R]) QueryList(
ctx context.Context,
builder *builder[R],
) (list []*R, total int64, err error) {
filterOpt, err := builder.filter(ctx)
if err != nil {
return nil, 0, err
}
sortOpt := builder.sort()
// 验证过滤条件和排序条件的类型有效性
for _, opt := range []any{filterOpt, sortOpt} {
_, mOk := opt.(bson.M)
_, dOk := opt.(bson.D)
if !mOk && !dOk {
return nil, 0, errors.New("invalid option")
}
}
// 使用 WaitAndGo 并行执行数据查询和总数统计操作
if err := WaitAndGo(func() error {
findOpt := options.Find().SetSort(sortOpt)
if builder.needPagination {
if builder.limit < 1 {
builder.limit = defaultLimit
}
findOpt.SetSkip(int64(builder.start)).SetLimit(int64(builder.limit))
}
cursor, err := builder.data.mongodb.Find(ctx, filterOpt, findOpt)
if err != nil {
return err
}
defer func(cursor *mongo.Cursor, ctx context.Context) {
_ = cursor.Close(ctx)
}(cursor, ctx)
return cursor.All(ctx, &list)
}, func() error {
if !builder.needTotal {
return nil
}
total, err = builder.data.mongodb.CountDocuments(ctx, filterOpt)
if err != nil {
return err
}
return nil
}); err != nil {
return nil, 0, err
}
return list, total, nil
}