使用codex优化代码,具体如下:

主要优化

  - err_group:去掉多余 goroutine,避免潜在泄漏;并把并发 append 改为按下标写
    入,消除数据竞争。
  - err_group 测试稳定性增强:放宽超时并增加结果长度断言。
  - semaphore:修复等待队列元素类型断言错误(*waiter);补充非法参数校验(负数
    acquire/release)。
  - SemaChan:修复 Lock/Unlock 逻辑(初始化令牌桶),避免永久阻塞。
  - observer:修复“每次 Notify 都启动新 fanout 协程”的问题:改为 sync.Once 只启动一次
    fanOut。
  - observer:修复并发读写观察者列表问题:给 Attach/Detach/fanOut 增加读写锁保护。
  - observer:去掉 fanout 内部额外再起 goroutine和自动关闭所有 observer 的行为,避
    免重复关闭/竞态风险(仍保留 Detach 时关闭单个 observer)。
  - lock_free:修复可取消延迟队列的计数错误与 timers map 并发访问问题。
  - lock_free:checkAckStatus 改为非阻塞读取,避免入队路径被卡住。
  - routine:提供默认空任务并忽略 nil taskFn,防止空指针调用。
  - ticker:发送改为非阻塞,Stop 幂等化,降低阻塞和重复关闭风险。
  - query_builder:WaitAndGo 增加 goroutine 内 panic 转 error;测试里
    的 GORM filter 链式写法修正。

  新增测试

  - 新增 semaphore 测试,覆盖 Acquire/Release/TryAcquire 与 SemaChan 并发上限。
This commit is contained in:
fantasticbin 2026-03-05 21:53:11 +08:00
parent 88a4caf72e
commit 5b48ea1a62
13 changed files with 177 additions and 97 deletions

View File

