Skip to content

Commit

Permalink
Merge pull request #38 from fraenky8/ISSUE-33
Browse files Browse the repository at this point in the history
ISSUE-33: Remove nullable-temporal logic
  • Loading branch information
fraenky8 authored Feb 11, 2022
2 parents 7f3258b + bf13ef5 commit 58b2740
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 66 deletions.
38 changes: 12 additions & 26 deletions internal/cli/tables-to-go-cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,12 @@ func Run(settings *settings.Settings, db database.Database, out output.Writer) (
}

type columnInfo struct {
isNullable bool
isTemporal bool
isNullablePrimitive bool
isNullableTemporal bool
isNullable bool
isTemporal bool
}

func (c columnInfo) hasTrue() bool {
return c.isNullable || c.isTemporal || c.isNullableTemporal || c.isNullablePrimitive
func (c columnInfo) isNullableOrTemporal() bool {
return c.isNullable || c.isTemporal
}

func createTableStructString(settings *settings.Settings, db database.Database, table *database.Table) (string, string, error) {
Expand Down Expand Up @@ -125,7 +123,7 @@ func createTableStructString(settings *settings.Settings, db database.Database,

// ISSUE-4: if columns are part of multiple constraints
// then the sql returns multiple rows per column name.
// Therefore we check if we already added a column with
// Therefore, we check if we already added a column with
// that name to the struct, if so, skip.
if _, ok := columns[columnName]; ok {
continue
Expand All @@ -142,11 +140,8 @@ func createTableStructString(settings *settings.Settings, db database.Database,
if !columnInfo.isTemporal {
columnInfo.isTemporal = col.isTemporal
}
if !columnInfo.isNullableTemporal {
columnInfo.isNullableTemporal = col.isNullableTemporal
}
if !columnInfo.isNullablePrimitive {
columnInfo.isNullablePrimitive = col.isNullablePrimitive
if !columnInfo.isNullable {
columnInfo.isNullable = col.isNullable
}

structFields.WriteString(columnName)
Expand All @@ -169,7 +164,7 @@ func createTableStructString(settings *settings.Settings, db database.Database,
fileContent.WriteString("\n\n")

// write imports
generateImports(&fileContent, settings, db, columnInfo)
generateImports(&fileContent, settings, columnInfo)

// write struct with fields
fileContent.WriteString("type ")
Expand All @@ -181,28 +176,22 @@ func createTableStructString(settings *settings.Settings, db database.Database,
return tableName, fileContent.String(), nil
}

func generateImports(content *strings.Builder, settings *settings.Settings, db database.Database, columnInfo columnInfo) {
func generateImports(content *strings.Builder, settings *settings.Settings, columnInfo columnInfo) {

if !columnInfo.hasTrue() && !settings.IsMastermindStructableRecorder {
if !columnInfo.isNullableOrTemporal() && !settings.IsMastermindStructableRecorder {
return
}

content.WriteString("import (\n")

if columnInfo.isNullablePrimitive && settings.IsNullTypeSQL() {
if columnInfo.isNullable && settings.IsNullTypeSQL() {
content.WriteString("\t\"database/sql\"\n")
}

if columnInfo.isTemporal {
content.WriteString("\t\"time\"\n")
}

if columnInfo.isNullableTemporal && settings.IsNullTypeSQL() {
content.WriteString("\t\n")
content.WriteString(db.GetDriverImportLibrary())
content.WriteString("\n")
}

if settings.IsMastermindStructableRecorder {
content.WriteString("\t\n\"github.com/Masterminds/structable\"\n")
}
Expand Down Expand Up @@ -234,9 +223,8 @@ func mapDbColumnTypeToGoType(s *settings.Settings, db database.Database, column
goType = "time.Time"
columnInfo.isTemporal = true
} else {
goType = getNullType(s, "*time.Time", db.GetTemporalDriverDataType())
goType = getNullType(s, "*time.Time", "sql.NullTime")
columnInfo.isTemporal = s.Null == settings.NullTypeNative
columnInfo.isNullableTemporal = true
columnInfo.isNullable = true
}
} else {
Expand All @@ -253,8 +241,6 @@ func mapDbColumnTypeToGoType(s *settings.Settings, db database.Database, column
}
}

columnInfo.isNullablePrimitive = columnInfo.isNullable && !db.IsTemporal(column)

return goType, columnInfo
}

Expand Down
8 changes: 4 additions & 4 deletions internal/cli/tables-to-go-cli_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1117,7 +1117,7 @@ func TestRun_TemporalColumns(t *testing.T) {
On(
"Write",
"TestTable",
"package dto\n\nimport (\n\t\n"+db.GetDriverImportLibrary()+"\n)\n\ntype TestTable struct {\nColumnName "+dbType.String()+".NullTime `db:\"column_name\"`\n}",
"package dto\n\nimport (\n\t\"database/sql\"\n)\n\ntype TestTable struct {\nColumnName sql.NullTime `db:\"column_name\"`\n}",
)

err := Run(s, mdb, w)
Expand Down Expand Up @@ -1203,7 +1203,7 @@ func TestRun_TemporalColumns(t *testing.T) {
On(
"Write",
"TestTable",
"package dto\n\nimport (\n\t\"time\"\n\t\n"+db.GetDriverImportLibrary()+"\n)\n\ntype TestTable struct {\nColumnName1 "+dbType.String()+".NullTime `db:\"column_name_1\"`\nColumnName2 time.Time `db:\"column_name_2\"`\n}",
"package dto\n\nimport (\n\t\"database/sql\"\n\t\"time\"\n)\n\ntype TestTable struct {\nColumnName1 sql.NullTime `db:\"column_name_1\"`\nColumnName2 time.Time `db:\"column_name_2\"`\n}",
)

err := Run(s, mdb, w)
Expand Down Expand Up @@ -1311,12 +1311,12 @@ func TestRun_TemporalColumns(t *testing.T) {
On(
"Write",
"TestTable1",
"package dto\n\nimport (\n\t\"time\"\n\t\n"+db.GetDriverImportLibrary()+"\n)\n\ntype TestTable1 struct {\nColumnName1 "+dbType.String()+".NullTime `db:\"column_name_1\"`\nColumnName2 time.Time `db:\"column_name_2\"`\n}",
"package dto\n\nimport (\n\t\"database/sql\"\n\t\"time\"\n)\n\ntype TestTable1 struct {\nColumnName1 sql.NullTime `db:\"column_name_1\"`\nColumnName2 time.Time `db:\"column_name_2\"`\n}",
).
On(
"Write",
"TestTable2",
"package dto\n\nimport (\n\t\"time\"\n\t\n"+db.GetDriverImportLibrary()+"\n)\n\ntype TestTable2 struct {\nColumnName1 time.Time `db:\"column_name_1\"`\nColumnName2 "+dbType.String()+".NullTime `db:\"column_name_2\"`\n}",
"package dto\n\nimport (\n\t\"database/sql\"\n\t\"time\"\n)\n\ntype TestTable2 struct {\nColumnName1 time.Time `db:\"column_name_1\"`\nColumnName2 sql.NullTime `db:\"column_name_2\"`\n}",
)

err := Run(s, mdb, w)
Expand Down
2 changes: 0 additions & 2 deletions pkg/database/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ type Database interface {
DSN() string
Connect() (err error)
Close() (err error)
GetDriverImportLibrary() string

GetTables() (tables []*Table, err error)
PrepareGetColumnsOfTableStmt() (err error)
Expand All @@ -47,7 +46,6 @@ type Database interface {

GetTemporalDatatypes() []string
IsTemporal(column Column) bool
GetTemporalDriverDataType() string

// TODO pg: bitstrings, enum, range, other special types
// TODO mysql: bit, enums, set
Expand Down
12 changes: 0 additions & 12 deletions pkg/database/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,6 @@ func (mysql *MySQL) DSN() string {
user, mysql.Settings.Pswd, mysql.Settings.Host, mysql.Settings.Port, mysql.Settings.DbName)
}

// GetDriverImportLibrary returns the golang sql driver specific fot the
// MySQL database.
func (mysql *MySQL) GetDriverImportLibrary() string {
return `"github.com/go-sql-driver/mysql"`
}

// GetTables gets all tables for a given database by name.
func (mysql *MySQL) GetTables() (tables []*Table, err error) {

Expand Down Expand Up @@ -202,9 +196,3 @@ func (mysql *MySQL) GetTemporalDatatypes() []string {
func (mysql *MySQL) IsTemporal(column Column) bool {
return mysql.IsStringInSlice(column.DataType, mysql.GetTemporalDatatypes())
}

// GetTemporalDriverDataType returns the time data type specific for the
// MySQL database.
func (mysql *MySQL) GetTemporalDriverDataType() string {
return "mysql.NullTime"
}
12 changes: 0 additions & 12 deletions pkg/database/postgresql.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,6 @@ func (pg *Postgresql) DSN() string {
pg.Settings.Host, pg.Settings.Port, user, pg.Settings.DbName, pg.Settings.Pswd)
}

// GetDriverImportLibrary returns the golang sql driver specific fot the
// Postgresql database.
func (pg *Postgresql) GetDriverImportLibrary() string {
return `pg "github.com/lib/pq"`
}

// GetTables gets all tables for a given schema by name.
func (pg *Postgresql) GetTables() (tables []*Table, err error) {

Expand Down Expand Up @@ -207,9 +201,3 @@ func (pg *Postgresql) GetTemporalDatatypes() []string {
func (pg *Postgresql) IsTemporal(column Column) bool {
return pg.IsStringInSlice(column.DataType, pg.GetTemporalDatatypes())
}

// GetTemporalDriverDataType returns the time data type specific for the
// Postgresql database.
func (pg *Postgresql) GetTemporalDriverDataType() string {
return "pg.NullTime"
}
10 changes: 0 additions & 10 deletions pkg/database/sqlite.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,6 @@ func (s *SQLite) DSN() string {
return strings.ReplaceAll(u.RequestURI(), "_auth=&", "_auth&")
}

// GetDriverImportLibrary returns the golang sql driver specific fot the
// SQLite database.
func (s *SQLite) GetDriverImportLibrary() string {
return `"github.com/mattn/go-sqlite3"`
}

func (s *SQLite) GetTables() (tables []*Table, err error) {

err = s.Select(&tables, `
Expand Down Expand Up @@ -194,7 +188,3 @@ func (s *SQLite) GetTemporalDatatypes() []string {
func (s *SQLite) IsTemporal(_ Column) bool {
return false
}

func (s *SQLite) GetTemporalDriverDataType() string {
return ""
}

0 comments on commit 58b2740

Please sign in to comment.