Skip to content

Commit

Permalink
Db conn refactor (#250)
Browse files Browse the repository at this point in the history
* Bits for copying source to target dbs

* Use transaction for destination DB

* Function comment clarified

* extract new method MakeDbInstance for use by import; do not migrate the input db

* Add some collision detection

* switch to streaming approach for data transfer

* Tiny formatting change

* Remove blank lines that lint dislikes

* Fix lint errors plus double-registration error for sqlite driver when
the import args are both sqlite://

* Start / WIP for import int test

* validate the source and target db using rownumbers

* undo baseFixtures rename

* Trust primary key / unique constraints to prevent collisions

* Wrap import destination in Tx per table

* Additional entitites for db import to verify

* WIP more explicit handling, less generics

* Revert "WIP more explicit handling, less generics"

Also switched to "string" tablenames to avoid generics/polymophism issues

* Fix lint

* Importer working for experiments, sqlite -> pg

* Comments and additional test

* Fix lint

* Another lint complaint

* remove comments and dry-run support

* Fix experiment_id association

* Add some additional testing fields

* Fix lint complaint

* disable import and testing of apps, dashboards

* Fixture rename

* Restore missing ArchiveRuns() function

* Add slices module and use to detect sqlite driver registration

* Add struct to hold import functions as methods

* Missed this chunk for adding Importer struct

* remove unneeded err check

* Add table-contents test and rename helpers

* remove separate declaration of variable

* Refactor db instance management to factory pattern

* Fix postgres factory

* comments and gofumpt

* Test both dbs in the integration-tests container

* Use makefile to pass var from github action to docker-compose

* Lint complaints

* Lower-case private methods

* moar lowercase/private structs

* FML_DATABASE_URI is set via Env -> docker-compose.yaml, so don't need
these -e args

* db.reset() move to shared code path so PG also uses

* PR requests

* PR change requests

* PR refactors

* Fumpt and change comments only

* Rename request from PR

* Rename Db() accessor to GormDB()

* Rename Db() accessor to GormDB() in tests

* PR requests error catch and changed error message

---------

Co-authored-by: Geoff Wilson <[email protected]>
  • Loading branch information
suprjinx and Geoff Wilson authored Aug 26, 2023
1 parent 3dc9e2b commit a052a5f
Show file tree
Hide file tree
Showing 21 changed files with 514 additions and 329 deletions.
2 changes: 1 addition & 1 deletion pkg/api/aim/experiments.go
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ func DeleteExperiment(c *fiber.Ctx) error {
id32 := int32(id)

// TODO this code should move to service with injected repository
experimentRepo := repositories.NewExperimentRepository(database.DB.DB)
experimentRepo := repositories.NewExperimentRepository(database.DB)
experiment := models.Experiment{ID: &id32}
err = experimentRepo.Delete(c.Context(), &experiment)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion pkg/api/aim/projects.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
func GetProject(c *fiber.Ctx) error {
return c.JSON(fiber.Map{
"name": "FastTrackML",
"path": database.DB.DSN(),
"path": database.DB.Dialector.Name(),
"description": "",
"telemetry_enabled": 0,
})
Expand Down
10 changes: 5 additions & 5 deletions pkg/api/aim/runs.go
Original file line number Diff line number Diff line change
Expand Up @@ -931,7 +931,7 @@ func DeleteRun(c *fiber.Ctx) error {
}

// TODO this code should move to service with injected repository
runRepo := repositories.NewRunRepository(database.DB.DB)
runRepo := repositories.NewRunRepository(database.DB)
run := models.Run{ID: params.ID}
err := runRepo.Delete(c.Context(), &run)
if err != nil {
Expand Down Expand Up @@ -961,7 +961,7 @@ func UpdateRun(c *fiber.Ctx) error {

// TODO this code should move to service
run := models.Run{ID: params.ID}
runRepo := repositories.NewRunRepository(database.DB.DB)
runRepo := repositories.NewRunRepository(database.DB)
var err error
if update.Archived != nil {
if *update.Archived {
Expand All @@ -977,7 +977,7 @@ func UpdateRun(c *fiber.Ctx) error {

if update.Name != nil {
run.Name = *update.Name
err = database.DB.DB.Transaction(func(tx *gorm.DB) error {
err = database.DB.Transaction(func(tx *gorm.DB) error {
if err := runRepo.UpdateWithTransaction(c.Context(), tx, &run); err != nil {
return err
}
Expand All @@ -1002,7 +1002,7 @@ func ArchiveBatch(c *fiber.Ctx) error {
}

// TODO this code should move to service
runRepo := repositories.NewRunRepository(database.DB.DB)
runRepo := repositories.NewRunRepository(database.DB)
var err error
if c.Query("archive") == "true" {
err = runRepo.ArchiveBatch(c.Context(), ids)
Expand All @@ -1024,7 +1024,7 @@ func DeleteBatch(c *fiber.Ctx) error {
}

// TODO this code should move to service
runRepo := repositories.NewRunRepository(database.DB.DB)
runRepo := repositories.NewRunRepository(database.DB)
if err := runRepo.DeleteBatch(c.Context(), ids); err != nil {
return err
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/cmd/import.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,13 @@ func importCmd(cmd *cobra.Command, args []string) error {
}

// initDBs inits the input and output DB connections.
func initDBs() (input, output *database.DbInstance, err error) {
func initDBs() (input, output database.DBProvider, err error) {
databaseSlowThreshold := time.Second * 1
databasePoolMax := 20
databaseReset := false
databaseMigrate := false
artifactRoot := "s3://fasttrackml"
input, err = database.MakeDBInstance(
input, err = database.MakeDBProvider(
viper.GetString("input-database-uri"),
databaseSlowThreshold,
databasePoolMax,
Expand All @@ -55,7 +55,7 @@ func initDBs() (input, output *database.DbInstance, err error) {
}

databaseMigrate = true
output, err = database.MakeDBInstance(
output, err = database.MakeDBProvider(
viper.GetString("output-database-uri"),
databaseSlowThreshold,
databasePoolMax,
Expand Down
24 changes: 13 additions & 11 deletions pkg/cmd/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,24 +75,24 @@ func serverCmd(cmd *cobra.Command, args []string) error {
mlflowAPI.NewRouter(
controller.NewController(
run.NewService(
repositories.NewTagRepository(db.DB),
repositories.NewRunRepository(db.DB),
repositories.NewParamRepository(db.DB),
repositories.NewMetricRepository(db.DB),
repositories.NewExperimentRepository(db.DB),
repositories.NewTagRepository(db.GormDB()),
repositories.NewRunRepository(db.GormDB()),
repositories.NewParamRepository(db.GormDB()),
repositories.NewMetricRepository(db.GormDB()),
repositories.NewExperimentRepository(db.GormDB()),
),
model.NewService(),
metric.NewService(
repositories.NewMetricRepository(db.DB),
repositories.NewMetricRepository(db.GormDB()),
),
artifact.NewService(
storage,
repositories.NewRunRepository(db.DB),
repositories.NewRunRepository(db.GormDB()),
),
experiment.NewService(
mlflowConfig,
repositories.NewTagRepository(db.DB),
repositories.NewExperimentRepository(db.DB),
repositories.NewTagRepository(db.GormDB()),
repositories.NewExperimentRepository(db.GormDB()),
),
),
).Init(server)
Expand Down Expand Up @@ -124,8 +124,8 @@ func serverCmd(cmd *cobra.Command, args []string) error {
}

// initDB init DB connection.
func initDB(config *mlflowConfig.ServiceConfig) (*database.DbInstance, error) {
db, err := database.ConnectDB(
func initDB(config *mlflowConfig.ServiceConfig) (database.DBProvider, error) {
db, err := database.MakeDBProvider(
config.DatabaseURI,
config.DatabaseSlowThreshold,
config.DatabasePoolMax,
Expand All @@ -136,6 +136,8 @@ func initDB(config *mlflowConfig.ServiceConfig) (*database.DbInstance, error) {
if err != nil {
return nil, fmt.Errorf("error connecting to DB: %w", err)
}
// cache a global reference to the gorm.DB
database.DB = db.GormDB()
return db, nil
}

Expand Down
64 changes: 64 additions & 0 deletions pkg/database/db_factory.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package database

import (
"fmt"
"net/url"
"time"

"github.com/rotisserie/eris"
)

// MakeDBProvider will create a DbProvider of the correct type from the parameters.
func MakeDBProvider(
dsn string, slowThreshold time.Duration, poolMax int, reset bool, migrate bool, artifactRoot string,
) (db DBProvider, err error) {
dsnURL, err := url.Parse(dsn)
if err != nil {
return nil, fmt.Errorf("invalid database URL: %w", err)
}
switch dsnURL.Scheme {
case "sqlite":
db, err = NewSqliteDBInstance(
*dsnURL,
slowThreshold,
poolMax,
reset,
)
if err != nil {
return nil, eris.Wrap(err, "error creating sqlite provider")
}
case "postgres", "postgresql":
db, err = NewPostgresDBInstance(
*dsnURL,
slowThreshold,
poolMax,
reset,
)
if err != nil {
return nil, eris.Wrap(err, "error creating postgres provider")
}
default:
{
return nil, eris.New("unsupported database type")
}
}

if reset {
if err := db.Reset(); err != nil {
db.Close()
return nil, err
}
}

if err := checkAndMigrate(migrate, db); err != nil {
db.Close()
return nil, err
}

if err := createDefaultExperiment(artifactRoot, db); err != nil {
db.Close()
return nil, err
}

return db, nil
}
44 changes: 44 additions & 0 deletions pkg/database/db_factory_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package database

import (
"testing"
"time"

"github.com/stretchr/testify/assert"
)

func TestMakeDBProvider(t *testing.T) {
tests := []struct {
name string
dsn string
expectedDialector string
}{
{
name: "WithSqliteURI",
dsn: "sqlite:///tmp/fasttrack.db",
expectedDialector: "sqlite",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
DB = nil
db, err := MakeDBProvider(
tt.dsn,
time.Second*2,
2,
false,
false,
"s3://somewhere",
)
assert.Nil(t, err)
assert.NotNil(t, db)
assert.Equal(t, tt.expectedDialector, db.GormDB().Dialector.Name())

// expecting the global 'DB' not to be set
assert.Nil(t, DB)

err = db.Close()
assert.Nil(t, err)
})
}
}
87 changes: 87 additions & 0 deletions pkg/database/db_instance.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package database

import (
"database/sql"
"errors"
"fmt"
"io"
"strings"
"time"

log "github.com/sirupsen/logrus"
"gorm.io/gorm"
)

// DBProvider is the interface to access the DB.
type DBProvider interface {
GormDB() *gorm.DB
Dsn() string
Close() error
Reset() error
}

// DB is a global gorm.DB reference
var DB *gorm.DB

// DBInstance is the base concrete type for DbProvider.
type DBInstance struct {
*gorm.DB
dsn string
closers []io.Closer
}

// Close will invoke the closers.
func (db *DBInstance) Close() error {
for _, c := range db.closers {
err := c.Close()
if err != nil {
return err
}
}
return nil
}

// Dsn will return the dsn string.
func (db *DBInstance) Dsn() string {
return db.dsn
}

// Db will return the gorm DB.
func (db *DBInstance) GormDB() *gorm.DB {
return db.DB
}

// createDefaultExperiment will create the default experiment if needed.
func createDefaultExperiment(artifactRoot string, db DBProvider) error {
if tx := db.GormDB().First(&Experiment{}, 0); tx.Error != nil {
if errors.Is(tx.Error, gorm.ErrRecordNotFound) {
log.Info("Creating default experiment")
var id int32 = 0
ts := time.Now().UTC().UnixMilli()
exp := Experiment{
ID: &id,
Name: "Default",
LifecycleStage: LifecycleStageActive,
CreationTime: sql.NullInt64{
Int64: ts,
Valid: true,
},
LastUpdateTime: sql.NullInt64{
Int64: ts,
Valid: true,
},
}
if tx := db.GormDB().Create(&exp); tx.Error != nil {
return fmt.Errorf("error creating default experiment: %s", tx.Error)
}

exp.ArtifactLocation = fmt.Sprintf("%s/%d", strings.TrimRight(artifactRoot, "/"), *exp.ID)
if tx := db.GormDB().Model(&exp).Update("ArtifactLocation", exp.ArtifactLocation); tx.Error != nil {
return fmt.Errorf("error updating artifact_location for experiment '%s': %s", exp.Name, tx.Error)
}
} else {
return fmt.Errorf("unable to find default experiment: %s", tx.Error)
}
}
return nil
}
6 changes: 3 additions & 3 deletions pkg/database/import.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ type Importer struct {
}

// NewImporter initializes an Importer.
func NewImporter(input, output *DbInstance) *Importer {
func NewImporter(input, output DBProvider) *Importer {
return &Importer{
sourceDB: input.DB,
destDB: output.DB,
sourceDB: input.GormDB(),
destDB: output.GormDB(),
experimentInfos: []experimentInfo{},
}
}
Expand Down
Loading

0 comments on commit a052a5f

Please sign in to comment.