Skip to content

Commit

Permalink
Copy the contents of one db instance to another (#200)
Browse files Browse the repository at this point in the history
* Bits for copying source to target dbs

* Use transaction for destination DB

* Function comment clarified

* extract new method MakeDbInstance for use by import; do not migrate the input db

* Add some collision detection

* switch to streaming approach for data transfer

* Tiny formatting change

* Remove blank lines that lint dislikes

* Fix lint errors plus double-registration error for sqlite driver when
the import args are both sqlite://

* Start / WIP for import int test

* validate the source and target db using rownumbers

* undo baseFixtures rename

* Trust primary key / unique constraints to prevent collisions

* Wrap import destination in Tx per table

* Additional entitites for db import to verify

* WIP more explicit handling, less generics

* Revert "WIP more explicit handling, less generics"

Also switched to "string" tablenames to avoid generics/polymophism issues

* Fix lint

* Importer working for experiments, sqlite -> pg

* Comments and additional test

* Fix lint

* Another lint complaint

* remove comments and dry-run support

* Fix experiment_id association

* Add some additional testing fields

* Fix lint complaint

* disable import and testing of apps, dashboards

* Fixture rename

* Restore missing ArchiveRuns() function

* Add slices module and use to detect sqlite driver registration

* Add struct to hold import functions as methods

* Missed this chunk for adding Importer struct

* remove unneeded err check

* Add table-contents test and rename helpers

* remove separate declaration of variable

---------

Co-authored-by: Geoff Wilson <[email protected]>
  • Loading branch information
suprjinx and Geoff Wilson authored Aug 22, 2023
1 parent d511346 commit a92e205
Show file tree
Hide file tree
Showing 13 changed files with 606 additions and 88 deletions.
79 changes: 79 additions & 0 deletions pkg/cmd/import.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package cmd

import (
"fmt"
"time"

"github.com/spf13/cobra"
"github.com/spf13/viper"

"github.com/G-Research/fasttrackml/pkg/database"
)

var ImportCmd = &cobra.Command{
Use: "import",
Short: "Copies an input database to an output database",
Long: `The import command will transfer the contents of the input
database to the output database. Please make sure that the
FasttrackML server is not currently connected to the input
database.`,
RunE: importCmd,
}

func importCmd(cmd *cobra.Command, args []string) error {
inputDB, outputDB, err := initDBs()
if err != nil {
return err
}
defer inputDB.Close()
defer outputDB.Close()

importer := database.NewImporter(inputDB, outputDB)
if err := importer.Import(); err != nil {
return err
}
return nil
}

// initDBs inits the input and output DB connections.
func initDBs() (input, output *database.DbInstance, err error) {
databaseSlowThreshold := time.Second * 1
databasePoolMax := 20
databaseReset := false
databaseMigrate := false
artifactRoot := "s3://fasttrackml"
input, err = database.MakeDBInstance(
viper.GetString("input-database-uri"),
databaseSlowThreshold,
databasePoolMax,
databaseReset,
databaseMigrate,
artifactRoot,
)
if err != nil {
return input, output, fmt.Errorf("error connecting to input DB: %w", err)
}

databaseMigrate = true
output, err = database.MakeDBInstance(
viper.GetString("output-database-uri"),
databaseSlowThreshold,
databasePoolMax,
databaseReset,
databaseMigrate,
artifactRoot,
)
if err != nil {
return input, output, fmt.Errorf("error connecting to output DB: %w", err)
}
return
}

func init() {
RootCmd.AddCommand(ImportCmd)

ImportCmd.Flags().StringP("input-database-uri", "i", "", "Input Database URI (eg., sqlite://fasttrackml.db)")
ImportCmd.Flags().StringP("output-database-uri", "o", "", "Output Database URI (eg., postgres://user:psw@postgres:5432)")
ImportCmd.MarkFlagRequired("input-database-uri")
ImportCmd.MarkFlagRequired("output-database-uri")
}
50 changes: 33 additions & 17 deletions pkg/database/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,29 @@ func (db *DbInstance) DSN() string {
return db.dsn
}

// DB is a global db instance.
var DB *DbInstance = &DbInstance{}

func ConnectDB(
// ConnectDB will establish and return a DbInstance while also caching it in the global
// var database.DB.
func ConnectDB(dsn string, slowThreshold time.Duration, poolMax int, reset bool, migrate bool, artifactRoot string,
) (*DbInstance, error) {
db, err := MakeDBInstance(dsn, slowThreshold, poolMax, reset, migrate, artifactRoot)
if err != nil {
return nil, err
}
// set the global DB
DB = db
return DB, nil
}

// MakeDbInstance will create a DbInstance from the parameters.
func MakeDBInstance(
dsn string, slowThreshold time.Duration, poolMax int, reset bool, migrate bool, artifactRoot string,
) (*DbInstance, error) {
DB.dsn = dsn
// local db instance
db := DbInstance{}
db.dsn = dsn
var sourceConn gorm.Dialector
var replicaConn gorm.Dialector
u, err := url.Parse(dsn)
Expand Down Expand Up @@ -112,7 +129,7 @@ func ConnectDB(
if err != nil {
return nil, fmt.Errorf("failed to connect to database: %w", err)
}
DB.closers = append(DB.closers, s)
db.closers = append(db.closers, s)
s.SetMaxIdleConns(1)
s.SetMaxOpenConns(1)
s.SetConnMaxIdleTime(0)
Expand All @@ -125,10 +142,10 @@ func ConnectDB(
dbURL.RawQuery = q.Encode()
r, err := sql.Open(SQLiteCustomDriverName, strings.Replace(dbURL.String(), "sqlite://", "file:", 1))
if err != nil {
DB.Close()
db.Close()
return nil, fmt.Errorf("failed to connect to database: %w", err)
}
DB.closers = append(DB.closers, r)
db.closers = append(db.closers, r)
replicaConn = sqlite.Dialector{
Conn: r,
}
Expand All @@ -148,7 +165,7 @@ func ConnectDB(
if log.GetLevel() == log.DebugLevel {
dbLogLevel = logger.Info
}
DB.DB, err = gorm.Open(sourceConn, &gorm.Config{
db.DB, err = gorm.Open(sourceConn, &gorm.Config{
Logger: logger.New(
glog.New(
log.StandardLogger().WriterLevel(log.WarnLevel),
Expand All @@ -163,12 +180,12 @@ func ConnectDB(
),
})
if err != nil {
DB.Close()
db.Close()
return nil, fmt.Errorf("failed to connect to database: %w", err)
}

if replicaConn != nil {
DB.Use(
db.Use(
dbresolver.Register(dbresolver.Config{
Replicas: []gorm.Dialector{
replicaConn,
Expand All @@ -178,30 +195,29 @@ func ConnectDB(
}

if u.Scheme != "sqlite" {
sqlDB, _ := DB.DB.DB()
sqlDB, _ := db.DB.DB()
sqlDB.SetConnMaxIdleTime(time.Minute)
sqlDB.SetMaxIdleConns(poolMax)
sqlDB.SetMaxOpenConns(poolMax)

if reset {
if err := resetDB(DB); err != nil {
DB.Close()
if err := resetDB(&db); err != nil {
db.Close()
return nil, err
}
}
}

if err := checkAndMigrateDB(DB, migrate); err != nil {
DB.Close()
if err := checkAndMigrateDB(&db, migrate); err != nil {
db.Close()
return nil, err
}

if err := createDefaultExperiment(DB, artifactRoot); err != nil {
DB.Close()
if err := createDefaultExperiment(&db, artifactRoot); err != nil {
db.Close()
return nil, err
}

return DB, nil
return &db, nil
}

func resetDB(db *DbInstance) error {
Expand Down
169 changes: 169 additions & 0 deletions pkg/database/import.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
package database

import (
"github.com/rotisserie/eris"
log "github.com/sirupsen/logrus"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)

type experimentInfo struct {
sourceID int64
destID int64
}

// Importer will handle transport of data from source to destination db.
type Importer struct {
sourceDB *gorm.DB
destDB *gorm.DB
experimentInfos []experimentInfo
}

// NewImporter initializes an Importer.
func NewImporter(input, output *DbInstance) *Importer {
return &Importer{
sourceDB: input.DB,
destDB: output.DB,
experimentInfos: []experimentInfo{},
}
}

// Import will copy the contents of input db to output db.
func (s *Importer) Import() error {
tables := []string{
"experiment_tags",
"runs",
"tags",
"params",
"metrics",
"latest_metrics",
// "apps",
// "dashboards",
}
// experiments needs special handling
if err := s.importExperiments(); err != nil {
return eris.Wrapf(err, "error importing table %s", "experiements")
}
// all other tables
for _, table := range tables {
if err := s.importTable(table); err != nil {
return eris.Wrapf(err, "error importing table %s", table)
}
}
return nil
}

// importExperiments will copy the contents of the experiments table from sourceDB to destDB,
// while recording the new ID.
func (s *Importer) importExperiments() error {
// Start transaction in the destDB
err := s.destDB.Transaction(func(destTX *gorm.DB) error {
// Query data from the source database
rows, err := s.sourceDB.Model(Experiment{}).Rows()
if err != nil {
return eris.Wrap(err, "error creating Rows instance from source")
}
defer rows.Close()

count := 0
for rows.Next() {
var scannedItem Experiment
if err := s.sourceDB.ScanRows(rows, &scannedItem); err != nil {
return eris.Wrap(err, "error creating Rows instance from source")
}
newItem := Experiment{
Name: scannedItem.Name,
ArtifactLocation: scannedItem.ArtifactLocation,
LifecycleStage: scannedItem.LifecycleStage,
CreationTime: scannedItem.CreationTime,
LastUpdateTime: scannedItem.LastUpdateTime,
}
if err := destTX.
Where(Experiment{Name: scannedItem.Name}).
FirstOrCreate(&newItem).Error; err != nil {
return eris.Wrap(err, "error creating destination row")
}
s.saveExperimentInfo(scannedItem, newItem)
count++
}
log.Infof("Importing %s - found %v records", "experiments", count)
return nil
})
if err != nil {
return eris.Wrap(err, "error copying experiments table")
}
return nil
}

// importTablewill copy the contents of one table (model) from sourceDB
// while updating the experiment_id to destDB.
func (s *Importer) importTable(table string) error {
// Start transaction in the destDB
err := s.destDB.Transaction(func(destTX *gorm.DB) error {
// Query data from the source database
rows, err := s.sourceDB.Table(table).Rows()
if err != nil {
return eris.Wrap(err, "error creating Rows instance from source")
}
defer rows.Close()

count := 0
for rows.Next() {
var item map[string]any
if err := s.sourceDB.ScanRows(rows, &item); err != nil {
return eris.Wrap(err, "error scanning source row")
}
item, err = s.translateFields(item)
if err != nil {
return eris.Wrap(err, "error translating fields")
}
if err := destTX.
Table(table).
Clauses(clause.OnConflict{DoNothing: true}).
Create(&item).Error; err != nil {
return eris.Wrap(err, "error creating destination row")
}
count++
}
log.Infof("Importing %s - found %v records", table, count)
return nil
})
if err != nil {
return err
}
return nil
}

// saveExperimentInfo will relate the source and destination experiment for later id mapping.
func (s *Importer) saveExperimentInfo(source, dest Experiment) {
s.experimentInfos = append(s.experimentInfos, experimentInfo{
sourceID: int64(*source.ID),
destID: int64(*dest.ID),
})
}

// translateFields will alter row before creation as needed (especially, replacing old experiment_id with new).
func (s *Importer) translateFields(item map[string]any) (map[string]any, error) {
// boolean is numeric when coming from sqlite
if isNaN, ok := item["is_nan"]; ok {
switch v := isNaN.(type) {
case bool:
break
default:
item["is_nan"] = (v != 0.0)
}
}
// items with experiment_id fk need to reference the new ID
if expID, ok := item["experiment_id"]; ok {
id, ok := expID.(int64)
if !ok {
return nil, eris.Errorf("unable to assert experiment_id as int64: %v", expID)
}
for _, expInfo := range s.experimentInfos {
if expInfo.sourceID == id {
item["experiment_id"] = expInfo.destID
}
}
}
return item, nil
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func (s *GetExperimentActivityTestSuite) Test_Ok() {
assert.Nil(s.T(), err)

archivedRunsIds := []string{runs[0].ID, runs[1].ID}
err = s.runFixtures.ArchiveRun(context.Background(), archivedRunsIds)
err = s.runFixtures.ArchiveRuns(context.Background(), archivedRunsIds)
assert.Nil(s.T(), err)

var resp response.GetExperimentActivity
Expand Down
Loading

0 comments on commit a92e205

Please sign in to comment.