-
Notifications
You must be signed in to change notification settings - Fork 19
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Copy the contents of one db instance to another (#200)
* Bits for copying source to target dbs * Use transaction for destination DB * Function comment clarified * extract new method MakeDbInstance for use by import; do not migrate the input db * Add some collision detection * switch to streaming approach for data transfer * Tiny formatting change * Remove blank lines that lint dislikes * Fix lint errors plus double-registration error for sqlite driver when the import args are both sqlite:// * Start / WIP for import int test * validate the source and target db using rownumbers * undo baseFixtures rename * Trust primary key / unique constraints to prevent collisions * Wrap import destination in Tx per table * Additional entitites for db import to verify * WIP more explicit handling, less generics * Revert "WIP more explicit handling, less generics" Also switched to "string" tablenames to avoid generics/polymophism issues * Fix lint * Importer working for experiments, sqlite -> pg * Comments and additional test * Fix lint * Another lint complaint * remove comments and dry-run support * Fix experiment_id association * Add some additional testing fields * Fix lint complaint * disable import and testing of apps, dashboards * Fixture rename * Restore missing ArchiveRuns() function * Add slices module and use to detect sqlite driver registration * Add struct to hold import functions as methods * Missed this chunk for adding Importer struct * remove unneeded err check * Add table-contents test and rename helpers * remove separate declaration of variable --------- Co-authored-by: Geoff Wilson <[email protected]>
- Loading branch information
Showing
13 changed files
with
606 additions
and
88 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
package cmd | ||
|
||
import ( | ||
"fmt" | ||
"time" | ||
|
||
"github.com/spf13/cobra" | ||
"github.com/spf13/viper" | ||
|
||
"github.com/G-Research/fasttrackml/pkg/database" | ||
) | ||
|
||
var ImportCmd = &cobra.Command{ | ||
Use: "import", | ||
Short: "Copies an input database to an output database", | ||
Long: `The import command will transfer the contents of the input | ||
database to the output database. Please make sure that the | ||
FasttrackML server is not currently connected to the input | ||
database.`, | ||
RunE: importCmd, | ||
} | ||
|
||
func importCmd(cmd *cobra.Command, args []string) error { | ||
inputDB, outputDB, err := initDBs() | ||
if err != nil { | ||
return err | ||
} | ||
defer inputDB.Close() | ||
defer outputDB.Close() | ||
|
||
importer := database.NewImporter(inputDB, outputDB) | ||
if err := importer.Import(); err != nil { | ||
return err | ||
} | ||
return nil | ||
} | ||
|
||
// initDBs inits the input and output DB connections. | ||
func initDBs() (input, output *database.DbInstance, err error) { | ||
databaseSlowThreshold := time.Second * 1 | ||
databasePoolMax := 20 | ||
databaseReset := false | ||
databaseMigrate := false | ||
artifactRoot := "s3://fasttrackml" | ||
input, err = database.MakeDBInstance( | ||
viper.GetString("input-database-uri"), | ||
databaseSlowThreshold, | ||
databasePoolMax, | ||
databaseReset, | ||
databaseMigrate, | ||
artifactRoot, | ||
) | ||
if err != nil { | ||
return input, output, fmt.Errorf("error connecting to input DB: %w", err) | ||
} | ||
|
||
databaseMigrate = true | ||
output, err = database.MakeDBInstance( | ||
viper.GetString("output-database-uri"), | ||
databaseSlowThreshold, | ||
databasePoolMax, | ||
databaseReset, | ||
databaseMigrate, | ||
artifactRoot, | ||
) | ||
if err != nil { | ||
return input, output, fmt.Errorf("error connecting to output DB: %w", err) | ||
} | ||
return | ||
} | ||
|
||
func init() { | ||
RootCmd.AddCommand(ImportCmd) | ||
|
||
ImportCmd.Flags().StringP("input-database-uri", "i", "", "Input Database URI (eg., sqlite://fasttrackml.db)") | ||
ImportCmd.Flags().StringP("output-database-uri", "o", "", "Output Database URI (eg., postgres://user:psw@postgres:5432)") | ||
ImportCmd.MarkFlagRequired("input-database-uri") | ||
ImportCmd.MarkFlagRequired("output-database-uri") | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
package database | ||
|
||
import ( | ||
"github.com/rotisserie/eris" | ||
log "github.com/sirupsen/logrus" | ||
"gorm.io/gorm" | ||
"gorm.io/gorm/clause" | ||
) | ||
|
||
type experimentInfo struct { | ||
sourceID int64 | ||
destID int64 | ||
} | ||
|
||
// Importer will handle transport of data from source to destination db. | ||
type Importer struct { | ||
sourceDB *gorm.DB | ||
destDB *gorm.DB | ||
experimentInfos []experimentInfo | ||
} | ||
|
||
// NewImporter initializes an Importer. | ||
func NewImporter(input, output *DbInstance) *Importer { | ||
return &Importer{ | ||
sourceDB: input.DB, | ||
destDB: output.DB, | ||
experimentInfos: []experimentInfo{}, | ||
} | ||
} | ||
|
||
// Import will copy the contents of input db to output db. | ||
func (s *Importer) Import() error { | ||
tables := []string{ | ||
"experiment_tags", | ||
"runs", | ||
"tags", | ||
"params", | ||
"metrics", | ||
"latest_metrics", | ||
// "apps", | ||
// "dashboards", | ||
} | ||
// experiments needs special handling | ||
if err := s.importExperiments(); err != nil { | ||
return eris.Wrapf(err, "error importing table %s", "experiements") | ||
} | ||
// all other tables | ||
for _, table := range tables { | ||
if err := s.importTable(table); err != nil { | ||
return eris.Wrapf(err, "error importing table %s", table) | ||
} | ||
} | ||
return nil | ||
} | ||
|
||
// importExperiments will copy the contents of the experiments table from sourceDB to destDB, | ||
// while recording the new ID. | ||
func (s *Importer) importExperiments() error { | ||
// Start transaction in the destDB | ||
err := s.destDB.Transaction(func(destTX *gorm.DB) error { | ||
// Query data from the source database | ||
rows, err := s.sourceDB.Model(Experiment{}).Rows() | ||
if err != nil { | ||
return eris.Wrap(err, "error creating Rows instance from source") | ||
} | ||
defer rows.Close() | ||
|
||
count := 0 | ||
for rows.Next() { | ||
var scannedItem Experiment | ||
if err := s.sourceDB.ScanRows(rows, &scannedItem); err != nil { | ||
return eris.Wrap(err, "error creating Rows instance from source") | ||
} | ||
newItem := Experiment{ | ||
Name: scannedItem.Name, | ||
ArtifactLocation: scannedItem.ArtifactLocation, | ||
LifecycleStage: scannedItem.LifecycleStage, | ||
CreationTime: scannedItem.CreationTime, | ||
LastUpdateTime: scannedItem.LastUpdateTime, | ||
} | ||
if err := destTX. | ||
Where(Experiment{Name: scannedItem.Name}). | ||
FirstOrCreate(&newItem).Error; err != nil { | ||
return eris.Wrap(err, "error creating destination row") | ||
} | ||
s.saveExperimentInfo(scannedItem, newItem) | ||
count++ | ||
} | ||
log.Infof("Importing %s - found %v records", "experiments", count) | ||
return nil | ||
}) | ||
if err != nil { | ||
return eris.Wrap(err, "error copying experiments table") | ||
} | ||
return nil | ||
} | ||
|
||
// importTablewill copy the contents of one table (model) from sourceDB | ||
// while updating the experiment_id to destDB. | ||
func (s *Importer) importTable(table string) error { | ||
// Start transaction in the destDB | ||
err := s.destDB.Transaction(func(destTX *gorm.DB) error { | ||
// Query data from the source database | ||
rows, err := s.sourceDB.Table(table).Rows() | ||
if err != nil { | ||
return eris.Wrap(err, "error creating Rows instance from source") | ||
} | ||
defer rows.Close() | ||
|
||
count := 0 | ||
for rows.Next() { | ||
var item map[string]any | ||
if err := s.sourceDB.ScanRows(rows, &item); err != nil { | ||
return eris.Wrap(err, "error scanning source row") | ||
} | ||
item, err = s.translateFields(item) | ||
if err != nil { | ||
return eris.Wrap(err, "error translating fields") | ||
} | ||
if err := destTX. | ||
Table(table). | ||
Clauses(clause.OnConflict{DoNothing: true}). | ||
Create(&item).Error; err != nil { | ||
return eris.Wrap(err, "error creating destination row") | ||
} | ||
count++ | ||
} | ||
log.Infof("Importing %s - found %v records", table, count) | ||
return nil | ||
}) | ||
if err != nil { | ||
return err | ||
} | ||
return nil | ||
} | ||
|
||
// saveExperimentInfo will relate the source and destination experiment for later id mapping. | ||
func (s *Importer) saveExperimentInfo(source, dest Experiment) { | ||
s.experimentInfos = append(s.experimentInfos, experimentInfo{ | ||
sourceID: int64(*source.ID), | ||
destID: int64(*dest.ID), | ||
}) | ||
} | ||
|
||
// translateFields will alter row before creation as needed (especially, replacing old experiment_id with new). | ||
func (s *Importer) translateFields(item map[string]any) (map[string]any, error) { | ||
// boolean is numeric when coming from sqlite | ||
if isNaN, ok := item["is_nan"]; ok { | ||
switch v := isNaN.(type) { | ||
case bool: | ||
break | ||
default: | ||
item["is_nan"] = (v != 0.0) | ||
} | ||
} | ||
// items with experiment_id fk need to reference the new ID | ||
if expID, ok := item["experiment_id"]; ok { | ||
id, ok := expID.(int64) | ||
if !ok { | ||
return nil, eris.Errorf("unable to assert experiment_id as int64: %v", expID) | ||
} | ||
for _, expInfo := range s.experimentInfos { | ||
if expInfo.sourceID == id { | ||
item["experiment_id"] = expInfo.destID | ||
} | ||
} | ||
} | ||
return item, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.