143 lines
2.8 KiB
Go
143 lines
2.8 KiB
Go
package semaphore
|
|
|
|
import (
|
|
"container/list"
|
|
"context"
|
|
"sync"
|
|
)
|
|
|
|
type waiter struct {
|
|
n int64
|
|
ready chan<- struct{} // 唤醒信号
|
|
}
|
|
|
|
type Semaphore struct {
|
|
size int64 // 资源数量
|
|
cur int64 // 当前已使用的资源数量
|
|
mu sync.Mutex
|
|
waiters list.List // 等待队列
|
|
}
|
|
|
|
func NewSemaphore(n int64) *Semaphore {
|
|
return &Semaphore{
|
|
size: n,
|
|
}
|
|
}
|
|
|
|
func (s *Semaphore) Acquire(ctx context.Context, n int64) error {
|
|
done := ctx.Done()
|
|
|
|
s.mu.Lock()
|
|
// 保证 ctx.Done() happened before Semaphore.Acquire()
|
|
select {
|
|
case <-done:
|
|
s.mu.Unlock()
|
|
return ctx.Err()
|
|
default:
|
|
}
|
|
|
|
// 快速路径:如果当前可用资源足够,直接分配
|
|
if s.size-s.cur >= n && s.waiters.Len() == 0 {
|
|
s.cur += n
|
|
s.mu.Unlock()
|
|
return nil
|
|
}
|
|
|
|
// 慢路径:需要等待释放资源
|
|
return s.acquireSlow(ctx, n)
|
|
}
|
|
|
|
func (s *Semaphore) acquireSlow(ctx context.Context, n int64) error {
|
|
done := ctx.Done()
|
|
|
|
// 如果请求的资源数量超过了所能提供的资源数量,则只能依靠 ctx.Done() 来退出
|
|
if n > s.size {
|
|
s.mu.Unlock()
|
|
<-done
|
|
return ctx.Err()
|
|
}
|
|
|
|
// 资源不足,将调用者加入等待队列
|
|
// 同时创建一个信号通道,用于唤醒等待的调用者
|
|
ready := make(chan struct{})
|
|
w := &waiter{n: n, ready: ready}
|
|
elem := s.waiters.PushBack(w)
|
|
s.mu.Unlock()
|
|
|
|
select {
|
|
case <-done:
|
|
s.mu.Lock()
|
|
select {
|
|
case <-ready:
|
|
// 如果已经被唤醒,假装已经成功获取资源
|
|
s.cur -= n
|
|
s.notifyWaiters()
|
|
default:
|
|
// 如果还没有被唤醒,从等待队列中移除调用者自己
|
|
isFront := s.waiters.Front() == elem
|
|
s.waiters.Remove(elem)
|
|
// 如果当前调用者是队列的第一个且有多余资源,唤醒下一个等待者
|
|
if isFront && s.size > s.cur {
|
|
s.notifyWaiters()
|
|
}
|
|
}
|
|
s.mu.Unlock()
|
|
return ctx.Err()
|
|
|
|
case <-ready:
|
|
// 成功获取资源,唤醒信号已发送
|
|
select {
|
|
case <-done:
|
|
s.Release(n)
|
|
return ctx.Err()
|
|
default:
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func (s *Semaphore) notifyWaiters() {
|
|
for {
|
|
next := s.waiters.Front()
|
|
if next == nil {
|
|
break // 没有等待者
|
|
}
|
|
|
|
w := next.Value.(waiter)
|
|
if s.size-s.cur < w.n {
|
|
// 没有足够的资源满足下一个等待者
|
|
// 没有必要继续唤醒后续等待者,防止有等待者处于饥饿状态
|
|
break
|
|
}
|
|
|
|
s.cur += w.n
|
|
s.waiters.Remove(next)
|
|
close(w.ready) // 唤醒等待者
|
|
}
|
|
}
|
|
|
|
func (s *Semaphore) Release(n int64) {
|
|
s.mu.Lock()
|
|
s.cur -= n // 释放 n 个资源
|
|
|
|
if s.cur < 0 {
|
|
s.mu.Unlock()
|
|
panic("semaphore: released more than held")
|
|
}
|
|
|
|
s.notifyWaiters() // 唤醒等待者
|
|
s.mu.Unlock()
|
|
}
|
|
|
|
func (s *Semaphore) TryAcquire(n int64) bool {
|
|
s.mu.Lock()
|
|
// 检查当前可用资源是否足够,并且还没有等待者
|
|
success := s.size-s.cur >= n && s.waiters.Len() == 0
|
|
if success {
|
|
s.cur += n // 分配资源
|
|
}
|
|
|
|
s.mu.Unlock()
|
|
return success
|
|
}
|