go-study/lock_free/queue.go

85 lines
1.8 KiB
Go

package lock_free
import (
"sync/atomic"
"unsafe"
)
// LkQueue 无锁队列
type LkQueue[T any] struct {
head unsafe.Pointer
tail unsafe.Pointer
}
// node 节点
type node[T any] struct {
value T
next unsafe.Pointer
}
// NewLkQueue 创建无锁队列
func NewLkQueue[T any]() *LkQueue[T] {
n := unsafe.Pointer(&node[T]{})
return &LkQueue[T]{head: n, tail: n}
}
// Enqueue 入队
func (q *LkQueue[T]) Enqueue(value T) {
n := &node[T]{value: value}
for {
tail := load[T](&q.tail)
next := load[T](&tail.next)
if tail == load[T](&q.tail) { // tail 和 next 是否一致
if next == nil {
if cas(&tail.next, next, n) {
cas(&q.tail, tail, n) // 入队完成。设置 tail
return
}
} else {
cas(&q.tail, tail, next)
}
}
}
}
// Dequeue 出队
func (q *LkQueue[T]) Dequeue() (value T, ok bool) {
for {
head := load[T](&q.head)
tail := load[T](&q.tail)
next := load[T](&head.next)
if head == load[T](&q.head) { // 检查 head、tail 和 next 是否一致
if head == tail { // 队列为空,或者 tail 还未到队尾
if next == nil { // 为空
return value, false
}
cas(&q.tail, tail, next) // 将 tail 往队尾移动
} else {
value = next.value
if cas(&q.head, head, next) {
return value, true // 出队完成
}
}
}
}
}
// Len 队列长度
func (q *LkQueue[T]) Len() int {
var count int
for node := load[T](&q.head); node != nil; node = load[T](&node.next) {
count++
}
return count - 1 // 减去头节点
}
// load 读取节点的值
func load[T any](p *unsafe.Pointer) *node[T] {
return (*node[T])(atomic.LoadPointer(p))
}
// cas 原子地修改节点的值
func cas[T any](p *unsafe.Pointer, old, new *node[T]) bool {
return atomic.CompareAndSwapPointer(p, unsafe.Pointer(old), unsafe.Pointer(new))
}