Skip to content

Commit

Permalink
Merge pull request #47 from fraenky8/ISSUE-45
Browse files Browse the repository at this point in the history
ISSUE-45: Support unknown column types
  • Loading branch information
fraenky8 authored Jul 25, 2022
2 parents 28b301b + f194358 commit 6ade5b0
Show file tree
Hide file tree
Showing 6 changed files with 322 additions and 24 deletions.
15 changes: 7 additions & 8 deletions internal/cli/tables-to-go-cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,13 +203,7 @@ func generateImports(content *strings.Builder, settings *settings.Settings, colu
}

func mapDbColumnTypeToGoType(s *settings.Settings, db database.Database, column database.Column) (goType string, columnInfo columnInfo) {
if db.IsString(column) || db.IsText(column) {
goType = "string"
if db.IsNullable(column) {
goType = getNullType(s, "*string", "sql.NullString")
columnInfo.isNullable = true
}
} else if db.IsInteger(column) {
if db.IsInteger(column) {
goType = "int"
if db.IsNullable(column) {
goType = getNullType(s, "*int", "sql.NullInt64")
Expand Down Expand Up @@ -240,7 +234,12 @@ func mapDbColumnTypeToGoType(s *settings.Settings, db database.Database, column
columnInfo.isNullable = true
}
default:
goType = getNullType(s, "*string", "sql.NullString")
// Everything else we cannot detect defaults to (nullable) string.
goType = "string"
if db.IsNullable(column) {
goType = getNullType(s, "*string", "sql.NullString")
columnInfo.isNullable = true
}
}
}

Expand Down
299 changes: 299 additions & 0 deletions internal/cli/tables-to-go-cli_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1624,6 +1624,305 @@ func TestRun_BooleanColumns(t *testing.T) {
}
}

func TestRun_UnknownColumns(t *testing.T) {
for dbType := range settings.SupportedDbTypes {
t.Run(dbType.String(), func(t *testing.T) {

s := settings.New()
s.DbType = dbType
db := database.New(s)

columnTypes := []string{
"enum", // MySQL
"USER-DEFINED", // Postgres
}

for _, columnType := range columnTypes {
t.Run(columnType, func(t *testing.T) {

t.Run("single table with NOT NULL column", func(t *testing.T) {
s := settings.New()
s.DbType = dbType

mdb := newMockDb(db)

table := &database.Table{
Name: "test_table",
Columns: []database.Column{
{
OrdinalPosition: 1,
Name: "column_name",
DataType: columnType,
},
},
}
mdb.tables = append(mdb.tables, table)

mdb.
On("GetTables").
Return(mdb.tables, nil)
mdb.
On("PrepareGetColumnsOfTableStmt").
Return(nil)
mdb.
On("GetColumnsOfTable", table)

w := newMockWriter()
w.
On(
"Write",
"TestTable",
"package dto\n\ntype TestTable struct {\nColumnName string `db:\"column_name\"`\n}",
)

err := Run(s, mdb, w)
assert.NoError(t, err)
})

t.Run("single table with NULL column", func(t *testing.T) {
s := settings.New()
s.DbType = dbType

mdb := newMockDb(db)

table := &database.Table{
Name: "test_table",
Columns: []database.Column{
{
OrdinalPosition: 1,
Name: "column_name",
DataType: columnType,
IsNullable: "YES",
},
},
}
mdb.tables = append(mdb.tables, table)

mdb.
On("GetTables").
Return(mdb.tables, nil)
mdb.
On("PrepareGetColumnsOfTableStmt").
Return(nil)
mdb.
On("GetColumnsOfTable", table)

w := newMockWriter()
w.
On(
"Write",
"TestTable",
"package dto\n\nimport (\n\t\"database/sql\"\n)\n\ntype TestTable struct {\nColumnName sql.NullString `db:\"column_name\"`\n}",
)

err := Run(s, mdb, w)
assert.NoError(t, err)
})

t.Run("single table with NULL column and native data type", func(t *testing.T) {
s := settings.New()
s.DbType = dbType
s.Null = settings.NullTypeNative

mdb := newMockDb(db)

table := &database.Table{
Name: "test_table",
Columns: []database.Column{
{
OrdinalPosition: 1,
Name: "column_name",
DataType: columnType,
IsNullable: "YES",
},
},
}
mdb.tables = append(mdb.tables, table)

mdb.
On("GetTables").
Return(mdb.tables, nil)
mdb.
On("PrepareGetColumnsOfTableStmt").
Return(nil)
mdb.
On("GetColumnsOfTable", table)

w := newMockWriter()
w.
On(
"Write",
"TestTable",
"package dto\n\nimport (\n)\n\ntype TestTable struct {\nColumnName *string `db:\"column_name\"`\n}",
)

err := Run(s, mdb, w)
assert.NoError(t, err)
})

t.Run("single table with two mixed columns", func(t *testing.T) {
s := settings.New()
s.DbType = dbType

mdb := newMockDb(db)

table := &database.Table{
Name: "test_table",
Columns: []database.Column{
{
OrdinalPosition: 1,
Name: "column_name_1",
DataType: columnType,
IsNullable: "YES",
},
{
OrdinalPosition: 2,
Name: "column_name_2",
DataType: columnType,
},
},
}
mdb.tables = append(mdb.tables, table)

mdb.
On("GetTables").
Return(mdb.tables, nil)
mdb.
On("PrepareGetColumnsOfTableStmt").
Return(nil)
mdb.
On("GetColumnsOfTable", table)

w := newMockWriter()
w.
On(
"Write",
"TestTable",
"package dto\n\nimport (\n\t\"database/sql\"\n)\n\ntype TestTable struct {\nColumnName1 sql.NullString `db:\"column_name_1\"`\nColumnName2 string `db:\"column_name_2\"`\n}",
)

err := Run(s, mdb, w)
assert.NoError(t, err)
})

t.Run("single table with two mixed columns and native data type", func(t *testing.T) {
s := settings.New()
s.DbType = dbType
s.Null = settings.NullTypeNative

mdb := newMockDb(db)

table := &database.Table{
Name: "test_table",
Columns: []database.Column{
{
OrdinalPosition: 1,
Name: "column_name_1",
DataType: columnType,
IsNullable: "YES",
},
{
OrdinalPosition: 2,
Name: "column_name_2",
DataType: columnType,
},
},
}
mdb.tables = append(mdb.tables, table)

mdb.
On("GetTables").
Return(mdb.tables, nil)
mdb.
On("PrepareGetColumnsOfTableStmt").
Return(nil)
mdb.
On("GetColumnsOfTable", table)

w := newMockWriter()
w.
On(
"Write",
"TestTable",
"package dto\n\nimport (\n)\n\ntype TestTable struct {\nColumnName1 *string `db:\"column_name_1\"`\nColumnName2 string `db:\"column_name_2\"`\n}",
)

err := Run(s, mdb, w)
assert.NoError(t, err)
})

t.Run("multi table with multi columns", func(t *testing.T) {
s := settings.New()
s.DbType = dbType

mdb := newMockDb(db)

table1 := &database.Table{
Name: "test_table_1",
Columns: []database.Column{
{
OrdinalPosition: 1,
Name: "column_name_1",
DataType: columnType,
IsNullable: "YES",
},
{
OrdinalPosition: 2,
Name: "column_name_2",
DataType: columnType,
},
},
}
table2 := &database.Table{
Name: "test_table_2",
Columns: []database.Column{
{
OrdinalPosition: 1,
Name: "column_name_1",
DataType: columnType,
},
{
OrdinalPosition: 2,
Name: "column_name_2",
DataType: columnType,
IsNullable: "YES",
},
},
}
mdb.tables = append(mdb.tables, table1, table2)

mdb.
On("GetTables").
Return(mdb.tables, nil)
mdb.
On("PrepareGetColumnsOfTableStmt").
Return(nil)
mdb.
On("GetColumnsOfTable", table1).
On("GetColumnsOfTable", table2)

w := newMockWriter()
w.
On(
"Write",
"TestTable1",
"package dto\n\nimport (\n\t\"database/sql\"\n)\n\ntype TestTable1 struct {\nColumnName1 sql.NullString `db:\"column_name_1\"`\nColumnName2 string `db:\"column_name_2\"`\n}",
).
On(
"Write",
"TestTable2",
"package dto\n\nimport (\n\t\"database/sql\"\n)\n\ntype TestTable2 struct {\nColumnName1 string `db:\"column_name_1\"`\nColumnName2 sql.NullString `db:\"column_name_2\"`\n}",
)

err := Run(s, mdb, w)
assert.NoError(t, err)
})
})
}
})
}
}

func TestValidVariableName(t *testing.T) {
type testCase struct {
name string
Expand Down
4 changes: 2 additions & 2 deletions pkg/database/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,8 @@ func (gdb *GeneralDatabase) IsNullable(column Column) bool {
return column.IsNullable == "YES"
}

// IsStringInSlice checks if needle (string) is in haystack ([]string).
func (gdb *GeneralDatabase) IsStringInSlice(needle string, haystack []string) bool {
// isStringInSlice checks if needle (string) is in haystack ([]string).
func isStringInSlice(needle string, haystack []string) bool {
for _, s := range haystack {
if s == needle {
return true
Expand Down
10 changes: 5 additions & 5 deletions pkg/database/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ func (mysql *MySQL) GetStringDatatypes() []string {

// IsString returns true if the colum is of type string for the MySQL database.
func (mysql *MySQL) IsString(column Column) bool {
return mysql.IsStringInSlice(column.DataType, mysql.GetStringDatatypes())
return isStringInSlice(column.DataType, mysql.GetStringDatatypes())
}

// GetTextDatatypes returns the text datatypes for the MySQL database.
Expand All @@ -146,7 +146,7 @@ func (mysql *MySQL) GetTextDatatypes() []string {

// IsText returns true if colum is of type text for the MySQL database.
func (mysql *MySQL) IsText(column Column) bool {
return mysql.IsStringInSlice(column.DataType, mysql.GetTextDatatypes())
return isStringInSlice(column.DataType, mysql.GetTextDatatypes())
}

// GetIntegerDatatypes returns the integer datatypes for the MySQL database.
Expand All @@ -162,7 +162,7 @@ func (mysql *MySQL) GetIntegerDatatypes() []string {

// IsInteger returns true if colum is of type integer for the MySQL database.
func (mysql *MySQL) IsInteger(column Column) bool {
return mysql.IsStringInSlice(column.DataType, mysql.GetIntegerDatatypes())
return isStringInSlice(column.DataType, mysql.GetIntegerDatatypes())
}

// GetFloatDatatypes returns the float datatypes for the MySQL database.
Expand All @@ -178,7 +178,7 @@ func (mysql *MySQL) GetFloatDatatypes() []string {

// IsFloat returns true if colum is of type float for the MySQL database.
func (mysql *MySQL) IsFloat(column Column) bool {
return mysql.IsStringInSlice(column.DataType, mysql.GetFloatDatatypes())
return isStringInSlice(column.DataType, mysql.GetFloatDatatypes())
}

// GetTemporalDatatypes returns the temporal datatypes for the MySQL database.
Expand All @@ -194,5 +194,5 @@ func (mysql *MySQL) GetTemporalDatatypes() []string {

// IsTemporal returns true if colum is of type temporal for the MySQL database.
func (mysql *MySQL) IsTemporal(column Column) bool {
return mysql.IsStringInSlice(column.DataType, mysql.GetTemporalDatatypes())
return isStringInSlice(column.DataType, mysql.GetTemporalDatatypes())
}
Loading

0 comments on commit 6ade5b0

Please sign in to comment.