Skip to content

Commit

Permalink
Fix GetRunMetrics endpoint. (#913)
Browse files Browse the repository at this point in the history
backport of #899
  • Loading branch information
dsuhinin authored Feb 16, 2024
1 parent 7cea4dc commit 360e94d
Showing 1 changed file with 41 additions and 45 deletions.
86 changes: 41 additions & 45 deletions pkg/api/aim/runs.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,23 +171,25 @@ func GetRunMetrics(c *fiber.Ctx) error {
return fiber.NewError(fiber.StatusUnprocessableEntity, err.Error())
}

metricKeysMap, contexts := make(fiber.Map, len(b)), make([]types.JSONB, 0, len(b))
// this is a temporary map which provides uniqueness inside the metricKeysMap map.
type metric struct {
name string
context string
}

// collect unique metrics. uniqueness provides metricKeysMap + metric struct.
metricKeysMap := make(map[metric]any, len(b))
for _, m := range b {
if m.Context != nil {
serializedContext, err := json.Marshal(m.Context)
if err != nil {
return fiber.NewError(fiber.StatusUnprocessableEntity, err.Error())
}
contexts = append(contexts, serializedContext)
metricKeysMap[metric{
name: m.Name,
context: string(serializedContext),
}] = nil
}
metricKeysMap[m.Name] = nil
}
metricKeys := make([]string, len(metricKeysMap))

i := 0
for k := range metricKeysMap {
metricKeys[i] = k
i++
}

// check that requested run actually exists.
Expand All @@ -209,59 +211,53 @@ func GetRunMetrics(c *fiber.Ctx) error {
return fmt.Errorf("unable to find run %q: %w", p.ID, err)
}

subQuery := database.DB
for metricKey := range metricKeysMap {
subQuery = subQuery.Or("key = ? AND json = ?", metricKey.name, types.JSONB(metricKey.context))
}

// fetch run metrics based on provided criteria.
var data []database.Metric
if err := database.DB.Where(
"run_uuid = ?", p.ID,
).InnerJoins(
if err := database.DB.InnerJoins(
"Context",
func() *gorm.DB {
query := database.DB
for _, context := range contexts {
query = query.Or("json = ?", context)
}
return query
}(),
).Where(
"key IN ?", metricKeys,
).Order(
"iter",
).Find(
&data,
).Error; err != nil {
).Where(
"run_uuid = ?", p.ID,
).Where(
subQuery,
).Find(&data).Error; err != nil {
return fmt.Errorf("unable to find run metrics: %w", err)
}

metrics := make(map[string]struct {
name string
iters []int
values []*float64
context json.RawMessage
}, len(metricKeys))
metrics := make(map[metric]struct {
iters []int
values []*float64
}, len(metricKeysMap))

for _, m := range data {
v := m.Value
for _, item := range data {
v := item.Value
pv := &v
if m.IsNan {
if item.IsNan {
pv = nil
}

key := fmt.Sprintf("%s%d", m.Key, m.ContextID)
k := metrics[key]
k.name = m.Key
k.iters = append(k.iters, int(m.Iter))
k.values = append(k.values, pv)
k.context = json.RawMessage(m.Context.Json)
metrics[key] = k
key := metric{
name: item.Key,
context: string(item.Context.Json),
}
m := metrics[key]
m.iters = append(m.iters, int(item.Iter))
m.values = append(m.values, pv)
metrics[key] = m
}

resp := make([]fiber.Map, 0, len(metrics))
for _, m := range metrics {
for key, m := range metrics {
resp = append(resp, fiber.Map{
"name": m.name,
"name": key.name,
"iters": m.iters,
"values": m.values,
"context": m.context,
"context": json.RawMessage(key.context),
})
}

Expand Down

0 comments on commit 360e94d

Please sign in to comment.