diff --git a/helpers/values.go b/helpers/values.go index 3c773ef..ddaa7e6 100644 --- a/helpers/values.go +++ b/helpers/values.go @@ -117,28 +117,54 @@ func When[T any](condition bool, callbacks func() T, defaults ...T) T { return zero } -// Default returns defaultValue if value is zero, otherwise value. +// Default returns the first non-zero value. +// If all values are zero, return the zero value. // // Default("", "foo") // "foo" // Default("bar", "foo") // "bar" -func Default[T comparable](value T, defaultValue T) T { +// Default("", "", "foo") // "foo" +func Default[T comparable](values ...T) T { var zero T - if value == zero { - return defaultValue + for _, value := range values { + if value != zero { + return value + } } - return value + return zero +} + +func DefaultFunc[T comparable](callbacks ...func() T) T { + var zero, value T + for _, callback := range callbacks { + if callback != nil { + value = callback() + if value != zero { + return value + } + } + } + return zero } // DefaultWithFunc returns defaultValue if value is zero, otherwise value. // // DefaultWithFunc("", func() string { return "foo" }) // "foo" // DefaultWithFunc("bar", func() string { return "foo" }) // "bar" -func DefaultWithFunc[T comparable](value T, defaultValue func() T) T { +// DefaultWithFunc("", func() string { return "" }, func() string { return "foo" }) // "foo" +func DefaultWithFunc[T comparable](value T, callbacks ...func() T) T { var zero T - if value == zero { - return defaultValue() + if value != zero { + return value } - return value + for _, callback := range callbacks { + if callback != nil { + value = callback() + if value != zero { + return value + } + } + } + return zero } // Ptr returns a pointer to the value. diff --git a/helpers/values_test.go b/helpers/values_test.go index 8d8d560..5b35b8a 100644 --- a/helpers/values_test.go +++ b/helpers/values_test.go @@ -243,9 +243,60 @@ func TestDefault(t *testing.T) { // ptr got4 := Default(nil, &foo{Name: "bar"}) assert.Equal(t, "bar", got4.Name) + + // more values + got5 := Default(0, 10, 20, 30) + assert.Equal(t, 10, got5) + + got6 := Default(0, 0, 20) + assert.Equal(t, 20, got6) } -func TestDefaultWith(t *testing.T) { +func TestDefaultFunc(t *testing.T) { + // string + got := DefaultFunc(func() string { + return "" + }, func() string { + return "foo" + }) + assert.Equal(t, "foo", got) + + // int + got2 := DefaultFunc(func() int { + return 0 + }, func() int { + return 10 + }) + assert.Equal(t, 10, got2) + + // struct + got3 := DefaultFunc(func() foo { + return foo{} + }, func() foo { + return foo{Name: "bar"} + }) + assert.Equal(t, "bar", got3.Name) + + // ptr + got4 := DefaultFunc(func() *foo { + return nil + }, func() *foo { + return &foo{Name: "bar"} + }) + assert.Equal(t, "bar", got4.Name) + + // more values + got5 := DefaultFunc(func() int { + return 0 + }, func() int { + return 0 + }, func() int { + return 10 + }) + assert.Equal(t, 10, got5) +} + +func TestDefaultWithFunc(t *testing.T) { // string got := DefaultWithFunc("", func() string { return "foo" @@ -269,6 +320,14 @@ func TestDefaultWith(t *testing.T) { return &foo{Name: "bar"} }) assert.Equal(t, "bar", got4.Name) + + // more values + got5 := DefaultWithFunc(0, func() int { + return 0 + }, func() int { + return 10 + }) + assert.Equal(t, 10, got5) } func TestPtrAndVal(t *testing.T) {