diff --git a/pkg/api/aim/helpers.go b/pkg/api/aim/helpers.go index eac7fd43d..4dbb9ddc2 100644 --- a/pkg/api/aim/helpers.go +++ b/pkg/api/aim/helpers.go @@ -6,6 +6,7 @@ import ( "encoding/json" "fmt" "math" + "reflect" "time" "github.com/gofiber/fiber/v2" @@ -41,7 +42,7 @@ func RunsSearchAsCSVResponse(ctx *fiber.Ctx, runs []database.Run, excludeTraces, if metric.IsNan { v = math.NaN() } - key := fmt.Sprintf("%s %s", metric.Key, metric.Context.Json.String()) + key := fmt.Sprintf("%s %s", metric.Key, string(metric.Context.Json)) if _, ok := metricData[key]; ok { metricData[key][run.ID] = v } else { @@ -270,3 +271,15 @@ func RunsSearchAsStreamResponse( log.Infof("body - %s %s %s", time.Since(start), ctx.Method(), ctx.Path()) }) } + +// CompareJson compares two json objects. +func CompareJson(json1, json2 []byte) bool { + var j, j2 interface{} + if err := json.Unmarshal(json1, &j); err != nil { + return false + } + if err := json.Unmarshal(json2, &j2); err != nil { + return false + } + return reflect.DeepEqual(j2, j) +} diff --git a/pkg/api/aim/runs.go b/pkg/api/aim/runs.go index 792c1cb2a..f6f9ff681 100644 --- a/pkg/api/aim/runs.go +++ b/pkg/api/aim/runs.go @@ -25,6 +25,7 @@ import ( "github.com/G-Research/fasttrackml/pkg/api/mlflow/api" "github.com/G-Research/fasttrackml/pkg/api/mlflow/dao/models" "github.com/G-Research/fasttrackml/pkg/api/mlflow/dao/repositories" + "github.com/G-Research/fasttrackml/pkg/common/db/types" "github.com/G-Research/fasttrackml/pkg/common/middleware/namespace" "github.com/G-Research/fasttrackml/pkg/database" ) @@ -170,14 +171,14 @@ func GetRunMetrics(c *fiber.Ctx) error { return fiber.NewError(fiber.StatusUnprocessableEntity, err.Error()) } - metricKeysMap, contexts := make(fiber.Map, len(b)), make([]string, 0, len(b)) + metricKeysMap, contexts := make(fiber.Map, len(b)), make([]types.JSONB, 0, len(b)) for _, m := range b { if m.Context != nil { serializedContext, err := json.Marshal(m.Context) if err != nil { return fiber.NewError(fiber.StatusUnprocessableEntity, err.Error()) } - contexts = append(contexts, string(serializedContext)) + contexts = append(contexts, serializedContext) } metricKeysMap[m.Name] = nil } @@ -217,11 +218,7 @@ func GetRunMetrics(c *fiber.Ctx) error { func() *gorm.DB { query := database.DB for _, context := range contexts { - if query.Dialector.Name() == database.SQLiteDialectorName { - query = query.Or("json(json) = json(?)", context) - } else { - query = query.Or("json = ?", context) - } + query = query.Or("json = ?", context) } return query }(), @@ -239,7 +236,7 @@ func GetRunMetrics(c *fiber.Ctx) error { name string iters []int values []*float64 - context datatypes.JSON + context json.RawMessage }, len(metricKeys)) for _, m := range data { @@ -254,7 +251,7 @@ func GetRunMetrics(c *fiber.Ctx) error { k.name = m.Key k.iters = append(k.iters, int(m.Iter)) k.values = append(k.values, pv) - k.context = m.Context.Json + k.context = json.RawMessage(m.Context.Json) metrics[key] = k } @@ -869,7 +866,7 @@ func SearchAlignedMetrics(c *fiber.Ctx) error { return fiber.NewError(fiber.StatusUnprocessableEntity, err.Error()) } - values, capacity, contextsMap := []any{}, 0, map[string]string{} + values, capacity, contextsMap := []any{}, 0, map[string]types.JSONB{} for _, r := range b.Runs { for _, t := range r.Traces { l := t.Slice[2] @@ -885,20 +882,16 @@ func SearchAlignedMetrics(c *fiber.Ctx) error { contextHash := fmt.Sprintf("%x", sum) _, ok := contextsMap[contextHash] if !ok { - contextsMap[contextHash] = string(data) + contextsMap[contextHash] = data } - values = append(values, r.ID, t.Name, string(data), float32(l)) + values = append(values, r.ID, t.Name, data, float32(l)) } } // map context values to context ids query := database.DB for _, context := range contextsMap { - if query.Dialector.Name() == database.SQLiteDialectorName { - query = query.Or("json(contexts.json) = json(?)", context) - } else { - query = query.Or("contexts.json = ?", context) - } + query = query.Or("contexts.json = ?", context) } var contexts []database.Context if err := query.Find(&contexts).Error; err != nil { @@ -908,11 +901,7 @@ func SearchAlignedMetrics(c *fiber.Ctx) error { // add context ids to `values` for _, context := range contexts { for i := 2; i < len(values); i += 4 { - json, err := json.Marshal(context.Json) - if err != nil { - return api.NewInternalError("error serializing context: %s", err) - } - if values[i] == string(json) { + if CompareJson(values[i].([]byte), context.Json) { values[i] = context.ID } } diff --git a/pkg/api/mlflow/controller/metric.go b/pkg/api/mlflow/controller/metric.go index d2fc442ed..26459dad2 100644 --- a/pkg/api/mlflow/controller/metric.go +++ b/pkg/api/mlflow/controller/metric.go @@ -133,7 +133,7 @@ func (c Controller) GetMetricHistories(ctx *fiber.Ctx) error { } else { b.Field(4).(*array.Float64Builder).Append(m.Value) } - b.Field(5).(*array.StringBuilder).Append(m.Context.Json.String()) + b.Field(5).(*array.StringBuilder).Append(string(m.Context.Json)) if (i+1)%100000 == 0 { if err := WriteStreamingRecord(writer, b.NewRecord()); err != nil { return fmt.Errorf("unable to write Arrow record batch: %w", err) diff --git a/pkg/api/mlflow/dao/models/metric.go b/pkg/api/mlflow/dao/models/metric.go index de6b2e669..afced87e9 100644 --- a/pkg/api/mlflow/dao/models/metric.go +++ b/pkg/api/mlflow/dao/models/metric.go @@ -4,11 +4,11 @@ import ( "crypto/sha256" "fmt" - "gorm.io/datatypes" + "github.com/G-Research/fasttrackml/pkg/common/db/types" ) // DefaultContext is the default metric context -var DefaultContext = Context{Json: datatypes.JSON("{}")} +var DefaultContext = Context{Json: types.JSONB("{}")} // Metric represents model to work with `metrics` table. type Metric struct { @@ -48,8 +48,8 @@ func (m LatestMetric) UniqueKey() string { // Context represents model to work with `contexts` table. type Context struct { - ID uint `gorm:"primaryKey;autoIncrement"` - Json datatypes.JSON `gorm:"not null;unique;index"` + ID uint `gorm:"primaryKey;autoIncrement"` + Json types.JSONB `gorm:"not null;unique;index"` } // GetJsonHash returns hash of the Context.Json diff --git a/pkg/common/db/types/jsonb.go b/pkg/common/db/types/jsonb.go new file mode 100644 index 000000000..58fe701ad --- /dev/null +++ b/pkg/common/db/types/jsonb.go @@ -0,0 +1,86 @@ +package types + +import ( + "context" + "database/sql/driver" + "encoding/json" + "errors" + "fmt" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" +) + +// JSONB defined JSONB data type, need to implements driver.Valuer, sql.Scanner interface +type JSONB json.RawMessage + +// Value return json value, implement driver.Valuer interface +func (j JSONB) Value() (driver.Value, error) { + if len(j) == 0 { + return nil, nil + } + return string(j), nil +} + +// Scan scan value into Jsonb, implements sql.Scanner interface +func (j *JSONB) Scan(value interface{}) error { + if value == nil { + *j = JSONB("null") + return nil + } + var bytes []byte + switch v := value.(type) { + case []byte: + if len(v) > 0 { + bytes = make([]byte, len(v)) + copy(bytes, v) + } + case string: + bytes = []byte(v) + default: + return errors.New(fmt.Sprint("Failed to unmarshal JSONB value:", value)) + } + + result := json.RawMessage(bytes) + *j = JSONB(result) + return nil +} + +// MarshalJSON to output non base64 encoded []byte +func (j JSONB) MarshalJSON() ([]byte, error) { + return json.RawMessage(j).MarshalJSON() +} + +// UnmarshalJSON to deserialize []byte +func (j *JSONB) UnmarshalJSON(b []byte) error { + result := json.RawMessage{} + err := result.UnmarshalJSON(b) + *j = JSONB(result) + return err +} + +func (j JSONB) String() string { + return string(j) +} + +// GormDataType gorm common data type +func (JSONB) GormDataType() string { + return "json" +} + +// GormDBDataType gorm db data type +func (JSONB) GormDBDataType(db *gorm.DB, field *schema.Field) string { + return "JSONB" +} + +// GormValue gorm db actual value +// nolint +func (js JSONB) GormValue(ctx context.Context, db *gorm.DB) clause.Expr { + if len(js) == 0 { + return gorm.Expr("NULL") + } + + data, _ := js.MarshalJSON() + return gorm.Expr("?", string(data)) +} diff --git a/pkg/database/migrate.go b/pkg/database/migrate.go index 6efa35665..4bf2f6166 100644 --- a/pkg/database/migrate.go +++ b/pkg/database/migrate.go @@ -9,12 +9,12 @@ import ( "time" log "github.com/sirupsen/logrus" - "gorm.io/datatypes" "gorm.io/gorm" "gorm.io/gorm/logger" "github.com/G-Research/fasttrackml/pkg/api/mlflow/common" "github.com/G-Research/fasttrackml/pkg/api/mlflow/dao/models" + "github.com/G-Research/fasttrackml/pkg/common/db/types" "github.com/G-Research/fasttrackml/pkg/database/migrations/v_0001" "github.com/G-Research/fasttrackml/pkg/database/migrations/v_0002" "github.com/G-Research/fasttrackml/pkg/database/migrations/v_0003" @@ -24,6 +24,7 @@ import ( "github.com/G-Research/fasttrackml/pkg/database/migrations/v_0007" "github.com/G-Research/fasttrackml/pkg/database/migrations/v_0008" "github.com/G-Research/fasttrackml/pkg/database/migrations/v_0009" + "github.com/G-Research/fasttrackml/pkg/database/migrations/v_0010" ) var supportedAlembicVersions = []string{ @@ -45,7 +46,7 @@ func CheckAndMigrateDB(migrate bool, db *gorm.DB) error { tx.First(&schemaVersion) } - if !slices.Contains(supportedAlembicVersions, alembicVersion.Version) || schemaVersion.Version != v_0009.Version { + if !slices.Contains(supportedAlembicVersions, alembicVersion.Version) || schemaVersion.Version != v_0010.Version { if !migrate && alembicVersion.Version != "" { return fmt.Errorf( "unsupported database schema versions alembic %s, FastTrackML %s", @@ -190,6 +191,13 @@ func CheckAndMigrateDB(migrate bool, db *gorm.DB) error { if err := v_0009.Migrate(db); err != nil { return fmt.Errorf("error migrating database to FastTrackML schema %s: %w", v_0009.Version, err) } + fallthrough + + case v_0009.Version: + log.Infof("Migrating database to FastTrackML schema %s", v_0010.Version) + if err := v_0010.Migrate(db); err != nil { + return fmt.Errorf("error migrating database to FastTrackML schema %s: %w", v_0010.Version, err) + } default: return fmt.Errorf("unsupported database FastTrackML schema version %s", schemaVersion.Version) @@ -221,7 +229,7 @@ func CheckAndMigrateDB(migrate bool, db *gorm.DB) error { Version: "97727af70f4d", }) tx.Create(&SchemaVersion{ - Version: v_0009.Version, + Version: v_0010.Version, }) tx.Commit() if tx.Error != nil { @@ -308,7 +316,7 @@ func CreateDefaultExperiment(db *gorm.DB, defaultArtifactRoot string) error { // CreateDefaultMetricContext creates the default metric context if it doesn't exist. func CreateDefaultMetricContext(db *gorm.DB) error { - defaultContext := Context{Json: datatypes.JSON("{}")} + defaultContext := Context{Json: types.JSONB("{}")} if err := db.First(&defaultContext).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { log.Info("Creating default context") diff --git a/pkg/database/migrations/v_0010/migrate.go b/pkg/database/migrations/v_0010/migrate.go new file mode 100644 index 000000000..33480b634 --- /dev/null +++ b/pkg/database/migrations/v_0010/migrate.go @@ -0,0 +1,19 @@ +package v_0010 + +import ( + "gorm.io/gorm" +) + +const Version = "10d125c68d9a" + +func Migrate(db *gorm.DB) error { + return db.Transaction(func(tx *gorm.DB) error { + if err := tx.AutoMigrate(&Context{}); err != nil { + return err + } + return tx.Model(&SchemaVersion{}). + Where("1 = 1"). + Update("Version", Version). + Error + }) +} diff --git a/pkg/database/migrations/v_0010/model.go b/pkg/database/migrations/v_0010/model.go new file mode 100644 index 000000000..268a616fb --- /dev/null +++ b/pkg/database/migrations/v_0010/model.go @@ -0,0 +1,241 @@ +package v_0010 + +import ( + "context" + "database/sql" + "database/sql/driver" + "encoding/json" + "time" + + "github.com/google/uuid" + "gorm.io/gorm" + "gorm.io/gorm/clause" + + "github.com/G-Research/fasttrackml/pkg/common/db/types" +) + +type Status string + +const ( + StatusRunning Status = "RUNNING" + StatusScheduled Status = "SCHEDULED" + StatusFinished Status = "FINISHED" + StatusFailed Status = "FAILED" + StatusKilled Status = "KILLED" +) + +type LifecycleStage string + +const ( + LifecycleStageActive LifecycleStage = "active" + LifecycleStageDeleted LifecycleStage = "deleted" +) + +type Namespace struct { + ID uint `gorm:"primaryKey;autoIncrement" json:"id"` + Apps []App `gorm:"constraint:OnDelete:CASCADE" json:"apps"` + Code string `gorm:"unique;index;not null" json:"code"` + Description string `json:"description"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + DeletedAt gorm.DeletedAt `gorm:"index" json:"deleted_at"` + DefaultExperimentID *int32 `gorm:"not null" json:"default_experiment_id"` + Experiments []Experiment `gorm:"constraint:OnDelete:CASCADE" json:"experiments"` +} + +type Experiment struct { + ID *int32 `gorm:"column:experiment_id;not null;primaryKey"` + Name string `gorm:"type:varchar(256);not null;index:,unique,composite:name"` + ArtifactLocation string `gorm:"type:varchar(256)"` + LifecycleStage LifecycleStage `gorm:"type:varchar(32);check:lifecycle_stage IN ('active', 'deleted')"` + CreationTime sql.NullInt64 `gorm:"type:bigint"` + LastUpdateTime sql.NullInt64 `gorm:"type:bigint"` + NamespaceID uint `gorm:"not null;index:,unique,composite:name"` + Namespace Namespace + Tags []ExperimentTag `gorm:"constraint:OnDelete:CASCADE"` + Runs []Run `gorm:"constraint:OnDelete:CASCADE"` +} + +type ExperimentTag struct { + Key string `gorm:"type:varchar(250);not null;primaryKey"` + Value string `gorm:"type:varchar(5000)"` + ExperimentID int32 `gorm:"not null;primaryKey"` +} + +//nolint:lll +type Run struct { + ID string `gorm:"<-:create;column:run_uuid;type:varchar(32);not null;primaryKey"` + Name string `gorm:"type:varchar(250)"` + SourceType string `gorm:"<-:create;type:varchar(20);check:source_type IN ('NOTEBOOK', 'JOB', 'LOCAL', 'UNKNOWN', 'PROJECT')"` + SourceName string `gorm:"<-:create;type:varchar(500)"` + EntryPointName string `gorm:"<-:create;type:varchar(50)"` + UserID string `gorm:"<-:create;type:varchar(256)"` + Status Status `gorm:"type:varchar(9);check:status IN ('SCHEDULED', 'FAILED', 'FINISHED', 'RUNNING', 'KILLED')"` + StartTime sql.NullInt64 `gorm:"<-:create;type:bigint"` + EndTime sql.NullInt64 `gorm:"type:bigint"` + SourceVersion string `gorm:"<-:create;type:varchar(50)"` + LifecycleStage LifecycleStage `gorm:"type:varchar(20);check:lifecycle_stage IN ('active', 'deleted')"` + ArtifactURI string `gorm:"<-:create;type:varchar(200)"` + ExperimentID int32 + Experiment Experiment + DeletedTime sql.NullInt64 `gorm:"type:bigint"` + RowNum RowNum `gorm:"<-:create;index"` + Params []Param `gorm:"constraint:OnDelete:CASCADE"` + Tags []Tag `gorm:"constraint:OnDelete:CASCADE"` + Metrics []Metric `gorm:"constraint:OnDelete:CASCADE"` + LatestMetrics []LatestMetric `gorm:"constraint:OnDelete:CASCADE"` +} + +type RowNum int64 + +func (rn *RowNum) Scan(v interface{}) error { + nullInt := sql.NullInt64{} + if err := nullInt.Scan(v); err != nil { + return err + } + *rn = RowNum(nullInt.Int64) + return nil +} + +func (rn RowNum) GormDataType() string { + return "bigint" +} + +func (rn RowNum) GormValue(ctx context.Context, db *gorm.DB) clause.Expr { + if rn == 0 { + return clause.Expr{ + SQL: "(SELECT COALESCE(MAX(row_num), -1) FROM runs) + 1", + } + } + return clause.Expr{ + SQL: "?", + Vars: []interface{}{int64(rn)}, + } +} + +type Param struct { + Key string `gorm:"type:varchar(250);not null;primaryKey"` + Value string `gorm:"type:varchar(500);not null"` + RunID string `gorm:"column:run_uuid;not null;primaryKey;index"` +} + +type Tag struct { + Key string `gorm:"type:varchar(250);not null;primaryKey"` + Value string `gorm:"type:varchar(5000)"` + RunID string `gorm:"column:run_uuid;not null;primaryKey;index"` +} + +type Metric struct { + Key string `gorm:"type:varchar(250);not null;primaryKey"` + Value float64 `gorm:"type:double precision;not null;primaryKey"` + Timestamp int64 `gorm:"not null;primaryKey"` + RunID string `gorm:"column:run_uuid;not null;primaryKey;index"` + Step int64 `gorm:"default:0;not null;primaryKey"` + IsNan bool `gorm:"default:false;not null;primaryKey"` + Iter int64 `gorm:"index"` + ContextID uint `gorm:"not null;primaryKey"` + Context Context +} + +type LatestMetric struct { + Key string `gorm:"type:varchar(250);not null;primaryKey"` + Value float64 `gorm:"type:double precision;not null"` + Timestamp int64 + Step int64 `gorm:"not null"` + IsNan bool `gorm:"not null"` + RunID string `gorm:"column:run_uuid;not null;primaryKey;index"` + LastIter int64 + ContextID uint `gorm:"not null;primaryKey"` + Context Context +} + +type Context struct { + ID uint `gorm:"primaryKey;autoIncrement"` + Json types.JSONB `gorm:"not null;unique;index"` +} + +type AlembicVersion struct { + Version string `gorm:"column:version_num;type:varchar(32);not null;primaryKey"` +} + +func (AlembicVersion) TableName() string { + return "alembic_version" +} + +type SchemaVersion struct { + Version string `gorm:"not null;primaryKey"` +} + +func (SchemaVersion) TableName() string { + return "schema_version" +} + +type Base struct { + ID uuid.UUID `gorm:"type:uuid;primaryKey" json:"id"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + IsArchived bool `json:"-"` +} + +func (b *Base) BeforeCreate(tx *gorm.DB) error { + b.ID = uuid.New() + return nil +} + +type Dashboard struct { + Base + Name string `json:"name"` + Description string `json:"description"` + AppID *uuid.UUID `gorm:"type:uuid" json:"app_id"` + App App `json:"-"` +} + +func (d Dashboard) MarshalJSON() ([]byte, error) { + type localDashboard Dashboard + type jsonDashboard struct { + localDashboard + AppType *string `json:"app_type"` + } + jd := jsonDashboard{ + localDashboard: localDashboard(d), + } + if d.App.IsArchived { + jd.AppID = nil + } else { + jd.AppType = &d.App.Type + } + return json.Marshal(jd) +} + +type App struct { + Base + Type string `gorm:"not null" json:"type"` + State AppState `json:"state"` + Namespace Namespace `json:"-"` + NamespaceID uint `gorm:"not null" json:"-"` +} + +type AppState map[string]any + +func (s AppState) Value() (driver.Value, error) { + v, err := json.Marshal(s) + if err != nil { + return nil, err + } + return string(v), nil +} + +func (s *AppState) Scan(v interface{}) error { + var nullS sql.NullString + if err := nullS.Scan(v); err != nil { + return err + } + if nullS.Valid { + return json.Unmarshal([]byte(nullS.String), s) + } + return nil +} + +func (s AppState) GormDataType() string { + return "text" +} diff --git a/pkg/database/model.go b/pkg/database/model.go index ba486e7c5..d0cbe20dc 100644 --- a/pkg/database/model.go +++ b/pkg/database/model.go @@ -10,9 +10,10 @@ import ( "time" "github.com/google/uuid" - "gorm.io/datatypes" "gorm.io/gorm" "gorm.io/gorm/clause" + + "github.com/G-Research/fasttrackml/pkg/common/db/types" ) type Status string @@ -162,8 +163,8 @@ type LatestMetric struct { } type Context struct { - ID uint `gorm:"primaryKey;autoIncrement"` - Json datatypes.JSON `gorm:"not null;unique;index"` + ID uint `gorm:"primaryKey;autoIncrement"` + Json types.JSONB `gorm:"not null;unique;index"` } // GetJsonHash returns hash of the Context.Json diff --git a/tests/integration/golang/aim/metric/search_test.go b/tests/integration/golang/aim/metric/search_test.go index f1cd86b78..db4e5dfaa 100644 --- a/tests/integration/golang/aim/metric/search_test.go +++ b/tests/integration/golang/aim/metric/search_test.go @@ -9,11 +9,11 @@ import ( "github.com/google/uuid" "github.com/stretchr/testify/suite" - "gorm.io/datatypes" "github.com/G-Research/fasttrackml/pkg/api/aim/encoding" "github.com/G-Research/fasttrackml/pkg/api/aim/request" "github.com/G-Research/fasttrackml/pkg/api/mlflow/dao/models" + "github.com/G-Research/fasttrackml/pkg/common/db/types" "github.com/G-Research/fasttrackml/tests/integration/golang/helpers" ) @@ -162,7 +162,7 @@ func (s *SearchMetricsTestSuite) Test_Ok() { RunID: run2.ID, Iter: 3, Context: models.Context{ - Json: datatypes.JSON(`{"testkey":"testvalue"}`), + Json: types.JSONB(`{"testkey":"testvalue"}`), }, }) s.Require().Nil(err) @@ -175,7 +175,7 @@ func (s *SearchMetricsTestSuite) Test_Ok() { RunID: run2.ID, LastIter: 3, Context: models.Context{ - Json: datatypes.JSON(`{"testkey":"testvalue"}`), + Json: types.JSONB(`{"testkey":"testvalue"}`), }, }) s.Require().Nil(err) diff --git a/tests/integration/golang/aim/run/get_run_metrics_test.go b/tests/integration/golang/aim/run/get_run_metrics_test.go index 739d63031..d17f81be3 100644 --- a/tests/integration/golang/aim/run/get_run_metrics_test.go +++ b/tests/integration/golang/aim/run/get_run_metrics_test.go @@ -3,6 +3,7 @@ package run import ( "context" "database/sql" + "encoding/json" "net/http" "testing" @@ -12,6 +13,7 @@ import ( "github.com/G-Research/fasttrackml/pkg/api/aim/request" "github.com/G-Research/fasttrackml/pkg/api/aim/response" "github.com/G-Research/fasttrackml/pkg/api/mlflow/dao/models" + "github.com/G-Research/fasttrackml/pkg/common/db/types" "github.com/G-Research/fasttrackml/tests/integration/golang/helpers" ) @@ -57,7 +59,7 @@ func (s *GetRunMetricsTestSuite) Test_Ok() { RunID: run.ID, Iter: 1, Context: models.Context{ - Json: []byte(`{"key1": "key1", "value1": "value1"}`), + Json: types.JSONB(`{"key1":"key1","value1":"value1"}`), }, }) s.Require().Nil(err) @@ -71,7 +73,7 @@ func (s *GetRunMetricsTestSuite) Test_Ok() { RunID: run.ID, Iter: 2, Context: models.Context{ - Json: []byte(`{"key2": "key2", "value2": "value2"}`), + Json: types.JSONB(`{"key2":"key2","value2":"value2"}`), }, }) s.Require().Nil(err) @@ -85,7 +87,7 @@ func (s *GetRunMetricsTestSuite) Test_Ok() { RunID: run.ID, Iter: 3, Context: models.Context{ - Json: []byte(`{"key3": "key3", "value3": "value3"}`), + Json: types.JSONB(`{"key3":"key3","value3":"value3"}`), }, }) s.Require().Nil(err) @@ -99,7 +101,7 @@ func (s *GetRunMetricsTestSuite) Test_Ok() { RunID: run.ID, Iter: 4, Context: models.Context{ - Json: []byte(`{"key4": "key4", "value4": "value4"}`), + Json: types.JSONB(`{"key4":"key4","value4":"value4"}`), }, }) s.Require().Nil(err) @@ -127,25 +129,25 @@ func (s *GetRunMetricsTestSuite) Test_Ok() { Name: "key1", Iters: []int64{1}, Values: []float64{123.1}, - Context: []byte(`{"key1":"key1","value1":"value1"}`), + Context: json.RawMessage(`{"key1":"key1","value1":"value1"}`), }, response.RunMetrics{ Name: "key1", Iters: []int64{2}, Values: []float64{123.2}, - Context: []byte(`{"key2":"key2","value2":"value2"}`), + Context: json.RawMessage(`{"key2":"key2","value2":"value2"}`), }, response.RunMetrics{ Name: "key2", Iters: []int64{3}, Values: []float64{124.1}, - Context: []byte(`{"key3":"key3","value3":"value3"}`), + Context: json.RawMessage(`{"key3":"key3","value3":"value3"}`), }, response.RunMetrics{ Name: "key2", Iters: []int64{4}, Values: []float64{124.2}, - Context: []byte(`{"key4":"key4","value4":"value4"}`), + Context: json.RawMessage(`{"key4":"key4","value4":"value4"}`), }, }, }, @@ -187,25 +189,25 @@ func (s *GetRunMetricsTestSuite) Test_Ok() { Name: "key1", Iters: []int64{1}, Values: []float64{123.1}, - Context: []byte(`{"key1":"key1","value1":"value1"}`), + Context: json.RawMessage(`{"key1":"key1","value1":"value1"}`), }, response.RunMetrics{ Name: "key1", Iters: []int64{2}, Values: []float64{123.2}, - Context: []byte(`{"key2":"key2","value2":"value2"}`), + Context: json.RawMessage(`{"key2":"key2","value2":"value2"}`), }, response.RunMetrics{ Name: "key2", Iters: []int64{3}, Values: []float64{124.1}, - Context: []byte(`{"key3":"key3","value3":"value3"}`), + Context: json.RawMessage(`{"key3":"key3","value3":"value3"}`), }, response.RunMetrics{ Name: "key2", Iters: []int64{4}, Values: []float64{124.2}, - Context: []byte(`{"key4":"key4","value4":"value4"}`), + Context: json.RawMessage(`{"key4":"key4","value4":"value4"}`), }, }, }, @@ -233,13 +235,13 @@ func (s *GetRunMetricsTestSuite) Test_Ok() { Name: "key1", Iters: []int64{1}, Values: []float64{123.1}, - Context: []byte(`{"key1":"key1","value1":"value1"}`), + Context: json.RawMessage(`{"key1":"key1","value1":"value1"}`), }, response.RunMetrics{ Name: "key2", Iters: []int64{3}, Values: []float64{124.1}, - Context: []byte(`{"key3":"key3","value3":"value3"}`), + Context: json.RawMessage(`{"key3":"key3","value3":"value3"}`), }, }, }, diff --git a/tests/integration/golang/aim/run/search_test.go b/tests/integration/golang/aim/run/search_test.go index a756ea96a..c82437e23 100644 --- a/tests/integration/golang/aim/run/search_test.go +++ b/tests/integration/golang/aim/run/search_test.go @@ -7,18 +7,16 @@ import ( "encoding/csv" "fmt" "io" - "strings" "testing" "github.com/google/uuid" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" "golang.org/x/exp/slices" - "gorm.io/datatypes" "github.com/G-Research/fasttrackml/pkg/api/aim/encoding" "github.com/G-Research/fasttrackml/pkg/api/aim/request" "github.com/G-Research/fasttrackml/pkg/api/mlflow/dao/models" + "github.com/G-Research/fasttrackml/pkg/common/db/types" "github.com/G-Research/fasttrackml/tests/integration/golang/helpers" ) @@ -92,7 +90,7 @@ func (s *SearchTestSuite) SetupTest() { RunID: run1.ID, LastIter: 1, Context: models.Context{ - Json: datatypes.JSON(`{"key":"value"}`), + Json: types.JSONB(`{"key": "value"}`), }, }) s.Require().Nil(err) @@ -105,7 +103,7 @@ func (s *SearchTestSuite) SetupTest() { RunID: run1.ID, LastIter: 1, Context: models.Context{ - Json: datatypes.JSON(`{"key":"value"}`), + Json: types.JSONB(`{"key": "value"}`), }, }) s.Require().Nil(err) @@ -154,7 +152,7 @@ func (s *SearchTestSuite) SetupTest() { RunID: run2.ID, LastIter: 1, Context: models.Context{ - Json: datatypes.JSON(`{"key":"value"}`), + Json: types.JSONB(`{"key": "value"}`), }, }) s.Require().Nil(err) @@ -203,7 +201,7 @@ func (s *SearchTestSuite) SetupTest() { RunID: run3.ID, LastIter: 3, Context: models.Context{ - Json: datatypes.JSON(`{"key":"value"}`), + Json: types.JSONB(`{"key": "value"}`), }, }) s.Require().Nil(err) @@ -252,7 +250,7 @@ func (s *SearchTestSuite) SetupTest() { RunID: run4.ID, LastIter: 1, Context: models.Context{ - Json: datatypes.JSON(`{"key":"value"}`), + Json: types.JSONB(`{"key": "value"}`), }, }) s.Require().Nil(err) @@ -297,8 +295,8 @@ func (s *SearchTestSuite) TestCSVReport_Ok() { "experiment_description", "date", "duration", - "TestMetric {\"key\":\"value\"}", - "TestMetric2 {\"key\":\"value\"}", + "TestMetric {\"key\": \"value\"}", + "TestMetric2 {\"key\": \"value\"}", "params[param1]", "params[param3]", "tags[mlflow.runName]", @@ -329,18 +327,8 @@ func (s *SearchTestSuite) TestCSVReport_Ok() { }, } - // check headers separately. headers could include information about metric + context and - // because of difference in `json` serialisation between `sqlite` and `postgres` could be - // some problems in comparing. remove all whitespaces for now. - // TODO remove such a case when when we will use `JSONB` instead of `JSON` for both `sqlite` and `postgres`. - for i, expectedRecord := range expectedResult[0] { - assert.Equal( - s.T(), strings.Replace(expectedRecord, " ", "", -1), strings.Replace(records[0][i], " ", "", -1), - ) - } - // check other data records normally. - s.Require().Equal(expectedResult[1:], records[1:]) + s.Require().Equal(expectedResult, records) } func (s *SearchTestSuite) TestStreamData_Ok() { diff --git a/tests/integration/golang/mlflow/metric/get_histories_test.go b/tests/integration/golang/mlflow/metric/get_histories_test.go index b4d56fa71..664df454e 100644 --- a/tests/integration/golang/mlflow/metric/get_histories_test.go +++ b/tests/integration/golang/mlflow/metric/get_histories_test.go @@ -14,6 +14,7 @@ import ( "github.com/G-Research/fasttrackml/pkg/api/mlflow/api/request" "github.com/G-Research/fasttrackml/pkg/api/mlflow/dao/models" "github.com/G-Research/fasttrackml/pkg/api/mlflow/service/metric" + "github.com/G-Research/fasttrackml/pkg/common/db/types" "github.com/G-Research/fasttrackml/tests/integration/golang/helpers" ) @@ -68,7 +69,7 @@ func (s *GetHistoriesTestSuite) Test_Ok() { Step: 1, Iter: 1, Context: models.Context{ - Json: []byte(` + Json: types.JSONB(` { "metrickey1": "metricvalue1", "metrickey2": "metricvalue2", diff --git a/tests/integration/golang/mlflow/metric/get_history_test.go b/tests/integration/golang/mlflow/metric/get_history_test.go index e56fb6aa8..cdd663609 100644 --- a/tests/integration/golang/mlflow/metric/get_history_test.go +++ b/tests/integration/golang/mlflow/metric/get_history_test.go @@ -5,13 +5,13 @@ import ( "testing" "github.com/stretchr/testify/suite" - "gorm.io/datatypes" "github.com/G-Research/fasttrackml/pkg/api/mlflow" "github.com/G-Research/fasttrackml/pkg/api/mlflow/api" "github.com/G-Research/fasttrackml/pkg/api/mlflow/api/request" "github.com/G-Research/fasttrackml/pkg/api/mlflow/api/response" "github.com/G-Research/fasttrackml/pkg/api/mlflow/dao/models" + "github.com/G-Research/fasttrackml/pkg/common/db/types" "github.com/G-Research/fasttrackml/tests/integration/golang/helpers" ) @@ -50,7 +50,7 @@ func (s *GetHistoryTestSuite) Test_Ok() { IsNan: false, Iter: 1, Context: models.Context{ - Json: datatypes.JSON(`{"key": "key", "value": "value"}`), + Json: types.JSONB(`{"key": "key", "value": "value"}`), }, }) s.Require().Nil(err)