diff --git a/python_api/api.py b/python_api/api.py index 4bdd740..d381800 100644 --- a/python_api/api.py +++ b/python_api/api.py @@ -335,9 +335,10 @@ async def predict(request: PredictRequest): first_input=inputs[0] - # xxx - trail= get_trial_by_model_and_input( model_id, first_input["src"]) + + trail= get_trial_by_model_and_input( model_id, inputs) # print(trail) + # print("trial") if trail and trail[2] is not None: # print(trail[2]) diff --git a/python_api/db.py b/python_api/db.py index 16682c8..41de19a 100644 --- a/python_api/db.py +++ b/python_api/db.py @@ -47,42 +47,84 @@ def create_expriement( cur, conn): return experiment_id -def get_trial_by_model_and_input(model_id, input_url): - input_url=json.loads(input_url) - input_url=input_url["url"] - input_query = """ - SELECT trial_id - FROM trial_inputs - WHERE %s in url - """ +def get_trial_by_model_and_input(model_id, input_urls): + # Check if input_urls is a list of json objects with src and inputType keys + if not all(isinstance(item, dict) and 'src' in item and 'inputType' in item for item in input_urls): + raise ValueError("Each input_url must be a JSON object with 'src' and 'inputType' keys") + + # Construct the input_query based on the number of input_urls + if len(input_urls) == 1: + input_url = input_urls[0]["src"] + input_query = "url LIKE %s" + input_values = [f"%{input_url}%"] + elif len(input_urls) == 2: + input_url_1 = input_urls[0]["src"] + input_url_2 = input_urls[1]["src"] + input_query = "url LIKE %s OR url LIKE %s" + input_values = [f"%{input_url_1}%", f"%{input_url_2}%"] + else: + return None # Main query to get trial details query = f""" SELECT trials.*, - experiments.*, - models.*, - trial_inputs.* + experiments.*, + models.*, + trial_inputs.* FROM trials JOIN experiments ON trials.experiment_id = experiments.id JOIN models ON trials.model_id = models.id JOIN trial_inputs ON trials.id = trial_inputs.trial_id WHERE trials.completed_at IS NOT NULL AND trials.model_id = %s - AND trials.id IN ({input_query}) + AND ({input_query}) """ + + print(f"Debug: SQL Query: {query}") + print(f"Debug: Input Values: {input_values}") + + + # Main query to get trial details + query = f""" + SELECT trials.*, + experiments.*, + models.*, + trial_inputs.* + FROM trials + JOIN experiments ON trials.experiment_id = experiments.id + JOIN models ON trials.model_id = models.id + JOIN trial_inputs ON trials.id = trial_inputs.trial_id + WHERE trials.completed_at IS NOT NULL + AND trials.model_id = %s + AND ({input_query}) + """ + debug_query = query % tuple([model_id] + input_values) + print(f"Debug: SQL Query: {debug_query}") # Execute the query - cur,conn=get_db_cur_con() - cur.execute(query, ( model_id,input_url)) + try: + # Replace with your database connection setup + cur, conn = get_db_cur_con() + cur.execute(query, [model_id] + input_values) - # Fetch the result - trial = cur.fetchone() + # Fetch the result + trial = cur.fetchone() - if trial is None: - return None - - print(trial['experiment_id'], trial['trial_id']) - return (trial['experiment_id'], trial['trial_id'],trial["completed_at"]) + if trial is None: + print("Debug: No trial found") + return None + + # Fetch column names for reference + colnames = [desc[0] for desc in cur.description] + # print(f"Debug: Columns: {colnames}") + # print(f"Debug: Trial Data: {trial}") + # Assuming the columns are returned in the order you expect + # You may need to adjust this depending on your database schema + return (trial['experiment_id'], trial['trial_id'], trial['completed_at']) + except (Exception, psycopg2.DatabaseError) as error: + print(f"Error: {error}") + return None + \ No newline at end of file