Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(testutil): add unit tests and remove unused functions #937

Merged
merged 8 commits into from
Feb 3, 2025
2 changes: 1 addition & 1 deletion internal/storagenode/client/log_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ type SubscribeResult struct {
}

var InvalidSubscribeResult = SubscribeResult{
LogEntry: varlogpb.InvalidLogEntry(),
LogEntry: varlogpb.LogEntry{},
Error: errors.New("invalid subscribe result"),
}

Expand Down
82 changes: 82 additions & 0 deletions pkg/util/netutil/netutil_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,97 @@ package netutil

import (
"context"
"net"
"strconv"
"sync"
"testing"
"time"

"github.com/stretchr/testify/require"
"go.uber.org/goleak"

"github.com/kakao/varlog/pkg/types"
"github.com/kakao/varlog/pkg/verrors"
)

func TestStoppableListener(t *testing.T) {
tcs := []struct {
name string
addr string
wantErr bool
}{
{
name: "ValidAddress",
addr: "127.0.0.1:0",
wantErr: false,
},
{
name: "InvalidAddress",
addr: "127.0.0.1:-1",
wantErr: true,
},
}

for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
lis, err := NewStoppableListener(context.Background(), tc.addr)
if tc.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
require.NotNil(t, lis)

err = lis.Close()
require.NoError(t, err)
})
}
}

func TestStoppableListener_AcceptStopped(t *testing.T) {
const expireDuration = 10 * time.Millisecond

ctx, cancel := context.WithTimeout(context.Background(), expireDuration)
defer cancel()

lis, err := NewStoppableListener(ctx, "127.0.0.1:0")
require.NoError(t, err)
t.Cleanup(func() {
err := lis.Close()
require.NoError(t, err)
})

_, err = lis.Accept()
require.Equal(t, verrors.ErrStopped, err)
}

func TestStoppableListener_AcceptSucceed(t *testing.T) {
lis, err := NewStoppableListener(context.Background(), "127.0.0.1:0")
require.NoError(t, err)
t.Cleanup(func() {
err := lis.Close()
require.NoError(t, err)
})

addr := lis.Addr().String()
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
conn, err := net.Dial("tcp", addr)
require.NoError(t, err)
err = conn.Close()
require.NoError(t, err)
}()

conn, err := lis.Accept()
require.NoError(t, err)
require.NotNil(t, conn)

err = conn.Close()
require.NoError(t, err)
}

