From 94dfbe0a3070b15385d5ac76b7871234172b5515 Mon Sep 17 00:00:00 2001 From: Nick Zelei <2420177+nickzelei@users.noreply.github.com> Date: Thu, 24 Oct 2024 14:06:49 -0700 Subject: [PATCH] NEOS-1568: simplifies rdbms url handling (#2855) --- .../pkg/dbconnect-config/dbconnect-config.go | 96 ------- .../dbconnect-config/dbconnect-config_test.go | 108 -------- backend/pkg/dbconnect-config/interface.go | 6 + backend/pkg/dbconnect-config/mssql.go | 60 ++--- backend/pkg/dbconnect-config/mssql_test.go | 144 ++++++++-- backend/pkg/dbconnect-config/mysql.go | 140 +++++----- backend/pkg/dbconnect-config/mysql_test.go | 253 +++++++++++------- backend/pkg/dbconnect-config/postgres.go | 146 +++++----- backend/pkg/dbconnect-config/postgres_test.go | 234 +++++++++++++--- backend/pkg/sqlconnect/sql-connector.go | 8 +- .../v1alpha1/connection-service/connection.go | 74 ++--- .../connections/[id]/components/MysqlForm.tsx | 2 +- .../new/connection/mysql/MysqlForm.tsx | 2 +- 13 files changed, 670 insertions(+), 603 deletions(-) delete mode 100644 backend/pkg/dbconnect-config/dbconnect-config.go delete mode 100644 backend/pkg/dbconnect-config/dbconnect-config_test.go create mode 100644 backend/pkg/dbconnect-config/interface.go diff --git a/backend/pkg/dbconnect-config/dbconnect-config.go b/backend/pkg/dbconnect-config/dbconnect-config.go deleted file mode 100644 index f39c7eee3e..0000000000 --- a/backend/pkg/dbconnect-config/dbconnect-config.go +++ /dev/null @@ -1,96 +0,0 @@ -package dbconnectconfig - -import ( - "fmt" - "net/url" -) - -const ( - mysqlDriver = "mysql" - postgresDriver = "postgres" - mssqlDriver = "sqlserver" -) - -type GeneralDbConnectConfig struct { - driver string - - host string - port *int32 - // For mssql this is actually the path..the database is provided as a query parameter - database *string - user string - pass string - - mysqlProtocol *string - - queryParams url.Values -} - -func (g *GeneralDbConnectConfig) String() string { - if g.driver == postgresDriver { - u := url.URL{ - Scheme: "postgres", - Host: buildDbUrlHost(g.host, g.port), - } - if g.database != nil { - u.Path = *g.database - } - - // Add user info - if g.user != "" || g.pass != "" { - u.User = url.UserPassword(g.user, g.pass) - } - u.RawQuery = g.queryParams.Encode() - return u.String() - } - if g.driver == mysqlDriver { - protocol := "tcp" - if g.mysqlProtocol != nil { - protocol = *g.mysqlProtocol - } - address := fmt.Sprintf("(%s)", buildDbUrlHost(g.host, g.port)) - - // User info - // dont use url.UserPassword as it escapes the password - // host and password should not be escaped. even if they contain special characters - userInfo := g.user - if g.pass != "" { - userInfo += ":" + g.pass - } - // Base DSN - dsn := fmt.Sprintf("%s@%s%s", userInfo, protocol, address) - if g.database != nil { - dsn = fmt.Sprintf("%s/%s", dsn, *g.database) - } - - // Append query parameters if any - if len(g.queryParams) > 0 { - query := g.queryParams.Encode() - dsn += "?" + query - } - return dsn - } - if g.driver == mssqlDriver { - u := url.URL{ - Scheme: mssqlDriver, - Host: buildDbUrlHost(g.host, g.port), - } - if g.database != nil { - u.Path = *g.database - } - // Add user info - if g.user != "" || g.pass != "" { - u.User = url.UserPassword(g.user, g.pass) - } - u.RawQuery = g.queryParams.Encode() - return u.String() - } - return "" -} - -func buildDbUrlHost(host string, port *int32) string { - if port != nil { - return fmt.Sprintf("%s:%d", host, *port) - } - return host -} diff --git a/backend/pkg/dbconnect-config/dbconnect-config_test.go b/backend/pkg/dbconnect-config/dbconnect-config_test.go deleted file mode 100644 index 108c9a9025..0000000000 --- a/backend/pkg/dbconnect-config/dbconnect-config_test.go +++ /dev/null @@ -1,108 +0,0 @@ -package dbconnectconfig - -import ( - "net/url" - "testing" - - "github.com/stretchr/testify/assert" -) - -func Test_GeneralDbConnectionConfig_String(t *testing.T) { - type testcase struct { - name string - input GeneralDbConnectConfig - expected string - } - testcases := []testcase{ - { - name: "empty", - input: GeneralDbConnectConfig{}, - expected: "", - }, - { - name: "postgres", - input: GeneralDbConnectConfig{ - driver: "postgres", - host: "localhost", - port: ptr(int32(5432)), - database: ptr("mydb"), - user: "test-user", - pass: "test-pass", - queryParams: url.Values{"sslmode": []string{"verify"}}, - }, - expected: "postgres://test-user:test-pass@localhost:5432/mydb?sslmode=verify", - }, - { - name: "mysql", - input: GeneralDbConnectConfig{ - driver: "mysql", - host: "localhost", - port: ptr(int32(3309)), - database: ptr("mydb"), - user: "test-user", - pass: "test-pass", - mysqlProtocol: ptr("tcp"), - queryParams: url.Values{"foo": []string{"bar"}}, - }, - expected: "test-user:test-pass@tcp(localhost:3309)/mydb?foo=bar", - }, - { - name: "mysql", - input: GeneralDbConnectConfig{ - driver: "mysql", - host: "localhost", - port: ptr(int32(3309)), - database: ptr("mydb"), - user: "specialuser!*-", - pass: "46!ZfMv3@Uh8*-<", - mysqlProtocol: ptr("tcp"), - queryParams: url.Values{"foo": []string{"bar"}}, - }, - expected: "specialuser!*-:46!ZfMv3@Uh8*-<@tcp(localhost:3309)/mydb?foo=bar", - }, - { - name: "mssql", - input: GeneralDbConnectConfig{ - driver: "sqlserver", - host: "localhost", - port: ptr(int32(1433)), - database: ptr("myinstance"), - user: "sa", - pass: "myStr0ngP@assword", - queryParams: url.Values{"database": []string{"master"}}, - }, - expected: "sqlserver://sa:myStr0ngP%40assword@localhost:1433/myinstance?database=master", - }, - { - name: "mssql-noinstance", - input: GeneralDbConnectConfig{ - driver: "sqlserver", - host: "localhost", - port: ptr(int32(1433)), - database: nil, - user: "sa", - pass: "myStr0ngP@assword", - queryParams: url.Values{"database": []string{"master"}}, - }, - expected: "sqlserver://sa:myStr0ngP%40assword@localhost:1433?database=master", - }, - { - name: "mssql-noinstance-noport", - input: GeneralDbConnectConfig{ - driver: "sqlserver", - host: "localhost", - port: nil, - database: nil, - user: "sa", - pass: "myStr0ngP@assword", - queryParams: url.Values{"database": []string{"master"}}, - }, - expected: "sqlserver://sa:myStr0ngP%40assword@localhost?database=master", - }, - } - for _, tc := range testcases { - t.Run(tc.name, func(t *testing.T) { - assert.Equal(t, tc.input.String(), tc.expected) - }) - } -} diff --git a/backend/pkg/dbconnect-config/interface.go b/backend/pkg/dbconnect-config/interface.go new file mode 100644 index 0000000000..d29210eea3 --- /dev/null +++ b/backend/pkg/dbconnect-config/interface.go @@ -0,0 +1,6 @@ +package dbconnectconfig + +type DbConnectConfig interface { + String() string + GetUser() string +} diff --git a/backend/pkg/dbconnect-config/mssql.go b/backend/pkg/dbconnect-config/mssql.go index 67549dbecf..4887ec5084 100644 --- a/backend/pkg/dbconnect-config/mssql.go +++ b/backend/pkg/dbconnect-config/mssql.go @@ -4,18 +4,32 @@ import ( "errors" "fmt" "net/url" - "strconv" - "strings" mgmtv1alpha1 "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1" nucleuserrors "github.com/nucleuscloud/neosync/backend/internal/errors" - "github.com/nucleuscloud/neosync/worker/pkg/workflows/datasync/activities/shared" ) -func NewFromMssqlConnection(config *mgmtv1alpha1.ConnectionConfig_MssqlConfig, connectionTimeout *uint32) (*GeneralDbConnectConfig, error) { +type mssqlConnectConfig struct { + url string + user string +} + +var _ DbConnectConfig = (*mssqlConnectConfig)(nil) + +func (m *mssqlConnectConfig) String() string { + return m.url +} +func (m *mssqlConnectConfig) GetUser() string { + return m.user +} + +func NewFromMssqlConnection( + config *mgmtv1alpha1.ConnectionConfig_MssqlConfig, + connectionTimeout *uint32, +) (DbConnectConfig, error) { switch cc := config.MssqlConfig.ConnectionConfig.(type) { case *mgmtv1alpha1.MssqlConnectionConfig_Url: - u, err := url.Parse(cc.Url) + uriconfig, err := url.Parse(cc.Url) if err != nil { var urlErr *url.Error if errors.As(err, &urlErr) { @@ -23,43 +37,15 @@ func NewFromMssqlConnection(config *mgmtv1alpha1.ConnectionConfig_MssqlConfig, c } return nil, fmt.Errorf("unable to parse mssql url: %w", err) } - user := u.User.Username() - pass, _ := u.User.Password() - - host, portStr := u.Hostname(), u.Port() - query := u.Query() - - var port *int32 - if portStr != "" { - parsedPort, err := strconv.ParseInt(portStr, 10, 32) - if err != nil { - return nil, fmt.Errorf("invalid port when processing mssql connection url: %w", err) - } - port = shared.Ptr(int32(parsedPort)) - } - - var instance *string - if u.Path != "" { - trimmed := strings.TrimPrefix(u.Path, "/") - if trimmed != "" { - instance = &trimmed - } - } + query := uriconfig.Query() - if connectionTimeout != nil { + if !query.Has("connection timeout") && connectionTimeout != nil { query.Add("connection timeout", fmt.Sprintf("%d", *connectionTimeout)) } + uriconfig.RawQuery = query.Encode() - return &GeneralDbConnectConfig{ - driver: mssqlDriver, - host: host, - port: port, - database: instance, - user: user, - pass: pass, - queryParams: query, - }, nil + return &mssqlConnectConfig{url: uriconfig.String(), user: getUserFromInfo(uriconfig.User)}, nil default: return nil, nucleuserrors.NewBadRequest(fmt.Sprintf("must provide valid mssql connection: %T", cc)) } diff --git a/backend/pkg/dbconnect-config/mssql_test.go b/backend/pkg/dbconnect-config/mssql_test.go index bcc6c13937..1998ec30be 100644 --- a/backend/pkg/dbconnect-config/mssql_test.go +++ b/backend/pkg/dbconnect-config/mssql_test.go @@ -1,8 +1,6 @@ package dbconnectconfig import ( - "net/url" - "reflect" "testing" mgmtv1alpha1 "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1" @@ -10,29 +8,127 @@ import ( ) func Test_NewFromMssqlConnection(t *testing.T) { - t.Run("standard string url", func(t *testing.T) { - out, err := NewFromMssqlConnection(&mgmtv1alpha1.ConnectionConfig_MssqlConfig{ - MssqlConfig: &mgmtv1alpha1.MssqlConnectionConfig{ - ConnectionConfig: &mgmtv1alpha1.MssqlConnectionConfig_Url{ - Url: "sqlserver://test-user:test-pass@localhost:1433/myinstance?database=master", + t.Run("URL", func(t *testing.T) { + t.Run("ok", func(t *testing.T) { + actual, err := NewFromMssqlConnection( + &mgmtv1alpha1.ConnectionConfig_MssqlConfig{ + MssqlConfig: &mgmtv1alpha1.MssqlConnectionConfig{ + ConnectionConfig: &mgmtv1alpha1.MssqlConnectionConfig_Url{ + Url: "sqlserver://test-user:test-pass@localhost:1433/myinstance?database=master", + }, + }, }, - }, - }, ptr(uint32(5))) - - assert.NoError(t, err) - assert.NotNil(t, out) - expected := &GeneralDbConnectConfig{ - driver: "sqlserver", - host: "localhost", - port: ptr(int32(1433)), - database: ptr("myinstance"), - user: "test-user", - pass: "test-pass", - queryParams: url.Values{"database": []string{"master"}, "connection timeout": []string{"5"}}, - } - if !reflect.DeepEqual(out, expected) { - t.Errorf("Expected %v, got %v", expected, out) - } + &testConnectionTimeout, + ) + assert.NoError(t, err) + assert.NotNil(t, actual) + assert.Equal( + t, + "sqlserver://test-user:test-pass@localhost:1433/myinstance?connection+timeout=5&database=master", + actual.String(), + ) + assert.Equal(t, "test-user", actual.GetUser()) + }) + t.Run("ok_no_timeout", func(t *testing.T) { + actual, err := NewFromMssqlConnection( + &mgmtv1alpha1.ConnectionConfig_MssqlConfig{ + MssqlConfig: &mgmtv1alpha1.MssqlConnectionConfig{ + ConnectionConfig: &mgmtv1alpha1.MssqlConnectionConfig_Url{ + Url: "sqlserver://test-user:test-pass@localhost:1433/myinstance?database=master", + }, + }, + }, + nil, + ) + assert.NoError(t, err) + assert.NotNil(t, actual) + assert.Equal( + t, + "sqlserver://test-user:test-pass@localhost:1433/myinstance?database=master", + actual.String(), + ) + assert.Equal(t, "test-user", actual.GetUser()) + }) + t.Run("ok_user_provided_timeout", func(t *testing.T) { + actual, err := NewFromMssqlConnection( + &mgmtv1alpha1.ConnectionConfig_MssqlConfig{ + MssqlConfig: &mgmtv1alpha1.MssqlConnectionConfig{ + ConnectionConfig: &mgmtv1alpha1.MssqlConnectionConfig_Url{ + Url: "sqlserver://test-user:test-pass@localhost:1433/myinstance?connection+timeout=10&database=master", + }, + }, + }, + &testConnectionTimeout, + ) + assert.NoError(t, err) + assert.NotNil(t, actual) + assert.Equal( + t, + "sqlserver://test-user:test-pass@localhost:1433/myinstance?connection+timeout=10&database=master", + actual.String(), + ) + assert.Equal(t, "test-user", actual.GetUser()) + }) + t.Run("ok_strong_password", func(t *testing.T) { + actual, err := NewFromMssqlConnection( + &mgmtv1alpha1.ConnectionConfig_MssqlConfig{ + MssqlConfig: &mgmtv1alpha1.MssqlConnectionConfig{ + ConnectionConfig: &mgmtv1alpha1.MssqlConnectionConfig_Url{ + Url: "sqlserver://sa:myStr0ngP%40assword@localhost:1433/myinstance?database=master", + }, + }, + }, + &testConnectionTimeout, + ) + assert.NoError(t, err) + assert.NotNil(t, actual) + assert.Equal( + t, + "sqlserver://sa:myStr0ngP%40assword@localhost:1433/myinstance?connection+timeout=5&database=master", + actual.String(), + ) + assert.Equal(t, "sa", actual.GetUser()) + }) + t.Run("ok_no_instance", func(t *testing.T) { + actual, err := NewFromMssqlConnection( + &mgmtv1alpha1.ConnectionConfig_MssqlConfig{ + MssqlConfig: &mgmtv1alpha1.MssqlConnectionConfig{ + ConnectionConfig: &mgmtv1alpha1.MssqlConnectionConfig_Url{ + Url: "sqlserver://sa:myStr0ngP%40assword@localhost:1433?database=master", + }, + }, + }, + &testConnectionTimeout, + ) + assert.NoError(t, err) + assert.NotNil(t, actual) + assert.Equal( + t, + "sqlserver://sa:myStr0ngP%40assword@localhost:1433?connection+timeout=5&database=master", + actual.String(), + ) + assert.Equal(t, "sa", actual.GetUser()) + }) + t.Run("ok_no_instance_no_port", func(t *testing.T) { + actual, err := NewFromMssqlConnection( + &mgmtv1alpha1.ConnectionConfig_MssqlConfig{ + MssqlConfig: &mgmtv1alpha1.MssqlConnectionConfig{ + ConnectionConfig: &mgmtv1alpha1.MssqlConnectionConfig_Url{ + Url: "sqlserver://sa:myStr0ngP%40assword@localhost?database=master", + }, + }, + }, + &testConnectionTimeout, + ) + assert.NoError(t, err) + assert.NotNil(t, actual) + assert.Equal( + t, + "sqlserver://sa:myStr0ngP%40assword@localhost?connection+timeout=5&database=master", + actual.String(), + ) + assert.Equal(t, "sa", actual.GetUser()) + }) }) } diff --git a/backend/pkg/dbconnect-config/mysql.go b/backend/pkg/dbconnect-config/mysql.go index fa91141db2..90a4befb95 100644 --- a/backend/pkg/dbconnect-config/mysql.go +++ b/backend/pkg/dbconnect-config/mysql.go @@ -1,93 +1,97 @@ package dbconnectconfig import ( + "errors" "fmt" + "log/slog" "net/url" - "strconv" "strings" + "time" + "github.com/go-sql-driver/mysql" mgmtv1alpha1 "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1" - nucleuserrors "github.com/nucleuscloud/neosync/backend/internal/errors" ) -func NewFromMysqlConnection(config *mgmtv1alpha1.ConnectionConfig_MysqlConfig, connectionTimeout *uint32) (*GeneralDbConnectConfig, error) { - switch cc := config.MysqlConfig.ConnectionConfig.(type) { - case *mgmtv1alpha1.MysqlConnectionConfig_Connection: - query := url.Values{} - if connectionTimeout != nil { - query.Add("timeout", fmt.Sprintf("%ds", *connectionTimeout)) - } - query.Add("multiStatements", "true") - query.Add("parseTime", "true") - return &GeneralDbConnectConfig{ - driver: mysqlDriver, - host: cc.Connection.Host, - port: &cc.Connection.Port, - database: &cc.Connection.Name, - user: cc.Connection.User, - pass: cc.Connection.Pass, - mysqlProtocol: &cc.Connection.Protocol, - queryParams: query, - }, nil - case *mgmtv1alpha1.MysqlConnectionConfig_Url: - // follows the format [scheme://][user[:password]@][/schema][?option=value&option=value...] - // from the format - https://dev.mysql.com/doc/dev/mysqlsh-api-javascript/8.0/classmysqlsh_1_1_shell.html#a639614cf6b980f0d5267cc7057b81012 +type mysqlConnectConfig struct { + dsn string + user string +} - u, err := url.Parse(cc.Url) - if err != nil { - return nil, err - } +var _ DbConnectConfig = (*mysqlConnectConfig)(nil) - // mysqlx is a newer connection protocol meant for more flexible schemas and supports mysqls nosql db capabilities - // more information here - https://dev.mysql.com/doc/refman/8.4/en/connecting-using-uri-or-key-value-pairs.html +func (m *mysqlConnectConfig) String() string { + return m.dsn +} +func (m *mysqlConnectConfig) GetUser() string { + return m.user +} - if u.Scheme != "mysql" && u.Scheme != "mysqlx" { - return nil, fmt.Errorf("scheme is not mysql ,unsupported scheme: %s", u.Scheme) +func NewFromMysqlConnection( + config *mgmtv1alpha1.ConnectionConfig_MysqlConfig, + connectionTimeout *uint32, + logger *slog.Logger, +) (DbConnectConfig, error) { + switch cc := config.MysqlConfig.GetConnectionConfig().(type) { + case *mgmtv1alpha1.MysqlConnectionConfig_Connection: + cfg := mysql.NewConfig() + cfg.DBName = cc.Connection.GetName() + cfg.Addr = cc.Connection.GetHost() + if cc.Connection.GetPort() > 0 { + cfg.Addr += fmt.Sprintf(":%d", cc.Connection.GetPort()) } - - var user string - var pass string - - if u.User != nil { - user = u.User.Username() - pass, _ = u.User.Password() + cfg.User = cc.Connection.GetUser() + cfg.Passwd = cc.Connection.GetPass() + if connectionTimeout != nil { + cfg.Timeout = time.Duration(*connectionTimeout) * time.Second } + cfg.Net = cc.Connection.GetProtocol() + cfg.MultiStatements = true + cfg.ParseTime = true + + return &mysqlConnectConfig{dsn: cfg.FormatDSN(), user: cfg.User}, nil + case *mgmtv1alpha1.MysqlConnectionConfig_Url: + mysqlurl := cc.Url - port := int32(3306) - if p := u.Port(); p != "" { - portInt, err := strconv.Atoi(p) + cfg, err := mysql.ParseDSN(mysqlurl) + if err != nil { + logger.Warn(fmt.Sprintf("failed to parse mysql url as DSN: %v", err)) + uriConfig, err := url.Parse(mysqlurl) if err != nil { - return nil, err + var urlErr *url.Error + if errors.As(err, &urlErr) { + return nil, fmt.Errorf("unable to parse mysql url [%s]: %w", urlErr.Op, urlErr.Err) + } + return nil, fmt.Errorf("unable to parse mysql url: %w", err) + } + cfg = mysql.NewConfig() + cfg.Net = "tcp" + cfg.DBName = strings.TrimPrefix(uriConfig.Path, "/") + cfg.Addr = uriConfig.Host + cfg.User = uriConfig.User.Username() + if passwd, ok := uriConfig.User.Password(); ok { + cfg.Passwd = passwd } - // #nosec G109 - // this throws a linter error due to strconv.Atoi conversion above from string -> int32 - // mysql ports are unsigned 16-bit numbers so they should never overflow in an in32 - // https://stackoverflow.com/questions/20379491/what-is-the-optimal-way-to-store-port-numbers-in-a-mysql-database#:~:text=Port%20number%20is%20an%20unsinged,highest%20value%20can%20be%2065535. - // https://downloads.mysql.com/docs/mysql-port-reference-en.pdf - port = int32(portInt) //nolint:gosec // Ignoring for now + if connectionTimeout != nil { + cfg.Timeout = time.Duration(*connectionTimeout) * time.Second + } + cfg.MultiStatements = true + cfg.ParseTime = true + for k, values := range uriConfig.Query() { + for _, value := range values { + cfg.Params[k] = value + } + } + return &mysqlConnectConfig{dsn: cfg.FormatDSN(), user: cfg.User}, nil } - database := strings.TrimPrefix(u.Path, "/") - - query := u.Query() - if connectionTimeout != nil { - query.Add("timeout", fmt.Sprintf("%ds", *connectionTimeout)) + if cfg.Timeout == 0 && connectionTimeout != nil { + cfg.Timeout = time.Duration(*connectionTimeout) * time.Second } - query.Add("multiStatements", "true") - query.Add("parseTime", "true") - - return &GeneralDbConnectConfig{ - driver: u.Scheme, - host: u.Hostname(), - port: &port, - database: &database, - user: user, - pass: pass, - mysqlProtocol: nil, - queryParams: query, - }, nil + cfg.MultiStatements = true + cfg.ParseTime = true + return &mysqlConnectConfig{dsn: cfg.FormatDSN(), user: cfg.User}, nil default: - return nil, nucleuserrors.NewBadRequest("must provide valid mysql connection") + return nil, fmt.Errorf("unsupported mysql connection config: %T", cc) } } diff --git a/backend/pkg/dbconnect-config/mysql_test.go b/backend/pkg/dbconnect-config/mysql_test.go index 612cecff3a..27207c226a 100644 --- a/backend/pkg/dbconnect-config/mysql_test.go +++ b/backend/pkg/dbconnect-config/mysql_test.go @@ -1,7 +1,8 @@ package dbconnectconfig import ( - "net/url" + "io" + "log/slog" "testing" mgmtv1alpha1 "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1" @@ -17,108 +18,164 @@ var ( Pass: "test-pass", Protocol: "tcp", } + discardLogger = slog.New(slog.NewTextHandler(io.Discard, nil)) + testConnectionTimeout = uint32(5) ) -func Test_NewFromMysqlConnection_Connection(t *testing.T) { - out, err := NewFromMysqlConnection(&mgmtv1alpha1.ConnectionConfig_MysqlConfig{ - MysqlConfig: &mgmtv1alpha1.MysqlConnectionConfig{ - ConnectionConfig: &mgmtv1alpha1.MysqlConnectionConfig_Connection{ - Connection: mysqlconnectionFixture, - }, - }, - }, ptr(uint32(5))) - - assert.NoError(t, err) - assert.NotNil(t, out) - assert.Equal(t, out, &GeneralDbConnectConfig{ - driver: "mysql", - host: "localhost", - port: ptr(int32(3309)), - database: ptr("mydb"), - user: "test-user", - pass: "test-pass", - mysqlProtocol: ptr("tcp"), - queryParams: url.Values{"timeout": []string{"5s"}, "multiStatements": []string{"true"}, "parseTime": []string{"true"}}, +func Test_NewFromMysqlConnection(t *testing.T) { + t.Run("Connection", func(t *testing.T) { + t.Run("ok", func(t *testing.T) { + actual, err := NewFromMysqlConnection( + &mgmtv1alpha1.ConnectionConfig_MysqlConfig{ + MysqlConfig: &mgmtv1alpha1.MysqlConnectionConfig{ + ConnectionConfig: &mgmtv1alpha1.MysqlConnectionConfig_Connection{ + Connection: mysqlconnectionFixture, + }, + }, + }, + &testConnectionTimeout, + discardLogger, + ) + assert.NoError(t, err) + assert.NotNil(t, actual) + assert.Equal( + t, + "test-user:test-pass@tcp(localhost:3309)/mydb?multiStatements=true&parseTime=true&timeout=5s", + actual.String(), + ) + assert.Equal(t, "test-user", actual.GetUser()) + }) + t.Run("ok_no_timeout", func(t *testing.T) { + actual, err := NewFromMysqlConnection( + &mgmtv1alpha1.ConnectionConfig_MysqlConfig{ + MysqlConfig: &mgmtv1alpha1.MysqlConnectionConfig{ + ConnectionConfig: &mgmtv1alpha1.MysqlConnectionConfig_Connection{ + Connection: mysqlconnectionFixture, + }, + }, + }, + nil, + discardLogger, + ) + assert.NoError(t, err) + assert.NotNil(t, actual) + assert.Equal( + t, + "test-user:test-pass@tcp(localhost:3309)/mydb?multiStatements=true&parseTime=true", + actual.String(), + ) + assert.Equal(t, "test-user", actual.GetUser()) + }) }) -} -func Test_NewFromMysqlConnection_Url_mysql(t *testing.T) { - out, err := NewFromMysqlConnection(&mgmtv1alpha1.ConnectionConfig_MysqlConfig{ - MysqlConfig: &mgmtv1alpha1.MysqlConnectionConfig{ - ConnectionConfig: &mgmtv1alpha1.MysqlConnectionConfig_Url{ - Url: "mysql://myuser:mypassword@localhost:3306/mydatabase?ssl=true", - }, - }, - }, ptr(uint32(5))) - - assert.NoError(t, err) - assert.NotNil(t, out) - assert.Equal(t, out, &GeneralDbConnectConfig{ - driver: "mysql", - host: "localhost", - port: ptr(int32(3306)), - database: ptr("mydatabase"), - user: "myuser", - pass: "mypassword", - mysqlProtocol: nil, - queryParams: url.Values{"ssl": []string{"true"}, "multiStatements": []string{"true"}, "timeout": []string{"5s"}, "parseTime": []string{"true"}}, + t.Run("URL_DSN", func(t *testing.T) { + t.Run("ok", func(t *testing.T) { + actual, err := NewFromMysqlConnection( + &mgmtv1alpha1.ConnectionConfig_MysqlConfig{ + MysqlConfig: &mgmtv1alpha1.MysqlConnectionConfig{ + ConnectionConfig: &mgmtv1alpha1.MysqlConnectionConfig_Url{ + Url: "test-user:testpass@tcp(localhost:3309)/mydb?multiStatements=true&parseTime=true", + }, + }, + }, + &testConnectionTimeout, + discardLogger, + ) + assert.NoError(t, err) + assert.NotNil(t, actual) + assert.Equal( + t, + "test-user:testpass@tcp(localhost:3309)/mydb?multiStatements=true&parseTime=true&timeout=5s", + actual.String(), + ) + assert.Equal(t, "test-user", actual.GetUser()) + }) + t.Run("ok_no_timeout", func(t *testing.T) { + actual, err := NewFromMysqlConnection( + &mgmtv1alpha1.ConnectionConfig_MysqlConfig{ + MysqlConfig: &mgmtv1alpha1.MysqlConnectionConfig{ + ConnectionConfig: &mgmtv1alpha1.MysqlConnectionConfig_Url{ + Url: "test-user:testpass@tcp(localhost:3309)/mydb", + }, + }, + }, + nil, + discardLogger, + ) + assert.NoError(t, err) + assert.NotNil(t, actual) + assert.Equal( + t, + "test-user:testpass@tcp(localhost:3309)/mydb?multiStatements=true&parseTime=true", + actual.String(), + ) + assert.Equal(t, "test-user", actual.GetUser()) + }) + t.Run("ok_specialchars_userpass", func(t *testing.T) { + actual, err := NewFromMysqlConnection( + &mgmtv1alpha1.ConnectionConfig_MysqlConfig{ + MysqlConfig: &mgmtv1alpha1.MysqlConnectionConfig{ + ConnectionConfig: &mgmtv1alpha1.MysqlConnectionConfig_Url{ + Url: "specialuser!*-:46!ZfMv3@Uh8*-<@@tcp(localhost:3309)/mydb", + }, + }, + }, + nil, + discardLogger, + ) + assert.NoError(t, err) + assert.NotNil(t, actual) + assert.Equal( + t, + "specialuser!*-:46!ZfMv3@Uh8*-<@@tcp(localhost:3309)/mydb?multiStatements=true&parseTime=true", + actual.String(), + ) + assert.Equal(t, "specialuser!*-", actual.GetUser()) + }) }) -} -func Test_NewFromMysqlConnection_Url_mysqlx(t *testing.T) { - out, err := NewFromMysqlConnection(&mgmtv1alpha1.ConnectionConfig_MysqlConfig{ - MysqlConfig: &mgmtv1alpha1.MysqlConnectionConfig{ - ConnectionConfig: &mgmtv1alpha1.MysqlConnectionConfig_Url{ - Url: "mysqlx://myuser:mypassword@localhost:3306/mydatabase?ssl=true", - }, - }, - }, ptr(uint32(5))) - assert.NoError(t, err) - assert.NotNil(t, out) - assert.Equal(t, out, &GeneralDbConnectConfig{ - driver: "mysqlx", - host: "localhost", - port: ptr(int32(3306)), - database: ptr("mydatabase"), - user: "myuser", - pass: "mypassword", - mysqlProtocol: nil, - queryParams: url.Values{"ssl": []string{"true"}, "multiStatements": []string{"true"}, "timeout": []string{"5s"}, "parseTime": []string{"true"}}, + t.Run("URL_URI", func(t *testing.T) { + t.Run("ok", func(t *testing.T) { + actual, err := NewFromMysqlConnection( + &mgmtv1alpha1.ConnectionConfig_MysqlConfig{ + MysqlConfig: &mgmtv1alpha1.MysqlConnectionConfig{ + ConnectionConfig: &mgmtv1alpha1.MysqlConnectionConfig_Url{ + Url: "mysql://test-user:testpass@localhost:3309/mydb", + }, + }, + }, + &testConnectionTimeout, + discardLogger, + ) + assert.NoError(t, err) + assert.NotNil(t, actual) + assert.Equal( + t, + "test-user:testpass@tcp(localhost:3309)/mydb?multiStatements=true&parseTime=true&timeout=5s", + actual.String(), + ) + assert.Equal(t, "test-user", actual.GetUser()) + }) + t.Run("ok_no_timeout", func(t *testing.T) { + actual, err := NewFromMysqlConnection( + &mgmtv1alpha1.ConnectionConfig_MysqlConfig{ + MysqlConfig: &mgmtv1alpha1.MysqlConnectionConfig{ + ConnectionConfig: &mgmtv1alpha1.MysqlConnectionConfig_Url{ + Url: "mysql://test-user:testpass@localhost:3309/mydb", + }, + }, + }, + nil, + discardLogger, + ) + assert.NoError(t, err) + assert.NotNil(t, actual) + assert.Equal( + t, + "test-user:testpass@tcp(localhost:3309)/mydb?multiStatements=true&parseTime=true", + actual.String(), + ) + assert.Equal(t, "test-user", actual.GetUser()) + }) }) } - -func Test_NewFromMysqlConnection_Url_Error(t *testing.T) { - _, err := NewFromMysqlConnection(&mgmtv1alpha1.ConnectionConfig_MysqlConfig{ - MysqlConfig: &mgmtv1alpha1.MysqlConnectionConfig{ - ConnectionConfig: &mgmtv1alpha1.MysqlConnectionConfig_Url{ - Url: "mysql://myuser:mypassword/mydatabase?ssl=true", - }, - }, - }, ptr(uint32(5))) - - assert.Error(t, err) -} - -func Test_NewFromMysqlConnection_Url_NoScheme(t *testing.T) { - _, err := NewFromMysqlConnection(&mgmtv1alpha1.ConnectionConfig_MysqlConfig{ - MysqlConfig: &mgmtv1alpha1.MysqlConnectionConfig{ - ConnectionConfig: &mgmtv1alpha1.MysqlConnectionConfig_Url{ - Url: "mysqlxxx://myuser:mypassword@localhost:3306/mydatabase?ssl=true", - }, - }, - }, ptr(uint32(5))) - - assert.Error(t, err) -} - -func Test_NewFromMysqlConnection_Url_NoPort(t *testing.T) { - _, err := NewFromMysqlConnection(&mgmtv1alpha1.ConnectionConfig_MysqlConfig{ - MysqlConfig: &mgmtv1alpha1.MysqlConnectionConfig{ - ConnectionConfig: &mgmtv1alpha1.MysqlConnectionConfig_Url{ - Url: "mysqlxxx://myuser:mypassword@localhost/mydatabase?ssl=true", - }, - }, - }, ptr(uint32(5))) - - assert.Error(t, err) -} diff --git a/backend/pkg/dbconnect-config/postgres.go b/backend/pkg/dbconnect-config/postgres.go index a19c61a019..32c2dc1749 100644 --- a/backend/pkg/dbconnect-config/postgres.go +++ b/backend/pkg/dbconnect-config/postgres.go @@ -3,47 +3,68 @@ package dbconnectconfig import ( "errors" "fmt" + "log/slog" "net/url" - "strconv" - "strings" mgmtv1alpha1 "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1" - nucleuserrors "github.com/nucleuscloud/neosync/backend/internal/errors" "github.com/nucleuscloud/neosync/backend/pkg/clienttls" - "github.com/nucleuscloud/neosync/worker/pkg/workflows/datasync/activities/shared" ) -func NewFromPostgresConnection(config *mgmtv1alpha1.ConnectionConfig_PgConfig, connectionTimeout *uint32) (*GeneralDbConnectConfig, error) { - switch cc := config.PgConfig.ConnectionConfig.(type) { +type pgConnectConfig struct { + url string + user string +} + +var _ DbConnectConfig = (*pgConnectConfig)(nil) + +func (m *pgConnectConfig) String() string { + return m.url +} +func (m *pgConnectConfig) GetUser() string { + return m.user +} + +func NewFromPostgresConnection( + config *mgmtv1alpha1.ConnectionConfig_PgConfig, + connectionTimeout *uint32, + logger *slog.Logger, +) (DbConnectConfig, error) { + switch cc := config.PgConfig.GetConnectionConfig().(type) { case *mgmtv1alpha1.PostgresConnectionConfig_Connection: - query := url.Values{} - if cc.Connection.SslMode != nil { - query.Add("sslmode", *cc.Connection.SslMode) + host := cc.Connection.GetHost() + if cc.Connection.GetPort() > 0 { + host += fmt.Sprintf(":%d", cc.Connection.GetPort()) } - if connectionTimeout != nil { - query.Add("connect_timeout", fmt.Sprintf("%d", *connectionTimeout)) + + pgurl := url.URL{ + Scheme: "postgres", + Host: host, + } + if cc.Connection.GetUser() != "" && cc.Connection.GetPass() != "" { + pgurl.User = url.UserPassword(cc.Connection.GetUser(), cc.Connection.GetPass()) + } else if cc.Connection.GetUser() != "" && cc.Connection.GetPass() == "" { + pgurl.User = url.User(cc.Connection.GetUser()) + } + if cc.Connection.GetName() != "" { + pgurl.Path = cc.Connection.GetName() + } + query := url.Values{} + if cc.Connection.GetSslMode() != "" { + query.Set("sslmode", cc.Connection.GetSslMode()) } if config.PgConfig.GetClientTls() != nil { - filenames := clienttls.GetClientTlsFileNames(config.PgConfig.GetClientTls()) - if filenames.RootCert != nil { - query.Add("sslrootcert", *filenames.RootCert) - } - if filenames.ClientCert != nil && filenames.ClientKey != nil { - query.Add("sslcert", *filenames.ClientCert) - query.Add("sslkey", *filenames.ClientKey) - } + query = setPgClientTlsQueryParams(query, config.PgConfig.GetClientTls()) + } + if connectionTimeout != nil { + query.Set("connect_timeout", fmt.Sprintf("%d", *connectionTimeout)) } - return &GeneralDbConnectConfig{ - driver: postgresDriver, - host: cc.Connection.Host, - port: &cc.Connection.Port, - database: &cc.Connection.Name, - user: cc.Connection.User, - pass: cc.Connection.Pass, - queryParams: query, - }, nil + pgurl.RawQuery = query.Encode() + + return &pgConnectConfig{url: pgurl.String(), user: getUserFromInfo(pgurl.User)}, nil case *mgmtv1alpha1.PostgresConnectionConfig_Url: - u, err := url.Parse(cc.Url) + pgurl := cc.Url + + uriconfig, err := url.Parse(pgurl) if err != nil { var urlErr *url.Error if errors.As(err, &urlErr) { @@ -51,46 +72,39 @@ func NewFromPostgresConnection(config *mgmtv1alpha1.ConnectionConfig_PgConfig, c } return nil, fmt.Errorf("unable to parse postgres url: %w", err) } - - user := u.User.Username() - pass, ok := u.User.Password() - if !ok { - return nil, errors.New("unable to get password for pg string") - } - - host, portStr := u.Hostname(), u.Port() - - var port int64 - if portStr != "" { - port, err = strconv.ParseInt(portStr, 10, 32) - if err != nil { - return nil, fmt.Errorf("invalid port: %w", err) - } - } else { - // default to standard postgres port 5432 if port not provided - port = int64(5432) + query := uriconfig.Query() + if !query.Has("connect_timeout") && connectionTimeout != nil { + query.Set("connect_timeout", fmt.Sprintf("%d", *connectionTimeout)) } - query := u.Query() + // todo: move this out of here into the driver if config.PgConfig.GetClientTls() != nil { - filenames := clienttls.GetClientTlsFileNames(config.PgConfig.GetClientTls()) - if filenames.RootCert != nil { - query.Add("sslrootcert", *filenames.RootCert) - } - if filenames.ClientCert != nil && filenames.ClientKey != nil { - query.Add("sslcert", *filenames.ClientCert) - query.Add("sslkey", *filenames.ClientKey) - } + query = setPgClientTlsQueryParams(query, config.PgConfig.GetClientTls()) } - return &GeneralDbConnectConfig{ - driver: postgresDriver, - host: host, - port: shared.Ptr(int32(port)), - database: shared.Ptr(strings.TrimPrefix(u.Path, "/")), - user: user, - pass: pass, - queryParams: query, - }, nil + uriconfig.RawQuery = query.Encode() + return &pgConnectConfig{url: uriconfig.String(), user: getUserFromInfo(uriconfig.User)}, nil default: - return nil, nucleuserrors.NewBadRequest("must provide valid postgres connection") + return nil, fmt.Errorf("unsupported pg connection config: %T", cc) + } +} + +func setPgClientTlsQueryParams( + query url.Values, + cfg *mgmtv1alpha1.ClientTlsConfig, +) url.Values { + filenames := clienttls.GetClientTlsFileNames(cfg) + if filenames.RootCert != nil { + query.Set("sslrootcert", *filenames.RootCert) + } + if filenames.ClientCert != nil && filenames.ClientKey != nil { + query.Set("sslcert", *filenames.ClientCert) + query.Set("sslkey", *filenames.ClientKey) + } + return query +} + +func getUserFromInfo(u *url.Userinfo) string { + if u == nil { + return "" } + return u.Username() } diff --git a/backend/pkg/dbconnect-config/postgres_test.go b/backend/pkg/dbconnect-config/postgres_test.go index 9d5a295205..dd932acb42 100644 --- a/backend/pkg/dbconnect-config/postgres_test.go +++ b/backend/pkg/dbconnect-config/postgres_test.go @@ -1,7 +1,6 @@ package dbconnectconfig import ( - "net/url" "testing" mgmtv1alpha1 "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1" @@ -19,48 +18,199 @@ var ( } ) -func Test_getGeneralDbConnectConfigFromPg_Connection(t *testing.T) { - out, err := NewFromPostgresConnection(&mgmtv1alpha1.ConnectionConfig_PgConfig{ - PgConfig: &mgmtv1alpha1.PostgresConnectionConfig{ - ConnectionConfig: &mgmtv1alpha1.PostgresConnectionConfig_Connection{ - Connection: pgconnectionFixture, - }, - }, - }, ptr(uint32(5))) - - assert.NoError(t, err) - assert.NotNil(t, out) - assert.Equal(t, out, &GeneralDbConnectConfig{ - driver: "postgres", - host: "localhost", - port: ptr(int32(5432)), - database: ptr("postgres"), - user: "test-user", - pass: "test-pass", - mysqlProtocol: nil, - queryParams: url.Values{"sslmode": []string{"verify"}, "connect_timeout": []string{"5"}}, +func Test_NewFromPostgresConnection(t *testing.T) { + t.Run("Connection", func(t *testing.T) { + t.Run("ok", func(t *testing.T) { + actual, err := NewFromPostgresConnection( + &mgmtv1alpha1.ConnectionConfig_PgConfig{ + PgConfig: &mgmtv1alpha1.PostgresConnectionConfig{ + ConnectionConfig: &mgmtv1alpha1.PostgresConnectionConfig_Connection{ + Connection: pgconnectionFixture, + }, + }, + }, + &testConnectionTimeout, + discardLogger, + ) + assert.NoError(t, err) + assert.NotNil(t, actual) + assert.Equal( + t, + "postgres://test-user:test-pass@localhost:5432/postgres?connect_timeout=5&sslmode=verify", + actual.String(), + ) + assert.Equal(t, "test-user", actual.GetUser()) + }) + t.Run("ok_no_timeout", func(t *testing.T) { + actual, err := NewFromPostgresConnection( + &mgmtv1alpha1.ConnectionConfig_PgConfig{ + PgConfig: &mgmtv1alpha1.PostgresConnectionConfig{ + ConnectionConfig: &mgmtv1alpha1.PostgresConnectionConfig_Connection{ + Connection: pgconnectionFixture, + }, + }, + }, + nil, + discardLogger, + ) + assert.NoError(t, err) + assert.NotNil(t, actual) + assert.Equal( + t, + "postgres://test-user:test-pass@localhost:5432/postgres?sslmode=verify", + actual.String(), + ) + assert.Equal(t, "test-user", actual.GetUser()) + }) + t.Run("ok_no_port", func(t *testing.T) { + actual, err := NewFromPostgresConnection( + &mgmtv1alpha1.ConnectionConfig_PgConfig{ + PgConfig: &mgmtv1alpha1.PostgresConnectionConfig{ + ConnectionConfig: &mgmtv1alpha1.PostgresConnectionConfig_Connection{ + Connection: &mgmtv1alpha1.PostgresConnection{ + Host: "localhost", + // Port: 5432, + Name: "postgres", + User: "test-user", + Pass: "test-pass", + SslMode: ptr("verify"), + }, + }, + }, + }, + &testConnectionTimeout, + discardLogger, + ) + assert.NoError(t, err) + assert.NotNil(t, actual) + assert.Equal( + t, + "postgres://test-user:test-pass@localhost/postgres?connect_timeout=5&sslmode=verify", + actual.String(), + ) + assert.Equal(t, "test-user", actual.GetUser()) + }) + t.Run("ok_no_pass", func(t *testing.T) { + actual, err := NewFromPostgresConnection( + &mgmtv1alpha1.ConnectionConfig_PgConfig{ + PgConfig: &mgmtv1alpha1.PostgresConnectionConfig{ + ConnectionConfig: &mgmtv1alpha1.PostgresConnectionConfig_Connection{ + Connection: &mgmtv1alpha1.PostgresConnection{ + Host: "localhost", + Port: 5432, + Name: "postgres", + User: "test-user", + // Pass: "test-pass", + SslMode: ptr("verify"), + }, + }, + }, + }, + &testConnectionTimeout, + discardLogger, + ) + assert.NoError(t, err) + assert.NotNil(t, actual) + assert.Equal( + t, + "postgres://test-user@localhost:5432/postgres?connect_timeout=5&sslmode=verify", + actual.String(), + ) + assert.Equal(t, "test-user", actual.GetUser()) + }) + t.Run("ok_no_creds", func(t *testing.T) { + actual, err := NewFromPostgresConnection( + &mgmtv1alpha1.ConnectionConfig_PgConfig{ + PgConfig: &mgmtv1alpha1.PostgresConnectionConfig{ + ConnectionConfig: &mgmtv1alpha1.PostgresConnectionConfig_Connection{ + Connection: &mgmtv1alpha1.PostgresConnection{ + Host: "localhost", + Port: 5432, + Name: "postgres", + // User: "test-user", + // Pass: "test-pass", + SslMode: ptr("verify"), + }, + }, + }, + }, + &testConnectionTimeout, + discardLogger, + ) + assert.NoError(t, err) + assert.NotNil(t, actual) + assert.Equal( + t, + "postgres://localhost:5432/postgres?connect_timeout=5&sslmode=verify", + actual.String(), + ) + assert.Equal(t, "", actual.GetUser()) + }) }) -} - -func Test_getGeneralDbConnectConfigFromPg_Url(t *testing.T) { - out, err := NewFromPostgresConnection(&mgmtv1alpha1.ConnectionConfig_PgConfig{ - PgConfig: &mgmtv1alpha1.PostgresConnectionConfig{ - ConnectionConfig: &mgmtv1alpha1.PostgresConnectionConfig_Url{ - Url: "postgres://test-user:test-pass@localhost:5432/postgres?sslmode=verify&connect_timeout=5", - }, - }, - }, ptr(uint32(5))) - assert.NoError(t, err) - assert.NotNil(t, out) - assert.Equal(t, out, &GeneralDbConnectConfig{ - driver: "postgres", - host: "localhost", - port: ptr(int32(5432)), - database: ptr("postgres"), - user: "test-user", - pass: "test-pass", - mysqlProtocol: nil, - queryParams: url.Values{"sslmode": []string{"verify"}, "connect_timeout": []string{"5"}}, + t.Run("URL", func(t *testing.T) { + t.Run("ok", func(t *testing.T) { + actual, err := NewFromPostgresConnection( + &mgmtv1alpha1.ConnectionConfig_PgConfig{ + PgConfig: &mgmtv1alpha1.PostgresConnectionConfig{ + ConnectionConfig: &mgmtv1alpha1.PostgresConnectionConfig_Url{ + Url: "postgres://test-user:test-pass@localhost:5432/postgres?sslmode=disable", + }, + }, + }, + &testConnectionTimeout, + discardLogger, + ) + assert.NoError(t, err) + assert.NotNil(t, actual) + assert.Equal( + t, + "postgres://test-user:test-pass@localhost:5432/postgres?connect_timeout=5&sslmode=disable", + actual.String(), + ) + assert.Equal(t, "test-user", actual.GetUser()) + }) + t.Run("ok_no_timeout", func(t *testing.T) { + actual, err := NewFromPostgresConnection( + &mgmtv1alpha1.ConnectionConfig_PgConfig{ + PgConfig: &mgmtv1alpha1.PostgresConnectionConfig{ + ConnectionConfig: &mgmtv1alpha1.PostgresConnectionConfig_Url{ + Url: "postgres://test-user:test-pass@localhost:5432/postgres", + }, + }, + }, + nil, + discardLogger, + ) + assert.NoError(t, err) + assert.NotNil(t, actual) + assert.Equal( + t, + "postgres://test-user:test-pass@localhost:5432/postgres", + actual.String(), + ) + assert.Equal(t, "test-user", actual.GetUser()) + }) + t.Run("ok_user_provided_timeout", func(t *testing.T) { + actual, err := NewFromPostgresConnection( + &mgmtv1alpha1.ConnectionConfig_PgConfig{ + PgConfig: &mgmtv1alpha1.PostgresConnectionConfig{ + ConnectionConfig: &mgmtv1alpha1.PostgresConnectionConfig_Url{ + Url: "postgres://test-user:test-pass@localhost:5432/postgres?connect_timeout=10", + }, + }, + }, + &testConnectionTimeout, + discardLogger, + ) + assert.NoError(t, err) + assert.NotNil(t, actual) + assert.Equal( + t, + "postgres://test-user:test-pass@localhost:5432/postgres?connect_timeout=10", + actual.String(), + ) + assert.Equal(t, "test-user", actual.GetUser()) + }) }) } diff --git a/backend/pkg/sqlconnect/sql-connector.go b/backend/pkg/sqlconnect/sql-connector.go index 51cbb21fcd..7afc0f6390 100644 --- a/backend/pkg/sqlconnect/sql-connector.go +++ b/backend/pkg/sqlconnect/sql-connector.go @@ -55,7 +55,7 @@ func (rc *SqlOpenConnector) NewDbFromConnectionConfig(cc *mgmtv1alpha1.Connectio return nil, fmt.Errorf("unable to upsert client tls files: %w", err) } } - connDetails, err := dbconnectconfig.NewFromPostgresConnection(config, connectionTimeout) + connDetails, err := dbconnectconfig.NewFromPostgresConnection(config, connectionTimeout, logger) if err != nil { return nil, err } @@ -76,7 +76,7 @@ func (rc *SqlOpenConnector) NewDbFromConnectionConfig(cc *mgmtv1alpha1.Connectio return newStdlibContainer("pgx", dsn, dbconnopts), nil } case *mgmtv1alpha1.ConnectionConfig_MysqlConfig: - connDetails, err := dbconnectconfig.NewFromMysqlConnection(config, connectionTimeout) + connDetails, err := dbconnectconfig.NewFromMysqlConnection(config, connectionTimeout, logger) if err != nil { return nil, err } @@ -330,12 +330,12 @@ func (s *stdlibContainer) Close() error { } type ConnectionDetails struct { - dbconnectconfig.GeneralDbConnectConfig + dbconnectconfig.DbConnectConfig MaxConnectionLimit *int32 } func (c *ConnectionDetails) String() string { - return c.GeneralDbConnectConfig.String() + return c.DbConnectConfig.String() } type ClientCertConfig struct { diff --git a/backend/services/mgmt/v1alpha1/connection-service/connection.go b/backend/services/mgmt/v1alpha1/connection-service/connection.go index 1b59c07b59..71949c4c4c 100644 --- a/backend/services/mgmt/v1alpha1/connection-service/connection.go +++ b/backend/services/mgmt/v1alpha1/connection-service/connection.go @@ -5,7 +5,7 @@ import ( "database/sql" "errors" "fmt" - "net/url" + "log/slog" "strings" "sync" @@ -17,6 +17,7 @@ import ( "github.com/nucleuscloud/neosync/backend/internal/dtomaps" nucleuserrors "github.com/nucleuscloud/neosync/backend/internal/errors" "github.com/nucleuscloud/neosync/backend/internal/neosyncdb" + dbconnectconfig "github.com/nucleuscloud/neosync/backend/pkg/dbconnect-config" pg_models "github.com/nucleuscloud/neosync/backend/sql/postgresql/models" "golang.org/x/sync/errgroup" @@ -31,7 +32,7 @@ func (s *Service) CheckConnectionConfig( switch req.Msg.GetConnectionConfig().GetConfig().(type) { case *mgmtv1alpha1.ConnectionConfig_PgConfig, *mgmtv1alpha1.ConnectionConfig_MysqlConfig, *mgmtv1alpha1.ConnectionConfig_MssqlConfig: - role, err := getDbRoleFromConnectionConfig(req.Msg.GetConnectionConfig()) + role, err := getDbRoleFromConnectionConfig(req.Msg.GetConnectionConfig(), logger) if err != nil { return nil, err } @@ -188,75 +189,32 @@ func (s *Service) CheckConnectionConfigById( }), nil } -func getDbRoleFromConnectionConfig(cconfig *mgmtv1alpha1.ConnectionConfig) (string, error) { +func getDbRoleFromConnectionConfig(cconfig *mgmtv1alpha1.ConnectionConfig, logger *slog.Logger) (string, error) { if cconfig == nil { return "", errors.New("connection config was nil, unable to retrieve db role") } switch typedconfig := cconfig.GetConfig().(type) { case *mgmtv1alpha1.ConnectionConfig_PgConfig: - return getPostgresUserFromConnectionConfig(typedconfig.PgConfig) - case *mgmtv1alpha1.ConnectionConfig_MysqlConfig: - return getMysqlUserFromConnectionConfig(typedconfig.MysqlConfig) - case *mgmtv1alpha1.ConnectionConfig_MssqlConfig: - return getMssqlUserFromConnectionConfig(typedconfig.MssqlConfig) - default: - return "", fmt.Errorf("invalid database connection config (%T) for retrieving db role: %w", typedconfig, errors.ErrUnsupported) - } -} - -func getPostgresUserFromConnectionConfig(pgconfig *mgmtv1alpha1.PostgresConnectionConfig) (string, error) { - switch config := pgconfig.ConnectionConfig.(type) { - case *mgmtv1alpha1.PostgresConnectionConfig_Connection: - return config.Connection.User, nil - case *mgmtv1alpha1.PostgresConnectionConfig_Url: - u, err := url.Parse(config.Url) + parsedCfg, err := dbconnectconfig.NewFromPostgresConnection(typedconfig, nil, logger) if err != nil { - var urlErr *url.Error - if errors.As(err, &urlErr) { - return "", fmt.Errorf("unable to parse postgres url [%s]: %w", urlErr.Op, urlErr.Err) - } - return "", fmt.Errorf("unable to parse postgres url: %w", err) + return "", fmt.Errorf("unable to parse pg connection: %w", err) } - return u.User.Username(), nil - default: - return "", fmt.Errorf("unable to parse connection url from postgres config: %T", config) - } -} - -func getMysqlUserFromConnectionConfig(pgconfig *mgmtv1alpha1.MysqlConnectionConfig) (string, error) { - switch config := pgconfig.ConnectionConfig.(type) { - case *mgmtv1alpha1.MysqlConnectionConfig_Connection: - return config.Connection.User, nil - case *mgmtv1alpha1.MysqlConnectionConfig_Url: - u, err := url.Parse(config.Url) + return parsedCfg.GetUser(), nil + case *mgmtv1alpha1.ConnectionConfig_MysqlConfig: + parsedCfg, err := dbconnectconfig.NewFromMysqlConnection(typedconfig, nil, logger) if err != nil { - var urlErr *url.Error - if errors.As(err, &urlErr) { - return "", fmt.Errorf("unable to parse mysql url [%s]: %w", urlErr.Op, urlErr.Err) - } - return "", fmt.Errorf("unable to parse mysql url: %w", err) + return "", fmt.Errorf("unable to parse mysql connection: %w", err) } - return u.User.Username(), nil - default: - return "", fmt.Errorf("unable to parse connection url from postgres config: %T", config) - } -} - -func getMssqlUserFromConnectionConfig(ccfg *mgmtv1alpha1.MssqlConnectionConfig) (string, error) { - switch config := ccfg.ConnectionConfig.(type) { - case *mgmtv1alpha1.MssqlConnectionConfig_Url: - u, err := url.Parse(config.Url) + return parsedCfg.GetUser(), nil + case *mgmtv1alpha1.ConnectionConfig_MssqlConfig: + parsedCfg, err := dbconnectconfig.NewFromMssqlConnection(typedconfig, nil) if err != nil { - var urlErr *url.Error - if errors.As(err, &urlErr) { - return "", fmt.Errorf("unable to parse mssql url [%s]: %w", urlErr.Op, urlErr.Err) - } - return "", fmt.Errorf("unable to parse mssql url: %w", err) + return "", fmt.Errorf("unable to parse mssql connection: %w", err) } - return u.User.Username(), nil + return parsedCfg.GetUser(), nil default: - return "", fmt.Errorf("unable to parse connection url from postgres config: %T", config) + return "", fmt.Errorf("invalid database connection config (%T) for retrieving db role: %w", typedconfig, errors.ErrUnsupported) } } diff --git a/frontend/apps/web/app/(mgmt)/[account]/connections/[id]/components/MysqlForm.tsx b/frontend/apps/web/app/(mgmt)/[account]/connections/[id]/components/MysqlForm.tsx index 4fe71e36e0..3f72abb88d 100644 --- a/frontend/apps/web/app/(mgmt)/[account]/connections/[id]/components/MysqlForm.tsx +++ b/frontend/apps/web/app/(mgmt)/[account]/connections/[id]/components/MysqlForm.tsx @@ -174,7 +174,7 @@ export default function MysqlForm(props: Props) { Your connection URL diff --git a/frontend/apps/web/app/(mgmt)/[account]/new/connection/mysql/MysqlForm.tsx b/frontend/apps/web/app/(mgmt)/[account]/new/connection/mysql/MysqlForm.tsx index 46d8ae289c..3a8745bb60 100644 --- a/frontend/apps/web/app/(mgmt)/[account]/new/connection/mysql/MysqlForm.tsx +++ b/frontend/apps/web/app/(mgmt)/[account]/new/connection/mysql/MysqlForm.tsx @@ -317,7 +317,7 @@ the hook in the useEffect conditionally. This is used to retrieve the values for Your connection URL