Skip to content

Commit

Permalink
feat: add String method to testcase (#108)
Browse files Browse the repository at this point in the history
- fix testcase parser to handle precision_timestamp literals 
- fix String & ShortString of FixedChar and PrecisionTimestamp types
  • Loading branch information
scgkiran authored Jan 21, 2025
1 parent d4f8747 commit aa74f3e
Show file tree
Hide file tree
Showing 14 changed files with 470 additions and 131 deletions.
4 changes: 2 additions & 2 deletions expr/builder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ func TestExprBuilder(t *testing.T) {
err string
}{
{"literal", "i8?(5)", b.Wrap(expr.NewLiteral(int8(5), true)), ""},
{"preciseTimeStampliteral", "precisiontimestamp?<3>(1970-01-01 00:02:03.456)", b.Wrap(expr.NewPrecisionTimestampLiteral(123456, types.PrecisionMilliSeconds, types.NullabilityNullable), nil), ""},
{"preciseTimeStampTzliteral", "precisiontimestamptz?<6>(1970-01-01T00:00:00.123456Z)", b.Wrap(expr.NewPrecisionTimestampTzLiteral(123456, types.PrecisionMicroSeconds, types.NullabilityNullable), nil), ""},
{"preciseTimeStampliteral", "precision_timestamp?<3>(1970-01-01 00:02:03.456)", b.Wrap(expr.NewPrecisionTimestampLiteral(123456, types.PrecisionMilliSeconds, types.NullabilityNullable), nil), ""},
{"preciseTimeStampTzliteral", "precision_timestamp_tz?<6>(1970-01-01T00:00:00.123456Z)", b.Wrap(expr.NewPrecisionTimestampTzLiteral(123456, types.PrecisionMicroSeconds, types.NullabilityNullable), nil), ""},
{"simple add", "add(.field(1) => i8, i8(5)) => i8?",
b.ScalarFunc(addID).Args(
b.RootRef(expr.NewStructFieldRef(1)),
Expand Down
53 changes: 53 additions & 0 deletions expr/literals.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,22 @@ func (*PrimitiveLiteral[T]) isRootRef() {}
func (t *PrimitiveLiteral[T]) String() string {
return fmt.Sprintf("%s(%s)", t.Type.String(), t.ValueString())
}

func (t *PrimitiveLiteral[T]) ValueString() string {
if lit, ok := any(t.Value).(types.TimePrinter); ok {
return lit.ToTimeString()
}
return fmt.Sprintf("%v", t.Value)
}

func (t *PrimitiveLiteral[T]) IsoValueString() string {
switch x := any(t.Value).(type) {
case types.IsoTimePrinter:
return x.ToIsoTimeString()
}
return t.ValueString()
}

func (t *PrimitiveLiteral[T]) GetType() types.Type { return t.Type }
func (t *PrimitiveLiteral[T]) ToProtoLiteral() *proto.Expression_Literal {
lit := &proto.Expression_Literal{
Expand Down Expand Up @@ -472,6 +482,49 @@ func (t *ProtoLiteral) ValueString() string {
return fmt.Sprintf("%s", t.Value)
}

// IsoValueString handles precision timestamp and interval literals to return a string in ISO 8601 format
func (t *ProtoLiteral) IsoValueString() string {
switch literalType := t.Type.(type) {
case *types.PrecisionTimestampType:
tm := types.Timestamp(t.Value.(int64)).ToPrecisionTime(literalType.Precision)
return tm.UTC().Format("2006-01-02T15:04:05.999999999")
case *types.PrecisionTimestampTzType:
tm := types.TimestampTz(t.Value.(int64)).ToPrecisionTime(literalType.Precision)
return tm.UTC().Format("2006-01-02T15:04:05.000-07:00")
case *types.IntervalYearType:
x, _ := t.Value.(*proto.Expression_Literal_IntervalYearToMonth)
// Validity is required by construction.
return fmt.Sprintf("P%dY%dM", x.GetYears(), x.GetMonths())
case *types.IntervalDayType:
x, _ := t.Value.(*proto.Expression_Literal_IntervalDayToSecond)
// Validity is required by construction.
seconds := x.GetSeconds()
minutes := seconds / 60
hours := minutes / 60
seconds = seconds % 60
minutes = minutes % 60
sb := strings.Builder{}
sb.WriteString("P")
if x.GetDays() > 0 {
sb.WriteString(fmt.Sprintf("%dD", x.GetDays()))
}
if minutes > 0 || seconds > 0 {
sb.WriteString("T")
if hours > 0 {
sb.WriteString(fmt.Sprintf("%dH", hours))
}
if minutes > 0 {
sb.WriteString(fmt.Sprintf("%dM", minutes))
}
if seconds > 0 {
sb.WriteString(fmt.Sprintf("%dS", seconds))
}
}
return sb.String()
}
return t.ValueString()
}

func (*ProtoLiteral) isRootRef() {}
func (t *ProtoLiteral) GetType() types.Type { return t.Type }
func (t *ProtoLiteral) String() string {
Expand Down
22 changes: 11 additions & 11 deletions expr/string_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,17 @@ func TestLiteralToString(t *testing.T) {
Value: expr.NewFixedCharLiteral(types.FixedChar("bar"), false),
},
}, true),
}, true), "list?<map?<string, char<3>>>([map?<string, char<3>>([{string(foo) char<3>(bar)} {string(baz) char<3>(bar)}])])"},
}, true), "list?<map?<string, fixedchar<3>>>([map?<string, fixedchar<3>>([{string(foo) fixedchar<3>(bar)} {string(baz) fixedchar<3>(bar)}])])"},
{MustLiteral(expr.NewLiteral(float32(1.5), false)), "fp32(1.5)"},
{MustLiteral(expr.NewLiteral(&types.VarChar{Value: "foobar", Length: 7}, true)), "varchar?<7>(foobar)"},
{expr.NewPrecisionTimestampLiteral(123456, types.PrecisionSeconds, types.NullabilityNullable), "precisiontimestamp?<0>(1970-01-02 10:17:36)"},
{expr.NewPrecisionTimestampLiteral(123456, types.PrecisionMilliSeconds, types.NullabilityNullable), "precisiontimestamp?<3>(1970-01-01 00:02:03.456)"},
{expr.NewPrecisionTimestampLiteral(123456, types.PrecisionMicroSeconds, types.NullabilityNullable), "precisiontimestamp?<6>(1970-01-01 00:00:00.123456)"},
{expr.NewPrecisionTimestampLiteral(123456, types.PrecisionNanoSeconds, types.NullabilityNullable), "precisiontimestamp?<9>(1970-01-01 00:00:00.000123456)"},
{expr.NewPrecisionTimestampTzLiteral(123456, types.PrecisionSeconds, types.NullabilityNullable), "precisiontimestamptz?<0>(1970-01-02T10:17:36Z)"},
{expr.NewPrecisionTimestampTzLiteral(123456, types.PrecisionMilliSeconds, types.NullabilityNullable), "precisiontimestamptz?<3>(1970-01-01T00:02:03.456Z)"},
{expr.NewPrecisionTimestampTzLiteral(123456, types.PrecisionMicroSeconds, types.NullabilityNullable), "precisiontimestamptz?<6>(1970-01-01T00:00:00.123456Z)"},
{expr.NewPrecisionTimestampTzLiteral(123456, types.PrecisionNanoSeconds, types.NullabilityNullable), "precisiontimestamptz?<9>(1970-01-01T00:00:00.000123456Z)"},
{expr.NewPrecisionTimestampLiteral(123456, types.PrecisionSeconds, types.NullabilityNullable), "precision_timestamp?<0>(1970-01-02 10:17:36)"},
{expr.NewPrecisionTimestampLiteral(123456, types.PrecisionMilliSeconds, types.NullabilityNullable), "precision_timestamp?<3>(1970-01-01 00:02:03.456)"},
{expr.NewPrecisionTimestampLiteral(123456, types.PrecisionMicroSeconds, types.NullabilityNullable), "precision_timestamp?<6>(1970-01-01 00:00:00.123456)"},
{expr.NewPrecisionTimestampLiteral(123456, types.PrecisionNanoSeconds, types.NullabilityNullable), "precision_timestamp?<9>(1970-01-01 00:00:00.000123456)"},
{expr.NewPrecisionTimestampTzLiteral(123456, types.PrecisionSeconds, types.NullabilityNullable), "precision_timestamp_tz?<0>(1970-01-02T10:17:36Z)"},
{expr.NewPrecisionTimestampTzLiteral(123456, types.PrecisionMilliSeconds, types.NullabilityNullable), "precision_timestamp_tz?<3>(1970-01-01T00:02:03.456Z)"},
{expr.NewPrecisionTimestampTzLiteral(123456, types.PrecisionMicroSeconds, types.NullabilityNullable), "precision_timestamp_tz?<6>(1970-01-01T00:00:00.123456Z)"},
{expr.NewPrecisionTimestampTzLiteral(123456, types.PrecisionNanoSeconds, types.NullabilityNullable), "precision_timestamp_tz?<9>(1970-01-01T00:00:00.000123456Z)"},
{MustLiteral(literal.NewDecimalFromString("12.345")), "decimal<5,3>(12.345)"},
{MustLiteral(literal.NewDecimalFromString("-12.345")), "decimal<5,3>(-12.345)"},
}
Expand Down Expand Up @@ -102,7 +102,7 @@ func TestLiteralToValueString(t *testing.T) {
Value: expr.NewFixedCharLiteral(types.FixedChar("bar"), false),
},
}, true),
}, true), "[[{string(foo) char<3>(bar)} {string(baz) char<3>(bar)}]]"},
}, true), "[[{string(foo) fixedchar<3>(bar)} {string(baz) fixedchar<3>(bar)}]]"},
{expr.NewNestedLiteral(expr.MapLiteralValue{
{
Key: expr.NewPrimitiveLiteral("foo", false),
Expand All @@ -112,7 +112,7 @@ func TestLiteralToValueString(t *testing.T) {
Key: expr.NewPrimitiveLiteral("baz", false),
Value: expr.NewFixedCharLiteral(types.FixedChar("bar"), false),
},
}, true), "[{string(foo) char<3>(bar)} {string(baz) char<3>(bar)}]"},
}, true), "[{string(foo) fixedchar<3>(bar)} {string(baz) fixedchar<3>(bar)}]"},
{MustLiteral(expr.NewLiteral(float32(1.5), false)), "1.5"},
{MustLiteral(expr.NewLiteral(&types.VarChar{Value: "foobar", Length: 7}, true)), "foobar"},
{expr.NewPrecisionTimestampLiteral(123456, types.PrecisionSeconds, types.NullabilityNullable), "1970-01-02 10:17:36"},
Expand Down
4 changes: 2 additions & 2 deletions functions/types_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ scalar_functions:
{"timestamp", "TIMESTAMP", &types.TimestampType{Nullability: types.NullabilityRequired}, true},
{"dec<10,2>", "NUMERIC(10,2)", &types.DecimalType{Nullability: types.NullabilityRequired, Precision: 10, Scale: 2}, true},
{"varchar<10>", "VARCHAR(10)", &types.VarCharType{Nullability: types.NullabilityRequired, Length: 10}, true},
{"char<10>", "CHAR(10)", &types.FixedCharType{Nullability: types.NullabilityRequired, Length: 10}, true},
{"fixedchar<10>", "CHAR(10)", &types.FixedCharType{Nullability: types.NullabilityRequired, Length: 10}, true},
{"fixedbinary<10>", "BINARY(10)", &types.FixedBinaryType{Nullability: types.NullabilityRequired, Length: 10}, true},

