From ceff1433aa511a7319288734752ec069606e0f2a Mon Sep 17 00:00:00 2001 From: Uzziah Date: Fri, 22 Sep 2023 16:49:16 +0800 Subject: [PATCH 1/6] =?UTF-8?q?feat:=20=E5=9F=BA=E4=BA=8E=E4=BC=98?= =?UTF-8?q?=E5=85=88=E7=BA=A7=E6=B7=98=E6=B1=B0=E7=9A=84=E6=9C=AC=E5=9C=B0?= =?UTF-8?q?=E7=BC=93=E5=AD=98=E8=AE=BE=E8=AE=A1=E4=B8=8E=E5=AE=9E=E7=8E=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- memory/priority/cache.go | 281 +++++++++++++++ memory/priority/cache_test.go | 474 +++++++++++++++++++++++++ memory/priority/priority_queue.go | 167 +++++++++ memory/priority/priority_queue_test.go | 352 ++++++++++++++++++ 4 files changed, 1274 insertions(+) create mode 100644 memory/priority/cache.go create mode 100644 memory/priority/cache_test.go create mode 100644 memory/priority/priority_queue.go create mode 100644 memory/priority/priority_queue_test.go diff --git a/memory/priority/cache.go b/memory/priority/cache.go new file mode 100644 index 0000000..31b63ab --- /dev/null +++ b/memory/priority/cache.go @@ -0,0 +1,281 @@ +package priority + +import ( + "context" + "github.com/ecodeclub/ecache" + "github.com/ecodeclub/ecache/internal/errs" + "github.com/ecodeclub/ekit" + "sync" + "time" +) + +type Option func(c *Cache) + +func WithCapacity(cap int) Option { + return func(c *Cache) { + c.cap = cap + } +} + +func WithComparator(comparator Comparator[Node]) Option { + return func(c *Cache) { + c.comparator = comparator + } +} + +func WithCleanInterval(interval time.Duration) Option { + return func(c *Cache) { + c.cleanInterval = interval + } +} + +func NewCache(opts ...Option) ecache.Cache { + defaultCap := 1024 + defaultCleanInterval := 10 * time.Second + defaultScanCount := 1000 + defaultExpiration := 30 * time.Second + + cache := &Cache{ + index: make(map[string]*Node), + comparator: defaultComparator{}, + cap: defaultCap, + cleanInterval: defaultCleanInterval, + scanCount: defaultScanCount, + defaultExpiration: defaultExpiration, + } + + for _, opt := range opts { + opt(cache) + } + + cache.pq = NewQueueWithHeap[Node](cache.comparator) + + go cache.clean() + + return cache +} + +// defaultComparator 默认比较器 按节点的过期时间进行比较 +type defaultComparator struct { +} + +func (d defaultComparator) Compare(src, dest *Node) int { + if src.Dl.Before(dest.Dl) { + return -1 + } + + if src.Dl.After(dest.Dl) { + return 1 + } + + return 0 +} + +type Cache struct { + index map[string]*Node // 用于存储数据的索引,方便快速查找 + pq Queue[Node] // 优先级队列,用于存储数据 + comparator Comparator[Node] // 比较器 + mu sync.RWMutex // 读写锁 + cap int // 容量 + len int // 当前队列长度 + cleanInterval time.Duration // 清理过期数据的时间间隔 + scanCount int // 扫描次数 + closeC chan struct{} // 关闭信号 + defaultExpiration time.Duration +} + +func (c *Cache) Set(ctx context.Context, key string, val any, expiration time.Duration) error { + c.mu.Lock() + defer c.mu.Unlock() + // 如果存在,则更新 + if node, ok := c.index[key]; ok { + node.Val = val + node.Dl = time.Now().Add(expiration) // 更新过期时间 + return nil + } + // 如果不存在,则插入 + // 插入之前校验容量是否已满,如果已满,需要淘汰优先级最低的数据 + c.add(ctx, key, val, expiration) + + return nil +} + +func (c *Cache) add(ctx context.Context, key string, val any, expiration time.Duration) { + c.checkCapacityAndDisuse(ctx) + + node := &Node{ + Key: key, + Val: val, + Dl: time.Now().Add(expiration), + } + + _ = c.pq.Push(ctx, node) + + c.index[key] = node + c.len++ +} + +func (c *Cache) checkCapacityAndDisuse(ctx context.Context) { + if c.len >= c.cap { + // 淘汰优先级最低的数据 + node, _ := c.pq.Pop(ctx) + // 删除索引 + delete(c.index, node.Key) + c.len-- + } +} + +func (c *Cache) SetNX(ctx context.Context, key string, val any, expiration time.Duration) (bool, error) { + c.mu.Lock() + defer c.mu.Unlock() + + if node, ok := c.index[key]; ok { + node.Dl = time.Now().Add(expiration) // 更新过期时间 + return false, nil + } + + c.add(ctx, key, val, expiration) + + return true, nil +} + +func (c *Cache) Get(ctx context.Context, key string) ecache.Value { + c.mu.Lock() + defer c.mu.Unlock() + + node, ok := c.index[key] + + if ok && node.Dl.After(time.Now()) { + return ecache.Value{ + AnyValue: ekit.AnyValue{ + Val: node.Val, + }, + } + } + + // 过期删除 + if ok { + _ = c.pq.Remove(ctx, node) + delete(c.index, key) + c.len-- + } + + return ecache.Value{ + AnyValue: ekit.AnyValue{ + Err: errs.ErrKeyNotExist, + }, + } +} + +func (c *Cache) GetSet(ctx context.Context, key string, val string) ecache.Value { + c.mu.Lock() + defer c.mu.Unlock() + + node, ok := c.index[key] + + if ok && node.Dl.After(time.Now()) { + old := node.Val + node.Val = val + return ecache.Value{ + AnyValue: ekit.AnyValue{ + Val: old, + }, + } + } + + if ok { + node.Val = val + node.Dl = time.Now().Add(c.defaultExpiration) + } else { + c.add(ctx, key, val, c.defaultExpiration) + } + + return ecache.Value{ + AnyValue: ekit.AnyValue{ + Err: errs.ErrKeyNotExist, + }, + } + +} + +func (c *Cache) LPush(ctx context.Context, key string, val ...any) (int64, error) { + //TODO implement me + panic("implement me") +} + +func (c *Cache) LPop(ctx context.Context, key string) ecache.Value { + //TODO implement me + panic("implement me") +} + +func (c *Cache) SAdd(ctx context.Context, key string, members ...any) (int64, error) { + //TODO implement me + panic("implement me") +} + +func (c *Cache) SRem(ctx context.Context, key string, members ...any) ecache.Value { + //TODO implement me + panic("implement me") +} + +func (c *Cache) IncrBy(ctx context.Context, key string, value int64) (int64, error) { + //TODO implement me + panic("implement me") +} + +func (c *Cache) DecrBy(ctx context.Context, key string, value int64) (int64, error) { + //TODO implement me + panic("implement me") +} + +func (c *Cache) IncrByFloat(ctx context.Context, key string, value float64) (float64, error) { + //TODO implement me + panic("implement me") +} + +func (c *Cache) Close() error { + close(c.closeC) + return nil +} + +func (c *Cache) clean() { + + ticker := time.NewTicker(c.cleanInterval) + + for { + select { + case <-ticker.C: + c.mu.Lock() + count := 0 + for k, v := range c.index { + if v.Dl.Before(time.Now()) { + _ = c.pq.Remove(context.Background(), v) + delete(c.index, k) + c.len-- + } + count++ + if count >= c.scanCount { + break + } + } + c.mu.Unlock() + case <-c.closeC: + return + } + } +} + +type Node struct { + Key string + Val any + Dl time.Time // 过期时间 + idx int +} + +func (n Node) Index() int { + return n.idx +} + +func (n Node) SetIndex(idx int) { + n.idx = idx +} diff --git a/memory/priority/cache_test.go b/memory/priority/cache_test.go new file mode 100644 index 0000000..6c98b7b --- /dev/null +++ b/memory/priority/cache_test.go @@ -0,0 +1,474 @@ +package priority + +import ( + "context" + "github.com/ecodeclub/ecache" + "github.com/ecodeclub/ecache/internal/errs" + "github.com/stretchr/testify/assert" + "testing" + "time" +) + +func TestCache_Set(t *testing.T) { + ctx := context.TODO() + + testCases := []struct { + name string + + cache ecache.Cache + + key string + val any + expiration time.Duration + + before func(cache ecache.Cache) + + wantIndex map[string]*Node + }{ + { + // 测试正常情况 + name: "test normal set", + cache: NewCache(), + key: "k1", + val: "v1", + expiration: 30 * time.Second, + before: func(cache ecache.Cache) { + + }, + wantIndex: map[string]*Node{ + "k1": { + Key: "k1", + Val: "v1", + Dl: time.Now().Add(30 * time.Second), + }, + }, + }, + { + // 测试key已存在的情况 + name: "test key exists set", + cache: NewCache(), + key: "k1", + val: "v1", + expiration: 30 * time.Second, + before: func(cache ecache.Cache) { + _ = cache.Set(ctx, "k1", "v2", 10*time.Second) + }, + wantIndex: map[string]*Node{ + "k1": { + Key: "k1", + Val: "v1", + Dl: time.Now().Add(30 * time.Second), + }, + }, + }, + { + // 测试淘汰策略 + name: "test eviction set", + cache: NewCache(WithCapacity(3)), + key: "k1", + val: "v1", + expiration: 30 * time.Second, + before: func(cache ecache.Cache) { + _ = cache.Set(ctx, "k2", "v2", 10*time.Second) + _ = cache.Set(ctx, "k3", "v3", 3*time.Second) + _ = cache.Set(ctx, "k4", "v4", 5*time.Second) + }, + wantIndex: map[string]*Node{ + "k1": { + Key: "k1", + Val: "v1", + Dl: time.Now().Add(30 * time.Second), + }, + "k2": { + Key: "k2", + Val: "v2", + Dl: time.Now().Add(10 * time.Second), + }, + "k4": { + Key: "k4", + Val: "v4", + Dl: time.Now().Add(5 * time.Second), + }, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tc.before(tc.cache) + + _ = tc.cache.Set(ctx, tc.key, tc.val, tc.expiration) + + assert.Equal(t, len(tc.wantIndex), len(tc.cache.(*Cache).index)) + + for k, v := range tc.wantIndex { + assert.Equal(t, v.Val, tc.cache.(*Cache).index[k].Val) + + assert.InDelta(t, v.Dl.Unix(), tc.cache.(*Cache).index[k].Dl.Unix(), 1) + } + }) + } +} + +func TestCache_SetNX(t *testing.T) { + ctx := context.TODO() + + testCases := []struct { + name string + + cache ecache.Cache + + key string + val any + expiration time.Duration + + before func(cache ecache.Cache) + + wantIndex map[string]*Node + wantRes bool + }{ + { + // 测试正常情况 + name: "test normal setnx", + cache: NewCache(), + key: "k1", + val: "v1", + expiration: 30 * time.Second, + before: func(cache ecache.Cache) { + _ = cache.Set(ctx, "k2", "v2", 10*time.Second) + }, + wantRes: true, + wantIndex: map[string]*Node{ + "k1": { + Key: "k1", + Val: "v1", + Dl: time.Now().Add(30 * time.Second), + }, + "k2": { + Key: "k2", + Val: "v2", + Dl: time.Now().Add(10 * time.Second), + }, + }, + }, + { + // 测试key已存在的情况 + name: "test key exists setnx", + cache: NewCache(), + key: "k1", + val: "v1", + expiration: 30 * time.Second, + before: func(cache ecache.Cache) { + _ = cache.Set(ctx, "k1", "v2", 10*time.Second) + }, + wantRes: false, + wantIndex: map[string]*Node{ + "k1": { + Key: "k1", + Val: "v2", + Dl: time.Now().Add(30 * time.Second), + }, + }, + }, + { + // 测试淘汰策略 + name: "test eviction set", + cache: NewCache(WithCapacity(3)), + key: "k1", + val: "v1", + expiration: 30 * time.Second, + before: func(cache ecache.Cache) { + _ = cache.Set(ctx, "k2", "v2", 10*time.Second) + _ = cache.Set(ctx, "k3", "v3", 3*time.Second) + _ = cache.Set(ctx, "k4", "v4", 5*time.Second) + }, + wantRes: true, + wantIndex: map[string]*Node{ + "k1": { + Key: "k1", + Val: "v1", + Dl: time.Now().Add(30 * time.Second), + }, + "k2": { + Key: "k2", + Val: "v2", + Dl: time.Now().Add(10 * time.Second), + }, + "k4": { + Key: "k4", + Val: "v4", + Dl: time.Now().Add(5 * time.Second), + }, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tc.before(tc.cache) + + ok, _ := tc.cache.SetNX(ctx, tc.key, tc.val, tc.expiration) + + assert.Equal(t, tc.wantRes, ok) + + assert.Equal(t, len(tc.wantIndex), len(tc.cache.(*Cache).index)) + + for k, v := range tc.wantIndex { + assert.Equal(t, v.Val, tc.cache.(*Cache).index[k].Val) + + assert.InDelta(t, v.Dl.Unix(), tc.cache.(*Cache).index[k].Dl.Unix(), 1) + } + }) + } +} + +func TestCache_Get(t *testing.T) { + ctx := context.TODO() + + testCases := []struct { + name string + + cache ecache.Cache + + key string + + before func(cache ecache.Cache) + + wantVal any + wantErr error + beforeGetIndex map[string]*Node + wantIndex map[string]*Node + }{ + { + // 测试正常情况 + name: "test normal get", + cache: NewCache(), + key: "k1", + before: func(cache ecache.Cache) { + _ = cache.Set(ctx, "k1", "v1", 30*time.Second) + }, + wantVal: "v1", + wantIndex: map[string]*Node{ + "k1": { + Key: "k1", + Val: "v1", + Dl: time.Now().Add(30 * time.Second), + }, + }, + }, + { + // 测试key不存在的情况 + name: "test key not exists get", + cache: NewCache(), + key: "k1", + before: func(cache ecache.Cache) { + _ = cache.Set(ctx, "k2", "v1", 30*time.Second) + }, + wantErr: errs.ErrKeyNotExist, + wantIndex: map[string]*Node{ + "k2": { + Key: "k2", + Val: "v1", + Dl: time.Now().Add(30 * time.Second), + }, + }, + }, + { + // 测试key已存在的情况, 但是key已经过期,并且惰性删除 + name: "test key exists but expired get and lazy delete", + cache: NewCache(), + key: "k1", + before: func(cache ecache.Cache) { + _ = cache.Set(ctx, "k1", "v2", 1*time.Second) + _ = cache.Set(ctx, "k2", "v2", 30*time.Second) + time.Sleep(2 * time.Second) + }, + wantErr: errs.ErrKeyNotExist, + beforeGetIndex: map[string]*Node{ + "k1": { + Key: "k1", + Val: "v2", + Dl: time.Now().Add(1 * time.Second), + }, + "k2": { + Key: "k2", + Val: "v2", + Dl: time.Now().Add(30 * time.Second), + }, + }, + wantIndex: map[string]*Node{ + "k2": { + Key: "k2", + Val: "v2", + Dl: time.Now().Add(30 * time.Second), + }, + }, + }, + { + // 测试key已存在的情况, 但是key已经过期,并且被扫描删除 + name: "test key exists but expired get and scan delete", + cache: NewCache(WithCleanInterval(2 * time.Second)), + key: "k1", + before: func(cache ecache.Cache) { + _ = cache.Set(ctx, "k1", "v2", 1*time.Second) + _ = cache.Set(ctx, "k2", "v2", 30*time.Second) + time.Sleep(3 * time.Second) + }, + wantErr: errs.ErrKeyNotExist, + beforeGetIndex: map[string]*Node{ + "k2": { + Key: "k2", + Val: "v2", + Dl: time.Now().Add(30 * time.Second), + }, + }, + wantIndex: map[string]*Node{ + "k2": { + Key: "k2", + Val: "v2", + Dl: time.Now().Add(30 * time.Second), + }, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tc.before(tc.cache) + + for k, v := range tc.beforeGetIndex { + assert.Equal(t, v.Val, tc.cache.(*Cache).index[k].Val) + + assert.InDelta(t, v.Dl.Unix(), tc.cache.(*Cache).index[k].Dl.Unix(), 2) + } + + res := tc.cache.Get(ctx, tc.key) + + assert.Equal(t, len(tc.wantIndex), len(tc.cache.(*Cache).index)) + + for k, v := range tc.wantIndex { + assert.Equal(t, v.Val, tc.cache.(*Cache).index[k].Val) + + assert.InDelta(t, v.Dl.Unix(), tc.cache.(*Cache).index[k].Dl.Unix(), 2) + } + + assert.Equal(t, tc.wantErr, res.Err) + + if res.Err != nil { + return + } + + assert.Equal(t, tc.wantVal, res.Val) + }) + } +} + +func TestCache_GetSet(t *testing.T) { + ctx := context.TODO() + + testCases := []struct { + name string + + cache ecache.Cache + + key string + val string + + before func(cache ecache.Cache) + + wantVal any + wantErr error + wantIndex map[string]*Node + }{ + { + // 测试正常情况 + name: "test normal getset", + cache: NewCache(), + key: "k1", + val: "v2", + before: func(cache ecache.Cache) { + _ = cache.Set(ctx, "k1", "v1", 30*time.Second) + }, + wantVal: "v1", + wantIndex: map[string]*Node{ + "k1": { + Key: "k1", + Val: "v2", + Dl: time.Now().Add(30 * time.Second), + }, + }, + }, + { + // 测试key不存在的情况 + name: "test key not exists getset", + cache: NewCache(), + key: "k1", + val: "v2", + before: func(cache ecache.Cache) { + _ = cache.Set(ctx, "k2", "v1", 30*time.Second) + }, + wantErr: errs.ErrKeyNotExist, + wantIndex: map[string]*Node{ + "k1": { + Key: "k1", + Val: "v2", + Dl: time.Now().Add(30 * time.Second), + }, + "k2": { + Key: "k2", + Val: "v1", + Dl: time.Now().Add(30 * time.Second), + }, + }, + }, + { + // 测试key已存在的情况, 但是key已经过期 + name: "test key exists but expired getset", + cache: NewCache(), + key: "k1", + val: "v3", + before: func(cache ecache.Cache) { + _ = cache.Set(ctx, "k1", "v2", 1*time.Second) + _ = cache.Set(ctx, "k2", "v2", 30*time.Second) + time.Sleep(2 * time.Second) + }, + wantErr: errs.ErrKeyNotExist, + wantIndex: map[string]*Node{ + "k1": { + Key: "k1", + Val: "v3", + Dl: time.Now().Add(32 * time.Second), + }, + "k2": { + Key: "k2", + Val: "v2", + Dl: time.Now().Add(30 * time.Second), + }, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tc.before(tc.cache) + + res := tc.cache.GetSet(ctx, tc.key, tc.val) + + assert.Equal(t, len(tc.wantIndex), len(tc.cache.(*Cache).index)) + + for k, v := range tc.wantIndex { + assert.Equal(t, v.Val, tc.cache.(*Cache).index[k].Val) + + assert.InDelta(t, v.Dl.Unix(), tc.cache.(*Cache).index[k].Dl.Unix(), 1) + } + + assert.Equal(t, tc.wantErr, res.Err) + + if res.Err != nil { + return + } + + assert.Equal(t, tc.wantVal, res.Val) + }) + } +} diff --git a/memory/priority/priority_queue.go b/memory/priority/priority_queue.go new file mode 100644 index 0000000..dc584d1 --- /dev/null +++ b/memory/priority/priority_queue.go @@ -0,0 +1,167 @@ +package priority + +import ( + "context" + "errors" +) + +type Queue[T any] interface { + Push(ctx context.Context, t *T) error + Pop(ctx context.Context) (*T, error) + Peek(ctx context.Context) (*T, error) + Remove(ctx context.Context, t *T) error // 为了支持随机删除而引入的接口,如果不需要随机删除,可以不实现 +} + +type Comparator[T any] interface { + Compare(src, dest *T) int +} + +type Indexable interface { + Index() int + SetIndex(idx int) +} + +func NewQueueWithHeap[T any](comparator Comparator[T]) Queue[T] { + // 这里可以考虑给一个默认的堆容量 + return &QueueWithHeap[T]{ + heap: make([]*T, 0), + comparator: comparator, + len: 0, + } +} + +type QueueWithHeap[T any] struct { + heap []*T + comparator Comparator[T] + len int +} + +func (q *QueueWithHeap[T]) Push(ctx context.Context, t *T) error { + if len(q.heap) > q.len { + q.heap[q.len] = t + } else { + q.heap = append(q.heap, t) + } + + // 如果是可索引的,需要为这个类型设置索引 + if idx, ok := checkIndexable(t); ok { + idx.SetIndex(q.len) + } + + q.len++ + + q.heapifyUp(q.len - 1) + + return nil +} + +func (q *QueueWithHeap[T]) Pop(ctx context.Context) (*T, error) { + if q.len == 0 { + return nil, errors.New("队列为空") + } + res := q.heap[0] + q.heap[0] = q.heap[q.len-1] + q.heap[q.len-1] = nil // let GC do its work + q.len-- + + q.heapifyDown(0) + return res, nil +} + +func (q *QueueWithHeap[T]) Peek(ctx context.Context) (*T, error) { + if q.len == 0 { + return nil, errors.New("队列为空") + } + + return q.heap[0], nil +} + +// Remove 随机删除一个元素 +// 但是要确保这个元素是在堆里的 +func (q *QueueWithHeap[T]) Remove(ctx context.Context, t *T) error { + idx, ok := checkIndexable(t) + if !ok { + return errors.New("只有实现Indexable的数据才能随机删除") + } + + if idx.Index() >= q.len { + return errors.New("这个元素不在堆里") + } + + q.heap[idx.Index()] = q.heap[q.len-1] + q.heap[q.len-1] = nil // let GC do its work + q.len-- + q.heapifyDown(idx.Index()) + return nil +} + +// 0 +// 1 2 +// 3 4 5 6 +// 7 8 9 10 +// +// root -> left(2n+1) root -> right(2n+2) left/right -> root(n-1/2) +// 堆化应该从最后一个有子节点的节点开始堆化(len(heap)-2/2) +func (q *QueueWithHeap[T]) heapify() { + n := q.len + + if n <= 1 { + return + } + + cur := n - 2/2 + + for i := cur; i >= 0; i-- { + q.heapifyDown(cur) + } +} + +// heapifyDown 从上往下进行堆化 +func (q *QueueWithHeap[T]) heapifyDown(cur int) { + n := q.len + + // 如果满足 idx <= n - 2 / 2 说明有子节点,需要往下进行堆化 + for cur <= (n-2)>>1 { + l, r := 2*cur+1, 2*cur+2 + min := l + + if r < n && q.comparator.Compare(q.heap[l], q.heap[r]) > 0 { + min = r + } + + // 说明已经满足堆化条件,直接返回 + if q.comparator.Compare(q.heap[cur], q.heap[min]) < 0 { + return + } + + // swap + q.swap(cur, min) + + cur = min + } +} + +// heapifyUp 从下往上进行堆化 +func (q *QueueWithHeap[T]) heapifyUp(cur int) { + for p := (cur - 1) >> 1; cur > 0 && q.comparator.Compare(q.heap[cur], q.heap[p]) < 0; cur, p = p, (p-1)>>1 { + q.swap(cur, p) + } +} + +// swap 交换下标值为src和dest位置的值,如果实现了Indexable接口,则更新以下索引 +func (q *QueueWithHeap[T]) swap(src, dest int) { + q.heap[src], q.heap[dest] = q.heap[dest], q.heap[src] + + if idx, ok := checkIndexable(q.heap[src]); ok { + idx.SetIndex(src) + } + + if idx, ok := checkIndexable(q.heap[dest]); ok { + idx.SetIndex(dest) + } +} + +func checkIndexable(val any) (Indexable, bool) { + idx, ok := val.(Indexable) + return idx, ok +} diff --git a/memory/priority/priority_queue_test.go b/memory/priority/priority_queue_test.go new file mode 100644 index 0000000..6d52d7a --- /dev/null +++ b/memory/priority/priority_queue_test.go @@ -0,0 +1,352 @@ +package priority + +import ( + "context" + "errors" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestQueueWithHeap_Push(t *testing.T) { + + ctx := context.TODO() + + testCases := []struct { + name string + + q Queue[testNode] + + t *testNode + + before func(q Queue[testNode]) + + wantLen int + wantRes []*testNode + }{ + { + // 队列为空,插入一个元素 + name: "insert one element, queue is empty", + q: NewQueueWithHeap[testNode](&testComparator{}), + t: &testNode{data: 1}, + before: func(q Queue[testNode]) { + + }, + wantLen: 1, + wantRes: []*testNode{ + {data: 1, index: 0}, + }, + }, + { + // 队列不为空,插入一个元素 + name: "insert one element, and no heapify", + q: NewQueueWithHeap[testNode](&testComparator{}), + t: &testNode{data: 5}, + before: func(q Queue[testNode]) { + _ = q.Push(ctx, &testNode{data: 2}) + _ = q.Push(ctx, &testNode{data: 3}) + _ = q.Push(ctx, &testNode{data: 4}) + _ = q.Push(ctx, &testNode{data: 6}) + _ = q.Push(ctx, &testNode{data: 7}) + }, + wantLen: 6, + wantRes: []*testNode{ + {data: 2, index: 0}, + {data: 3, index: 1}, + {data: 4, index: 2}, + {data: 6, index: 3}, + {data: 7, index: 4}, + {data: 5, index: 5}, + }, + }, + { + // 队列不为空,插入一个元素 + name: "insert one element, and heapify", + q: NewQueueWithHeap[testNode](&testComparator{}), + t: &testNode{data: 1}, + before: func(q Queue[testNode]) { + _ = q.Push(ctx, &testNode{data: 2}) + _ = q.Push(ctx, &testNode{data: 3}) + _ = q.Push(ctx, &testNode{data: 4}) + _ = q.Push(ctx, &testNode{data: 6}) + _ = q.Push(ctx, &testNode{data: 7}) + }, + wantLen: 6, + wantRes: []*testNode{ + {data: 1, index: 0}, + {data: 3, index: 1}, + {data: 2, index: 2}, + {data: 6, index: 3}, + {data: 7, index: 4}, + {data: 4, index: 5}, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tc.before(tc.q) + + err := tc.q.Push(ctx, tc.t) + if err != nil { + assert.NoError(t, err) + } + + assert.Equal(t, tc.wantLen, tc.q.(*QueueWithHeap[testNode]).len) + + for i, v := range tc.wantRes { + assert.Equal(t, v, tc.q.(*QueueWithHeap[testNode]).heap[i]) + } + }) + } +} + +func TestQueueWithHeap_Pop(t *testing.T) { + + ctx := context.TODO() + + testCases := []struct { + name string + + q Queue[testNode] + + before func(q Queue[testNode]) + + wantLen int + wantRes *testNode + wantHeap []*testNode + wantErr error + }{ + { + // 队列为空,弹出一个元素 + name: "pop one element, queue is empty", + q: NewQueueWithHeap[testNode](&testComparator{}), + before: func(q Queue[testNode]) { + + }, + wantErr: errors.New("队列为空"), + }, + { + // 当队列只有一个元素,弹出一个元素 + name: "pop one element, queue has one element", + q: NewQueueWithHeap[testNode](&testComparator{}), + before: func(q Queue[testNode]) { + _ = q.Push(ctx, &testNode{data: 2}) + }, + wantLen: 0, + wantRes: &testNode{data: 2, index: 0}, + wantHeap: []*testNode{}, + }, + { + // 堆里多个元素,弹出一个元素 + name: "pop one element, queue has many elements", + q: NewQueueWithHeap[testNode](&testComparator{}), + before: func(q Queue[testNode]) { + _ = q.Push(ctx, &testNode{data: 2}) + _ = q.Push(ctx, &testNode{data: 3}) + _ = q.Push(ctx, &testNode{data: 4}) + _ = q.Push(ctx, &testNode{data: 6}) + _ = q.Push(ctx, &testNode{data: 7}) + }, + wantLen: 4, + wantRes: &testNode{data: 2, index: 0}, + wantHeap: []*testNode{ + {data: 3, index: 0}, + {data: 6, index: 1}, + {data: 4, index: 2}, + {data: 7, index: 3}, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tc.before(tc.q) + + res, err := tc.q.Pop(ctx) + assert.Equal(t, tc.wantErr, err) + + if err != nil { + return + } + + assert.Equal(t, tc.wantRes, res) + + assert.Equal(t, tc.wantLen, tc.q.(*QueueWithHeap[testNode]).len) + + for i, v := range tc.wantHeap { + assert.Equal(t, v, tc.q.(*QueueWithHeap[testNode]).heap[i]) + } + }) + } +} + +func TestQueueWithHeap_Peek(t *testing.T) { + + ctx := context.TODO() + + testCases := []struct { + name string + + q Queue[testNode] + + before func(q Queue[testNode]) + + wantRes *testNode + wantErr error + }{ + { + // 队列为空,peek一个元素 + name: "peek one element, queue is empty", + q: NewQueueWithHeap[testNode](&testComparator{}), + before: func(q Queue[testNode]) { + + }, + wantErr: errors.New("队列为空"), + }, + { + // 堆里多个元素,peek一个元素 + name: "peek one element, queue has many elements", + q: NewQueueWithHeap[testNode](&testComparator{}), + before: func(q Queue[testNode]) { + _ = q.Push(ctx, &testNode{data: 2}) + _ = q.Push(ctx, &testNode{data: 3}) + _ = q.Push(ctx, &testNode{data: 4}) + _ = q.Push(ctx, &testNode{data: 6}) + _ = q.Push(ctx, &testNode{data: 7}) + }, + wantRes: &testNode{data: 2, index: 0}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tc.before(tc.q) + + res, err := tc.q.Peek(ctx) + assert.Equal(t, tc.wantErr, err) + + if err != nil { + return + } + + assert.Equal(t, tc.wantRes, res) + }) + } +} + +func TestQueueWithHeap_Remove(t *testing.T) { + + ctx := context.TODO() + + testCases := []struct { + name string + + q Queue[testNode] + t *testNode + + before func(q Queue[testNode]) + + wantLen int + wantHeap []*testNode + wantErr error + }{ + { + // 删除一个不在队列中的元素 + name: "remove one element, element not in queue", + q: NewQueueWithHeap[testNode](&testComparator{}), + t: &testNode{data: 2}, + before: func(q Queue[testNode]) { + + }, + wantErr: errors.New("这个元素不在堆里"), + }, + { + // 删除一个在队列中的元素 + name: "remove one element, element in queue", + q: NewQueueWithHeap[testNode](&testComparator{}), + t: &testNode{data: 3, index: 1}, + before: func(q Queue[testNode]) { + _ = q.Push(ctx, &testNode{data: 2}) + _ = q.Push(ctx, &testNode{data: 3}) + _ = q.Push(ctx, &testNode{data: 4}) + _ = q.Push(ctx, &testNode{data: 6}) + _ = q.Push(ctx, &testNode{data: 7}) + _ = q.Push(ctx, &testNode{data: 9}) + _ = q.Push(ctx, &testNode{data: 10}) + }, + wantLen: 6, + wantHeap: []*testNode{ + {data: 2, index: 0}, + {data: 6, index: 1}, + {data: 4, index: 2}, + {data: 10, index: 3}, + {data: 7, index: 4}, + {data: 9, index: 5}, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tc.before(tc.q) + + err := tc.q.Remove(ctx, tc.t) + assert.Equal(t, tc.wantErr, err) + + assert.Equal(t, tc.wantLen, tc.q.(*QueueWithHeap[testNode]).len) + + for i, v := range tc.wantHeap { + assert.Equal(t, v, tc.q.(*QueueWithHeap[testNode]).heap[i]) + } + }) + } +} + +func TestQueueWithHeap_Remove_Not_Indexable(t *testing.T) { + + q := NewQueueWithHeap[int](&testIntComparator{}) + + arg := 1 + + err := q.Remove(context.Background(), &arg) + + assert.Equal(t, errors.New("只有实现Indexable的数据才能随机删除"), err) +} + +type testIntComparator struct{} + +func (t *testIntComparator) Compare(src, dest *int) int { + if *src > *dest { + return 1 + } else if *src < *dest { + return -1 + } else { + return 0 + } +} + +type testComparator struct{} + +func (t *testComparator) Compare(src, dest *testNode) int { + if src.data > dest.data { + return 1 + } else if src.data < dest.data { + return -1 + } else { + return 0 + } +} + +type testNode struct { + index int + + data int +} + +func (t *testNode) Index() int { + return t.index +} + +func (t *testNode) SetIndex(idx int) { + t.index = idx +} From 132ba49652213da74b31100f8754d5bf6818129e Mon Sep 17 00:00:00 2001 From: Uzziah Date: Sat, 23 Sep 2023 11:07:46 +0800 Subject: [PATCH 2/6] =?UTF-8?q?fix:=20=E9=87=8D=E6=9E=84=E8=BF=87=E6=9C=9F?= =?UTF-8?q?=E6=B8=85=E9=99=A4=E6=89=AB=E6=8F=8F=E6=96=B9=E6=B3=95=E4=BB=A5?= =?UTF-8?q?=E5=8F=8A=E5=88=A0=E9=99=A4=E6=96=B9=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- memory/priority/cache.go | 39 +++++++++++++++++++++++---------------- 1 file changed, 23 insertions(+), 16 deletions(-) diff --git a/memory/priority/cache.go b/memory/priority/cache.go index 31b63ab..a97b5b1 100644 --- a/memory/priority/cache.go +++ b/memory/priority/cache.go @@ -155,8 +155,7 @@ func (c *Cache) Get(ctx context.Context, key string) ecache.Value { // 过期删除 if ok { - _ = c.pq.Remove(ctx, node) - delete(c.index, key) + c.delete(node) c.len-- } @@ -245,26 +244,34 @@ func (c *Cache) clean() { for { select { case <-ticker.C: - c.mu.Lock() - count := 0 - for k, v := range c.index { - if v.Dl.Before(time.Now()) { - _ = c.pq.Remove(context.Background(), v) - delete(c.index, k) - c.len-- - } - count++ - if count >= c.scanCount { - break - } - } - c.mu.Unlock() + c.scan() case <-c.closeC: return } } } +func (c *Cache) scan() { + c.mu.Lock() + defer c.mu.Unlock() + count := 0 + for _, v := range c.index { + if v.Dl.Before(time.Now()) { + c.delete(v) + c.len-- + } + count++ + if count >= c.scanCount { + break + } + } +} + +func (c *Cache) delete(n *Node) { + _ = c.pq.Remove(context.Background(), n) + delete(c.index, n.Key) +} + type Node struct { Key string Val any From 520adf0ef8194e2623d0b6d2a9b6fca64c3fa368 Mon Sep 17 00:00:00 2001 From: Uzziah Date: Sat, 23 Sep 2023 20:57:18 +0800 Subject: [PATCH 3/6] =?UTF-8?q?fix:=20node=20receiver=20=E4=BF=AE=E6=94=B9?= =?UTF-8?q?=E4=B8=BA=E6=8C=87=E9=92=88?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- memory/priority/cache.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/memory/priority/cache.go b/memory/priority/cache.go index a97b5b1..0437ddd 100644 --- a/memory/priority/cache.go +++ b/memory/priority/cache.go @@ -279,10 +279,10 @@ type Node struct { idx int } -func (n Node) Index() int { +func (n *Node) Index() int { return n.idx } -func (n Node) SetIndex(idx int) { +func (n *Node) SetIndex(idx int) { n.idx = idx } From f2a9057cb22d2e8fcbc950d2c066decd938281ed Mon Sep 17 00:00:00 2001 From: Uzziah Date: Sat, 23 Sep 2023 21:00:08 +0800 Subject: [PATCH 4/6] =?UTF-8?q?fix:=20=E5=8E=BB=E6=8E=89=E4=B8=8D=E9=9C=80?= =?UTF-8?q?=E8=A6=81=E7=9A=84=E6=96=B9=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- memory/priority/priority_queue.go | 21 --------------------- 1 file changed, 21 deletions(-) diff --git a/memory/priority/priority_queue.go b/memory/priority/priority_queue.go index dc584d1..8ac39aa 100644 --- a/memory/priority/priority_queue.go +++ b/memory/priority/priority_queue.go @@ -95,27 +95,6 @@ func (q *QueueWithHeap[T]) Remove(ctx context.Context, t *T) error { return nil } -// 0 -// 1 2 -// 3 4 5 6 -// 7 8 9 10 -// -// root -> left(2n+1) root -> right(2n+2) left/right -> root(n-1/2) -// 堆化应该从最后一个有子节点的节点开始堆化(len(heap)-2/2) -func (q *QueueWithHeap[T]) heapify() { - n := q.len - - if n <= 1 { - return - } - - cur := n - 2/2 - - for i := cur; i >= 0; i-- { - q.heapifyDown(cur) - } -} - // heapifyDown 从上往下进行堆化 func (q *QueueWithHeap[T]) heapifyDown(cur int) { n := q.len From eb89dc69ebce16d7cf6bb443de522a9a66272d6c Mon Sep 17 00:00:00 2001 From: Uzziah Date: Sat, 14 Oct 2023 10:21:29 +0800 Subject: [PATCH 5/6] =?UTF-8?q?fix:=20=E6=94=B9=E7=94=A8ekit=E7=9A=84?= =?UTF-8?q?=E4=BC=98=E5=85=88=E7=BA=A7=E9=98=9F=E5=88=97=E5=AE=9E=E7=8E=B0?= =?UTF-8?q?=EF=BC=8C=E6=96=B0=E5=A2=9Edelete=E6=96=B9=E6=B3=95=E7=9A=84?= =?UTF-8?q?=E5=AE=9E=E7=8E=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- memory/priority/cache.go | 121 +++++---- memory/priority/priority_queue.go | 146 ---------- memory/priority/priority_queue_test.go | 352 ------------------------- 3 files changed, 71 insertions(+), 548 deletions(-) delete mode 100644 memory/priority/priority_queue.go delete mode 100644 memory/priority/priority_queue_test.go diff --git a/memory/priority/cache.go b/memory/priority/cache.go index 0437ddd..d1f0cac 100644 --- a/memory/priority/cache.go +++ b/memory/priority/cache.go @@ -5,6 +5,7 @@ import ( "github.com/ecodeclub/ecache" "github.com/ecodeclub/ecache/internal/errs" "github.com/ecodeclub/ekit" + "github.com/ecodeclub/ekit/queue" "sync" "time" ) @@ -17,7 +18,7 @@ func WithCapacity(cap int) Option { } } -func WithComparator(comparator Comparator[Node]) Option { +func WithComparator(comparator ekit.Comparator[*Node]) Option { return func(c *Cache) { c.comparator = comparator } @@ -35,9 +36,22 @@ func NewCache(opts ...Option) ecache.Cache { defaultScanCount := 1000 defaultExpiration := 30 * time.Second + // defaultComparator 默认比较器 按节点的过期时间进行比较 + defaultComparator := func(src, dest *Node) int { + if src.Dl.Before(dest.Dl) { + return -1 + } + + if src.Dl.After(dest.Dl) { + return 1 + } + + return 0 + } + cache := &Cache{ index: make(map[string]*Node), - comparator: defaultComparator{}, + comparator: defaultComparator, cap: defaultCap, cleanInterval: defaultCleanInterval, scanCount: defaultScanCount, @@ -48,39 +62,23 @@ func NewCache(opts ...Option) ecache.Cache { opt(cache) } - cache.pq = NewQueueWithHeap[Node](cache.comparator) + cache.pq = queue.NewPriorityQueue[*Node](defaultCap, cache.comparator) go cache.clean() return cache } -// defaultComparator 默认比较器 按节点的过期时间进行比较 -type defaultComparator struct { -} - -func (d defaultComparator) Compare(src, dest *Node) int { - if src.Dl.Before(dest.Dl) { - return -1 - } - - if src.Dl.After(dest.Dl) { - return 1 - } - - return 0 -} - type Cache struct { - index map[string]*Node // 用于存储数据的索引,方便快速查找 - pq Queue[Node] // 优先级队列,用于存储数据 - comparator Comparator[Node] // 比较器 - mu sync.RWMutex // 读写锁 - cap int // 容量 - len int // 当前队列长度 - cleanInterval time.Duration // 清理过期数据的时间间隔 - scanCount int // 扫描次数 - closeC chan struct{} // 关闭信号 + index map[string]*Node // 用于存储数据的索引,方便快速查找 + pq *queue.PriorityQueue[*Node] // 优先级队列,用于存储数据 + comparator ekit.Comparator[*Node] // 比较器 + mu sync.RWMutex // 读写锁 + cap int // 容量 + len int // 当前队列长度 + cleanInterval time.Duration // 清理过期数据的时间间隔 + scanCount int // 扫描次数 + closeC chan struct{} // 关闭信号 defaultExpiration time.Duration } @@ -101,7 +99,7 @@ func (c *Cache) Set(ctx context.Context, key string, val any, expiration time.Du } func (c *Cache) add(ctx context.Context, key string, val any, expiration time.Duration) { - c.checkCapacityAndDisuse(ctx) + c.checkCapacityAndDisuse() node := &Node{ Key: key, @@ -109,22 +107,33 @@ func (c *Cache) add(ctx context.Context, key string, val any, expiration time.Du Dl: time.Now().Add(expiration), } - _ = c.pq.Push(ctx, node) + _ = c.pq.Enqueue(node) c.index[key] = node c.len++ } -func (c *Cache) checkCapacityAndDisuse(ctx context.Context) { +func (c *Cache) checkCapacityAndDisuse() { if c.len >= c.cap { - // 淘汰优先级最低的数据 - node, _ := c.pq.Pop(ctx) - // 删除索引 - delete(c.index, node.Key) - c.len-- + // 先淘汰堆顶元素,保证有足够的空间插入新数据 + c.disuse() + + // 看下堆顶元素是否是否被标记删除,如果是,则删除 + for top, _ := c.pq.Peek(); top.isDel; top, _ = c.pq.Peek() { + c.disuse() + } + } } +func (c *Cache) disuse() { + // 淘汰优先级最低的数据 + node, _ := c.pq.Dequeue() + // 删除索引 + delete(c.index, node.Key) + c.len-- +} + func (c *Cache) SetNX(ctx context.Context, key string, val any, expiration time.Duration) (bool, error) { c.mu.Lock() defer c.mu.Unlock() @@ -197,6 +206,25 @@ func (c *Cache) GetSet(ctx context.Context, key string, val string) ecache.Value } +func (c *Cache) Delete(ctx context.Context, key ...string) (int64, error) { + c.mu.Lock() + defer c.mu.Unlock() + + var count int64 + + for _, k := range key { + // 这里其实还要考虑过期的情况,如果过期了,是否要计入删除的数量 + // 这里暂时不考虑过期的情况 + if node, ok := c.index[k]; ok { + c.delete(node) + c.len-- + count++ + } + } + + return count, nil +} + func (c *Cache) LPush(ctx context.Context, key string, val ...any) (int64, error) { //TODO implement me panic("implement me") @@ -212,7 +240,7 @@ func (c *Cache) SAdd(ctx context.Context, key string, members ...any) (int64, er panic("implement me") } -func (c *Cache) SRem(ctx context.Context, key string, members ...any) ecache.Value { +func (c *Cache) SRem(ctx context.Context, key string, members ...any) (int64, error) { //TODO implement me panic("implement me") } @@ -268,21 +296,14 @@ func (c *Cache) scan() { } func (c *Cache) delete(n *Node) { - _ = c.pq.Remove(context.Background(), n) + // 标记删除 + n.isDel = true delete(c.index, n.Key) } type Node struct { - Key string - Val any - Dl time.Time // 过期时间 - idx int -} - -func (n *Node) Index() int { - return n.idx -} - -func (n *Node) SetIndex(idx int) { - n.idx = idx + Key string + Val any + Dl time.Time // 过期时间 + isDel bool } diff --git a/memory/priority/priority_queue.go b/memory/priority/priority_queue.go deleted file mode 100644 index 8ac39aa..0000000 --- a/memory/priority/priority_queue.go +++ /dev/null @@ -1,146 +0,0 @@ -package priority - -import ( - "context" - "errors" -) - -type Queue[T any] interface { - Push(ctx context.Context, t *T) error - Pop(ctx context.Context) (*T, error) - Peek(ctx context.Context) (*T, error) - Remove(ctx context.Context, t *T) error // 为了支持随机删除而引入的接口,如果不需要随机删除,可以不实现 -} - -type Comparator[T any] interface { - Compare(src, dest *T) int -} - -type Indexable interface { - Index() int - SetIndex(idx int) -} - -func NewQueueWithHeap[T any](comparator Comparator[T]) Queue[T] { - // 这里可以考虑给一个默认的堆容量 - return &QueueWithHeap[T]{ - heap: make([]*T, 0), - comparator: comparator, - len: 0, - } -} - -type QueueWithHeap[T any] struct { - heap []*T - comparator Comparator[T] - len int -} - -func (q *QueueWithHeap[T]) Push(ctx context.Context, t *T) error { - if len(q.heap) > q.len { - q.heap[q.len] = t - } else { - q.heap = append(q.heap, t) - } - - // 如果是可索引的,需要为这个类型设置索引 - if idx, ok := checkIndexable(t); ok { - idx.SetIndex(q.len) - } - - q.len++ - - q.heapifyUp(q.len - 1) - - return nil -} - -func (q *QueueWithHeap[T]) Pop(ctx context.Context) (*T, error) { - if q.len == 0 { - return nil, errors.New("队列为空") - } - res := q.heap[0] - q.heap[0] = q.heap[q.len-1] - q.heap[q.len-1] = nil // let GC do its work - q.len-- - - q.heapifyDown(0) - return res, nil -} - -func (q *QueueWithHeap[T]) Peek(ctx context.Context) (*T, error) { - if q.len == 0 { - return nil, errors.New("队列为空") - } - - return q.heap[0], nil -} - -// Remove 随机删除一个元素 -// 但是要确保这个元素是在堆里的 -func (q *QueueWithHeap[T]) Remove(ctx context.Context, t *T) error { - idx, ok := checkIndexable(t) - if !ok { - return errors.New("只有实现Indexable的数据才能随机删除") - } - - if idx.Index() >= q.len { - return errors.New("这个元素不在堆里") - } - - q.heap[idx.Index()] = q.heap[q.len-1] - q.heap[q.len-1] = nil // let GC do its work - q.len-- - q.heapifyDown(idx.Index()) - return nil -} - -// heapifyDown 从上往下进行堆化 -func (q *QueueWithHeap[T]) heapifyDown(cur int) { - n := q.len - - // 如果满足 idx <= n - 2 / 2 说明有子节点,需要往下进行堆化 - for cur <= (n-2)>>1 { - l, r := 2*cur+1, 2*cur+2 - min := l - - if r < n && q.comparator.Compare(q.heap[l], q.heap[r]) > 0 { - min = r - } - - // 说明已经满足堆化条件,直接返回 - if q.comparator.Compare(q.heap[cur], q.heap[min]) < 0 { - return - } - - // swap - q.swap(cur, min) - - cur = min - } -} - -// heapifyUp 从下往上进行堆化 -func (q *QueueWithHeap[T]) heapifyUp(cur int) { - for p := (cur - 1) >> 1; cur > 0 && q.comparator.Compare(q.heap[cur], q.heap[p]) < 0; cur, p = p, (p-1)>>1 { - q.swap(cur, p) - } -} - -// swap 交换下标值为src和dest位置的值,如果实现了Indexable接口,则更新以下索引 -func (q *QueueWithHeap[T]) swap(src, dest int) { - q.heap[src], q.heap[dest] = q.heap[dest], q.heap[src] - - if idx, ok := checkIndexable(q.heap[src]); ok { - idx.SetIndex(src) - } - - if idx, ok := checkIndexable(q.heap[dest]); ok { - idx.SetIndex(dest) - } -} - -func checkIndexable(val any) (Indexable, bool) { - idx, ok := val.(Indexable) - return idx, ok -} diff --git a/memory/priority/priority_queue_test.go b/memory/priority/priority_queue_test.go deleted file mode 100644 index 6d52d7a..0000000 --- a/memory/priority/priority_queue_test.go +++ /dev/null @@ -1,352 +0,0 @@ -package priority - -import ( - "context" - "errors" - "github.com/stretchr/testify/assert" - "testing" -) - -func TestQueueWithHeap_Push(t *testing.T) { - - ctx := context.TODO() - - testCases := []struct { - name string - - q Queue[testNode] - - t *testNode - - before func(q Queue[testNode]) - - wantLen int - wantRes []*testNode - }{ - { - // 队列为空,插入一个元素 - name: "insert one element, queue is empty", - q: NewQueueWithHeap[testNode](&testComparator{}), - t: &testNode{data: 1}, - before: func(q Queue[testNode]) { - - }, - wantLen: 1, - wantRes: []*testNode{ - {data: 1, index: 0}, - }, - }, - { - // 队列不为空,插入一个元素 - name: "insert one element, and no heapify", - q: NewQueueWithHeap[testNode](&testComparator{}), - t: &testNode{data: 5}, - before: func(q Queue[testNode]) { - _ = q.Push(ctx, &testNode{data: 2}) - _ = q.Push(ctx, &testNode{data: 3}) - _ = q.Push(ctx, &testNode{data: 4}) - _ = q.Push(ctx, &testNode{data: 6}) - _ = q.Push(ctx, &testNode{data: 7}) - }, - wantLen: 6, - wantRes: []*testNode{ - {data: 2, index: 0}, - {data: 3, index: 1}, - {data: 4, index: 2}, - {data: 6, index: 3}, - {data: 7, index: 4}, - {data: 5, index: 5}, - }, - }, - { - // 队列不为空,插入一个元素 - name: "insert one element, and heapify", - q: NewQueueWithHeap[testNode](&testComparator{}), - t: &testNode{data: 1}, - before: func(q Queue[testNode]) { - _ = q.Push(ctx, &testNode{data: 2}) - _ = q.Push(ctx, &testNode{data: 3}) - _ = q.Push(ctx, &testNode{data: 4}) - _ = q.Push(ctx, &testNode{data: 6}) - _ = q.Push(ctx, &testNode{data: 7}) - }, - wantLen: 6, - wantRes: []*testNode{ - {data: 1, index: 0}, - {data: 3, index: 1}, - {data: 2, index: 2}, - {data: 6, index: 3}, - {data: 7, index: 4}, - {data: 4, index: 5}, - }, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - tc.before(tc.q) - - err := tc.q.Push(ctx, tc.t) - if err != nil { - assert.NoError(t, err) - } - - assert.Equal(t, tc.wantLen, tc.q.(*QueueWithHeap[testNode]).len) - - for i, v := range tc.wantRes { - assert.Equal(t, v, tc.q.(*QueueWithHeap[testNode]).heap[i]) - } - }) - } -} - -func TestQueueWithHeap_Pop(t *testing.T) { - - ctx := context.TODO() - - testCases := []struct { - name string - - q Queue[testNode] - - before func(q Queue[testNode]) - - wantLen int - wantRes *testNode - wantHeap []*testNode - wantErr error - }{ - { - // 队列为空,弹出一个元素 - name: "pop one element, queue is empty", - q: NewQueueWithHeap[testNode](&testComparator{}), - before: func(q Queue[testNode]) { - - }, - wantErr: errors.New("队列为空"), - }, - { - // 当队列只有一个元素,弹出一个元素 - name: "pop one element, queue has one element", - q: NewQueueWithHeap[testNode](&testComparator{}), - before: func(q Queue[testNode]) { - _ = q.Push(ctx, &testNode{data: 2}) - }, - wantLen: 0, - wantRes: &testNode{data: 2, index: 0}, - wantHeap: []*testNode{}, - }, - { - // 堆里多个元素,弹出一个元素 - name: "pop one element, queue has many elements", - q: NewQueueWithHeap[testNode](&testComparator{}), - before: func(q Queue[testNode]) { - _ = q.Push(ctx, &testNode{data: 2}) - _ = q.Push(ctx, &testNode{data: 3}) - _ = q.Push(ctx, &testNode{data: 4}) - _ = q.Push(ctx, &testNode{data: 6}) - _ = q.Push(ctx, &testNode{data: 7}) - }, - wantLen: 4, - wantRes: &testNode{data: 2, index: 0}, - wantHeap: []*testNode{ - {data: 3, index: 0}, - {data: 6, index: 1}, - {data: 4, index: 2}, - {data: 7, index: 3}, - }, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - tc.before(tc.q) - - res, err := tc.q.Pop(ctx) - assert.Equal(t, tc.wantErr, err) - - if err != nil { - return - } - - assert.Equal(t, tc.wantRes, res) - - assert.Equal(t, tc.wantLen, tc.q.(*QueueWithHeap[testNode]).len) - - for i, v := range tc.wantHeap { - assert.Equal(t, v, tc.q.(*QueueWithHeap[testNode]).heap[i]) - } - }) - } -} - -func TestQueueWithHeap_Peek(t *testing.T) { - - ctx := context.TODO() - - testCases := []struct { - name string - - q Queue[testNode] - - before func(q Queue[testNode]) - - wantRes *testNode - wantErr error - }{ - { - // 队列为空,peek一个元素 - name: "peek one element, queue is empty", - q: NewQueueWithHeap[testNode](&testComparator{}), - before: func(q Queue[testNode]) { - - }, - wantErr: errors.New("队列为空"), - }, - { - // 堆里多个元素,peek一个元素 - name: "peek one element, queue has many elements", - q: NewQueueWithHeap[testNode](&testComparator{}), - before: func(q Queue[testNode]) { - _ = q.Push(ctx, &testNode{data: 2}) - _ = q.Push(ctx, &testNode{data: 3}) - _ = q.Push(ctx, &testNode{data: 4}) - _ = q.Push(ctx, &testNode{data: 6}) - _ = q.Push(ctx, &testNode{data: 7}) - }, - wantRes: &testNode{data: 2, index: 0}, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - tc.before(tc.q) - - res, err := tc.q.Peek(ctx) - assert.Equal(t, tc.wantErr, err) - - if err != nil { - return - } - - assert.Equal(t, tc.wantRes, res) - }) - } -} - -func TestQueueWithHeap_Remove(t *testing.T) { - - ctx := context.TODO() - - testCases := []struct { - name string - - q Queue[testNode] - t *testNode - - before func(q Queue[testNode]) - - wantLen int - wantHeap []*testNode - wantErr error - }{ - { - // 删除一个不在队列中的元素 - name: "remove one element, element not in queue", - q: NewQueueWithHeap[testNode](&testComparator{}), - t: &testNode{data: 2}, - before: func(q Queue[testNode]) { - - }, - wantErr: errors.New("这个元素不在堆里"), - }, - { - // 删除一个在队列中的元素 - name: "remove one element, element in queue", - q: NewQueueWithHeap[testNode](&testComparator{}), - t: &testNode{data: 3, index: 1}, - before: func(q Queue[testNode]) { - _ = q.Push(ctx, &testNode{data: 2}) - _ = q.Push(ctx, &testNode{data: 3}) - _ = q.Push(ctx, &testNode{data: 4}) - _ = q.Push(ctx, &testNode{data: 6}) - _ = q.Push(ctx, &testNode{data: 7}) - _ = q.Push(ctx, &testNode{data: 9}) - _ = q.Push(ctx, &testNode{data: 10}) - }, - wantLen: 6, - wantHeap: []*testNode{ - {data: 2, index: 0}, - {data: 6, index: 1}, - {data: 4, index: 2}, - {data: 10, index: 3}, - {data: 7, index: 4}, - {data: 9, index: 5}, - }, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - tc.before(tc.q) - - err := tc.q.Remove(ctx, tc.t) - assert.Equal(t, tc.wantErr, err) - - assert.Equal(t, tc.wantLen, tc.q.(*QueueWithHeap[testNode]).len) - - for i, v := range tc.wantHeap { - assert.Equal(t, v, tc.q.(*QueueWithHeap[testNode]).heap[i]) - } - }) - } -} - -func TestQueueWithHeap_Remove_Not_Indexable(t *testing.T) { - - q := NewQueueWithHeap[int](&testIntComparator{}) - - arg := 1 - - err := q.Remove(context.Background(), &arg) - - assert.Equal(t, errors.New("只有实现Indexable的数据才能随机删除"), err) -} - -type testIntComparator struct{} - -func (t *testIntComparator) Compare(src, dest *int) int { - if *src > *dest { - return 1 - } else if *src < *dest { - return -1 - } else { - return 0 - } -} - -type testComparator struct{} - -func (t *testComparator) Compare(src, dest *testNode) int { - if src.data > dest.data { - return 1 - } else if src.data < dest.data { - return -1 - } else { - return 0 - } -} - -type testNode struct { - index int - - data int -} - -func (t *testNode) Index() int { - return t.index -} - -func (t *testNode) SetIndex(idx int) { - t.index = idx -} From d7bae5decf9ea9330f9e1130b1fcf0deb920f641 Mon Sep 17 00:00:00 2001 From: uzziah Date: Sun, 5 Nov 2023 18:39:06 +0800 Subject: [PATCH 6/6] =?UTF-8?q?fix:=20=E8=A7=A3=E5=86=B3=E6=B5=8B=E8=AF=95?= =?UTF-8?q?=E7=94=A8=E4=BE=8Bdata=20race=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- memory/priority/cache.go | 19 ++++++++++++++++-- memory/priority/cache_test.go | 37 +++++++++++++++++++++++++++++------ 2 files changed, 48 insertions(+), 8 deletions(-) diff --git a/memory/priority/cache.go b/memory/priority/cache.go index d1f0cac..c54fae0 100644 --- a/memory/priority/cache.go +++ b/memory/priority/cache.go @@ -1,13 +1,28 @@ +// Copyright 2023 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package priority import ( "context" + "sync" + "time" + "github.com/ecodeclub/ecache" "github.com/ecodeclub/ecache/internal/errs" "github.com/ecodeclub/ekit" "github.com/ecodeclub/ekit/queue" - "sync" - "time" ) type Option func(c *Cache) diff --git a/memory/priority/cache_test.go b/memory/priority/cache_test.go index 6c98b7b..dd60bed 100644 --- a/memory/priority/cache_test.go +++ b/memory/priority/cache_test.go @@ -1,12 +1,27 @@ +// Copyright 2023 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package priority import ( "context" + "testing" + "time" + "github.com/ecodeclub/ecache" "github.com/ecodeclub/ecache/internal/errs" "github.com/stretchr/testify/assert" - "testing" - "time" ) func TestCache_Set(t *testing.T) { @@ -337,9 +352,11 @@ func TestCache_Get(t *testing.T) { tc.before(tc.cache) for k, v := range tc.beforeGetIndex { - assert.Equal(t, v.Val, tc.cache.(*Cache).index[k].Val) + node := tc.cache.(*Cache).getNode(k) + + assert.Equal(t, v.Val, node.Val) - assert.InDelta(t, v.Dl.Unix(), tc.cache.(*Cache).index[k].Dl.Unix(), 2) + assert.InDelta(t, v.Dl.Unix(), node.Dl.Unix(), 2) } res := tc.cache.Get(ctx, tc.key) @@ -347,9 +364,11 @@ func TestCache_Get(t *testing.T) { assert.Equal(t, len(tc.wantIndex), len(tc.cache.(*Cache).index)) for k, v := range tc.wantIndex { - assert.Equal(t, v.Val, tc.cache.(*Cache).index[k].Val) + node := tc.cache.(*Cache).getNode(k) - assert.InDelta(t, v.Dl.Unix(), tc.cache.(*Cache).index[k].Dl.Unix(), 2) + assert.Equal(t, v.Val, node.Val) + + assert.InDelta(t, v.Dl.Unix(), node.Dl.Unix(), 2) } assert.Equal(t, tc.wantErr, res.Err) @@ -363,6 +382,12 @@ func TestCache_Get(t *testing.T) { } } +func (c *Cache) getNode(key string) *Node { + c.mu.Lock() + defer c.mu.Unlock() + return c.index[key] +} + func TestCache_GetSet(t *testing.T) { ctx := context.TODO()