Skip to content

Commit

Permalink
fix: set nullability correctly in testcase parser (#110)
Browse files Browse the repository at this point in the history
  • Loading branch information
scgkiran authored Jan 24, 2025
1 parent 3fcd4ed commit 1aaa0fa
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 63 deletions.
2 changes: 1 addition & 1 deletion expr/literals.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ type NullLiteral struct {
}

func NewNullLiteral(t types.Type) *NullLiteral {
return &NullLiteral{Type: t}
return &NullLiteral{Type: t.WithNullability(types.NullabilityNullable)}
}

func (*NullLiteral) IsScalar() bool { return true }
Expand Down
71 changes: 47 additions & 24 deletions testcases/parser/parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ add(120::i8, 10::i8) [overflow:ERROR] = <!ERROR>
arithURI := "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml"
ids := []string{"add:i8_i8", "add:i16_i16", "add:i8_i8"}
argTypes := [][]types.Type{
{&types.Int8Type{}, &types.Int8Type{}},
{&types.Int16Type{}, &types.Int16Type{}},
{&types.Int8Type{}, &types.Int8Type{}},
{&types.Int8Type{Nullability: types.NullabilityRequired}, &types.Int8Type{Nullability: types.NullabilityRequired}},
{&types.Int16Type{Nullability: types.NullabilityRequired}, &types.Int16Type{Nullability: types.NullabilityRequired}},
{&types.Int8Type{Nullability: types.NullabilityRequired}, &types.Int8Type{Nullability: types.NullabilityRequired}},
}
reg, funcRegistry := functions.NewExtensionAndFunctionRegistries(&extensions.DefaultCollection)
basicGroupDesc := "'Basic examples without any special cases'"
Expand Down Expand Up @@ -91,8 +91,8 @@ lt('2016-12-31T13:30:15'::ts, '2017-12-31T13:30:15'::ts) = true::bool
assert.Equal(t, tsLiteral, testFile.TestCases[0].Args[1].Value)
boolLiteral := literal.NewBool(true)
assert.Equal(t, boolLiteral, testFile.TestCases[0].Result.Value)
assert.Equal(t, &types.BooleanType{}, testFile.TestCases[0].Result.Type)
timestampType := &types.TimestampType{Nullability: types.NullabilityUnspecified}
assert.Equal(t, &types.BooleanType{Nullability: types.NullabilityRequired}, testFile.TestCases[0].Result.Type)
timestampType := &types.TimestampType{Nullability: types.NullabilityRequired}
assert.Equal(t, timestampType, testFile.TestCases[0].Args[0].Type)
assert.Equal(t, timestampType, testFile.TestCases[0].Args[1].Type)
assert.Equal(t, ScalarFuncType, testFile.TestCases[0].FuncType)
Expand Down Expand Up @@ -155,17 +155,19 @@ func TestParseTestWithVariousTypes(t *testing.T) {
{testCaseStr: "f10('P10Y5M'::iyear, 5::i64) = 'P15Y5M'::iyear", expTestStr: "f10('P10Y5M'::interval_year, 5::i64) = 'P15Y5M'::interval_year"},
{testCaseStr: "f11('P10DT5H6M7S'::interval_day, 5::i64) = 'P10DT10H6M7S'::interval_day", expTestStr: "f11('P10DT5H6M7S'::interval_day<0>, 5::i64) = 'P10DT10H6M7S'::interval_day<0>"},
{testCaseStr: "f11('P10DT6M7S'::interval_day, 5::i64) = 'P10DT11M7S'::interval_day", expTestStr: "f11('P10DT6M7S'::interval_day<0>, 5::i64) = 'P10DT11M7S'::interval_day<0>"},
{testCaseStr: "or(false::bool, null::bool) = null::bool", expTestStr: "or(false::boolean, null::boolean) = null::boolean"},
{testCaseStr: "or(false::bool, null::bool) = null::bool", expTestStr: "or(false::boolean, null::boolean?) = null::boolean?"},
{testCaseStr: "f12('a'::vchar<9>, 'b'::varchar<4>) = 'c'::varchar<3>", expTestStr: "f12('a'::varchar<9>, 'b'::varchar<4>) = 'c'::varchar<3>"},
{testCaseStr: "f8('1991-01-01T01:02:03.456'::pts<3>, '1991-01-01T00:00:00.000000'::pts<6>) = '1991-01-01T22:33:44'::pts<0>", expTestStr: "f8('1991-01-01T01:02:03.456'::precision_timestamp<3>, '1991-01-01T00:00:00'::precision_timestamp<6>) = '1991-01-01T22:33:44'::precision_timestamp<0>"},
{testCaseStr: "f8('1991-01-01T01:02:03.456+05:30'::ptstz<3>, '1991-01-01T00:00:00+15:30'::ptstz<0>) = '1991-01-01T22:33:44+15:30'::ptstz<0>", expTestStr: "f8('1990-12-31T19:32:03.456+00:00'::precision_timestamp_tz<3>, '1990-12-31T08:30:00.000+00:00'::precision_timestamp_tz<0>) = '1991-01-01T07:03:44.000+00:00'::precision_timestamp_tz<0>"},
//{"f12('P10DT6M7.2000S'::iday<4>, 5::i64) = 'P10DT11M7.2000S'::iday<4>"}, // TODO enable after fixing the grammar
{testCaseStr: "f12('P10DT6M7S'::interval_day, 5::i64) = 'P10DT11M7S'::interval_day", expTestStr: "f12('P10DT6M7S'::interval_day<0>, 5::i64) = 'P10DT11M7S'::interval_day<0>"},
{testCaseStr: "concat('abcd'::varchar<9>, Null::str) [null_handling:ACCEPT_NULLS] = Null::str", expTestStr: "concat('abcd'::varchar<9>, null::string) [null_handling:ACCEPT_NULLS] = null::string"},
{testCaseStr: "concat('abcd'::varchar<9>, null::string) [null_handling:ACCEPT_NULLS] = null::string"},
{testCaseStr: "concat('abcd'::vchar<9>, 'ef'::varchar<9>) = Null::vchar<9>", expTestStr: "concat('abcd'::varchar<9>, 'ef'::varchar<9>) = null::varchar<9>"},
{testCaseStr: "concat('abcd'::vchar<9>, 'ef'::fixedchar<9>) = Null::fchar<9>", expTestStr: "concat('abcd'::varchar<9>, 'ef'::fixedchar<9>) = null::fixedchar<9>"},
{testCaseStr: "concat('abcd'::fbin<9>, 'ef'::fixedbinary<9>) = Null::fbin<9>", expTestStr: "concat('0x61626364'::fixedbinary<9>, '0x6566'::fixedbinary<9>) = null::fixedbinary<9>"},
{testCaseStr: "concat('abcd'::varchar<9>, Null::str) [null_handling:ACCEPT_NULLS] = Null::str", expTestStr: "concat('abcd'::varchar<9>, null::string?) [null_handling:ACCEPT_NULLS] = null::string?"},
{testCaseStr: "concat('abcd'::varchar<9>, null::string) [null_handling:ACCEPT_NULLS] = null::string", expTestStr: "concat('abcd'::varchar<9>, null::string?) [null_handling:ACCEPT_NULLS] = null::string?"},
{testCaseStr: "concat('abcd'::varchar<9>, null::varchar?<9>) [null_handling:ACCEPT_NULLS] = null::varchar?<9>"},
{testCaseStr: "concat('abcd'::vchar<9>, 'ef'::varchar<9>) = 'abcdef'::vchar<9>", expTestStr: "concat('abcd'::varchar<9>, 'ef'::varchar<9>) = 'abcdef'::varchar<9>"},
{testCaseStr: "concat('abcd'::vchar<9>, Null::varchar<9>) = Null::vchar<9>", expTestStr: "concat('abcd'::varchar<9>, null::varchar?<9>) = null::varchar?<9>"},
{testCaseStr: "concat('abcd'::vchar<9>, Null::fixedchar<9>) = Null::fchar<9>", expTestStr: "concat('abcd'::varchar<9>, null::fixedchar?<9>) = null::fixedchar?<9>"},
{testCaseStr: "concat('abcd'::fbin<9>, Null::fixedbinary<9>) = Null::fbin<9>", expTestStr: "concat('0x61626364'::fixedbinary<9>, null::fixedbinary?<9>) = null::fixedbinary?<9>"},
{testCaseStr: "f35('1991-01-01T01:02:03.456'::pts<3>) = '1991-01-01T01:02:30.123123'::precision_timestamp<3>", expTestStr: "f35('1991-01-01T01:02:03.456'::precision_timestamp<3>) = '1991-01-01T01:02:30.123'::precision_timestamp<3>"},
{testCaseStr: "f36('1991-01-01T01:02:03.456'::pts<3>, '1991-01-01T01:02:30.123123'::precision_timestamp<3>) = 123456::i64", expTestStr: "f36('1991-01-01T01:02:03.456'::precision_timestamp<3>, '1991-01-01T01:02:30.123'::precision_timestamp<3>) = 123456::i64"},
{testCaseStr: "f37('1991-01-01T01:02:03.123456'::pts<6>, '1991-01-01T04:05:06.456'::precision_timestamp<6>) = 123456::i64", expTestStr: "f37('1991-01-01T01:02:03.123456'::precision_timestamp<6>, '1991-01-01T04:05:06.456'::precision_timestamp<6>) = 123456::i64"},
Expand All @@ -183,10 +185,24 @@ func TestParseTestWithVariousTypes(t *testing.T) {
} else {
assert.Equal(t, test.testCaseStr, testFile.TestCases[0].String())
}
for _, arg := range testFile.TestCases[0].Args {
assert.NotNil(t, arg.Value)
checkNullability(t, arg.Value, arg.Type)
}
assert.NotNil(t, testFile.TestCases[0].Result.Value)
checkNullability(t, testFile.TestCases[0].Result.Value, testFile.TestCases[0].Result.Type)
})
}
}

