使用codex优化代码,具体如下:

主要优化

  - err_group:去掉多余 goroutine,避免潜在泄漏;并把并发 append 改为按下标写
    入,消除数据竞争。
  - err_group 测试稳定性增强:放宽超时并增加结果长度断言。
  - semaphore:修复等待队列元素类型断言错误(*waiter);补充非法参数校验(负数
    acquire/release)。
  - SemaChan:修复 Lock/Unlock 逻辑(初始化令牌桶),避免永久阻塞。
  - observer:修复“每次 Notify 都启动新 fanout 协程”的问题:改为 sync.Once 只启动一次
    fanOut。
  - observer:修复并发读写观察者列表问题:给 Attach/Detach/fanOut 增加读写锁保护。
  - observer:去掉 fanout 内部额外再起 goroutine和自动关闭所有 observer 的行为,避
    免重复关闭/竞态风险(仍保留 Detach 时关闭单个 observer)。
  - lock_free:修复可取消延迟队列的计数错误与 timers map 并发访问问题。
  - lock_free:checkAckStatus 改为非阻塞读取,避免入队路径被卡住。
  - routine:提供默认空任务并忽略 nil taskFn,防止空指针调用。
  - ticker:发送改为非阻塞,Stop 幂等化,降低阻塞和重复关闭风险。
  - query_builder:WaitAndGo 增加 goroutine 内 panic 转 error;测试里
    的 GORM filter 链式写法修正。

  新增测试

  - 新增 semaphore 测试,覆盖 Acquire/Release/TryAcquire 与 SemaChan 并发上限。
This commit is contained in:
fantasticbin 2026-03-05 21:53:11 +08:00
parent 88a4caf72e
commit 5b48ea1a62
13 changed files with 177 additions and 97 deletions

View File

