From 0e4a5770e985e8f20ac67f0d8ab5887433640724 Mon Sep 17 00:00:00 2001 From: Ilya Ozherelyev Date: Fri, 8 Nov 2024 13:29:11 +0100 Subject: [PATCH] [common] Add unit test for convertions (#1395) * [common] Add unit test for convertions --- internal/common/convert.go | 27 +++-- internal/common/convert_test.go | 57 +++++++++++ internal/common/thrift_util.go | 54 +++------- internal/common/thrift_util_test.go | 150 ++++++++++++++++++++++++++++ 4 files changed, 238 insertions(+), 50 deletions(-) create mode 100644 internal/common/convert_test.go create mode 100644 internal/common/thrift_util_test.go diff --git a/internal/common/convert.go b/internal/common/convert.go index 0646de53f..2f30788c7 100644 --- a/internal/common/convert.go +++ b/internal/common/convert.go @@ -38,55 +38,60 @@ func Int64Ceil(v float64) int64 { // Int32Ptr makes a copy and returns the pointer to an int32. func Int32Ptr(v int32) *int32 { - return &v + return PtrOf(v) } // Float64Ptr makes a copy and returns the pointer to a float64. func Float64Ptr(v float64) *float64 { - return &v + return PtrOf(v) } // Int64Ptr makes a copy and returns the pointer to an int64. func Int64Ptr(v int64) *int64 { - return &v + return PtrOf(v) } // StringPtr makes a copy and returns the pointer to a string. func StringPtr(v string) *string { - return &v + return PtrOf(v) } // BoolPtr makes a copy and returns the pointer to a string. func BoolPtr(v bool) *bool { - return &v + return PtrOf(v) } // TaskListPtr makes a copy and returns the pointer to a TaskList. func TaskListPtr(v s.TaskList) *s.TaskList { - return &v + return PtrOf(v) } // DecisionTypePtr makes a copy and returns the pointer to a DecisionType. func DecisionTypePtr(t s.DecisionType) *s.DecisionType { - return &t + return PtrOf(t) } // EventTypePtr makes a copy and returns the pointer to a EventType. func EventTypePtr(t s.EventType) *s.EventType { - return &t + return PtrOf(t) } // QueryTaskCompletedTypePtr makes a copy and returns the pointer to a QueryTaskCompletedType. func QueryTaskCompletedTypePtr(t s.QueryTaskCompletedType) *s.QueryTaskCompletedType { - return &t + return PtrOf(t) } // TaskListKindPtr makes a copy and returns the pointer to a TaskListKind. func TaskListKindPtr(t s.TaskListKind) *s.TaskListKind { - return &t + return PtrOf(t) } // QueryResultTypePtr makes a copy and returns the pointer to a QueryResultType. func QueryResultTypePtr(t s.QueryResultType) *s.QueryResultType { - return &t + return PtrOf(t) +} + +// PtrOf makes a copy and returns the pointer to a value. +func PtrOf[T any](v T) *T { + return &v } diff --git a/internal/common/convert_test.go b/internal/common/convert_test.go new file mode 100644 index 000000000..f09c90191 --- /dev/null +++ b/internal/common/convert_test.go @@ -0,0 +1,57 @@ +// Copyright (c) 2017-2021 Uber Technologies Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package common + +import ( + "testing" + + s "go.uber.org/cadence/.gen/go/shared" + + "github.com/stretchr/testify/assert" +) + +func TestPtrOf(t *testing.T) { + assert.Equal(t, "a", *PtrOf("a")) + assert.Equal(t, 1, *PtrOf(1)) + assert.Equal(t, int32(1), *PtrOf(int32(1))) + assert.Equal(t, int64(1), *PtrOf(int64(1))) + assert.Equal(t, float64(1.1), *PtrOf(float64(1.1))) + assert.Equal(t, true, *PtrOf(true)) +} + +func TestPtrHelpers(t *testing.T) { + assert.Equal(t, int32(1), *Int32Ptr(1)) + assert.Equal(t, int64(1), *Int64Ptr(1)) + assert.Equal(t, 1.1, *Float64Ptr(1.1)) + assert.Equal(t, true, *BoolPtr(true)) + assert.Equal(t, "a", *StringPtr("a")) + assert.Equal(t, s.TaskList{Name: PtrOf("a")}, *TaskListPtr(s.TaskList{Name: PtrOf("a")})) + assert.Equal(t, s.DecisionTypeScheduleActivityTask, *DecisionTypePtr(s.DecisionTypeScheduleActivityTask)) + assert.Equal(t, s.EventTypeWorkflowExecutionStarted, *EventTypePtr(s.EventTypeWorkflowExecutionStarted)) + assert.Equal(t, s.QueryTaskCompletedTypeCompleted, *QueryTaskCompletedTypePtr(s.QueryTaskCompletedTypeCompleted)) + assert.Equal(t, s.TaskListKindNormal, *TaskListKindPtr(s.TaskListKindNormal)) + assert.Equal(t, s.QueryResultTypeFailed, *QueryResultTypePtr(s.QueryResultTypeFailed)) +} + +func TestCeilHelpers(t *testing.T) { + assert.Equal(t, int32(2), Int32Ceil(1.1)) + assert.Equal(t, int64(2), Int64Ceil(1.1)) +} diff --git a/internal/common/thrift_util.go b/internal/common/thrift_util.go index 34763293f..730127e62 100644 --- a/internal/common/thrift_util.go +++ b/internal/common/thrift_util.go @@ -27,19 +27,13 @@ import ( "github.com/apache/thrift/lib/go/thrift" ) -// TSerialize is used to serialize thrift TStruct to []byte -func TSerialize(ctx context.Context, t thrift.TStruct) (b []byte, err error) { - return thrift.NewTSerializer().Write(ctx, t) -} - // TListSerialize is used to serialize list of thrift TStruct to []byte -func TListSerialize(ts []thrift.TStruct) (b []byte, err error) { +func TListSerialize(ts []thrift.TStruct) ([]byte, error) { if ts == nil { - return + return nil, nil } t := thrift.NewTSerializer() - t.Transport.Reset() // NOTE: we don't write any markers as thrift by design being a streaming protocol doesn't // recommend writing length. @@ -48,26 +42,11 @@ func TListSerialize(ts []thrift.TStruct) (b []byte, err error) { ctx := context.Background() for _, v := range ts { if e := v.Write(ctx, t.Protocol); e != nil { - err = thrift.PrependError("error writing TStruct: ", e) - return + return nil, thrift.PrependError("error writing TStruct: ", e) } } - if err = t.Protocol.Flush(ctx); err != nil { - return - } - - if err = t.Transport.Flush(ctx); err != nil { - return - } - - b = t.Transport.Bytes() - return -} - -// TDeserialize is used to deserialize []byte to thrift TStruct -func TDeserialize(ctx context.Context, t thrift.TStruct, b []byte) (err error) { - return thrift.NewTDeserializer().Read(ctx, t, b) + return t.Transport.Bytes(), t.Protocol.Flush(ctx) } // TListDeserialize is used to deserialize []byte to list of thrift TStruct @@ -92,15 +71,13 @@ func TListDeserialize(ts []thrift.TStruct, b []byte) (err error) { // IsUseThriftEncoding checks if the objects passed in are all encoded using thrift. func IsUseThriftEncoding(objs []interface{}) bool { - // NOTE: our criteria to use which encoder is simple if all the types are serializable using thrift then we use - // thrift encoder. For everything else we default to gob. - if len(objs) == 0 { return false } - - for i := 0; i < len(objs); i++ { - if !IsThriftType(objs[i]) { + // NOTE: our criteria to use which encoder is simple if all the types are serializable using thrift then we use + // thrift encoder. For everything else we default to gob. + for _, obj := range objs { + if !IsThriftType(obj) { return false } } @@ -109,15 +86,13 @@ func IsUseThriftEncoding(objs []interface{}) bool { // IsUseThriftDecoding checks if the objects passed in are all de-serializable using thrift. func IsUseThriftDecoding(objs []interface{}) bool { - // NOTE: our criteria to use which encoder is simple if all the types are de-serializable using thrift then we use - // thrift decoder. For everything else we default to gob. - if len(objs) == 0 { return false } - - for i := 0; i < len(objs); i++ { - rVal := reflect.ValueOf(objs[i]) + // NOTE: our criteria to use which encoder is simple if all the types are de-serializable using thrift then we use + // thrift decoder. For everything else we default to gob. + for _, obj := range objs { + rVal := reflect.ValueOf(obj) if rVal.Kind() != reflect.Ptr || !IsThriftType(reflect.Indirect(rVal).Interface()) { return false } @@ -133,6 +108,7 @@ func IsThriftType(v interface{}) bool { if reflect.ValueOf(v).Kind() != reflect.Ptr { return false } - t := reflect.TypeOf((*thrift.TStruct)(nil)).Elem() - return reflect.TypeOf(v).Implements(t) + return reflect.TypeOf(v).Implements(tStructType) } + +var tStructType = reflect.TypeOf((*thrift.TStruct)(nil)).Elem() diff --git a/internal/common/thrift_util_test.go b/internal/common/thrift_util_test.go new file mode 100644 index 000000000..a73647455 --- /dev/null +++ b/internal/common/thrift_util_test.go @@ -0,0 +1,150 @@ +// Copyright (c) 2017-2021 Uber Technologies Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package common + +import ( + "context" + "testing" + + "github.com/apache/thrift/lib/go/thrift" + "github.com/stretchr/testify/assert" +) + +func TestTListSerialize(t *testing.T) { + t.Run("nil", func(t *testing.T) { + data, err := TListSerialize(nil) + assert.NoError(t, err) + assert.Nil(t, data) + }) + t.Run("normal", func(t *testing.T) { + ts := []thrift.TStruct{ + &mockThriftStruct{Field1: "value1", Field2: 1}, + &mockThriftStruct{Field1: "value2", Field2: 2}, + } + + _, err := TListSerialize(ts) + assert.NoError(t, err) + }) +} + +func TestTListDeserialize(t *testing.T) { + ts := []thrift.TStruct{ + &mockThriftStruct{}, + &mockThriftStruct{}, + } + + data, err := TListSerialize(ts) + assert.NoError(t, err) + + err = TListDeserialize(ts, data) + assert.NoError(t, err) +} + +func TestIsUseThriftEncoding(t *testing.T) { + for _, tc := range []struct { + name string + input []interface{} + expected bool + }{ + { + name: "nil", + input: nil, + expected: false, + }, + { + name: "success", + input: []interface{}{ + &mockThriftStruct{}, + &mockThriftStruct{}, + }, + expected: true, + }, + { + name: "fail", + input: []interface{}{ + &mockThriftStruct{}, + PtrOf("string"), + }, + expected: false, + }, + } { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.expected, IsUseThriftEncoding(tc.input)) + }) + } +} + +func TestIsUseThriftDecoding(t *testing.T) { + for _, tc := range []struct { + name string + input []interface{} + expected bool + }{ + { + name: "nil", + input: nil, + expected: false, + }, + { + name: "success", + input: []interface{}{ + PtrOf(&mockThriftStruct{}), + PtrOf(&mockThriftStruct{}), + }, + expected: true, + }, + { + name: "fail", + input: []interface{}{ + PtrOf(&mockThriftStruct{}), + PtrOf("string"), + }, + expected: false, + }, + } { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.expected, IsUseThriftDecoding(tc.input)) + }) + } +} + +func TestIsThriftType(t *testing.T) { + assert.True(t, IsThriftType(&mockThriftStruct{})) + + assert.False(t, IsThriftType(mockThriftStruct{})) +} + +type mockThriftStruct struct { + Field1 string + Field2 int +} + +func (m *mockThriftStruct) Read(ctx context.Context, iprot thrift.TProtocol) error { + return nil +} + +func (m *mockThriftStruct) Write(ctx context.Context, oprot thrift.TProtocol) error { + return nil +} + +func (m *mockThriftStruct) String() string { + return "" +}