// short names
Expand All @@ -160,7 +160,7 @@ scalar_functions:
{"timestamp?", "TIMESTAMP", &types.TimestampType{Nullability: types.NullabilityNullable}, true},
{"dec?<10,2>", "NUMERIC(10,2)", &types.DecimalType{Nullability: types.NullabilityNullable, Precision: 10, Scale: 2}, true},
{"varchar?<10>", "VARCHAR(10)", &types.VarCharType{Nullability: types.NullabilityNullable, Length: 10}, true},
{"char?<10>", "CHAR(10)", &types.FixedCharType{Nullability: types.NullabilityNullable, Length: 10}, true},
{"fixedchar?<10>", "CHAR(10)", &types.FixedCharType{Nullability: types.NullabilityNullable, Length: 10}, true},
{"fixedbinary?<10>", "BINARY(10)", &types.FixedBinaryType{Nullability: types.NullabilityNullable, Length: 10}, true},
}
for _, tt := range tests {
Expand Down
16 changes: 16 additions & 0 deletions literal/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,14 @@ func NewPrecisionTimestamp(precision types.TimePrecision, value int64) (expr.Lit
}, false)
}

func NewPrecisionTimestampFromString(precision types.TimePrecision, value string) (expr.Literal, error) {
tm, err := parseTimeFromString(value)
if err != nil {
return nil, err
}
return NewPrecisionTimestampFromTime(precision, tm)
}

