协程池泛型优化

main
wangwenbin 9 months ago
parent c3525558de
commit 3d730904fa

@ -1,6 +1,8 @@
package routine package routine
type options[T any] struct { type T any
type options struct {
workers int workers int
capacity int capacity int
taskFn func(T) taskFn func(T)
@ -8,11 +10,11 @@ type options[T any] struct {
} }
// Option function // Option function
type Option[T any] func(*options[T]) type Option func(*options)
// setOptions 设置默认值 // loadOptions 设置默认值
func setOptions[T any](opt ...Option[T]) options[T] { func loadOptions(opt ...Option) options {
opts := options[T]{ opts := options{
workers: 1, workers: 1,
capacity: 1, capacity: 1,
} }
@ -25,35 +27,35 @@ func setOptions[T any](opt ...Option[T]) options[T] {
} }
// WithWorkers 设置协程数 // WithWorkers 设置协程数
func WithWorkers[T any](workers int) Option[T] { func WithWorkers(workers int) Option {
if workers <= 0 { if workers <= 0 {
workers = 1 workers = 1
} }
return func(o *options[T]) { return func(o *options) {
o.workers = workers o.workers = workers
} }
} }
// WithCapacity 设置任务队列容量 // WithCapacity 设置任务队列容量
func WithCapacity[T any](capacity int) Option[T] { func WithCapacity(capacity int) Option {
if capacity <= 0 { if capacity <= 0 {
capacity = 1 capacity = 1
} }
return func(o *options[T]) { return func(o *options) {
o.capacity = capacity o.capacity = capacity
} }
} }
// WithTaskFn 设置任务函数 // WithTaskFn 设置任务函数
func WithTaskFn[T any](taskFn func(T)) Option[T] { func WithTaskFn(taskFn func(T)) Option {
return func(o *options[T]) { return func(o *options) {
o.taskFn = taskFn o.taskFn = taskFn
} }
} }
// WithPanicHandler 设置panic处理函数 // WithPanicHandler 设置panic处理函数
func WithPanicHandler[T any](panicHandler func(any)) Option[T] { func WithPanicHandler(panicHandler func(any)) Option {
return func(o *options[T]) { return func(o *options) {
o.panicHandler = panicHandler o.panicHandler = panicHandler
} }
} }

@ -4,7 +4,7 @@ import (
"sync" "sync"
) )
type Pool[T any] struct { type Pool struct {
taskQueue chan T taskQueue chan T
taskFn func(T) taskFn func(T)
workers int workers int
@ -12,9 +12,9 @@ type Pool[T any] struct {
wg sync.WaitGroup wg sync.WaitGroup
} }
func NewPool[T any](opt ...Option[T]) *Pool[T] { func NewPool(opt ...Option) *Pool {
opts := setOptions(opt...) opts := loadOptions(opt...)
pool := &Pool[T]{ pool := &Pool{
taskQueue: make(chan T, opts.capacity), taskQueue: make(chan T, opts.capacity),
panicHandler: opts.panicHandler, panicHandler: opts.panicHandler,
workers: opts.workers, workers: opts.workers,
@ -37,7 +37,7 @@ func NewPool[T any](opt ...Option[T]) *Pool[T] {
} }
// Start 启动任务 // Start 启动任务
func (p *Pool[T]) Start() { func (p *Pool) Start() {
for i := 0; i < p.workers; i++ { for i := 0; i < p.workers; i++ {
go func() { go func() {
defer p.wg.Done() defer p.wg.Done()
@ -55,12 +55,12 @@ func (p *Pool[T]) Start() {
} }
// Push 提交任务 // Push 提交任务
func (p *Pool[T]) Push(task T) { func (p *Pool) Push(task T) {
p.taskQueue <- task p.taskQueue <- task
} }
// Wait 挂起当前协程 // Wait 挂起当前协程
func (p *Pool[T]) Wait() { func (p *Pool) Wait() {
close(p.taskQueue) close(p.taskQueue)
p.wg.Wait() p.wg.Wait()
} }

@ -14,7 +14,8 @@ func TestPool(t *testing.T) {
num := runtime.NumCPU() num := runtime.NumCPU()
var sum atomic.Int32 var sum atomic.Int32
task := func(num int32) { task := func(t T) {
num := t.(int32)
if num < 0 { if num < 0 {
panic("unable to handle negative numbers") panic("unable to handle negative numbers")
} }
@ -25,10 +26,10 @@ func TestPool(t *testing.T) {
fmt.Printf("Panic: %v\n %s", r, string(debug.Stack())) fmt.Printf("Panic: %v\n %s", r, string(debug.Stack()))
} }
pool := NewPool( pool := NewPool(
WithWorkers[int32](num), WithWorkers(num),
WithCapacity[int32](num), WithCapacity(num),
WithTaskFn(task), WithTaskFn(task),
WithPanicHandler[int32](handler), WithPanicHandler(handler),
) )
pool.Start() pool.Start()

Loading…
Cancel
Save