Skip to content

Commit

Permalink
Merge pull request #37 from fraenky8/ISSUE-34
Browse files Browse the repository at this point in the history
ISSUE-34: Support unix socket connections
  • Loading branch information
fraenky8 authored Feb 11, 2022
2 parents ee8194e + a84ec71 commit 7f3258b
Show file tree
Hide file tree
Showing 18 changed files with 274 additions and 162 deletions.
19 changes: 10 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -185,24 +185,23 @@ type SomeUserInfo struct {
Print usage with `-?` or `-help`

```
tables-to-go -help
Usage of tables-to-go:
-? shows help and usage
-d string
database name (default "postgres")
-f
force, skip tables that encounter errors but construct all others
-format string
format of struct fields (columns): camelCase (c) or original (o) (default "c")
-f force; skip tables that encounter errors
-fn-format string
format of the filename: camelCase (c, default) or snake_case (s)
format of the filename: camelCase (c, default) or snake_case (s) (default c)
-format string
format of struct fields (columns): camelCase (c) or original (o) (default c)
-h string
host of database (default "127.0.0.1")
-help
shows help and usage
-no-initialism
disable the conversion to upper-case words in column names
disable the conversion to upper-case words in column names
-null string
representation of NULL columns: sql.Null* (sql) or primitive pointers (native|primitive) (default "sql")
representation of NULL columns: sql.Null* (sql) or primitive pointers (native|primitive) (default sql)
-of string
output file path (default "current working directory")
-p string
Expand All @@ -215,12 +214,14 @@ tables-to-go -help
prefix for file- and struct names
-s string
schema name (default "public")
-socket string
The socket file to use for connection. Takes precedence over host:port.
-structable-recorder
generate a structable.Recorder field
-suf string
suffix for file- and struct names
-t string
type of database to use, currently supported: [pg mysql] (default "pg")
type of database to use, currently supported: [pg mysql sqlite3] (default pg)
-tags-no-db
do not create db-tags
-tags-structable
Expand Down
4 changes: 2 additions & 2 deletions internal/cli/tables-to-go-cli_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1733,10 +1733,10 @@ func TestFormatColumnName(t *testing.T) {
{"semicolons", "MyColumn;"},
{"brackets", "MyColumn()"},
}
settings := settings.New()
s := settings.New()
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
_, err := formatColumnName(settings, tc.input, "MyTable")
_, err := formatColumnName(s, tc.input, "MyTable")
if err == nil {
t.Errorf("formatColumnName(%q) should have thrown error but didn't", tc.input)
}
Expand Down
32 changes: 16 additions & 16 deletions pkg/database/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@ import (
)

var (
// dbTypeToDriverMap maps the database type to the driver names
dbTypeToDriverMap = map[settings.DbType]string{
settings.DbTypePostgresql: "postgres",
settings.DbTypeMySQL: "mysql",
settings.DbTypeSQLite: "sqlite3",
// dbTypeToDriverMap maps the database type to the driver names.
dbTypeToDriverMap = map[settings.DBType]string{
settings.DBTypePostgresql: "postgres",
settings.DBTypeMySQL: "mysql",
settings.DBTypeSQLite: "sqlite3",
}
)

// Database interface for the concrete databases
// Database interface for the concrete databases.
type Database interface {
DSN() string
Connect() (err error)
Expand Down Expand Up @@ -53,13 +53,13 @@ type Database interface {
// TODO mysql: bit, enums, set
}

// Table has a name and a set (slice) of columns
// Table has a name and a set (slice) of columns.
type Table struct {
Name string `db:"table_name"`
Columns []Column
}

// Column stores information about a column
// Column stores information about a column.
type Column struct {
OrdinalPosition int `db:"ordinal_position"`
Name string `db:"column_name"`
Expand All @@ -74,8 +74,8 @@ type Column struct {
ConstraintType sql.NullString `db:"constraint_type"` // pg specific
}

// GeneralDatabase represents a base "class" database - for all other concrete databases
// it implements partly the Database interface
// GeneralDatabase represents a base "class" database - for all other concrete
// databases it implements partly the Database interface.
type GeneralDatabase struct {
GetColumnsOfTableStmt *sqlx.Stmt
*sqlx.DB
Expand All @@ -89,11 +89,11 @@ func New(s *settings.Settings) Database {
var db Database

switch s.DbType {
case settings.DbTypeSQLite:
case settings.DBTypeSQLite:
db = NewSQLite(s)
case settings.DbTypeMySQL:
case settings.DBTypeMySQL:
db = NewMySQL(s)
case settings.DbTypePostgresql:
case settings.DBTypePostgresql:
fallthrough
default:
db = NewPostgresql(s)
Expand All @@ -120,17 +120,17 @@ func (gdb *GeneralDatabase) Connect(dsn string) (err error) {
return gdb.Ping()
}

// Close closes the database connection
// Close closes the database connection.
func (gdb *GeneralDatabase) Close() error {
return gdb.DB.Close()
}

// IsNullable returns true if column is a nullable one
// IsNullable returns true if the column is a nullable column.
func (gdb *GeneralDatabase) IsNullable(column Column) bool {
return column.IsNullable == "YES"
}

// IsStringInSlice checks if needle (string) is in haystack ([]string)
// IsStringInSlice checks if needle (string) is in haystack ([]string).
func (gdb *GeneralDatabase) IsStringInSlice(needle string, haystack []string) bool {
for _, s := range haystack {
if s == needle {
Expand Down
63 changes: 40 additions & 23 deletions pkg/database/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,38 +10,52 @@ import (
_ "github.com/go-sql-driver/mysql"
)

// MySQL implemenmts the Database interface with help of generalDatabase
// MySQL implements the Database interface with help of GeneralDatabase.
type MySQL struct {
*GeneralDatabase

defaultUserName string
}

// NewMySQL creates a new MySQL database
// NewMySQL creates a new MySQL database.
func NewMySQL(s *settings.Settings) *MySQL {
return &MySQL{
GeneralDatabase: &GeneralDatabase{
Settings: s,
driver: dbTypeToDriverMap[s.DbType],
},
defaultUserName: "root",
}
}

// Connect connects to the database by the given data source name (dsn) of the concrete database
// Connect connects to the database by the given data source name (dsn) of the
// concrete database.
func (mysql *MySQL) Connect() error {
return mysql.GeneralDatabase.Connect(mysql.DSN())
}

// DSN creates the DSN String to connect to this database
// DSN creates the DSN String to connect to this database.
func (mysql *MySQL) DSN() string {
return fmt.Sprintf("%v:%v@tcp(%v:%v)/%v",
mysql.Settings.User, mysql.Settings.Pswd, mysql.Settings.Host, mysql.Settings.Port, mysql.Settings.DbName)
user := mysql.defaultUserName
if mysql.Settings.User != "" {
user = mysql.Settings.User
}

if mysql.Settings.Socket != "" {
return fmt.Sprintf("%s:%s@unix(%s)/%s",
user, mysql.Settings.Pswd, mysql.Settings.Socket, mysql.Settings.DbName)
}
return fmt.Sprintf("%s:%s@tcp(%s:%s)/%s",
user, mysql.Settings.Pswd, mysql.Settings.Host, mysql.Settings.Port, mysql.Settings.DbName)
}

// GetDriverImportLibrary returns the golang sql driver specific fot the MySQL database
// GetDriverImportLibrary returns the golang sql driver specific fot the
// MySQL database.
func (mysql *MySQL) GetDriverImportLibrary() string {
return `"github.com/go-sql-driver/mysql"`
}

// GetTables gets all tables for a given database by name
// GetTables gets all tables for a given database by name.
func (mysql *MySQL) GetTables() (tables []*Table, err error) {

err = mysql.Select(&tables, `
Expand All @@ -62,7 +76,8 @@ func (mysql *MySQL) GetTables() (tables []*Table, err error) {
return tables, err
}

// PrepareGetColumnsOfTableStmt prepares the statement for retrieving the columns of a specific table for a given database
// PrepareGetColumnsOfTableStmt prepares the statement for retrieving the
// columns of a specific table for a given database.
func (mysql *MySQL) PrepareGetColumnsOfTableStmt() (err error) {

mysql.GetColumnsOfTableStmt, err = mysql.Preparex(`
Expand All @@ -85,7 +100,8 @@ func (mysql *MySQL) PrepareGetColumnsOfTableStmt() (err error) {
return err
}

// GetColumnsOfTable executes the statement for retrieving the columns of a specific table for a given database
// GetColumnsOfTable executes the statement for retrieving the columns of a
// specific table for a given database.
func (mysql *MySQL) GetColumnsOfTable(table *Table) (err error) {

err = mysql.GetColumnsOfTableStmt.Select(&table.Columns, table.Name, mysql.DbName)
Expand All @@ -101,17 +117,17 @@ func (mysql *MySQL) GetColumnsOfTable(table *Table) (err error) {
return err
}

// IsPrimaryKey checks if column belongs to primary key
// IsPrimaryKey checks if the column belongs to the primary key.
func (mysql *MySQL) IsPrimaryKey(column Column) bool {
return strings.Contains(column.ColumnKey, "PRI")
}

// IsAutoIncrement checks if column is a auto_increment column
// IsAutoIncrement checks if the column is an auto_increment column.
func (mysql *MySQL) IsAutoIncrement(column Column) bool {
return strings.Contains(column.Extra, "auto_increment")
}

// GetStringDatatypes returns the string datatypes for the MySQL database
// GetStringDatatypes returns the string datatypes for the MySQL database.
func (mysql *MySQL) GetStringDatatypes() []string {
return []string{
"char",
Expand All @@ -121,25 +137,25 @@ func (mysql *MySQL) GetStringDatatypes() []string {
}
}

// IsString returns true if colum is of type string for the MySQL database
// IsString returns true if the colum is of type string for the MySQL database.
func (mysql *MySQL) IsString(column Column) bool {
return mysql.IsStringInSlice(column.DataType, mysql.GetStringDatatypes())
}

// GetTextDatatypes returns the text datatypes for the MySQL database
// GetTextDatatypes returns the text datatypes for the MySQL database.
func (mysql *MySQL) GetTextDatatypes() []string {
return []string{
"text",
"blob",
}
}

// IsText returns true if colum is of type text for the MySQL database
// IsText returns true if colum is of type text for the MySQL database.
func (mysql *MySQL) IsText(column Column) bool {
return mysql.IsStringInSlice(column.DataType, mysql.GetTextDatatypes())
}

// GetIntegerDatatypes returns the integer datatypes for the MySQL database
// GetIntegerDatatypes returns the integer datatypes for the MySQL database.
func (mysql *MySQL) GetIntegerDatatypes() []string {
return []string{
"tinyint",
Expand All @@ -150,12 +166,12 @@ func (mysql *MySQL) GetIntegerDatatypes() []string {
}
}

// IsInteger returns true if colum is of type integer for the MySQL database
// IsInteger returns true if colum is of type integer for the MySQL database.
func (mysql *MySQL) IsInteger(column Column) bool {
return mysql.IsStringInSlice(column.DataType, mysql.GetIntegerDatatypes())
}

// GetFloatDatatypes returns the float datatypes for the MySQL database
// GetFloatDatatypes returns the float datatypes for the MySQL database.
func (mysql *MySQL) GetFloatDatatypes() []string {
return []string{
"numeric",
Expand All @@ -166,12 +182,12 @@ func (mysql *MySQL) GetFloatDatatypes() []string {
}
}

// IsFloat returns true if colum is of type float for the MySQL database
// IsFloat returns true if colum is of type float for the MySQL database.
func (mysql *MySQL) IsFloat(column Column) bool {
return mysql.IsStringInSlice(column.DataType, mysql.GetFloatDatatypes())
}

// GetTemporalDatatypes returns the temporal datatypes for the MySQL database
// GetTemporalDatatypes returns the temporal datatypes for the MySQL database.
func (mysql *MySQL) GetTemporalDatatypes() []string {
return []string{
"time",
Expand All @@ -182,12 +198,13 @@ func (mysql *MySQL) GetTemporalDatatypes() []string {
}
}

// IsTemporal returns true if colum is of type temporal for the MySQL database
// IsTemporal returns true if colum is of type temporal for the MySQL database.
func (mysql *MySQL) IsTemporal(column Column) bool {
return mysql.IsStringInSlice(column.DataType, mysql.GetTemporalDatatypes())
}

// GetTemporalDriverDataType returns the time data type specific for the MySQL database
// GetTemporalDriverDataType returns the time data type specific for the
// MySQL database.
func (mysql *MySQL) GetTemporalDriverDataType() string {
return "mysql.NullTime"
}
54 changes: 54 additions & 0 deletions pkg/database/mysql_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package database

import (
"testing"

"github.com/stretchr/testify/assert"

"github.com/fraenky8/tables-to-go/pkg/settings"
)

func TestMySQL_DSN(t *testing.T) {
tests := []struct {
desc string
settings func() *settings.Settings
expected func(*settings.Settings) string
}{
{
desc: "no username given, defaults to `root` with tcp",
settings: func() *settings.Settings {
s := settings.New()
s.DbType = settings.DBTypeMySQL
s.Pswd = "mysecretpassword"
s.DbName = "my-cool-db"
s.Port = "3306"
return s
},
expected: func(s *settings.Settings) string {
return "root:mysecretpassword@tcp(127.0.0.1:3306)/my-cool-db"
},
},
{
desc: "username given, with socket",
settings: func() *settings.Settings {
s := settings.New()
s.DbType = settings.DBTypeMySQL
s.User = "admin"
s.Pswd = "mysecretpassword"
s.DbName = "my-cool-db"
s.Socket = "/tmp/mysql.sock"
return s
},
expected: func(s *settings.Settings) string {
return "admin:mysecretpassword@unix(/tmp/mysql.sock)/my-cool-db"
},
},
}
for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
db := NewMySQL(test.settings())
actual := db.DSN()
assert.Equal(t, test.expected(db.Settings), actual)
})
}
}
Loading

0 comments on commit 7f3258b

Please sign in to comment.