diff --git a/.gitmodules b/.gitmodules index 4ae12907..498b2ea7 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,3 @@ [submodule "tests/testdata"] path = tests/testdata - url = https://github.com/go-jet/jet-test-data + url = https://github.com/arjen-ag5/jet-test-data diff --git a/generator/metadata/column_meta_data.go b/generator/metadata/column_meta_data.go index ecd61e22..5b6ce8ab 100644 --- a/generator/metadata/column_meta_data.go +++ b/generator/metadata/column_meta_data.go @@ -28,4 +28,5 @@ type DataType struct { Name string Kind DataTypeKind IsUnsigned bool + Dimensions int // The number of array dimensions } diff --git a/generator/postgres/query_set.go b/generator/postgres/query_set.go index fc4135ab..a60c01aa 100644 --- a/generator/postgres/query_set.go +++ b/generator/postgres/query_set.go @@ -66,6 +66,7 @@ select not attr.attnotnull as "column.isNullable", attr.attgenerated = 's' as "column.isGenerated", attr.atthasdef as "column.hasDefault", + attr.attndims as "dataType.dimensions", (case when tp.typtype = 'b' AND tp.typcategory <> 'A' then 'base' when tp.typtype = 'b' AND tp.typcategory = 'A' then 'array' diff --git a/generator/template/model_template.go b/generator/template/model_template.go index 84f46d70..ffce3189 100644 --- a/generator/template/model_template.go +++ b/generator/template/model_template.go @@ -7,6 +7,7 @@ import ( "github.com/google/uuid" "github.com/jackc/pgtype" "path/filepath" + "github.com/lib/pq" "reflect" "strings" "time" @@ -251,7 +252,7 @@ func getUserDefinedType(column metadata.Column) string { switch column.DataType.Kind { case metadata.EnumType: return dbidentifier.ToGoIdentifier(column.DataType.Name) - case metadata.UserDefinedType, metadata.ArrayType: + case metadata.UserDefinedType: return "string" } @@ -270,6 +271,11 @@ func getGoType(column metadata.Column) interface{} { // toGoType returns model type for column info. func toGoType(column metadata.Column) interface{} { + // We don't support multi-dimensional arrays + if column.DataType.Dimensions > 1 { + return "" + } + switch strings.ToLower(column.DataType.Name) { case "user-defined", "enum": return "" @@ -335,6 +341,16 @@ func toGoType(column metadata.Column) interface{} { return pgtype.Int8range{} case "numrange": return pgtype.Numrange{} + case "bool[]", "boolean[]": + return pq.BoolArray{} + case "integer[]", "int4[]": + return pq.Int32Array{} + case "bigint[]", "int8[]": + return pq.Int64Array{} + case "bytea[]": + return pq.ByteaArray{} + case "text[]", "jsonb[]", "json[]": + return pq.StringArray{} default: fmt.Println("- [Model ] Unsupported sql column '" + column.Name + " " + column.DataType.Name + "', using string instead.") return "" diff --git a/generator/template/sql_builder_template.go b/generator/template/sql_builder_template.go index a72e8e99..a4bdf5c3 100644 --- a/generator/template/sql_builder_template.go +++ b/generator/template/sql_builder_template.go @@ -156,53 +156,101 @@ func DefaultTableSQLBuilderColumn(columnMetaData metadata.Column) TableSQLBuilde // getSqlBuilderColumnType returns type of jet sql builder column func getSqlBuilderColumnType(columnMetaData metadata.Column) string { if columnMetaData.DataType.Kind != metadata.BaseType && - columnMetaData.DataType.Kind != metadata.RangeType { + columnMetaData.DataType.Kind != metadata.RangeType && + columnMetaData.DataType.Kind != metadata.ArrayType { return "String" } - switch strings.ToLower(columnMetaData.DataType.Name) { + typeName := columnMetaData.DataType.Name + columnName := columnMetaData.Name + + var columnType string + var supported bool + + if columnMetaData.DataType.Kind == metadata.ArrayType { + if columnMetaData.DataType.Dimensions > 1 { + fmt.Println("- [SQL Builder] Unsupported sql array with multiple dimensions column '" + columnName + " " + typeName + "', using StringColumn instead.") + return "String" + } + + columnType, supported = sqlArrayToColumnType(strings.TrimSuffix(typeName, "[]")) + } else { + columnType, supported = sqlToColumnType(typeName) + } + + if !supported { + fmt.Printf("- [SQL Builder] Unsupported SQL column '" + columnName + " " + typeName + "', using StringColumn instead.\n") + return "String" + } + + return columnType +} + +// sqlArrayToColumnType maps the type of an SQL array column type to a go jet sql builder column. Note that you don't +// pass the brackets `[]`, signifying an SQL array type, into this function. The second return value returns whether the +// given type is supported +func sqlArrayToColumnType(typeName string) (string, bool) { + switch strings.ToLower(typeName) { + case "user-defined", "enum", "text", "character", "character varying", "bytea", "uuid", + "tsvector", "bit", "bit varying", "money", "json", "jsonb", "xml", "point", "line", "ARRAY", + "char", "varchar", "nvarchar", "binary", "varbinary", "bpchar", "varbit", + "tinyblob", "blob", "mediumblob", "longblob", "tinytext", "mediumtext", "longtext": // MySQL + return "StringArray", true + case "smallint", "integer", "bigint", "int2", "int4", "int8", + "tinyint", "mediumint", "int", "year": //MySQL + return "IntegerArray", true case "boolean", "bool": - return "Bool" + return "BoolArray", true + default: + return "", false + } +} + +// sqlToColumnType maps the type of a SQL column type to a go jet sql builder column. The second return value returns +// whether the given type is supported. +func sqlToColumnType(typeName string) (string, bool) { + switch strings.ToLower(typeName) { + case "boolean", "bool": + return "Bool", true case "smallint", "integer", "bigint", "int2", "int4", "int8", "tinyint", "mediumint", "int", "year": //MySQL - return "Integer" + return "Integer", true case "date": - return "Date" + return "Date", true case "timestamp without time zone", "timestamp", "datetime": //MySQL: - return "Timestamp" + return "Timestamp", true case "timestamp with time zone", "timestamptz": - return "Timestampz" + return "Timestampz", true case "time without time zone", "time": //MySQL - return "Time" + return "Time", true case "time with time zone", "timetz": - return "Timez" + return "Timez", true case "interval": - return "Interval" + return "Interval", true case "user-defined", "enum", "text", "character", "character varying", "bytea", "uuid", "tsvector", "bit", "bit varying", "money", "json", "jsonb", "xml", "point", "line", "ARRAY", "char", "varchar", "nvarchar", "binary", "varbinary", "bpchar", "varbit", "tinyblob", "blob", "mediumblob", "longblob", "tinytext", "mediumtext", "longtext": // MySQL - return "String" + return "String", true case "real", "numeric", "decimal", "double precision", "float", "float4", "float8", "double": // MySQL - return "Float" + return "Float", true case "daterange": - return "DateRange" + return "DateRange", true case "tsrange": - return "TimestampRange" + return "TimestampRange", true case "tstzrange": - return "TimestampzRange" + return "TimestampzRange", true case "int4range": - return "Int4Range" + return "Int4Range", true case "int8range": - return "Int8Range" + return "Int8Range", true case "numrange": - return "NumericRange" + return "NumericRange", true default: - fmt.Println("- [SQL Builder] Unsupported sql column '" + columnMetaData.Name + " " + columnMetaData.DataType.Name + "', using StringColumn instead.") - return "String" + return "", false } } diff --git a/internal/jet/array_expression.go b/internal/jet/array_expression.go new file mode 100644 index 00000000..9ac562c6 --- /dev/null +++ b/internal/jet/array_expression.go @@ -0,0 +1,93 @@ +package jet + +// Array interface +type Array[E Expression] interface { + Expression + + EQ(rhs Array[E]) BoolExpression + NOT_EQ(rhs Array[E]) BoolExpression + LT(rhs Array[E]) BoolExpression + GT(rhs Array[E]) BoolExpression + LT_EQ(rhs Array[E]) BoolExpression + GT_EQ(rhs Array[E]) BoolExpression + + CONTAINS(rhs Array[E]) BoolExpression + IS_CONTAINED_BY(rhs Array[E]) BoolExpression + OVERLAP(rhs Array[E]) BoolExpression + CONCAT(rhs Array[E]) Array[E] + CONCAT_ELEMENT(E) Array[E] + + AT(expression IntegerExpression) E +} + +type arrayInterfaceImpl[E Expression] struct { + parent Array[E] +} + +type BinaryBoolOp func(Expression, Expression) BoolExpression + +func (a arrayInterfaceImpl[E]) EQ(rhs Array[E]) BoolExpression { + return Eq(a.parent, rhs) +} + +func (a arrayInterfaceImpl[E]) NOT_EQ(rhs Array[E]) BoolExpression { + return NotEq(a.parent, rhs) +} + +func (a arrayInterfaceImpl[E]) LT(rhs Array[E]) BoolExpression { + return Lt(a.parent, rhs) +} + +func (a arrayInterfaceImpl[E]) GT(rhs Array[E]) BoolExpression { + return Gt(a.parent, rhs) +} + +func (a arrayInterfaceImpl[E]) LT_EQ(rhs Array[E]) BoolExpression { + return LtEq(a.parent, rhs) +} + +func (a arrayInterfaceImpl[E]) GT_EQ(rhs Array[E]) BoolExpression { + return GtEq(a.parent, rhs) +} + +func (a arrayInterfaceImpl[E]) CONTAINS(rhs Array[E]) BoolExpression { + return Contains(a.parent, rhs) +} + +func (a arrayInterfaceImpl[E]) IS_CONTAINED_BY(rhs Array[E]) BoolExpression { + return IsContainedBy(a.parent, rhs) +} + +func (a arrayInterfaceImpl[E]) OVERLAP(rhs Array[E]) BoolExpression { + return Overlap(a.parent, rhs) +} + +func (a arrayInterfaceImpl[E]) CONCAT(rhs Array[E]) Array[E] { + return ArrayExp[E](NewBinaryOperatorExpression(a.parent, rhs, "||")) +} + +func (a arrayInterfaceImpl[E]) CONCAT_ELEMENT(rhs E) Array[E] { + return ArrayExp[E](NewBinaryOperatorExpression(a.parent, rhs, "||")) +} + +func (a arrayInterfaceImpl[E]) AT(expression IntegerExpression) E { + return arrayElementTypeCaster[E](a.parent, arraySubscriptExpr(a.parent, expression)) +} + +type arrayExpressionWrapper[E Expression] struct { + arrayInterfaceImpl[E] + Expression +} + +func newArrayExpressionWrap[E Expression](expression Expression) Array[E] { + arrayExpressionWrapper := arrayExpressionWrapper[E]{Expression: expression} + arrayExpressionWrapper.arrayInterfaceImpl.parent = &arrayExpressionWrapper + return &arrayExpressionWrapper +} + +// ArrayExp is array expression wrapper around arbitrary expression. +// Allows go compiler to see any expression as array expression. +// Does not add sql cast to generated sql builder output. +func ArrayExp[E Expression](expression Expression) Array[E] { + return newArrayExpressionWrap[E](expression) +} diff --git a/internal/jet/array_expression_test.go b/internal/jet/array_expression_test.go new file mode 100644 index 00000000..508b5f96 --- /dev/null +++ b/internal/jet/array_expression_test.go @@ -0,0 +1,59 @@ +package jet + +import ( + "github.com/lib/pq" + "testing" +) + +func TestArrayExpressionEQ(t *testing.T) { + assertClauseSerialize(t, table1ColStringArray.EQ(table2ColArray), "(table1.col_array_string = table2.col_array_string)") +} + +func TestArrayExpressionNOT_EQ(t *testing.T) { + assertClauseSerialize(t, table1ColStringArray.NOT_EQ(table2ColArray), "(table1.col_array_string != table2.col_array_string)") + assertClauseSerialize(t, table1ColStringArray.NOT_EQ(StringArray([]string{"x"})), "(table1.col_array_string != $1)", pq.StringArray{"x"}) +} + +func TestArrayExpressionLT(t *testing.T) { + assertClauseSerialize(t, table1ColStringArray.LT(table2ColArray), "(table1.col_array_string < table2.col_array_string)") +} + +func TestArrayExpressionGT(t *testing.T) { + assertClauseSerialize(t, table1ColStringArray.GT(table2ColArray), "(table1.col_array_string > table2.col_array_string)") +} + +func TestArrayExpressionLT_EQ(t *testing.T) { + assertClauseSerialize(t, table1ColStringArray.LT_EQ(table2ColArray), "(table1.col_array_string <= table2.col_array_string)") +} + +func TestArrayExpressionGT_EQ(t *testing.T) { + assertClauseSerialize(t, table1ColStringArray.GT_EQ(table2ColArray), "(table1.col_array_string >= table2.col_array_string)") +} + +func TestArrayExpressionCONTAINS(t *testing.T) { + assertClauseSerialize(t, table1ColStringArray.CONTAINS(table2ColArray), "(table1.col_array_string @> table2.col_array_string)") + assertClauseSerialize(t, table1ColStringArray.CONTAINS(StringArray([]string{"x"})), "(table1.col_array_string @> $1)", pq.StringArray{"x"}) +} + +func TestArrayExpressionCONTAINED_BY(t *testing.T) { + assertClauseSerialize(t, table1ColStringArray.IS_CONTAINED_BY(table2ColArray), "(table1.col_array_string <@ table2.col_array_string)") + assertClauseSerialize(t, table1ColStringArray.IS_CONTAINED_BY(StringArray([]string{"x"})), "(table1.col_array_string <@ $1)", pq.StringArray{"x"}) +} + +func TestArrayExpressionOVERLAP(t *testing.T) { + assertClauseSerialize(t, table1ColStringArray.OVERLAP(table2ColArray), "(table1.col_array_string && table2.col_array_string)") +} + +func TestArrayExpressionCONCAT(t *testing.T) { + assertClauseSerialize(t, table1ColStringArray.CONCAT(table2ColArray), "(table1.col_array_string || table2.col_array_string)") + assertClauseSerialize(t, table1ColStringArray.CONCAT(StringArray([]string{"x"})), "(table1.col_array_string || $1)", pq.StringArray{"x"}) +} + +func TestArrayExpressionCONCAT_ELEMENT(t *testing.T) { + assertClauseSerialize(t, table1ColStringArray.CONCAT_ELEMENT(StringExp(table2ColArray.AT(Int(1)))), "(table1.col_array_string || table2.col_array_string[$1])", int64(1)) + assertClauseSerialize(t, table1ColStringArray.CONCAT_ELEMENT(String("x")), "(table1.col_array_string || $1)", "x") +} + +func TestArrayExpressionAT(t *testing.T) { + assertClauseSerialize(t, table1ColStringArray.AT(Int(1)), "table1.col_array_string[$1]", int64(1)) +} diff --git a/internal/jet/column_types.go b/internal/jet/column_types.go index a7320615..2c47b103 100644 --- a/internal/jet/column_types.go +++ b/internal/jet/column_types.go @@ -121,6 +121,46 @@ func IntegerColumn(name string) ColumnInteger { //------------------------------------------------------// +type ColumnArray[E Expression] interface { + Array[E] + Column + + From(subQuery SelectTable) ColumnArray[E] + SET(stringExp Array[E]) ColumnAssigment +} + +type arrayColumnImpl[E Expression] struct { + arrayInterfaceImpl[E] + + ColumnExpressionImpl +} + +func (a arrayColumnImpl[E]) From(subQuery SelectTable) ColumnArray[E] { + newArrayColumn := ArrayColumn[E](a.name) + newArrayColumn.setTableName(a.tableName) + newArrayColumn.setSubQuery(subQuery) + + return newArrayColumn +} + +func (a *arrayColumnImpl[E]) SET(stringExp Array[E]) ColumnAssigment { + return columnAssigmentImpl{ + column: a, + expression: stringExp, + } +} + +// StringColumn creates named string column. +func ArrayColumn[E Expression](name string) ColumnArray[E] { + arrayColumn := &arrayColumnImpl[E]{} + arrayColumn.arrayInterfaceImpl.parent = arrayColumn + arrayColumn.ColumnExpressionImpl = NewColumnImpl(name, "", arrayColumn) + + return arrayColumn +} + +//------------------------------------------------------// + // ColumnString is interface for SQL text, character, character varying // bytea, uuid columns and enums types. type ColumnString interface { diff --git a/internal/jet/column_types_test.go b/internal/jet/column_types_test.go index 059d722d..38d9e96b 100644 --- a/internal/jet/column_types_test.go +++ b/internal/jet/column_types_test.go @@ -1,6 +1,7 @@ package jet import ( + "github.com/lib/pq" "testing" ) @@ -8,6 +9,42 @@ var subQuery = &selectTableImpl{ alias: "sub_query", } +func TestNewArrayColumnString(t *testing.T) { + stringArrayColumn := ArrayColumn[StringExpression]("colArray").From(subQuery) + assertClauseSerialize(t, stringArrayColumn, `sub_query."colArray"`) + assertClauseSerialize(t, stringArrayColumn.EQ(StringArray([]string{"X"})), `(sub_query."colArray" = $1)`, pq.StringArray{"X"}) + assertProjectionSerialize(t, stringArrayColumn, `sub_query."colArray" AS "colArray"`) + + arrayColumn2 := table1ColStringArray.From(subQuery) + assertClauseSerialize(t, arrayColumn2, `sub_query."table1.col_array_string"`) + assertClauseSerialize(t, arrayColumn2.EQ(StringArray([]string{"X"})), `(sub_query."table1.col_array_string" = $1)`, pq.StringArray{"X"}) + assertProjectionSerialize(t, arrayColumn2, `sub_query."table1.col_array_string" AS "table1.col_array_string"`) +} + +func TestNewArrayColumnBool(t *testing.T) { + boolArrayColumn := ArrayColumn[BoolExpression]("colArrayBool").From(subQuery) + assertClauseSerialize(t, boolArrayColumn, `sub_query."colArrayBool"`) + assertClauseSerialize(t, boolArrayColumn.EQ(BoolArray([]bool{true})), `(sub_query."colArrayBool" = $1)`, pq.BoolArray{true}) + assertProjectionSerialize(t, boolArrayColumn, `sub_query."colArrayBool" AS "colArrayBool"`) + + arrayColumn2 := table1ColBoolArray.From(subQuery) + assertClauseSerialize(t, arrayColumn2, `sub_query."table1.col_array_bool"`) + assertClauseSerialize(t, arrayColumn2.EQ(BoolArray([]bool{true})), `(sub_query."table1.col_array_bool" = $1)`, pq.BoolArray{true}) + assertProjectionSerialize(t, arrayColumn2, `sub_query."table1.col_array_bool" AS "table1.col_array_bool"`) +} + +func TestNewArrayColumnInteger(t *testing.T) { + intArrayColumn := ArrayColumn[IntegerExpression]("colArrayInt").From(subQuery) + assertClauseSerialize(t, intArrayColumn, `sub_query."colArrayInt"`) + assertClauseSerialize(t, intArrayColumn.EQ(Int32Array([]int32{42})), `(sub_query."colArrayInt" = $1)`, pq.Int32Array{42}) + assertProjectionSerialize(t, intArrayColumn, `sub_query."colArrayInt" AS "colArrayInt"`) + + arrayColumn2 := table1ColIntArray.From(subQuery) + assertClauseSerialize(t, arrayColumn2, `sub_query."table1.col_array_int"`) + assertClauseSerialize(t, arrayColumn2.EQ(Int32Array([]int32{42})), `(sub_query."table1.col_array_int" = $1)`, pq.Int32Array{42}) + assertProjectionSerialize(t, arrayColumn2, `sub_query."table1.col_array_int" AS "table1.col_array_int"`) +} + func TestNewBoolColumn(t *testing.T) { boolColumn := BoolColumn("colBool").From(subQuery) assertClauseSerialize(t, boolColumn, `sub_query."colBool"`) diff --git a/internal/jet/expression.go b/internal/jet/expression.go index 9999803f..c2488baa 100644 --- a/internal/jet/expression.go +++ b/internal/jet/expression.go @@ -316,6 +316,29 @@ func (s *complexExpression) serialize(statement StatementType, out *SQLBuilder, } } +//type arraySubscriptExpression struct { +// ExpressionInterfaceImpl +// array Expression +// subscript IntegerExpression +//} +// +//func (a arraySubscriptExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { +// if !contains(options, NoWrap) { +// out.WriteString("(") +// } +// a.array.serialize(statement, out, FallTrough(options)...) // FallTrough here because complexExpression is just a wrapper +// out.WriteString("[") +// a.subscript.serialize(statement, out, FallTrough(options)...) // FallTrough here because complexExpression is just a wrapper +// out.WriteString("]") +// if !contains(options, NoWrap) { +// out.WriteString(")") +// } +//} + +func arraySubscriptExpr(array Expression, subscript IntegerExpression) Expression { + return CustomExpression(array, Token("["), subscript, Token("]")) +} + func wrap(expressions ...Expression) Expression { return NewFunc("", expressions, nil) } diff --git a/internal/jet/func_expression.go b/internal/jet/func_expression.go index ddc579e4..b776b718 100644 --- a/internal/jet/func_expression.go +++ b/internal/jet/func_expression.go @@ -646,6 +646,47 @@ func LEAST(value Expression, values ...Expression) Expression { return NewFunc("LEAST", allValues, nil) } +// -------------------- Array Expressions Functions ------------------// + +// ANY should be used in combination with a boolean operator. The result of ANY is "true" if any true result is obtained +func ANY[E Expression](arr Array[E]) E { + return arrayElementTypeCaster(arr, Func("ANY", arr)) +} + +// ALL should be used in combination with a boolean operator. TThe result of ALL is “true” if all comparisons yield true +func ALL[E Expression](arr Array[E]) E { + return arrayElementTypeCaster(arr, Func("ALL", arr)) +} + +func arrayElementTypeCaster[E Expression](arrayExp Array[E], exp Expression) E { + var i Expression + switch arrayExp.(type) { + case Array[StringExpression]: + i = StringExp(exp) + case Array[IntegerExpression]: + i = IntExp(exp) + case Array[Int4Expression]: + i = IntExp(exp) + case Array[Int8Expression]: + i = IntExp(exp) + case Array[BoolExpression]: + i = BoolExp(exp) + } + + return i.(E) +} + +func ARRAY[E Expression](elems ...E) Array[E] { + var args = make([]Serializer, len(elems)) + for i, each := range elems { + args[i] = each + } + return ArrayExp[E](CustomExpression(Token("ARRAY["), ListSerializer{ + Serializers: args, + Separator: ",", + }, Token("]"))) +} + //--------------------------------------------------------------------// type funcExpressionImpl struct { diff --git a/internal/jet/literal_expression.go b/internal/jet/literal_expression.go index 251d3ab9..b8238c62 100644 --- a/internal/jet/literal_expression.go +++ b/internal/jet/literal_expression.go @@ -2,6 +2,7 @@ package jet import ( "fmt" + "github.com/lib/pq" "time" ) @@ -160,6 +161,24 @@ func Decimal(value string) FloatExpression { return &floatLiteral } +// ---------------------------------------------------// + +func BoolArray(values []bool) Array[BoolExpression] { + return ArrayExp[BoolExpression](literal(pq.BoolArray(values))) +} + +func Int64Array(values []int64) Array[IntegerExpression] { + return ArrayExp[IntegerExpression](literal(pq.Int64Array(values))) +} + +func Int32Array(values []int32) Array[IntegerExpression] { + return ArrayExp[IntegerExpression](literal(pq.Int32Array(values))) +} + +func StringArray(values []string) Array[StringExpression] { + return ArrayExp[StringExpression](literal(pq.StringArray(values))) +} + // ---------------------------------------------------// type stringLiteral struct { stringInterfaceImpl diff --git a/internal/jet/operators.go b/internal/jet/operators.go index c453c3e0..46b36ec3 100644 --- a/internal/jet/operators.go +++ b/internal/jet/operators.go @@ -74,6 +74,11 @@ func Contains(lhs Expression, rhs Expression) BoolExpression { return newBinaryBoolOperatorExpression(lhs, rhs, "@>") } +// IsContainedBy returns a representation of "a <@ b" +func IsContainedBy(lhs Expression, rhs Expression) BoolExpression { + return newBinaryBoolOperatorExpression(lhs, rhs, "<@") +} + // Overlap returns a representation of "a && b" func Overlap(lhs, rhs Expression) BoolExpression { return newBinaryBoolOperatorExpression(lhs, rhs, "&&") diff --git a/internal/jet/sql_builder.go b/internal/jet/sql_builder.go index 46f47ad4..035bb38f 100644 --- a/internal/jet/sql_builder.go +++ b/internal/jet/sql_builder.go @@ -7,6 +7,7 @@ import ( "github.com/go-jet/jet/v2/internal/3rdparty/pq" "github.com/go-jet/jet/v2/internal/utils/is" "github.com/google/uuid" + pq2 "github.com/lib/pq" "reflect" "sort" "strconv" @@ -81,11 +82,11 @@ func (s *SQLBuilder) write(data []byte) { } func isPreSeparator(b byte) bool { - return b == ' ' || b == '.' || b == ',' || b == '(' || b == '\n' || b == ':' + return b == ' ' || b == '.' || b == ',' || b == '(' || b == '\n' || b == ':' || b == '[' } func isPostSeparator(b byte) bool { - return b == ' ' || b == '.' || b == ',' || b == ')' || b == '\n' || b == ':' + return b == ' ' || b == '.' || b == ',' || b == ')' || b == '\n' || b == ':' || b == '[' || b == ']' } // WriteAlias is used to add alias to output SQL @@ -226,6 +227,8 @@ func argToString(value interface{}) string { case string: return stringQuote(bindVal) + case []string: + return stringArrayQuote(bindVal) case []byte: return stringQuote(string(bindVal)) case uuid.UUID: @@ -253,6 +256,13 @@ func argToString(value interface{}) string { } } +func stringArrayQuote(val []string) string { + // We'll rely on the internals of pq2.StringArray here. We know it will never return an error, and the returned + // value is a string + dv, _ := pq2.StringArray(val).Value() + return dv.(string) +} + func integerTypesToString(value interface{}) string { switch bindVal := value.(type) { case int: diff --git a/internal/jet/string_expression_test.go b/internal/jet/string_expression_test.go index 0f461acc..19837d33 100644 --- a/internal/jet/string_expression_test.go +++ b/internal/jet/string_expression_test.go @@ -76,6 +76,14 @@ func TestStringNOT_REGEXP_LIKE(t *testing.T) { assertClauseSerialize(t, table3StrCol.NOT_REGEXP_LIKE(String("JOHN"), true), "(table3.col2 NOT REGEXP $1)", "JOHN") } +func TestStringANY_EQ(t *testing.T) { + assertClauseSerialize(t, table2ColStr.EQ(ANY[StringExpression](table1ColStringArray)), "(table2.col_str = ANY(table1.col_array_string))") +} + +func TestStringALL_EQ(t *testing.T) { + assertClauseSerialize(t, table2ColStr.EQ(ALL[StringExpression](table1ColStringArray)), "(table2.col_str = ALL(table1.col_array_string))") +} + func TestStringExp(t *testing.T) { assertClauseSerialize(t, StringExp(table2ColFloat), "table2.col_float") assertClauseSerialize(t, StringExp(table2ColFloat).NOT_LIKE(String("abc")), "(table2.col_float NOT LIKE $1)", "abc") diff --git a/internal/jet/testutils.go b/internal/jet/testutils.go index 70b21c77..0f4ff8a6 100644 --- a/internal/jet/testutils.go +++ b/internal/jet/testutils.go @@ -15,19 +15,22 @@ var defaultDialect = NewDialect(DialectParams{ // just for tests }) var ( - table1Col1 = IntegerColumn("col1") - table1ColInt = IntegerColumn("col_int") - table1ColFloat = FloatColumn("col_float") - table1Col3 = IntegerColumn("col3") - table1ColTime = TimeColumn("col_time") - table1ColTimez = TimezColumn("col_timez") - table1ColTimestamp = TimestampColumn("col_timestamp") - table1ColTimestampz = TimestampzColumn("col_timestampz") - table1ColBool = BoolColumn("col_bool") - table1ColDate = DateColumn("col_date") - table1ColRange = RangeColumn[Int8Expression]("col_range") + table1Col1 = IntegerColumn("col1") + table1ColInt = IntegerColumn("col_int") + table1ColFloat = FloatColumn("col_float") + table1Col3 = IntegerColumn("col3") + table1ColTime = TimeColumn("col_time") + table1ColTimez = TimezColumn("col_timez") + table1ColTimestamp = TimestampColumn("col_timestamp") + table1ColTimestampz = TimestampzColumn("col_timestampz") + table1ColBool = BoolColumn("col_bool") + table1ColDate = DateColumn("col_date") + table1ColRange = RangeColumn[Int8Expression]("col_range") + table1ColStringArray = ArrayColumn[StringExpression]("col_array_string") + table1ColBoolArray = ArrayColumn[BoolExpression]("col_array_bool") + table1ColIntArray = ArrayColumn[IntegerExpression]("col_array_int") ) -var table1 = NewTable("db", "table1", "", table1Col1, table1ColInt, table1ColFloat, table1Col3, table1ColTime, table1ColTimez, table1ColBool, table1ColDate, table1ColRange, table1ColTimestamp, table1ColTimestampz) +var table1 = NewTable("db", "table1", "", table1Col1, table1ColInt, table1ColFloat, table1Col3, table1ColTime, table1ColTimez, table1ColBool, table1ColDate, table1ColRange, table1ColTimestamp, table1ColTimestampz, table1ColStringArray, table1ColBoolArray, table1ColIntArray) var ( table2Col3 = IntegerColumn("col3") @@ -42,8 +45,9 @@ var ( table2ColTimestampz = TimestampzColumn("col_timestampz") table2ColDate = DateColumn("col_date") table2ColRange = RangeColumn[Int8Expression]("col_range") + table2ColArray = ArrayColumn[StringExpression]("col_array_string") ) -var table2 = NewTable("db", "table2", "", table2Col3, table2Col4, table2ColInt, table2ColFloat, table2ColStr, table2ColBool, table2ColTime, table2ColTimez, table2ColDate, table2ColRange, table2ColTimestamp, table2ColTimestampz) +var table2 = NewTable("db", "table2", "", table2Col3, table2Col4, table2ColInt, table2ColFloat, table2ColStr, table2ColBool, table2ColTime, table2ColTimez, table2ColDate, table2ColRange, table2ColTimestamp, table2ColTimestampz, table2ColArray) var ( table3Col1 = IntegerColumn("col1") diff --git a/postgres/columns.go b/postgres/columns.go index a70c234b..d4a8b1a0 100644 --- a/postgres/columns.go +++ b/postgres/columns.go @@ -101,6 +101,24 @@ type ColumnInt8Range jet.ColumnRange[jet.Int8Expression] // Int8RangeColumn creates named range with range column var Int8RangeColumn = jet.RangeColumn[jet.Int8Expression] +// ColumnStringArray is interface of column +type ColumnStringArray jet.ColumnArray[StringExpression] + +// StringArrayColumn creates named string array column +var StringArrayColumn = jet.ArrayColumn[StringExpression] + +// ColumnIntegerArray is interface of column +type ColumnIntegerArray jet.ColumnArray[IntegerExpression] + +// IntegerArrayColumn creates named integer array column +var IntegerArrayColumn = jet.ArrayColumn[IntegerExpression] + +// ColumnBoolArray is interface of column +type ColumnBoolArray jet.ColumnArray[BoolExpression] + +// BoolArrayColumn creates named bool array column +var BoolArrayColumn = jet.ArrayColumn[BoolExpression] + //------------------------------------------------------// // ColumnInterval is interface of PostgreSQL interval columns. diff --git a/postgres/expressions.go b/postgres/expressions.go index d8ad34b4..8b780377 100644 --- a/postgres/expressions.go +++ b/postgres/expressions.go @@ -9,15 +9,24 @@ type Expression = jet.Expression // BoolExpression interface type BoolExpression = jet.BoolExpression +// BoolArrayExpression interface +type BoolArrayExpression = jet.Array[BoolExpression] + // StringExpression interface type StringExpression = jet.StringExpression +// StringArrayExpression interface +type StringArrayExpression = jet.Array[StringExpression] + // NumericExpression interface type NumericExpression = jet.NumericExpression // IntegerExpression interface type IntegerExpression = jet.IntegerExpression +// IntegerArrayExpression interface +type IntegerArrayExpression = jet.Array[IntegerExpression] + // FloatExpression is interface type FloatExpression = jet.FloatExpression diff --git a/postgres/functions.go b/postgres/functions.go index bce2e987..c6fae677 100644 --- a/postgres/functions.go +++ b/postgres/functions.go @@ -267,7 +267,7 @@ var TO_ASCII = jet.TO_ASCII // TO_HEX converts number to its equivalent hexadecimal representation var TO_HEX = jet.TO_HEX -//----------Data Type Formatting Functions ----------------------// +//---------- Range Functions ----------------------// // LOWER_BOUND returns range expressions lower bound func LOWER_BOUND[T Expression](expression jet.Range[T]) T { @@ -279,7 +279,61 @@ func UPPER_BOUND[T Expression](expression jet.Range[T]) T { return jet.UPPER_BOUND[T](expression) } -//----------Data Type Formatting Functions ----------------------// +// ---------- Array Functions ----------------------// + +// ANY should be used in combination with a boolean operator. The result of ANY is "true" if any true result is obtained +func ANY[T Expression](expression jet.Array[T]) T { + return jet.ANY[T](expression) +} + +// ALL should be used in combination with a boolean operator. TThe result of ALL is “true” if all comparisons yield true +func ALL[T Expression](expression jet.Array[T]) T { + return jet.ALL[T](expression) +} + +func ARRAY_APPEND[E Expression](arr jet.Array[E], el E) jet.Array[E] { + return arrayTypeCaster[E](arr, Func("ARRAY_APPEND", arr, el)) +} + +func ARRAY_CAT[E Expression](arr1, arr2 jet.Array[E]) jet.Array[E] { + return arrayTypeCaster[E](arr1, Func("ARRAY_CAT", arr1, arr2)) +} + +func ARRAY_PREPEND[E Expression](el E, arr jet.Array[E]) jet.Array[E] { + return jet.ArrayExp[E](Func("ARRAY_PREPEND", el, arr)) +} + +func ARRAY_LENGTH[E Expression](arr jet.Array[E], el IntegerExpression) IntegerExpression { + return IntExp(Func("ARRAY_LENGTH", arr, el)) +} + +func ARRAY_REMOVE[E Expression](arr jet.Array[E], el Expression) IntegerExpression { + return IntExp(Func("ARRAY_REMOVE", arr, el)) +} + +func ARRAY_TO_STRING(arr Expression, delim StringExpression) StringExpression { + return StringExp(Func("ARRAY_TO_STRING", arr, delim)) +} + +func arrayTypeCaster[E Expression](arrayExp Expression, exp Expression) jet.Array[E] { + var i Expression + switch arrayExp.(type) { + case jet.Array[StringExpression]: + i = jet.ArrayExp[StringExpression](exp) + case jet.Array[IntegerExpression]: + i = jet.ArrayExp[IntegerExpression](exp) + case jet.Array[BoolExpression]: + i = jet.ArrayExp[BoolExpression](exp) + } + return i.(jet.Array[E]) +} + +// ARRAY constructor +func ARRAY[T Expression](elems ...T) jet.Array[T] { + return jet.ARRAY[T](elems...) +} + +//---------- Data Type Formatting Functions ----------------------// // TO_CHAR converts expression to string with format var TO_CHAR = jet.TO_CHAR diff --git a/postgres/literal.go b/postgres/literal.go index 4f1c2c87..951e3f06 100644 --- a/postgres/literal.go +++ b/postgres/literal.go @@ -11,6 +11,11 @@ func Bool(value bool) BoolExpression { return CAST(jet.Bool(value)).AS_BOOL() } +// BoolArray creates new bool array literal expression +func BoolArray(elements []bool) BoolArrayExpression { + return jet.BoolArray(elements) +} + // Int is constructor for 64 bit signed integer expressions literals. var Int = jet.Int @@ -29,11 +34,21 @@ func Int32(value int32) IntegerExpression { return CAST(jet.Int32(value)).AS_INTEGER() } +// Int32Array creates new 32 bit signed integer literal expression +func Int32Array(elements []int32) IntegerArrayExpression { + return jet.Int32Array(elements) +} + // Int64 is constructor for 64 bit signed integer expressions literals. func Int64(value int64) IntegerExpression { return CAST(jet.Int(value)).AS_BIGINT() } +// Int64Array creates new 64 bit signed integer literal expression +func Int64Array(elements []int64) IntegerArrayExpression { + return jet.Int64Array(elements) +} + // Uint8 is constructor for 8 bit unsigned integer expressions literals. func Uint8(value uint8) IntegerExpression { return CAST(jet.Uint8(value)).AS_SMALLINT() @@ -65,6 +80,7 @@ func Double(value float64) FloatExpression { // Decimal creates new float literal expression var Decimal = jet.Decimal +// String creates new string literal expression // String is a parameter constructor for the PostgreSQL text type. Using the `Text` constructor is // generally preferable. // @@ -75,6 +91,11 @@ func String(value string) StringExpression { return CAST(jet.String(value)).AS_TEXT() } +// StringArray creates new string array literal expression +func StringArray(elements []string) StringArrayExpression { + return jet.StringArray(elements) +} + // Text is a parameter constructor for the PostgreSQL text type. This constructor also adds an // explicit placeholder type cast to text in the generated query, such as `$3::text`. // Example usage: diff --git a/postgres/utils_test.go b/postgres/utils_test.go index 96bb13b0..d89b17bd 100644 --- a/postgres/utils_test.go +++ b/postgres/utils_test.go @@ -18,6 +18,8 @@ var table1ColBool = BoolColumn("col_bool") var table1ColDate = DateColumn("col_date") var table1ColInterval = IntervalColumn("col_interval") var table1ColRange = Int8RangeColumn("col_range") +var table1ColStringArray = StringArrayColumn("col_string_array") +var table1ColIntArray = IntegerArrayColumn("col_int_array") var table1 = NewTable( "db", @@ -34,6 +36,8 @@ var table1 = NewTable( table1ColTimestampz, table1ColInterval, table1ColRange, + table1ColStringArray, + table1ColIntArray, ) var table2Col3 = IntegerColumn("col3") @@ -49,8 +53,10 @@ var table2ColTimestampz = TimestampzColumn("col_timestampz") var table2ColDate = DateColumn("col_date") var table2ColInterval = IntervalColumn("col_interval") var table2ColRange = Int8RangeColumn("col_range") +var table2ColStringArray = StringArrayColumn("col_string_array") +var table2ColIntArray = IntegerArrayColumn("col_int_array") -var table2 = NewTable("db", "table2", "", table2Col3, table2Col4, table2ColInt, table2ColFloat, table2ColStr, table2ColBool, table2ColTime, table2ColTimez, table2ColDate, table2ColTimestamp, table2ColTimestampz, table2ColInterval, table2ColRange) +var table2 = NewTable("db", "table2", "", table2Col3, table2Col4, table2ColInt, table2ColFloat, table2ColStr, table2ColBool, table2ColTime, table2ColTimez, table2ColDate, table2ColTimestamp, table2ColTimestampz, table2ColInterval, table2ColRange, table2ColStringArray, table2ColIntArray) var table3Col1 = IntegerColumn("col1") var table3ColInt = IntegerColumn("col_int") diff --git a/tests/docker-compose.yaml b/tests/docker-compose.yaml index bcbbb25f..9aba1ec2 100644 --- a/tests/docker-compose.yaml +++ b/tests/docker-compose.yaml @@ -1,4 +1,3 @@ -version: '3' services: postgres: image: postgres:14.1 @@ -13,7 +12,7 @@ services: - ./testdata/init/postgres:/docker-entrypoint-initdb.d mysql: - image: mysql:8.0.27 + image: mysql/mysql-server:8.0.27 command: ['--default-authentication-plugin=mysql_native_password', '--log_bin_trust_function_creators=1'] restart: always environment: diff --git a/tests/postgres/alltypes_test.go b/tests/postgres/alltypes_test.go index 6c3755a2..48a2b814 100644 --- a/tests/postgres/alltypes_test.go +++ b/tests/postgres/alltypes_test.go @@ -5,6 +5,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/go-jet/jet/v2/qrm" + "github.com/lib/pq" "testing" "time" @@ -1461,11 +1462,11 @@ var allTypesRow0 = model.AllTypes{ JSON: `{"a": 1, "b": 3}`, JsonbPtr: ptr.Of(`{"a": 1, "b": 3}`), Jsonb: `{"a": 1, "b": 3}`, - IntegerArrayPtr: ptr.Of("{1,2,3}"), - IntegerArray: "{1,2,3}", - TextArrayPtr: ptr.Of("{breakfast,consulting}"), - TextArray: "{breakfast,consulting}", - JsonbArray: `{"{\"a\": 1, \"b\": 2}","{\"a\": 3, \"b\": 4}"}`, + IntegerArrayPtr: &pq.Int32Array{1, 2, 3}, + IntegerArray: pq.Int32Array{1, 2, 3}, + TextArrayPtr: &pq.StringArray{"breakfast", "consulting"}, + TextArray: pq.StringArray{"breakfast", "consulting"}, + JsonbArray: pq.StringArray{`{"a": 1, "b": 2}`, `{"a": 3, "b": 4}`}, TextMultiDimArrayPtr: ptr.Of("{{meeting,lunch},{training,presentation}}"), TextMultiDimArray: "{{meeting,lunch},{training,presentation}}", MoodPtr: &moodSad, @@ -1530,10 +1531,10 @@ var allTypesRow1 = model.AllTypes{ JsonbPtr: nil, Jsonb: `{"a": 1, "b": 3}`, IntegerArrayPtr: nil, - IntegerArray: "{1,2,3}", + IntegerArray: pq.Int32Array{1, 2, 3}, TextArrayPtr: nil, - TextArray: "{breakfast,consulting}", - JsonbArray: `{"{\"a\": 1, \"b\": 2}","{\"a\": 3, \"b\": 4}"}`, + TextArray: pq.StringArray{"breakfast", "consulting"}, + JsonbArray: pq.StringArray{`{"a": 1, "b": 2}`, `{"a": 3, "b": 4}`}, TextMultiDimArrayPtr: nil, TextMultiDimArray: "{{meeting,lunch},{training,presentation}}", MoodPtr: nil, diff --git a/tests/postgres/array_test.go b/tests/postgres/array_test.go new file mode 100644 index 00000000..61d857a7 --- /dev/null +++ b/tests/postgres/array_test.go @@ -0,0 +1,276 @@ +package postgres + +import ( + "github.com/go-jet/jet/v2/internal/testutils" + . "github.com/go-jet/jet/v2/postgres" + "github.com/go-jet/jet/v2/qrm" + "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/test_sample/model" + . "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/test_sample/table" + "github.com/google/go-cmp/cmp" + "github.com/lib/pq" + "github.com/stretchr/testify/require" + "math/big" + "testing" +) + +func TestArrayTableSelect(t *testing.T) { + skipForCockroachDB(t) + + textArray := StringArray([]string{"a"}) + boolArray := BoolArray([]bool{true}) + int4Array := Int32Array([]int32{1, 2}) + int8Array := Int64Array([]int64{10, 11}) + + query := SELECT( + SampleArrays.AllColumns, + SampleArrays.TextArray.EQ(SampleArrays.TextArray).AS("sample.text_eq"), + SampleArrays.BoolArray.EQ(boolArray).AS("sample.bool_eq"), + SampleArrays.TextArray.NOT_EQ(textArray).AS("sample.text_neq"), + SampleArrays.Int4Array.LT(int4Array).IS_TRUE().AS("sample.int4_lt"), + SampleArrays.Int8Array.LT_EQ(int8Array).IS_FALSE().AS("sample.int8_lteq"), + SampleArrays.TextArray.GT(textArray).AS("sample.text_gt"), + SampleArrays.Int4Array.GT_EQ(int4Array).AS("sample.bool_gteq"), + Int32(22).EQ(ANY[IntegerExpression](SampleArrays.Int4Array)).AS("sample.int4_eq_any"), + Int32(22).NOT_EQ(ANY[IntegerExpression](SampleArrays.Int4Array)).AS("sample.int4_neq_any"), + Int32(22).EQ(ALL[IntegerExpression](SampleArrays.Int4Array)).AS("sample.int4_eq_all"), + SampleArrays.Int8Array.CONTAINS(Int64Array([]int64{75364})).AS("sample.int8cont"), + SampleArrays.Int8Array.IS_CONTAINED_BY(Int64Array([]int64{75364})).AS("sample.int8cont_by"), + SampleArrays.Int4Array.OVERLAP(int4Array).AS("sample.int4_overlap"), + SampleArrays.BoolArray.CONCAT(boolArray).AS("sample.bool_concat"), + SampleArrays.TextArray.CONCAT_ELEMENT(String("z")).AS("sample.text_concat_el"), + SampleArrays.TextArray.AT(Int32(1)).AS("sample.text_at"), + ARRAY_APPEND[StringExpression](SampleArrays.TextArray, String("after")).AS("sample.text_append"), + ARRAY_CAT[StringExpression](SampleArrays.TextArray, textArray).AS("sample.text_cat"), + ARRAY_LENGTH[StringExpression](SampleArrays.TextArray, Int32(1)).AS("sample.text_length"), + ARRAY_PREPEND[StringExpression](String("before"), SampleArrays.TextArray).AS("sample.text_prepend"), + ).FROM( + SampleArrays, + ).WHERE( + SampleArrays.BoolArray.CONTAINS(BoolArray([]bool{true})), + ) + + testutils.AssertStatementSql(t, query, ` +SELECT sample_arrays.text_array AS "sample_arrays.text_array", + sample_arrays.bool_array AS "sample_arrays.bool_array", + sample_arrays.int4_array AS "sample_arrays.int4_array", + sample_arrays.int8_array AS "sample_arrays.int8_array", + (sample_arrays.text_array = sample_arrays.text_array) AS "sample.text_eq", + (sample_arrays.bool_array = $1) AS "sample.bool_eq", + (sample_arrays.text_array != $2) AS "sample.text_neq", + (sample_arrays.int4_array < $3) IS TRUE AS "sample.int4_lt", + (sample_arrays.int8_array <= $4) IS FALSE AS "sample.int8_lteq", + (sample_arrays.text_array > $5) AS "sample.text_gt", + (sample_arrays.int4_array >= $6) AS "sample.bool_gteq", + ($7::integer = ANY(sample_arrays.int4_array)) AS "sample.int4_eq_any", + ($8::integer != ANY(sample_arrays.int4_array)) AS "sample.int4_neq_any", + ($9::integer = ALL(sample_arrays.int4_array)) AS "sample.int4_eq_all", + (sample_arrays.int8_array @> $10) AS "sample.int8cont", + (sample_arrays.int8_array <@ $11) AS "sample.int8cont_by", + (sample_arrays.int4_array && $12) AS "sample.int4_overlap", + (sample_arrays.bool_array || $13) AS "sample.bool_concat", + (sample_arrays.text_array || $14::text) AS "sample.text_concat_el", + sample_arrays.text_array[$15::integer] AS "sample.text_at", + ARRAY_APPEND(sample_arrays.text_array, $16::text) AS "sample.text_append", + ARRAY_CAT(sample_arrays.text_array, $17) AS "sample.text_cat", + ARRAY_LENGTH(sample_arrays.text_array, $18::integer) AS "sample.text_length", + ARRAY_PREPEND($19::text, sample_arrays.text_array) AS "sample.text_prepend" +FROM test_sample.sample_arrays +WHERE sample_arrays.bool_array @> $20; +`) + + type sample struct { + model.SampleArrays + TextEq bool + BoolEq bool + TextNeq bool + Int4Lt bool + Int8Lteq bool + TextGt bool + BoolGteq bool + Int4EqAny bool + Int4NeqAny bool + Int4EqAll bool + Int8Cont bool + Int8ContBy bool + Int4Overlap bool + BoolConcat pq.BoolArray + TextConcatEl pq.StringArray + TextAt string + TextAppend pq.StringArray + TextCat pq.StringArray + TextLength int32 + TextPrepend pq.StringArray + } + + var dest sample + err := query.Query(db, &dest) + require.NoError(t, err) + + expectedRow := sample{ + SampleArrays: sampleArrayRow, + TextEq: true, + BoolEq: true, + TextNeq: true, + Int4Lt: false, + Int8Lteq: true, + TextGt: true, + BoolGteq: true, + Int4EqAny: false, + Int4NeqAny: true, + Int4EqAll: false, + Int8Cont: false, + Int8ContBy: false, + Int4Overlap: true, + BoolConcat: pq.BoolArray{true, true}, + TextConcatEl: pq.StringArray{"a", "b", "z"}, + TextAt: "a", + TextAppend: pq.StringArray{"a", "b", "after"}, + TextCat: pq.StringArray{"a", "b", "a"}, + TextLength: 2, + TextPrepend: pq.StringArray{"before", "a", "b"}, + } + + testutils.AssertDeepEqual(t, dest, expectedRow, cmp.AllowUnexported(big.Int{})) + requireLogged(t, query) +} + +func TestArraySelectColumnsFromSubQuery(t *testing.T) { + skipForCockroachDB(t) + + subQuery := SELECT( + SampleArrays.AllColumns, + SampleArrays.Int4Array.AS("array4"), + ).FROM( + SampleArrays, + ).AsTable("sub_query") + + int4Array := IntegerArrayColumn("array4").From(subQuery) + + stmt := SELECT( + subQuery.AllColumns(), + int4Array, + ).FROM( + subQuery, + ) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT sub_query."sample_arrays.text_array" AS "sample_arrays.text_array", + sub_query."sample_arrays.bool_array" AS "sample_arrays.bool_array", + sub_query."sample_arrays.int4_array" AS "sample_arrays.int4_array", + sub_query."sample_arrays.int8_array" AS "sample_arrays.int8_array", + sub_query.array4 AS "array4", + sub_query.array4 AS "array4" +FROM ( + SELECT sample_arrays.text_array AS "sample_arrays.text_array", + sample_arrays.bool_array AS "sample_arrays.bool_array", + sample_arrays.int4_array AS "sample_arrays.int4_array", + sample_arrays.int8_array AS "sample_arrays.int8_array", + sample_arrays.int4_array AS "array4" + FROM test_sample.sample_arrays + ) AS sub_query; +`) + + var dest struct { + model.SampleArrays + Array4 pq.Int32Array + } + + err := stmt.Query(db, &dest) + + require.NoError(t, err) + testutils.AssertDeepEqual(t, dest.SampleArrays.Int4Array, sampleArrayRow.Int4Array) + testutils.AssertDeepEqual(t, dest.SampleArrays.Int8Array, sampleArrayRow.Int8Array) + testutils.AssertDeepEqual(t, dest.Array4, sampleArrayRow.Int4Array) +} + +func TestArrayTable_InsertColumn(t *testing.T) { + skipForCockroachDB(t) + + insertQuery := SampleArrays.INSERT(SampleArrays.AllColumns). + VALUES( + ARRAY(String("A"), String("B")), + ARRAY(Bool(true)), + ARRAY(Int32(1)), + ARRAY(Int64(2)), + ). + MODEL( + sampleArrayRow, + ). + RETURNING(SampleArrays.AllColumns) + + expectedQuery := ` +INSERT INTO test_sample.sample_arrays (text_array, bool_array, int4_array, int8_array) +VALUES (ARRAY['A'::text,'B'::text], ARRAY[TRUE::boolean], ARRAY[1::integer], ARRAY[2::bigint]), + ('{"a","b"}', '{t}', '{1,2,3}', '{10,11,12}') +RETURNING sample_arrays.text_array AS "sample_arrays.text_array", + sample_arrays.bool_array AS "sample_arrays.bool_array", + sample_arrays.int4_array AS "sample_arrays.int4_array", + sample_arrays.int8_array AS "sample_arrays.int8_array"; +` + testutils.AssertDebugStatementSql(t, insertQuery, expectedQuery) + + testutils.ExecuteInTxAndRollback(t, db, func(tx qrm.DB) { + var dest []model.SampleArrays + err := insertQuery.Query(tx, &dest) + require.NoError(t, err) + require.Len(t, dest, 2) + testutils.AssertDeepEqual(t, sampleArrayRow, dest[1], cmp.AllowUnexported(big.Int{})) + }) +} + +func TestArrayTableUpdate(t *testing.T) { + skipForCockroachDB(t) + + t.Run("using model", func(t *testing.T) { + stmt := SampleArrays.UPDATE(SampleArrays.AllColumns). + MODEL(sampleArrayRow). + WHERE(String("a").EQ(ANY[StringExpression](SampleArrays.TextArray))). + RETURNING(SampleArrays.AllColumns) + + testutils.AssertStatementSql(t, stmt, ` +UPDATE test_sample.sample_arrays +SET (text_array, bool_array, int4_array, int8_array) = ($1, $2, $3, $4) +WHERE $5::text = ANY(sample_arrays.text_array) +RETURNING sample_arrays.text_array AS "sample_arrays.text_array", + sample_arrays.bool_array AS "sample_arrays.bool_array", + sample_arrays.int4_array AS "sample_arrays.int4_array", + sample_arrays.int8_array AS "sample_arrays.int8_array"; +`) + + testutils.ExecuteInTxAndRollback(t, db, func(tx qrm.DB) { + var dest []model.SampleArrays + err := stmt.Query(tx, &dest) + require.NoError(t, err) + require.Len(t, dest, 1) + testutils.AssertDeepEqual(t, sampleArrayRow, dest[0], cmp.AllowUnexported(big.Int{})) + }) + }) + + t.Run("update using SET", func(t *testing.T) { + stmt := SampleArrays.UPDATE(). + SET( + SampleArrays.Int4Array.SET(ARRAY(Int32(-10), Int32(11))), + SampleArrays.Int8Array.SET(ARRAY(Int64(-1200), Int64(7800))), + ). + WHERE(String("a").EQ(ANY[StringExpression](SampleArrays.TextArray))) + + testutils.AssertDebugStatementSql(t, stmt, ` +UPDATE test_sample.sample_arrays +SET int4_array = ARRAY[-10::integer,11::integer], + int8_array = ARRAY[-1200::bigint,7800::bigint] +WHERE 'a'::text = ANY(sample_arrays.text_array); +`) + + testutils.ExecuteInTxAndRollback(t, db, func(tx qrm.DB) { + testutils.AssertExec(t, stmt, tx, 1) + }) + }) + +} + +var sampleArrayRow = model.SampleArrays{ + TextArray: pq.StringArray([]string{"a", "b"}), + BoolArray: pq.BoolArray([]bool{true}), + Int4Array: pq.Int32Array([]int32{1, 2, 3}), + Int8Array: pq.Int64Array([]int64{10, 11, 12}), +} diff --git a/tests/postgres/generator_template_test.go b/tests/postgres/generator_template_test.go index e518db71..e72cb16c 100644 --- a/tests/postgres/generator_template_test.go +++ b/tests/postgres/generator_template_test.go @@ -446,7 +446,7 @@ func TestGeneratorTemplate_Model_ChangeFieldTypes(t *testing.T) { require.Contains(t, data, "\"database/sql\"") require.Contains(t, data, "Description sql.NullString") require.Contains(t, data, "ReleaseYear sql.NullInt32") - require.Contains(t, data, "SpecialFeatures sql.NullString") + require.Contains(t, data, "SpecialFeatures *pq.StringArray") } func TestGeneratorTemplate_SQLBuilder_ChangeColumnTypes(t *testing.T) { diff --git a/tests/postgres/generator_test.go b/tests/postgres/generator_test.go index 93fb9c3a..4fdc4a31 100644 --- a/tests/postgres/generator_test.go +++ b/tests/postgres/generator_test.go @@ -757,13 +757,13 @@ func TestGeneratedAllTypesSQLBuilderFiles(t *testing.T) { testutils.AssertFileNamesEqual(t, modelDir, "all_types.go", "all_types_view.go", "employee.go", "link.go", "mood.go", "person.go", "person_phone.go", "weird_names_table.go", "level.go", "user.go", "floats.go", "people.go", - "components.go", "vulnerabilities.go", "all_types_materialized_view.go", "sample_ranges.go") + "components.go", "vulnerabilities.go", "all_types_materialized_view.go", "sample_ranges.go", "sample_arrays.go") testutils.AssertFileContent(t, modelDir+"/all_types.go", allTypesModelContent) testutils.AssertFileContent(t, modelDir+"/link.go", linkModelContent) testutils.AssertFileNamesEqual(t, tableDir, "all_types.go", "employee.go", "link.go", "person.go", "person_phone.go", "weird_names_table.go", "user.go", "floats.go", "people.go", "table_use_schema.go", - "components.go", "vulnerabilities.go", "sample_ranges.go") + "components.go", "vulnerabilities.go", "sample_ranges.go", "sample_arrays.go") testutils.AssertFileContent(t, tableDir+"/all_types.go", allTypesTableContent) testutils.AssertFileContent(t, tableDir+"/sample_ranges.go", sampleRangeTableContent) @@ -836,6 +836,7 @@ package model import ( "github.com/google/uuid" + "github.com/lib/pq" "time" ) @@ -894,11 +895,11 @@ type AllTypes struct { JSON string JsonbPtr *string Jsonb string - IntegerArrayPtr *string - IntegerArray string - TextArrayPtr *string - TextArray string - JsonbArray string + IntegerArrayPtr *pq.Int32Array + IntegerArray pq.Int32Array + TextArrayPtr *pq.StringArray + TextArray pq.StringArray + JsonbArray pq.StringArray TextMultiDimArrayPtr *string TextMultiDimArray string MoodPtr *Mood @@ -999,11 +1000,11 @@ type allTypesTable struct { JSON postgres.ColumnString JsonbPtr postgres.ColumnString Jsonb postgres.ColumnString - IntegerArrayPtr postgres.ColumnString - IntegerArray postgres.ColumnString - TextArrayPtr postgres.ColumnString - TextArray postgres.ColumnString - JsonbArray postgres.ColumnString + IntegerArrayPtr postgres.ColumnIntegerArray + IntegerArray postgres.ColumnIntegerArray + TextArrayPtr postgres.ColumnStringArray + TextArray postgres.ColumnStringArray + JsonbArray postgres.ColumnStringArray TextMultiDimArrayPtr postgres.ColumnString TextMultiDimArray postgres.ColumnString MoodPtr postgres.ColumnString @@ -1102,11 +1103,11 @@ func newAllTypesTableImpl(schemaName, tableName, alias string) allTypesTable { JSONColumn = postgres.StringColumn("json") JsonbPtrColumn = postgres.StringColumn("jsonb_ptr") JsonbColumn = postgres.StringColumn("jsonb") - IntegerArrayPtrColumn = postgres.StringColumn("integer_array_ptr") - IntegerArrayColumn = postgres.StringColumn("integer_array") - TextArrayPtrColumn = postgres.StringColumn("text_array_ptr") - TextArrayColumn = postgres.StringColumn("text_array") - JsonbArrayColumn = postgres.StringColumn("jsonb_array") + IntegerArrayPtrColumn = postgres.IntegerArrayColumn("integer_array_ptr") + IntegerArrayColumn = postgres.IntegerArrayColumn("integer_array") + TextArrayPtrColumn = postgres.StringArrayColumn("text_array_ptr") + TextArrayColumn = postgres.StringArrayColumn("text_array") + JsonbArrayColumn = postgres.StringArrayColumn("jsonb_array") TextMultiDimArrayPtrColumn = postgres.StringColumn("text_multi_dim_array_ptr") TextMultiDimArrayColumn = postgres.StringColumn("text_multi_dim_array") MoodPtrColumn = postgres.StringColumn("mood_ptr") diff --git a/tests/postgres/scan_test.go b/tests/postgres/scan_test.go index 321fc383..594070f4 100644 --- a/tests/postgres/scan_test.go +++ b/tests/postgres/scan_test.go @@ -2,6 +2,7 @@ package postgres import ( "context" + "github.com/lib/pq" "github.com/go-jet/jet/v2/internal/utils/ptr" "github.com/volatiletech/null/v8" "testing" @@ -968,10 +969,11 @@ func TestScanIntoCustomBaseTypes(t *testing.T) { ReplacementCost MyFloat64 Rating *model.MpaaRating LastUpdate MyTime - SpecialFeatures *MyString + SpecialFeatures pq.StringArray Fulltext MyString } + // We'll skip special features, because it's a slice and it does not implement sql.Scanner stmt := SELECT( Film.AllColumns, ).FROM( @@ -980,14 +982,12 @@ func TestScanIntoCustomBaseTypes(t *testing.T) { Film.FilmID.ASC(), ).LIMIT(3) - var films []model.Film - - err := stmt.Query(db, &films) - require.NoError(t, err) - var myFilms []film + err := stmt.Query(db, &myFilms) + require.NoError(t, err) - err = stmt.Query(db, &myFilms) + var films []model.Film + err = stmt.Query(db, &films) require.NoError(t, err) require.Equal(t, testutils.ToJSON(films), testutils.ToJSON(myFilms)) @@ -1161,7 +1161,7 @@ var film1 = model.Film{ ReplacementCost: 20.99, Rating: &pgRating, LastUpdate: *testutils.TimestampWithoutTimeZone("2013-05-26 14:50:58.951", 3), - SpecialFeatures: ptr.Of("{\"Deleted Scenes\",\"Behind the Scenes\"}"), + SpecialFeatures: &pq.StringArray{"Deleted Scenes", "Behind the Scenes"}, Fulltext: "'academi':1 'battl':15 'canadian':20 'dinosaur':2 'drama':5 'epic':4 'feminist':8 'mad':11 'must':14 'rocki':21 'scientist':12 'teacher':17", } @@ -1177,7 +1177,7 @@ var film2 = model.Film{ ReplacementCost: 12.99, Rating: &gRating, LastUpdate: *testutils.TimestampWithoutTimeZone("2013-05-26 14:50:58.951", 3), - SpecialFeatures: ptr.Of(`{Trailers,"Deleted Scenes"}`), + SpecialFeatures: &pq.StringArray{"Trailers", "Deleted Scenes"}, Fulltext: `'ace':1 'administr':9 'ancient':19 'astound':4 'car':17 'china':20 'databas':8 'epistl':5 'explor':12 'find':15 'goldfing':2 'must':14`, } diff --git a/tests/postgres/select_test.go b/tests/postgres/select_test.go index d39987d6..8d0cdce3 100644 --- a/tests/postgres/select_test.go +++ b/tests/postgres/select_test.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "github.com/go-jet/jet/v2/internal/utils/ptr" + "github.com/lib/pq" "testing" "time" @@ -1838,7 +1839,7 @@ ORDER BY film.film_id ASC; Rating: &gRating, RentalDuration: 3, LastUpdate: *testutils.TimestampWithoutTimeZone("2013-05-26 14:50:58.951", 3), - SpecialFeatures: ptr.Of("{Trailers,\"Deleted Scenes\"}"), + SpecialFeatures: &pq.StringArray{"Trailers", "Deleted Scenes"}, Fulltext: "'ace':1 'administr':9 'ancient':19 'astound':4 'car':17 'china':20 'databas':8 'epistl':5 'explor':12 'find':15 'goldfing':2 'must':14", }) } @@ -3360,7 +3361,10 @@ func TestSelectRecursionScanNxM(t *testing.T) { "ReplacementCost": 20.99, "Rating": "PG", "LastUpdate": "2013-05-26T14:50:58.951Z", - "SpecialFeatures": "{\"Deleted Scenes\",\"Behind the Scenes\"}", + "SpecialFeatures": [ + "Deleted Scenes", + "Behind the Scenes" + ], "Fulltext": "'academi':1 'battl':15 'canadian':20 'dinosaur':2 'drama':5 'epic':4 'feminist':8 'mad':11 'must':14 'rocki':21 'scientist':12 'teacher':17", "Actors": [ { @@ -3384,7 +3388,10 @@ func TestSelectRecursionScanNxM(t *testing.T) { "ReplacementCost": 9.99, "Rating": "R", "LastUpdate": "2013-05-26T14:50:58.951Z", - "SpecialFeatures": "{Trailers,\"Deleted Scenes\"}", + "SpecialFeatures": [ + "Trailers", + "Deleted Scenes" + ], "Fulltext": "'anaconda':1 'australia':18 'confess':2 'dentist':8,11 'display':5 'fight':14 'girl':16 'lacklustur':4 'must':13", "Actors": [ { @@ -3432,7 +3439,10 @@ func TestSelectRecursionScanNxM(t *testing.T) { "ReplacementCost": 20.99, "Rating": "PG", "LastUpdate": "2013-05-26T14:50:58.951Z", - "SpecialFeatures": "{\"Deleted Scenes\",\"Behind the Scenes\"}", + "SpecialFeatures": [ + "Deleted Scenes", + "Behind the Scenes" + ], "Fulltext": "'academi':1 'battl':15 'canadian':20 'dinosaur':2 'drama':5 'epic':4 'feminist':8 'mad':11 'must':14 'rocki':21 'scientist':12 'teacher':17", "Actors": null }, @@ -3448,7 +3458,10 @@ func TestSelectRecursionScanNxM(t *testing.T) { "ReplacementCost": 9.99, "Rating": "R", "LastUpdate": "2013-05-26T14:50:58.951Z", - "SpecialFeatures": "{Trailers,\"Deleted Scenes\"}", + "SpecialFeatures": [ + "Trailers", + "Deleted Scenes" + ], "Fulltext": "'anaconda':1 'australia':18 'confess':2 'dentist':8,11 'display':5 'fight':14 'girl':16 'lacklustur':4 'must':13", "Actors": null } diff --git a/tests/postgres/values_test.go b/tests/postgres/values_test.go index bdb631c8..f46ca299 100644 --- a/tests/postgres/values_test.go +++ b/tests/postgres/values_test.go @@ -172,7 +172,9 @@ ORDER BY film_values.title; "ReplacementCost": 15.99, "Rating": "R", "LastUpdate": "2013-05-26T14:50:58.951Z", - "SpecialFeatures": "{Trailers}", + "SpecialFeatures": [ + "Trailers" + ], "Fulltext": "'airport':1 'ancient':18 'confront':14 'epic':4 'girl':11 'india':19 'monkey':16 'moos':8 'must':13 'pollock':2 'tale':5" }, "Title": "Airport Pollock", @@ -194,7 +196,9 @@ ORDER BY film_values.title; "ReplacementCost": 12.99, "Rating": "PG-13", "LastUpdate": "2013-05-26T14:50:58.951Z", - "SpecialFeatures": "{Trailers}", + "SpecialFeatures": [ + "Trailers" + ], "Fulltext": "'boat':20 'bright':1 'conquer':14 'encount':2 'fate':4 'feminist':11 'jet':19 'lumberjack':8 'must':13 'student':16 'yarn':5" }, "Title": "Bright Encounters", diff --git a/tests/testdata b/tests/testdata index 1c501acb..8433df98 160000 --- a/tests/testdata +++ b/tests/testdata @@ -1 +1 @@ -Subproject commit 1c501acb72bea389788404988ef0130b733f9cee +Subproject commit 8433df982dbd9862f20d3d6dcc5ab80f6d44e0cd