func TestGetListenerAddr(t *testing.T) {
tests := []struct {
in string
Expand Down
58 changes: 0 additions & 58 deletions pkg/util/testutil/testutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ import (
"strings"
"time"

"github.com/pkg/errors"

"github.com/kakao/varlog/internal/vtesting"
)

Expand Down Expand Up @@ -36,62 +34,6 @@ func CompareWaitN(factor int64, cmp func() bool) bool {
return CompareWait(cmp, vtesting.TimeoutUnitTimesFactor(factor))
}

func CompareWait100(cmp func() bool) bool {
return CompareWaitN(100, cmp)
}

func CompareWait10(cmp func() bool) bool {
return CompareWaitN(10, cmp)
}

func CompareWait1(cmp func() bool) bool {
return CompareWaitN(1, cmp)
}

func CompareWaitErrorWithRetryInterval(cmp func() (bool, error), timeout time.Duration, retryInterval time.Duration) error {
after := time.NewTimer(timeout)
defer after.Stop()

numTries := 0
for {
select {
case <-after.C:
return errors.Errorf("compare wait timeout (%s,tries=%d)", timeout.String(), numTries)
default:
numTries++
ok, err := cmp()
if err != nil {
return err
}

if ok {
return nil
}
time.Sleep(retryInterval)
}
}
}

func CompareWaitError(cmp func() (bool, error), timeout time.Duration) error {
return CompareWaitErrorWithRetryInterval(cmp, timeout, time.Millisecond)
}

func CompareWaitErrorWithRetryIntervalN(factor int64, retryInterval time.Duration, cmp func() (bool, error)) error {
if factor < 1 {
factor = 1
}

return CompareWaitErrorWithRetryInterval(cmp, vtesting.TimeoutUnitTimesFactor(factor), retryInterval)
}

func CompareWaitErrorN(factor int64, cmp func() (bool, error)) error {
if factor < 1 {
factor = 1
}

return CompareWaitError(cmp, vtesting.TimeoutUnitTimesFactor(factor))
}

func GetFunctionName(i interface{}) string {
a := runtime.FuncForPC(reflect.ValueOf(i).Pointer()).Name()
s := strings.Split(a, "/")
Expand Down
62 changes: 62 additions & 0 deletions pkg/util/testutil/testutil_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package testutil_test

import (
"testing"
"time"

"github.com/stretchr/testify/require"

"github.com/kakao/varlog/pkg/util/testutil"
)

func TestCompareWait(t *testing.T) {
tcs := []struct {
cmp func() bool
want bool
}{
{cmp: func() bool { return true }, want: true},
{cmp: func() bool { return false }, want: false},
}

for _, tc := range tcs {
got := testutil.CompareWait(tc.cmp, time.Second)
require.Equal(t, tc.want, got)
}
}

func TestCompareWaitN_Factor0(t *testing.T) {
ts := time.Now()
testutil.CompareWaitN(0, func() bool {
return false
})
factor0 := time.Since(ts)

ts = time.Now()
testutil.CompareWaitN(1, func() bool {
return false
})
factor1 := time.Since(ts)

require.InEpsilon(t, 1.0, factor1/factor0, float64((10 * time.Millisecond).Nanoseconds()))
}

func TestCompareWaitN_Factor2(t *testing.T) {
ts := time.Now()
testutil.CompareWaitN(1, func() bool {
return false
})
factor1 := time.Since(ts)

ts = time.Now()
testutil.CompareWaitN(2, func() bool {
return false
})
factor2 := time.Since(ts)

require.InEpsilon(t, 2.0, factor2/factor1, float64((10 * time.Millisecond).Nanoseconds()))
}

func TestGetFunctionName(t *testing.T) {
got := testutil.GetFunctionName(TestGetFunctionName)
require.Equal(t, "testutil_test.TestGetFunctionName", got)
}
4 changes: 2 additions & 2 deletions pkg/varlog/subscribe.go
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ func (p *dispatcher) dispatch(_ context.Context) {
sentErr = sentErr || res.Error != nil
}
if !sentErr {
p.onNextFunc(varlogpb.InvalidLogEntry(), io.EOF)
p.onNextFunc(varlogpb.LogEntry{}, io.EOF)
}
}

Expand All @@ -532,7 +532,7 @@ type invalidSubscriber struct {
}

func (s invalidSubscriber) Next() (varlogpb.LogEntry, error) {
return varlogpb.InvalidLogEntry(), s.err
return varlogpb.LogEntry{}, s.err
}

func (s invalidSubscriber) Close() error {
Expand Down
7 changes: 3 additions & 4 deletions pkg/varlogtest/admin.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ func (c *testAdmin) ListStorageNodes(ctx context.Context, opts ...varlog.AdminCa

return ret, nil
}

func (c *testAdmin) GetStorageNodes(ctx context.Context, opts ...varlog.AdminCallOption) (map[types.StorageNodeID]admpb.StorageNodeMetadata, error) {
snms, err := c.ListStorageNodes(ctx)
if err != nil {
Expand Down Expand Up @@ -182,8 +183,7 @@ func (c *testAdmin) AddTopic(ctx context.Context, opts ...varlog.AdminCallOption
c.vt.topics[topicID] = topicDesc
c.vt.trimGLSNs[topicID] = types.InvalidGLSN

invalidLogEntry := varlogpb.InvalidLogEntry()
c.vt.globalLogEntries[topicID] = []*varlogpb.LogEntry{&invalidLogEntry}
c.vt.globalLogEntries[topicID] = []*varlogpb.LogEntry{{}}

return proto.Clone(&topicDesc).(*varlogpb.TopicDescriptor), nil
}
Expand Down Expand Up @@ -294,8 +294,7 @@ func (c *testAdmin) AddLogStream(_ context.Context, topicID types.TopicID, logSt

c.vt.logStreams[logStreamID] = lsd

invalidLogEntry := varlogpb.InvalidLogEntry()
c.vt.localLogEntries[logStreamID] = []*varlogpb.LogEntry{&invalidLogEntry}
c.vt.localLogEntries[logStreamID] = []*varlogpb.LogEntry{{}}

topicDesc.LogStreams = append(topicDesc.LogStreams, logStreamID)
c.vt.topics[topicID] = topicDesc
Expand Down
7 changes: 3 additions & 4 deletions pkg/varlogtest/log.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ func (c *testLog) Subscribe(ctx context.Context, topicID types.TopicID, begin ty
for _, logEntry := range copiedLogEntries {
onNextFunc(logEntry, nil)
}
onNextFunc(varlogpb.InvalidLogEntry(), io.EOF)
onNextFunc(varlogpb.LogEntry{}, io.EOF)
}()

return func() {
Expand Down Expand Up @@ -319,7 +319,6 @@ func (c *testLog) PeekLogStream(ctx context.Context, tpid types.TopicID, lsid ty
GLSN: tail.GLSN,
}
return first, last, nil

}

func (c *testLog) AppendableLogStreams(tpid types.TopicID) map[types.LogStreamID]struct{} {
Expand Down Expand Up @@ -463,7 +462,7 @@ func newErrSubscriber(err error) *errSubscriber {
}

func (s errSubscriber) Next() (varlogpb.LogEntry, error) {
return varlogpb.InvalidLogEntry(), s.err
return varlogpb.LogEntry{}, s.err
}

func (s errSubscriber) Close() error {
Expand Down Expand Up @@ -496,7 +495,7 @@ func (s *subscriberImpl) Next() (varlogpb.LogEntry, error) {
logEntry, err := s.next()
if err != nil {
s.setErr(err)
return varlogpb.InvalidLogEntry(), err
return varlogpb.LogEntry{}, err
}
if s.cursor == s.end {
s.setErr(io.EOF)
Expand Down
14 changes: 14 additions & 0 deletions proto/snpb/metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ import (
"github.com/kakao/varlog/proto/varlogpb"
)

// ToStorageNodeDescriptor converts a StorageNodeMetadataDescriptor to a
// varlogpb.StorageNodeDescriptor. It returns nil if the
// StorageNodeMetadataDescriptor is nil.
func (snmd *StorageNodeMetadataDescriptor) ToStorageNodeDescriptor() *varlogpb.StorageNodeDescriptor {
if snmd == nil {
return nil
Expand All @@ -19,6 +22,9 @@ func (snmd *StorageNodeMetadataDescriptor) ToStorageNodeDescriptor() *varlogpb.S
return snd
}

// GetLogStream retrieves a LogStreamReplicaMetadataDescriptor by its
// LogStreamID. It returns the LogStreamReplicaMetadataDescriptor and true if
// found, otherwise an empty descriptor and false.
func (snmd *StorageNodeMetadataDescriptor) GetLogStream(logStreamID types.LogStreamID) (LogStreamReplicaMetadataDescriptor, bool) {
logStreams := snmd.GetLogStreamReplicas()
for i := range logStreams {
Expand All @@ -29,6 +35,10 @@ func (snmd *StorageNodeMetadataDescriptor) GetLogStream(logStreamID types.LogStr
return LogStreamReplicaMetadataDescriptor{}, false
}

// Head returns the varlogpb.LogEntryMeta corresponding to the local low
// watermark of the LogStreamReplicaMetadataDescriptor. The "head" represents
// the earliest log entry in the log stream replica. It returns an empty
// varlogpb.LogEntryMeta if the LogStreamReplicaMetadataDescriptor is nil.
func (lsrmd *LogStreamReplicaMetadataDescriptor) Head() varlogpb.LogEntryMeta {
if lsrmd == nil {
return varlogpb.LogEntryMeta{}
Expand All @@ -41,6 +51,10 @@ func (lsrmd *LogStreamReplicaMetadataDescriptor) Head() varlogpb.LogEntryMeta {
}
}

// Tail returns the varlogpb.LogEntryMeta corresponding to the local high
// watermark of the LogStreamReplicaMetadataDescriptor. The "tail" represents
// the latest log entry in the log stream replica. It returns an empty
// varlogpb.LogEntryMeta if the LogStreamReplicaMetadataDescriptor is nil.
func (lsrmd *LogStreamReplicaMetadataDescriptor) Tail() varlogpb.LogEntryMeta {
if lsrmd == nil {
return varlogpb.LogEntryMeta{}
Expand Down
Loading
Loading