From c9632e7ab1d9aa6e15b69275a8a23acacdac1910 Mon Sep 17 00:00:00 2001 From: fantasticbin Date: Tue, 10 Dec 2024 08:53:30 +0800 Subject: [PATCH] =?UTF-8?q?=E6=97=A0=E9=94=81=E9=98=9F=E5=88=97=E5=8F=8A?= =?UTF-8?q?=E9=98=9F=E5=88=97=E9=9B=86=E5=90=88=E5=A2=9E=E5=8A=A0=E9=95=BF?= =?UTF-8?q?=E5=BA=A6=E8=8E=B7=E5=8F=96=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lock_free/queue.go | 9 +++++++++ lock_free/queues.go | 8 ++++++++ lock_free/queues_test.go | 12 ++++++++++++ 3 files changed, 29 insertions(+) diff --git a/lock_free/queue.go b/lock_free/queue.go index 5df4299..3a3bae0 100644 --- a/lock_free/queue.go +++ b/lock_free/queue.go @@ -64,6 +64,15 @@ func (q *LkQueue[T]) Dequeue() (value T, ok bool) { } } +// Len 队列长度 +func (q *LkQueue[T]) Len() int { + var count int + for node := load[T](&q.head); node != nil; node = load[T](&node.next) { + count++ + } + return count - 1 // 减去头节点 +} + // load 读取节点的值 func load[T any](p *unsafe.Pointer) *node[T] { return (*node[T])(atomic.LoadPointer(p)) diff --git a/lock_free/queues.go b/lock_free/queues.go index ffbecf5..1ba6372 100644 --- a/lock_free/queues.go +++ b/lock_free/queues.go @@ -35,6 +35,14 @@ func (q *Queues[TKey, TValue, TRoute]) Dequeue(route TRoute) (value TValue, ok b return value, false } +// Len 队列长度 +func (q *Queues[TKey, TValue, TRoute]) Len(route TRoute) int { + if queue, ok := q.queues.Load(route); ok { + return queue.(*DelayLkQueue[TKey, TValue]).Len() + } + return 0 +} + // DelayEnqueue 延迟入队 func (q *Queues[TKey, TValue, TRoute]) DelayEnqueue(route TRoute, value TValue, duration time.Duration) { if queue, ok := q.queues.Load(route); ok { diff --git a/lock_free/queues_test.go b/lock_free/queues_test.go index 6e9e4e0..a56300c 100644 --- a/lock_free/queues_test.go +++ b/lock_free/queues_test.go @@ -16,6 +16,18 @@ func TestQueues(t *testing.T) { route := "test" q := NewQueues[struct{}, int, string]() + for _, c := range cases { + q.Enqueue(route, c.value) + } + + if q.Len(route) != len(cases) { + t.Errorf("queue length error, want %d, got %d", len(cases), q.Len(route)) + } + + for range cases { + q.Dequeue(route) + } + for _, c := range cases { q.DelayEnqueue(route, c.value, c.duration) }