diff --git a/routine/options.go b/routine/options.go new file mode 100644 index 0000000..8ca94bc --- /dev/null +++ b/routine/options.go @@ -0,0 +1,59 @@ +package routine + +type options[T any] struct { + workers int + capacity int + taskFn func(T) + panicHandler func(any) +} + +// Option function +type Option[T any] func(*options[T]) + +// setOptions 设置默认值 +func setOptions[T any](opt ...Option[T]) options[T] { + opts := options[T]{ + workers: 1, + capacity: 1, + } + + for _, o := range opt { + o(&opts) + } + + return opts +} + +// WithWorkers 设置协程数 +func WithWorkers[T any](workers int) Option[T] { + if workers <= 0 { + workers = 1 + } + return func(o *options[T]) { + o.workers = workers + } +} + +// WithCapacity 设置任务队列容量 +func WithCapacity[T any](capacity int) Option[T] { + if capacity <= 0 { + capacity = 1 + } + return func(o *options[T]) { + o.capacity = capacity + } +} + +// WithTaskFn 设置任务函数 +func WithTaskFn[T any](taskFn func(T)) Option[T] { + return func(o *options[T]) { + o.taskFn = taskFn + } +} + +// WithPanicHandler 设置panic处理函数 +func WithPanicHandler[T any](panicHandler func(any)) Option[T] { + return func(o *options[T]) { + o.panicHandler = panicHandler + } +} diff --git a/routine/pool.go b/routine/pool.go index 0bb2809..d7bed9c 100644 --- a/routine/pool.go +++ b/routine/pool.go @@ -1,34 +1,37 @@ package routine import ( - "fmt" - "runtime/debug" "sync" ) type Pool[T any] struct { - taskQueue chan T - taskFn func(T) - workers int - wg sync.WaitGroup + taskQueue chan T + taskFn func(T) + workers int + panicHandler func(any) + wg sync.WaitGroup } -func NewPool[T any](workers, capacity int, taskFn func(T)) *Pool[T] { +func NewPool[T any](opt ...Option[T]) *Pool[T] { + opts := setOptions(opt...) pool := &Pool[T]{ - taskQueue: make(chan T, capacity), - taskFn: func(t T) { - defer func() { - // 处理协程运行中出现panic的情况 - if r := recover(); r != nil { - fmt.Printf("Panic: %v\n %s", r, string(debug.Stack())) + taskQueue: make(chan T, opts.capacity), + panicHandler: opts.panicHandler, + workers: opts.workers, + } + pool.taskFn = func(t T) { + defer func() { + // 处理协程运行中出现panic的情况 + if r := recover(); r != nil { + if pool.panicHandler != nil { + pool.panicHandler(r) } - }() + } + }() - taskFn(t) - }, - workers: workers, + opts.taskFn(t) } - pool.wg.Add(workers) + pool.wg.Add(opts.workers) return pool } diff --git a/routine/pool_test.go b/routine/pool_test.go index 09cb3ee..f29a52e 100644 --- a/routine/pool_test.go +++ b/routine/pool_test.go @@ -1,7 +1,9 @@ package routine import ( + "fmt" "runtime" + "runtime/debug" "sync/atomic" "testing" @@ -11,13 +13,23 @@ import ( func TestPool(t *testing.T) { num := runtime.NumCPU() var sum atomic.Int32 - pool := NewPool(num, num, func(num int32) { + + task := func(num int32) { if num < 0 { panic("unable to handle negative numbers") } sum.Add(num) - }) + } + handler := func(r any) { + fmt.Printf("Panic: %v\n %s", r, string(debug.Stack())) + } + pool := NewPool( + WithWorkers[int32](num), + WithCapacity[int32](num), + WithTaskFn(task), + WithPanicHandler[int32](handler), + ) pool.Start() for i := int32(1000); i >= -1; i-- {