Skip to content

Commit

Permalink
[release/0.3] Permit duplicate param logging (#591)
Browse files Browse the repository at this point in the history
Backport of #550

* Rely on database to detect duplicate parameter logging where the value has changed and prevent them
* Fail the entire batch when an invalid duplicate parameter is detected
* Return HTTP 400 instead of 500 in case of duplication

Co-authored-by: Geoffrey Wilson <[email protected]>
  • Loading branch information
jgiannuzzi and suprjinx authored Nov 15, 2023
1 parent a5f3adc commit b4287c6
Show file tree
Hide file tree
Showing 8 changed files with 370 additions and 42 deletions.
32 changes: 32 additions & 0 deletions pkg/api/mlflow/dao/repositories/helpers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package repositories

import (
"fmt"
"strings"

"github.com/G-Research/fasttrackml/pkg/api/mlflow/dao/models"
)

// makeSqlPlaceholders collects a string of "(?,?,?), (?,?,?)" and so on,
// for use as sql parameters
func makeSqlPlaceholders(numberInEachSet, numberOfSets int) string {
set := fmt.Sprintf("(%s)", strings.Repeat("?,", numberInEachSet-1)+"?")
return strings.Repeat(set+",", numberOfSets-1) + set
}

// makeParamConflictPlaceholdersAndValues provides sql placeholders and concatenates
// Key, Value, RunID from each input Param for use in sql values replacement
func makeParamConflictPlaceholdersAndValues(params []models.Param) (string, []interface{}) {
// make place holders of 3 fields for each param
placeholders := makeSqlPlaceholders(3, len(params))
// values array is params * 3 in length since using 3 fields from each
valuesArray := make([]interface{}, len(params)*3)
index := 0
for _, param := range params {
valuesArray[index] = param.Key
valuesArray[index+1] = param.Value
valuesArray[index+2] = param.RunID
index = index + 3
}
return placeholders, valuesArray
}
55 changes: 55 additions & 0 deletions pkg/api/mlflow/dao/repositories/helpers_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package repositories

import (
"testing"

"github.com/stretchr/testify/assert"

"github.com/G-Research/fasttrackml/pkg/api/mlflow/dao/models"
)

func Test_makeSqlPlaceholders(t *testing.T) {
tests := []struct {
numberInEachSet int
numberOfSets int
expectedResult string
}{
{numberInEachSet: 1, numberOfSets: 1, expectedResult: "(?)"},
{numberInEachSet: 2, numberOfSets: 1, expectedResult: "(?,?)"},
{numberInEachSet: 1, numberOfSets: 2, expectedResult: "(?),(?)"},
{numberInEachSet: 2, numberOfSets: 2, expectedResult: "(?,?),(?,?)"},
}

for _, tt := range tests {
result := makeSqlPlaceholders(tt.numberInEachSet, tt.numberOfSets)
assert.Equal(t, tt.expectedResult, result)
}
}

func Test_makeParamConflictPlaceholdersAndValues(t *testing.T) {
tests := []struct {
params []models.Param
expectedPlaceholders string
expectedValues []interface{}
}{
{
params: []models.Param{{Key: "key1", Value: "value1", RunID: "run1"}},
expectedPlaceholders: "(?,?,?)",
expectedValues: []interface{}{"key1", "value1", "run1"},
},
{
params: []models.Param{
{Key: "key1", Value: "value1", RunID: "run1"},
{Key: "key2", Value: "value2", RunID: "run2"},
},
expectedPlaceholders: "(?,?,?),(?,?,?)",
expectedValues: []interface{}{"key1", "value1", "run1", "key2", "value2", "run2"},
},
}

for _, tt := range tests {
placeholders, values := makeParamConflictPlaceholdersAndValues(tt.params)
assert.Equal(t, tt.expectedPlaceholders, placeholders)
assert.Equal(t, tt.expectedValues, values)
}
}
96 changes: 59 additions & 37 deletions pkg/api/mlflow/dao/repositories/param.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,34 @@ import (

"github.com/rotisserie/eris"
"gorm.io/gorm"
"gorm.io/gorm/clause"

"github.com/G-Research/fasttrackml/pkg/api/mlflow/dao/models"
)

// ParamConflictError is returned when there is a conflict in the params (same key, different value).
type ParamConflictError struct {
Message string
}

// Error returns the ParamConflictError message.
func (e ParamConflictError) Error() string {
return e.Message
}

// paramConflict represents a conflicting parameter.
type paramConflict struct {
RunID string `gorm:"column:run_uuid"`
Key string
OldValue string
NewValue string
}

// String renders the paramConflict for error messages.
func (pc paramConflict) String() string {
return fmt.Sprintf("{run_id: %s, key: %s, old_value: %s, new_value: %s}", pc.RunID, pc.Key, pc.OldValue, pc.NewValue)
}

// ParamRepositoryProvider provides an interface to work with models.Param entity.
type ParamRepositoryProvider interface {
// CreateBatch creates []models.Param entities in batch.
Expand All @@ -32,48 +56,46 @@ func NewParamRepository(db *gorm.DB) *ParamRepository {

// CreateBatch creates []models.Param entities in batch.
func (r ParamRepository) CreateBatch(ctx context.Context, batchSize int, params []models.Param) error {
// try to create params in batch; error condition requires special handling
// to allow certain duplicates
if err := r.db.CreateInBatches(params, batchSize).Error; err != nil {
// remove duplicate rows and try again
dedupedParams, errRemovingMatches := r.removeExactMatches(ctx, params)
if errRemovingMatches != nil {
return eris.Wrap(errRemovingMatches, "error removing duplicates in batch")
if err := r.db.Transaction(func(tx *gorm.DB) error {
if err := tx.WithContext(ctx).Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "run_uuid"}, {Name: "key"}},
DoNothing: true,
}).CreateInBatches(params, batchSize).Error; err != nil {
return eris.Wrap(err, "error creating params in batch")
}
if err := r.db.CreateInBatches(dedupedParams, batchSize).Error; err != nil {
return eris.Wrap(err, "error creating params in batch after removing duplicates")
// if there were ignored conflicts, verify to be exact duplicates
if tx.RowsAffected != int64(len(params)) {
conflictingParams, err := findConflictingParams(tx, params)
if err != nil {
return eris.Wrap(err, "error checking for conflicting params")
}
if len(conflictingParams) > 0 {
return ParamConflictError{
Message: fmt.Sprintf("conflicting params found: %v", conflictingParams),
}
}
}
return nil
}); err != nil {
return err
}
return nil
}

// removeExactMatches will return a new slice of params which excludes exact matches
func (r ParamRepository) removeExactMatches(ctx context.Context, params []models.Param) ([]models.Param, error) {
var keys []string
paramMap := map[string]models.Param{}
for _, param := range params {
key := fmt.Sprintf("%v-%v-%v", param.RunID, param.Key, param.Value)
keys = append(keys, key)
paramMap[key] = param
// findConflictingParams checks if there are conflicting values for the input params. If a key does not
// yet exist in the db, or if the same key and value already exist for the run, it is not a conflict.
// If the key already exists for the run but with a different value, it is a conflict. Conflicts are returned.
func findConflictingParams(tx *gorm.DB, params []models.Param) ([]paramConflict, error) {
var conflicts []paramConflict
placeholders, values := makeParamConflictPlaceholdersAndValues(params)
sql := fmt.Sprintf(`WITH new(key, value, run_uuid) AS (VALUES %s)
SELECT current.run_uuid, current.key, current.value as old_value, new.value as new_value
FROM params AS current
INNER JOIN new USING (run_uuid, key)
WHERE new.value != current.value`, placeholders)
if err := tx.Raw(sql, values...).
Find(&conflicts).Error; err != nil {
return nil, eris.Wrap(err, "error fetching params from db")
}

var foundKeys []string
if err := r.db.Raw(`
select run_uuid || '-' || key || '-' || value
from params
where run_uuid || '-' || key || '-' || value in ?`, keys).
Find(&foundKeys).Error; err != nil {
return []models.Param{}, eris.Wrap(err, "problem selecting existing params")
}

for _, foundKey := range foundKeys {
delete(paramMap, foundKey)
}

paramsToReturn := []models.Param{}
for _, v := range paramMap {
paramsToReturn = append(paramsToReturn, v)
}

return paramsToReturn, nil
return conflicts, nil
}
7 changes: 7 additions & 0 deletions pkg/api/mlflow/service/run/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"database/sql"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"regexp"
"strconv"
Expand Down Expand Up @@ -435,6 +436,9 @@ func (s Service) LogParam(ctx context.Context, req *request.LogParamRequest) err

param := convertors.ConvertLogParamRequestToDBModel(run.ID, req)
if err := s.paramRepository.CreateBatch(ctx, 1, []models.Param{*param}); err != nil {
if errors.As(err, &repositories.ParamConflictError{}) {
return api.NewInvalidParameterValueError("unable to insert params for run '%s': %s", run.ID, err)
}
return api.NewInternalError("unable to insert params for run '%s': %s", run.ID, err)
}

Expand Down Expand Up @@ -504,6 +508,9 @@ func (s Service) LogBatch(ctx context.Context, req *request.LogBatchRequest) err
return api.NewInvalidParameterValueError(err.Error())
}
if err := s.paramRepository.CreateBatch(ctx, 100, params); err != nil {
if errors.As(err, &repositories.ParamConflictError{}) {
return api.NewInvalidParameterValueError("unable to insert params for run '%s': %s", run.ID, err)
}
return api.NewInternalError("unable to insert params for run '%s': %s", run.ID, err)
}
if err := s.metricRepository.CreateBatch(ctx, run, 100, metrics); err != nil {
Expand Down
85 changes: 85 additions & 0 deletions pkg/api/mlflow/service/run/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1106,6 +1106,50 @@ func TestService_LogBatch_Error(t *testing.T) {
)
},
},
{
name: "CreateBatchParamsConflictError",
error: api.NewInvalidParameterValueError(`unable to insert params for run '1': param conflict!`),
request: &request.LogBatchRequest{
RunID: "1",
Params: []request.ParamPartialRequest{
{
Key: "key",
Value: "value",
},
},
},
service: func() *Service {
runRepository := repositories.MockRunRepositoryProvider{}
runRepository.On(
"GetByIDAndLifecycleStage",
context.TODO(),
"1",
models.LifecycleStageActive,
).Return(&models.Run{
ID: "1",
}, nil)
paramRepository := repositories.MockParamRepositoryProvider{}
paramRepository.On(
"CreateBatch",
context.TODO(),
100,
[]models.Param{
{
Key: "key",
Value: "value",
RunID: "1",
},
},
).Return(repositories.ParamConflictError{Message: "param conflict!"})
return NewService(
&repositories.MockTagRepositoryProvider{},
&runRepository,
&paramRepository,
&repositories.MockMetricRepositoryProvider{},
&repositories.MockExperimentRepositoryProvider{},
)
},
},
{
name: "CreateBatchMetricsDatabaseError",
error: api.NewInternalError(`unable to insert metrics for run '1': database error`),
Expand Down Expand Up @@ -1665,6 +1709,47 @@ func TestService_LogParam_Error(t *testing.T) {
)
},
},
{
name: "LogParamConflictError",
error: api.NewInvalidParameterValueError(`unable to insert params for run '1': conflict!`),
request: &request.LogParamRequest{
RunID: "1",
Key: "key",
Value: "value",
},
service: func() *Service {
runRepository := repositories.MockRunRepositoryProvider{}
runRepository.On(
"GetByIDAndLifecycleStage",
context.TODO(),
"1",
models.LifecycleStageActive,
).Return(&models.Run{
ID: "1",
LifecycleStage: models.LifecycleStageActive,
}, nil)
paramRepository := repositories.MockParamRepositoryProvider{}
paramRepository.On(
"CreateBatch",
context.TODO(),
1,
mock.MatchedBy(func(params []models.Param) bool {
assert.Equal(t, 1, len(params))
assert.Equal(t, "key", params[0].Key)
assert.Equal(t, "value", params[0].Value)
assert.Equal(t, "1", params[0].RunID)
return true
}),
).Return(repositories.ParamConflictError{Message: "conflict!"})
return NewService(
&repositories.MockTagRepositoryProvider{},
&runRepository,
&paramRepository,
&repositories.MockMetricRepositoryProvider{},
&repositories.MockExperimentRepositoryProvider{},
)
},
},
}

for _, tt := range testData {
Expand Down
9 changes: 9 additions & 0 deletions tests/integration/golang/fixtures/param.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,12 @@ func (f ParamFixtures) CreateParam(ctx context.Context, param *models.Param) (*m
}
return param, nil
}

// GetParamsByRunID returns all params for a given run.
func (f ParamFixtures) GetParamsByRunID(ctx context.Context, runID string) ([]models.Param, error) {
var params []models.Param
if err := f.baseFixtures.db.WithContext(ctx).Where("run_uuid = ?", runID).Find(&params).Error; err != nil {
return nil, eris.Wrap(err, "error getting params by run id")
}
return params, nil
}
Loading

0 comments on commit b4287c6

Please sign in to comment.