diff --git a/cyclic_barrier/h2o.go b/cyclic_barrier/h2o.go index 9bcda64..ce6a082 100644 --- a/cyclic_barrier/h2o.go +++ b/cyclic_barrier/h2o.go @@ -20,7 +20,6 @@ type H2O struct { semaH *semaphore.Weighted semaO *semaphore.Weighted cb cyclicbarrier.CyclicBarrier - wg sync.WaitGroup } func NewH2O() *H2O { @@ -69,26 +68,27 @@ func (h *H2O) Gen(num uint) <-chan string { releaseOxygen := func() { ch <- "O" } + var wg sync.WaitGroup // 使用 WaitGroup 等待所有的 goroutine 完成 - h.wg.Add(numInt * sum) + wg.Add(numInt * sum) for i := 0; i < numInt*H2OHydrogenNum; i++ { go func() { time.Sleep(time.Duration(rand.Intn(100)) * time.Millisecond) h.hydrogen(releaseHydrogen) - h.wg.Done() + wg.Done() }() } for i := 0; i < numInt*H2OOxygenNum; i++ { go func() { time.Sleep(time.Duration(rand.Intn(100)) * time.Millisecond) h.oxygen(releaseOxygen) - h.wg.Done() + wg.Done() }() } go func() { - h.wg.Wait() + wg.Wait() close(ch) }() diff --git a/err_group/weather.go b/err_group/weather.go index 6793db1..e3e9ca7 100644 --- a/err_group/weather.go +++ b/err_group/weather.go @@ -2,9 +2,10 @@ package err_group import ( "context" - "golang.org/x/sync/errgroup" "math/rand" "time" + + "golang.org/x/sync/errgroup" ) type WeatherData struct { @@ -17,30 +18,22 @@ var WeatherList = []string{"晴", "阴", "多云", "小雨", "大雨"} // fetchWeatherData 获取天气数据 func fetchWeatherData(ctx context.Context, city string) (*WeatherData, error) { - done := make(chan struct{}, 1) // 用于指示数据获取成功完成 - data := new(WeatherData) - - // 启动一个goroutine来获取天气数据 - go func() { - // 这里仅为示例,实际中这里会是API调用 - // 为了模拟可能出现的超时或错误情况,这里随机地休眠一段时间 - time.Sleep(time.Duration(rand.Intn(500)) * time.Millisecond) - data.City = city - // 随机天气 - data.Weather = WeatherList[rand.Intn(len(WeatherList))] - // 随机温度 - data.Temp = rand.Intn(40) - done <- struct{}{} - }() - + // 这里仅为示例,实际中这里会是 API 调用 + // 为了模拟可能出现的超时或错误情况,这里随机休眠一段时间 + delay := time.Duration(rand.Intn(500)) * time.Millisecond select { case <-ctx.Done(): - // 如果上下文被取消,返回错误 return nil, ctx.Err() - case <-done: - // 如果数据获取成功完成,返回数据 - return data, nil + case <-time.After(delay): } + + return &WeatherData{ + City: city, + // 随机天气 + Weather: WeatherList[rand.Intn(len(WeatherList))], + // 随机温度 + Temp: rand.Intn(40), + }, nil } func GetAllCityWeatherData(timeOut time.Duration) ([]*WeatherData, error) { @@ -53,11 +46,10 @@ func GetAllCityWeatherData(timeOut time.Duration) ([]*WeatherData, error) { // 城市列表 cities := []string{"New York", "Tokyo", "Berlin", "Paris", "Beijing"} - list := make([]*WeatherData, 0, len(cities)) + list := make([]*WeatherData, len(cities)) // 循环启动goroutine来获取每个城市的天气数据 - for _, city := range cities { - city := city + for i, city := range cities { g.Go(func() error { // FetchWeatherData使用了上下文,如果上下文被取消,它应该立刻尝试返回 data, err := fetchWeatherData(ctx, city) @@ -65,8 +57,8 @@ func GetAllCityWeatherData(timeOut time.Duration) ([]*WeatherData, error) { return err } - // 如果没有错误,将数据添加到列表中 - list = append(list, data) + // 每个协程写入固定下标,避免并发 append 数据竞争 + list[i] = data return nil }) } diff --git a/err_group/weather_test.go b/err_group/weather_test.go index e1245c2..15e2e2a 100644 --- a/err_group/weather_test.go +++ b/err_group/weather_test.go @@ -6,11 +6,15 @@ import ( ) func TestGetAllCityWeatherData(t *testing.T) { - list, err := GetAllCityWeatherData(500 * time.Millisecond) + list, err := GetAllCityWeatherData(2 * time.Second) if err != nil { t.Fatal(err) } + if len(list) != 5 { + t.Fatalf("expected 5 weather records, got %d", len(list)) + } + for _, data := range list { t.Logf("%s: %s, %d", data.City, data.Weather, data.Temp) } diff --git a/lock_free/delay_queue.go b/lock_free/delay_queue.go index 3a279e6..76e1b55 100644 --- a/lock_free/delay_queue.go +++ b/lock_free/delay_queue.go @@ -43,12 +43,18 @@ func (q *DelayLkQueue[TKey, TValue]) CancellableDelayEnqueue(key TKey, value TVa q.m.Lock() defer q.m.Unlock() if timer, ok := q.timers[key]; ok { - timer.Stop() + if timer.Stop() { + q.delayCount.Add(^uint64(0)) + } } q.delayCount.Add(1) q.timers[key] = time.AfterFunc(duration, func() { q.delayCount.Add(^uint64(0)) + + q.m.Lock() delete(q.timers, key) + q.m.Unlock() + q.Enqueue(value) }) } @@ -58,9 +64,10 @@ func (q *DelayLkQueue[TKey, TValue]) CancelDelayEnqueue(key TKey) { q.m.Lock() defer q.m.Unlock() if timer, ok := q.timers[key]; ok { - q.delayCount.Add(^uint64(0)) + if timer.Stop() { + q.delayCount.Add(^uint64(0)) + } delete(q.timers, key) - timer.Stop() } } diff --git a/lock_free/queues.go b/lock_free/queues.go index 62f41cc..34808f8 100644 --- a/lock_free/queues.go +++ b/lock_free/queues.go @@ -34,11 +34,21 @@ func (q *Queues[TKey, TValue, TRoute]) SetExpireAutoFail(expireAutoFail time.Dur // checkAckStatus 检查确认状态 func (q *Queues[TKey, TValue, TRoute]) checkAckStatus(route TRoute) { if ack, ok := q.ack.Load(route); ok && !q.autoAck { - if status, exist := <-ack.(chan bool); !status && exist { - if msg, ok := q.msgs.LoadAndDelete(route); ok { - // 重新入队 - q.Enqueue(route, msg.(TValue)) + ackChan := ack.(chan bool) + select { + case status, exist := <-ackChan: + if !exist { + q.ack.Delete(route) + return } + + if !status { + if msg, loaded := q.msgs.LoadAndDelete(route); loaded { + // 重新入队 + q.Enqueue(route, msg.(TValue)) + } + } + default: } } } diff --git a/observer/observer.go b/observer/observer.go index de5c125..c5217fe 100644 --- a/observer/observer.go +++ b/observer/observer.go @@ -1,11 +1,16 @@ package observer -import "reflect" +import ( + "reflect" + "sync" +) // Subject 被观察者 type Subject struct { observers []Observer in chan any + mu sync.RWMutex + once sync.Once } // Observer 观察者 chan @@ -18,11 +23,17 @@ func NewSubject() Subject { // Attach 观察者绑定 func (s *Subject) Attach(obs ...Observer) { + s.mu.Lock() + defer s.mu.Unlock() + s.observers = append(s.observers, obs...) } // Detach 观察者解绑 func (s *Subject) Detach(obs Observer) { + s.mu.Lock() + defer s.mu.Unlock() + for i, o := range s.observers { if o == obs { s.observers = append(s.observers[:i], s.observers[i+1:]...) @@ -34,37 +45,33 @@ func (s *Subject) Detach(obs Observer) { // Notify 通知观察者 func (s *Subject) Notify(data any) { - // 这里可考虑不需要使用协程运行 - go s.fanOut(s.in, s.observers) + // fanout 只启动一次,避免每次通知重复启动协程 + s.once.Do(func() { + go s.fanOut(s.in) + }) + s.in <- data } // fanOut 扇出模式实现 -func (s *Subject) fanOut(ch <-chan interface{}, out []Observer) { +func (s *Subject) fanOut(ch <-chan interface{}) { // 绑定输入 chan 的 reflect.SelectCase cases := []reflect.SelectCase{ {Dir: reflect.SelectRecv, Chan: reflect.ValueOf(ch)}, } - go func() { - defer func() { - // 退出时关闭所有的输出 chan - for _, o := range out { - close(o) - } - }() - - for { - _, value, ok := reflect.Select(cases) // 从输入 chan 中读取数据 - if !ok { - // 输入 channel 被关闭 - return - } - - // 输入 channel 接收到数据 - for _, o := range out { - o <- value.Interface() // 放入到输出 chan 中,同步方式 - } + for { + _, value, ok := reflect.Select(cases) // 从输入 chan 中读取数据 + if !ok { + // 输入 channel 被关闭 + return } - }() + + // 输入 channel 接收到数据 + s.mu.RLock() + for _, o := range s.observers { + o <- value.Interface() // 放入到输出 chan 中,同步方式 + } + s.mu.RUnlock() + } } diff --git a/query_builder/service_test.go b/query_builder/service_test.go index 4158474..1f74fc4 100644 --- a/query_builder/service_test.go +++ b/query_builder/service_test.go @@ -4,10 +4,11 @@ import ( "context" "errors" "fmt" - "go.uber.org/mock/gomock" - "gorm.io/gorm" "testing" "time" + + "go.uber.org/mock/gomock" + "gorm.io/gorm" ) type TestEntity struct { @@ -34,11 +35,11 @@ type TestService struct { func (s *TestService) GetFilter(_ context.Context) (any, error) { return func(db *gorm.DB) *gorm.DB { if s.filter.Name != "" { - db.Where("name = ?", s.filter.Name) + db = db.Where("name = ?", s.filter.Name) } if s.filter.Age > 0 { - db.Where("age >= ?", s.filter.Age) + db = db.Where("age >= ?", s.filter.Age) } return db @@ -180,7 +181,7 @@ func TestQueryList(t *testing.T) { ctx context.Context, builder *builder[TestEntity], next func(context.Context, - ) ([]*TestEntity, int64, error)) ([]*TestEntity, int64, error) { + ) ([]*TestEntity, int64, error)) ([]*TestEntity, int64, error) { defer func() func() { pre := time.Now() return func() { diff --git a/query_builder/utils.go b/query_builder/utils.go index 0261e3c..a589085 100644 --- a/query_builder/utils.go +++ b/query_builder/utils.go @@ -8,14 +8,19 @@ import ( // WaitAndGo 等待所有函数执行完毕 func WaitAndGo(fn ...func() error) error { - defer func() { - if err := recover(); err != nil { - fmt.Printf("Panic: %+v\n %s", err, string(debug.Stack())) - } - }() var g errgroup.Group for _, f := range fn { - g.Go(func() error { + if f == nil { + continue + } + + g.Go(func() (err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("panic: %+v\n%s", r, string(debug.Stack())) + } + }() + return f() }) } diff --git a/routine/options.go b/routine/options.go index 2b9850a..a66a8fc 100644 --- a/routine/options.go +++ b/routine/options.go @@ -17,6 +17,7 @@ func loadOptions(opt ...Option) options { opts := options{ workers: 1, capacity: 1, + taskFn: func(T) {}, } for _, o := range opt { @@ -49,7 +50,9 @@ func WithCapacity(capacity int) Option { // WithTaskFn 设置任务函数 func WithTaskFn(taskFn func(T)) Option { return func(o *options) { - o.taskFn = taskFn + if taskFn != nil { + o.taskFn = taskFn + } } } diff --git a/semaphore/semaphore.go b/semaphore/semaphore.go index 88cf5f6..57d3d3c 100644 --- a/semaphore/semaphore.go +++ b/semaphore/semaphore.go @@ -3,6 +3,7 @@ package semaphore import ( "container/list" "context" + "errors" "sync" ) @@ -25,6 +26,13 @@ func NewSemaphore(n int64) *Semaphore { } func (s *Semaphore) Acquire(ctx context.Context, n int64) error { + if n <= 0 { + if n == 0 { + return nil + } + return errors.New("semaphore: negative acquire") + } + done := ctx.Done() s.mu.Lock() @@ -103,7 +111,7 @@ func (s *Semaphore) notifyWaiters() { break // 没有等待者 } - w := next.Value.(waiter) + w := next.Value.(*waiter) if s.size-s.cur < w.n { // 没有足够的资源满足下一个等待者 // 没有必要继续唤醒后续等待者,防止有等待者处于饥饿状态 @@ -117,6 +125,13 @@ func (s *Semaphore) notifyWaiters() { } func (s *Semaphore) Release(n int64) { + if n <= 0 { + if n == 0 { + return + } + panic("semaphore: negative release") + } + s.mu.Lock() s.cur -= n // 释放 n 个资源 @@ -130,6 +145,10 @@ func (s *Semaphore) Release(n int64) { } func (s *Semaphore) TryAcquire(n int64) bool { + if n <= 0 { + return n == 0 + } + s.mu.Lock() // 检查当前可用资源是否足够,并且还没有等待者 success := s.size-s.cur >= n && s.waiters.Len() == 0 diff --git a/semaphore/semaphore_chan.go b/semaphore/semaphore_chan.go index 0082301..522e6e3 100644 --- a/semaphore/semaphore_chan.go +++ b/semaphore/semaphore_chan.go @@ -1,9 +1,6 @@ package semaphore -import "sync" - type SemaChan struct { - sync.Locker sem chan struct{} // 信号量通道 } @@ -11,13 +8,18 @@ func NewSemaChan(n int) *SemaChan { if n <= 0 { n = 1 // 确保信号量至少为1,直接变成一个互斥锁 } + sem := make(chan struct{}, n) // 初始化信号量通道,容量为n + for i := 0; i < n; i++ { + sem <- struct{}{} + } + return &SemaChan{ - sem: make(chan struct{}, n), // 初始化信号量通道,容量为n + sem: sem, } } func (s *SemaChan) Lock() { - <-s.sem // 使用接收的方式阻塞,用来与 sync.Mutex 的内存模型保持一致 + <-s.sem } func (s *SemaChan) Unlock() { diff --git a/ticker/ticker.go b/ticker/ticker.go index 2bd9197..2d07727 100644 --- a/ticker/ticker.go +++ b/ticker/ticker.go @@ -1,11 +1,16 @@ package ticker -import "time" +import ( + "sync" + "time" +) type Ticker struct { C chan time.Time ticker *time.Ticker close chan struct{} + start sync.Once + once sync.Once } func NewTicker(d time.Duration) *Ticker { @@ -17,20 +22,29 @@ func NewTicker(d time.Duration) *Ticker { } func (t *Ticker) Start() { - // 首次直接触发 - t.C <- time.Now() - go func() { - for { - select { - case <-t.close: - // 关闭时停止goroutine - return - case tc := <-t.ticker.C: - // 把go原生定时器 push 的时间推送到我们定义的 time channel 中 - t.C <- tc - } + t.start.Do(func() { + // 首次直接触发 + select { + case t.C <- time.Now(): + default: } - }() + + go func() { + for { + select { + case <-t.close: + // 关闭时停止goroutine + return + case tc := <-t.ticker.C: + // 把go原生定时器 push 的时间推送到我们定义的 time channel 中 + select { + case t.C <- tc: + default: + } + } + } + }() + }) } func (t *Ticker) Reset(d time.Duration) { @@ -38,6 +52,8 @@ func (t *Ticker) Reset(d time.Duration) { } func (t *Ticker) Stop() { - t.ticker.Stop() - close(t.close) + t.once.Do(func() { + t.ticker.Stop() + close(t.close) + }) } diff --git a/ticker/ticker_test.go b/ticker/ticker_test.go index a6114e7..9f7bb2e 100644 --- a/ticker/ticker_test.go +++ b/ticker/ticker_test.go @@ -40,3 +40,17 @@ func TestTicker(t *testing.T) { time.Sleep(3 * time.Second) assert.Equal(t, expected, data) } + +func TestTickerStartTwice(t *testing.T) { + ticker := NewTicker(100 * time.Millisecond) + ticker.Start() + ticker.Start() + + defer ticker.Stop() + + select { + case <-ticker.C: + case <-time.After(time.Second): + t.Fatal("ticker did not emit tick") + } +}