From e02cf114e721bdcdd17fdafe5742e229203fdb6d Mon Sep 17 00:00:00 2001 From: fantasticbin Date: Sun, 11 May 2025 18:53:28 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=8C=E5=96=84=E6=9F=A5=E8=AF=A2=E6=9E=84?= =?UTF-8?q?=E5=BB=BA=E5=99=A8=E4=B8=AD=E9=97=B4=E4=BB=B6=E4=BC=A0=E9=80=92?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- query_builder/service.go | 14 +++++++++++++- query_builder/service_test.go | 19 +++++++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/query_builder/service.go b/query_builder/service.go index f3b3284..b2de0e8 100644 --- a/query_builder/service.go +++ b/query_builder/service.go @@ -37,7 +37,8 @@ func (s *BaseService[R, F, S]) QueryList(ctx context.Context) ([]*R, int64, erro // List 查询列表功能结构 // 泛型参数上同 type List[R any, F any, S any] struct { - strategy QueryListStrategy[R] + strategy QueryListStrategy[R] + middlewares []QueryMiddleware[R] } // SetStrategy 设置查询策略 @@ -48,6 +49,12 @@ func (l *List[R, F, S]) SetStrategy(strategy QueryListStrategy[R]) *List[R, F, S return l } +// Use 添加查询中间件 +func (l *List[R, F, S]) Use(middlewares ...QueryMiddleware[R]) *List[R, F, S] { + l.middlewares = append(l.middlewares, middlewares...) + return l +} + // Query 执行查询 // 该方法会根据传入的Service实例和QueryOption选项执行查询 func (l *List[R, F, S]) Query( @@ -69,5 +76,10 @@ func (l *List[R, F, S]) Query( sort: options.GetSort(), service: s, } + + for _, middleware := range l.middlewares { + service.Use(middleware) + } + return service.QueryList(ctx) } diff --git a/query_builder/service_test.go b/query_builder/service_test.go index 25294c4..4158474 100644 --- a/query_builder/service_test.go +++ b/query_builder/service_test.go @@ -3,9 +3,11 @@ package builder import ( "context" "errors" + "fmt" "go.uber.org/mock/gomock" "gorm.io/gorm" "testing" + "time" ) type TestEntity struct { @@ -173,6 +175,23 @@ func TestQueryList(t *testing.T) { // 这里使用 Mock 策略替代真实的策略 // 实际使用时会根据数据源自动选择 list.SetStrategy(mockStrategy) + // 添加耗时监控 + list.Use(func( + ctx context.Context, + builder *builder[TestEntity], + next func(context.Context, + ) ([]*TestEntity, int64, error)) ([]*TestEntity, int64, error) { + defer func() func() { + pre := time.Now() + return func() { + elapsed := time.Since(pre) + fmt.Println("elapsed:", elapsed) + } + }()() + result, total, err := next(ctx) + return result, total, err + }) + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if tt.mockSetup != nil {