完善查询构建器中间件传递

This commit is contained in:
fantasticbin 2025-05-11 18:53:28 +08:00
parent 342b25cd78
commit e02cf114e7
2 changed files with 32 additions and 1 deletions

View File

@ -38,6 +38,7 @@ func (s *BaseService[R, F, S]) QueryList(ctx context.Context) ([]*R, int64, erro
// 泛型参数上同 // 泛型参数上同
type List[R any, F any, S any] struct { type List[R any, F any, S any] struct {
strategy QueryListStrategy[R] strategy QueryListStrategy[R]
middlewares []QueryMiddleware[R]
} }
// SetStrategy 设置查询策略 // SetStrategy 设置查询策略
@ -48,6 +49,12 @@ func (l *List[R, F, S]) SetStrategy(strategy QueryListStrategy[R]) *List[R, F, S
return l return l
} }
// Use 添加查询中间件
func (l *List[R, F, S]) Use(middlewares ...QueryMiddleware[R]) *List[R, F, S] {
l.middlewares = append(l.middlewares, middlewares...)
return l
}
// Query 执行查询 // Query 执行查询
// 该方法会根据传入的Service实例和QueryOption选项执行查询 // 该方法会根据传入的Service实例和QueryOption选项执行查询
func (l *List[R, F, S]) Query( func (l *List[R, F, S]) Query(
@ -69,5 +76,10 @@ func (l *List[R, F, S]) Query(
sort: options.GetSort(), sort: options.GetSort(),
service: s, service: s,
} }
for _, middleware := range l.middlewares {
service.Use(middleware)
}
return service.QueryList(ctx) return service.QueryList(ctx)
} }

View File

@ -3,9 +3,11 @@ package builder
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"go.uber.org/mock/gomock" "go.uber.org/mock/gomock"
"gorm.io/gorm" "gorm.io/gorm"
"testing" "testing"
"time"
) )
type TestEntity struct { type TestEntity struct {
@ -173,6 +175,23 @@ func TestQueryList(t *testing.T) {
// 这里使用 Mock 策略替代真实的策略 // 这里使用 Mock 策略替代真实的策略
// 实际使用时会根据数据源自动选择 // 实际使用时会根据数据源自动选择
list.SetStrategy(mockStrategy) list.SetStrategy(mockStrategy)
// 添加耗时监控
list.Use(func(
ctx context.Context,
builder *builder[TestEntity],
next func(context.Context,
) ([]*TestEntity, int64, error)) ([]*TestEntity, int64, error) {
defer func() func() {
pre := time.Now()
return func() {
elapsed := time.Since(pre)
fmt.Println("elapsed:", elapsed)
}
}()()
result, total, err := next(ctx)
return result, total, err
})
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if tt.mockSetup != nil { if tt.mockSetup != nil {