From ba3ae3f4ee23a56a05aa7b1831dd0e7d9e17d598 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Tue, 6 Dec 2022 19:00:44 +0800 Subject: [PATCH 1/2] =?UTF-8?q?=E5=87=86=E5=A4=87=20Release=200.0.5=20(#12?= =?UTF-8?q?5)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * sqlx 加密列 key长度校验 (#102) * sqlx 加密列 key长度校验 * sqlx 加密列 key长度校验 补单元测试 * 修改加密列key长度错误提示 * atomicx: 泛型封装 atomic.Value (#101) * atomicx: 泛型封装 atomic.Value * 添加 CHANGELOG * syncx/atomicx: 增加 Swap 和 CAS 的泛型包装 * 添加 swap nil 的测试 * 添加更加多的 benchmark 测试,同时保证 NewValue 和 NewValueOf 的语义在 nil 上一致 * 优化单元测试 * queue: API 定义 (#109) * queue: API 定义 * 补充 API 说明 * 实现优先级队列和并发安全优先级队列 (#110) 基于小顶堆和切片的实现 * queue: 延时队列 (#115) * 延迟队列: 优化唤醒入队元素逻辑 (#117) * 修改CHANGELOG链接;添加测试用例修复bug Signed-off-by: longyue0521 * 修改cond的SignalCh为signalCh;理清注释 Signed-off-by: longyue0521 Signed-off-by: longyue0521 * value: AnyValue 设计 (#120) (#121) * value: AnyValue 设计 (#120) * 修复ci检测问题 * 1.fix cr问题 2.add changelog对该pr的引用 3.add license 头部 * 1.修改ChangeLog,加入新特性描述 2.挪出value包,放在根目录 3.统一error格式打印 * 断言方式.Name改为.String Co-authored-by: vividwei * queue: 基于切片的并发阻塞队列和基于 CAS 的并发队列设计 (#119) * queue:使用list包中的LinkedList实现并发阻塞链式队列 (#122) * queue:增加链式并发阻塞队列 Co-authored-by: kangdan Co-authored-by: dan.kang * ConcurrentLinkBlockingQueue 改成ConcurrentLinkedBlockingQueue (#123) * ConcurrentLinkBlockingQueue 改成ConcurrentLinkedBlockingQueue * modify .CHANGELOG.md * modify .CHANGELOG.md Co-authored-by: kangdan Co-authored-by: dan.kang * queue: ConcurrentLinkedQueue增加超时控制逻辑 (#124) Co-authored-by: kangdan Co-authored-by: dan.kang * queue: 添加例子 - 添加队列例子 - 去除 ConcurrentLinkedQueue 的超时机制 * queue: 添加例子 (#126) Signed-off-by: longyue0521 Co-authored-by: hookokoko <648646891@qq.com> Co-authored-by: Gevin Co-authored-by: Longyue Li Co-authored-by: 韦佳栋 <353470469@qq.com> Co-authored-by: vividwei Co-authored-by: kangdan666 <95063166+kangdan6@users.noreply.github.com> Co-authored-by: kangdan Co-authored-by: dan.kang --- .CHANGELOG.md | 8 + internal/errs/error.go | 9 +- internal/queue/priority_queue.go | 136 +++ internal/queue/priority_queue_test.go | 496 ++++++++++ internal/slice/shink_test.go | 72 ++ internal/slice/shrink.go | 40 + list/array_list.go | 17 +- queue/concurrent_array_blocking_queue.go | 133 +++ queue/concurrent_array_blocking_queue_test.go | 433 +++++++++ queue/concurrent_linked_blocking_queue.go | 112 +++ .../concurrent_linked_blocking_queue_test.go | 319 +++++++ queue/concurrent_linked_queue.go | 85 ++ queue/concurrent_linked_queue_test.go | 197 ++++ queue/concurrent_priority_queue.go | 65 ++ queue/concurrent_priority_queue_test.go | 311 ++++++ queue/delay_queue.go | 190 ++++ queue/delay_queue_test.go | 351 +++++++ queue/types.go | 50 + sqlx/encrypt.go | 4 + sqlx/encrypt_test.go | 5 + syncx/atomicx/atomic.go | 64 ++ syncx/atomicx/atomic_test.go | 233 +++++ syncx/atomicx/example_test.go | 74 ++ types.go | 29 + value.go | 236 +++++ value_test.go | 886 ++++++++++++++++++ 26 files changed, 4538 insertions(+), 17 deletions(-) create mode 100644 internal/queue/priority_queue.go create mode 100644 internal/queue/priority_queue_test.go create mode 100644 internal/slice/shink_test.go create mode 100644 internal/slice/shrink.go create mode 100644 queue/concurrent_array_blocking_queue.go create mode 100644 queue/concurrent_array_blocking_queue_test.go create mode 100644 queue/concurrent_linked_blocking_queue.go create mode 100644 queue/concurrent_linked_blocking_queue_test.go create mode 100644 queue/concurrent_linked_queue.go create mode 100644 queue/concurrent_linked_queue_test.go create mode 100644 queue/concurrent_priority_queue.go create mode 100644 queue/concurrent_priority_queue_test.go create mode 100644 queue/delay_queue.go create mode 100644 queue/delay_queue_test.go create mode 100644 queue/types.go create mode 100644 syncx/atomicx/atomic.go create mode 100644 syncx/atomicx/atomic_test.go create mode 100644 syncx/atomicx/example_test.go create mode 100644 types.go create mode 100644 value.go create mode 100644 value_test.go diff --git a/.CHANGELOG.md b/.CHANGELOG.md index f9b24201..abf05386 100644 --- a/.CHANGELOG.md +++ b/.CHANGELOG.md @@ -1,4 +1,12 @@ # 开发中 +- [atomicx: 泛型封装 atomic.Value](https://github.com/gotomicro/ekit/pull/101) +- [queue: API 定义](https://github.com/gotomicro/ekit/pull/109) +- [queue: 基于堆和切片的优先级队列](https://github.com/gotomicro/ekit/pull/110) +- [queue: 延时队列](https://github.com/gotomicro/ekit/pull/115) +- [ekit: AnyValue 设计](https://github.com/gotomicro/ekit/pull/121) +- [queue: 基于切片的并发阻塞队列和基于 CAS 的并发队列设计](https://github.com/gotomicro/ekit/pull/119) +- [queue: 基于链表实现的有界/无界阻塞队列](https://github.com/gotomicro/ekit/pull/122) + - [queue: ConcurrentLinkBlockingQueue重命名为ConcurrentLinkedBlockingQueue](https://github.com/gotomicro/ekit/pull/123) # v0.0.4 - [slice: 重构 index 和 contains 的方法,直接调用对应Func 版本](https://github.com/gotomicro/ekit/pull/87) diff --git a/internal/errs/error.go b/internal/errs/error.go index 28dd4e3f..d8779df9 100644 --- a/internal/errs/error.go +++ b/internal/errs/error.go @@ -14,9 +14,16 @@ package errs -import "fmt" +import ( + "fmt" +) // NewErrIndexOutOfRange 创建一个代表下标超出范围的错误 func NewErrIndexOutOfRange(length int, index int) error { return fmt.Errorf("ekit: 下标超出范围,长度 %d, 下标 %d", length, index) } + +// NewErrInvalidType 创建一个代表类型转换失败的错误 +func NewErrInvalidType(want, got string) error { + return fmt.Errorf("ekit: 类型转换失败,want:%s, got:%s", want, got) +} diff --git a/internal/queue/priority_queue.go b/internal/queue/priority_queue.go new file mode 100644 index 00000000..28a6be72 --- /dev/null +++ b/internal/queue/priority_queue.go @@ -0,0 +1,136 @@ +// Copyright 2021 gotomicro +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package queue + +import ( + "errors" + + "github.com/gotomicro/ekit/internal/slice" + + "github.com/gotomicro/ekit" +) + +var ( + ErrOutOfCapacity = errors.New("ekit: 超出最大容量限制") + ErrEmptyQueue = errors.New("ekit: 队列为空") +) + +// PriorityQueue 是一个基于小顶堆的优先队列 +// 当capacity <= 0时,为无界队列,切片容量会动态扩缩容 +// 当capacity > 0 时,为有界队列,初始化后就固定容量,不会扩缩容 +type PriorityQueue[T any] struct { + // 用于比较前一个元素是否小于后一个元素 + compare ekit.Comparator[T] + // 队列容量 + capacity int + // 队列中的元素,为便于计算父子节点的index,0位置留空,根节点从1开始 + data []T +} + +func (p *PriorityQueue[T]) Len() int { + return len(p.data) - 1 +} + +// Cap 无界队列返回0,有界队列返回创建队列时设置的值 +func (p *PriorityQueue[T]) Cap() int { + return p.capacity +} + +func (p *PriorityQueue[T]) IsBoundless() bool { + return p.capacity <= 0 +} + +func (p *PriorityQueue[T]) isFull() bool { + return p.capacity > 0 && len(p.data)-1 == p.capacity +} + +func (p *PriorityQueue[T]) isEmpty() bool { + return len(p.data) < 2 +} + +func (p *PriorityQueue[T]) Peek() (T, error) { + if p.isEmpty() { + var t T + return t, ErrEmptyQueue + } + return p.data[1], nil +} + +func (p *PriorityQueue[T]) Enqueue(t T) error { + if p.isFull() { + return ErrOutOfCapacity + } + + p.data = append(p.data, t) + node, parent := len(p.data)-1, (len(p.data)-1)/2 + for parent > 0 && p.compare(p.data[node], p.data[parent]) < 0 { + p.data[parent], p.data[node] = p.data[node], p.data[parent] + node = parent + parent = parent / 2 + } + + return nil +} + +func (p *PriorityQueue[T]) Dequeue() (T, error) { + if p.isEmpty() { + var t T + return t, ErrEmptyQueue + } + + pop := p.data[1] + p.data[1] = p.data[len(p.data)-1] + p.data = p.data[:len(p.data)-1] + p.shrinkIfNecessary() + p.heapify(p.data, len(p.data)-1, 1) + return pop, nil +} + +func (p *PriorityQueue[T]) shrinkIfNecessary() { + if p.IsBoundless() { + p.data = slice.Shrink[T](p.data) + } +} + +func (p *PriorityQueue[T]) heapify(data []T, n, i int) { + minPos := i + for { + if left := i * 2; left <= n && p.compare(data[left], data[minPos]) < 0 { + minPos = left + } + if right := i*2 + 1; right <= n && p.compare(data[right], data[minPos]) < 0 { + minPos = right + } + if minPos == i { + break + } + data[i], data[minPos] = data[minPos], data[i] + i = minPos + } +} + +// NewPriorityQueue 创建优先队列 capacity <= 0 时,为无界队列,否则有有界队列 +func NewPriorityQueue[T any](capacity int, compare ekit.Comparator[T]) *PriorityQueue[T] { + sliceCap := capacity + 1 + if capacity < 1 { + capacity = 0 + sliceCap = 64 + } + return &PriorityQueue[T]{ + capacity: capacity, + data: make([]T, 1, sliceCap), + compare: compare, + } +} diff --git a/internal/queue/priority_queue_test.go b/internal/queue/priority_queue_test.go new file mode 100644 index 00000000..40fb2240 --- /dev/null +++ b/internal/queue/priority_queue_test.go @@ -0,0 +1,496 @@ +// Copyright 2021 gotomicro +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package queue + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/gotomicro/ekit" + + "github.com/stretchr/testify/assert" +) + +func TestNewPriorityQueue(t *testing.T) { + data := []int{6, 5, 4, 3, 2, 1} + testCases := []struct { + name string + q *PriorityQueue[int] + capacity int + data []int + expected []int + }{ + { + name: "无边界", + q: NewPriorityQueue(0, compare()), + capacity: 0, + data: data, + expected: []int{1, 2, 3, 4, 5, 6}, + }, + { + name: "有边界 ", + q: NewPriorityQueue(len(data), compare()), + capacity: len(data), + data: data, + expected: []int{1, 2, 3, 4, 5, 6}, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, 0, tc.q.Len()) + for _, d := range data { + err := tc.q.Enqueue(d) + assert.NoError(t, err) + if err != nil { + return + } + } + assert.Equal(t, tc.capacity, tc.q.Cap()) + assert.Equal(t, len(data), tc.q.Len()) + res := make([]int, 0, len(data)) + for tc.q.Len() > 0 { + el, err := tc.q.Dequeue() + assert.NoError(t, err) + if err != nil { + return + } + res = append(res, el) + } + assert.Equal(t, tc.expected, res) + }) + + } + +} + +func TestPriorityQueue_Peek(t *testing.T) { + testCases := []struct { + name string + capacity int + data []int + wantErr error + }{ + { + name: "有数据", + capacity: 0, + data: []int{6, 5, 4, 3, 2, 1}, + wantErr: ErrEmptyQueue, + }, + { + name: "无数据", + capacity: 0, + data: []int{}, + wantErr: ErrEmptyQueue, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + q := NewPriorityQueue[int](tc.capacity, compare()) + for _, el := range tc.data { + err := q.Enqueue(el) + require.NoError(t, err) + } + for q.Len() > 0 { + peek, err := q.Peek() + assert.NoError(t, err) + el, _ := q.Dequeue() + assert.Equal(t, el, peek) + } + _, err := q.Peek() + assert.Equal(t, tc.wantErr, err) + }) + + } +} + +func TestPriorityQueue_Enqueue(t *testing.T) { + testCases := []struct { + name string + capacity int + data []int + element int + wantErr error + }{ + { + name: "有界空队列", + capacity: 10, + data: []int{}, + element: 10, + }, + { + name: "有界满队列", + capacity: 6, + data: []int{6, 5, 4, 3, 2, 1}, + element: 10, + wantErr: ErrOutOfCapacity, + }, + { + name: "有界非空不满队列", + capacity: 12, + data: []int{6, 5, 4, 3, 2, 1}, + element: 10, + }, + { + name: "无界空队列", + capacity: 0, + data: []int{}, + element: 10, + }, + { + name: "无界非空队列", + capacity: 0, + data: []int{6, 5, 4, 3, 2, 1}, + element: 10, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + q := priorityQueueOf(tc.capacity, tc.data, compare()) + require.NotNil(t, q) + err := q.Enqueue(tc.element) + assert.Equal(t, tc.wantErr, err) + assert.Equal(t, tc.capacity, q.Cap()) + }) + + } +} + +func TestPriorityQueue_EnqueueElement(t *testing.T) { + testCases := []struct { + name string + data []int + element int + wantSlice []int + }{ + { + name: "新加入的元素是最大的", + data: []int{10, 8, 7, 6, 2}, + element: 20, + wantSlice: []int{0, 2, 6, 8, 10, 7, 20}, + }, + { + name: "新加入的元素是最小的", + data: []int{10, 8, 7, 6, 2}, + element: 1, + wantSlice: []int{0, 1, 6, 2, 10, 7, 8}, + }, + { + name: "新加入的元素子区间中", + data: []int{10, 8, 7, 6, 2}, + element: 5, + wantSlice: []int{0, 2, 6, 5, 10, 7, 8}, + }, + { + name: "新加入的元素与已有元素相同", + data: []int{10, 8, 7, 6, 2}, + element: 6, + wantSlice: []int{0, 2, 6, 6, 10, 7, 8}, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + q := priorityQueueOf(0, tc.data, compare()) + require.NotNil(t, q) + err := q.Enqueue(tc.element) + require.NoError(t, err) + assert.Equal(t, tc.wantSlice, q.data) + }) + + } +} + +func TestPriorityQueue_EnqueueHeapStruct(t *testing.T) { + data := []int{6, 5, 4, 3, 2, 1} + testCases := []struct { + name string + capacity int + data []int + wantSlice []int + pivot int + pivotData []int + }{ + { + name: "队列满", + capacity: len(data), + data: data, + pivot: 2, + pivotData: []int{0, 4, 6, 5}, + wantSlice: []int{0, 1, 3, 2, 6, 4, 5}, + }, + { + name: "队列不满", + capacity: len(data) * 2, + data: data, + pivot: 3, + pivotData: []int{0, 3, 4, 5, 6}, + wantSlice: []int{0, 1, 3, 2, 6, 4, 5}, + }, + { + name: "无界队列", + capacity: 0, + data: data, + pivot: 3, + pivotData: []int{0, 3, 4, 5, 6}, + wantSlice: []int{0, 1, 3, 2, 6, 4, 5}, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + q := NewPriorityQueue[int](tc.capacity, compare()) + for i, el := range tc.data { + require.NoError(t, q.Enqueue(el)) + // 检查中途堆结构堆调整,是否符合预期 + if i == tc.pivot { + assert.Equal(t, tc.pivotData, q.data) + } + } + // 检查最终堆结构,是否符合预期 + assert.Equal(t, tc.wantSlice, q.data) + }) + + } +} + +func TestPriorityQueue_Dequeue(t *testing.T) { + testCases := []struct { + name string + data []int + wantErr error + wantVal int + wantSlice []int + }{ + { + name: "空队列", + data: []int{}, + wantErr: ErrEmptyQueue, + }, + { + name: "只有一个元素", + data: []int{10}, + wantVal: 10, + wantSlice: []int{0}, + }, + { + name: "many", + data: []int{6, 5, 4, 3, 2, 1}, + wantVal: 1, + wantSlice: []int{0, 2, 3, 5, 6, 4}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + q := priorityQueueOf(0, tc.data, compare()) + require.NotNil(t, q) + val, err := q.Dequeue() + assert.Equal(t, tc.wantErr, err) + if err != nil { + return + } + assert.Equal(t, tc.wantSlice, q.data) + assert.Equal(t, tc.wantVal, val) + }) + } +} + +func TestPriorityQueue_DequeueComplexCheck(t *testing.T) { + testCases := []struct { + name string + capacity int + data []int + pivot int + want []int + }{ + { + name: "无边界", + capacity: 0, + data: []int{6, 5, 4, 3, 2, 1}, + pivot: 2, + want: []int{0, 4, 6, 5}, + }, + { + name: "有边界", + capacity: 6, + data: []int{6, 5, 4, 3, 2, 1}, + pivot: 3, + want: []int{0, 5, 6}, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + q := priorityQueueOf(tc.capacity, tc.data, compare()) + require.NotNil(t, q) + i := 0 + prev := -1 + for q.Len() > 0 { + el, err := q.Dequeue() + require.NoError(t, err) + // 检查中途出队后,堆结构堆调整是否符合预期 + if i == tc.pivot { + assert.Equal(t, tc.want, q.data) + } + // 检查出队是否有序 + assert.LessOrEqual(t, prev, el) + prev = el + i++ + } + }) + + } +} + +func TestPriorityQueue_Shrink(t *testing.T) { + var compare ekit.Comparator[int] = func(a, b int) int { + if a < b { + return -1 + } + if a == b { + return 0 + } + return 1 + } + testCases := []struct { + name string + originCap int + enqueueLoop int + dequeueLoop int + expectCap int + sliceCap int + }{ + { + name: "有界,小于64", + originCap: 32, + enqueueLoop: 6, + dequeueLoop: 5, + expectCap: 32, + sliceCap: 33, + }, + { + name: "有界,小于2048, 不足1/4", + originCap: 1000, + enqueueLoop: 20, + dequeueLoop: 5, + expectCap: 1000, + sliceCap: 1001, + }, + { + name: "有界,小于2048, 超过1/4", + originCap: 1000, + enqueueLoop: 400, + dequeueLoop: 5, + expectCap: 1000, + sliceCap: 1001, + }, + { + name: "有界,大于2048,不足一半", + originCap: 3000, + enqueueLoop: 400, + dequeueLoop: 40, + expectCap: 3000, + sliceCap: 3001, + }, + { + name: "有界,大于2048,大于一半", + originCap: 3000, + enqueueLoop: 2000, + dequeueLoop: 5, + expectCap: 3000, + sliceCap: 3001, + }, + { + name: "无界,小于64", + originCap: 0, + enqueueLoop: 30, + dequeueLoop: 5, + expectCap: 0, + sliceCap: 64, + }, + { + name: "无界,小于2048, 不足1/4", + originCap: 0, + enqueueLoop: 2000, + dequeueLoop: 1990, + expectCap: 0, + sliceCap: 50, + }, + { + name: "无界,小于2048, 超过1/4", + originCap: 0, + enqueueLoop: 2000, + dequeueLoop: 600, + expectCap: 0, + sliceCap: 2560, + }, + { + name: "无界,大于2048,不足一半", + originCap: 0, + enqueueLoop: 3000, + dequeueLoop: 2000, + expectCap: 0, + sliceCap: 1331, + }, + { + name: "无界,大于2048,大于一半", + originCap: 0, + enqueueLoop: 3000, + dequeueLoop: 5, + expectCap: 0, + sliceCap: 3408, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + q := NewPriorityQueue[int](tc.originCap, compare) + for i := 0; i < tc.enqueueLoop; i++ { + err := q.Enqueue(i) + if err != nil { + return + } + } + for i := 0; i < tc.dequeueLoop; i++ { + _, err := q.Dequeue() + if err != nil { + return + } + } + assert.Equal(t, tc.expectCap, q.Cap()) + assert.Equal(t, tc.sliceCap, cap(q.data)) + }) + } +} + +func priorityQueueOf(capacity int, data []int, compare ekit.Comparator[int]) *PriorityQueue[int] { + q := NewPriorityQueue[int](capacity, compare) + for _, el := range data { + err := q.Enqueue(el) + if err != nil { + return nil + } + } + return q +} + +func compare() ekit.Comparator[int] { + return func(a, b int) int { + if a < b { + return -1 + } + if a == b { + return 0 + } + return 1 + } +} diff --git a/internal/slice/shink_test.go b/internal/slice/shink_test.go new file mode 100644 index 00000000..366fd207 --- /dev/null +++ b/internal/slice/shink_test.go @@ -0,0 +1,72 @@ +// Copyright 2021 gotomicro +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package slice + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestShrink(t *testing.T) { + testCases := []struct { + name string + originCap int + enqueueLoop int + expectCap int + }{ + { + name: "小于64", + originCap: 32, + enqueueLoop: 6, + expectCap: 32, + }, + { + name: "小于2048, 不足1/4", + originCap: 1000, + enqueueLoop: 20, + expectCap: 500, + }, + { + name: "小于2048, 超过1/4", + originCap: 1000, + enqueueLoop: 400, + expectCap: 1000, + }, + { + name: "大于2048,不足一半", + originCap: 3000, + enqueueLoop: 60, + expectCap: 1875, + }, + { + name: "大于2048,大于一半", + originCap: 3000, + enqueueLoop: 2000, + expectCap: 3000, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + l := make([]int, 0, tc.originCap) + + for i := 0; i < tc.enqueueLoop; i++ { + l = append(l, i) + } + l = Shrink[int](l) + assert.Equal(t, tc.expectCap, cap(l)) + }) + } +} diff --git a/internal/slice/shrink.go b/internal/slice/shrink.go new file mode 100644 index 00000000..911adf3a --- /dev/null +++ b/internal/slice/shrink.go @@ -0,0 +1,40 @@ +// Copyright 2021 gotomicro +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package slice + +func calCapacity(c, l int) (int, bool) { + if c <= 64 { + return c, false + } + if c > 2048 && (c/l >= 2) { + factor := 0.625 + return int(float32(c) * float32(factor)), true + } + if c <= 2048 && (c/l >= 4) { + return c / 2, true + } + return c, false +} + +func Shrink[T any](src []T) []T { + c, l := cap(src), len(src) + n, changed := calCapacity(c, l) + if !changed { + return src + } + s := make([]T, 0, n) + s = append(s, src...) + return s +} diff --git a/list/array_list.go b/list/array_list.go index 90c9d084..bf03ce02 100644 --- a/list/array_list.go +++ b/list/array_list.go @@ -92,22 +92,7 @@ func (a *ArrayList[T]) Delete(index int) (T, error) { // shrink 数组缩容 func (a *ArrayList[T]) shrink() { - var newCap int - c, l := a.Cap(), a.Len() - if c <= 64 { - return - } - if c > 2048 && (c/l >= 2) { - newCap = int(float32(c) * float32(0.625)) - } else if c <= 2048 && (c/l >= 4) { - newCap = c / 2 - } else { - // 不满足缩容 - return - } - newSlice := make([]T, 0, newCap) - newSlice = append(newSlice, a.vals...) - a.vals = newSlice + a.vals = slice.Shrink(a.vals) } func (a *ArrayList[T]) Len() int { diff --git a/queue/concurrent_array_blocking_queue.go b/queue/concurrent_array_blocking_queue.go new file mode 100644 index 00000000..3b4e62da --- /dev/null +++ b/queue/concurrent_array_blocking_queue.go @@ -0,0 +1,133 @@ +// Copyright 2021 gotomicro +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package queue + +import ( + "context" + "sync" +) + +// ConcurrentArrayBlockingQueue 有界并发阻塞队列 +type ConcurrentArrayBlockingQueue[T any] struct { + data []T + mutex *sync.RWMutex + + // 队头元素下标 + head int + // 队尾元素下标 + tail int + // 包含多少个元素 + count int + + notEmpty *cond + notFull *cond + + // zero 不能作为返回值返回,防止用户篡改 + zero T +} + +// NewConcurrentArrayBlockingQueue 创建一个有界阻塞队列 +// 容量会在最开始的时候就初始化好 +// capacity 必须为正数 +func NewConcurrentArrayBlockingQueue[T any](capacity int) *ConcurrentArrayBlockingQueue[T] { + mutex := &sync.RWMutex{} + res := &ConcurrentArrayBlockingQueue[T]{ + data: make([]T, capacity), + mutex: mutex, + notEmpty: newCond(mutex), + notFull: newCond(mutex), + } + return res +} + +// Enqueue 入队 +// 注意:目前我们已经通过broadcast实现了超时控制 +func (c *ConcurrentArrayBlockingQueue[T]) Enqueue(ctx context.Context, t T) error { + if ctx.Err() != nil { + return ctx.Err() + } + c.mutex.Lock() + for c.count == len(c.data) { + signal := c.notFull.signalCh() + select { + case <-ctx.Done(): + return ctx.Err() + case <-signal: + // 收到信号要重新加锁 + c.mutex.Lock() + } + } + c.data[c.tail] = t + c.tail++ + c.count++ + // c.tail 已经是最后一个了,重置下标 + if c.tail == cap(c.data) { + c.tail = 0 + } + // 这里会释放锁 + c.notEmpty.broadcast() + return nil +} + +// Dequeue 出队 +// 注意:目前我们已经通过broadcast实现了超时控制 +func (c *ConcurrentArrayBlockingQueue[T]) Dequeue(ctx context.Context) (T, error) { + if ctx.Err() != nil { + var t T + return t, ctx.Err() + } + c.mutex.Lock() + for c.count == 0 { + signal := c.notEmpty.signalCh() + select { + case <-ctx.Done(): + var t T + return t, ctx.Err() + case <-signal: + c.mutex.Lock() + } + } + val := c.data[c.head] + // 为了释放内存,GC + c.data[c.head] = c.zero + c.count-- + c.head++ + // 重置下标 + if c.head == cap(c.data) { + c.head = 0 + } + c.notFull.broadcast() + return val, nil +} + +func (c *ConcurrentArrayBlockingQueue[T]) Len() int { + c.mutex.RLock() + defer c.mutex.RUnlock() + return c.count +} + +func (c *ConcurrentArrayBlockingQueue[T]) AsSlice() []T { + c.mutex.RLock() + defer c.mutex.RUnlock() + res := make([]T, 0, c.count) + cnt := 0 + capacity := cap(c.data) + for cnt < c.count { + index := (c.head + cnt) % capacity + res = append(res, c.data[index]) + cnt++ + } + return res +} diff --git a/queue/concurrent_array_blocking_queue_test.go b/queue/concurrent_array_blocking_queue_test.go new file mode 100644 index 00000000..b7ce3359 --- /dev/null +++ b/queue/concurrent_array_blocking_queue_test.go @@ -0,0 +1,433 @@ +// Copyright 2021 gotomicro +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package queue + +import ( + "context" + "fmt" + "math/rand" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestConcurrentBlockingQueue_Enqueue(t *testing.T) { + testCases := []struct { + name string + q func() *ConcurrentArrayBlockingQueue[int] + val int + timeout time.Duration + wantErr error + wantData []int + wantSlice []int + wantLen int + wantHead int + wantTail int + }{ + { + name: "empty and enqueued", + q: func() *ConcurrentArrayBlockingQueue[int] { + return NewConcurrentArrayBlockingQueue[int](3) + }, + val: 123, + timeout: time.Second, + wantData: []int{123, 0, 0}, + wantSlice: []int{123}, + wantLen: 1, + wantTail: 1, + wantHead: 0, + }, + { + name: "invalid context", + q: func() *ConcurrentArrayBlockingQueue[int] { + return NewConcurrentArrayBlockingQueue[int](3) + }, + val: 123, + timeout: -time.Second, + wantData: []int{0, 0, 0}, + wantSlice: []int{}, + wantErr: context.DeadlineExceeded, + }, + { + // 入队之后就满了,恰好放在切片的最后一个位置 + name: "enqueued full last index", + q: func() *ConcurrentArrayBlockingQueue[int] { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + q := NewConcurrentArrayBlockingQueue[int](3) + err := q.Enqueue(ctx, 123) + require.NoError(t, err) + err = q.Enqueue(ctx, 234) + require.NoError(t, err) + return q + }, + val: 345, + timeout: time.Second, + wantData: []int{123, 234, 345}, + wantSlice: []int{123, 234, 345}, + wantLen: 3, + wantTail: 0, + wantHead: 0, + }, + { + // 入队之后就满了,恰好放在切片的第一个 + name: "enqueued full middle index", + q: func() *ConcurrentArrayBlockingQueue[int] { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + q := NewConcurrentArrayBlockingQueue[int](3) + err := q.Enqueue(ctx, 123) + require.NoError(t, err) + err = q.Enqueue(ctx, 234) + require.NoError(t, err) + err = q.Enqueue(ctx, 345) + require.NoError(t, err) + val, err := q.Dequeue(ctx) + require.NoError(t, err) + require.Equal(t, 123, val) + return q + }, + val: 456, + timeout: time.Second, + wantData: []int{456, 234, 345}, + wantSlice: []int{234, 345, 456}, + wantLen: 3, + wantTail: 1, + wantHead: 1, + }, + { + // 入队之后就满了,恰好放在中间 + name: "enqueued full first index", + q: func() *ConcurrentArrayBlockingQueue[int] { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + q := NewConcurrentArrayBlockingQueue[int](3) + err := q.Enqueue(ctx, 123) + require.NoError(t, err) + err = q.Enqueue(ctx, 234) + require.NoError(t, err) + err = q.Enqueue(ctx, 345) + require.NoError(t, err) + val, err := q.Dequeue(ctx) + require.NoError(t, err) + require.Equal(t, 123, val) + val, err = q.Dequeue(ctx) + require.NoError(t, err) + require.Equal(t, 234, val) + err = q.Enqueue(ctx, 456) + require.NoError(t, err) + return q + }, + val: 567, + timeout: time.Second, + wantData: []int{456, 567, 345}, + wantSlice: []int{345, 456, 567}, + wantLen: 3, + wantTail: 2, + wantHead: 2, + }, + { + // 元素本身就是零值 + name: "all zero value ", + q: func() *ConcurrentArrayBlockingQueue[int] { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + q := NewConcurrentArrayBlockingQueue[int](3) + err := q.Enqueue(ctx, 0) + require.NoError(t, err) + err = q.Enqueue(ctx, 0) + require.NoError(t, err) + return q + }, + val: 0, + timeout: time.Second, + wantData: []int{0, 0, 0}, + wantSlice: []int{0, 0, 0}, + wantLen: 3, + wantTail: 0, + wantHead: 0, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), tc.timeout) + defer cancel() + q := tc.q() + err := q.Enqueue(ctx, tc.val) + assert.Equal(t, tc.wantErr, err) + assert.Equal(t, tc.wantData, q.data) + assert.Equal(t, tc.wantSlice, q.AsSlice()) + assert.Equal(t, tc.wantLen, q.Len()) + assert.Equal(t, tc.wantHead, q.head) + assert.Equal(t, tc.wantTail, q.tail) + }) + } + + t.Run("enqueue timeout", func(t *testing.T) { + q := NewConcurrentArrayBlockingQueue[int](3) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + err := q.Enqueue(ctx, 123) + require.NoError(t, err) + err = q.Enqueue(ctx, 234) + require.NoError(t, err) + err = q.Enqueue(ctx, 345) + require.NoError(t, err) + err = q.Enqueue(ctx, 456) + require.Equal(t, context.DeadlineExceeded, err) + }) + + // 入队阻塞,而后出队,于是入队成功 + t.Run("enqueue blocking and dequeue", func(t *testing.T) { + q := NewConcurrentArrayBlockingQueue[int](3) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + go func() { + time.Sleep(time.Millisecond * 100) + val, err := q.Dequeue(ctx) + require.NoError(t, err) + require.Equal(t, 123, val) + }() + err := q.Enqueue(ctx, 123) + require.NoError(t, err) + err = q.Enqueue(ctx, 234) + require.NoError(t, err) + err = q.Enqueue(ctx, 345) + require.NoError(t, err) + err = q.Enqueue(ctx, 456) + require.NoError(t, err) + }) +} + +func TestConcurrentBlockingQueue_Dequeue(t *testing.T) { + testCases := []struct { + name string + q func() *ConcurrentArrayBlockingQueue[int] + val int + timeout time.Duration + wantErr error + wantVal int + wantData []int + wantSlice []int + wantLen int + wantHead int + wantTail int + }{ + { + name: "dequeued", + q: func() *ConcurrentArrayBlockingQueue[int] { + q := NewConcurrentArrayBlockingQueue[int](3) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + err := q.Enqueue(ctx, 123) + require.NoError(t, err) + err = q.Enqueue(ctx, 234) + require.NoError(t, err) + return q + }, + wantVal: 123, + timeout: time.Second, + wantData: []int{0, 234, 0}, + wantSlice: []int{234}, + wantLen: 1, + wantHead: 1, + wantTail: 2, + }, + { + name: "invalid context", + q: func() *ConcurrentArrayBlockingQueue[int] { + q := NewConcurrentArrayBlockingQueue[int](3) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + err := q.Enqueue(ctx, 123) + require.NoError(t, err) + err = q.Enqueue(ctx, 234) + require.NoError(t, err) + return q + }, + wantErr: context.DeadlineExceeded, + timeout: -time.Second, + wantData: []int{123, 234, 0}, + wantSlice: []int{123, 234}, + wantLen: 2, + wantHead: 0, + wantTail: 2, + }, + { + name: "dequeue and empty first", + q: func() *ConcurrentArrayBlockingQueue[int] { + q := NewConcurrentArrayBlockingQueue[int](3) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + err := q.Enqueue(ctx, 123) + require.NoError(t, err) + return q + }, + wantVal: 123, + timeout: time.Second, + wantData: []int{0, 0, 0}, + wantSlice: []int{}, + wantLen: 0, + wantHead: 1, + wantTail: 1, + }, + { + name: "dequeue and empty middle", + q: func() *ConcurrentArrayBlockingQueue[int] { + q := NewConcurrentArrayBlockingQueue[int](3) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + err := q.Enqueue(ctx, 123) + require.NoError(t, err) + err = q.Enqueue(ctx, 234) + require.NoError(t, err) + val, err := q.Dequeue(ctx) + require.NoError(t, err) + require.Equal(t, 123, val) + return q + }, + wantVal: 234, + timeout: time.Second, + wantData: []int{0, 0, 0}, + wantSlice: []int{}, + wantLen: 0, + wantHead: 2, + wantTail: 2, + }, + { + name: "dequeue and empty last", + q: func() *ConcurrentArrayBlockingQueue[int] { + q := NewConcurrentArrayBlockingQueue[int](3) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + err := q.Enqueue(ctx, 123) + require.NoError(t, err) + err = q.Enqueue(ctx, 234) + require.NoError(t, err) + err = q.Enqueue(ctx, 345) + require.NoError(t, err) + val, err := q.Dequeue(ctx) + require.NoError(t, err) + require.Equal(t, 123, val) + val, err = q.Dequeue(ctx) + require.NoError(t, err) + require.Equal(t, 234, val) + return q + }, + wantVal: 345, + timeout: time.Second, + wantData: []int{0, 0, 0}, + wantSlice: []int{}, + wantLen: 0, + wantHead: 0, + wantTail: 0, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), tc.timeout) + defer cancel() + q := tc.q() + val, err := q.Dequeue(ctx) + assert.Equal(t, tc.wantErr, err) + assert.Equal(t, tc.wantVal, val) + assert.Equal(t, tc.wantData, q.data) + assert.Equal(t, tc.wantSlice, q.AsSlice()) + assert.Equal(t, tc.wantLen, q.Len()) + assert.Equal(t, tc.wantHead, q.head) + assert.Equal(t, tc.wantTail, q.tail) + }) + } + + t.Run("dequeue timeout", func(t *testing.T) { + q := NewConcurrentArrayBlockingQueue[int](3) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + val, err := q.Dequeue(ctx) + require.Equal(t, context.DeadlineExceeded, err) + require.Equal(t, 0, val) + }) + + // 出队阻塞,然后入队,然后出队成功 + t.Run("dequeue blocking and enqueue", func(t *testing.T) { + q := NewConcurrentArrayBlockingQueue[int](3) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + go func() { + time.Sleep(time.Millisecond * 100) + err := q.Enqueue(ctx, 123) + require.NoError(t, err) + }() + val, err := q.Dequeue(ctx) + require.NoError(t, err) + require.Equal(t, 123, val) + }) +} + +func TestConcurrentArrayBlockingQueue(t *testing.T) { + // 并发测试,只是测试有没有死锁之类的问题 + // 先进先出这个特性依赖于其它单元测试 + // 也依赖于代码审查 + q := NewConcurrentArrayBlockingQueue[int](100) + var wg sync.WaitGroup + wg.Add(1000) + for i := 0; i < 1000; i++ { + go func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + val := rand.Int() + err := q.Enqueue(ctx, val) + cancel() + require.NoError(t, err) + }() + } + go func() { + for i := 0; i < 1000; i++ { + go func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + _, err := q.Dequeue(ctx) + cancel() + require.NoError(t, err) + wg.Done() + }() + } + }() + wg.Wait() +} + +func ExampleNewConcurrentArrayBlockingQueue() { + q := NewConcurrentArrayBlockingQueue[int](10) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + _ = q.Enqueue(ctx, 22) + val, err := q.Dequeue(ctx) + // 这是例子,实际中你不需要写得那么复杂 + switch err { + case context.Canceled: + // 有人主动取消了,即调用了 cancel 方法。在这个例子里不会出现这个情况 + case context.DeadlineExceeded: + // 超时了 + case nil: + fmt.Println(val) + default: + // 其它乱七八糟的 + } + // Output: + // 22 +} diff --git a/queue/concurrent_linked_blocking_queue.go b/queue/concurrent_linked_blocking_queue.go new file mode 100644 index 00000000..9cdc58a1 --- /dev/null +++ b/queue/concurrent_linked_blocking_queue.go @@ -0,0 +1,112 @@ +// Copyright 2021 gotomicro +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package queue + +import ( + "context" + "sync" + + "github.com/gotomicro/ekit/list" +) + +// ConcurrentLinkedBlockingQueue 基于链表的并发阻塞队列 +// 如果 maxSize 是正数。那么就是有界并发阻塞队列 +// 如果不是,就是无界并发阻塞队列, 在这种情况下,入队永远能够成功 +type ConcurrentLinkedBlockingQueue[T any] struct { + mutex *sync.RWMutex + + // 最大容量 + maxSize int + // 链表 + linkedlist *list.LinkedList[T] + + notEmpty *cond + notFull *cond +} + +// NewConcurrentLinkedBlockingQueue 创建链式阻塞队列 capacity <= 0 时,为无界队列 +func NewConcurrentLinkedBlockingQueue[T any](capacity int) *ConcurrentLinkedBlockingQueue[T] { + mutex := &sync.RWMutex{} + res := &ConcurrentLinkedBlockingQueue[T]{ + maxSize: capacity, + mutex: mutex, + notEmpty: newCond(mutex), + notFull: newCond(mutex), + linkedlist: list.NewLinkedList[T](), + } + return res +} + +// Enqueue 入队 +// 注意:目前我们已经通过broadcast实现了超时控制 +func (c *ConcurrentLinkedBlockingQueue[T]) Enqueue(ctx context.Context, t T) error { + if ctx.Err() != nil { + return ctx.Err() + } + c.mutex.Lock() + for c.maxSize > 0 && c.linkedlist.Len() == c.maxSize { + signal := c.notFull.signalCh() + select { + case <-ctx.Done(): + return ctx.Err() + case <-signal: + // 收到信号要重新加锁 + c.mutex.Lock() + } + } + + err := c.linkedlist.Append(t) + + // 这里会释放锁 + c.notEmpty.broadcast() + return err +} + +// Dequeue 出队 +// 注意:目前我们已经通过broadcast实现了超时控制 +func (c *ConcurrentLinkedBlockingQueue[T]) Dequeue(ctx context.Context) (T, error) { + if ctx.Err() != nil { + var t T + return t, ctx.Err() + } + c.mutex.Lock() + for c.linkedlist.Len() == 0 { + signal := c.notEmpty.signalCh() + select { + case <-ctx.Done(): + var t T + return t, ctx.Err() + case <-signal: + c.mutex.Lock() + } + } + + val, err := c.linkedlist.Delete(0) + c.notFull.broadcast() + return val, err +} + +func (c *ConcurrentLinkedBlockingQueue[T]) Len() int { + c.mutex.RLock() + defer c.mutex.RUnlock() + return c.linkedlist.Len() +} + +func (c *ConcurrentLinkedBlockingQueue[T]) AsSlice() []T { + c.mutex.RLock() + defer c.mutex.RUnlock() + res := c.linkedlist.AsSlice() + return res +} diff --git a/queue/concurrent_linked_blocking_queue_test.go b/queue/concurrent_linked_blocking_queue_test.go new file mode 100644 index 00000000..4391f0e1 --- /dev/null +++ b/queue/concurrent_linked_blocking_queue_test.go @@ -0,0 +1,319 @@ +// Copyright 2021 gotomicro +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package queue + +import ( + "context" + "fmt" + "math/rand" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestConcurrentLinkedBlockingQueue_Enqueue(t *testing.T) { + testCases := []struct { + name string + q func() *ConcurrentLinkedBlockingQueue[int] + val int + timeout time.Duration + wantErr error + wantSlice []int + wantLen int + }{ + { + name: "empty and enqueued", + q: func() *ConcurrentLinkedBlockingQueue[int] { + return NewConcurrentLinkedBlockingQueue[int](3) + }, + val: 123, + timeout: time.Second, + wantSlice: []int{123}, + wantLen: 1, + }, + { + name: "invalid context", + q: func() *ConcurrentLinkedBlockingQueue[int] { + return NewConcurrentLinkedBlockingQueue[int](3) + }, + val: 123, + timeout: -time.Second, + wantSlice: []int{}, + wantErr: context.DeadlineExceeded, + }, + { + // 入队之后就满了,恰好放在切片的最后一个位置 + name: "enqueued full last index", + q: func() *ConcurrentLinkedBlockingQueue[int] { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + q := NewConcurrentLinkedBlockingQueue[int](3) + err := q.Enqueue(ctx, 123) + require.NoError(t, err) + err = q.Enqueue(ctx, 234) + require.NoError(t, err) + return q + }, + val: 345, + timeout: time.Second, + wantSlice: []int{123, 234, 345}, + wantLen: 3, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), tc.timeout) + defer cancel() + q := tc.q() + err := q.Enqueue(ctx, tc.val) + assert.Equal(t, tc.wantErr, err) + assert.Equal(t, tc.wantSlice, q.AsSlice()) + assert.Equal(t, tc.wantLen, q.Len()) + }) + } + + t.Run("enqueue timeout", func(t *testing.T) { + q := NewConcurrentLinkedBlockingQueue[int](3) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + err := q.Enqueue(ctx, 123) + require.NoError(t, err) + err = q.Enqueue(ctx, 234) + require.NoError(t, err) + err = q.Enqueue(ctx, 345) + require.NoError(t, err) + err = q.Enqueue(ctx, 456) + require.Equal(t, context.DeadlineExceeded, err) + }) + + // 入队阻塞,而后出队,于是入队成功 + t.Run("enqueue blocking and dequeue", func(t *testing.T) { + q := NewConcurrentLinkedBlockingQueue[int](3) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + go func() { + time.Sleep(time.Millisecond * 100) + val, err := q.Dequeue(ctx) + require.NoError(t, err) + require.Equal(t, 123, val) + }() + err := q.Enqueue(ctx, 123) + require.NoError(t, err) + err = q.Enqueue(ctx, 234) + require.NoError(t, err) + err = q.Enqueue(ctx, 345) + require.NoError(t, err) + err = q.Enqueue(ctx, 456) + require.NoError(t, err) + }) + + // 无界的情况下,可以无限添加元素,当然小心内存, 以及goroutine调度导致的超时 + // capacity <= 0 时,为无界队列 + t.Run("capacity <= 0", func(t *testing.T) { + q := NewConcurrentLinkedBlockingQueue[int](-1) + for i := 0; i < 10; i++ { + go func() { + for i := 0; i < 1000; i++ { + ctx := context.Background() + val := rand.Int() + err := q.Enqueue(ctx, val) + require.NoError(t, err) + } + + }() + } + }) +} + +func TestConcurrentLinkedBlockingQueue_Dequeue(t *testing.T) { + testCases := []struct { + name string + q func() *ConcurrentLinkedBlockingQueue[int] + val int + timeout time.Duration + wantErr error + wantVal int + wantSlice []int + wantLen int + }{ + { + name: "dequeued", + q: func() *ConcurrentLinkedBlockingQueue[int] { + q := NewConcurrentLinkedBlockingQueue[int](3) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + err := q.Enqueue(ctx, 123) + require.NoError(t, err) + err = q.Enqueue(ctx, 234) + require.NoError(t, err) + return q + }, + wantVal: 123, + timeout: time.Second, + wantSlice: []int{234}, + wantLen: 1, + }, + { + name: "invalid context", + q: func() *ConcurrentLinkedBlockingQueue[int] { + q := NewConcurrentLinkedBlockingQueue[int](3) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + err := q.Enqueue(ctx, 123) + require.NoError(t, err) + err = q.Enqueue(ctx, 234) + require.NoError(t, err) + return q + }, + wantErr: context.DeadlineExceeded, + timeout: -time.Second, + wantSlice: []int{123, 234}, + wantLen: 2, + }, + { + name: "dequeue and empty first", + q: func() *ConcurrentLinkedBlockingQueue[int] { + q := NewConcurrentLinkedBlockingQueue[int](3) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + err := q.Enqueue(ctx, 123) + require.NoError(t, err) + return q + }, + wantVal: 123, + timeout: time.Second, + wantSlice: []int{}, + wantLen: 0, + }, + { + name: "dequeue and empty last", + q: func() *ConcurrentLinkedBlockingQueue[int] { + q := NewConcurrentLinkedBlockingQueue[int](3) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + err := q.Enqueue(ctx, 123) + require.NoError(t, err) + err = q.Enqueue(ctx, 234) + require.NoError(t, err) + err = q.Enqueue(ctx, 345) + require.NoError(t, err) + val, err := q.Dequeue(ctx) + require.NoError(t, err) + require.Equal(t, 123, val) + val, err = q.Dequeue(ctx) + require.NoError(t, err) + require.Equal(t, 234, val) + return q + }, + wantVal: 345, + timeout: time.Second, + wantSlice: []int{}, + wantLen: 0, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), tc.timeout) + defer cancel() + q := tc.q() + val, err := q.Dequeue(ctx) + assert.Equal(t, tc.wantErr, err) + assert.Equal(t, tc.wantVal, val) + assert.Equal(t, tc.wantSlice, q.AsSlice()) + assert.Equal(t, tc.wantLen, q.Len()) + }) + } + + t.Run("dequeue timeout", func(t *testing.T) { + q := NewConcurrentLinkedBlockingQueue[int](3) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + val, err := q.Dequeue(ctx) + require.Equal(t, context.DeadlineExceeded, err) + require.Equal(t, 0, val) + }) + + // 出队阻塞,然后入队,然后出队成功 + t.Run("dequeue blocking and enqueue", func(t *testing.T) { + q := NewConcurrentLinkedBlockingQueue[int](3) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + go func() { + time.Sleep(time.Millisecond * 100) + err := q.Enqueue(ctx, 123) + require.NoError(t, err) + }() + val, err := q.Dequeue(ctx) + require.NoError(t, err) + require.Equal(t, 123, val) + }) +} + +func TestConcurrentLinkedBlockingQueue(t *testing.T) { + // 并发测试,只是测试有没有死锁之类的问题 + // 先进先出这个特性依赖于其它单元测试 + // 也依赖于代码审查 + q := NewConcurrentLinkedBlockingQueue[int](100) + var wg sync.WaitGroup + wg.Add(1000) + for i := 0; i < 1000; i++ { + go func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + val := rand.Int() + err := q.Enqueue(ctx, val) + cancel() + require.NoError(t, err) + }() + } + go func() { + for i := 0; i < 1000; i++ { + go func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + _, err := q.Dequeue(ctx) + cancel() + require.NoError(t, err) + wg.Done() + }() + } + }() + wg.Wait() +} + +func ExampleNewConcurrentLinkedBlockingQueue() { + // 创建一个容量为 10 的有界并发阻塞队列,如果传入 0 或者负数,那么创建的是无界并发阻塞队列 + q := NewConcurrentLinkedBlockingQueue[int](10) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + _ = q.Enqueue(ctx, 22) + val, err := q.Dequeue(ctx) + // 这是例子,实际中你不需要写得那么复杂 + switch err { + case context.Canceled: + // 有人主动取消了,即调用了 cancel 方法。在这个例子里不会出现这个情况 + case context.DeadlineExceeded: + // 超时了 + case nil: + fmt.Println(val) + default: + // 其它乱七八糟的 + } + // Output: + // 22 +} diff --git a/queue/concurrent_linked_queue.go b/queue/concurrent_linked_queue.go new file mode 100644 index 00000000..b4df433f --- /dev/null +++ b/queue/concurrent_linked_queue.go @@ -0,0 +1,85 @@ +// Copyright 2021 gotomicro +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package queue + +import ( + "sync/atomic" + "unsafe" + + "github.com/gotomicro/ekit/internal/queue" +) + +// ConcurrentLinkedQueue 无界并发安全队列 +type ConcurrentLinkedQueue[T any] struct { + // *node[T] + head unsafe.Pointer + // *node[T] + tail unsafe.Pointer +} + +func NewConcurrentLinkedQueue[T any]() *ConcurrentLinkedQueue[T] { + head := &node[T]{} + ptr := unsafe.Pointer(head) + return &ConcurrentLinkedQueue[T]{ + head: ptr, + tail: ptr, + } +} + +func (c *ConcurrentLinkedQueue[T]) Enqueue(t T) error { + newNode := &node[T]{val: t} + newPtr := unsafe.Pointer(newNode) + for { + tailPtr := atomic.LoadPointer(&c.tail) + tail := (*node[T])(tailPtr) + tailNext := atomic.LoadPointer(&tail.next) + if tailNext != nil { + // 已经被人修改了,我们不需要修复,因为预期中修改的那个人会把 c.tail 指过去 + continue + } + if atomic.CompareAndSwapPointer(&tail.next, tailNext, newPtr) { + // 如果失败也不用担心,说明有人抢先一步了 + atomic.CompareAndSwapPointer(&c.tail, tailPtr, newPtr) + return nil + } + } +} + +func (c *ConcurrentLinkedQueue[T]) Dequeue() (T, error) { + for { + headPtr := atomic.LoadPointer(&c.head) + head := (*node[T])(headPtr) + tailPtr := atomic.LoadPointer(&c.tail) + tail := (*node[T])(tailPtr) + if head == tail { + // 不需要做更多检测,在当下这一刻,我们就认为没有元素,即便这时候正好有人入队 + // 但是并不妨碍我们在它彻底入队完成——即所有的指针都调整好——之前, + // 认为其实还是没有元素 + var t T + return t, queue.ErrEmptyQueue + } + headNextPtr := atomic.LoadPointer(&head.next) + if atomic.CompareAndSwapPointer(&c.head, headPtr, headNextPtr) { + headNext := (*node[T])(headNextPtr) + return headNext.val, nil + } + } +} + +type node[T any] struct { + val T + // *node[T] + next unsafe.Pointer +} diff --git a/queue/concurrent_linked_queue_test.go b/queue/concurrent_linked_queue_test.go new file mode 100644 index 00000000..77876a61 --- /dev/null +++ b/queue/concurrent_linked_queue_test.go @@ -0,0 +1,197 @@ +// Copyright 2021 gotomicro +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package queue + +import ( + "fmt" + "math/rand" + "sync" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestConcurrentQueue_Enqueue(t *testing.T) { + t.Parallel() + testCases := []struct { + name string + q func() *ConcurrentLinkedQueue[int] + val int + + wantData []int + wantErr error + }{ + { + name: "empty", + q: func() *ConcurrentLinkedQueue[int] { + return NewConcurrentLinkedQueue[int]() + }, + val: 123, + wantData: []int{123}, + }, + { + name: "multiple", + q: func() *ConcurrentLinkedQueue[int] { + q := NewConcurrentLinkedQueue[int]() + err := q.Enqueue(123) + require.NoError(t, err) + return q + }, + val: 234, + wantData: []int{123, 234}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + q := tc.q() + err := q.Enqueue(tc.val) + assert.Equal(t, tc.wantErr, err) + assert.Equal(t, tc.wantData, q.asSlice()) + }) + } +} + +func TestConcurrentQueue_Dequeue(t *testing.T) { + t.Parallel() + testCases := []struct { + name string + q func() *ConcurrentLinkedQueue[int] + wantVal int + wantData []int + wantErr error + }{ + { + name: "empty", + q: func() *ConcurrentLinkedQueue[int] { + q := NewConcurrentLinkedQueue[int]() + return q + }, + wantErr: errEmptyQueue, + }, + { + name: "single", + q: func() *ConcurrentLinkedQueue[int] { + q := NewConcurrentLinkedQueue[int]() + err := q.Enqueue(123) + assert.NoError(t, err) + return q + }, + wantVal: 123, + }, + { + name: "multiple", + q: func() *ConcurrentLinkedQueue[int] { + q := NewConcurrentLinkedQueue[int]() + err := q.Enqueue(123) + assert.NoError(t, err) + err = q.Enqueue(234) + assert.NoError(t, err) + return q + }, + wantVal: 123, + wantData: []int{234}, + }, + { + name: "enqueue and dequeue", + q: func() *ConcurrentLinkedQueue[int] { + q := NewConcurrentLinkedQueue[int]() + err := q.Enqueue(123) + assert.NoError(t, err) + err = q.Enqueue(234) + assert.NoError(t, err) + val, err := q.Dequeue() + assert.Equal(t, 123, val) + assert.NoError(t, err) + err = q.Enqueue(345) + assert.NoError(t, err) + return q + }, + wantVal: 234, + wantData: []int{345}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + q := tc.q() + val, err := q.Dequeue() + assert.Equal(t, tc.wantErr, err) + if err != nil { + return + } + assert.Equal(t, tc.wantVal, val) + assert.Equal(t, tc.wantData, q.asSlice()) + }) + } +} + +func TestConcurrentLinkedQueue(t *testing.T) { + t.Parallel() + // 仅仅是为了测试在入队出队期间不会出现 panic 或者死循环之类的问题 + // FIFO 特性参考其余测试 + q := NewConcurrentLinkedQueue[int]() + var wg sync.WaitGroup + wg.Add(10000) + for i := 0; i < 10; i++ { + go func() { + for j := 0; j < 1000; j++ { + val := rand.Int() + _ = q.Enqueue(val) + } + }() + } + var cnt int32 = 0 + for i := 0; i < 10; i++ { + go func() { + for { + if atomic.LoadInt32(&cnt) >= 10000 { + return + } + _, err := q.Dequeue() + if err == nil { + atomic.AddInt32(&cnt, 1) + wg.Done() + } + } + }() + } + wg.Wait() +} + +func (c *ConcurrentLinkedQueue[T]) asSlice() []T { + var res []T + cur := (*node[T])((*node[T])(c.head).next) + for cur != nil { + res = append(res, cur.val) + cur = (*node[T])(cur.next) + } + return res +} + +func ExampleNewConcurrentLinkedQueue() { + q := NewConcurrentLinkedQueue[int]() + _ = q.Enqueue(10) + val, err := q.Dequeue() + if err != nil { + // 一般意味着队列为空 + fmt.Println(err) + } + fmt.Println(val) + // Output: + // 10 +} diff --git a/queue/concurrent_priority_queue.go b/queue/concurrent_priority_queue.go new file mode 100644 index 00000000..ae722691 --- /dev/null +++ b/queue/concurrent_priority_queue.go @@ -0,0 +1,65 @@ +// Copyright 2021 gotomicro +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package queue + +import ( + "sync" + + "github.com/gotomicro/ekit" + "github.com/gotomicro/ekit/internal/queue" +) + +type ConcurrentPriorityQueue[T any] struct { + pq queue.PriorityQueue[T] + m sync.RWMutex +} + +func (c *ConcurrentPriorityQueue[T]) Len() int { + c.m.RLock() + defer c.m.RUnlock() + return c.pq.Len() +} + +// Cap 无界队列返回0,有界队列返回创建队列时设置的值 +func (c *ConcurrentPriorityQueue[T]) Cap() int { + c.m.RLock() + defer c.m.RUnlock() + return c.pq.Cap() +} + +func (c *ConcurrentPriorityQueue[T]) Peek() (T, error) { + c.m.RLock() + defer c.m.RUnlock() + return c.pq.Peek() +} + +func (c *ConcurrentPriorityQueue[T]) Enqueue(t T) error { + c.m.Lock() + defer c.m.Unlock() + return c.pq.Enqueue(t) +} + +func (c *ConcurrentPriorityQueue[T]) Dequeue() (T, error) { + c.m.Lock() + defer c.m.Unlock() + return c.pq.Dequeue() +} + +// NewConcurrentPriorityQueue 创建优先队列 capacity <= 0 时,为无界队列 +func NewConcurrentPriorityQueue[T any](capacity int, compare ekit.Comparator[T]) *ConcurrentPriorityQueue[T] { + return &ConcurrentPriorityQueue[T]{ + pq: *queue.NewPriorityQueue[T](capacity, compare), + } +} diff --git a/queue/concurrent_priority_queue_test.go b/queue/concurrent_priority_queue_test.go new file mode 100644 index 00000000..42eef1b8 --- /dev/null +++ b/queue/concurrent_priority_queue_test.go @@ -0,0 +1,311 @@ +// Copyright 2021 gotomicro +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package queue + +import ( + "fmt" + "sync" + "testing" + + "github.com/gotomicro/ekit" + "github.com/gotomicro/ekit/internal/queue" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var ( + errOutOfCapacity = queue.ErrOutOfCapacity + errEmptyQueue = queue.ErrEmptyQueue +) + +func TestNewConcurrentPriorityQueue(t *testing.T) { + testCases := []struct { + name string + q *ConcurrentPriorityQueue[int] + capacity int + data []int + expect []int + }{ + { + name: "无边界", + q: NewConcurrentPriorityQueue(0, ekit.ComparatorRealNumber[int]), + capacity: 0, + data: []int{6, 5, 4, 3, 2, 1}, + expect: []int{1, 2, 3, 4, 5, 6}, + }, + { + name: "有边界 ", + q: NewConcurrentPriorityQueue(6, ekit.ComparatorRealNumber[int]), + capacity: 6, + data: []int{6, 5, 4, 3, 2, 1}, + expect: []int{1, 2, 3, 4, 5, 6}, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, 0, tc.q.Len()) + for _, d := range tc.data { + require.NoError(t, tc.q.Enqueue(d)) + } + assert.Equal(t, tc.capacity, tc.q.Cap()) + assert.Equal(t, len(tc.data), tc.q.Len()) + res := make([]int, 0, len(tc.data)) + for tc.q.Len() > 0 { + head, err := tc.q.Peek() + require.NoError(t, err) + el, err := tc.q.Dequeue() + require.NoError(t, err) + assert.Equal(t, head, el) + res = append(res, el) + } + assert.Equal(t, tc.expect, res) + }) + + } + +} + +// 多个go routine 执行入队操作,完成后,主携程把元素逐一出队,只要有序,可以认为并发入队没问题 +func TestConcurrentPriorityQueue_Enqueue(t *testing.T) { + testCases := []struct { + name string + capacity int + concurrency int + perRoutine int + wantSlice []int + remain int + wantErr error + errCount int + }{ + { + name: "不超过capacity", + capacity: 1100, + concurrency: 100, + perRoutine: 10, + }, + { + name: "超过capacity", + capacity: 1000, + concurrency: 101, + perRoutine: 10, + wantErr: errOutOfCapacity, + errCount: 10, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + q := NewConcurrentPriorityQueue[int](tc.capacity, ekit.ComparatorRealNumber[int]) + wg := sync.WaitGroup{} + wg.Add(tc.concurrency) + errChan := make(chan error, tc.capacity) + for i := tc.concurrency; i > 0; i-- { + go func(i int) { + start := i * 10 + for j := 0; j < tc.perRoutine; j++ { + err := q.Enqueue(start + j) + if err != nil { + errChan <- err + } + } + wg.Done() + }(i) + } + wg.Wait() + assert.Equal(t, tc.errCount, len(errChan)) + prev := -1 + for q.Len() > 0 { + el, _ := q.Dequeue() + assert.Less(t, prev, el) + + // 入队元素总数小于capacity时,应该所有元素都入队了,出队顺序应该依次加1 + if prev > -1 && len(errChan) == 0 { + assert.Equal(t, prev+1, el) + } + prev = el + } + }) + + } +} + +// 预先入队一组数据,通过测试多个协程并发出队时,每个协程内出队元素有序,间接确认并发安全 +func TestConcurrentPriorityQueue_Dequeue(t *testing.T) { + testCases := []struct { + name string + total int + concurrency int + perRoutine int + wantSlice []int + remain int + wantErr error + errCount int + }{ + { + name: "入队大于出队", + total: 910, + concurrency: 100, + perRoutine: 9, + remain: 10, + }, + { + name: "入队小于出队", + total: 900, + concurrency: 101, + perRoutine: 9, + remain: 0, + wantErr: errEmptyQueue, + errCount: 9, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + q := NewConcurrentPriorityQueue[int](tc.total, ekit.ComparatorRealNumber[int]) + for i := tc.total; i > 0; i-- { + require.NoError(t, q.Enqueue(i)) + } + + resultChan := make(chan int, tc.concurrency*tc.perRoutine) + disOrderChan := make(chan bool, tc.concurrency*tc.perRoutine) + errChan := make(chan error, tc.errCount) + wg := sync.WaitGroup{} + wg.Add(tc.concurrency) + + for i := 0; i < tc.concurrency; i++ { + go func() { + prev := -1 + for i := 0; i < tc.perRoutine; i++ { + el, err := q.Dequeue() + if err != nil { + // 如果出队报错,把错误放到error通道,以便后续检查错误的内容和数量是否符合预期 + errChan <- err + } else { + // 如果出队不报错,则检查出队结果是否符合预期 + resultChan <- el + if prev >= el { + disOrderChan <- false + } + prev = el + } + + } + wg.Done() + }() + } + wg.Wait() + close(resultChan) + close(errChan) + close(disOrderChan) + + // 检查并发出队的元素数量,是否符合预期 + assert.Equal(t, tc.remain, q.Len()) + + // 检查所有协程中的执行错误,是否符合预期 + assert.Equal(t, tc.errCount, len(errChan)) + for err := range errChan { + assert.Equal(t, tc.wantErr, err) + } + + // 每个协程内部,出队元素应该有序,检查是否发现无序的情况 + assert.Equal(t, 0, len(disOrderChan)) + + // 每个协程的每次出队操作,出队元素都应该不同,检查是否符合预期 + resultSet := make(map[int]bool) + for el := range resultChan { + _, ok := resultSet[el] + assert.Equal(t, false, ok) + resultSet[el] = true + } + + }) + + } +} + +// 测试同时并发出入队。只要并发安全,并发出入队后的剩余元素数量+报错数量应该符合预期 +// TODO 有待设计更好的并发出入队测试方案 +func TestConcurrentPriorityQueue_EnqueueDequeue(t *testing.T) { + testCases := []struct { + name string + enqueue int + dequeue int + remain int + }{ + { + name: "出队等于入队", + enqueue: 50, + dequeue: 50, + remain: 0, + }, + { + name: "出队小于入队", + enqueue: 50, + dequeue: 40, + remain: 10, + }, + { + name: "出队大于入队", + enqueue: 50, + dequeue: 60, + remain: -10, + }, + } + for _, tt := range testCases { + tc := tt + t.Run(tc.name, func(t *testing.T) { + q := NewConcurrentPriorityQueue[int](0, ekit.ComparatorRealNumber[int]) + errChan := make(chan error, tc.dequeue) + wg := sync.WaitGroup{} + wg.Add(tc.enqueue + tc.dequeue) + go func() { + for i := 0; i < tc.enqueue; i++ { + go func(i int) { + require.NoError(t, q.Enqueue(i)) + wg.Done() + }(i) + } + }() + go func() { + for i := 0; i < tc.dequeue; i++ { + _, err := q.Dequeue() + if err != nil { + errChan <- err + } + wg.Done() + } + }() + + wg.Wait() + close(errChan) + assert.Equal(t, tc.remain, q.Len()-len(errChan)) + }) + } +} + +func ExampleNewConcurrentPriorityQueue() { + q := NewConcurrentPriorityQueue[int](10, ekit.ComparatorRealNumber[int]) + _ = q.Enqueue(3) + _ = q.Enqueue(2) + _ = q.Enqueue(1) + var vals []int + val, _ := q.Dequeue() + vals = append(vals, val) + val, _ = q.Dequeue() + vals = append(vals, val) + val, _ = q.Dequeue() + vals = append(vals, val) + fmt.Println(vals) + // Output: + // [1 2 3] +} diff --git a/queue/delay_queue.go b/queue/delay_queue.go new file mode 100644 index 00000000..470acabb --- /dev/null +++ b/queue/delay_queue.go @@ -0,0 +1,190 @@ +// Copyright 2021 gotomicro +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package queue + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/gotomicro/ekit/internal/queue" +) + +// DelayQueue 延时队列 +// 每次出队的元素必然都是已经到期的元素,即 Delay() 返回的值小于等于 0 +// 延时队列本身对时间的精确度并不是很高,其时间精确度主要取决于 time.Timer +// 所以如果你需要极度精确的延时队列,那么这个结构并不太适合你。 +// 但是如果你能够容忍至多在毫秒级的误差,那么这个结构还是可以使用的 +type DelayQueue[T Delayable] struct { + q queue.PriorityQueue[T] + mutex *sync.Mutex + dequeueSignal *cond + enqueueSignal *cond +} + +func NewDelayQueue[T Delayable](c int) *DelayQueue[T] { + m := &sync.Mutex{} + res := &DelayQueue[T]{ + q: *queue.NewPriorityQueue[T](c, func(src T, dst T) int { + srcDelay := src.Delay() + dstDelay := dst.Delay() + if srcDelay > dstDelay { + return 1 + } + if srcDelay == dstDelay { + return 0 + } + return -1 + }), + mutex: m, + dequeueSignal: newCond(m), + enqueueSignal: newCond(m), + } + return res +} + +func (d *DelayQueue[T]) Enqueue(ctx context.Context, t T) error { + for { + select { + // 先检测 ctx 有没有过期 + case <-ctx.Done(): + return ctx.Err() + default: + } + d.mutex.Lock() + err := d.q.Enqueue(t) + switch err { + case nil: + d.enqueueSignal.broadcast() + return nil + case queue.ErrOutOfCapacity: + signal := d.dequeueSignal.signalCh() + select { + case <-ctx.Done(): + return ctx.Err() + case <-signal: + } + default: + d.mutex.Unlock() + return fmt.Errorf("ekit: 延时队列入队的时候遇到未知错误 %w,请上报", err) + } + } +} + +func (d *DelayQueue[T]) Dequeue(ctx context.Context) (T, error) { + var timer *time.Timer + defer func() { + if timer != nil { + timer.Stop() + } + }() + for { + select { + // 先检测 ctx 有没有过期 + case <-ctx.Done(): + var t T + return t, ctx.Err() + default: + } + d.mutex.Lock() + val, err := d.q.Peek() + switch err { + case nil: + delay := val.Delay() + if delay <= 0 { + val, err = d.q.Dequeue() + d.dequeueSignal.broadcast() + // 理论上来说这里 err 不可能不为 nil + return val, err + } + signal := d.enqueueSignal.signalCh() + if timer == nil { + timer = time.NewTimer(delay) + } else { + timer.Reset(delay) + } + select { + case <-ctx.Done(): + var t T + return t, ctx.Err() + case <-timer.C: + // 到了时间 + d.mutex.Lock() + // 原队头可能已经被其他协程先出队,故再次检查队头 + val, err := d.q.Peek() + if err != nil || val.Delay() > 0 { + d.mutex.Unlock() + continue + } + // 验证元素过期后将其出队 + val, err = d.q.Dequeue() + d.dequeueSignal.broadcast() + return val, err + case <-signal: + // 进入下一个循环。这里可能是有新的元素入队,也可能是到期了 + } + case queue.ErrEmptyQueue: + signal := d.enqueueSignal.signalCh() + select { + case <-ctx.Done(): + var t T + return t, ctx.Err() + case <-signal: + } + default: + d.mutex.Unlock() + var t T + return t, fmt.Errorf("ekit: 延时队列出队的时候遇到未知错误 %w,请上报", err) + } + } +} + +type Delayable interface { + Delay() time.Duration +} + +type cond struct { + signal chan struct{} + l sync.Locker +} + +func newCond(l sync.Locker) *cond { + return &cond{ + signal: make(chan struct{}), + l: l, + } +} + +// broadcast 唤醒等待者 +// 如果没有人等待,那么什么也不会发生 +// 必须加锁之后才能调用这个方法 +// 广播之后锁会被释放,这也是为了确保用户必然是在锁范围内调用的 +func (c *cond) broadcast() { + signal := make(chan struct{}) + old := c.signal + c.signal = signal + c.l.Unlock() + close(old) +} + +// signalCh 返回一个 channel,用于监听广播信号 +// 必须在锁范围内使用 +// 调用后,锁会被释放,这也是为了确保用户必然是在锁范围内调用的 +func (c *cond) signalCh() <-chan struct{} { + res := c.signal + c.l.Unlock() + return res +} diff --git a/queue/delay_queue_test.go b/queue/delay_queue_test.go new file mode 100644 index 00000000..0b803851 --- /dev/null +++ b/queue/delay_queue_test.go @@ -0,0 +1,351 @@ +// Copyright 2021 gotomicro +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package queue + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" +) + +func TestDelayQueue_Dequeue(t *testing.T) { + t.Parallel() + now := time.Now() + testCases := []struct { + name string + q *DelayQueue[delayElem] + timeout time.Duration + wantVal int + wantErr error + }{ + { + name: "dequeued", + q: newDelayQueue(t, delayElem{ + deadline: now.Add(time.Millisecond * 10), + val: 11, + }), + timeout: time.Second, + wantVal: 11, + }, + { + // 元素本身就已经过期了 + name: "already deadline", + q: newDelayQueue(t, delayElem{ + deadline: now.Add(-time.Millisecond * 10), + val: 11, + }), + timeout: time.Second, + wantVal: 11, + }, + { + // 已经超时了的 context 设置 + name: "invalid context", + q: newDelayQueue(t, delayElem{ + deadline: now.Add(time.Millisecond * 10), + val: 11, + }), + timeout: -time.Second, + wantErr: context.DeadlineExceeded, + }, + { + name: "empty and timeout", + q: NewDelayQueue[delayElem](10), + timeout: time.Second, + wantErr: context.DeadlineExceeded, + }, + { + name: "not empty but timeout", + q: newDelayQueue(t, delayElem{ + deadline: now.Add(time.Second * 10), + val: 11, + }), + timeout: time.Second, + wantErr: context.DeadlineExceeded, + }, + } + + for _, tt := range testCases { + tc := tt + t.Run(tc.name, func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), tc.timeout) + defer cancel() + ele, err := tc.q.Dequeue(ctx) + assert.Equal(t, tc.wantErr, err) + if err != nil { + return + } + assert.Equal(t, tc.wantVal, ele.val) + }) + } + + // 最开始没有元素,然后进去了一个元素 + t.Run("dequeue while enqueue", func(t *testing.T) { + q := NewDelayQueue[delayElem](3) + go func() { + time.Sleep(time.Millisecond * 500) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + err := q.Enqueue(ctx, delayElem{ + val: 123, + deadline: time.Now().Add(time.Millisecond * 100), + }) + require.NoError(t, err) + }() + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + ele, err := q.Dequeue(ctx) + require.NoError(t, err) + require.Equal(t, 123, ele.val) + }) + + // 进去了一个更加短超时时间的元素 + // 于是后面两个都会拿出来,但是时间短的会先拿出来 + t.Run("enqueue short ele", func(t *testing.T) { + q := NewDelayQueue[delayElem](3) + // 长时间过期的元素 + err := q.Enqueue(context.Background(), delayElem{ + val: 234, + deadline: time.Now().Add(time.Second), + }) + require.NoError(t, err) + + go func() { + time.Sleep(time.Millisecond * 200) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + err := q.Enqueue(ctx, delayElem{ + val: 123, + deadline: time.Now().Add(time.Millisecond * 300), + }) + require.NoError(t, err) + }() + ctx, cancel := context.WithTimeout(context.Background(), time.Second*2) + defer cancel() + // 先拿出短时间的 + ele, err := q.Dequeue(ctx) + require.NoError(t, err) + require.Equal(t, 123, ele.val) + require.True(t, ele.deadline.Before(time.Now())) + + // 再拿出长时间的 + ele, err = q.Dequeue(ctx) + require.NoError(t, err) + require.Equal(t, 234, ele.val) + require.True(t, ele.deadline.Before(time.Now())) + + // 没有元素了,会超时 + _, err = q.Dequeue(ctx) + require.Equal(t, context.DeadlineExceeded, err) + }) + + t.Run("dequeue two elements concurrently with larger delay intervals", func(t *testing.T) { + t.Parallel() + + capacity := 2 + q := NewDelayQueue[delayElem](capacity) + + // 使队列处于有元素状态,元素间的截止日期有较大时间差 + elem1 := delayElem{ + val: 10001, + deadline: time.Now().Add(50 * time.Millisecond), + } + require.NoError(t, q.Enqueue(context.Background(), elem1)) + + elem2 := delayElem{ + val: 10002, + deadline: time.Now().Add(500 * time.Millisecond), + } + require.NoError(t, q.Enqueue(context.Background(), elem2)) + + // 并发出队,使调用者协程并发地按照较小截止日期的元素的延迟时间进行等待 + elemsChan := make(chan delayElem, capacity) + var eg errgroup.Group + for i := 0; i < capacity; i++ { + eg.Go(func() error { + ele, err := q.Dequeue(context.Background()) + elemsChan <- ele + return err + }) + } + + assert.NoError(t, eg.Wait()) + + // 一定先拿出短时间的 + ele := <-elemsChan + require.Equal(t, elem1.val, ele.val) + require.True(t, ele.deadline.Before(time.Now())) + + // 再拿出长时间的,因为并发原因多个调用者协程可能都等待具有较小截止日期的元素 + // 防止后者未验证元素是否过期而直接将其出队 + ele = <-elemsChan + require.Equal(t, elem2.val, ele.val) + require.True(t, ele.deadline.Before(time.Now())) + }) +} + +func TestDelayQueue_Enqueue(t *testing.T) { + t.Parallel() + now := time.Now() + testCases := []struct { + name string + q *DelayQueue[delayElem] + timeout time.Duration + val delayElem + wantErr error + }{ + { + name: "enqueued", + q: NewDelayQueue[delayElem](3), + timeout: time.Second, + val: delayElem{val: 123, deadline: now.Add(time.Minute)}, + }, + { + // context 本身已经过期了 + name: "invalid context", + q: NewDelayQueue[delayElem](3), + timeout: -time.Second, + val: delayElem{val: 123, deadline: now.Add(time.Minute)}, + wantErr: context.DeadlineExceeded, + }, + { + // enqueue 的时候阻塞住了,直到超时 + name: "enqueue timeout", + q: newDelayQueue(t, delayElem{val: 123, deadline: now.Add(time.Minute)}), + timeout: time.Millisecond * 100, + val: delayElem{val: 234, deadline: now.Add(time.Minute)}, + wantErr: context.DeadlineExceeded, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), tc.timeout) + defer cancel() + err := tc.q.Enqueue(ctx, tc.val) + assert.Equal(t, tc.wantErr, err) + }) + } + + // 队列满了,这时候入队。 + // 在等待一段时间之后,队列元素被取走一个 + t.Run("enqueue while dequeue", func(t *testing.T) { + t.Parallel() + q := newDelayQueue(t, delayElem{val: 123, deadline: time.Now().Add(time.Second)}) + go func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*2) + defer cancel() + ele, err := q.Dequeue(ctx) + require.NoError(t, err) + require.Equal(t, 123, ele.val) + }() + ctx, cancel := context.WithTimeout(context.Background(), time.Second*2) + defer cancel() + err := q.Enqueue(ctx, delayElem{val: 345, deadline: time.Now().Add(time.Millisecond * 1500)}) + require.NoError(t, err) + }) + + // 入队相同过期时间的元素 + // 但是因为我们在入队的时候是分别计算 Delay 的 + // 那么就会导致虽然过期时间是相同的,但是因为调用 Delay 有先后之分 + // 所以会造成 dstDelay 就是要比 srcDelay 小一点 + t.Run("enqueue with same deadline", func(t *testing.T) { + t.Parallel() + q := NewDelayQueue[delayElem](3) + deadline := time.Now().Add(time.Second) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*2) + defer cancel() + err := q.Enqueue(ctx, delayElem{val: 123, deadline: deadline}) + require.NoError(t, err) + err = q.Enqueue(ctx, delayElem{val: 456, deadline: deadline}) + require.NoError(t, err) + err = q.Enqueue(ctx, delayElem{val: 789, deadline: deadline}) + require.NoError(t, err) + + ele, err := q.Dequeue(ctx) + require.NoError(t, err) + require.Equal(t, 123, ele.val) + + ele, err = q.Dequeue(ctx) + require.NoError(t, err) + require.Equal(t, 789, ele.val) + + ele, err = q.Dequeue(ctx) + require.NoError(t, err) + require.Equal(t, 456, ele.val) + }) +} + +func newDelayQueue(t *testing.T, eles ...delayElem) *DelayQueue[delayElem] { + q := NewDelayQueue[delayElem](len(eles)) + for _, ele := range eles { + err := q.Enqueue(context.Background(), ele) + require.NoError(t, err) + } + return q +} + +type delayElem struct { + deadline time.Time + val int +} + +func (d delayElem) Delay() time.Duration { + return time.Until(d.deadline) +} + +func ExampleNewDelayQueue() { + q := NewDelayQueue[delayElem](10) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + now := time.Now() + _ = q.Enqueue(ctx, delayElem{ + // 3 秒后过期 + deadline: now.Add(time.Second * 3), + val: 3, + }) + + _ = q.Enqueue(ctx, delayElem{ + // 2 秒后过期 + deadline: now.Add(time.Second * 2), + val: 2, + }) + + _ = q.Enqueue(ctx, delayElem{ + // 1 秒后过期 + deadline: now.Add(time.Second * 1), + val: 1, + }) + + var vals []int + val, _ := q.Dequeue(ctx) + vals = append(vals, val.val) + val, _ = q.Dequeue(ctx) + vals = append(vals, val.val) + val, _ = q.Dequeue(ctx) + vals = append(vals, val.val) + fmt.Println(vals) + duration := time.Since(now) + if duration > time.Second*3 { + fmt.Println("delay!") + } + // Output: + // [1 2 3] + // delay! +} diff --git a/queue/types.go b/queue/types.go new file mode 100644 index 00000000..f8ba0b80 --- /dev/null +++ b/queue/types.go @@ -0,0 +1,50 @@ +// Copyright 2021 gotomicro +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package queue + +import ( + "context" +) + +// BlockingQueue 阻塞队列 +// 参考 Queue 普通队列 +// 一个阻塞队列是否遵循 FIFO 取决于具体实现 +type BlockingQueue[T any] interface { + // Enqueue 将元素放入队列。如果在 ctx 超时之前,队列有空闲位置,那么元素会被放入队列; + // 否则返回 error。 + // 在超时或者调用者主动 cancel 的情况下,所有的实现都必须返回 ctx。 + // 调用者可以通过检查 error 是否为 context.DeadlineExceeded + // 或者 context.Canceled 来判断入队失败的原因 + // 注意,调用者必须使用 errors.Is 来判断,而不能直接使用 == + Enqueue(ctx context.Context, t T) error + // Dequeue 从队首获得一个元素 + // 如果在 ctx 超时之前,队列中有元素,那么会返回队首的元素,否则返回 error。 + // 在超时或者调用者主动 cancel 的情况下,所有的实现都必须返回 ctx。 + // 调用者可以通过检查 error 是否为 context.DeadlineExceeded + // 或者 context.Canceled 来判断入队失败的原因 + // 注意,调用者必须使用 errors.Is 来判断,而不能直接使用 == + Dequeue(ctx context.Context) (T, error) +} + +// Queue 普通队列 +// 参考 BlockingQueue 阻塞队列 +// 一个队列是否遵循 FIFO 取决于具体实现 +type Queue[T any] interface { + // Enqueue 将元素放入队列,如果此时队列已经满了,那么返回错误 + Enqueue(t T) error + // Dequeue 从队首获得一个元素 + // 如果此时队列里面没有元素,那么返回错误 + Dequeue() (T, error) +} diff --git a/sqlx/encrypt.go b/sqlx/encrypt.go index d6f8b53b..e232d839 100644 --- a/sqlx/encrypt.go +++ b/sqlx/encrypt.go @@ -39,6 +39,7 @@ type EncryptColumn[T any] struct { } var errInvalid = errors.New("ekit EncryptColumn无效") +var errKeyLengthInvalid = errors.New("ekit EncryptColumn仅支持 16/24/32 byte 的key") // Value 返回加密后的值 // 如果 T 是基本类型,那么会对 T 进行直接加密 @@ -47,6 +48,9 @@ func (e EncryptColumn[T]) Value() (driver.Value, error) { if !e.Valid { return nil, errInvalid } + if len(e.Key) != 16 && len(e.Key) != 24 && len(e.Key) != 32 { + return nil, errKeyLengthInvalid + } var val any = e.Val var err error var b []byte diff --git a/sqlx/encrypt_test.go b/sqlx/encrypt_test.go index ab4301d0..1a10e77c 100644 --- a/sqlx/encrypt_test.go +++ b/sqlx/encrypt_test.go @@ -36,6 +36,11 @@ func TestEncryptColumn_Basic(t *testing.T) { wantEnErr error wantDeErr error }{ + { + name: "wrong length key", + input: &EncryptColumn[string]{Key: "ABC", Val: "abc", Valid: true}, + wantEnErr: errKeyLengthInvalid, + }, { name: "int", input: &EncryptColumn[int32]{Key: "ABCDABCDABCDABCDABCDABCDABCDABCD", Val: 123, Valid: true}, diff --git a/syncx/atomicx/atomic.go b/syncx/atomicx/atomic.go new file mode 100644 index 00000000..705e5f83 --- /dev/null +++ b/syncx/atomicx/atomic.go @@ -0,0 +1,64 @@ +// Copyright 2021 gotomicro +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package atomicx + +import "sync/atomic" + +// Value 是对 atomic.Value 的泛型封装 +// 相比直接使用 atomic.Value, +// - Load 方法大概开销多了 0.5 ns +// - Store 方法多了不到 2 ns +// - Swap 方法多了 14 ns +// - CompareAndSwap 在失败的情况下,会多 2 ns,成功的时候多了 0.3 ns +// 使用 NewValue 或者 NewValueOf 来创建实例 +type Value[T any] struct { + val atomic.Value +} + +// NewValue 会创建一个 Value 对象,里面存放着 T 的零值 +// 注意,这个零值是带了类型的零值 +func NewValue[T any]() *Value[T] { + var t T + return NewValueOf[T](t) +} + +// NewValueOf 会使用传入的值来创建一个 Value 对象 +func NewValueOf[T any](t T) *Value[T] { + val := atomic.Value{} + val.Store(t) + return &Value[T]{ + val: val, + } +} + +func (v *Value[T]) Load() (val T) { + data := v.val.Load() + val = data.(T) + return +} + +func (v *Value[T]) Store(val T) { + v.val.Store(val) +} + +func (v *Value[T]) Swap(new T) (old T) { + data := v.val.Swap(new) + old = data.(T) + return +} + +func (v *Value[T]) CompareAndSwap(old, new T) (swapped bool) { + return v.val.CompareAndSwap(old, new) +} diff --git a/syncx/atomicx/atomic_test.go b/syncx/atomicx/atomic_test.go new file mode 100644 index 00000000..36633fc6 --- /dev/null +++ b/syncx/atomicx/atomic_test.go @@ -0,0 +1,233 @@ +// Copyright 2021 gotomicro +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package atomicx + +import ( + "sync/atomic" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNewValueOf(t *testing.T) { + testCases := []struct { + name string + input *User + }{ + { + name: "nil", + }, + { + name: "user", + input: &User{ + Name: "Tom", + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + val := NewValueOf[*User](tc.input) + assert.Equal(t, tc.input, val.Load()) + }) + } +} + +func TestValue_CompareAndSwap(t *testing.T) { + testCases := []struct { + name string + old *User + new *User + }{ + { + name: "both nil", + }, + { + name: "old nil", + new: &User{Name: "Tom"}, + }, + { + name: "new nil", + old: &User{Name: "Tom"}, + }, + { + name: "not nil", + new: &User{Name: "Tom"}, + old: &User{Name: "Jerry"}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + val := NewValueOf[*User](tc.old) + swapped := val.CompareAndSwap(tc.old, tc.new) + assert.True(t, swapped) + }) + } +} + +func TestValue_Swap(t *testing.T) { + testCases := []struct { + name string + old *User + new *User + }{ + { + name: "both nil", + }, + { + name: "old nil", + new: &User{Name: "Tom"}, + }, + { + name: "new nil", + old: &User{Name: "Tom"}, + }, + { + name: "not nil", + new: &User{Name: "Tom"}, + old: &User{Name: "Jerry"}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + val := NewValueOf[*User](tc.old) + oldVal := val.Swap(tc.new) + newVal := val.Load() + assert.Equal(t, tc.old, oldVal) + assert.Equal(t, tc.new, newVal) + }) + } +} + +func TestValue_Store_Load(t *testing.T) { + testCases := []struct { + name string + input *User + wantVal *User + }{ + { + name: "nil", + }, + { + name: "user", + input: &User{ + Name: "Tom", + }, + wantVal: &User{ + Name: "Tom", + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + val := NewValue[*User]() + val.Store(tc.input) + v := val.Load() + assert.Equal(t, tc.wantVal, v) + }) + } +} + +func BenchmarkValue_Load(b *testing.B) { + b.Run("Value", func(b *testing.B) { + val := NewValueOf[int](123) + for i := 0; i < b.N; i++ { + _ = val.Load() + } + }) + + b.Run("atomic Value", func(b *testing.B) { + val := &atomic.Value{} + val.Store(123) + for i := 0; i < b.N; i++ { + _ = val.Load() + } + }) +} + +func BenchmarkValue_Store(b *testing.B) { + b.Run("Value", func(b *testing.B) { + val := NewValue[int]() + for i := 0; i < b.N; i++ { + val.Store(123) + } + }) + + b.Run("atomic Value", func(b *testing.B) { + val := &atomic.Value{} + + for i := 0; i < b.N; i++ { + val.Store(123) + } + }) +} + +func BenchmarkValue_Swap(b *testing.B) { + b.Run("Value", func(b *testing.B) { + val := NewValueOf[int](123) + for i := 0; i < b.N; i++ { + _ = val.Swap(456) + } + }) + + b.Run("atomic Value", func(b *testing.B) { + val := &atomic.Value{} + val.Store(123) + for i := 0; i < b.N; i++ { + _ = val.Swap(456) + } + }) +} + +func BenchmarkValue_CompareAndSwap(b *testing.B) { + b.Run("Value", func(b *testing.B) { + b.Run("fail", func(b *testing.B) { + val := NewValueOf[int](123) + for i := 0; i < b.N; i++ { + _ = val.CompareAndSwap(-1, 100) + } + }) + b.Run("success", func(b *testing.B) { + val := NewValueOf[int](0) + for i := 0; i < b.N; i++ { + _ = val.CompareAndSwap(i, i+1) + } + }) + }) + + b.Run("atomic Value", func(b *testing.B) { + b.Run("fail", func(b *testing.B) { + val := &atomic.Value{} + val.Store(123) + for i := 0; i < b.N; i++ { + _ = val.CompareAndSwap(-1, 100) + } + }) + b.Run("success", func(b *testing.B) { + val := &atomic.Value{} + val.Store(0) + for i := 0; i < b.N; i++ { + _ = val.CompareAndSwap(i, i+1) + } + }) + }) +} + +type User struct { + Name string +} diff --git a/syncx/atomicx/example_test.go b/syncx/atomicx/example_test.go new file mode 100644 index 00000000..59196dff --- /dev/null +++ b/syncx/atomicx/example_test.go @@ -0,0 +1,74 @@ +// Copyright 2021 gotomicro +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package atomicx + +import "fmt" + +func ExampleNewValue() { + val := NewValue[int]() + data := val.Load() + fmt.Println(data) + // Output: + // 0 +} + +func ExampleNewValueOf() { + val := NewValueOf[int](123) + data := val.Load() + fmt.Println(data) + // Output: + // 123 +} + +func ExampleValue_Load() { + val := NewValueOf[int](123) + data := val.Load() + fmt.Println(data) + // Output: + // 123 +} + +func ExampleValue_Store() { + val := NewValueOf[int](123) + data := val.Load() + fmt.Println(data) + val.Store(456) + data = val.Load() + fmt.Println(data) + // Output: + // 123 + // 456 +} + +func ExampleValue_Swap() { + val := NewValueOf[int](123) + oldVal := val.Swap(456) + newVal := val.Load() + fmt.Printf("old: %d, new: %d", oldVal, newVal) + // Output: + // old: 123, new: 456 +} + +func ExampleValue_CompareAndSwap() { + val := NewValueOf[int](123) + swapped := val.CompareAndSwap(123, 456) + fmt.Println(swapped) + + swapped = val.CompareAndSwap(455, 459) + fmt.Println(swapped) + // Output: + // true + // false +} diff --git a/types.go b/types.go new file mode 100644 index 00000000..9a60fbbb --- /dev/null +++ b/types.go @@ -0,0 +1,29 @@ +// Copyright 2021 gotomicro +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ekit + +// Comparator 用于比较两个对象的大小 src < dst, 返回-1,src = dst, 返回0,src > dst, 返回1 +// 不要返回任何其它值! +type Comparator[T any] func(src T, dst T) int + +func ComparatorRealNumber[T RealNumber](src T, dst T) int { + if src < dst { + return -1 + } else if src == dst { + return 0 + } else { + return 1 + } +} diff --git a/value.go b/value.go new file mode 100644 index 00000000..e21c6c96 --- /dev/null +++ b/value.go @@ -0,0 +1,236 @@ +// Copyright 2021 gotomicro +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ekit + +import ( + "reflect" + + "github.com/gotomicro/ekit/internal/errs" +) + +type AnyValue struct { + Val any + Err error +} + +// Int 返回 int 数据 +func (av AnyValue) Int() (int, error) { + if av.Err != nil { + return 0, av.Err + } + val, ok := av.Val.(int) + if !ok { + return 0, errs.NewErrInvalidType("int", reflect.TypeOf(av.Val).String()) + } + return val, nil +} + +// IntOrDefault 返回 int 数据,或者默认值 +func (a AnyValue) IntOrDefault(def int) int { + val, err := a.Int() + if err != nil { + return def + } + return val +} + +// Uint 返回 uint 数据 +func (av AnyValue) Uint() (uint, error) { + if av.Err != nil { + return 0, av.Err + } + val, ok := av.Val.(uint) + if !ok { + return 0, errs.NewErrInvalidType("uint", reflect.TypeOf(av.Val).String()) + } + return val, nil +} + +// UintOrDefault 返回 uint 数据,或者默认值 +func (a AnyValue) UintOrDefault(def uint) uint { + val, err := a.Uint() + if err != nil { + return def + } + return val +} + +// Int32 返回 int32 数据 +func (av AnyValue) Int32() (int32, error) { + if av.Err != nil { + return 0, av.Err + } + val, ok := av.Val.(int32) + if !ok { + return 0, errs.NewErrInvalidType("int32", reflect.TypeOf(av.Val).String()) + } + return val, nil +} + +// Int32OrDefault 返回 int32 数据,或者默认值 +func (a AnyValue) Int32OrDefault(def int32) int32 { + val, err := a.Int32() + if err != nil { + return def + } + return val +} + +// Uint32 返回 uint32 数据 +func (av AnyValue) Uint32() (uint32, error) { + if av.Err != nil { + return 0, av.Err + } + val, ok := av.Val.(uint32) + if !ok { + return 0, errs.NewErrInvalidType("uint32", reflect.TypeOf(av.Val).String()) + } + return val, nil +} + +// Uint32OrDefault 返回 uint32 数据,或者默认值 +func (a AnyValue) Uint32OrDefault(def uint32) uint32 { + val, err := a.Uint32() + if err != nil { + return def + } + return val +} + +// Int64 返回 int64 数据 +func (av AnyValue) Int64() (int64, error) { + if av.Err != nil { + return 0, av.Err + } + val, ok := av.Val.(int64) + if !ok { + return 0, errs.NewErrInvalidType("int64", reflect.TypeOf(av.Val).String()) + } + return val, nil +} + +// Int64OrDefault 返回 int64 数据,或者默认值 +func (a AnyValue) Int64OrDefault(def int64) int64 { + val, err := a.Int64() + if err != nil { + return def + } + return val +} + +// Uint64 返回 uint64 数据 +func (av AnyValue) Uint64() (uint64, error) { + if av.Err != nil { + return 0, av.Err + } + val, ok := av.Val.(uint64) + if !ok { + return 0, errs.NewErrInvalidType("uint64", reflect.TypeOf(av.Val).String()) + } + return val, nil +} + +// Uint64OrDefault 返回 uint64 数据,或者默认值 +func (a AnyValue) Uint64OrDefault(def uint64) uint64 { + val, err := a.Uint64() + if err != nil { + return def + } + return val +} + +// Float32 返回 float32 数据 +func (av AnyValue) Float32() (float32, error) { + if av.Err != nil { + return 0, av.Err + } + val, ok := av.Val.(float32) + if !ok { + return 0, errs.NewErrInvalidType("float32", reflect.TypeOf(av.Val).String()) + } + return val, nil +} + +// Float32OrDefault 返回 float32 数据,或者默认值 +func (a AnyValue) Float32OrDefault(def float32) float32 { + val, err := a.Float32() + if err != nil { + return def + } + return val +} + +// Float64 返回 float64 数据 +func (av AnyValue) Float64() (float64, error) { + if av.Err != nil { + return 0, av.Err + } + val, ok := av.Val.(float64) + if !ok { + return 0, errs.NewErrInvalidType("float64", reflect.TypeOf(av.Val).String()) + } + return val, nil +} + +// Float64OrDefault 返回 float64 数据,或者默认值 +func (a AnyValue) Float64OrDefault(def float64) float64 { + val, err := a.Float64() + if err != nil { + return def + } + return val +} + +// String 返回 string 数据 +func (av AnyValue) String() (string, error) { + if av.Err != nil { + return "", av.Err + } + val, ok := av.Val.(string) + if !ok { + return "", errs.NewErrInvalidType("string", reflect.TypeOf(av.Val).String()) + } + return val, nil +} + +// StringOrDefault 返回 string 数据,或者默认值 +func (a AnyValue) StringOrDefault(def string) string { + val, err := a.String() + if err != nil { + return def + } + return val +} + +// Bytes 返回 []byte 数据 +func (av AnyValue) Bytes() ([]byte, error) { + if av.Err != nil { + return nil, av.Err + } + val, ok := av.Val.([]byte) + if !ok { + return nil, errs.NewErrInvalidType("[]byte", reflect.TypeOf(av.Val).String()) + } + return val, nil +} + +// BytesOrDefault 返回 []byte 数据,或者默认值 +func (a AnyValue) BytesOrDefault(def []byte) []byte { + val, err := a.Bytes() + if err != nil { + return def + } + return val +} diff --git a/value_test.go b/value_test.go new file mode 100644 index 00000000..b7788220 --- /dev/null +++ b/value_test.go @@ -0,0 +1,886 @@ +// Copyright 2021 gotomicro +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ekit + +import ( + "errors" + "reflect" + "testing" + + "github.com/gotomicro/ekit/internal/errs" + "github.com/stretchr/testify/assert" +) + +func TestAnyValue_Int(t *testing.T) { + tests := []struct { + name string + val AnyValue + want int + err error + }{ + { + name: "normal case:", + val: AnyValue{ + Val: int(1), + }, + want: int(1), + err: nil, + }, + { + name: "error case:", + val: AnyValue{ + Err: errors.New("error"), + }, + err: errors.New("error"), + }, + { + name: "type error case:", + val: AnyValue{ + Val: "", + }, + err: errs.NewErrInvalidType("int", reflect.TypeOf("").String()), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + av := AnyValue{ + Val: tt.val.Val, + Err: tt.val.Err, + } + got, err := av.Int() + assert.Equal(t, err, tt.err) + assert.Equal(t, got, tt.want) + }) + } +} + +func TestAnyValue_IntOrDefault(t *testing.T) { + tests := []struct { + name string + val AnyValue + def int + want int + }{ + { + name: "normal case:", + val: AnyValue{ + Val: int(1), + }, + want: 1, + }, + { + name: "default case:", + val: AnyValue{ + Val: int(1), + Err: errors.New("error"), + }, + def: int(2), + want: int(2), + }, + { + name: "type error case:", + val: AnyValue{ + Val: "", + }, + def: int(1), + want: int(1), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := AnyValue{ + Val: tt.val.Val, + Err: tt.val.Err, + } + assert.Equal(t, a.IntOrDefault(tt.def), tt.want) + }) + } +} + +func TestAnyValue_Uint(t *testing.T) { + tests := []struct { + name string + val AnyValue + want uint + err error + }{ + { + name: "normal case:", + val: AnyValue{ + Val: uint(1), + }, + want: uint(1), + err: nil, + }, + { + name: "error case:", + val: AnyValue{ + Err: errors.New("error"), + }, + err: errors.New("error"), + }, + { + name: "type error case:", + val: AnyValue{ + Val: []string{"111"}, + }, + err: errs.NewErrInvalidType("uint", reflect.TypeOf([]string{"111"}).String()), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + av := AnyValue{ + Val: tt.val.Val, + Err: tt.val.Err, + } + got, err := av.Uint() + assert.Equal(t, err, tt.err) + assert.Equal(t, got, tt.want) + }) + } +} + +func TestAnyValue_UintOrDefault(t *testing.T) { + tests := []struct { + name string + val AnyValue + def uint + want uint + }{ + { + name: "normal case:", + val: AnyValue{ + Val: uint(1), + }, + want: 1, + }, + { + name: "default case:", + val: AnyValue{ + Val: uint(1), + Err: errors.New("error"), + }, + def: uint(2), + want: uint(2), + }, + { + name: "type error case:", + val: AnyValue{ + Val: "", + }, + def: uint(2), + want: uint(2), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := AnyValue{ + Val: tt.val.Val, + Err: tt.val.Err, + } + assert.Equal(t, a.UintOrDefault(tt.def), tt.want) + }) + } +} + +func TestAnyValue_Int32(t *testing.T) { + tests := []struct { + name string + val AnyValue + want int32 + err error + }{ + { + name: "normal case:", + val: AnyValue{ + Val: int32(1), + }, + want: int32(1), + err: nil, + }, + { + name: "error case:", + val: AnyValue{ + Err: errors.New("error"), + }, + err: errors.New("error"), + }, + { + name: "type error case:", + val: AnyValue{ + Val: "", + }, + err: errs.NewErrInvalidType("int32", reflect.TypeOf("").String()), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + av := AnyValue{ + Val: tt.val.Val, + Err: tt.val.Err, + } + got, err := av.Int32() + assert.Equal(t, err, tt.err) + assert.Equal(t, got, tt.want) + }) + } +} + +func TestAnyValue_Int32OrDefault(t *testing.T) { + tests := []struct { + name string + val AnyValue + def int32 + want int32 + }{ + { + name: "normal case:", + val: AnyValue{ + Val: int32(1), + }, + want: int32(1), + }, + { + name: "default case:", + val: AnyValue{ + Val: int32(1), + Err: errors.New("error"), + }, + def: int32(2), + want: int32(2), + }, + { + name: "type error case:", + val: AnyValue{ + Val: "", + }, + def: int32(2), + want: int32(2), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := AnyValue{ + Val: tt.val.Val, + Err: tt.val.Err, + } + assert.Equal(t, a.Int32OrDefault(tt.def), tt.want) + }) + } +} + +func TestAnyValue_Uint32(t *testing.T) { + tests := []struct { + name string + val AnyValue + want uint32 + err error + }{ + { + name: "normal case:", + val: AnyValue{ + Val: uint32(1), + }, + want: uint32(1), + err: nil, + }, + { + name: "error case:", + val: AnyValue{ + Err: errors.New("error"), + }, + err: errors.New("error"), + }, + { + name: "type error case:", + val: AnyValue{ + Val: "", + }, + err: errs.NewErrInvalidType("uint32", reflect.TypeOf("").String()), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + av := AnyValue{ + Val: tt.val.Val, + Err: tt.val.Err, + } + got, err := av.Uint32() + assert.Equal(t, err, tt.err) + assert.Equal(t, got, tt.want) + }) + } +} + +func TestAnyValue_Uint32OrDefault(t *testing.T) { + tests := []struct { + name string + val AnyValue + def uint32 + want uint32 + }{ + { + name: "normal case:", + val: AnyValue{ + Val: uint32(1), + }, + want: 1, + }, + { + name: "default case:", + val: AnyValue{ + Val: uint32(1), + Err: errors.New("error"), + }, + + def: uint32(2), + want: uint32(2), + }, + { + name: "type error case:", + val: AnyValue{ + Val: "", + }, + def: uint32(2), + want: uint32(2), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := AnyValue{ + Val: tt.val.Val, + Err: tt.val.Err, + } + assert.Equal(t, a.Uint32OrDefault(tt.def), tt.want) + }) + } +} + +func TestAnyValue_Int64(t *testing.T) { + tests := []struct { + name string + val AnyValue + want int64 + err error + }{ + { + name: "normal case:", + val: AnyValue{ + Val: int64(1), + }, + want: int64(1), + err: nil, + }, + { + name: "error case:", + val: AnyValue{ + Err: errors.New("error"), + }, + err: errors.New("error"), + }, + { + name: "type error case:", + val: AnyValue{ + Val: "", + }, + err: errs.NewErrInvalidType("int64", reflect.TypeOf("").String()), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + av := AnyValue{ + Val: tt.val.Val, + Err: tt.val.Err, + } + got, err := av.Int64() + assert.Equal(t, err, tt.err) + assert.Equal(t, got, tt.want) + }) + } +} + +func TestAnyValue_Int64OrDefault(t *testing.T) { + tests := []struct { + name string + val AnyValue + def int64 + want int64 + }{ + { + name: "normal case:", + val: AnyValue{ + Val: int64(1), + }, + want: 1, + }, + { + name: "default case:", + val: AnyValue{ + Val: int64(1), + Err: errors.New("error"), + }, + def: int64(2), + want: int64(2), + }, + { + name: "type error case:", + val: AnyValue{ + Val: "", + }, + def: int64(2), + want: int64(2), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := AnyValue{ + Val: tt.val.Val, + Err: tt.val.Err, + } + assert.Equal(t, a.Int64OrDefault(tt.def), tt.want) + }) + } +} + +func TestAnyValue_Uint64(t *testing.T) { + tests := []struct { + name string + val AnyValue + want uint64 + err error + }{ + { + name: "normal case:", + val: AnyValue{ + Val: uint64(1), + }, + want: uint64(1), + err: nil, + }, + { + name: "error case:", + val: AnyValue{ + Err: errors.New("error"), + }, + err: errors.New("error"), + }, + { + name: "type error case:", + val: AnyValue{ + Val: "", + }, + err: errs.NewErrInvalidType("uint64", reflect.TypeOf("").String()), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + av := AnyValue{ + Val: tt.val.Val, + Err: tt.val.Err, + } + got, err := av.Uint64() + assert.Equal(t, err, tt.err) + assert.Equal(t, got, tt.want) + }) + } +} + +func TestAnyValue_Uint64OrDefault(t *testing.T) { + tests := []struct { + name string + val AnyValue + def uint64 + want uint64 + }{ + { + name: "normal case:", + val: AnyValue{ + Val: uint64(1), + }, + want: 1, + }, + { + name: "default case:", + val: AnyValue{ + Val: uint64(1), + Err: errors.New("error"), + }, + def: uint64(2), + want: uint64(2), + }, + { + name: "type error case:", + val: AnyValue{ + Val: "", + }, + def: uint64(2), + want: uint64(2), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := AnyValue{ + Val: tt.val.Val, + Err: tt.val.Err, + } + assert.Equal(t, a.Uint64OrDefault(tt.def), tt.want) + }) + } +} + +func TestAnyValue_Float32(t *testing.T) { + tests := []struct { + name string + val AnyValue + want float32 + err error + }{ + { + name: "normal case:", + val: AnyValue{ + Val: float32(1), + }, + want: float32(1), + err: nil, + }, + { + name: "error case:", + val: AnyValue{ + Err: errors.New("error"), + }, + err: errors.New("error"), + }, + { + name: "type error case:", + val: AnyValue{ + Val: "", + }, + err: errs.NewErrInvalidType("float32", reflect.TypeOf("").String()), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + av := AnyValue{ + Val: tt.val.Val, + Err: tt.val.Err, + } + got, err := av.Float32() + assert.Equal(t, err, tt.err) + assert.Equal(t, got, tt.want) + }) + } +} + +func TestAnyValue_Float32OrDefault(t *testing.T) { + tests := []struct { + name string + val AnyValue + def float32 + want float32 + }{ + { + name: "normal case:", + val: AnyValue{ + Val: float32(1), + }, + want: 1, + }, + { + name: "default case:", + val: AnyValue{ + Val: float32(1), + Err: errors.New("error"), + }, + def: float32(2), + want: float32(2), + }, + { + name: "type error case:", + val: AnyValue{ + Val: "", + }, + def: float32(2), + want: float32(2), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := AnyValue{ + Val: tt.val.Val, + Err: tt.val.Err, + } + assert.Equal(t, a.Float32OrDefault(tt.def), tt.want) + }) + } +} + +func TestAnyValue_Float64(t *testing.T) { + + tests := []struct { + name string + val AnyValue + want float64 + err error + }{ + { + name: "normal case:", + val: AnyValue{ + Val: float64(1), + }, + want: float64(1), + err: nil, + }, + { + name: "error case:", + val: AnyValue{ + Err: errors.New("error"), + }, + err: errors.New("error"), + }, + { + name: "type error case:", + val: AnyValue{ + Val: "", + }, + err: errs.NewErrInvalidType("float64", reflect.TypeOf("").String()), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + av := AnyValue{ + Val: tt.val.Val, + Err: tt.val.Err, + } + got, err := av.Float64() + assert.Equal(t, err, tt.err) + assert.Equal(t, got, tt.want) + }) + } +} + +func TestAnyValue_Float64OrDefault(t *testing.T) { + tests := []struct { + name string + val AnyValue + def float64 + want float64 + }{ + { + name: "normal case:", + val: AnyValue{ + Val: float64(1), + }, + want: 1, + }, + { + name: "default case:", + val: AnyValue{ + Val: float64(1), + Err: errors.New("error"), + }, + def: float64(2), + want: float64(2), + }, + { + name: "type error case:", + val: AnyValue{ + Val: "", + }, + def: float64(2), + want: float64(2), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := AnyValue{ + Val: tt.val.Val, + Err: tt.val.Err, + } + assert.Equal(t, a.Float64OrDefault(tt.def), tt.want) + }) + } +} + +func TestAnyValue_String(t *testing.T) { + tests := []struct { + name string + val AnyValue + want string + err error + }{ + { + name: "normal case:", + val: AnyValue{ + Val: "111", + }, + want: "111", + err: nil, + }, + { + name: "error case:", + val: AnyValue{ + Err: errors.New("error"), + }, + err: errors.New("error"), + }, + { + name: "type error case:", + val: AnyValue{ + Val: 1, + }, + err: errs.NewErrInvalidType("string", reflect.TypeOf(111).String()), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + av := AnyValue{ + Val: tt.val.Val, + Err: tt.val.Err, + } + got, err := av.String() + assert.Equal(t, err, tt.err) + assert.Equal(t, got, tt.want) + }) + } +} + +func TestAnyValue_StringOrDefault(t *testing.T) { + tests := []struct { + name string + val AnyValue + def string + want string + }{ + { + name: "normal case:", + val: AnyValue{ + Val: "111", + }, + want: "111", + }, + { + name: "default case:", + val: AnyValue{ + Val: "111", + Err: errors.New("error"), + }, + def: "222", + want: "222", + }, + { + name: "type error case:", + val: AnyValue{ + Val: 1, + }, + def: "222", + want: "222", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := AnyValue{ + Val: tt.val.Val, + Err: tt.val.Err, + } + assert.Equal(t, a.StringOrDefault(tt.def), tt.want) + }) + } +} + +func TestAnyValue_Bytes(t *testing.T) { + tests := []struct { + name string + val AnyValue + want []byte + err error + }{ + { + name: "normal case:", + val: AnyValue{ + Val: []byte("111"), + }, + want: []byte("111"), + err: nil, + }, + { + name: "error case:", + val: AnyValue{ + Err: errors.New("error"), + }, + err: errors.New("error"), + }, + { + name: "type error case:", + val: AnyValue{ + Val: 1, + }, + err: errs.NewErrInvalidType("[]byte", reflect.TypeOf(111).String()), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + av := AnyValue{ + Val: tt.val.Val, + Err: tt.val.Err, + } + got, err := av.Bytes() + assert.Equal(t, err, tt.err) + assert.Equal(t, got, tt.want) + }) + } +} + +func TestAnyValue_BytesOrDefault(t *testing.T) { + tests := []struct { + name string + val AnyValue + def []byte + want []byte + }{ + { + name: "normal case:", + val: AnyValue{ + Val: []byte("111"), + }, + want: []byte("111"), + }, + { + name: "default case:", + val: AnyValue{ + Val: []byte("111"), + Err: errors.New("error"), + }, + def: []byte("222"), + want: []byte("222"), + }, + { + name: "type error case:", + val: AnyValue{ + Val: 1, + }, + def: []byte("222"), + want: []byte("222"), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := AnyValue{ + Val: tt.val.Val, + Err: tt.val.Err, + } + assert.Equal(t, a.BytesOrDefault(tt.def), tt.want) + }) + } +} From a2a10507a0c779b94ed93c734779e5c9daa15c14 Mon Sep 17 00:00:00 2001 From: Deng Ming Date: Mon, 25 Sep 2023 22:19:31 +0800 Subject: [PATCH 2/2] =?UTF-8?q?v0.0.8=20=E7=9A=84changelog?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.CHANGELOG.md b/.CHANGELOG.md index 91281f4f..2ffc5fdc 100644 --- a/.CHANGELOG.md +++ b/.CHANGELOG.md @@ -1,4 +1,6 @@ # 开发中 + +# v0.0.8 - [atomicx: 泛型封装 atomic.Value](https://github.com/gotomicro/ekit/pull/101) - [queue: API 定义](https://github.com/gotomicro/ekit/pull/109) - [queue: 基于堆和切片的优先级队列](https://github.com/gotomicro/ekit/pull/110)