Skip to content

Commit

Permalink
Merge pull request #1 from criteo-forks/fix/ml-bump
Browse files Browse the repository at this point in the history
fix(ml): adapt experimental ml commands to 8.x ES API
  • Loading branch information
jmbass authored Aug 28, 2024
2 parents e60c21b + a84821c commit 394dfa0
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions detection_rules/ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,16 +395,17 @@ def upload_job(ctx: click.Context, job_file, overwrite):
with open(job_file, 'r') as f:
job = json.load(f)

def safe_upload(func):
def safe_upload(func, job_type, **kwargs):
try:
func(name, body)
func(**kwargs)
except (elasticsearch.ConflictError, elasticsearch.RequestError) as err:
if isinstance(err, elasticsearch.RequestError) and err.error != 'resource_already_exists_exception':
client_error(str(err), err, ctx=ctx)

if overwrite:
ctx.invoke(delete_job, job_name=name, job_type=job_type)
func(name, body)
click.echo(f'WARN: Overwriting job {name} if it exists.')
ctx.invoke(delete_job, job_name=kwargs.get('job_id'), job_type=job_type)
func(**kwargs)
else:
client_error(str(err), err, ctx=ctx)

Expand All @@ -414,11 +415,11 @@ def safe_upload(func):
body = job['body']

if job_type == 'anomaly_detection':
safe_upload(ml_client.put_job)
safe_upload(ml_client.put_job, job_type=job_type, job_id=name, body=body)
elif job_type == 'data_frame_analytic':
safe_upload(ml_client.put_data_frame_analytics)
safe_upload(ml_client.put_data_frame_analytics, job_type=job_type, id=name, body=body)
elif job_type == 'datafeed':
safe_upload(ml_client.put_datafeed)
safe_upload(ml_client.put_datafeed, job_type=job_type, datafeed_id=name, body=body)
else:
client_error(f'Unknown ML job type: {job_type}')

Expand All @@ -438,11 +439,11 @@ def delete_job(ctx: click.Context, job_name, job_type, verbose=True):

try:
if job_type == 'anomaly_detection':
ml_client.delete_job(job_name)
ml_client.delete_job(job_id=job_name)
elif job_type == 'data_frame_analytic':
ml_client.delete_data_frame_analytics(job_name)
ml_client.delete_data_frame_analytics(job_id=job_name)
elif job_type == 'datafeed':
ml_client.delete_datafeed(job_name)
ml_client.delete_datafeed(job_id=job_name)
else:
client_error(f'Unknown ML job type: {job_type}')
except (elasticsearch.NotFoundError, elasticsearch.ConflictError) as e:
Expand Down

0 comments on commit 394dfa0

Please sign in to comment.