diff --git a/tests/test_kat.py b/tests/test_kat.py index 7bbe0bea..70585482 100644 --- a/tests/test_kat.py +++ b/tests/test_kat.py @@ -6,8 +6,17 @@ from typing import Tuple, Callable, Dict, Iterable, Any, List -def load_kat_data(): - """Load KAT values from YAML file""" +def load_kat_data() -> Dict[str, Any]: + """Load Known Answer Test (KAT) values from YAML file. + + Returns: + Dict[str, Any]: Dictionary containing test cases and expected results + loaded from the YAML file. + + Note: + The YAML file should be named 'kat.yaml' and located in the same + directory as this test file. + """ test_dir = Path(__file__).parent kat_yaml_path = test_dir / "kat.yaml" with kat_yaml_path.open("r") as file: @@ -17,7 +26,32 @@ def load_kat_data(): def extract_paths( nested_dict: Dict[str, Any], path: Tuple[str, ...] = () ) -> Iterable[Tuple[Tuple[str, ...], List[Tuple[Any, Any]]]]: - """Recursively extracts internal estimators paths and their corresponding inputs/outputs tuples.""" + """Recursively extract internal estimator paths and their corresponding test cases. + + This function traverses the input dictionary and yields tuples containing the + internal estimator path and its input/output test cases for each leaf node + containing an "inputs_with_expected_outputs" key. + + Args: + nested_dict (Dict[str, Any]): The nested dictionary to traverse. + path (Tuple[str, ...]): The current path in the dictionary (used for recursion). + + Yields: + Tuple[Tuple[str, ...], List[Tuple[Any, Any]]]: A tuple containing: + - The path to the internal estimator + - List of (input, expected_output) pairs for testing + + Example: + For a dictionary like: + { + "estimator1": { + "kat_generator": { + "inputs_with_expected_outputs": [(1, 2), (3, 4)] + } + } + } + It yields: (('estimator1', 'kat_generator'), [(1, 2), (3, 4)]) + """ for key, value in nested_dict.items(): current_path = path + (key,) if isinstance(value, dict): @@ -28,36 +62,56 @@ def extract_paths( def import_internal_estimator(internal_estimator_path: Tuple[str, ...]) -> Callable: - """Imports an internal estimator function from a specified path.""" + """Import an internal estimator function from a specified path. + + This function dynamically imports an estimator function based on the provided path, + relative to the caller's location in the package hierarchy. + + Args: + internal_estimator_path (Tuple[str, ...]): Path to the internal estimator function. + The last element is the function name, preceding elements form the module path. + + Returns: + Callable: The imported estimator function. + + Example: + estimator = import_internal_estimator(('sdfq', 'lee-brickell')) + # Imports the lee-brickell function from the sdfq module + """ caller_frame = inspect.currentframe().f_back caller_module = inspect.getmodule(caller_frame) package_parts = caller_module.__name__.split(".") caller_path = ".".join(package_parts[:-1] + ["internal_estimators"]) - single_case_module_path = ".".join(internal_estimator_path[:-1]) + module_path = ".".join(internal_estimator_path[:-1]) function_name = internal_estimator_path[-1] - import_modpath = f"{caller_path}.{single_case_module_path}" - module = importlib.import_module(import_modpath) + import_path = f"{caller_path}.{module_path}" + module = importlib.import_module(import_path) return getattr(module, function_name) -def generate_test_cases(): - """Generate test cases from KAT data.""" +def generate_test_cases() -> List[Tuple[str, Tuple[str, ...], Any, float]]: + """Generate test cases from KAT data for parametrized testing. + + Returns: + List[Tuple[str, Tuple[str, ...], Any, float]]: List of test cases, each containing: + - Estimator name (human-readable) + - Estimator path (for importing) + - Input value + - Expected output value + """ kat = load_kat_data() test_cases = [] - for internal_estimator_path, inputs_with_outputs in extract_paths(kat): - internal_estimator_name = f"{internal_estimator_path[-1]} from {internal_estimator_path[-2].upper()}Estimator" - for input_val, expected_output in inputs_with_outputs: + for estimator_path, test_pairs in extract_paths(kat): + estimator_name = ( + f"{estimator_path[-1]} from {estimator_path[-2].upper()}Estimator" + ) + for input_val, expected_output in test_pairs: test_cases.append( - ( - internal_estimator_name, - internal_estimator_path, - input_val, - expected_output, - ) + (estimator_name, estimator_path, input_val, expected_output) ) return test_cases @@ -71,9 +125,22 @@ def test_kat( estimator_path: Tuple[str, ...], input_val: Any, expected_output: float, -): - """Test each case, including the estimation computation.""" - # Import and run the estimator function here so it's part of the parallel execution +) -> None: + """Execute Known Answer Tests for internal estimators. + + This test verifies that internal estimators produce results within + acceptable tolerance of known correct values. + + Args: + estimator_name (str): Human-readable name of the estimator + estimator_path (Tuple[str, ...]): Import path for the estimator + input_val (Any): Test input value + expected_output (float): Expected output value + + Raises: + AssertionError: If the actual output differs from expected output + by more than the allowed tolerance (epsilon) + """ estimator_func = import_internal_estimator(estimator_path) actual_output, epsilon = estimator_func(input_val)