diff --git a/examples/evaluate_text2sql.py b/examples/evaluate_text2sql.py index ceeeb248bc..19e504009f 100644 --- a/examples/evaluate_text2sql.py +++ b/examples/evaluate_text2sql.py @@ -23,18 +23,14 @@ print_dict( evaluated_dataset[0], - keys_to_print=[ - "source", - "prediction", - "subset", - ], + keys_to_print=["source", "prediction", "subset"], ) print_dict( evaluated_dataset[0]["score"]["global"], ) assert ( - evaluated_dataset[0]["score"]["global"]["score"] >= 0.44 + evaluated_dataset[0]["score"]["global"]["score"] >= 0.43 ), "results have been degraded, something is wrong with the metric" # with llama-3-3-70b-instruct diff --git a/prepare/metrics/text2sql_execution_accuracy.py b/prepare/metrics/text2sql_execution_accuracy.py index ca15409a52..c6465fc763 100644 --- a/prepare/metrics/text2sql_execution_accuracy.py +++ b/prepare/metrics/text2sql_execution_accuracy.py @@ -1,14 +1,23 @@ from unitxt.catalog import add_to_catalog -from unitxt.metrics import ExecutionAccuracy +from unitxt.metrics import SQLExecutionAccuracy from unitxt.test_utils.metrics import test_metric -metric = ExecutionAccuracy() +metric = SQLExecutionAccuracy() predictions = [ "SELECT nme FROM employees WHERE department = 'Sales'", "SELECT name FROM employees WHERE department = 'Sales'", + "SELECT name FROM employees WHERE department = 'Engineering'", + "SELECT id, name FROM employees WHERE department = 'Sales'", + "SELECT name FROM employees WHERE department = 'Non-Existent'", ] # Incorrect column name 'nme' -references = [["SELECT name FROM employees WHERE department = 'Sales';"]] * 2 +references = [ + ["SELECT name FROM employees WHERE department = 'Sales';"], + ["SELECT name FROM employees WHERE department = 'Sales';"], + ["SELECT name FROM employees WHERE department = 'Sales';"], + ["SELECT name FROM employees WHERE department = 'Sales';"], + ["SELECT name FROM employees WHERE department = 'Non-Existent';"], +] task_data = [ { "db": { @@ -26,31 +35,99 @@ }, } } -] * 2 +] * 5 instance_targets = [ { + "error_message": "Error executing SQL: no such column: nme", "execution_accuracy": 0.0, + "gold_df_json": "", + "gold_error": 0.0, + "non_empty_execution_accuracy": 0.0, + "non_empty_gold_df": 0.0, + "predicted_df_json": "", + "predicted_error": 1.0, "score": 0.0, - "score_name": "execution_accuracy", + "score_name": "non_empty_execution_accuracy", + "subset_non_empty_execution_result": 0.0, }, { + "error_message": "", "execution_accuracy": 1.0, + "gold_df_json": '{"0":{"0":"Alice","1":"Charlie"}}', + "gold_error": 1.0, + "non_empty_execution_accuracy": 1.0, + "non_empty_gold_df": 1.0, + "predicted_df_json": '{"0":{"0":"Alice","1":"Charlie"}}', + "predicted_error": 0.0, "score": 1.0, - "score_name": "execution_accuracy", + "score_name": "non_empty_execution_accuracy", + "subset_non_empty_execution_result": 1.0, + }, + { + "error_message": "None", + "execution_accuracy": 0.0, + "gold_df_json": '{"0":{"0":"Alice","1":"Charlie"}}', + "gold_error": 0.0, + "non_empty_execution_accuracy": 0.0, + "non_empty_gold_df": 1.0, + "predicted_df_json": '{"0":{"0":"Bob"}}', + "predicted_error": 0.0, + "score": 0.0, + "score_name": "non_empty_execution_accuracy", + "subset_non_empty_execution_result": 0.0, + }, + { + "error_message": "None", + "execution_accuracy": 0.0, + "gold_df_json": '{"0":{"0":"Alice","1":"Charlie"}}', + "gold_error": 0.0, + "non_empty_execution_accuracy": 0.0, + "non_empty_gold_df": 1.0, + "predicted_df_json": '{"0":{"0":1,"1":3},"1":{"0":"Alice","1":"Charlie"}}', + "predicted_error": 0.0, + "score": 0.0, + "score_name": "non_empty_execution_accuracy", + "subset_non_empty_execution_result": 1.0, + }, + { + "error_message": "", + "execution_accuracy": 1.0, + "gold_df_json": "{}", + "gold_error": 1.0, + "non_empty_execution_accuracy": 0.0, + "non_empty_gold_df": 0.0, + "predicted_df_json": "{}", + "predicted_error": 0.0, + "score": 0.0, + "score_name": "non_empty_execution_accuracy", + "subset_non_empty_execution_result": 0.0, }, ] global_target = { - "execution_accuracy": 0.5, - "execution_accuracy_ci_high": 1.0, + "execution_accuracy": 0.4, + "execution_accuracy_ci_high": 0.87, "execution_accuracy_ci_low": 0.0, - "num_of_instances": 2, - "score": 0.5, - "score_ci_high": 1.0, + "gold_error": 0.4, + "gold_sql_runtime_ci_high": 0.0, + "gold_sql_runtime_ci_low": 0.0, + "non_empty_execution_accuracy": 0.2, + "non_empty_execution_accuracy_ci_high": 0.8, + "non_empty_execution_accuracy_ci_low": 0.0, + "non_empty_gold_df": 0.6, + "num_of_instances": 5, + "predicted_error": 0.2, + "predicted_sql_runtime_ci_high": 0.0, + "predicted_sql_runtime_ci_low": 0.0, + "score": 0.2, + "score_ci_high": 0.8, "score_ci_low": 0.0, - "score_name": "execution_accuracy", + "score_name": "non_empty_execution_accuracy", + "subset_non_empty_execution_result": 0.4, + "subset_non_empty_execution_result_ci_high": 1.0, + "subset_non_empty_execution_result_ci_low": 0.0, } outputs = test_metric( @@ -60,6 +137,11 @@ instance_targets=instance_targets, global_target=global_target, task_data=task_data, + score_keys_to_ignore=[ + "predicted_sql_runtime", + "gold_sql_runtime", + "pred_to_gold_runtime_ratio", + ], ) add_to_catalog(metric, "metrics.text2sql.execution_accuracy", overwrite=True) diff --git a/prepare/processors/text2sql.py b/prepare/processors/text2sql.py index 85fb4c663b..dfee6b1138 100644 --- a/prepare/processors/text2sql.py +++ b/prepare/processors/text2sql.py @@ -1,10 +1,11 @@ from unitxt import add_to_catalog from unitxt.operator import SequentialOperator -from unitxt.processors import GetSQL +from unitxt.processors import AddPrefix, GetSQL add_to_catalog( SequentialOperator( steps=[ + AddPrefix(field="prediction", prefix="SELECT "), GetSQL(field="prediction"), ] ), diff --git a/src/unitxt/db_utils.py b/src/unitxt/db_utils.py index 00a600da3c..49da12878a 100644 --- a/src/unitxt/db_utils.py +++ b/src/unitxt/db_utils.py @@ -47,10 +47,10 @@ def execute_query_local(db_path: str, query: str) -> Any: conn = sqlite3.connect(db_path) cursor = conn.cursor() cursor.execute(query) - return cursor.fetchall() + return cursor.fetchall(), None except sqlite3.Error as e: logger.info(f"Error executing SQL: {e}") - return None + return None, f"Error executing SQL: {e}" finally: if conn: conn.close() @@ -178,10 +178,10 @@ def execute_query(self, query: str) -> Any: try: cursor.execute(query) - return cursor.fetchall() + return cursor.fetchall(), None except sqlite3.Error as e: logger.info(f"Error executing SQL: {e}") - return None + return None, f"Error executing SQL: {e}" finally: conn.close() @@ -196,7 +196,7 @@ def execute_query_remote( max_retries: int = 3, retry_delay: int = 5, # seconds timeout: int = 30, # seconds -) -> Optional[dict]: +) -> (Optional[dict], str): """Executes a query against the remote database, with retries for certain exceptions.""" headers = { "Content-Type": "application/json", @@ -214,7 +214,7 @@ def execute_query_remote( timeout=timeout, ) response.raise_for_status() - return response.json() + return response.json(), None except retryable_exceptions as e: retries += 1 @@ -225,7 +225,10 @@ def execute_query_remote( time.sleep(retry_delay) else: logger.error(f"Max retries ({max_retries}) exceeded for query: {query}") - return None + return ( + None, + f"Max retries ({max_retries}) exceeded for query: {query} - Error: {e!s}", + ) except requests.exceptions.HTTPError as e: if e.response.status_code >= 500: @@ -239,16 +242,22 @@ def execute_query_remote( logger.error( f"Max retries ({max_retries}) exceeded for query: {query}" ) - return None + return ( + None, + f"Max retries ({max_retries}) exceeded for query: {query} - Error: {e!s}", + ) else: logger.error(f"HTTP Error on attempt {retries}: {e}") - return None + return ( + None, + f"HTTP Error on attempt {retries}: {e}", + ) except Exception as e: logger.error(f"Unexpected error on attempt {retries}: {e}") - return None + return (None, f"Unexpected error on attempt {retries}: {e}") - return None + return None, "Unknown Error in SQL execution" class RemoteDatabaseConnector(DatabaseConnector): diff --git a/src/unitxt/loaders.py b/src/unitxt/loaders.py index 7d6938afab..f34dc43bfe 100644 --- a/src/unitxt/loaders.py +++ b/src/unitxt/loaders.py @@ -180,6 +180,7 @@ def process(self) -> MultiStream: def get_splits(self): return list(self().keys()) + class LazyLoader(Loader): split: Optional[str] = NonPositionalField(default=None) @@ -191,9 +192,7 @@ def get_splits(self) -> List[str]: def split_generator(self, split: str) -> Generator: pass - def load_iterables( - self - ) -> Union[Dict[str, DynamicStream], IterableDatasetDict]: + def load_iterables(self) -> Union[Dict[str, DynamicStream], IterableDatasetDict]: if self.split is not None: splits = [self.split] else: @@ -339,11 +338,12 @@ def get_splits(self): dataset = self.load_dataset( split=None, disable_memory_caching=True, streaming=True ) - except NotImplementedError: # streaming is not supported for zipped files so we load without streaming + except ( + NotImplementedError + ): # streaming is not supported for zipped files so we load without streaming dataset = self.load_dataset(split=None, streaming=False) return list(dataset.keys()) - def split_generator(self, split: str) -> Generator: if self.get_limit() is not None: self.log_limited_loading() @@ -436,16 +436,14 @@ def split_generator(self, split: str) -> Generator: self.log_limited_loading() try: - dataset = reader( - self.files[split], **self.get_args() - ).to_dict("records") + dataset = reader(self.files[split], **self.get_args()).to_dict( + "records" + ) except ValueError: import fsspec with fsspec.open(self.files[split], mode="rt") as f: - dataset = reader( - f, **self.get_args() - ).to_dict("records") + dataset = reader(f, **self.get_args()).to_dict("records") except Exception as e: logger.debug(f"Attempt csv load {attempt + 1} failed: {e}") if attempt < settings.loaders_max_retries - 1: diff --git a/src/unitxt/metrics.py b/src/unitxt/metrics.py index dd1f98b554..307be76968 100644 --- a/src/unitxt/metrics.py +++ b/src/unitxt/metrics.py @@ -392,7 +392,6 @@ def bootstrap(self, data: List[Any], score_names: List[str]): return result - IntermediateType = TypeVar("IntermediateType") PredictionType = TypeVar("PredictionType") @@ -2296,13 +2295,11 @@ def verify(self): Documentation.HUGGINGFACE_METRICS, ) - assert ( - self.hf_additional_input_fields is None - or isoftype(self.hf_additional_input_fields, List[str]) + assert self.hf_additional_input_fields is None or isoftype( + self.hf_additional_input_fields, List[str] ), f"Argument hf_additional_input_fields should be either None or List[str]. It is now: {self.hf_additional_input_fields}." - assert ( - self.hf_additional_input_fields_pass_one_value is None - or isoftype(self.hf_additional_input_fields_pass_one_value, List[str]) + assert self.hf_additional_input_fields_pass_one_value is None or isoftype( + self.hf_additional_input_fields_pass_one_value, List[str] ), f"Argument hf_additional_input_fields_pass_one_value should be either None or List[str]. It is now: {self.hf_additional_input_fields_pass_one_value}." return super().verify() @@ -2876,8 +2873,8 @@ def compute( labels=labels_param, ) if isinstance(result[self.metric], numpy.ndarray): - assert ( - len(result[self.metric]) == len(labels) + assert len(result[self.metric]) == len( + labels ), f"F1 result ({result[self.metric]}) has more entries than labels ({labels})" final_result = {self.main_score: nan_mean(result[self.metric])} for i, label in enumerate(labels): @@ -3840,9 +3837,9 @@ class LlamaIndexLLMMetric(InstanceMetric): prediction_type = str reduction_map: Dict[str, List[str]] = None openai_models: List[str] = ["gpt-3.5-turbo"] - anthropic_models: List[ - str - ] = [] # this is here for the sake of documentation for future models + anthropic_models: List[str] = ( + [] + ) # this is here for the sake of documentation for future models mock_models: List[str] = ["mock"] external_api_models = openai_models + anthropic_models data_classification_policy = ["public"] @@ -5636,9 +5633,9 @@ def prepare(self): def create_ensemble_scores(self, instance): score = self.ensemble(instance) - instance[ - "prediction" - ] = score # We use here the prediction field to pass the score to the compute method. + instance["prediction"] = ( + score # We use here the prediction field to pass the score to the compute method. + ) return instance def ensemble(self, instance): @@ -5860,6 +5857,7 @@ class RiskType(str, Enum): AGENTIC = "agentic_risk" CUSTOM_RISK = "custom_risk" + class GraniteGuardianBase(InstanceMetric): """Return metric for different kinds of "risk" from the Granite-3.0 Guardian model.""" @@ -5923,7 +5921,12 @@ def prepare(self): def verify(self): super().verify() - assert self.risk_type == RiskType.CUSTOM_RISK or self.risk_name in self.available_risks[self.risk_type], UnitxtError(f"The risk \'{self.risk_name}\' is not a valid \'{' '.join([word[0].upper() + word[1:] for word in self.risk_type.split('_')])}\'") + assert ( + self.risk_type == RiskType.CUSTOM_RISK + or self.risk_name in self.available_risks[self.risk_type] + ), UnitxtError( + f"The risk '{self.risk_name}' is not a valid '{' '.join([word[0].upper() + word[1:] for word in self.risk_type.split('_')])}'" + ) @abstractmethod def verify_granite_guardian_config(self, task_data): @@ -6026,8 +6029,10 @@ def get_probabilities(self, top_tokens_list): dim=0, ).numpy() + class GraniteGuardianUserRisk(GraniteGuardianBase): risk_type = RiskType.USER_MESSAGE + def verify_granite_guardian_config(self, task_data): # User message risks only require the user message field and are the same as the assistant message risks, except for jailbreak assert self.user_message_field in task_data, UnitxtError( @@ -6039,32 +6044,34 @@ def process_input_fields(self, task_data): messages += self.create_message("user", task_data[self.user_message_field]) return messages + class GraniteGuardianAssistantRisk(GraniteGuardianBase): risk_type = RiskType.ASSISTANT_MESSAGE + def verify_granite_guardian_config(self, task_data): assert ( - self.assistant_message_field in task_data - and self.user_message_field in task_data - ), UnitxtError( - f'Task data must contain "{self.assistant_message_field}" and "{self.user_message_field}" fields' - ) + self.assistant_message_field in task_data + and self.user_message_field in task_data + ), UnitxtError( + f'Task data must contain "{self.assistant_message_field}" and "{self.user_message_field}" fields' + ) def process_input_fields(self, task_data): messages = [] messages += self.create_message("user", task_data[self.user_message_field]) messages += self.create_message( - "assistant", task_data[self.assistant_message_field] - ) + "assistant", task_data[self.assistant_message_field] + ) return messages + class GraniteGuardianRagRisk(GraniteGuardianBase): risk_type = RiskType.RAG def verify_granite_guardian_config(self, task_data): if self.risk_name == "context_relevance": assert ( - self.context_field in task_data - and self.user_message_field in task_data + self.context_field in task_data and self.user_message_field in task_data ), UnitxtError( f'Task data must contain "{self.context_field}" and "{self.user_message_field}" fields' ) @@ -6086,55 +6093,53 @@ def verify_granite_guardian_config(self, task_data): def process_input_fields(self, task_data): messages = [] if self.risk_name == "context_relevance": - messages += self.create_message( - "user", task_data[self.user_message_field] - ) - messages += self.create_message( - "context", task_data[self.context_field] - ) + messages += self.create_message("user", task_data[self.user_message_field]) + messages += self.create_message("context", task_data[self.context_field]) elif self.risk_name == "groundedness": - messages += self.create_message( - "context", task_data[self.context_field] - ) + messages += self.create_message("context", task_data[self.context_field]) messages += self.create_message( "assistant", task_data[self.assistant_message_field] ) elif self.risk_name == "answer_relevance": - messages += self.create_message( - "user", task_data[self.user_message_field] - ) + messages += self.create_message("user", task_data[self.user_message_field]) messages += self.create_message( "assistant", task_data[self.assistant_message_field] ) return messages + + class GraniteGuardianAgenticRisk(GraniteGuardianBase): risk_type = RiskType.AGENTIC + def verify_granite_guardian_config(self, task_data): assert ( - self.tools_field in task_data - and self.user_message_field in task_data - and self.assistant_message_field in task_data - ), UnitxtError( - f'Task data must contain "{self.tools_field}", "{self.assistant_message_field}" and "{self.user_message_field}" fields' - ) + self.tools_field in task_data + and self.user_message_field in task_data + and self.assistant_message_field in task_data + ), UnitxtError( + f'Task data must contain "{self.tools_field}", "{self.assistant_message_field}" and "{self.user_message_field}" fields' + ) def process_input_fields(self, task_data): messages = [] messages += self.create_message( - "tools", json.loads(task_data[self.tools_field]) - ) + "tools", json.loads(task_data[self.tools_field]) + ) messages += self.create_message("user", task_data[self.user_message_field]) messages += self.create_message( "assistant", task_data[self.assistant_message_field] ) return messages + class GraniteGuardianCustomRisk(GraniteGuardianBase): risk_type = RiskType.CUSTOM_RISK def verify(self): super().verify() - assert self.risk_type is not None, UnitxtError("In a custom risk, risk_type must be defined") + assert self.risk_type is not None, UnitxtError( + "In a custom risk, risk_type must be defined" + ) def verify_granite_guardian_config(self, task_data): # even though this is a custom risks, we will limit the @@ -6142,34 +6147,31 @@ def verify_granite_guardian_config(self, task_data): # was trained with: user, assistant, context & tools. # we just checked whether at least one of them is provided assert ( - self.tools_field in task_data - or self.user_message_field in task_data - or self.assistant_message_field in task_data - or self.context_field in task_data - ), UnitxtError( - f'Task data must contain at least one of"{self.tools_field}", "{self.assistant_message_field}", "{self.user_message_field}" or "{self.context_field}" fields' - ) + self.tools_field in task_data + or self.user_message_field in task_data + or self.assistant_message_field in task_data + or self.context_field in task_data + ), UnitxtError( + f'Task data must contain at least one of"{self.tools_field}", "{self.assistant_message_field}", "{self.user_message_field}" or "{self.context_field}" fields' + ) def process_input_fields(self, task_data): messages = [] if self.context_field in task_data: - messages += self.create_message( - "context", task_data[self.context_field] - ) + messages += self.create_message("context", task_data[self.context_field]) if self.tools_field in task_data: messages += self.create_message( "tools", json.loads(task_data[self.tools_field]) ) if self.user_message_field in task_data: - messages += self.create_message( - "user", task_data[self.user_message_field] - ) + messages += self.create_message("user", task_data[self.user_message_field]) if self.assistant_message_field in task_data: messages += self.create_message( "assistant", task_data[self.assistant_message_field] ) return messages + RISK_TYPE_TO_CLASS: Dict[RiskType, GraniteGuardianBase] = { RiskType.USER_MESSAGE: GraniteGuardianUserRisk, RiskType.ASSISTANT_MESSAGE: GraniteGuardianAssistantRisk, @@ -6177,18 +6179,96 @@ def process_input_fields(self, task_data): RiskType.AGENTIC: GraniteGuardianAgenticRisk, } -class ExecutionAccuracy(InstanceMetric): - reduction_map = {"mean": ["execution_accuracy"]} - main_score = "execution_accuracy" - ci_scores = ["execution_accuracy"] + +class SQLExecutionAccuracy(InstanceMetric): + reduction_map = { + "mean": [ + "execution_accuracy", + "non_empty_execution_accuracy", + "subset_non_empty_execution_result", + "non_empty_gold_df", + "gold_sql_runtime", + "predicted_sql_runtime", + "pred_to_gold_runtime_ratio", + "gold_error", + "predicted_error", + ] + } + main_score = "non_empty_execution_accuracy" + ci_scores = [ + "execution_accuracy", + "non_empty_execution_accuracy", + "subset_non_empty_execution_result", + "gold_sql_runtime", + "predicted_sql_runtime", + ] prediction_type = "Any" # string representation is compared sql_timeout = 100.0 _requirements_list = ["sqlglot", "func_timeout"] + @staticmethod + def compare_dfs_ignore_colnames(df1, df2): + """Compares two DataFrames based on row content, ignoring column names. + + Args: + df1 (pd.DataFrame): Pandas DataFrame 1 to compare. + df2 (pd.DataFrame): Pandas DataFrame 2 to compare. + + Returns: + True if the DataFrames have the same content (ignoring column names), + False otherwise. + """ + df1.fillna(0, inplace=True) + df2.fillna(0, inplace=True) + + if df1.shape != df2.shape: + return False + + # run over all columns of d11, + # and see if there is a columns in df2 that matches it, + # if not return False, if all the columns worked return tue + for df1_col in df1.columns: + col_matched = False + for df2_col in df2.columns: + if all(df1[df1_col].values == df2[df2_col].values): + col_matched = True + if not col_matched: + return False + + return True + + @staticmethod + def is_subset_ignore_colnames(df1, df2): + """Checks if df1 is a subset of df2 based on row content, ignoring column names. + + Args: + df1: Pandas DataFrame 1 to compare. + df2: Pandas DataFrame 2 to compare. + + Returns: + True if df1 is a subset of df2 based on column values, + False otherwise. + """ + if df1.shape[1] > df2.shape[1]: + return False + + # Convert each column to a tuple of values (you could also use a Series.tolist(), etc.) + df1_cols = [tuple(df1.iloc[:, i]) for i in range(df1.shape[1])] + df2_cols = [tuple(df2.iloc[:, j]) for j in range(df2.shape[1])] + df2_cols_count = Counter(df2_cols) + for col in df1_cols: + if df2_cols_count[col] > 0: + df2_cols_count[col] -= 1 + else: + return False + + return True + @staticmethod def equivalent_sqls(expected: str, generated: str) -> int: + """Checks if SQL queries are equivalent using SQLGlot parsing, so we don't run them.""" from sqlglot import diff, parse_one from sqlglot.optimizer import optimize @@ -6200,61 +6280,161 @@ def equivalent_sqls(expected: str, generated: str) -> int: return 1 if sql_diff == 0 else 0 - def run_sql_and_match(self, predicted_sql: str, gold_sql: str, connector) -> int: - """Runs SQL queries using the provided connector and checks if the results match.""" - if predicted_sql.lower().strip() == gold_sql.lower().strip(): - return 1 # if the SQLs are exactly the same, return 1 + def get_sql_execution_results( + self, predicted_sql: str, gold_sql: str, connector + ) -> (int, int, int, int, int, int, int, int, int, str, str, str): + """Runs SQL queries using the provided connector and gets scores and results. + + Args: + predicted_sql (str): predicted SQL query + gold_sql (str): gold reference SQL query + connector: database connector + + Returns: + a 12-tuple of + 1. execution_result: if df responses match + 2. non_empty_execution_result: if dfs are non-empty and match + 3. subset_non_empty_execution_result: if non-empty dfs and gt df subset of predicted df + 4. non_empty_gold_df: if gt df is non-empty + 5. gold_sql_runtime: ground truth query runtime + 6. predicted_sql_runtime: predicted query runtime + 7. pred_to_gold_runtime_ratio: ratio of predicted query runtime to gt query runtime + 8. gold_error: if gt has an error + 9. predicted_error: if predicted query has an error + 10. ground truth dataframe + 11. predicted query's dataframe + 12. error message (if any) + """ + import time + from func_timeout import func_timeout + + gold_res = None + gold_error = "" + gold_sql_runtime = 0 + try: + start_time = time.perf_counter() + gold_res, gold_error = func_timeout( + self.sql_timeout, + connector.execute_query, + args=(gold_sql,), + ) + end_time = time.perf_counter() + gold_sql_runtime = end_time - start_time + except Exception as e: + # raise OSError( + # "Error executing gold SQL, if gold does not execute metric should fail" + # ) from e + gold_error = f"Error executing gold SQL: {e}" + if gold_error is not None: + return ( + 0, + 0, + 0, + 0, + gold_sql_runtime, + 0, + 0, + 0, + 0, + "", + "", + "", + ) + + gold_df = pd.DataFrame(gold_res) + non_empty_gold_df = 0 if gold_df.empty else 1 + + no_execution_match_result = ( + 1, + non_empty_gold_df, + non_empty_gold_df, + non_empty_gold_df, + gold_sql_runtime, + 0, + 0, + 1, + 0, + gold_df.to_json(), + gold_df.to_json(), + "", + ) + if predicted_sql.lower().strip() == gold_sql.lower().strip(): + return no_execution_match_result try: if self.equivalent_sqls(gold_sql, predicted_sql): - return 1 + return no_execution_match_result except Exception as e: # Catch specific exceptions if possible logger.info( f"Error in equivalent_sqls: {e}. Treating as non-equivalent and going to test with the db." ) + pred_res = None + pred_error = "" + pred_sql_runtime = 0 try: - gold_res = connector.execute_query(gold_sql) + start_time = time.perf_counter() + pred_res, pred_error = func_timeout( + self.sql_timeout, + connector.execute_query, + args=(predicted_sql,), + ) + end_time = time.perf_counter() + pred_sql_runtime = end_time - start_time except Exception as e: - raise OSError( - "Error executing gold SQL, if gold does not execute metric should fail" - ) from e + pred_error = f"Error executing predicted SQL: {e}" + logger.info(pred_error) - try: - pred_res = connector.execute_query(predicted_sql) - except Exception as e: - logger.info(f"Error executing predicted SQL: {e}") - return 0 # if the predicted SQL fails to execute, result is 0 + pred_to_gold_runtime_ratio = ( + float(pred_sql_runtime) / gold_sql_runtime if gold_sql_runtime > 0 else 0 + ) if pred_res is None: - if gold_res is None: - return 1 - return 0 - - # if pred_res is dict with results take this as the result - if isinstance(pred_res, dict): - pred_res = pred_res["results"] - gold_res = gold_res["results"] + return ( + 0, + 0, + 0, + 0, + gold_sql_runtime, + pred_sql_runtime, + pred_to_gold_runtime_ratio, + 0, + 1, + "", + "", + pred_error, + ) - def normalize_tuple(tup): - """Normalizes a tuple by sorting its non-None elements. + predicted_df = pd.DataFrame(pred_res) - Args: - tup: The input tuple. + execution_result = ( + 1 if self.compare_dfs_ignore_colnames(predicted_df, gold_df) else 0 + ) - Returns: - A tuple with non-None elements sorted first, followed by None values. - """ - return sorted([str(item) for item in tup]) + subset_non_empty_execution_result = 0 + non_empty_execution_result = 0 + if non_empty_gold_df: + if execution_result == 1: + non_empty_execution_result = 1 + if self.is_subset_ignore_colnames(gold_df, predicted_df): + subset_non_empty_execution_result = 1 - return int( - sorted([normalize_tuple(t) for t in pred_res]) - == sorted([normalize_tuple(t) for t in gold_res]) + return ( + execution_result, + non_empty_execution_result, + subset_non_empty_execution_result, + non_empty_gold_df, + gold_sql_runtime, + pred_sql_runtime, + pred_to_gold_runtime_ratio, + 0, + 0, + gold_df.to_json(), + predicted_df.to_json(), + pred_error, ) def compute(self, references: List[Any], prediction: str, task_data: Dict) -> dict: - from func_timeout import FunctionTimedOut, func_timeout - predicted_sql = prediction execution_result: float = 0.0 @@ -6266,18 +6446,43 @@ def compute(self, references: List[Any], prediction: str, task_data: Dict) -> di db_connector = get_db_connector(task_data["db"]["db_type"])(task_data["db"]) - try: - execution_result = func_timeout( - self.sql_timeout, - self.run_sql_and_match, - args=(predicted_sql, references[0], db_connector), - ) # type: ignore - except FunctionTimedOut: - logger.error("QUERY TIMEOUT, returning score=0 for this instance") - execution_result = 0.0 - - result = {self.main_score: float(execution_result)} - logger.debug(f"Result: {result}") + logger.debug( + f"Starting to get SQL execution results over DB: {task_data['db']}" + ) + ( + execution_result, + non_empty_execution_result, + subset_non_empty_execution_result, + non_empty_gold_df, + gold_sql_runtime, + predicted_sql_runtime, + pred_to_gold_runtime_ratio, + gold_error, + predicted_error, + gold_df_json, + predicted_df_json, + error_message, + ) = self.get_sql_execution_results( + predicted_sql, references[0], db_connector + ) + + result = { + "execution_accuracy": float(execution_result), + "non_empty_execution_accuracy": float(non_empty_execution_result), + "subset_non_empty_execution_result": float( + subset_non_empty_execution_result + ), + "non_empty_gold_df": float(non_empty_gold_df), + "gold_sql_runtime": float(gold_sql_runtime), + "predicted_sql_runtime": float(predicted_sql_runtime), + "pred_to_gold_runtime_ratio": float(pred_to_gold_runtime_ratio), + "gold_error": float(gold_error), + "predicted_error": float(predicted_error), + "error_message": str(error_message), + "gold_df_json": str(gold_df_json), + "predicted_df_json": str(predicted_df_json), + } result["score"] = result[self.main_score] result["score_name"] = self.main_score + logger.debug(f"Result: {result}") return result diff --git a/src/unitxt/test_utils/metrics.py b/src/unitxt/test_utils/metrics.py index 25c23c3e6e..430101e38f 100644 --- a/src/unitxt/test_utils/metrics.py +++ b/src/unitxt/test_utils/metrics.py @@ -88,6 +88,7 @@ def test_metric( instance_targets: List[dict], global_target: dict, task_data: Optional[List[dict]] = None, + score_keys_to_ignore: Optional[List[str]] = None, ): if settings.test_metric_disable: logger.info( @@ -110,6 +111,7 @@ def test_metric( instance_targets, global_outputs=outputs[0]["score"]["global"], instance_outputs=[output["score"]["instance"] for output in outputs], + score_keys_to_ignore=score_keys_to_ignore, ) logger.info("Metric tested successfully!") @@ -121,8 +123,17 @@ def check_scores( instance_targets: List[dict], global_outputs: dict, instance_outputs: List[dict], + score_keys_to_ignore: Optional[List[str]] = None, ): errors = [] + if score_keys_to_ignore: + for key in score_keys_to_ignore: + global_target.pop(key, None) + global_outputs.pop(key, None) + for instance_output in instance_outputs: + instance_output.pop(key, None) + for instance_target in instance_targets: + instance_target.pop(key, None) global_score = round_floats(global_outputs) if not dict_equal(global_score, global_target): errors.append( diff --git a/tests/library/test_db_utils.py b/tests/library/test_db_utils.py index aae085f9f5..31f2470b1b 100644 --- a/tests/library/test_db_utils.py +++ b/tests/library/test_db_utils.py @@ -66,7 +66,7 @@ def test_execute_query_failure(self, mock_post): mock_post.side_effect = requests.exceptions.RequestException("API Error") connector = RemoteDatabaseConnector(self.db_config) - result = connector.execute_query("SELECT * FROM table1") + result, _ = connector.execute_query("SELECT * FROM table1") self.assertIsNone(result) @@ -142,7 +142,7 @@ def test_execute_query(self): return_value=self.db_path, ): connector = LocalSQLiteConnector(self.db_config) - result = connector.execute_query("SELECT * FROM table1") + result, _ = connector.execute_query("SELECT * FROM table1") self.assertEqual(len(result), 2) self.assertEqual(result[0], ("value1", 1)) self.assertEqual(result[1], ("value2", 2)) @@ -153,7 +153,7 @@ def test_execute_query_error(self): return_value=self.db_path, ): connector = LocalSQLiteConnector(self.db_config) - result = connector.execute_query("SELECT * FROM non_existent_table") + result, _ = connector.execute_query("SELECT * FROM non_existent_table") self.assertIsNone(result) def test_download_database_unsupported_db(self): @@ -224,7 +224,7 @@ def test_get_table_schema_with_selected_tables(self): def test_execute_query_success(self): connector = InMemoryDatabaseConnector(self.db_config) - result = connector.execute_query("SELECT * FROM users WHERE age > 30") + result, _ = connector.execute_query("SELECT * FROM users WHERE age > 30") expected_result = [ (3, "Charlie", "charlie@example.com", 40, "Chicago"), (4, "David", "david@example.com", 35, "New York"), @@ -234,8 +234,10 @@ def test_execute_query_success(self): def test_execute_query_failure(self): connector = InMemoryDatabaseConnector(self.db_config) - result = connector.execute_query("SELECT * FROM non_existent_table") - + result, error = connector.execute_query("SELECT * FROM non_existent_table") + self.assertEqual( + error, "Error executing SQL: no such table: non_existent_table" + ) self.assertIsNone(result) def test_execute_complex_query(self): @@ -247,7 +249,7 @@ def test_execute_complex_query(self): WHERE u.city = 'Los Angeles' ORDER BY o.quantity DESC """ - result = connector.execute_query(query) + result, _ = connector.execute_query(query) expected_result = [ ("Bob", "Keyboard", 3), ("Eva", "Mouse", 2), @@ -265,7 +267,7 @@ def test_execute_query_with_aggregation(self): GROUP BY u.city ORDER BY u.city ASC """ - result = connector.execute_query(query) + result, _ = connector.execute_query(query) expected_result = [ ("Chicago", 40.0), ("Los Angeles", 26.5), @@ -283,7 +285,7 @@ def test_execute_query_with_sum_and_having(self): HAVING SUM(o.price) > 300 ORDER BY u.name DESC """ - result = connector.execute_query(query) + result, _ = connector.execute_query(query) expected_result = [("Eva", 1654.0), ("Charlie", 315.0), ("Alice", 1225.5)] self.assertEqual(result, expected_result) @@ -291,5 +293,6 @@ def test_execute_query_with_sum_and_having(self): def test_execute_query_empty_table(self): self.db_config["data"]["empty_table"] = {"columns": ["id"], "rows": []} connector = InMemoryDatabaseConnector(self.db_config) - result = connector.execute_query("SELECT * FROM empty_table") + result, error = connector.execute_query("SELECT * FROM empty_table") self.assertEqual(result, []) + self.assertIsNone(error) diff --git a/tests/library/test_metrics.py b/tests/library/test_metrics.py index 85986e327f..30511cbeef 100644 --- a/tests/library/test_metrics.py +++ b/tests/library/test_metrics.py @@ -16,7 +16,6 @@ CharEditDistance, CharEditDistanceAccuracy, Detector, - ExecutionAccuracy, F1Binary, F1BinaryPosOnly, F1Fast, @@ -60,6 +59,7 @@ RelaxedCorrectness, RocAuc, Rouge, + SQLExecutionAccuracy, StringContainment, StringContainmentRatio, TokenOverlap, @@ -1202,7 +1202,9 @@ def test_grouped_instance_metrics(self): score_prefix, metric.main_score, ] - ).replace("__", "_") # for the case of empty score_prefix + ).replace( + "__", "_" + ) # for the case of empty score_prefix self.assertTrue( any( @@ -1391,7 +1393,7 @@ def test_perplexity_with_prefix(self): ) def test_execution_accuracy_correct_query_mock_db(self): - metric = ExecutionAccuracy() + metric = SQLExecutionAccuracy() predictions = ["SELECT name FROM employees WHERE department = 'Sales'"] references = ["SELECT name FROM employees WHERE department = 'Sales';"] task_data = [ @@ -1417,7 +1419,7 @@ def test_execution_accuracy_correct_query_mock_db(self): self.assertEqual(1.0, outputs["score"]) def test_execution_accuracy_different_db_schema(self): - metric = ExecutionAccuracy() + metric = SQLExecutionAccuracy() predictions = [ "SELECT product_name, price FROM products WHERE category = 'Electronics'" ] @@ -1453,7 +1455,7 @@ def test_execution_accuracy_different_db_schema(self): self.assertEqual(1.0, outputs["score"]) def test_execution_accuracy_multiple_tables(self): - metric = ExecutionAccuracy() + metric = SQLExecutionAccuracy() predictions = [ "SELECT o.order_id, c.name FROM orders AS o JOIN customers AS c ON o.customer_id = c.customer_id WHERE o.status = 'Shipped'" ] @@ -1491,7 +1493,7 @@ def test_execution_accuracy_multiple_tables(self): self.assertEqual(1.0, outputs["score"]) def test_execution_accuracy_empty_result(self): - metric = ExecutionAccuracy() + metric = SQLExecutionAccuracy() predictions = ["SELECT name FROM employees WHERE department = 'HR'"] references = ["SELECT name FROM employees WHERE department = 'HR';"] task_data = [ @@ -1514,10 +1516,10 @@ def test_execution_accuracy_empty_result(self): ] outputs = metric.compute(references, predictions[0], task_data[0]) - self.assertEqual(1.0, outputs["score"]) + self.assertEqual(0.0, outputs["score"]) def test_execution_accuracy_aggregation_query(self): - metric = ExecutionAccuracy() + metric = SQLExecutionAccuracy() predictions = ["SELECT AVG(salary) FROM employees"] references = ["SELECT AVG(salary) FROM employees;"] task_data = [ @@ -1543,7 +1545,7 @@ def test_execution_accuracy_aggregation_query(self): self.assertEqual(1.0, outputs["score"]) def test_execution_accuracy_incorrect_query(self): - metric = ExecutionAccuracy() + metric = SQLExecutionAccuracy() predictions = [ "SELECT nme FROM employees WHERE department = 'Sales'" ] # Incorrect column name 'nme' @@ -1852,7 +1854,9 @@ def _test_grouped_instance_confidence_interval( score_prefix, metric.main_score, ] - ).replace("__", "_") # for the case of empty score_prefix + ).replace( + "__", "_" + ) # for the case of empty score_prefix if input_expected_global_result_is_none: expected_global_result = { @@ -1924,13 +1928,15 @@ def test_task_based_llm_as_judge_metric(self): ) actual_scores = [output["score"] for output in outputs] main_score = f"{model_label}_{metric_label}" - instance_targets = [ - { - main_score: 0.0, - "score": 0.0, - "score_name": main_score, - main_score + "_judge_raw_output": "no", - main_score + "_judge_raw_input": """<|begin_of_text|><|start_header_id|>system<|end_header_id|> + instance_targets = ( + [ + { + main_score: 0.0, + "score": 0.0, + "score_name": main_score, + main_score + "_judge_raw_output": "no", + main_score + + "_judge_raw_input": """<|begin_of_text|><|start_header_id|>system<|end_header_id|> You are given a question, the corresponding ground-truth answer and a prediction from a model. Compare the "Ground-truth answer" and the "Prediction" to determine whether the prediction correctly answers the question. There should be no contradicting statements in the prediction. The prediction may contain extra information. If the prediction states something as a possibility, treat it as a definitive answer. @@ -1946,8 +1952,10 @@ def test_task_based_llm_as_judge_metric(self): <|eot_id|><|start_header_id|>assistant<|end_header_id|> Answer: """, - } - ] * 2 + } + ] + * 2 + ) global_target = { main_score: 0.0, "score": 0.0, diff --git a/utils/.secrets.baseline b/utils/.secrets.baseline index 1226ff7479..be78bc45ad 100644 --- a/utils/.secrets.baseline +++ b/utils/.secrets.baseline @@ -151,7 +151,7 @@ "filename": "src/unitxt/loaders.py", "hashed_secret": "840268f77a57d5553add023cfa8a4d1535f49742", "is_verified": false, - "line_number": 595, + "line_number": 593, "is_secret": false } ], @@ -184,5 +184,5 @@ } ] }, - "generated_at": "2025-02-12T09:37:42Z" + "generated_at": "2025-02-12T13:21:22Z" }