Skip to content

Commit

Permalink
chore: Refactor get_trial_by_model_and_input to handle multiple JSON …
Browse files Browse the repository at this point in the history
…input URLs
  • Loading branch information
amirnd51 committed Jul 24, 2024
1 parent 2c97869 commit 7ba1ade
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 23 deletions.
5 changes: 3 additions & 2 deletions python_api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,9 +335,10 @@ async def predict(request: PredictRequest):
first_input=inputs[0]


# xxx
trail= get_trial_by_model_and_input( model_id, first_input["src"])

trail= get_trial_by_model_and_input( model_id, inputs)
# print(trail)

# print("trial")
if trail and trail[2] is not None:
# print(trail[2])
Expand Down
84 changes: 63 additions & 21 deletions python_api/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,42 +47,84 @@ def create_expriement( cur, conn):
return experiment_id


def get_trial_by_model_and_input(model_id, input_url):
input_url=json.loads(input_url)
input_url=input_url["url"]
input_query = """
SELECT trial_id
FROM trial_inputs
WHERE %s in url
"""
def get_trial_by_model_and_input(model_id, input_urls):
# Check if input_urls is a list of json objects with src and inputType keys
if not all(isinstance(item, dict) and 'src' in item and 'inputType' in item for item in input_urls):
raise ValueError("Each input_url must be a JSON object with 'src' and 'inputType' keys")

# Construct the input_query based on the number of input_urls
if len(input_urls) == 1:
input_url = input_urls[0]["src"]
input_query = "url LIKE %s"
input_values = [f"%{input_url}%"]
elif len(input_urls) == 2:
input_url_1 = input_urls[0]["src"]
input_url_2 = input_urls[1]["src"]
input_query = "url LIKE %s OR url LIKE %s"
input_values = [f"%{input_url_1}%", f"%{input_url_2}%"]
else:
return None

# Main query to get trial details
query = f"""
SELECT trials.*,
experiments.*,
models.*,
trial_inputs.*
experiments.*,
models.*,
trial_inputs.*
FROM trials
JOIN experiments ON trials.experiment_id = experiments.id
JOIN models ON trials.model_id = models.id
JOIN trial_inputs ON trials.id = trial_inputs.trial_id
WHERE trials.completed_at IS NOT NULL
AND trials.model_id = %s
AND trials.id IN ({input_query})
AND ({input_query})
"""

print(f"Debug: SQL Query: {query}")
print(f"Debug: Input Values: {input_values}")



# Main query to get trial details
query = f"""
SELECT trials.*,
experiments.*,
models.*,
trial_inputs.*
FROM trials
JOIN experiments ON trials.experiment_id = experiments.id
JOIN models ON trials.model_id = models.id
JOIN trial_inputs ON trials.id = trial_inputs.trial_id
WHERE trials.completed_at IS NOT NULL
AND trials.model_id = %s
AND ({input_query})
"""
debug_query = query % tuple([model_id] + input_values)
print(f"Debug: SQL Query: {debug_query}")
# Execute the query
cur,conn=get_db_cur_con()
cur.execute(query, ( model_id,input_url))
try:
# Replace with your database connection setup
cur, conn = get_db_cur_con()
cur.execute(query, [model_id] + input_values)

# Fetch the result
trial = cur.fetchone()
# Fetch the result
trial = cur.fetchone()

if trial is None:
return None

print(trial['experiment_id'], trial['trial_id'])
return (trial['experiment_id'], trial['trial_id'],trial["completed_at"])
if trial is None:
print("Debug: No trial found")
return None

# Fetch column names for reference
colnames = [desc[0] for desc in cur.description]
# print(f"Debug: Columns: {colnames}")
# print(f"Debug: Trial Data: {trial}")

# Assuming the columns are returned in the order you expect
# You may need to adjust this depending on your database schema
return (trial['experiment_id'], trial['trial_id'], trial['completed_at'])


except (Exception, psycopg2.DatabaseError) as error:
print(f"Error: {error}")
return None

0 comments on commit 7ba1ade

Please sign in to comment.