// NewPrecisionTimestampTzFromTime creates a new PrecisionTimestampTz literal from a time.Time timestamp value with given precision.
func NewPrecisionTimestampTzFromTime(precision types.TimePrecision, tm time.Time) (expr.Literal, error) {
return NewPrecisionTimestampTz(precision, getTimeValueByPrecision(tm, precision))
Expand All @@ -351,6 +359,14 @@ func NewPrecisionTimestampTz(precision types.TimePrecision, value int64) (expr.L
}, false)
}

func NewPrecisionTimestampTzFromString(precision types.TimePrecision, value string) (expr.Literal, error) {
tm, err := parseTimeFromString(value)
if err != nil {
return nil, err
}
return NewPrecisionTimestampTzFromTime(precision, tm)
}

func getTimeValueByPrecision(tm time.Time, precision types.TimePrecision) int64 {
switch precision {
case types.PrecisionSeconds:
Expand Down
174 changes: 174 additions & 0 deletions testcases/parser/nodes.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,57 @@ type CaseLiteral struct {
SubstraitError *SubstraitError
}

func (c *CaseLiteral) String() string {
if c.SubstraitError != nil {
return c.SubstraitError.String()
}
if c.Value == nil {
return "NULL"
}
return literalToString(c.Value) + "::" + c.Type.String()
}

func literalToString(literal expr.Literal) string {
if literal == nil {
panic("literal is nil")
}

switch lit := literal.(type) {
case *expr.NullLiteral:
return literal.ValueString()
case types.IsoValuePrinter:
switch literal.GetType().(type) {
// for these types enclose in single quotes
case *types.IntervalYearType, *types.IntervalDayType,
*types.PrecisionTimestampType, *types.PrecisionTimestampTzType,
*types.TimestampType, *types.TimeType, *types.TimestampTzType:
return fmt.Sprintf("'%s'", lit.IsoValueString())
}
}
switch literal.GetType().(type) {
// for these types enclose in single quotes
case *types.StringType, *types.FixedCharType, *types.VarCharType,
*types.FixedBinaryType, *types.BinaryType, *types.DateType:
return fmt.Sprintf("'%s'", literal.ValueString())
default:
return literal.ValueString()
}
}

func (c *CaseLiteral) AsAggregateArgumentString() string {
if c.SubstraitError != nil {
return c.SubstraitError.String()
}
if list, ok := c.Value.(*expr.ListLiteral); ok {
var elements []string
for _, element := range list.Value {
elements = append(elements, literalToString(element))
}
return "(" + strings.Join(elements, ", ") + ")::" + c.Type.String()
}
return c.Value.ValueString() + "::" + c.Type.String()
}

type TestFileHeader struct {
Version string
FuncType TestFuncType
Expand All @@ -47,6 +98,111 @@ type TestCase struct {
FuncType TestFuncType
}

func (tc *TestCase) String() string {
switch tc.FuncType {
case ScalarFuncType:
return tc.getScalarTestString()
case AggregateFuncType:
return tc.getAggregateTestString()
default:
panic(fmt.Sprintf("unsupported function type: %s", tc.FuncType))
}
}

func (tc *TestCase) getScalarTestString() string {
var b strings.Builder
b.WriteString(tc.FuncName)
b.WriteByte('(')
for i, arg := range tc.Args {
if i != 0 {
b.WriteString(", ")
}
b.WriteString(arg.String())
}
b.WriteByte(')')
b.WriteString(tc.getOptionString())
b.WriteString(" = ")
b.WriteString(tc.Result.String())
return b.String()
}

func (tc *TestCase) getAggregateTestString() string {
var b strings.Builder
if tc.needCompactAggregateFuncCall() {
b.WriteString(tc.getAggregateFuncTableString())
b.WriteByte(' ')
}

b.WriteString(tc.FuncName)
b.WriteByte('(')
for i, arg := range tc.AggregateArgs {
if i != 0 {
b.WriteString(", ")
}
b.WriteString(arg.String())
}
b.WriteByte(')')
b.WriteString(tc.getOptionString())
b.WriteString(" = ")
b.WriteString(tc.Result.String())
return b.String()
}

func (tc *TestCase) needCompactAggregateFuncCall() bool {
if tc.FuncType == ScalarFuncType {
return false
}
if len(tc.AggregateArgs) == 0 {
return true
}
for _, arg := range tc.AggregateArgs {
if arg.IsScalar || arg.ColumnName != "" {
return true
}
}
// common case of single column aggregate function
return false
}

func (tc *TestCase) getAggregateFuncTableString() string {
var b strings.Builder
if len(tc.Columns) == 0 {
return ""
}
b.WriteByte('(')
numRows := len(tc.Columns[0])
for i := 0; i < numRows; i++ {
if i != 0 {
b.WriteString(", ")
}
b.WriteByte('(')
for j, column := range tc.Columns {
if j != 0 {
b.WriteString(", ")
}
b.WriteString(literalToString(column[i]))
}
b.WriteByte(')')
}
b.WriteByte(')')
return b.String()
}

func (tc *TestCase) getOptionString() string {
if len(tc.Options) == 0 {
return ""
}
var b strings.Builder
b.WriteString(" [")
var options []string
for k, v := range tc.Options {
options = append(options, fmt.Sprintf("%s:%s", k, v))
}
b.WriteString(strings.Join(options, ","))
b.WriteByte(']')
return b.String()
}

func (tc *TestCase) GetFunctionOptions() []*types.FunctionOption {
if len(tc.Options) == 0 {
return nil
Expand Down Expand Up @@ -234,6 +390,16 @@ type AggregateArgument struct {
IsScalar bool
}

func (a *AggregateArgument) String() string {
if a.IsScalar {
return a.Argument.String()
}
if a.ColumnName == "" {
return a.Argument.AsAggregateArgumentString()
}
return a.ColumnName + "::" + a.ColumnType.String()
}

func (a *AggregateArgument) GetType() types.Type {
if a.IsScalar {
return a.Argument.Type
Expand Down Expand Up @@ -262,3 +428,11 @@ type CompactAggregateFuncCall struct {
Rows [][]expr.Literal
AggregateArgs []*AggregateArgument
}

type SubstraitError struct {
Error string
}

func (e SubstraitError) String() string {
return "<!" + e.Error + ">"
}
Loading

0 comments on commit aa74f3e

Please sign in to comment.