diff --git a/crencoding/var_int_test.go b/crencoding/var_int_test.go index 214e013..0493093 100644 --- a/crencoding/var_int_test.go +++ b/crencoding/var_int_test.go @@ -19,19 +19,19 @@ import ( "math" "math/rand/v2" "testing" + + "github.com/cockroachdb/crlib/testutils/require" ) // TestUvarintLen tests UvarintLen32 and UvarintLen64. func TestUvarintLen(t *testing.T) { check := func(n uint64) { + t := require.WithMsgf(t, "n=%d", n) res64 := UvarintLen64(n) - if expected := len(binary.AppendUvarint(nil, n)); res64 != expected { - t.Fatalf("invalid result for %d: %d instead of %d", n, res64, expected) - } + require.Equal(t, res64, len(binary.AppendUvarint(nil, n))) + res32 := UvarintLen32(uint32(n)) - if expected := len(binary.AppendUvarint(nil, uint64(uint32(n)))); res32 != expected { - t.Fatalf("invalid result for %d: %d instead of %d", n, res32, expected) - } + require.Equal(t, res32, len(binary.AppendUvarint(nil, uint64(uint32(n))))) } check(0) check(math.MaxUint64) diff --git a/crstrings/utils_test.go b/crstrings/utils_test.go index 2f86f7e..7e8eb92 100644 --- a/crstrings/utils_test.go +++ b/crstrings/utils_test.go @@ -18,6 +18,8 @@ import ( "fmt" "strings" "testing" + + "github.com/cockroachdb/crlib/testutils/require" ) type num int @@ -28,10 +30,10 @@ func (n num) String() string { func TestJoinStringers(t *testing.T) { nums := []num{0, 1, 2, 3} - expect(t, "", JoinStringers(", ", nums[:0]...)) - expect(t, "000", JoinStringers(", ", nums[0])) - expect(t, "000, 001", JoinStringers(", ", nums[0], nums[1])) - expect(t, "000, 001, 002, 003", JoinStringers(", ", nums...)) + require.Equal(t, "", JoinStringers(", ", nums[:0]...)) + require.Equal(t, "000", JoinStringers(", ", nums[0])) + require.Equal(t, "000, 001", JoinStringers(", ", nums[0], nums[1])) + require.Equal(t, "000, 001, 002, 003", JoinStringers(", ", nums...)) } func TestMapAndJoin(t *testing.T) { @@ -39,45 +41,38 @@ func TestMapAndJoin(t *testing.T) { fn := func(n int) string { return fmt.Sprintf("%d", n) } - expect(t, "", MapAndJoin(fn, ", ", nums[:0]...)) - expect(t, "0", MapAndJoin(fn, ", ", nums[0])) - expect(t, "0, 1", MapAndJoin(fn, ", ", nums[0], nums[1])) - expect(t, "0, 1, 2, 3", MapAndJoin(fn, ", ", nums...)) -} - -func expect(t *testing.T, expected, actual string) { - t.Helper() - if actual != expected { - t.Errorf("expected %q got %q", expected, actual) - } + require.Equal(t, "", MapAndJoin(fn, ", ", nums[:0]...)) + require.Equal(t, "0", MapAndJoin(fn, ", ", nums[0])) + require.Equal(t, "0, 1", MapAndJoin(fn, ", ", nums[0], nums[1])) + require.Equal(t, "0, 1, 2, 3", MapAndJoin(fn, ", ", nums...)) } func TestIf(t *testing.T) { - expect(t, "", If(false, "true")) - expect(t, "true", If(true, "true")) + require.Equal(t, "", If(false, "true")) + require.Equal(t, "true", If(true, "true")) } func TestIfElse(t *testing.T) { - expect(t, "false", IfElse(false, "true", "false")) - expect(t, "true", IfElse(true, "true", "false")) + require.Equal(t, "false", IfElse(false, "true", "false")) + require.Equal(t, "true", IfElse(true, "true", "false")) } func TestWithSep(t *testing.T) { - expect(t, "a,b", WithSep("a", ",", "b")) - expect(t, "a", WithSep("a", ",", "")) - expect(t, "b", WithSep("", ",", "b")) + require.Equal(t, "a,b", WithSep("a", ",", "b")) + require.Equal(t, "a", WithSep("a", ",", "")) + require.Equal(t, "b", WithSep("", ",", "b")) } func TestFilterEmpty(t *testing.T) { s := []string{"a", "", "b", "", "c", ""} - expect(t, "a,b,c", strings.Join(FilterEmpty(s), ",")) + require.Equal(t, "a,b,c", strings.Join(FilterEmpty(s), ",")) } func TestLines(t *testing.T) { - expect(t, `["a" "b" "c"]`, fmt.Sprintf("%q", Lines("a\nb\nc"))) - expect(t, `["a" "b" "c"]`, fmt.Sprintf("%q", Lines("a\nb\nc\n"))) - expect(t, `["a" "b" "c" ""]`, fmt.Sprintf("%q", Lines("a\nb\nc\n\n"))) - expect(t, `["" "a" "b" "c"]`, fmt.Sprintf("%q", Lines("\na\nb\nc\n"))) - expect(t, `[]`, fmt.Sprintf("%q", Lines(""))) - expect(t, `[]`, fmt.Sprintf("%q", Lines("\n"))) + require.Equal(t, `["a" "b" "c"]`, fmt.Sprintf("%q", Lines("a\nb\nc"))) + require.Equal(t, `["a" "b" "c"]`, fmt.Sprintf("%q", Lines("a\nb\nc\n"))) + require.Equal(t, `["a" "b" "c" ""]`, fmt.Sprintf("%q", Lines("a\nb\nc\n\n"))) + require.Equal(t, `["" "a" "b" "c"]`, fmt.Sprintf("%q", Lines("\na\nb\nc\n"))) + require.Equal(t, `[]`, fmt.Sprintf("%q", Lines(""))) + require.Equal(t, `[]`, fmt.Sprintf("%q", Lines("\n"))) } diff --git a/crtime/monotonic_test.go b/crtime/monotonic_test.go index d386f90..3d484ad 100644 --- a/crtime/monotonic_test.go +++ b/crtime/monotonic_test.go @@ -17,18 +17,17 @@ package crtime import ( "testing" "time" + + "github.com/cockroachdb/crlib/testutils/require" ) func TestMono(t *testing.T) { a := NowMono() time.Sleep(10 * time.Millisecond) b := NowMono() - if delta := b.Sub(a); delta < 9*time.Millisecond { - t.Errorf("expected 10+ms, got %s", delta) - } + require.GE(t, b.Sub(a), 9*time.Millisecond) c := MonoFromTime(time.Now()) d := NowMono() - if c < b || c > d { - t.Errorf("expected %d <= %d <= %d", b, c, d) - } + require.LE(t, b, c) + require.LE(t, c, d) } diff --git a/fifo/queue_test.go b/fifo/queue_test.go index ffc72ec..f8c3419 100644 --- a/fifo/queue_test.go +++ b/fifo/queue_test.go @@ -17,34 +17,36 @@ package fifo import ( "math/rand" "testing" + + "github.com/cockroachdb/crlib/testutils/require" ) var pool = MakeQueueBackingPool[int]() func TestQueue(t *testing.T) { q := MakeQueue[int](&pool) - requireEqual(t, q.PeekFront(), nil) - requireEqual(t, q.Len(), 0) + require.Equal(t, q.PeekFront(), nil) + require.Equal(t, q.Len(), 0) q.PushBack(1) q.PushBack(2) q.PushBack(3) - requireEqual(t, q.Len(), 3) - requireEqual(t, *q.PeekFront(), 1) + require.Equal(t, q.Len(), 3) + require.Equal(t, *q.PeekFront(), 1) q.PopFront() - requireEqual(t, *q.PeekFront(), 2) + require.Equal(t, *q.PeekFront(), 2) q.PopFront() - requireEqual(t, *q.PeekFront(), 3) + require.Equal(t, *q.PeekFront(), 3) q.PopFront() - requireEqual(t, q.PeekFront(), nil) + require.Equal(t, q.PeekFront(), nil) for i := 1; i <= 1000; i++ { q.PushBack(i) - requireEqual(t, q.Len(), i) + require.Equal(t, q.Len(), i) } for i := 1; i <= 1000; i++ { - requireEqual(t, *q.PeekFront(), i) + require.Equal(t, *q.PeekFront(), i) q.PopFront() - requireEqual(t, q.Len(), 1000-i) + require.Equal(t, q.Len(), 1000-i) } } @@ -55,20 +57,13 @@ func TestQueueRand(t *testing.T) { for n := rand.Intn(100); n > 0; n-- { r++ q.PushBack(r) - requireEqual(t, q.Len(), r-l) + require.Equal(t, q.Len(), r-l) } for n := rand.Intn(q.Len() + 1); n > 0; n-- { l++ - requireEqual(t, *q.PeekFront(), l) + require.Equal(t, *q.PeekFront(), l) q.PopFront() - requireEqual(t, q.Len(), r-l) + require.Equal(t, q.Len(), r-l) } } } - -func requireEqual[T comparable](t *testing.T, actual, expected T) { - t.Helper() - if actual != expected { - t.Fatalf("expected %v, but found %v", expected, actual) - } -} diff --git a/fifo/semaphore_test.go b/fifo/semaphore_test.go index 46a7c90..2cb9243 100644 --- a/fifo/semaphore_test.go +++ b/fifo/semaphore_test.go @@ -22,14 +22,16 @@ import ( "sync" "testing" "time" + + "github.com/cockroachdb/crlib/testutils/require" ) func TestSemaphoreAPI(t *testing.T) { s := NewSemaphore(10) - requireEqual(t, s.TryAcquire(5), true) - requireEqual(t, s.TryAcquire(10), false) - requireEqual(t, s.Acquire(context.Background(), 20), ErrRequestExceedsCapacity) - requireEqual(t, "capacity: 10, outstanding: 5, num-had-to-wait: 0", s.Stats().String()) + require.Equal(t, s.TryAcquire(5), true) + require.Equal(t, s.TryAcquire(10), false) + require.Equal(t, s.Acquire(context.Background(), 20), ErrRequestExceedsCapacity) + require.Equal(t, "capacity: 10, outstanding: 5, num-had-to-wait: 0", s.Stats().String()) ch := make(chan struct{}, 10) go func() { @@ -46,15 +48,15 @@ func TestSemaphoreAPI(t *testing.T) { } ch <- struct{}{} }() - assertNoRecv(t, ch) + require.NoRecv(t, ch) s.Release(5) - assertRecv(t, ch) - assertRecv(t, ch) - assertNoRecv(t, ch) + require.Recv(t, ch) + require.Recv(t, ch) + require.NoRecv(t, ch) s.Release(1) - assertNoRecv(t, ch) + require.NoRecv(t, ch) s.Release(8) - assertRecv(t, ch) + require.Recv(t, ch) // Test UpdateCapacity. go func() { @@ -66,16 +68,16 @@ func TestSemaphoreAPI(t *testing.T) { t.Error(err) } ch <- struct{}{} - requireEqual(t, s.Acquire(context.Background(), 5), ErrRequestExceedsCapacity) + require.Equal(t, s.Acquire(context.Background(), 5), ErrRequestExceedsCapacity) ch <- struct{}{} }() - assertNoRecv(t, ch) + require.NoRecv(t, ch) s.UpdateCapacity(15) - assertRecv(t, ch) - assertRecv(t, ch) - assertNoRecv(t, ch) + require.Recv(t, ch) + require.Recv(t, ch) + require.NoRecv(t, ch) s.UpdateCapacity(2) - assertRecv(t, ch) + require.Recv(t, ch) } // TestSemaphoreBasic is a test with multiple goroutines acquiring a unit and @@ -103,7 +105,7 @@ func TestSemaphoreBasic(t *testing.T) { } for i := 0; i < numGoroutines; i++ { - if err := assertRecv(t, resCh); err != nil { + if err := require.Recv(t, resCh); err != nil { t.Fatal(err) } } @@ -132,14 +134,14 @@ func TestSemaphoreContextCancellation(t *testing.T) { cancel() - err := assertRecv(t, errCh) + err := require.Recv(t, errCh) if !errors.Is(err, context.Canceled) { t.Fatalf("expected context cancellation error, got %v", err) } stats := s.Stats() - requireEqual(t, stats.Capacity, 1) - requireEqual(t, stats.Outstanding, 1) + require.Equal(t, stats.Capacity, 1) + require.Equal(t, stats.Outstanding, 1) } // TestSemaphoreCanceledAcquisitions tests the behavior where we enqueue @@ -163,7 +165,7 @@ func TestSemaphoreCanceledAcquisitions(t *testing.T) { } for i := 0; i < numGoroutines; i++ { - if err := assertRecv(t, errCh); !errors.Is(err, context.Canceled) { + if err := require.Recv(t, errCh); !errors.Is(err, context.Canceled) { t.Fatalf("expected context cancellation error, got %v", err) } } @@ -173,7 +175,7 @@ func TestSemaphoreCanceledAcquisitions(t *testing.T) { errCh <- s.Acquire(context.Background(), 1) }() - if err := assertRecv(t, errCh); err != nil { + if err := require.Recv(t, errCh); err != nil { t.Fatal(err) } } @@ -206,25 +208,25 @@ func TestSemaphoreNumHadToWait(t *testing.T) { } } // Initially s should have no waiters. - requireEqual(t, s.Stats().NumHadToWait, 0) + require.Equal(t, s.Stats().NumHadToWait, 0) if err := s.Acquire(ctx, 1); err != nil { t.Fatal(err) } // Still no waiters. - requireEqual(t, s.Stats().NumHadToWait, 0) + require.Equal(t, s.Stats().NumHadToWait, 0) for i := 0; i < 10; i++ { go doAcquire(ctx) } assertNumWaitersSoon(10) s.Release(1) - assertRecv(t, doneCh) + require.Recv(t, doneCh) go doAcquire(ctx) assertNumWaitersSoon(11) for i := 0; i < 10; i++ { s.Release(1) - assertRecv(t, doneCh) + require.Recv(t, doneCh) } - requireEqual(t, s.Stats().NumHadToWait, 11) + require.Equal(t, s.Stats().NumHadToWait, 11) } func TestConcurrentUpdatesAndAcquisitions(t *testing.T) { @@ -258,26 +260,6 @@ func TestConcurrentUpdatesAndAcquisitions(t *testing.T) { wg.Wait() s.UpdateCapacity(maxCap) stats := s.Stats() - requireEqual(t, stats.Capacity, 100) - requireEqual(t, stats.Outstanding, 0) -} - -func assertRecv[T any](t *testing.T, ch chan T) T { - t.Helper() - select { - case v := <-ch: - return v - case <-time.After(time.Second): - t.Fatal("did not receive notification") - panic("unreachable") - } -} - -func assertNoRecv[T any](t *testing.T, ch chan T) { - t.Helper() - select { - case <-ch: - t.Fatal("received unexpected notification") - case <-time.After(10 * time.Millisecond): - } + require.Equal(t, stats.Capacity, 100) + require.Equal(t, stats.Outstanding, 0) } diff --git a/testutils/require/channels.go b/testutils/require/channels.go new file mode 100644 index 0000000..4b7e451 --- /dev/null +++ b/testutils/require/channels.go @@ -0,0 +1,64 @@ +// Copyright 2024 The Cockroach Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +// implied. See the License for the specific language governing +// permissions and limitations under the License. + +package require + +import "time" + +// Recv asserts that a value is received on the channel within the specified +// duration within 1 second and returns that value. +func Recv[T any](tb TB, ch chan T) T { + select { + case v := <-ch: + return v + case <-time.After(1 * time.Second): + tb.Helper() + tb.Fatal("did not receive on channel") + panic("unreachable") + } +} + +// RecvWithin asserts that a value is received on the channel within the specified +// duration, and returns that value. +func RecvWithin[T any](tb TB, ch chan T, within time.Duration) T { + select { + case v := <-ch: + return v + case <-time.After(within): + tb.Helper() + tb.Fatal("did not receive on channel") + panic("unreachable") + } +} + +// NoRecv asserts that no value is received on the channel within 10ms. +func NoRecv[T any](tb TB, ch chan T) { + select { + case <-ch: + tb.Helper() + tb.Fatal("received unexpected notification") + case <-time.After(10 * time.Millisecond): + } +} + +// NoRecvWithin asserts that no value is received on the channel within the +// specified duration. +func NoRecvWithin[T any](tb TB, ch chan T, within time.Duration) { + select { + case <-ch: + tb.Helper() + tb.Fatal("received unexpected notification") + case <-time.After(within): + } +} diff --git a/testutils/require/comparisons.go b/testutils/require/comparisons.go new file mode 100644 index 0000000..7ddfb89 --- /dev/null +++ b/testutils/require/comparisons.go @@ -0,0 +1,51 @@ +// Copyright 2024 The Cockroach Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +// implied. See the License for the specific language governing +// permissions and limitations under the License. + +package require + +// LT asserts that a < b. +func LT[T ordered](tb TB, a, b T) { + if !(a < b) { + tb.Helper() + tb.Fatalf("expected %v < %v", a, b) + } +} + +// LE asserts that a <= b. +func LE[T ordered](tb TB, a, b T) { + if !(a <= b) { + tb.Helper() + tb.Fatalf("expected %v <= %v", a, b) + } +} + +// GT asserts that a > b. +func GT[T ordered](tb TB, a, b T) { + if !(a > b) { + tb.Helper() + tb.Fatalf("expected %v > %v", a, b) + } +} + +// GE asserts that a >= b. +func GE[T ordered](tb TB, a, b T) { + if !(a >= b) { + tb.Helper() + tb.Fatalf("expected %v >= %v", a, b) + } +} + +type ordered interface { + ~int | ~int8 | ~int16 | ~int32 | ~int64 | ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~uintptr | ~float32 | ~float64 | ~string +} diff --git a/testutils/require/doc.go b/testutils/require/doc.go new file mode 100644 index 0000000..2bbc7ab --- /dev/null +++ b/testutils/require/doc.go @@ -0,0 +1,52 @@ +// Copyright 2024 The Cockroach Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +// implied. See the License for the specific language governing +// permissions and limitations under the License. + +/* +Package require implements convenience wrappers around checking conditions and +failing tests. + +The interface is inspired from `github.com/stretchr/testify/require` but the +implementation is simpler and uses generics. The benefit of generics is that we +don't have to add casts to make the types match, e.g. +[require.Equal](t, uint32Var, 2). + +Failed assertions result in a t.Fatal() call. + +# Equality + + - [require.Equal] + - [require.NotEqual] + - [require.True] + - [require.False] + +# Comparisons + + - [require.LT] + - [require.LE] + - [require.GT] + - [require.GE] + +# Channels + + - [require.Recv], [require.RecvWithin] + - [require.NoRecv], [require.NoRecvWithin] + +# Errors + - [require.NoError] + - [require.NoError1], [require.NoError2] + +# Including info in error messages + - [require.WithMsg], [require.WithMsgf] +*/ +package require diff --git a/testutils/require/equality.go b/testutils/require/equality.go new file mode 100644 index 0000000..ac9bc0a --- /dev/null +++ b/testutils/require/equality.go @@ -0,0 +1,58 @@ +// Copyright 2024 The Cockroach Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +// implied. See the License for the specific language governing +// permissions and limitations under the License. + +package require + +import ( + "fmt" + "reflect" +) + +// Equal asserts that a and b are deeply equal. +func Equal[T any](tb TB, a, b T) { + if !reflect.DeepEqual(a, b) { + tb.Helper() + aStr := fmt.Sprint(a) + bStr := fmt.Sprint(b) + if len(aStr)+len(bStr) > 80 { + tb.Fatalf("expected equality:\n a: %s\n b: %s", aStr, bStr) + } else { + tb.Fatalf("expected %s == %s", aStr, bStr) + } + } +} + +// NotEqual asserts that a and b are deeply equal. +func NotEqual[T any](tb TB, a, b T) { + if reflect.DeepEqual(a, b) { + tb.Helper() + tb.Fatalf("expected %v != %v", a, b) + } +} + +// True asserts that the value is true. +func True[T ~bool](tb TB, a T) { + if !a { + tb.Helper() + tb.Fatalf("expected true") + } +} + +// False asserts that the value is false. +func False[T ~bool](tb TB, a T) { + if a { + tb.Helper() + tb.Fatalf("expected false") + } +} diff --git a/testutils/require/errors.go b/testutils/require/errors.go new file mode 100644 index 0000000..75668df --- /dev/null +++ b/testutils/require/errors.go @@ -0,0 +1,67 @@ +// Copyright 2024 The Cockroach Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +// implied. See the License for the specific language governing +// permissions and limitations under the License. + +package require + +import "fmt" + +// NoError asserts that err is nil. +func NoError(tb TB, err error) { + if err != nil { + tb.Helper() + tb.Fatalf("unexpected error: %v", err) + } +} + +// NoError1 is passed an arbitrary value and an error and panics if the error is +// not-nil, otherwise returns the value. It can be used to get the return value +// of a fallible function that must succeed. +// +// Instead of: +// +// v, err := SomeFunc() +// if err != nil { +// t.Fatal(err) +// } +// +// We can use: +// +// v := require.NoError1(SomeFunc()) +func NoError1[T any](a T, err error) T { + if err != nil { + panic(fmt.Sprintf("unexpected error: %+v", err)) + } + return a +} + +// NoError2 is passed two arbitrary values and an error and panics if the error +// is not-nil, otherwise returns the values. It can be used to get the return +// values of a fallible function that must succeed. +// +// Instead of: +// +// v, w, err := SomeFunc() +// if err != nil { +// t.Fatal(err) +// } +// +// We can use: +// +// v, w := require.NoError2(SomeFunc()) +func NoError2[T any, U any](a T, b U, err error) (T, U) { + if err != nil { + panic(fmt.Sprintf("unexpected error: %+v", err)) + } + return a, b +} diff --git a/testutils/require/require.go b/testutils/require/require.go new file mode 100644 index 0000000..2d0336f --- /dev/null +++ b/testutils/require/require.go @@ -0,0 +1,101 @@ +// Copyright 2024 The Cockroach Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +// implied. See the License for the specific language governing +// permissions and limitations under the License. + +package require + +import "fmt" + +// TB is an interface common to *testing.T and *testing.B. +type TB interface { + Error(args ...any) + Errorf(format string, args ...any) + Fatal(args ...any) + Fatalf(format string, args ...any) + Helper() + Log(args ...any) + Logf(format string, args ...any) +} + +// withMsg implements TB and prepends some information to all logs or error +// messages. +type withMsg struct { + TB + + msg string +} + +func (w *withMsg) Error(args ...any) { + w.TB.Helper() + w.TB.Errorf("%s: %s", w.msg, fmt.Sprint(args...)) +} +func (w *withMsg) Errorf(format string, args ...any) { + w.TB.Helper() + w.TB.Errorf("%s: %s", w.msg, fmt.Sprintf(format, args...)) +} + +func (w *withMsg) Fatal(args ...any) { + w.TB.Helper() + w.TB.Fatalf("%s: %s", w.msg, fmt.Sprint(args...)) +} + +func (w *withMsg) Fatalf(format string, args ...any) { + w.TB.Helper() + w.TB.Fatalf("%s: %s", w.msg, fmt.Sprintf(format, args...)) +} + +func (w *withMsg) Log(args ...any) { + w.TB.Helper() + w.TB.Logf("%s: %s", w.msg, fmt.Sprint(args...)) +} + +func (w *withMsg) Logf(format string, args ...any) { + w.TB.Helper() + w.TB.Logf("%s: %s", w.msg, fmt.Sprintf(format, args...)) +} + +// WithMsg returns a TB that can be used with assertions and logs which +// prepends a message to any log or error message. +// +// Example: +// +// { +// t := require.WithMsg(t, "n=", n) +// require.Equal(t, a, b) +// require.LT(t, c, d) +// } +// +// A failure message would look like: +// +// n=5: expected 6 == 7 +func WithMsg(tb TB, args ...any) TB { + return &withMsg{TB: tb, msg: fmt.Sprint(args...)} +} + +// WithMsgf returns a TB that can be used with assertions and logs which +// prepends a message to any log or error message. +// +// Example: +// +// { +// t := require.WithMsgf(t, "n=%d", n) +// require.Equal(t, a, b) +// require.LT(t, c, d) +// } +// +// A failure message would look like: +// +// n=5: expected 6 == 7 +func WithMsgf(tb TB, format string, args ...any) TB { + return &withMsg{TB: tb, msg: fmt.Sprintf(format, args...)} +} diff --git a/testutils/require/require_test.go b/testutils/require/require_test.go new file mode 100644 index 0000000..8b77dd7 --- /dev/null +++ b/testutils/require/require_test.go @@ -0,0 +1,35 @@ +// Copyright 2024 The Cockroach Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +// implied. See the License for the specific language governing +// permissions and limitations under the License. + +package require_test + +import ( + "testing" + + "github.com/cockroachdb/crlib/testutils/require" +) + +func TestWithMsg(t *testing.T) { + t2 := require.WithMsg(t, "foo") + // foo: hello1 + t2.Logf("hello%d", 1) + + // 1.2: hello2 + t2 = require.WithMsgf(t, "%d.%d", 1, 2) + t2.Log("hello2") + + // 1.2: bar: hello3 + t3 := require.WithMsgf(t2, "bar") + t3.Log("hello3") +}