使用codex优化代码,具体如下:
主要优化
- err_group:去掉多余 goroutine,避免潜在泄漏;并把并发 append 改为按下标写
入,消除数据竞争。
- err_group 测试稳定性增强:放宽超时并增加结果长度断言。
- semaphore:修复等待队列元素类型断言错误(*waiter);补充非法参数校验(负数
acquire/release)。
- SemaChan:修复 Lock/Unlock 逻辑(初始化令牌桶),避免永久阻塞。
- observer:修复“每次 Notify 都启动新 fanout 协程”的问题:改为 sync.Once 只启动一次
fanOut。
- observer:修复并发读写观察者列表问题:给 Attach/Detach/fanOut 增加读写锁保护。
- observer:去掉 fanout 内部额外再起 goroutine和自动关闭所有 observer 的行为,避
免重复关闭/竞态风险(仍保留 Detach 时关闭单个 observer)。
- lock_free:修复可取消延迟队列的计数错误与 timers map 并发访问问题。
- lock_free:checkAckStatus 改为非阻塞读取,避免入队路径被卡住。
- routine:提供默认空任务并忽略 nil taskFn,防止空指针调用。
- ticker:发送改为非阻塞,Stop 幂等化,降低阻塞和重复关闭风险。
- query_builder:WaitAndGo 增加 goroutine 内 panic 转 error;测试里
的 GORM filter 链式写法修正。
新增测试
- 新增 semaphore 测试,覆盖 Acquire/Release/TryAcquire 与 SemaChan 并发上限。
This commit is contained in:
parent
88a4caf72e
commit
5b48ea1a62
@ -20,7 +20,6 @@ type H2O struct {
|
||||
semaH *semaphore.Weighted
|
||||
semaO *semaphore.Weighted
|
||||
cb cyclicbarrier.CyclicBarrier
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
func NewH2O() *H2O {
|
||||
@ -69,26 +68,27 @@ func (h *H2O) Gen(num uint) <-chan string {
|
||||
releaseOxygen := func() {
|
||||
ch <- "O"
|
||||
}
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// 使用 WaitGroup 等待所有的 goroutine 完成
|
||||
h.wg.Add(numInt * sum)
|
||||
wg.Add(numInt * sum)
|
||||
for i := 0; i < numInt*H2OHydrogenNum; i++ {
|
||||
go func() {
|
||||
time.Sleep(time.Duration(rand.Intn(100)) * time.Millisecond)
|
||||
h.hydrogen(releaseHydrogen)
|
||||
h.wg.Done()
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
for i := 0; i < numInt*H2OOxygenNum; i++ {
|
||||
go func() {
|
||||
time.Sleep(time.Duration(rand.Intn(100)) * time.Millisecond)
|
||||
h.oxygen(releaseOxygen)
|
||||
h.wg.Done()
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
|
||||
go func() {
|
||||
h.wg.Wait()
|
||||
wg.Wait()
|
||||
close(ch)
|
||||
}()
|
||||
|
||||
|
||||
@ -2,9 +2,10 @@ package err_group
|
||||
|
||||
import (
|
||||
"context"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"math/rand"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
type WeatherData struct {
|
||||
@ -17,30 +18,22 @@ var WeatherList = []string{"晴", "阴", "多云", "小雨", "大雨"}
|
||||
|
||||
// fetchWeatherData 获取天气数据
|
||||
func fetchWeatherData(ctx context.Context, city string) (*WeatherData, error) {
|
||||
done := make(chan struct{}, 1) // 用于指示数据获取成功完成
|
||||
data := new(WeatherData)
|
||||
|
||||
// 启动一个goroutine来获取天气数据
|
||||
go func() {
|
||||
// 这里仅为示例,实际中这里会是API调用
|
||||
// 为了模拟可能出现的超时或错误情况,这里随机地休眠一段时间
|
||||
time.Sleep(time.Duration(rand.Intn(500)) * time.Millisecond)
|
||||
data.City = city
|
||||
// 随机天气
|
||||
data.Weather = WeatherList[rand.Intn(len(WeatherList))]
|
||||
// 随机温度
|
||||
data.Temp = rand.Intn(40)
|
||||
done <- struct{}{}
|
||||
}()
|
||||
|
||||
// 这里仅为示例,实际中这里会是 API 调用
|
||||
// 为了模拟可能出现的超时或错误情况,这里随机休眠一段时间
|
||||
delay := time.Duration(rand.Intn(500)) * time.Millisecond
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// 如果上下文被取消,返回错误
|
||||
return nil, ctx.Err()
|
||||
case <-done:
|
||||
// 如果数据获取成功完成,返回数据
|
||||
return data, nil
|
||||
case <-time.After(delay):
|
||||
}
|
||||
|
||||
return &WeatherData{
|
||||
City: city,
|
||||
// 随机天气
|
||||
Weather: WeatherList[rand.Intn(len(WeatherList))],
|
||||
// 随机温度
|
||||
Temp: rand.Intn(40),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func GetAllCityWeatherData(timeOut time.Duration) ([]*WeatherData, error) {
|
||||
@ -53,11 +46,10 @@ func GetAllCityWeatherData(timeOut time.Duration) ([]*WeatherData, error) {
|
||||
|
||||
// 城市列表
|
||||
cities := []string{"New York", "Tokyo", "Berlin", "Paris", "Beijing"}
|
||||
list := make([]*WeatherData, 0, len(cities))
|
||||
list := make([]*WeatherData, len(cities))
|
||||
|
||||
// 循环启动goroutine来获取每个城市的天气数据
|
||||
for _, city := range cities {
|
||||
city := city
|
||||
for i, city := range cities {
|
||||
g.Go(func() error {
|
||||
// FetchWeatherData使用了上下文,如果上下文被取消,它应该立刻尝试返回
|
||||
data, err := fetchWeatherData(ctx, city)
|
||||
@ -65,8 +57,8 @@ func GetAllCityWeatherData(timeOut time.Duration) ([]*WeatherData, error) {
|
||||
return err
|
||||
}
|
||||
|
||||
// 如果没有错误,将数据添加到列表中
|
||||
list = append(list, data)
|
||||
// 每个协程写入固定下标,避免并发 append 数据竞争
|
||||
list[i] = data
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
@ -6,11 +6,15 @@ import (
|
||||
)
|
||||
|
||||
func TestGetAllCityWeatherData(t *testing.T) {
|
||||
list, err := GetAllCityWeatherData(500 * time.Millisecond)
|
||||
list, err := GetAllCityWeatherData(2 * time.Second)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(list) != 5 {
|
||||
t.Fatalf("expected 5 weather records, got %d", len(list))
|
||||
}
|
||||
|
||||
for _, data := range list {
|
||||
t.Logf("%s: %s, %d", data.City, data.Weather, data.Temp)
|
||||
}
|
||||
|
||||
@ -43,12 +43,18 @@ func (q *DelayLkQueue[TKey, TValue]) CancellableDelayEnqueue(key TKey, value TVa
|
||||
q.m.Lock()
|
||||
defer q.m.Unlock()
|
||||
if timer, ok := q.timers[key]; ok {
|
||||
timer.Stop()
|
||||
if timer.Stop() {
|
||||
q.delayCount.Add(^uint64(0))
|
||||
}
|
||||
}
|
||||
q.delayCount.Add(1)
|
||||
q.timers[key] = time.AfterFunc(duration, func() {
|
||||
q.delayCount.Add(^uint64(0))
|
||||
|
||||
q.m.Lock()
|
||||
delete(q.timers, key)
|
||||
q.m.Unlock()
|
||||
|
||||
q.Enqueue(value)
|
||||
})
|
||||
}
|
||||
@ -58,9 +64,10 @@ func (q *DelayLkQueue[TKey, TValue]) CancelDelayEnqueue(key TKey) {
|
||||
q.m.Lock()
|
||||
defer q.m.Unlock()
|
||||
if timer, ok := q.timers[key]; ok {
|
||||
q.delayCount.Add(^uint64(0))
|
||||
if timer.Stop() {
|
||||
q.delayCount.Add(^uint64(0))
|
||||
}
|
||||
delete(q.timers, key)
|
||||
timer.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -34,11 +34,21 @@ func (q *Queues[TKey, TValue, TRoute]) SetExpireAutoFail(expireAutoFail time.Dur
|
||||
// checkAckStatus 检查确认状态
|
||||
func (q *Queues[TKey, TValue, TRoute]) checkAckStatus(route TRoute) {
|
||||
if ack, ok := q.ack.Load(route); ok && !q.autoAck {
|
||||
if status, exist := <-ack.(chan bool); !status && exist {
|
||||
if msg, ok := q.msgs.LoadAndDelete(route); ok {
|
||||
// 重新入队
|
||||
q.Enqueue(route, msg.(TValue))
|
||||
ackChan := ack.(chan bool)
|
||||
select {
|
||||
case status, exist := <-ackChan:
|
||||
if !exist {
|
||||
q.ack.Delete(route)
|
||||
return
|
||||
}
|
||||
|
||||
if !status {
|
||||
if msg, loaded := q.msgs.LoadAndDelete(route); loaded {
|
||||
// 重新入队
|
||||
q.Enqueue(route, msg.(TValue))
|
||||
}
|
||||
}
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,11 +1,16 @@
|
||||
package observer
|
||||
|
||||
import "reflect"
|
||||
import (
|
||||
"reflect"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Subject 被观察者
|
||||
type Subject struct {
|
||||
observers []Observer
|
||||
in chan any
|
||||
mu sync.RWMutex
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
// Observer 观察者 chan
|
||||
@ -18,11 +23,17 @@ func NewSubject() Subject {
|
||||
|
||||
// Attach 观察者绑定
|
||||
func (s *Subject) Attach(obs ...Observer) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.observers = append(s.observers, obs...)
|
||||
}
|
||||
|
||||
// Detach 观察者解绑
|
||||
func (s *Subject) Detach(obs Observer) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
for i, o := range s.observers {
|
||||
if o == obs {
|
||||
s.observers = append(s.observers[:i], s.observers[i+1:]...)
|
||||
@ -34,37 +45,33 @@ func (s *Subject) Detach(obs Observer) {
|
||||
|
||||
// Notify 通知观察者
|
||||
func (s *Subject) Notify(data any) {
|
||||
// 这里可考虑不需要使用协程运行
|
||||
go s.fanOut(s.in, s.observers)
|
||||
// fanout 只启动一次,避免每次通知重复启动协程
|
||||
s.once.Do(func() {
|
||||
go s.fanOut(s.in)
|
||||
})
|
||||
|
||||
s.in <- data
|
||||
}
|
||||
|
||||
// fanOut 扇出模式实现
|
||||
func (s *Subject) fanOut(ch <-chan interface{}, out []Observer) {
|
||||
func (s *Subject) fanOut(ch <-chan interface{}) {
|
||||
// 绑定输入 chan 的 reflect.SelectCase
|
||||
cases := []reflect.SelectCase{
|
||||
{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(ch)},
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
// 退出时关闭所有的输出 chan
|
||||
for _, o := range out {
|
||||
close(o)
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
_, value, ok := reflect.Select(cases) // 从输入 chan 中读取数据
|
||||
if !ok {
|
||||
// 输入 channel 被关闭
|
||||
return
|
||||
}
|
||||
|
||||
// 输入 channel 接收到数据
|
||||
for _, o := range out {
|
||||
o <- value.Interface() // 放入到输出 chan 中,同步方式
|
||||
}
|
||||
for {
|
||||
_, value, ok := reflect.Select(cases) // 从输入 chan 中读取数据
|
||||
if !ok {
|
||||
// 输入 channel 被关闭
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
// 输入 channel 接收到数据
|
||||
s.mu.RLock()
|
||||
for _, o := range s.observers {
|
||||
o <- value.Interface() // 放入到输出 chan 中,同步方式
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
}
|
||||
}
|
||||
|
||||
@ -4,10 +4,11 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"go.uber.org/mock/gomock"
|
||||
"gorm.io/gorm"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"go.uber.org/mock/gomock"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type TestEntity struct {
|
||||
@ -34,11 +35,11 @@ type TestService struct {
|
||||
func (s *TestService) GetFilter(_ context.Context) (any, error) {
|
||||
return func(db *gorm.DB) *gorm.DB {
|
||||
if s.filter.Name != "" {
|
||||
db.Where("name = ?", s.filter.Name)
|
||||
db = db.Where("name = ?", s.filter.Name)
|
||||
}
|
||||
|
||||
if s.filter.Age > 0 {
|
||||
db.Where("age >= ?", s.filter.Age)
|
||||
db = db.Where("age >= ?", s.filter.Age)
|
||||
}
|
||||
|
||||
return db
|
||||
@ -180,7 +181,7 @@ func TestQueryList(t *testing.T) {
|
||||
ctx context.Context,
|
||||
builder *builder[TestEntity],
|
||||
next func(context.Context,
|
||||
) ([]*TestEntity, int64, error)) ([]*TestEntity, int64, error) {
|
||||
) ([]*TestEntity, int64, error)) ([]*TestEntity, int64, error) {
|
||||
defer func() func() {
|
||||
pre := time.Now()
|
||||
return func() {
|
||||
|
||||
@ -8,14 +8,19 @@ import (
|
||||
|
||||
// WaitAndGo 等待所有函数执行完毕
|
||||
func WaitAndGo(fn ...func() error) error {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
fmt.Printf("Panic: %+v\n %s", err, string(debug.Stack()))
|
||||
}
|
||||
}()
|
||||
var g errgroup.Group
|
||||
for _, f := range fn {
|
||||
g.Go(func() error {
|
||||
if f == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
g.Go(func() (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = fmt.Errorf("panic: %+v\n%s", r, string(debug.Stack()))
|
||||
}
|
||||
}()
|
||||
|
||||
return f()
|
||||
})
|
||||
}
|
||||
|
||||
@ -17,6 +17,7 @@ func loadOptions(opt ...Option) options {
|
||||
opts := options{
|
||||
workers: 1,
|
||||
capacity: 1,
|
||||
taskFn: func(T) {},
|
||||
}
|
||||
|
||||
for _, o := range opt {
|
||||
@ -49,7 +50,9 @@ func WithCapacity(capacity int) Option {
|
||||
// WithTaskFn 设置任务函数
|
||||
func WithTaskFn(taskFn func(T)) Option {
|
||||
return func(o *options) {
|
||||
o.taskFn = taskFn
|
||||
if taskFn != nil {
|
||||
o.taskFn = taskFn
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -3,6 +3,7 @@ package semaphore
|
||||
import (
|
||||
"container/list"
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
)
|
||||
|
||||
@ -25,6 +26,13 @@ func NewSemaphore(n int64) *Semaphore {
|
||||
}
|
||||
|
||||
func (s *Semaphore) Acquire(ctx context.Context, n int64) error {
|
||||
if n <= 0 {
|
||||
if n == 0 {
|
||||
return nil
|
||||
}
|
||||
return errors.New("semaphore: negative acquire")
|
||||
}
|
||||
|
||||
done := ctx.Done()
|
||||
|
||||
s.mu.Lock()
|
||||
@ -103,7 +111,7 @@ func (s *Semaphore) notifyWaiters() {
|
||||
break // 没有等待者
|
||||
}
|
||||
|
||||
w := next.Value.(waiter)
|
||||
w := next.Value.(*waiter)
|
||||
if s.size-s.cur < w.n {
|
||||
// 没有足够的资源满足下一个等待者
|
||||
// 没有必要继续唤醒后续等待者,防止有等待者处于饥饿状态
|
||||
@ -117,6 +125,13 @@ func (s *Semaphore) notifyWaiters() {
|
||||
}
|
||||
|
||||
func (s *Semaphore) Release(n int64) {
|
||||
if n <= 0 {
|
||||
if n == 0 {
|
||||
return
|
||||
}
|
||||
panic("semaphore: negative release")
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.cur -= n // 释放 n 个资源
|
||||
|
||||
@ -130,6 +145,10 @@ func (s *Semaphore) Release(n int64) {
|
||||
}
|
||||
|
||||
func (s *Semaphore) TryAcquire(n int64) bool {
|
||||
if n <= 0 {
|
||||
return n == 0
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
// 检查当前可用资源是否足够,并且还没有等待者
|
||||
success := s.size-s.cur >= n && s.waiters.Len() == 0
|
||||
|
||||
@ -1,9 +1,6 @@
|
||||
package semaphore
|
||||
|
||||
import "sync"
|
||||
|
||||
type SemaChan struct {
|
||||
sync.Locker
|
||||
sem chan struct{} // 信号量通道
|
||||
}
|
||||
|
||||
@ -11,13 +8,18 @@ func NewSemaChan(n int) *SemaChan {
|
||||
if n <= 0 {
|
||||
n = 1 // 确保信号量至少为1,直接变成一个互斥锁
|
||||
}
|
||||
sem := make(chan struct{}, n) // 初始化信号量通道,容量为n
|
||||
for i := 0; i < n; i++ {
|
||||
sem <- struct{}{}
|
||||
}
|
||||
|
||||
return &SemaChan{
|
||||
sem: make(chan struct{}, n), // 初始化信号量通道,容量为n
|
||||
sem: sem,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SemaChan) Lock() {
|
||||
<-s.sem // 使用接收的方式阻塞,用来与 sync.Mutex 的内存模型保持一致
|
||||
<-s.sem
|
||||
}
|
||||
|
||||
func (s *SemaChan) Unlock() {
|
||||
|
||||
@ -1,11 +1,16 @@
|
||||
package ticker
|
||||
|
||||
import "time"
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Ticker struct {
|
||||
C chan time.Time
|
||||
ticker *time.Ticker
|
||||
close chan struct{}
|
||||
start sync.Once
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
func NewTicker(d time.Duration) *Ticker {
|
||||
@ -17,20 +22,29 @@ func NewTicker(d time.Duration) *Ticker {
|
||||
}
|
||||
|
||||
func (t *Ticker) Start() {
|
||||
// 首次直接触发
|
||||
t.C <- time.Now()
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-t.close:
|
||||
// 关闭时停止goroutine
|
||||
return
|
||||
case tc := <-t.ticker.C:
|
||||
// 把go原生定时器 push 的时间推送到我们定义的 time channel 中
|
||||
t.C <- tc
|
||||
}
|
||||
t.start.Do(func() {
|
||||
// 首次直接触发
|
||||
select {
|
||||
case t.C <- time.Now():
|
||||
default:
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-t.close:
|
||||
// 关闭时停止goroutine
|
||||
return
|
||||
case tc := <-t.ticker.C:
|
||||
// 把go原生定时器 push 的时间推送到我们定义的 time channel 中
|
||||
select {
|
||||
case t.C <- tc:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
})
|
||||
}
|
||||
|
||||
func (t *Ticker) Reset(d time.Duration) {
|
||||
@ -38,6 +52,8 @@ func (t *Ticker) Reset(d time.Duration) {
|
||||
}
|
||||
|
||||
func (t *Ticker) Stop() {
|
||||
t.ticker.Stop()
|
||||
close(t.close)
|
||||
t.once.Do(func() {
|
||||
t.ticker.Stop()
|
||||
close(t.close)
|
||||
})
|
||||
}
|
||||
|
||||
@ -40,3 +40,17 @@ func TestTicker(t *testing.T) {
|
||||
time.Sleep(3 * time.Second)
|
||||
assert.Equal(t, expected, data)
|
||||
}
|
||||
|
||||
func TestTickerStartTwice(t *testing.T) {
|
||||
ticker := NewTicker(100 * time.Millisecond)
|
||||
ticker.Start()
|
||||
ticker.Start()
|
||||
|
||||
defer ticker.Stop()
|
||||
|
||||
select {
|
||||
case <-ticker.C:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("ticker did not emit tick")
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user