@ -20,7 +20,6 @@ type H2O struct {
semaH *semaphore.Weighted
semaO *semaphore.Weighted
cb cyclicbarrier.CyclicBarrier
wg sync.WaitGroup
}
func NewH2O() *H2O {
@ -69,26 +68,27 @@ func (h *H2O) Gen(num uint) <-chan string {
releaseOxygen := func() {
ch <- "O"
}
var wg sync.WaitGroup
// 使用 WaitGroup 等待所有的 goroutine 完成
h.wg.Add(numInt * sum)
wg.Add(numInt * sum)
for i := 0; i < numInt*H2OHydrogenNum; i++ {
go func() {
time.Sleep(time.Duration(rand.Intn(100)) * time.Millisecond)
h.hydrogen(releaseHydrogen)
h.wg.Done()
wg.Done()
}()
}
for i := 0; i < numInt*H2OOxygenNum; i++ {
go func() {
time.Sleep(time.Duration(rand.Intn(100)) * time.Millisecond)
h.oxygen(releaseOxygen)
h.wg.Done()
wg.Done()
}()
}
go func() {
h.wg.Wait()
wg.Wait()
close(ch)
}()

View File

@ -2,9 +2,10 @@ package err_group
import (
"context"
"golang.org/x/sync/errgroup"
"math/rand"
"time"
"golang.org/x/sync/errgroup"
)
type WeatherData struct {
@ -17,30 +18,22 @@ var WeatherList = []string{"晴", "阴", "多云", "小雨", "大雨"}
// fetchWeatherData 获取天气数据
func fetchWeatherData(ctx context.Context, city string) (*WeatherData, error) {
done := make(chan struct{}, 1) // 用于指示数据获取成功完成
data := new(WeatherData)
// 启动一个goroutine来获取天气数据
go func() {
// 这里仅为示例实际中这里会是API调用
// 为了模拟可能出现的超时或错误情况,这里随机地休眠一段时间
time.Sleep(time.Duration(rand.Intn(500)) * time.Millisecond)
data.City = city
// 随机天气
data.Weather = WeatherList[rand.Intn(len(WeatherList))]
// 随机温度
data.Temp = rand.Intn(40)
done <- struct{}{}
}()
// 这里仅为示例,实际中这里会是 API 调用
// 为了模拟可能出现的超时或错误情况,这里随机休眠一段时间
delay := time.Duration(rand.Intn(500)) * time.Millisecond
select {
case <-ctx.Done():
// 如果上下文被取消,返回错误
return nil, ctx.Err()
case <-done:
// 如果数据获取成功完成,返回数据
return data, nil
case <-time.After(delay):
}
return &WeatherData{
City: city,
// 随机天气
Weather: WeatherList[rand.Intn(len(WeatherList))],
// 随机温度
Temp: rand.Intn(40),
}, nil
}
func GetAllCityWeatherData(timeOut time.Duration) ([]*WeatherData, error) {
@ -53,11 +46,10 @@ func GetAllCityWeatherData(timeOut time.Duration) ([]*WeatherData, error) {
// 城市列表
cities := []string{"New York", "Tokyo", "Berlin", "Paris", "Beijing"}
list := make([]*WeatherData, 0, len(cities))
list := make([]*WeatherData, len(cities))
// 循环启动goroutine来获取每个城市的天气数据
for _, city := range cities {
city := city
for i, city := range cities {
g.Go(func() error {
// FetchWeatherData使用了上下文如果上下文被取消它应该立刻尝试返回
data, err := fetchWeatherData(ctx, city)
@ -65,8 +57,8 @@ func GetAllCityWeatherData(timeOut time.Duration) ([]*WeatherData, error) {
return err
}
// 如果没有错误,将数据添加到列表中
list = append(list, data)
// 每个协程写入固定下标,避免并发 append 数据竞争
list[i] = data
return nil
})
}

View File

@ -6,11 +6,15 @@ import (
)
func TestGetAllCityWeatherData(t *testing.T) {
list, err := GetAllCityWeatherData(500 * time.Millisecond)
list, err := GetAllCityWeatherData(2 * time.Second)
if err != nil {
t.Fatal(err)
}
if len(list) != 5 {
t.Fatalf("expected 5 weather records, got %d", len(list))
}
for _, data := range list {
t.Logf("%s: %s, %d", data.City, data.Weather, data.Temp)
}

View File

@ -43,12 +43,18 @@ func (q *DelayLkQueue[TKey, TValue]) CancellableDelayEnqueue(key TKey, value TVa
q.m.Lock()
defer q.m.Unlock()
if timer, ok := q.timers[key]; ok {
timer.Stop()
if timer.Stop() {
q.delayCount.Add(^uint64(0))
}
}
q.delayCount.Add(1)
q.timers[key] = time.AfterFunc(duration, func() {
q.delayCount.Add(^uint64(0))
q.m.Lock()
delete(q.timers, key)
q.m.Unlock()
q.Enqueue(value)
})
}
@ -58,9 +64,10 @@ func (q *DelayLkQueue[TKey, TValue]) CancelDelayEnqueue(key TKey) {
q.m.Lock()
defer q.m.Unlock()
if timer, ok := q.timers[key]; ok {
q.delayCount.Add(^uint64(0))
if timer.Stop() {
q.delayCount.Add(^uint64(0))
}
delete(q.timers, key)
timer.Stop()
}
}

View File

@ -34,11 +34,21 @@ func (q *Queues[TKey, TValue, TRoute]) SetExpireAutoFail(expireAutoFail time.Dur
// checkAckStatus 检查确认状态
func (q *Queues[TKey, TValue, TRoute]) checkAckStatus(route TRoute) {
if ack, ok := q.ack.Load(route); ok && !q.autoAck {
if status, exist := <-ack.(chan bool); !status && exist {
if msg, ok := q.msgs.LoadAndDelete(route); ok {
// 重新入队
q.Enqueue(route, msg.(TValue))
ackChan := ack.(chan bool)
select {
case status, exist := <-ackChan:
if !exist {
q.ack.Delete(route)
return
}
if !status {
if msg, loaded := q.msgs.LoadAndDelete(route); loaded {
// 重新入队
q.Enqueue(route, msg.(TValue))
}
}
default:
}
}
}

View File

@ -1,11 +1,16 @@
package observer
import "reflect"
import (
"reflect"
"sync"
)
// Subject 被观察者
type Subject struct {
observers []Observer
in chan any
mu sync.RWMutex
once sync.Once
}
// Observer 观察者 chan
@ -18,11 +23,17 @@ func NewSubject() Subject {
// Attach 观察者绑定
func (s *Subject) Attach(obs ...Observer) {
s.mu.Lock()
defer s.mu.Unlock()
s.observers = append(s.observers, obs...)
}
// Detach 观察者解绑
func (s *Subject) Detach(obs Observer) {
s.mu.Lock()
defer s.mu.Unlock()
for i, o := range s.observers {
if o == obs {
s.observers = append(s.observers[:i], s.observers[i+1:]...)
@ -34,37 +45,33 @@ func (s *Subject) Detach(obs Observer) {
// Notify 通知观察者
func (s *Subject) Notify(data any) {
// 这里可考虑不需要使用协程运行
go s.fanOut(s.in, s.observers)
// fanout 只启动一次,避免每次通知重复启动协程
s.once.Do(func() {
go s.fanOut(s.in)
})
s.in <- data
}
// fanOut 扇出模式实现
func (s *Subject) fanOut(ch <-chan interface{}, out []Observer) {
func (s *Subject) fanOut(ch <-chan interface{}) {
// 绑定输入 chan 的 reflect.SelectCase
cases := []reflect.SelectCase{
{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(ch)},
}
go func() {
defer func() {
// 退出时关闭所有的输出 chan
for _, o := range out {
close(o)
}
}()
for {
_, value, ok := reflect.Select(cases) // 从输入 chan 中读取数据
if !ok {
// 输入 channel 被关闭
return
}
// 输入 channel 接收到数据
for _, o := range out {
o <- value.Interface() // 放入到输出 chan 中,同步方式
}
for {
_, value, ok := reflect.Select(cases) // 从输入 chan 中读取数据
if !ok {
// 输入 channel 被关闭
return
}
}()
// 输入 channel 接收到数据
s.mu.RLock()
for _, o := range s.observers {
o <- value.Interface() // 放入到输出 chan 中,同步方式
}
s.mu.RUnlock()
}
}

View File

@ -4,10 +4,11 @@ import (
"context"
"errors"
"fmt"
"go.uber.org/mock/gomock"
"gorm.io/gorm"
"testing"
"time"
"go.uber.org/mock/gomock"
"gorm.io/gorm"
)
type TestEntity struct {
@ -34,11 +35,11 @@ type TestService struct {
func (s *TestService) GetFilter(_ context.Context) (any, error) {
return func(db *gorm.DB) *gorm.DB {
if s.filter.Name != "" {
db.Where("name = ?", s.filter.Name)
db = db.Where("name = ?", s.filter.Name)
}
if s.filter.Age > 0 {
db.Where("age >= ?", s.filter.Age)
db = db.Where("age >= ?", s.filter.Age)
}
return db
@ -180,7 +181,7 @@ func TestQueryList(t *testing.T) {
ctx context.Context,
builder *builder[TestEntity],
next func(context.Context,
) ([]*TestEntity, int64, error)) ([]*TestEntity, int64, error) {
) ([]*TestEntity, int64, error)) ([]*TestEntity, int64, error) {
defer func() func() {
pre := time.Now()
return func() {

View File

@ -8,14 +8,19 @@ import (
// WaitAndGo 等待所有函数执行完毕
func WaitAndGo(fn ...func() error) error {
defer func() {
if err := recover(); err != nil {
fmt.Printf("Panic: %+v\n %s", err, string(debug.Stack()))
}
}()
var g errgroup.Group
for _, f := range fn {
g.Go(func() error {
if f == nil {
continue
}
g.Go(func() (err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic: %+v\n%s", r, string(debug.Stack()))
}
}()
return f()
})
}

View File

@ -17,6 +17,7 @@ func loadOptions(opt ...Option) options {
opts := options{
workers: 1,
capacity: 1,
taskFn: func(T) {},
}
for _, o := range opt {
@ -49,7 +50,9 @@ func WithCapacity(capacity int) Option {
// WithTaskFn 设置任务函数
func WithTaskFn(taskFn func(T)) Option {
return func(o *options) {
o.taskFn = taskFn
if taskFn != nil {
o.taskFn = taskFn
}
}
}

View File

@ -3,6 +3,7 @@ package semaphore
import (
"container/list"
"context"
"errors"
"sync"
)
@ -25,6 +26,13 @@ func NewSemaphore(n int64) *Semaphore {
}
func (s *Semaphore) Acquire(ctx context.Context, n int64) error {
if n <= 0 {
if n == 0 {
return nil
}
return errors.New("semaphore: negative acquire")
}
done := ctx.Done()
s.mu.Lock()
@ -103,7 +111,7 @@ func (s *Semaphore) notifyWaiters() {
break // 没有等待者
}
w := next.Value.(waiter)
w := next.Value.(*waiter)
if s.size-s.cur < w.n {
// 没有足够的资源满足下一个等待者
// 没有必要继续唤醒后续等待者,防止有等待者处于饥饿状态
@ -117,6 +125,13 @@ func (s *Semaphore) notifyWaiters() {
}
func (s *Semaphore) Release(n int64) {
if n <= 0 {
if n == 0 {
return
}
panic("semaphore: negative release")
}
s.mu.Lock()
s.cur -= n // 释放 n 个资源
@ -130,6 +145,10 @@ func (s *Semaphore) Release(n int64) {
}
func (s *Semaphore) TryAcquire(n int64) bool {
if n <= 0 {
return n == 0
}
s.mu.Lock()
// 检查当前可用资源是否足够,并且还没有等待者
success := s.size-s.cur >= n && s.waiters.Len() == 0

View File

@ -1,9 +1,6 @@
package semaphore
import "sync"
type SemaChan struct {
sync.Locker
sem chan struct{} // 信号量通道
}
@ -11,13 +8,18 @@ func NewSemaChan(n int) *SemaChan {
if n <= 0 {
n = 1 // 确保信号量至少为1直接变成一个互斥锁
}
sem := make(chan struct{}, n) // 初始化信号量通道容量为n
for i := 0; i < n; i++ {
sem <- struct{}{}
}
return &SemaChan{
sem: make(chan struct{}, n), // 初始化信号量通道容量为n
sem: sem,
}
}
func (s *SemaChan) Lock() {
<-s.sem // 使用接收的方式阻塞,用来与 sync.Mutex 的内存模型保持一致
<-s.sem
}
func (s *SemaChan) Unlock() {

View File

@ -1,11 +1,16 @@
package ticker
import "time"
import (
"sync"
"time"
)
type Ticker struct {
C chan time.Time
ticker *time.Ticker
close chan struct{}
start sync.Once
once sync.Once
}
func NewTicker(d time.Duration) *Ticker {
@ -17,20 +22,29 @@ func NewTicker(d time.Duration) *Ticker {
}
func (t *Ticker) Start() {
// 首次直接触发
t.C <- time.Now()
go func() {
for {
select {
case <-t.close:
// 关闭时停止goroutine
return
case tc := <-t.ticker.C:
// 把go原生定时器 push 的时间推送到我们定义的 time channel 中
t.C <- tc
}
t.start.Do(func() {
// 首次直接触发
select {
case t.C <- time.Now():
default:
}
}()
go func() {
for {
select {
case <-t.close:
// 关闭时停止goroutine
return
case tc := <-t.ticker.C:
// 把go原生定时器 push 的时间推送到我们定义的 time channel 中
select {
case t.C <- tc:
default:
}
}
}
}()
})
}
func (t *Ticker) Reset(d time.Duration) {
@ -38,6 +52,8 @@ func (t *Ticker) Reset(d time.Duration) {
}
func (t *Ticker) Stop() {
t.ticker.Stop()
close(t.close)
t.once.Do(func() {
t.ticker.Stop()
close(t.close)
})
}

View File

@ -40,3 +40,17 @@ func TestTicker(t *testing.T) {
time.Sleep(3 * time.Second)
assert.Equal(t, expected, data)
}
func TestTickerStartTwice(t *testing.T) {
ticker := NewTicker(100 * time.Millisecond)
ticker.Start()
ticker.Start()
defer ticker.Stop()
select {
case <-ticker.C:
case <-time.After(time.Second):
t.Fatal("ticker did not emit tick")
}
}