func checkNullability(t *testing.T, lit expr.Literal, argType types.Type) {
if _, ok := lit.(*expr.NullLiteral); !ok {
assert.Equal(t, types.NullabilityRequired, argType.GetNullability())
} else {
assert.Equal(t, types.NullabilityNullable, argType.GetNullability())
}
}

func TestParseStringTestCases(t *testing.T) {
header := makeHeader("v1.0", "extensions/functions_arithmetic_decimal.yaml")
tests := `# basic
Expand Down Expand Up @@ -249,7 +265,7 @@ some_func('abc'::str, 'def'::str) = [1, 2, 3, 4, 5, 6]::List<i8>`
strDef := literal.NewString("def")
assert.Equal(t, strAbc, testFile.TestCases[0].Args[0].Value)
assert.Equal(t, strDef, testFile.TestCases[0].Args[1].Value)
i8List := &types.ListType{Type: &types.Int8Type{}}
i8List := &types.ListType{Type: &types.Int8Type{Nullability: types.NullabilityRequired}, Nullability: types.NullabilityRequired}
list, _ := literal.NewList([]expr.Literal{
literal.NewInt8(1), literal.NewInt8(2), literal.NewInt8(3),
literal.NewInt8(4), literal.NewInt8(5), literal.NewInt8(6),
Expand Down Expand Up @@ -328,9 +344,9 @@ sum((9223372036854775806, 1, 1, 1, 1, 10000000000)::i64) [overflow:ERROR] = <!ER
require.Equal(t, 1, aggregateFunc.NArgs())
aggArg, ok := aggregateFunc.Arg(0).(*expr.FieldReference)
require.True(t, ok)
assert.Equal(t, &types.Float32Type{}, aggArg.GetType())
assert.Equal(t, &types.Float32Type{Nullability: types.NullabilityRequired}, aggArg.GetType())
assert.Equal(t, ".field(0) => fp32", aggArg.String())
assert.Equal(t, []types.Type{&types.Float32Type{}}, tc.GetArgTypes())
assert.Equal(t, []types.Type{&types.Float32Type{Nullability: types.NullabilityRequired}}, tc.GetArgTypes())
assert.Equal(t, testStrings[0], tc.String())

tc = testFile.TestCases[1]
Expand All @@ -355,9 +371,9 @@ sum((9223372036854775806, 1, 1, 1, 1, 10000000000)::i64) [overflow:ERROR] = <!ER
require.Equal(t, 1, aggregateFunc.NArgs())
aggArg, ok = aggregateFunc.Arg(0).(*expr.FieldReference)
require.True(t, ok)
assert.Equal(t, &types.Int64Type{}, aggArg.GetType())
assert.Equal(t, &types.Int64Type{Nullability: types.NullabilityRequired}, aggArg.GetType())
assert.Equal(t, ".field(0) => i64", aggArg.String())
assert.Equal(t, []types.Type{&types.Int64Type{}}, tc.GetArgTypes())
assert.Equal(t, []types.Type{&types.Int64Type{Nullability: types.NullabilityRequired}}, tc.GetArgTypes())
}

func newInt64List(values ...int64) interface{} {
Expand Down Expand Up @@ -404,7 +420,7 @@ func TestParseAggregateFuncCompact(t *testing.T) {
assert.Len(t, testFile.TestCases[0].AggregateArgs, 2)
assert.Equal(t, newFloat32Values(20, -3, 1, 10, 5), testFile.TestCases[0].Columns[0])
assert.Equal(t, newFloat32Values(20, -3, 1, 10, 5), testFile.TestCases[0].Columns[1])
f32Type := &types.Float32Type{}
f32Type := &types.Float32Type{Nullability: types.NullabilityRequired}
args := []*AggregateArgument{
createAggregateArg(t, "", "col0", f32Type),
createAggregateArg(t, "", "col1", f32Type),
Expand Down Expand Up @@ -493,7 +509,7 @@ LIST_AGG(t1.col0, ','::string) = 1::fp64
require.NoError(t, err)
require.NotNil(t, testFile)
assert.Len(t, testFile.TestCases, 2)
expectedArgTypes := []types.Type{&types.Float32Type{}, &types.StringType{}}
expectedArgTypes := []types.Type{&types.Float32Type{Nullability: types.NullabilityRequired}, &types.StringType{Nullability: types.NullabilityRequired}}
for i, tc := range testFile.TestCases {
assert.Equal(t, AggregateFuncType, tc.FuncType)
assert.Equal(t, expectedArgTypes, tc.GetArgTypes(), "unexpected arg types in test case %d", i)
Expand Down Expand Up @@ -603,6 +619,7 @@ func TestParseAggregateTestWithVariousTypes(t *testing.T) {
{testCaseStr: "f8(('1991-01-01T01:02:03.456', '1991-01-01T00:00:00')::timestamp) = '1991-01-01T22:33:44'::ts", expTestStr: "f8(('1991-01-01T01:02:03.456', '1991-01-01T00:00:00')::timestamp) = '1991-01-01T22:33:44'::timestamp"},
{testCaseStr: "f8(('1991-01-01T01:02:03.456+05:30', '1991-01-01T00:00:00+15:30')::tstz) = 23::i32", expTestStr: "f8(('1990-12-31T19:32:03.456', '1990-12-31T08:30:00')::timestamp_tz) = 23::i32"},
{testCaseStr: "f10(('P10Y5M', 'P11Y5M')::interval_year) = 'P21Y10M'::interval_year"},
{testCaseStr: "f10(('P10Y5M', null)::interval_year) = null::interval_year", expTestStr: "f10(('P10Y5M', null)::interval_year?) = null::interval_year?"},
{testCaseStr: "f10(('P10Y2M', 'P10Y7M')::iyear) = 'P20Y9M'::iyear", expTestStr: "f10(('P10Y2M', 'P10Y7M')::interval_year) = 'P20Y9M'::interval_year"},
{testCaseStr: "f11(('P10DT5H6M7S', 'P10DT6M7S')::interval_day) = 'P20DT11H6M7S'::interval_day", expTestStr: "f11(('P10DT5H6M7S', 'P10DT6M7S')::interval_day<0>) = 'P20DT11H6M7S'::interval_day<0>"},
{testCaseStr: "f11(('P10DT5H6M7S', 'P10DT6M7S')::iday?) = 'P20DT11H6M7S'::iday", expTestStr: "f11(('P10DT5H6M7S', 'P10DT6M7S')::interval_day?<0>) = 'P20DT11H6M7S'::interval_day<0>"},
Expand All @@ -611,12 +628,18 @@ func TestParseAggregateTestWithVariousTypes(t *testing.T) {
{testCaseStr: "((20), (3), (1), (10), (5)) count_star() = 1::fp64", expTestStr: "(('20'), ('3'), ('1'), ('10'), ('5')) count_star() = 1::fp64"}, // no type specified for columns in the test case
{testCaseStr: `DEFINE t1(fp32, fp32) = ((20, 20), (-3, -3), (1, 1), (10,10), (5,5))
count_star() = 1::fp64`, expTestStr: "((20, 20), (-3, -3), (1, 1), (10, 10), (5, 5)) count_star() = 1::fp64"},
{testCaseStr: "f20(('abcd', 'ef')::fchar?<9>) = Null::fchar<9>", expTestStr: "f20(('abcd', 'ef')::fixedchar?<9>) = null::fixedchar<9>"},
{testCaseStr: "f20(('abcd', 'ef')::fixedchar<9>) = Null::fchar<9>", expTestStr: "f20(('abcd', 'ef')::fixedchar<9>) = null::fixedchar<9>"},
{testCaseStr: "f20(('abcd', 'ef', null)::vchar?<9>) = Null::vchar<9>", expTestStr: "f20(('abcd', 'ef', null)::varchar?<9>) = null::varchar<9>"},
{testCaseStr: "f20(('abcd', 'ef')::varchar<9>) = Null::vchar<9>", expTestStr: "f20(('abcd', 'ef')::varchar<9>) = null::varchar<9>"},
{testCaseStr: "f20(('abcd', 'ef')::fbin<9>) = Null::fbin<9>", expTestStr: "f20(('abcd', 'ef')::fixedbinary<9>) = null::fixedbinary<9>"},
{testCaseStr: "f20(('abcd', 'ef')::fixedbinary?<9>) = Null::fixedbinary<9>", expTestStr: "f20(('abcd', 'ef')::fixedbinary?<9>) = null::fixedbinary<9>"},
{testCaseStr: `DEFINE t1(varchar<5>) = (('cat'), ('bat'), ('rat'), (null))
count_star() = 1::fp64`, expTestStr: "(('cat'), ('bat'), ('rat'), (null)) count_star() = 1::fp64"}, // no arguments, so no type info in the output format
{testCaseStr: `DEFINE t1(varchar<5>) = (('cat'), ('bat'), ('rat'), (null))
count(t1.col0) = 4::fp64`, expTestStr: "(('cat'), ('bat'), ('rat'), (null)) count(col0::varchar?<5>) = 4::fp64"},
{testCaseStr: "f20(('abcd', 'ef')::fchar?<9>) = Null::fchar<9>", expTestStr: "f20(('abcd', 'ef')::fixedchar?<9>) = null::fixedchar?<9>"},
{testCaseStr: "f20(('abcd', 'ef')::fixedchar<9>) = Null::fchar<9>", expTestStr: "f20(('abcd', 'ef')::fixedchar<9>) = null::fixedchar?<9>"},
{testCaseStr: "f20(('abcd', null)::fixedchar<9>) = Null::fchar<9>", expTestStr: "f20(('abcd', null)::fixedchar?<9>) = null::fixedchar?<9>"},
{testCaseStr: "f20(('abcd', 'ef', null)::vchar?<9>) = Null::vchar<9>", expTestStr: "f20(('abcd', 'ef', null)::varchar?<9>) = null::varchar?<9>"},
{testCaseStr: "f20(('abcd', 'ef')::varchar<9>) = Null::vchar<9>", expTestStr: "f20(('abcd', 'ef')::varchar<9>) = null::varchar?<9>"},
{testCaseStr: "f20(('abcd', 'ef')::fbin<9>) = Null::fbin<9>", expTestStr: "f20(('abcd', 'ef')::fixedbinary<9>) = null::fixedbinary?<9>"},
{testCaseStr: "f20(('abcd', 'ef')::varchar?<9>) = 'abcdef'::varchar<9>", expTestStr: "f20(('abcd', 'ef')::varchar?<9>) = 'abcdef'::varchar<9>"},
{testCaseStr: "f20(('abcd', null)::fixedchar?<9>) = Null::fixedchar<9>", expTestStr: "f20(('abcd', null)::fixedchar?<9>) = null::fixedchar?<9>"},
{testCaseStr: "f35(('1991-01-01T01:02:03.456')::pts?<3>) = '1991-01-01T01:02:30.123123'::precision_timestamp<3>",
expTestStr: "f35(('1991-01-01T01:02:03.456')::precision_timestamp?<3>) = '1991-01-01T01:02:30.123'::precision_timestamp<3>"},
{testCaseStr: "f36(('1991-01-01T01:02:03.456', '1991-01-01T01:02:30.123123')::precision_timestamp<3>) = 123456::i64"},
Expand Down
Loading

0 comments on commit 1aaa0fa

Please sign in to comment.