Skip to content

Commit

Permalink
Improve logic for validating whether metrics are selected in the metr…
Browse files Browse the repository at this point in the history
…ics explorer (#309)
  • Loading branch information
jgiannuzzi authored Sep 13, 2023
1 parent 9ce0ae5 commit 28a5bab
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 85 deletions.
13 changes: 10 additions & 3 deletions pkg/api/aim/query/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,14 @@ type QueryParser struct {

type ParsedQuery interface {
Filter(*gorm.DB) *gorm.DB
IsMetricSelected() bool
}

type parsedQuery struct {
qp *QueryParser
joins map[string]join
conditions []clause.Expression
qp *QueryParser
joins map[string]join
conditions []clause.Expression
metricSelected bool
}

type callable func(args []ast.Expr) (any, error)
Expand Down Expand Up @@ -162,6 +164,10 @@ func (pq *parsedQuery) Filter(tx *gorm.DB) *gorm.DB {
return tx
}

func (pq *parsedQuery) IsMetricSelected() bool {
return pq.metricSelected
}

func (pq *parsedQuery) parseNode(node ast.Expr) (any, error) {
ret, err := pq._parseNode(node)
if err != nil && !errors.Is(err, SyntaxError{}) {
Expand Down Expand Up @@ -566,6 +572,7 @@ func (pq *parsedQuery) parseName(node *ast.Name) (any, error) {
func(attr string) (any, error) {
switch attr {
case "name":
pq.metricSelected = true
return clause.Column{
Table: table,
Name: "key",
Expand Down
32 changes: 11 additions & 21 deletions pkg/api/aim/runs.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"encoding/binary"
"fmt"
"math"
"regexp"
"strconv"
"strings"
"time"
Expand All @@ -23,9 +22,6 @@ import (
"github.com/G-Research/fasttrackml/pkg/database"
)

// validation rule for SearchMetrics
var metricNameRegExp = regexp.MustCompile(`in\s*metric\.name|metric\.name(?:\.|\s*==)`)

func GetRunInfo(c *fiber.Ctx) error {
q := struct {
// TODO skip_system is unused - should we keep it?
Expand Down Expand Up @@ -335,7 +331,7 @@ func SearchRuns(c *fiber.Ctx) error {
return fiber.NewError(fiber.StatusUnprocessableEntity, "x-timezone-offset header is not a valid integer")
}

pq := query.QueryParser{
qp := query.QueryParser{
Default: query.DefaultExpression{
Contains: "run.archived",
Expression: "not run.archived",
Expand All @@ -347,7 +343,7 @@ func SearchRuns(c *fiber.Ctx) error {
TzOffset: tzOffset,
Dialector: database.DB.Dialector.Name(),
}
qp, err := pq.Parse(q.Query)
pq, err := qp.Parse(q.Query)
if err != nil {
return err
}
Expand Down Expand Up @@ -390,7 +386,7 @@ func SearchRuns(c *fiber.Ctx) error {
}

var runs []database.Run
qp.Filter(tx).Find(&runs)
pq.Filter(tx).Find(&runs)
if tx.Error != nil {
return fmt.Errorf("error searching runs: %w", tx.Error)
}
Expand Down Expand Up @@ -519,17 +515,12 @@ func SearchMetrics(c *fiber.Ctx) error {
q.Steps = 50
}

// require a metric.name
if !validateMetricNamePresent(q.Query) {
return fiber.NewError(fiber.StatusUnprocessableEntity, "No metrics are selected")
}

tzOffset, err := strconv.Atoi(c.Get("x-timezone-offset", "0"))
if err != nil {
return fiber.NewError(fiber.StatusUnprocessableEntity, "x-timezone-offset header is not a valid integer")
}

pq := query.QueryParser{
qp := query.QueryParser{
Default: query.DefaultExpression{
Contains: "run.archived",
Expression: "not run.archived",
Expand All @@ -542,11 +533,15 @@ func SearchMetrics(c *fiber.Ctx) error {
TzOffset: tzOffset,
Dialector: database.DB.Dialector.Name(),
}
qp, err := pq.Parse(q.Query)
pq, err := qp.Parse(q.Query)
if err != nil {
return err
}

if !pq.IsMetricSelected() {
return fiber.NewError(fiber.StatusUnprocessableEntity, "No metrics are selected")
}

var totalRuns int64
if tx := database.DB.Model(&database.Run{}).Count(&totalRuns); tx.Error != nil {
return fmt.Errorf("error searching run metrics: %w", tx.Error)
Expand All @@ -557,7 +552,7 @@ func SearchMetrics(c *fiber.Ctx) error {
Joins("Experiment", database.DB.Select("ID", "Name")).
Preload("Params").
Preload("Tags").
Where("run_uuid IN (?)", qp.Filter(database.DB.
Where("run_uuid IN (?)", pq.Filter(database.DB.
Select("runs.run_uuid").
Table("runs").
Joins("LEFT JOIN experiments USING(experiment_id)").
Expand Down Expand Up @@ -610,7 +605,7 @@ func SearchMetrics(c *fiber.Ctx) error {
Table("metrics").
Joins(
"INNER JOIN (?) runmetrics USING(run_uuid, key)",
qp.Filter(database.DB.
pq.Filter(database.DB.
Select("runs.run_uuid", "runs.row_num", "latest_metrics.key", fmt.Sprintf("(latest_metrics.last_iter + 1)/ %f AS interval", float32(q.Steps))).
Table("runs").
Joins("LEFT JOIN experiments USING(experiment_id)").
Expand Down Expand Up @@ -1045,8 +1040,3 @@ func toNumpy(values []float64) fiber.Map {
"blob": buf.Bytes(),
}
}

// validateMetricNamePresent scans the query for metric.name condition
func validateMetricNamePresent(query string) bool {
return metricNameRegExp.Match([]byte(query))
}
150 changes: 89 additions & 61 deletions pkg/api/aim/runs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,69 +4,97 @@ import (
"testing"

"github.com/stretchr/testify/assert"

"github.com/G-Research/fasttrackml/pkg/api/aim/query"
)

func Test_validateMetricNamePresent(t *testing.T) {
tests := []struct {
name string
query string
wantResult bool
}{
{
name: "QueryWithMetricName",
query: `(run.active == true) and ((metric.name == "accuracy"))`,
wantResult: true,
},
{
name: "QueryWithMetricNameNoSpaces",
query: `(run.active == true) and ((metric.name=="accuracy"))`,
wantResult: true,
},
{
name: "QueryWithStringSyntax",
query: `(run.active == true) and (metric.name.startswith("acc"))`,
wantResult: true,
},
{
name: "QueryWithInSyntax",
query: `(run.active == true) and ("accuracy" in metric.name)`,
wantResult: true,
},
{
name: "QueryWithMetricNameAtEnd",
query: `(run.active == true) and "accuracy" in metric.name`,
wantResult: true,
},
{
name: "QueryWithoutMetricName",
query: `(run.active == true)`,
wantResult: false,
},
{
name: "QueryWithoutDot",
query: `(run.active == true) and (metricname === "accuracy")`,
wantResult: false,
},
{
name: "QueryWithOrCharacter",
query: `(run.active == true) and (metric.name| === "accuracy")`,
wantResult: false,
},
{
name: "QueryWithTrickyDictKey",
query: `(run.active == true) and (run.tags["metric.name"] == "foo")`,
wantResult: false,
},
{
name: "QueryWithTrickyDictKeyAndMetricName",
query: `(run.active == true) and (run.tags["mymetric.name"] == "foo") and (metric.name == "accuracy")`,
wantResult: true,
},
func Test_isMetricSelected(t *testing.T) {
dialectors := []string{
"sqlite3",
"postgres",
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := validateMetricNamePresent(tt.query)
assert.Equal(t, tt.wantResult, result)
})
for _, dialector := range dialectors {
qp := query.QueryParser{
Default: query.DefaultExpression{
Contains: "run.archived",
Expression: "not run.archived",
},
Tables: map[string]string{
"runs": "runs",
"experiments": "experiments",
"metrics": "latest_metrics",
},
TzOffset: 0,
Dialector: dialector,
}

tests := []struct {
name string
query string
wantResult bool
}{
{
name: "QueryWithMetricName",
query: `(run.active == True) and ((metric.name == "accuracy"))`,
wantResult: true,
},
{
name: "QueryWithMetricNameNoSpaces",
query: `(run.active == True) and ((metric.name=="accuracy"))`,
wantResult: true,
},
{
name: "QueryWithStringSyntax",
query: `(run.active == True) and (metric.name.startswith("acc"))`,
wantResult: true,
},
{
name: "QueryWithInSyntax",
query: `(run.active == True) and ("accuracy" in metric.name)`,
wantResult: true,
},
{
name: "QueryWithMetricNameAtEnd",
query: `(run.active == True) and "accuracy" in metric.name`,
wantResult: true,
},
{
name: "QueryWithoutMetricName",
query: `(run.active == True)`,
wantResult: false,
},
{
name: "QueryWithTrickyDictKey",
query: `(run.active == True) and (run.tags["metric.name"] == "foo")`,
wantResult: false,
},
{
name: "QueryWithTrickyDictKeyAndMetricName",
query: `(run.active == True) and (run.tags["mymetric.name"] == "foo") and (metric.name == "accuracy")`,
wantResult: true,
},
{
name: "QueryWithRegexMatchMetricName",
query: `(run.active == True) and re.match("accuracy", metric.name)`,
wantResult: true,
},
{
name: "QueryWithRegexSearchMetricName",
query: `(run.active == True) and re.search("accuracy", metric.name)`,
wantResult: true,
},
{
name: "QueryWithRegexSearchAndNoMetricName",
query: `(run.active == True) and re.search("accuracy", run.tags["metric.name"])`,
wantResult: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
pq, err := qp.Parse(tt.query)
assert.Nil(t, err)
assert.Equal(t, tt.wantResult, pq.IsMetricSelected())
})
}
}
}

0 comments on commit 28a5bab

Please sign in to comment.