@ -20,7 +20,6 @@ type H2O struct {
semaH *semaphore.Weighted semaH *semaphore.Weighted
semaO *semaphore.Weighted semaO *semaphore.Weighted
cb cyclicbarrier.CyclicBarrier cb cyclicbarrier.CyclicBarrier
wg sync.WaitGroup
} }
func NewH2O() *H2O { func NewH2O() *H2O {
@ -69,26 +68,27 @@ func (h *H2O) Gen(num uint) <-chan string {
releaseOxygen := func() { releaseOxygen := func() {
ch <- "O" ch <- "O"
} }
var wg sync.WaitGroup
// 使用 WaitGroup 等待所有的 goroutine 完成 // 使用 WaitGroup 等待所有的 goroutine 完成
h.wg.Add(numInt * sum) wg.Add(numInt * sum)
for i := 0; i < numInt*H2OHydrogenNum; i++ { for i := 0; i < numInt*H2OHydrogenNum; i++ {
go func() { go func() {
time.Sleep(time.Duration(rand.Intn(100)) * time.Millisecond) time.Sleep(time.Duration(rand.Intn(100)) * time.Millisecond)
h.hydrogen(releaseHydrogen) h.hydrogen(releaseHydrogen)
h.wg.Done() wg.Done()
}() }()
} }
for i := 0; i < numInt*H2OOxygenNum; i++ { for i := 0; i < numInt*H2OOxygenNum; i++ {
go func() { go func() {
time.Sleep(time.Duration(rand.Intn(100)) * time.Millisecond) time.Sleep(time.Duration(rand.Intn(100)) * time.Millisecond)
h.oxygen(releaseOxygen) h.oxygen(releaseOxygen)
h.wg.Done() wg.Done()
}() }()
} }
go func() { go func() {
h.wg.Wait() wg.Wait()
close(ch) close(ch)
}() }()

View File

@ -2,9 +2,10 @@ package err_group
import ( import (
"context" "context"
"golang.org/x/sync/errgroup"
"math/rand" "math/rand"
"time" "time"
"golang.org/x/sync/errgroup"
) )
type WeatherData struct { type WeatherData struct {
@ -17,30 +18,22 @@ var WeatherList = []string{"晴", "阴", "多云", "小雨", "大雨"}
// fetchWeatherData 获取天气数据 // fetchWeatherData 获取天气数据
func fetchWeatherData(ctx context.Context, city string) (*WeatherData, error) { func fetchWeatherData(ctx context.Context, city string) (*WeatherData, error) {
done := make(chan struct{}, 1) // 用于指示数据获取成功完成 // 这里仅为示例,实际中这里会是 API 调用
data := new(WeatherData) // 为了模拟可能出现的超时或错误情况,这里随机休眠一段时间
delay := time.Duration(rand.Intn(500)) * time.Millisecond
// 启动一个goroutine来获取天气数据
go func() {
// 这里仅为示例实际中这里会是API调用
// 为了模拟可能出现的超时或错误情况,这里随机地休眠一段时间
time.Sleep(time.Duration(rand.Intn(500)) * time.Millisecond)
data.City = city
// 随机天气
data.Weather = WeatherList[rand.Intn(len(WeatherList))]
// 随机温度
data.Temp = rand.Intn(40)
done <- struct{}{}
}()
select { select {
case <-ctx.Done(): case <-ctx.Done():
// 如果上下文被取消,返回错误
return nil, ctx.Err() return nil, ctx.Err()
case <-done: case <-time.After(delay):
// 如果数据获取成功完成,返回数据
return data, nil
} }
return &WeatherData{
City: city,
// 随机天气
Weather: WeatherList[rand.Intn(len(WeatherList))],
// 随机温度
Temp: rand.Intn(40),
}, nil
} }
func GetAllCityWeatherData(timeOut time.Duration) ([]*WeatherData, error) { func GetAllCityWeatherData(timeOut time.Duration) ([]*WeatherData, error) {
@ -53,11 +46,10 @@ func GetAllCityWeatherData(timeOut time.Duration) ([]*WeatherData, error) {
// 城市列表 // 城市列表
cities := []string{"New York", "Tokyo", "Berlin", "Paris", "Beijing"} cities := []string{"New York", "Tokyo", "Berlin", "Paris", "Beijing"}
list := make([]*WeatherData, 0, len(cities)) list := make([]*WeatherData, len(cities))
// 循环启动goroutine来获取每个城市的天气数据 // 循环启动goroutine来获取每个城市的天气数据
for _, city := range cities { for i, city := range cities {
city := city
g.Go(func() error { g.Go(func() error {
// FetchWeatherData使用了上下文如果上下文被取消它应该立刻尝试返回 // FetchWeatherData使用了上下文如果上下文被取消它应该立刻尝试返回
data, err := fetchWeatherData(ctx, city) data, err := fetchWeatherData(ctx, city)
@ -65,8 +57,8 @@ func GetAllCityWeatherData(timeOut time.Duration) ([]*WeatherData, error) {
return err return err
} }
// 如果没有错误,将数据添加到列表中 // 每个协程写入固定下标,避免并发 append 数据竞争
list = append(list, data) list[i] = data
return nil return nil
}) })
} }

View File

@ -6,11 +6,15 @@ import (
) )
func TestGetAllCityWeatherData(t *testing.T) { func TestGetAllCityWeatherData(t *testing.T) {
list, err := GetAllCityWeatherData(500 * time.Millisecond) list, err := GetAllCityWeatherData(2 * time.Second)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if len(list) != 5 {
t.Fatalf("expected 5 weather records, got %d", len(list))
}
for _, data := range list { for _, data := range list {
t.Logf("%s: %s, %d", data.City, data.Weather, data.Temp) t.Logf("%s: %s, %d", data.City, data.Weather, data.Temp)
} }

View File

@ -43,12 +43,18 @@ func (q *DelayLkQueue[TKey, TValue]) CancellableDelayEnqueue(key TKey, value TVa
q.m.Lock() q.m.Lock()
defer q.m.Unlock() defer q.m.Unlock()
if timer, ok := q.timers[key]; ok { if timer, ok := q.timers[key]; ok {
timer.Stop() if timer.Stop() {
q.delayCount.Add(^uint64(0))
}
} }
q.delayCount.Add(1) q.delayCount.Add(1)
q.timers[key] = time.AfterFunc(duration, func() { q.timers[key] = time.AfterFunc(duration, func() {
q.delayCount.Add(^uint64(0)) q.delayCount.Add(^uint64(0))
q.m.Lock()
delete(q.timers, key) delete(q.timers, key)
q.m.Unlock()
q.Enqueue(value) q.Enqueue(value)
}) })
} }
@ -58,9 +64,10 @@ func (q *DelayLkQueue[TKey, TValue]) CancelDelayEnqueue(key TKey) {
q.m.Lock() q.m.Lock()
defer q.m.Unlock() defer q.m.Unlock()
if timer, ok := q.timers[key]; ok { if timer, ok := q.timers[key]; ok {
q.delayCount.Add(^uint64(0)) if timer.Stop() {
q.delayCount.Add(^uint64(0))
}
delete(q.timers, key) delete(q.timers, key)
timer.Stop()
} }
} }

View File

@ -34,11 +34,21 @@ func (q *Queues[TKey, TValue, TRoute]) SetExpireAutoFail(expireAutoFail time.Dur
// checkAckStatus 检查确认状态 // checkAckStatus 检查确认状态
func (q *Queues[TKey, TValue, TRoute]) checkAckStatus(route TRoute) { func (q *Queues[TKey, TValue, TRoute]) checkAckStatus(route TRoute) {
if ack, ok := q.ack.Load(route); ok && !q.autoAck { if ack, ok := q.ack.Load(route); ok && !q.autoAck {
if status, exist := <-ack.(chan bool); !status && exist { ackChan := ack.(chan bool)
if msg, ok := q.msgs.LoadAndDelete(route); ok { select {
// 重新入队 case status, exist := <-ackChan:
q.Enqueue(route, msg.(TValue)) if !exist {
q.ack.Delete(route)
return
} }
if !status {
if msg, loaded := q.msgs.LoadAndDelete(route); loaded {
// 重新入队
q.Enqueue(route, msg.(TValue))
}
}
default:
} }
} }
} }

View File

@ -1,11 +1,16 @@
package observer package observer
import "reflect" import (
"reflect"
"sync"
)
// Subject 被观察者 // Subject 被观察者
type Subject struct { type Subject struct {
observers []Observer observers []Observer
in chan any in chan any
mu sync.RWMutex
once sync.Once
} }
// Observer 观察者 chan // Observer 观察者 chan
@ -18,11 +23,17 @@ func NewSubject() Subject {
// Attach 观察者绑定 // Attach 观察者绑定
func (s *Subject) Attach(obs ...Observer) { func (s *Subject) Attach(obs ...Observer) {
s.mu.Lock()
defer s.mu.Unlock()
s.observers = append(s.observers, obs...) s.observers = append(s.observers, obs...)
} }
// Detach 观察者解绑 // Detach 观察者解绑
func (s *Subject) Detach(obs Observer) { func (s *Subject) Detach(obs Observer) {
s.mu.Lock()
defer s.mu.Unlock()
for i, o := range s.observers { for i, o := range s.observers {
if o == obs { if o == obs {
s.observers = append(s.observers[:i], s.observers[i+1:]...) s.observers = append(s.observers[:i], s.observers[i+1:]...)
@ -34,37 +45,33 @@ func (s *Subject) Detach(obs Observer) {
// Notify 通知观察者 // Notify 通知观察者
func (s *Subject) Notify(data any) { func (s *Subject) Notify(data any) {
// 这里可考虑不需要使用协程运行 // fanout 只启动一次,避免每次通知重复启动协程
go s.fanOut(s.in, s.observers) s.once.Do(func() {
go s.fanOut(s.in)
})
s.in <- data s.in <- data
} }
// fanOut 扇出模式实现 // fanOut 扇出模式实现
func (s *Subject) fanOut(ch <-chan interface{}, out []Observer) { func (s *Subject) fanOut(ch <-chan interface{}) {
// 绑定输入 chan 的 reflect.SelectCase // 绑定输入 chan 的 reflect.SelectCase
cases := []reflect.SelectCase{ cases := []reflect.SelectCase{
{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(ch)}, {Dir: reflect.SelectRecv, Chan: reflect.ValueOf(ch)},
} }
go func() { for {
defer func() { _, value, ok := reflect.Select(cases) // 从输入 chan 中读取数据
// 退出时关闭所有的输出 chan if !ok {
for _, o := range out { // 输入 channel 被关闭
close(o) return
}
}()
for {
_, value, ok := reflect.Select(cases) // 从输入 chan 中读取数据
if !ok {
// 输入 channel 被关闭
return
}
// 输入 channel 接收到数据
for _, o := range out {
o <- value.Interface() // 放入到输出 chan 中,同步方式
}
} }
}()
// 输入 channel 接收到数据
s.mu.RLock()
for _, o := range s.observers {
o <- value.Interface() // 放入到输出 chan 中,同步方式
}
s.mu.RUnlock()
}
} }

View File

@ -4,10 +4,11 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"go.uber.org/mock/gomock"
"gorm.io/gorm"
"testing" "testing"
"time" "time"
"go.uber.org/mock/gomock"
"gorm.io/gorm"
) )
type TestEntity struct { type TestEntity struct {
@ -34,11 +35,11 @@ type TestService struct {
func (s *TestService) GetFilter(_ context.Context) (any, error) { func (s *TestService) GetFilter(_ context.Context) (any, error) {
return func(db *gorm.DB) *gorm.DB { return func(db *gorm.DB) *gorm.DB {
if s.filter.Name != "" { if s.filter.Name != "" {
db.Where("name = ?", s.filter.Name) db = db.Where("name = ?", s.filter.Name)
} }
if s.filter.Age > 0 { if s.filter.Age > 0 {
db.Where("age >= ?", s.filter.Age) db = db.Where("age >= ?", s.filter.Age)
} }
return db return db
@ -180,7 +181,7 @@ func TestQueryList(t *testing.T) {
ctx context.Context, ctx context.Context,
builder *builder[TestEntity], builder *builder[TestEntity],
next func(context.Context, next func(context.Context,
) ([]*TestEntity, int64, error)) ([]*TestEntity, int64, error) { ) ([]*TestEntity, int64, error)) ([]*TestEntity, int64, error) {
defer func() func() { defer func() func() {
pre := time.Now() pre := time.Now()
return func() { return func() {

View File

@ -8,14 +8,19 @@ import (
// WaitAndGo 等待所有函数执行完毕 // WaitAndGo 等待所有函数执行完毕
func WaitAndGo(fn ...func() error) error { func WaitAndGo(fn ...func() error) error {
defer func() {
if err := recover(); err != nil {
fmt.Printf("Panic: %+v\n %s", err, string(debug.Stack()))
}
}()
var g errgroup.Group var g errgroup.Group
for _, f := range fn { for _, f := range fn {
g.Go(func() error { if f == nil {
continue
}
g.Go(func() (err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic: %+v\n%s", r, string(debug.Stack()))
}
}()
return f() return f()
}) })
} }

View File

@ -17,6 +17,7 @@ func loadOptions(opt ...Option) options {
opts := options{ opts := options{
workers: 1, workers: 1,
capacity: 1, capacity: 1,
taskFn: func(T) {},
} }
for _, o := range opt { for _, o := range opt {
@ -49,7 +50,9 @@ func WithCapacity(capacity int) Option {
// WithTaskFn 设置任务函数 // WithTaskFn 设置任务函数
func WithTaskFn(taskFn func(T)) Option { func WithTaskFn(taskFn func(T)) Option {
return func(o *options) { return func(o *options) {
o.taskFn = taskFn if taskFn != nil {
o.taskFn = taskFn
}
} }
} }

View File

@ -3,6 +3,7 @@ package semaphore
import ( import (
"container/list" "container/list"
"context" "context"
"errors"
"sync" "sync"
) )
@ -25,6 +26,13 @@ func NewSemaphore(n int64) *Semaphore {
} }
func (s *Semaphore) Acquire(ctx context.Context, n int64) error { func (s *Semaphore) Acquire(ctx context.Context, n int64) error {
if n <= 0 {
if n == 0 {
return nil
}
return errors.New("semaphore: negative acquire")
}
done := ctx.Done() done := ctx.Done()
s.mu.Lock() s.mu.Lock()
@ -103,7 +111,7 @@ func (s *Semaphore) notifyWaiters() {
break // 没有等待者 break // 没有等待者
} }
w := next.Value.(waiter) w := next.Value.(*waiter)
if s.size-s.cur < w.n { if s.size-s.cur < w.n {
// 没有足够的资源满足下一个等待者 // 没有足够的资源满足下一个等待者
// 没有必要继续唤醒后续等待者,防止有等待者处于饥饿状态 // 没有必要继续唤醒后续等待者,防止有等待者处于饥饿状态
@ -117,6 +125,13 @@ func (s *Semaphore) notifyWaiters() {
} }
func (s *Semaphore) Release(n int64) { func (s *Semaphore) Release(n int64) {
if n <= 0 {
if n == 0 {
return
}
panic("semaphore: negative release")
}
s.mu.Lock() s.mu.Lock()
s.cur -= n // 释放 n 个资源 s.cur -= n // 释放 n 个资源
@ -130,6 +145,10 @@ func (s *Semaphore) Release(n int64) {
} }
func (s *Semaphore) TryAcquire(n int64) bool { func (s *Semaphore) TryAcquire(n int64) bool {
if n <= 0 {
return n == 0
}
s.mu.Lock() s.mu.Lock()
// 检查当前可用资源是否足够,并且还没有等待者 // 检查当前可用资源是否足够,并且还没有等待者
success := s.size-s.cur >= n && s.waiters.Len() == 0 success := s.size-s.cur >= n && s.waiters.Len() == 0

View File

@ -1,9 +1,6 @@
package semaphore package semaphore
import "sync"
type SemaChan struct { type SemaChan struct {
sync.Locker
sem chan struct{} // 信号量通道 sem chan struct{} // 信号量通道
} }
@ -11,13 +8,18 @@ func NewSemaChan(n int) *SemaChan {
if n <= 0 { if n <= 0 {
n = 1 // 确保信号量至少为1直接变成一个互斥锁 n = 1 // 确保信号量至少为1直接变成一个互斥锁
} }
sem := make(chan struct{}, n) // 初始化信号量通道容量为n
for i := 0; i < n; i++ {
sem <- struct{}{}
}
return &SemaChan{ return &SemaChan{
sem: make(chan struct{}, n), // 初始化信号量通道容量为n sem: sem,
} }
} }
func (s *SemaChan) Lock() { func (s *SemaChan) Lock() {
<-s.sem // 使用接收的方式阻塞,用来与 sync.Mutex 的内存模型保持一致 <-s.sem
} }
func (s *SemaChan) Unlock() { func (s *SemaChan) Unlock() {

View File

@ -1,11 +1,16 @@
package ticker package ticker
import "time" import (
"sync"
"time"
)
type Ticker struct { type Ticker struct {
C chan time.Time C chan time.Time
ticker *time.Ticker ticker *time.Ticker
close chan struct{} close chan struct{}
start sync.Once
once sync.Once
} }
func NewTicker(d time.Duration) *Ticker { func NewTicker(d time.Duration) *Ticker {
@ -17,20 +22,29 @@ func NewTicker(d time.Duration) *Ticker {
} }
func (t *Ticker) Start() { func (t *Ticker) Start() {
// 首次直接触发 t.start.Do(func() {
t.C <- time.Now() // 首次直接触发
go func() { select {
for { case t.C <- time.Now():
select { default:
case <-t.close:
// 关闭时停止goroutine
return
case tc := <-t.ticker.C:
// 把go原生定时器 push 的时间推送到我们定义的 time channel 中
t.C <- tc
}
} }
}()
go func() {
for {
select {
case <-t.close:
// 关闭时停止goroutine
return
case tc := <-t.ticker.C:
// 把go原生定时器 push 的时间推送到我们定义的 time channel 中
select {
case t.C <- tc:
default:
}
}
}
}()
})
} }
func (t *Ticker) Reset(d time.Duration) { func (t *Ticker) Reset(d time.Duration) {
@ -38,6 +52,8 @@ func (t *Ticker) Reset(d time.Duration) {
} }
func (t *Ticker) Stop() { func (t *Ticker) Stop() {
t.ticker.Stop() t.once.Do(func() {
close(t.close) t.ticker.Stop()
close(t.close)
})
} }

View File

@ -40,3 +40,17 @@ func TestTicker(t *testing.T) {
time.Sleep(3 * time.Second) time.Sleep(3 * time.Second)
assert.Equal(t, expected, data) assert.Equal(t, expected, data)
} }
func TestTickerStartTwice(t *testing.T) {
ticker := NewTicker(100 * time.Millisecond)
ticker.Start()
ticker.Start()
defer ticker.Stop()
select {
case <-ticker.C:
case <-time.After(time.Second):
t.Fatal("ticker did not emit tick")
}
}