diff --git a/pkg/api/aim/api/request/run.go b/pkg/api/aim/api/request/run.go index 3d962f05e..d2e867910 100644 --- a/pkg/api/aim/api/request/run.go +++ b/pkg/api/aim/api/request/run.go @@ -1,6 +1,9 @@ package request import ( + "strconv" + "strings" + "github.com/gofiber/fiber/v2" ) @@ -95,16 +98,16 @@ type SearchAlignedMetricsRequest struct { AlignBy string `json:"align_by"` } -// SearchArtifactsRequest is a request struct for `GET /runs/search/image` endpoint. +// SearchArtifactsRequest is a request struct for `POST /runs/search/image` endpoint. type SearchArtifactsRequest struct { BaseSearchRequest - Query string `query:"q"` - SkipSystem bool `query:"skip_system"` - RecordDensity int `query:"record_density"` - IndexDensity int `query:"index_density"` - RecordRange string `query:"record_range"` - IndexRange string `query:"index_range"` - CalcRanges bool `query:"calc_ranges"` + Query string `json:"q"` + SkipSystem bool `json:"skip_system"` + RecordDensity any `json:"record_density"` + IndexDensity any `json:"index_density"` + RecordRange string `json:"record_range"` + IndexRange string `json:"index_range"` + CalcRanges bool `json:"calc_ranges"` } // DeleteRunRequest is a request struct for `DELETE /runs/:id` endpoint. @@ -129,3 +132,81 @@ type DeleteRunTagRequest struct { RunID string `params:"id"` TagID string `params:"tagID"` } + +// RecordRangeMin returns the low end of the record range. +func (req SearchArtifactsRequest) RecordRangeMin() int { + return rangeMin(req.RecordRange) +} + +// RecordRangeMax returns the high end of the record range. +func (req SearchArtifactsRequest) RecordRangeMax(dflt int) int { + return rangeMax(req.RecordRange, dflt) +} + +// IndexRangeMin returns the low end of the index range. +func (req SearchArtifactsRequest) IndexRangeMin() int { + return rangeMin(req.IndexRange) +} + +// IndexRangeMax returns the high end of the index range. +func (req SearchArtifactsRequest) IndexRangeMax(dflt int) int { + return rangeMax(req.IndexRange, dflt) +} + +// StepCount returns the RecordDensity requested or -1 if not limited. +func (req SearchArtifactsRequest) StepCount() int { + switch v := req.RecordDensity.(type) { + case float64: + return int(v) + case string: + num, err := strconv.Atoi(v) + if err != nil || num < 1 { + return -1 + } + return num + default: + return -1 + } +} + +// ItemsPerStep returns the IndexDensity requested or -1 if not limited. +func (req SearchArtifactsRequest) ItemsPerStep() int { + switch v := req.IndexDensity.(type) { + case float64: + return int(v) + case string: + num, err := strconv.Atoi(v) + if err != nil || num < 1 { + return -1 + } + return num + default: + return -1 + } +} + +// rangeMin will extract the lower end of a range string in the request. +func rangeMin(r string) int { + rangeVals := strings.Split(r, ":") + if len(rangeVals) != 2 { + return 0 + } + num, err := strconv.Atoi(rangeVals[0]) + if err == nil { + return num + } + return 0 +} + +// rangeMax will extract the lower end of a range string in the request. +func rangeMax(r string, dflt int) int { + rangeVals := strings.Split(r, ":") + if len(rangeVals) != 2 { + return dflt + } + num, err := strconv.Atoi(rangeVals[1]) + if err == nil { + return num + } + return dflt +} diff --git a/pkg/api/aim/api/response/project.go b/pkg/api/aim/api/response/project.go index 689b110d7..02d3ff44f 100644 --- a/pkg/api/aim/api/response/project.go +++ b/pkg/api/aim/api/response/project.go @@ -101,6 +101,12 @@ func NewProjectParamsResponse(projectParams *models.ProjectParams, } } + // process images + images := make(fiber.Map, len(projectParams.Images)) + for _, imageName := range projectParams.Images { + images[imageName] = []fiber.Map{} + } + rsp := ProjectParamsResponse{} if !excludeParams { rsp.Params = ¶ms @@ -118,7 +124,7 @@ func NewProjectParamsResponse(projectParams *models.ProjectParams, for _, s := range sequences { switch s { case "images": - rsp.Images = &fiber.Map{} + rsp.Images = &images case "texts": rsp.Texts = &fiber.Map{} case "figures": diff --git a/pkg/api/aim/api/response/run.go b/pkg/api/aim/api/response/run.go index a08f8be4e..c55d1915f 100644 --- a/pkg/api/aim/api/response/run.go +++ b/pkg/api/aim/api/response/run.go @@ -517,8 +517,8 @@ func NewStreamMetricsResponse(ctx *fiber.Ctx, rows *sql.Rows, totalRuns int64, // NewStreamArtifactsResponse streams the provided sql.Rows to the fiber context. // //nolint:gocyclo -func NewStreamArtifactsResponse(ctx *fiber.Ctx, rows *sql.Rows, totalRuns int64, - result repositories.ArtifactSearchSummary, req request.SearchArtifactsRequest, +func NewStreamArtifactsResponse(ctx *fiber.Ctx, rows *sql.Rows, runs map[string]models.Run, + summary repositories.ArtifactSearchSummary, req request.SearchArtifactsRequest, ) { ctx.Context().SetBodyStreamWriter(func(w *bufio.Writer) { //nolint:errcheck @@ -528,17 +528,17 @@ func NewStreamArtifactsResponse(ctx *fiber.Ctx, rows *sql.Rows, totalRuns int64, if err := func() error { var ( - runID string - runData fiber.Map - traces []fiber.Map - cur int64 + runID string + runData fiber.Map + tracesMap map[string]fiber.Map + cur int64 ) reportProgress := func() error { if !req.ReportProgress { return nil } err := encoding.EncodeTree(w, fiber.Map{ - fmt.Sprintf("progress_%d", cur): []int64{cur, totalRuns}, + fmt.Sprintf("progress_%d", cur): []int64{cur, int64(len(runs))}, }) if err != nil { return err @@ -546,23 +546,44 @@ func NewStreamArtifactsResponse(ctx *fiber.Ctx, rows *sql.Rows, totalRuns int64, cur++ return w.Flush() } - addImage := func(img models.Artifact) { + addImage := func(img models.Artifact, run models.Run) { + maxIndex := summary.MaxIndex(img.RunID, img.Name) + maxStep := summary.MaxStep(img.RunID, img.Name) if runData == nil { - imagesPerStep := result.StepImageCount(img.RunID, 0) runData = fiber.Map{ "ranges": fiber.Map{ - "record_range_total": []int{0, result.TotalSteps(img.RunID)}, - "record_range_used": []int{0, int(img.Step)}, - "index_range_total": []int{0, imagesPerStep}, - "index_range_used": []int{0, int(img.Index)}, + "record_range_total": []int{0, maxStep}, + "record_range_used": []int{req.RecordRangeMin(), req.RecordRangeMax(maxStep)}, + "index_range_total": []int{0, maxIndex}, + "index_range_used": []int{req.IndexRangeMin(), req.IndexRangeMax(maxIndex)}, }, "params": fiber.Map{ - "images_per_step": imagesPerStep, + "images_per_step": maxIndex, }, + "props": renderProps(run), } - traces = []fiber.Map{} + tracesMap = map[string]fiber.Map{} } - traces = append(traces, fiber.Map{ + trace, ok := tracesMap[img.Name] + if !ok { + trace = fiber.Map{ + "name": img.Name, + "context": fiber.Map{}, + "caption": img.Caption, + } + tracesMap[img.Name] = trace + } + traceValues, ok := trace["values"].([][]fiber.Map) + if !ok { + stepsSlice := make([][]fiber.Map, maxStep+1) + traceValues = stepsSlice + } + + iters, ok := trace["iters"].([]int64) + if !ok { + iters = make([]int64, maxStep+1) + } + value := fiber.Map{ "blob_uri": img.BlobURI, "caption": img.Caption, "height": img.Height, @@ -571,13 +592,72 @@ func NewStreamArtifactsResponse(ctx *fiber.Ctx, rows *sql.Rows, totalRuns int64, "iter": img.Iter, "index": img.Index, "step": img.Step, - }) + } + + stepImages := traceValues[img.Step] + if stepImages == nil { + stepImages = []fiber.Map{} + } + stepImages = append(stepImages, value) + traceValues[img.Step] = stepImages + iters[img.Step] = img.Iter // TODO maybe not correct + trace["values"] = traceValues + trace["iters"] = iters + tracesMap[img.Name] = trace + } + selectTraces := func() { + // collect the traces for this run, limiting to RecordDensity and IndexDensity. + selectIndices := func(trace fiber.Map) fiber.Map { + // limit steps slice to len of RecordDensity. + stepCount := req.StepCount() + imgCount := req.ItemsPerStep() + steps, ok := trace["values"].([][]fiber.Map) + if !ok { + return trace + } + iters, ok := trace["iters"].([]int64) + if !ok { + return trace + } + filteredSteps := [][]fiber.Map{} + filteredIters := []int64{} + stepInterval := len(steps) / stepCount + for stepIndex := 0; stepIndex < len(steps); stepIndex++ { + if stepCount == -1 || + len(steps) <= stepCount || + stepIndex%stepInterval == 0 { + step := steps[stepIndex] + newStep := []fiber.Map{} + imgInterval := len(step) / imgCount + for imgIndex := 0; imgIndex < len(step); imgIndex++ { + if imgCount == -1 || + len(step) <= imgCount || + imgIndex%imgInterval == 0 { + newStep = append(newStep, step[imgIndex]) + } + } + filteredSteps = append(filteredSteps, newStep) + filteredIters = append(filteredIters, iters[stepIndex]) + } + } + trace["values"] = filteredSteps + trace["iters"] = filteredIters + return trace + } + + traces := make([]fiber.Map, len(tracesMap)) + i := 0 + for _, trace := range tracesMap { + traces[i] = selectIndices(trace) + i++ + } + runData["traces"] = traces } flushImages := func() error { if runID == "" { return nil } - runData["traces"] = traces + selectTraces() if err := encoding.EncodeTree(w, fiber.Map{ runID: runData, }); err != nil { @@ -588,12 +668,12 @@ func NewStreamArtifactsResponse(ctx *fiber.Ctx, rows *sql.Rows, totalRuns int64, } return w.Flush() } + hasRows := false for rows.Next() { var image models.Artifact if err := database.DB.ScanRows(rows, &image); err != nil { return err } - // flush after each change in runID // (assumes order by runID) if image.RunID != runID { @@ -602,18 +682,18 @@ func NewStreamArtifactsResponse(ctx *fiber.Ctx, rows *sql.Rows, totalRuns int64, } runID = image.RunID runData = nil - traces = nil } - addImage(image) - + addImage(image, runs[image.RunID]) + hasRows = true } - if err := flushImages(); err != nil { - return err - } - - if err := reportProgress(); err != nil { - return err + if hasRows { + if err := flushImages(); err != nil { + return err + } + if err := reportProgress(); err != nil { + return err + } } return nil @@ -784,19 +864,7 @@ func NewRunsSearchStreamResponse( if err := func() error { for i, r := range runs { run := fiber.Map{ - "props": fiber.Map{ - "name": r.Name, - "description": nil, - "experiment": fiber.Map{ - "id": fmt.Sprintf("%d", *r.Experiment.ID), - "name": r.Experiment.Name, - }, - "tags": ConvertTagsToMaps(r.SharedTags), - "creation_time": float64(r.StartTime.Int64) / 1000, - "end_time": float64(r.EndTime.Int64) / 1000, - "archived": r.LifecycleStage == models.LifecycleStageDeleted, - "active": r.Status == models.StatusRunning, - }, + "props": renderProps(r), } if !excludeTraces { @@ -886,21 +954,7 @@ func NewActiveRunsStreamResponse(ctx *fiber.Ctx, runs []models.Run, reportProgre start := time.Now() if err := func() error { for i, r := range runs { - - props := fiber.Map{ - "name": r.Name, - "description": nil, - "experiment": fiber.Map{ - "id": fmt.Sprintf("%d", *r.Experiment.ID), - "name": r.Experiment.Name, - }, - "tags": ConvertTagsToMaps(r.SharedTags), - "creation_time": float64(r.StartTime.Int64) / 1000, - "end_time": float64(r.EndTime.Int64) / 1000, - "archived": r.LifecycleStage == models.LifecycleStageDeleted, - "active": r.Status == models.StatusRunning, - } - + props := renderProps(r) metrics := make([]fiber.Map, len(r.LatestMetrics)) for i, m := range r.LatestMetrics { v := m.Value @@ -960,6 +1014,25 @@ func NewActiveRunsStreamResponse(ctx *fiber.Ctx, runs []models.Run, reportProgre return nil } +// renderProps makes the "props" map for a run. +func renderProps(r models.Run) fiber.Map { + m := fiber.Map{ + "name": r.Name, + "description": nil, + "experiment": fiber.Map{ + "id": fmt.Sprintf("%d", r.ExperimentID), + "name": r.Experiment.Name, + "artifact_location": r.Experiment.ArtifactLocation, + }, + "tags": ConvertTagsToMaps(r.SharedTags), + "creation_time": float64(r.StartTime.Int64) / 1000, + "end_time": float64(r.EndTime.Int64) / 1000, + "archived": r.LifecycleStage == models.LifecycleStageDeleted, + "active": r.Status == models.StatusRunning, + } + return m +} + // NewRunImagesStreamResponse streams the provided images to the fiber context. func NewRunImagesStreamResponse(ctx *fiber.Ctx, images []models.Image) error { ctx.Set("Content-Type", "application/octet-stream") diff --git a/pkg/api/aim/controller/runs.go b/pkg/api/aim/controller/runs.go index bc8c59ea2..6f219f06a 100644 --- a/pkg/api/aim/controller/runs.go +++ b/pkg/api/aim/controller/runs.go @@ -250,7 +250,7 @@ func (c Controller) SearchImages(ctx *fiber.Ctx) error { log.Debugf("searchMetrics namespace: %s", ns.Code) req := request.SearchArtifactsRequest{} - if err = ctx.QueryParser(&req); err != nil { + if err = ctx.BodyParser(&req); err != nil { return fiber.NewError(fiber.StatusUnprocessableEntity, err.Error()) } if ctx.Query("report_progress") == "" { @@ -263,12 +263,12 @@ func (c Controller) SearchImages(ctx *fiber.Ctx) error { } //nolint:rowserrcheck - rows, totalRuns, result, err := c.runService.SearchArtifacts(ctx.Context(), ns.ID, tzOffset, req) + rows, runs, result, err := c.runService.SearchArtifacts(ctx.Context(), ns.ID, tzOffset, req) if err != nil { return fiber.NewError(fiber.StatusInternalServerError, err.Error()) } - response.NewStreamArtifactsResponse(ctx, rows, totalRuns, result, req) + response.NewStreamArtifactsResponse(ctx, rows, runs, result, req) return nil } diff --git a/pkg/api/aim/dao/models/artifact.go b/pkg/api/aim/dao/models/artifact.go index 21921d9db..d67e28464 100644 --- a/pkg/api/aim/dao/models/artifact.go +++ b/pkg/api/aim/dao/models/artifact.go @@ -8,7 +8,8 @@ import ( // Artifact represents the artifact model. type Artifact struct { - RowNum int64 // this field is virtual, not persistent + RowNum int64 // RowNum is calculated, not persistent + Name string ID uuid.UUID `gorm:"type:uuid;primaryKey" json:"id"` Iter int64 `gorm:"index"` Step int64 `gorm:"default:0;not null"` diff --git a/pkg/api/aim/dao/models/project.go b/pkg/api/aim/dao/models/project.go index c73d71053..9df291b09 100644 --- a/pkg/api/aim/dao/models/project.go +++ b/pkg/api/aim/dao/models/project.go @@ -14,4 +14,5 @@ type ProjectParams struct { Metrics []LatestMetric TagKeys []string ParamKeys []string + Images []string } diff --git a/pkg/api/aim/dao/repositories/artifact.go b/pkg/api/aim/dao/repositories/artifact.go index 09a7af26e..a2fd332c0 100644 --- a/pkg/api/aim/dao/repositories/artifact.go +++ b/pkg/api/aim/dao/repositories/artifact.go @@ -3,12 +3,16 @@ package repositories import ( "context" "database/sql" + "fmt" + "math" + "strings" "gorm.io/gorm" "github.com/rotisserie/eris" "github.com/G-Research/fasttrackml/pkg/api/aim/api/request" + "github.com/G-Research/fasttrackml/pkg/api/aim/dao/models" "github.com/G-Research/fasttrackml/pkg/api/aim/query" "github.com/G-Research/fasttrackml/pkg/common/dao/repositories" ) @@ -16,21 +20,42 @@ import ( // ArtifactSearchStepInfo is a search summary for a Run Step. type ArtifactSearchStepInfo struct { RunUUID string `gorm:"column:run_uuid"` + Name string `gorm:"column:name"` Step int `gorm:"column:step"` ImgCount int `gorm:"column:img_count"` + MaxIndex int `gorm:"column:max_index"` } -// ArtifactSearchSummary is a search summary for whole run. -type ArtifactSearchSummary map[string][]ArtifactSearchStepInfo +// ArtifactSearchSummary is a search summary for run and name. +type ArtifactSearchSummary map[string]map[string][]ArtifactSearchStepInfo + +// MaxStep figures out the max step belonging to the runID and sequence name. +func (r ArtifactSearchSummary) MaxStep(runID, name string) int { + runSequence := r[runID][name] + maxStep := 0 + for _, step := range runSequence { + if step.Step > maxStep { + maxStep = step.Step + } + } + return maxStep +} -// TotalSteps figures out how many steps belong to the runID. -func (r ArtifactSearchSummary) TotalSteps(runID string) int { - return len(r[runID]) +// MaxIndex figures out the maximum index for the runID and sequence name. +func (r ArtifactSearchSummary) MaxIndex(runID, name string) int { + runSequence := r[runID][name] + maxIndex := 1 + for _, step := range runSequence { + if step.MaxIndex > maxIndex { + maxIndex = step.MaxIndex + } + } + return maxIndex } // StepImageCount figures out how many steps belong to the runID and step. -func (r ArtifactSearchSummary) StepImageCount(runID string, step int) int { - runStepImages := r[runID] +func (r ArtifactSearchSummary) StepImageCount(runID, name string, step int) int { + runStepImages := r[runID][name] return runStepImages[step].ImgCount } @@ -43,7 +68,10 @@ type ArtifactRepositoryProvider interface { namespaceID uint, timeZoneOffset int, req request.SearchArtifactsRequest, - ) (*sql.Rows, int64, ArtifactSearchSummary, error) + ) (*sql.Rows, map[string]models.Run, ArtifactSearchSummary, error) + GetArtifactNamesByExperiments( + ctx context.Context, namespaceID uint, experiments []int, + ) ([]string, error) } // ArtifactRepository repository to work with `artifact` entity. @@ -64,7 +92,7 @@ func (r ArtifactRepository) Search( namespaceID uint, timeZoneOffset int, req request.SearchArtifactsRequest, -) (*sql.Rows, int64, ArtifactSearchSummary, error) { +) (*sql.Rows, map[string]models.Run, ArtifactSearchSummary, error) { qp := query.QueryParser{ Default: query.DefaultExpression{ Contains: "run.archived", @@ -73,59 +101,124 @@ func (r ArtifactRepository) Search( Tables: map[string]string{ "runs": "runs", "experiments": "experiments", + "artifacts": "artifacts", }, TzOffset: timeZoneOffset, Dialector: r.GetDB().Dialector.Name(), } pq, err := qp.Parse(req.Query) if err != nil { - return nil, 0, nil, err + return nil, nil, nil, err } runIDs := []string{} + runs := []models.Run{} if tx := pq.Filter(r.GetDB().WithContext(ctx). - Select("runs.run_uuid"). Table("runs"). Joins(`INNER JOIN experiments ON experiments.experiment_id = runs.experiment_id AND experiments.namespace_id = ?`, namespaceID, )). - Find(&runIDs); tx.Error != nil { - return nil, 0, nil, eris.Wrap(err, "error finding runs for artifact search") + Preload("Experiment"). + Find(&runs); tx.Error != nil { + return nil, nil, nil, eris.Wrap(err, "error finding runs for artifact search") + } + + runMap := make(map[string]models.Run, len(runs)) + for _, run := range runs { + if _, ok := runMap[run.ID]; !ok { + runIDs = append(runIDs, run.ID) + runMap[run.ID] = run + } } // collect some summary data for progress indicator stepInfo := []ArtifactSearchStepInfo{} if tx := r.GetDB().WithContext(ctx). - Raw(`SELECT run_uuid, step, count(id) as img_count - FROM artifacts - WHERE run_uuid IN (?) - GROUP BY run_uuid, step;`, runIDs). + Raw(`SELECT run_uuid, name, step, count(id) as img_count, max("index") as max_index + FROM artifacts + WHERE run_uuid IN (?) + GROUP BY run_uuid, name, step;`, + runIDs). Find(&stepInfo); tx.Error != nil { - return nil, 0, nil, eris.Wrap(err, "error find result summary for artifact search") + return nil, nil, nil, eris.Wrap(err, "error find result summary for artifact search") } + imageNames := []string{} + imageNameQueryTemplate := `images.name == "%s"` resultSummary := make(ArtifactSearchSummary, len(runIDs)) for _, rslt := range stepInfo { - resultSummary[rslt.RunUUID] = append(resultSummary[rslt.RunUUID], rslt) + traceMap, ok := resultSummary[rslt.RunUUID] + if !ok { + traceMap = map[string][]ArtifactSearchStepInfo{} + } + traceMap[rslt.Name] = append(traceMap[rslt.Name], rslt) + resultSummary[rslt.RunUUID] = traceMap + qImage := fmt.Sprintf(imageNameQueryTemplate, rslt.Name) + if strings.Contains(req.Query, qImage) { + imageNames = append(imageNames, rslt.Name) + } } // get a cursor for the artifacts tx := r.GetDB().WithContext(ctx). - Table("artifacts"). - Where("run_uuid IN ?", runIDs). - Order("run_uuid"). - Order("step"). - Order("created_at") + Raw(` + SELECT artifacts.*, rows.row_num + FROM artifacts + JOIN ( + SELECT id, ROW_NUMBER() OVER() row_num + FROM artifacts + ) rows USING (id) + WHERE run_uuid IN ? + AND step BETWEEN ? AND ? + AND "index" BETWEEN ? AND ? + AND name IN ? + ORDER BY run_uuid, name, step + `, + runIDs, + req.RecordRangeMin(), + req.RecordRangeMax(math.MaxInt16), + req.IndexRangeMin(), + req.IndexRangeMax(math.MaxInt16), + imageNames) rows, err := tx.Rows() if err != nil { - return nil, 0, nil, eris.Wrap(err, "error searching artifacts") + return nil, nil, nil, eris.Wrap(err, "error searching artifacts") } if err := rows.Err(); err != nil { - return nil, 0, nil, eris.Wrap(err, "error getting artifacts rows cursor") + return nil, nil, nil, eris.Wrap(err, "error getting artifacts rows cursor") + } + + return rows, runMap, resultSummary, nil +} + +// GetArtifactNamesByExperiments will find image names in the selected experiments. +func (r ArtifactRepository) GetArtifactNamesByExperiments( + ctx context.Context, namespaceID uint, experiments []int, +) ([]string, error) { + runIDs := []string{} + if err := r.GetDB().WithContext(ctx). + Select("run_uuid"). + Table("runs"). + Joins(`INNER JOIN experiments + ON experiments.experiment_id = runs.experiment_id + AND experiments.namespace_id = ? + AND experiments.experiment_id IN ?`, + namespaceID, experiments, + ). + Find(&runIDs).Error; err != nil { + return nil, eris.Wrap(err, "error finding runs for artifacts") } - return rows, int64(len(runIDs)), resultSummary, nil + imageNames := []string{} + if err := r.GetDB().WithContext(ctx). + Distinct("name"). + Table("artifacts"). + Where("run_uuid IN ?", runIDs). + Find(&imageNames).Error; err != nil { + return nil, eris.Wrap(err, "error finding runs for artifact search") + } + return imageNames, nil } diff --git a/pkg/api/aim/query/query.go b/pkg/api/aim/query/query.go index df634fcfc..db2188e1d 100644 --- a/pkg/api/aim/query/query.go +++ b/pkg/api/aim/query/query.go @@ -706,6 +706,37 @@ func (pq *parsedQuery) parseName(node *ast.Name) (any, error) { ).UnixMilli(), nil }, ), nil + case "images": + table, ok := pq.qp.Tables["runs"] + if !ok { + return nil, errors.New("unsupported name identifier 'runs'") + } + return attributeGetter( + func(attr string) (any, error) { + joinKey := fmt.Sprintf("artifacts:%s", attr) + j, ok := pq.joins[joinKey] + alias := fmt.Sprintf("artifacts_%d", len(pq.joins)) + if !ok { + j = join{ + alias: alias, + query: fmt.Sprintf( + "INNER JOIN artifacts %s ON %s.run_uuid = %s.run_uuid", + alias, table, alias, + ), + args: []any{attr}, + } + pq.AddJoin(joinKey, j) + } + switch attr { + case "name": + return clause.Column{ + Table: j.alias, + Name: "name", + }, nil + } + return nil, fmt.Errorf("unsupported name identifier %q", node.Id) + }, + ), nil default: return nil, fmt.Errorf("unsupported name identifier %q", node.Id) } diff --git a/pkg/api/aim/query/query_test.go b/pkg/api/aim/query/query_test.go index c46f10fc6..de5a2e5da 100644 --- a/pkg/api/aim/query/query_test.go +++ b/pkg/api/aim/query/query_test.go @@ -374,6 +374,14 @@ func (s *QueryTestSuite) TestSqliteDialector_Ok() { `AND ("metrics_0"."value" < $4 AND "runs"."lifecycle_stage" <> $5)`, expectedVars: []interface{}{"my_metric", "$.key1", "value1", -1, models.LifecycleStageDeleted}, }, + { + name: "TestImagesName", + query: `(images.name == 'my-image')`, + expectedSQL: `SELECT "run_uuid" FROM "runs" ` + + `INNER JOIN artifacts artifacts_0 ON runs.run_uuid = artifacts_0.run_uuid ` + + `WHERE "artifacts_0"."name" = $1 AND "runs"."lifecycle_stage" <> $2`, + expectedVars: []interface{}{"my-image", models.LifecycleStageDeleted}, + }, } for _, tt := range tests { @@ -387,6 +395,7 @@ func (s *QueryTestSuite) TestSqliteDialector_Ok() { "runs": "runs", "experiments": "Experiment", "metrics": "metrics", + "images": "images", }, Dialector: sqlite.Dialector{}.Name(), } diff --git a/pkg/api/aim/routes.go b/pkg/api/aim/routes.go index f51008a78..ef82dec78 100644 --- a/pkg/api/aim/routes.go +++ b/pkg/api/aim/routes.go @@ -64,7 +64,7 @@ func (r *Router) Init(server fiber.Router) { runs.Get("/search/run/", r.controller.SearchRuns) runs.Post("/search/metric/", r.controller.SearchMetrics) runs.Post("/search/metric/align/", r.controller.SearchAlignedMetrics) - runs.Post("/search/image/", r.controller.SearchImages) + runs.Post("/search/images/", r.controller.SearchImages) runs.Get("/:id/info/", r.controller.GetRunInfo) runs.Post("/:id/tags/new", r.controller.AddRunTag) runs.Delete("/:id/tags/:tagID", r.controller.DeleteRunTag) diff --git a/pkg/api/aim/services/project/service.go b/pkg/api/aim/services/project/service.go index 344303b1e..53e9f00de 100644 --- a/pkg/api/aim/services/project/service.go +++ b/pkg/api/aim/services/project/service.go @@ -18,6 +18,7 @@ type Service struct { paramRepository repositories.ParamRepositoryProvider metricRepository repositories.MetricRepositoryProvider experimentRepository repositories.ExperimentRepositoryProvider + artifactRepository repositories.ArtifactRepositoryProvider liveUpdatesEnabled bool } @@ -28,6 +29,7 @@ func NewService( paramRepository repositories.ParamRepositoryProvider, metricRepository repositories.MetricRepositoryProvider, experimentRepository repositories.ExperimentRepositoryProvider, + artifactRepository repositories.ArtifactRepositoryProvider, liveUpdatesEnabled bool, ) *Service { return &Service{ @@ -36,6 +38,7 @@ func NewService( paramRepository: paramRepository, metricRepository: metricRepository, experimentRepository: experimentRepository, + artifactRepository: artifactRepository, liveUpdatesEnabled: liveUpdatesEnabled, } } @@ -113,5 +116,15 @@ func (s Service) GetProjectParams( } projectParams.Metrics = metrics } + if slices.Contains(req.Sequences, "images") { + // fetch images available for requested Experiments. + images, err := s.artifactRepository.GetArtifactNamesByExperiments( + ctx, namespaceID, req.Experiments, + ) + if err != nil { + return nil, api.NewInternalError("error getting images: %s", err) + } + projectParams.Images = images + } return &projectParams, nil } diff --git a/pkg/api/aim/services/run/service.go b/pkg/api/aim/services/run/service.go index 446c05443..0604a8806 100644 --- a/pkg/api/aim/services/run/service.go +++ b/pkg/api/aim/services/run/service.go @@ -233,12 +233,12 @@ func (s Service) SearchMetrics( // SearchArtifacts returns the list of artifacts (images) by provided search criteria. func (s Service) SearchArtifacts( ctx context.Context, namespaceID uint, timeZoneOffset int, req request.SearchArtifactsRequest, -) (*sql.Rows, int64, repositories.ArtifactSearchSummary, error) { - rows, total, result, err := s.artifactRepository.Search(ctx, namespaceID, timeZoneOffset, req) +) (*sql.Rows, map[string]models.Run, repositories.ArtifactSearchSummary, error) { + rows, runs, result, err := s.artifactRepository.Search(ctx, namespaceID, timeZoneOffset, req) if err != nil { - return nil, 0, nil, api.NewInternalError("error searching artifacts: %s", err) + return nil, nil, nil, api.NewInternalError("error searching artifacts: %s", err) } - return rows, total, result, nil + return rows, runs, result, nil } // SearchAlignedMetrics returns the list of aligned metrics. diff --git a/pkg/api/mlflow/api/request/log.go b/pkg/api/mlflow/api/request/log.go index 474a924f7..b6dbc6fa9 100644 --- a/pkg/api/mlflow/api/request/log.go +++ b/pkg/api/mlflow/api/request/log.go @@ -76,6 +76,7 @@ type LogOutputRequest struct { // LogArtifactRequest is a request object for `POST mlflow/runs/log-artifact` endpoint. type LogArtifactRequest struct { + Name string `json:"name"` Iter int64 `json:"iter"` Step int64 `json:"step"` Caption string `json:"caption"` diff --git a/pkg/api/mlflow/dao/models/artifact.go b/pkg/api/mlflow/dao/models/artifact.go index 59cbc100e..b7eb1803e 100644 --- a/pkg/api/mlflow/dao/models/artifact.go +++ b/pkg/api/mlflow/dao/models/artifact.go @@ -4,11 +4,14 @@ import ( "time" "github.com/google/uuid" + "github.com/rotisserie/eris" + "gorm.io/gorm" ) // Artifact represents the artifact model. type Artifact struct { ID uuid.UUID `gorm:"type:uuid;primaryKey" json:"id"` + Name string `gorm:"not null;index"` Iter int64 `gorm:"index"` Step int64 `gorm:"default:0;not null"` Run Run @@ -22,3 +25,23 @@ type Artifact struct { CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` } + +// AfterSave will calculate the iter number for this step sequence based on creation time. +func (u *Artifact) AfterSave(tx *gorm.DB) error { + if err := tx.Exec( + `UPDATE artifacts + SET iter = rows.new_iter + FROM ( + SELECT id, ROW_NUMBER() OVER (ORDER BY created_at) as new_iter + FROM artifacts + WHERE run_uuid = ? + AND name = ? + AND step = ? + ) as rows + WHERE artifacts.id = rows.id`, + u.RunID, u.Name, u.Step, + ).Error; err != nil { + return eris.Wrap(err, "error updating artifacts iter") + } + return nil +} diff --git a/pkg/api/mlflow/services/run/converters.go b/pkg/api/mlflow/services/run/converters.go index 2ce84c2fc..3c728aac7 100644 --- a/pkg/api/mlflow/services/run/converters.go +++ b/pkg/api/mlflow/services/run/converters.go @@ -32,6 +32,7 @@ func ConvertCreateRunArtifactRequestToModel( ) *models.Artifact { return &models.Artifact{ ID: uuid.New(), + Name: req.Name, Iter: req.Iter, Step: req.Step, RunID: req.RunID, diff --git a/pkg/database/migrations/v_0017/model.go b/pkg/database/migrations/v_0017/model.go index c57652a53..288bbbf16 100644 --- a/pkg/database/migrations/v_0017/model.go +++ b/pkg/database/migrations/v_0017/model.go @@ -306,8 +306,9 @@ type RoleNamespace struct { type Artifact struct { Base - Iter int64 `gorm:"index"` - Step int64 `gorm:"default:0;not null"` + Name string `gorm:"not null;index"` + Iter int64 `gorm:"index"` + Step int64 `gorm:"default:0;not null"` Run Run RunID string `gorm:"column:run_uuid;not null;index;constraint:OnDelete:CASCADE"` Index int64 diff --git a/pkg/database/model.go b/pkg/database/model.go index 917ffe9c3..1c61a9e3c 100644 --- a/pkg/database/model.go +++ b/pkg/database/model.go @@ -312,8 +312,9 @@ type RoleNamespace struct { type Artifact struct { Base - Iter int64 `gorm:"index"` - Step int64 `gorm:"default:0;not null"` + Name string `gorm:"not null;index"` + Iter int64 `gorm:"index"` + Step int64 `gorm:"default:0;not null"` Run Run RunID string `gorm:"column:run_uuid;not null;index;constraint:OnDelete:CASCADE"` Index int64 diff --git a/pkg/server/server.go b/pkg/server/server.go index b1a2c9945..4865ee30f 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -295,6 +295,7 @@ func createApp( aimRepositories.NewParamRepository(db.GormDB()), aimRepositories.NewMetricRepository(db.GormDB()), aimRepositories.NewExperimentRepository(db.GormDB()), + aimRepositories.NewArtifactRepository(db.GormDB()), config.LiveUpdatesEnabled, ), aimDashboardService.NewService( diff --git a/python/client_test.py b/python/client_test.py index a6a891c3d..2d812b2bd 100644 --- a/python/client_test.py +++ b/python/client_test.py @@ -120,8 +120,11 @@ def test_init_output_logging(client, server, run): def test_log_image(client, server, run): # test logging some images - for i in range(100): + for i in range(10): img_local = posixpath.join(os.path.dirname(__file__), "dice.png") assert ( - client.log_image(run.info.run_id, img_local, "images", "These are dice", 0, 640, 480, "png", i, 0) == None + client.log_image( + run.info.run_id, "sequence name", img_local, "images", "These are dice", 0, 640, 480, "png", i, 0 + ) + == None ) diff --git a/python/fasttrackml/_tracking_service/client.py b/python/fasttrackml/_tracking_service/client.py index 0e6647755..8faa53985 100644 --- a/python/fasttrackml/_tracking_service/client.py +++ b/python/fasttrackml/_tracking_service/client.py @@ -148,6 +148,7 @@ def log_output( def log_image( self, run_id: str, + name: str, filename: str, artifact_path: str, caption: str, @@ -161,4 +162,6 @@ def log_image( # 1. log the artifact self.log_artifact(run_id, filename, artifact_path) # 2. log the image metadata - self.custom_store.log_image(run_id, filename, artifact_path, caption, index, width, height, format, step, iter) + self.custom_store.log_image( + run_id, name, filename, artifact_path, caption, index, width, height, format, step, iter + ) diff --git a/python/fasttrackml/client.py b/python/fasttrackml/client.py index a2fc88ac6..ae79eadab 100644 --- a/python/fasttrackml/client.py +++ b/python/fasttrackml/client.py @@ -465,6 +465,7 @@ def log_output( def log_image( self, run_id: str, + name: str, filename: str, artifact_path: str, caption: str, @@ -481,6 +482,7 @@ def log_image( Args: run_id: String ID of the run + name: String the name for this sequence of images filename: The filename of the image in the local filesystem artifact_path: The optional path to append to the artifact_uri caption: The image caption @@ -512,5 +514,5 @@ def log_image( client.set_terminated(run.info.run_id) """ return self._tracking_client.log_image( - run_id, filename, artifact_path, caption, index, width, height, format, step, iter + run_id, name, filename, artifact_path, caption, index, width, height, format, step, iter ) diff --git a/python/fasttrackml/store/custom_rest_store.py b/python/fasttrackml/store/custom_rest_store.py index d76d2e804..e6c818e2e 100644 --- a/python/fasttrackml/store/custom_rest_store.py +++ b/python/fasttrackml/store/custom_rest_store.py @@ -280,6 +280,7 @@ def log_output(self, run_id, data): def log_image( self, run_id: str, + name: str, filename: str, artifact_path: str, caption: str, @@ -293,6 +294,7 @@ def log_image( storage_path = posixpath.join(artifact_path, os.path.basename(filename)) request_body = { "run_id": run_id, + "name": name, "blob_uri": storage_path, "caption": caption, "index": index, diff --git a/tests/integration/golang/aim/run/search_artifacts_test.go b/tests/integration/golang/aim/run/search_artifacts_test.go index 718974827..916811e0e 100644 --- a/tests/integration/golang/aim/run/search_artifacts_test.go +++ b/tests/integration/golang/aim/run/search_artifacts_test.go @@ -6,6 +6,7 @@ import ( "database/sql" "fmt" "net/http" + "strings" "testing" "github.com/google/uuid" @@ -28,14 +29,15 @@ func TestSearchArtifactsTestSuite(t *testing.T) { func (s *SearchArtifactsTestSuite) Test_Ok() { // create test experiments. experiment, err := s.ExperimentFixtures.CreateExperiment(context.Background(), &models.Experiment{ - Name: uuid.New().String(), - LifecycleStage: models.LifecycleStageActive, - NamespaceID: s.DefaultNamespace.ID, + Name: uuid.New().String(), + LifecycleStage: models.LifecycleStageActive, + NamespaceID: s.DefaultNamespace.ID, + ArtifactLocation: "s3://my-bucket", }) s.Require().Nil(err) run1, err := s.RunFixtures.CreateRun(context.Background(), &models.Run{ - ID: "id1", + ID: strings.ReplaceAll(uuid.New().String(), "-", ""), Name: "TestRun1", UserID: "1", Status: models.StatusRunning, @@ -53,22 +55,27 @@ func (s *SearchArtifactsTestSuite) Test_Ok() { LifecycleStage: models.LifecycleStageActive, }) s.Require().Nil(err) - _, err = s.ArtifactFixtures.CreateArtifact(context.Background(), &models.Artifact{ - ID: uuid.New(), - RunID: run1.ID, - BlobURI: "path/filename.png", - Step: 1, - Iter: 1, - Index: 1, - Caption: "caption1", - Format: "png", - Width: 100, - Height: 100, - }) - s.Require().Nil(err) + for i := 0; i < 5; i++ { + for j := 0; j < 5; j++ { + _, err = s.ArtifactFixtures.CreateArtifact(context.Background(), &models.Artifact{ + ID: uuid.New(), + Name: "some-name", + RunID: run1.ID, + BlobURI: "path/filename.png", + Step: int64(i), + Iter: 1, + Index: int64(j), + Caption: "caption1", + Format: "png", + Width: 100, + Height: 100, + }) + s.Require().Nil(err) + } + } run2, err := s.RunFixtures.CreateRun(context.Background(), &models.Run{ - ID: "id2", + ID: strings.ReplaceAll(uuid.New().String(), "-", ""), Name: "TestRun2", UserID: "1", Status: models.StatusRunning, @@ -86,30 +93,119 @@ func (s *SearchArtifactsTestSuite) Test_Ok() { LifecycleStage: models.LifecycleStageActive, }) s.Require().Nil(err) - _, err = s.ArtifactFixtures.CreateArtifact(context.Background(), &models.Artifact{ - ID: uuid.New(), - RunID: run2.ID, - BlobURI: "path/filename.png", - Step: 1, - Iter: 1, - Index: 1, - Caption: "caption2", - Format: "png", - Width: 100, - Height: 100, - }) - s.Require().Nil(err) + for i := 0; i < 5; i++ { + for j := 0; j < 5; j++ { + _, err = s.ArtifactFixtures.CreateArtifact(context.Background(), &models.Artifact{ + ID: uuid.New(), + Name: "other-name", + RunID: run2.ID, + BlobURI: "path/filename.png", + Step: int64(i), + Iter: 1, + Index: int64(j), + Caption: "caption2", + Format: "png", + Width: 100, + Height: 100, + }) + s.Require().Nil(err) + } + } - runs := []*models.Run{run1, run2} tests := []struct { - name string - request request.SearchArtifactsRequest - metrics []*models.LatestMetric + name string + request request.SearchArtifactsRequest + includedRuns []*models.Run + excludedRuns []*models.Run + expectedRecordRangeUsedMax int64 + expectedIndexRangeUsedMax int64 + expectedImageIndexesPresent []int + expectedImageIndexesAbsent []int + expectedValuesIndexesPresent []int + expectedValuesIndexesAbsent []int }{ { - name: "SearchArtifact", - request: request.SearchArtifactsRequest{}, - metrics: []*models.LatestMetric{}, + name: "SearchArtifact", + request: request.SearchArtifactsRequest{ + Query: `((images.name == "some-name") or (images.name == "other-name"))`, + }, + includedRuns: []*models.Run{run1, run2}, + expectedRecordRangeUsedMax: 4, + expectedIndexRangeUsedMax: 4, + expectedImageIndexesPresent: []int{0, 1, 2, 3}, + expectedImageIndexesAbsent: []int{}, + expectedValuesIndexesPresent: []int{0, 1, 2, 3, 4}, + expectedValuesIndexesAbsent: []int{}, + }, + { + name: "SearchArtifactWithNameQuery", + request: request.SearchArtifactsRequest{ + Query: `((images.name == "some-name"))`, + }, + includedRuns: []*models.Run{run1}, + excludedRuns: []*models.Run{run2}, + expectedRecordRangeUsedMax: 4, + expectedIndexRangeUsedMax: 4, + expectedImageIndexesPresent: []int{0, 1, 2, 3}, + expectedImageIndexesAbsent: []int{}, + expectedValuesIndexesPresent: []int{0, 1, 2, 3, 4}, + expectedValuesIndexesAbsent: []int{}, + }, + { + name: "SearchArtifactWithRecordRange", + request: request.SearchArtifactsRequest{ + Query: `((images.name == "some-name"))`, + RecordRange: "0:2", + }, + includedRuns: []*models.Run{run1}, + expectedRecordRangeUsedMax: 2, + expectedIndexRangeUsedMax: 4, + expectedImageIndexesPresent: []int{0, 1, 2, 3}, + expectedImageIndexesAbsent: []int{}, + expectedValuesIndexesPresent: []int{0, 1, 2}, + expectedValuesIndexesAbsent: []int{3, 4}, + }, + { + name: "SearchArtifactWithIndexRange", + request: request.SearchArtifactsRequest{ + Query: `((images.name == "some-name"))`, + IndexRange: "0:2", + }, + includedRuns: []*models.Run{run1}, + expectedRecordRangeUsedMax: 4, + expectedIndexRangeUsedMax: 2, + expectedImageIndexesPresent: []int{0, 1, 2}, + expectedImageIndexesAbsent: []int{3}, + expectedValuesIndexesPresent: []int{0, 1, 2, 3, 4}, + expectedValuesIndexesAbsent: []int{}, + }, + { + name: "SearchArtifactWithIndexDensity", + request: request.SearchArtifactsRequest{ + Query: `((images.name == "other-name"))`, + IndexDensity: 1, + }, + includedRuns: []*models.Run{run2}, + expectedRecordRangeUsedMax: 4, + expectedIndexRangeUsedMax: 4, + expectedImageIndexesPresent: []int{0}, + expectedImageIndexesAbsent: []int{1, 2, 3}, + expectedValuesIndexesPresent: []int{0, 1, 2, 3, 4}, + expectedValuesIndexesAbsent: []int{}, + }, + { + name: "SearchArtifactWithRecordDensity", + request: request.SearchArtifactsRequest{ + Query: `((images.name == "other-name"))`, + RecordDensity: 1, + }, + includedRuns: []*models.Run{run2}, + expectedRecordRangeUsedMax: 4, + expectedIndexRangeUsedMax: 4, + expectedImageIndexesPresent: []int{0}, + expectedImageIndexesAbsent: []int{1, 2, 3}, + expectedValuesIndexesPresent: []int{0}, + expectedValuesIndexesAbsent: []int{1, 2, 3, 4}, }, } for _, tt := range tests { @@ -124,22 +220,54 @@ func (s *SearchArtifactsTestSuite) Test_Ok() { helpers.ResponseTypeBuffer, ).WithResponse( resp, - ).DoRequest("/runs/search/image"), + ).DoRequest("/runs/search/images"), ) decodedData, err := encoding.NewDecoder(resp).Decode() s.Require().Nil(err) - for _, run := range runs { + for _, run := range tt.includedRuns { + traceIndex := 0 + rangesPrefix := fmt.Sprintf("%v.ranges", run.ID) + recordRangeKey := rangesPrefix + ".record_range_used.1" + s.Equal(tt.expectedRecordRangeUsedMax, decodedData[recordRangeKey]) + propsPrefix := fmt.Sprintf("%v.props", run.ID) + artifactLocation := propsPrefix + ".experiment.artifact_location" + s.Equal(experiment.ArtifactLocation, decodedData[artifactLocation]) + indexRangeKey := rangesPrefix + ".index_range_used.1" + s.Equal(tt.expectedIndexRangeUsedMax, decodedData[indexRangeKey]) + tracesPrefix := fmt.Sprintf("%v.traces.%d", run.ID, traceIndex) + for _, valuesIndex := range tt.expectedValuesIndexesPresent { + for _, imgIndex := range tt.expectedImageIndexesPresent { + valuesPrefix := fmt.Sprintf(".values.%d.%d", valuesIndex, imgIndex) + blobUriKey := tracesPrefix + valuesPrefix + ".blob_uri" + s.Contains(decodedData, blobUriKey) + s.Equal("path/filename.png", decodedData[blobUriKey]) + } + } + for _, valuesIndex := range tt.expectedValuesIndexesAbsent { + for _, imgIndex := range tt.expectedImageIndexesAbsent { + valuesPrefix := fmt.Sprintf(".values.%d.%d", valuesIndex, imgIndex) + blobUriKey := tracesPrefix + valuesPrefix + ".blob_uri" + s.NotContains(decodedData, blobUriKey) + } + } + } + for _, run := range tt.excludedRuns { imgIndex := 0 + valuesIndex := 0 rangesPrefix := fmt.Sprintf("%v.ranges", run.ID) - recordRangeKey := rangesPrefix + ".record_range_total.1" - s.Equal(int64(1), decodedData[recordRangeKey]) - indexRangeKey := rangesPrefix + ".index_range_total.1" - s.Equal(int64(1), decodedData[indexRangeKey]) + recordRangeKey := rangesPrefix + ".record_range_used.1" + s.Empty(decodedData[recordRangeKey]) + propsPrefix := fmt.Sprintf("%v.props", run.ID) + artifactLocation := propsPrefix + ".experiment.artifact_location" + s.Empty(decodedData[artifactLocation]) + indexRangeKey := rangesPrefix + ".index_range_used.1" + s.Empty(decodedData[indexRangeKey]) tracesPrefix := fmt.Sprintf("%v.traces.%d", run.ID, imgIndex) - blobUriKey := tracesPrefix + ".blob_uri" - s.Equal("path/filename.png", decodedData[blobUriKey]) + valuesPrefix := fmt.Sprintf(".values.%d", valuesIndex) + blobUriKey := tracesPrefix + valuesPrefix + ".blob_uri" + s.Empty(decodedData[blobUriKey]) } }) } diff --git a/tests/integration/python/aim.patch b/tests/integration/python/aim.patch index 8fab74592..08a0eab2d 100644 --- a/tests/integration/python/aim.patch +++ b/tests/integration/python/aim.patch @@ -1,5 +1,5 @@ diff --git a/tests/api/test_dashboards_api.py b/tests/api/test_dashboards_api.py -index 94bc29c..85000fc 100644 +index 94bc29ce..85000fc4 100644 --- a/tests/api/test_dashboards_api.py +++ b/tests/api/test_dashboards_api.py @@ -1,7 +1,7 @@ @@ -33,7 +33,7 @@ index 94bc29c..85000fc 100644 def test_list_dashboards_api(self): response = self.client.get('/api/dashboards/') diff --git a/tests/api/test_project_api.py b/tests/api/test_project_api.py -index 31d3654..060d528 100644 +index 31d36547..060d528a 100644 --- a/tests/api/test_project_api.py +++ b/tests/api/test_project_api.py @@ -1,21 +1,17 @@ @@ -128,8 +128,229 @@ index 31d3654..060d528 100644 def test_project_images_info_only_api(self): client = self.client +diff --git a/tests/api/test_run_images_api.py b/tests/api/test_run_images_api.py +index 2726dcd9..401484b9 100644 +--- a/tests/api/test_run_images_api.py ++++ b/tests/api/test_run_images_api.py +@@ -1,20 +1,13 @@ + from parameterized import parameterized + import random + +-from tests.base import ApiTestBase +-from tests.utils import decode_encoded_tree_stream, generate_image_set +- +-from aim.storage.treeutils import decode_tree +-from aim.storage.context import Context +-from aim.sdk.run import Run +- ++from tests.fml import ApiTestBase, db_fixtures + ++@db_fixtures() + class TestNoImagesRunQueryApi(ApiTestBase): + def test_query_images_api_empty_result(self): + client = self.client +- +- query = self.isolated_query_patch() +- response = client.get('/api/runs/search/images/', params={'q': query, 'report_progress': False}) ++ response = client.post('/api/runs/search/images/', json={'q': 'run.name=="not-found"','report_progress': False}) + self.assertEqual(200, response.status_code) + self.assertEqual(b'', response.content) + +@@ -22,13 +15,13 @@ class TestNoImagesRunQueryApi(ApiTestBase): + class RunImagesTestBase(ApiTestBase): + @classmethod + def setUpClass(cls) -> None: +- super().setUpClass() ++ # super().setUpClass() + run = cls.create_run(repo=cls.repo) + run['images_per_step'] = 16 +- for step in range(100): +- images = generate_image_set(img_count=16, caption_prefix=f'Image {step}') +- run.track(images, name='random_images') +- run.track(random.random(), name='random_values') ++ # for step in range(100): ++ # images = generate_image_set(img_count=16, caption_prefix=f'Image {step}') ++ # run.track(images, name='random_images') ++ # run.track(random.random(), name='random_values') + cls.run_hash = run.hash + + +@@ -37,7 +30,7 @@ class TestRunImagesSearchApi(RunImagesTestBase): + client = self.client + + query = self.isolated_query_patch() +- response = client.get('/api/runs/search/images/', params={'q': query, 'report_progress': False}) ++ response = client.post('/api/runs/search/images/', json={'q': query, 'report_progress': False}) + self.assertEqual(200, response.status_code) + + decoded_response = decode_tree(decode_encoded_tree_stream(response.iter_bytes(chunk_size=512 * 1024))) +@@ -70,8 +63,8 @@ class TestRunImagesSearchApi(RunImagesTestBase): + client = self.client + + query = self.isolated_query_patch() +- response = client.get('/api/runs/search/images/', +- params={'q': query, 'record_density': 200, 'index_density': 10, 'report_progress': False}) ++ response = client.post('/api/runs/search/images/', ++ json={'q': query, 'record_density': 200, 'index_density': 10, 'report_progress': False}) + self.assertEqual(200, response.status_code) + + decoded_response = decode_tree(decode_encoded_tree_stream(response.iter_bytes(chunk_size=512 * 1024), +@@ -88,7 +81,7 @@ class TestRunImagesSearchApi(RunImagesTestBase): + client = self.client + + query = self.isolated_query_patch() +- response = client.get('/api/runs/search/images/', params={'q': query, ++ response = client.post('/api/runs/search/images/', json={'q': query, + 'record_density': 10, + 'index_density': 4, + 'report_progress': False}) +@@ -116,7 +109,7 @@ class TestRunImagesSearchApi(RunImagesTestBase): + client = self.client + + query = self.isolated_query_patch() +- response = client.get('/api/runs/search/images/', params={'q': query, ++ response = client.post('/api/runs/search/images/', json={'q': query, + 'record_range': input_range, + 'record_density': 100, + 'report_progress': False}) +@@ -137,7 +130,7 @@ class TestRunImagesSearchApi(RunImagesTestBase): + client = self.client + + query = self.isolated_query_patch() +- response = client.get('/api/runs/search/images/', params={ ++ response = client.post('/api/runs/search/images/', json={ + 'q': query, + 'record_range': '10:20', + 'index_range': '3:6', +@@ -169,7 +162,7 @@ class TestRunImagesSearchApi(RunImagesTestBase): + class RunImagesURIBulkLoadApi(RunImagesTestBase): + @classmethod + def setUpClass(cls) -> None: +- super().setUpClass() ++ # super().setUpClass() + cls.image_blobs = {} + run = Run(run_hash=cls.run_hash, read_only=True) + empty_context = Context({}) +@@ -190,7 +183,7 @@ class RunImagesURIBulkLoadApi(RunImagesTestBase): + self.uri_map = {} + client = self.client + +- response = client.get('/api/runs/search/images/', params={ ++ response = client.post('/api/runs/search/images/', json={ + 'record_range': '0:10', + 'index_range': '0:5', + 'report_progress': False, +@@ -254,7 +247,7 @@ class TestRunImagesBatchApi(RunImagesTestBase): + class TestImageListsAndSingleImagesSearchApi(ApiTestBase): + @classmethod + def setUpClass(cls) -> None: +- super().setUpClass() ++ # super().setUpClass() + + run = cls.create_run(system_tracking_interval=None) + cls.run_hash = run.hash +@@ -268,8 +261,8 @@ class TestImageListsAndSingleImagesSearchApi(ApiTestBase): + client = self.client + + query = self.isolated_query_patch('images.name == "single_images"') +- response = client.get('/api/runs/search/images/', +- params={'q': query, 'report_progress': False}) ++ response = client.post('/api/runs/search/images/', ++ json={'q': query, 'report_progress': False}) + self.assertEqual(200, response.status_code) + + decoded_response = decode_tree(decode_encoded_tree_stream(response.iter_bytes(chunk_size=512 * 1024))) +@@ -291,7 +284,7 @@ class TestImageListsAndSingleImagesSearchApi(ApiTestBase): + client = self.client + + query = self.isolated_query_patch() +- response = client.get('/api/runs/search/images/', params={'q': query, 'report_progress': False}) ++ response = client.post('/api/runs/search/images/', json={'q': query, 'report_progress': False}) + self.assertEqual(200, response.status_code) + + decoded_response = decode_tree(decode_encoded_tree_stream(response.iter_bytes(chunk_size=512 * 1024))) +@@ -319,8 +312,8 @@ class TestImageListsAndSingleImagesSearchApi(ApiTestBase): + client = self.client + + query = self.isolated_query_patch() +- response = client.get('/api/runs/search/images/', +- params={'q': query, 'index_range': '3:5', 'report_progress': False}) ++ response = client.post('/api/runs/search/images/', ++ json={'q': query, 'index_range': '3:5', 'report_progress': False}) + self.assertEqual(200, response.status_code) + + decoded_response = decode_tree(decode_encoded_tree_stream(response.iter_bytes(chunk_size=512 * 1024))) +@@ -348,7 +341,7 @@ class TestImageListsAndSingleImagesSearchApi(ApiTestBase): + class TestRunInfoApi(ApiTestBase): + @classmethod + def setUpClass(cls) -> None: +- super().setUpClass() ++ # super().setUpClass() + + # run1 -> context {'subset': 'train'} -> Image[] + # | -> integers +@@ -375,7 +368,7 @@ class TestRunInfoApi(ApiTestBase): + + def test_run_info_get_images_only_api(self): + client = self.client +- response = client.get(f'api/runs/{self.run1_hash}/info', params={'sequence': 'images'}) ++ response = client.post(f'api/runs/{self.run1_hash}/info', json={'sequence': 'images'}) + self.assertEqual(200, response.status_code) + response_data = response.json() + self.assertEqual(1, len(response_data['traces'])) +@@ -383,7 +376,7 @@ class TestRunInfoApi(ApiTestBase): + self.assertDictEqual({'subset': 'train'}, response_data['traces']['images'][0]['context']) + self.assertEqual('image_lists', response_data['traces']['images'][0]['name']) + +- response = client.get(f'api/runs/{self.run2_hash}/info', params={'sequence': 'images'}) ++ response = client.post(f'api/runs/{self.run2_hash}/info', json={'sequence': 'images'}) + self.assertEqual(200, response.status_code) + response_data = response.json() + self.assertEqual(1, len(response_data['traces'])) +@@ -398,7 +391,7 @@ class TestRunInfoApi(ApiTestBase): + ]) + def test_run_info_get_all_sequences_api(self, qparams, trace_type_count): + client = self.client +- response = client.get(f'api/runs/{self.run1_hash}/info', params=qparams) ++ response = client.post(f'api/runs/{self.run1_hash}/info', json=qparams) + self.assertEqual(200, response.status_code) + response_data = response.json() + self.assertEqual(trace_type_count, len(response_data['traces'])) +@@ -419,7 +412,7 @@ class TestRunInfoApi(ApiTestBase): + self.assertDictEqual({'subset': 'train'}, metrics_data[1]['context']) + self.assertDictEqual({'subset': 'train'}, metrics_data[2]['context']) + +- response = client.get(f'api/runs/{self.run2_hash}/info', params={'sequence': ('images', 'metric')}) ++ response = client.post(f'api/runs/{self.run2_hash}/info', json={'sequence': ('images', 'metric')}) + self.assertEqual(200, response.status_code) + response_data = response.json() + self.assertEqual(2, len(response_data['traces'])) +@@ -436,14 +429,14 @@ class TestRunInfoApi(ApiTestBase): + + def test_run_info_get_metrics_only_api(self): + client = self.client +- response = client.get(f'api/runs/{self.run1_hash}/info', params={'sequence': 'metric'}) ++ response = client.post(f'api/runs/{self.run1_hash}/info', json={'sequence': 'metric'}) + self.assertEqual(200, response.status_code) + response_data = response.json() + self.assertEqual(1, len(response_data['traces'])) + self.assertIn('metric', response_data['traces']) + self.assertEqual(3, len(response_data['traces']['metric'])) + +- response = client.get(f'api/runs/{self.run2_hash}/info', params={'sequence': 'metric'}) ++ response = client.post(f'api/runs/{self.run2_hash}/info', json={'sequence': 'metric'}) + self.assertEqual(200, response.status_code) + response_data = response.json() + self.assertEqual(1, len(response_data['traces'])) +@@ -452,5 +445,5 @@ class TestRunInfoApi(ApiTestBase): + + def test_invalid_sequence_type(self): + client = self.client +- response = client.get(f'api/runs/{self.run1_hash}/info', params={'sequence': 'non-existing-sequence'}) ++ response = client.post(f'api/runs/{self.run1_hash}/info', json={'sequence': 'non-existing-sequence'}) + self.assertEqual(400, response.status_code) diff --git a/tests/conftest.py b/tests/conftest.py -index 8cdd353..e69de29 100644 +index 8cdd353e..e69de29b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,44 +0,0 @@ @@ -179,7 +400,7 @@ index 8cdd353..e69de29 100644 - del os.environ[AIM_REPO_NAME] diff --git a/tests/fml.py b/tests/fml.py new file mode 100644 -index 0000000..099cab0 +index 00000000..19a3387d --- /dev/null +++ b/tests/fml.py @@ -0,0 +1,199 @@ @@ -208,7 +429,7 @@ index 0000000..099cab0 + + +def init_server(backend_uri, root_artifact_uri): -+ port = get_safe_port() ++ port = 5000 #get_safe_port() + address = f"{LOCALHOST}:{port}" + process = Popen( + [ diff --git a/tests/integration/python/config.json b/tests/integration/python/config.json index 60da46c0c..1c3218bed 100644 --- a/tests/integration/python/config.json +++ b/tests/integration/python/config.json @@ -13,12 +13,15 @@ "httpx==0.25.0", "idna==3.4", "iniconfig==2.0.0", + "numpy", "packaging==23.2", "parameterized==0.9.0", + "pillow", "pluggy==1.3.0", "pytest==7.4.3", "pytz==2023.3.post1", - "sniffio==1.3.0" + "sniffio==1.3.0", + "sqlalchemy" ], "tests": [ "tests/api/test_dashboards_api.py", @@ -83,4 +86,4 @@ ] } } -} \ No newline at end of file +}