diff --git a/pyproject.toml b/pyproject.toml index ccdf1b0..e6b3408 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,8 @@ dependencies = [ "scikit-base>=0.8.0,<0.9.0", "scikit-learn>=1.5.1,<1.6.0", "scipy>=1.13,<1.14.0", - "packaging", + "packaging>=24.0,<24.2", + "pydantic>=2.0,<3.0", ] [project.optional-dependencies] diff --git a/src/tsbootstrap/registry/_lookup.py b/src/tsbootstrap/registry/_lookup.py index 29fb38d..204d532 100644 --- a/src/tsbootstrap/registry/_lookup.py +++ b/src/tsbootstrap/registry/_lookup.py @@ -1,125 +1,119 @@ -"""Registry lookup methods. +""" +Registry lookup methods. This module exports the following methods for registry lookup: -all_objects(object_types, filter_tags) - lookup and filtering of objects +- all_objects(object_types: Optional[Union[str, List[str]]] = None, + filter_tags: Optional[Dict[str, Union[str, List[str], bool]]] = None, + exclude_objects: Optional[Union[str, List[str]]] = None, + return_names: bool = True, + as_dataframe: bool = False, + return_tags: Optional[Union[str, List[str]]] = None, + suppress_import_stdout: bool = True) -> Union[List[Any], List[Tuple], pd.DataFrame] + Lookup and filtering of objects in the tsbootstrap registry. """ -# based on the sktime module of same name - -__author__ = ["fkiraly"] -# all_objects is based on the sklearn utility all_estimators - - from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union +import pandas as pd from skbase.base import BaseObject from skbase.lookup import all_objects as _all_objects -from tsbootstrap.registry._tags import OBJECT_TAG_REGISTER +from tsbootstrap.registry._tags import ( + OBJECT_TAG_REGISTER, + check_tag_is_valid, +) -VALID_OBJECT_TYPE_STRINGS = {x[1] for x in OBJECT_TAG_REGISTER} +VALID_OBJECT_TYPE_STRINGS: set = {tag.scitype for tag in OBJECT_TAG_REGISTER} def all_objects( - object_types=None, - filter_tags=None, - exclude_objects=None, - return_names=True, - as_dataframe=False, - return_tags=None, - suppress_import_stdout=True, -): - """Get a list of all objects from tsbootstrap. - - This function crawls the module and gets all classes that inherit + object_types: Optional[Union[str, List[str]]] = None, + filter_tags: Optional[ + Union[str, Dict[str, Union[str, List[str], bool]]] + ] = None, + exclude_objects: Optional[Union[str, List[str]]] = None, + return_names: bool = True, + as_dataframe: bool = False, + return_tags: Optional[Union[str, List[str]]] = None, + suppress_import_stdout: bool = True, +) -> Union[List[Any], List[Tuple], pd.DataFrame]: + """ + Get a list of all objects from tsbootstrap. + + This function crawls the module and retrieves all classes that inherit from tsbootstrap's and sklearn's base classes. - Not included are: the base classes themselves, classes defined in test modules. + Excluded from retrieval are: + - The base classes themselves + - Classes defined in test modules Parameters ---------- - object_types: str, list of str, optional (default=None) - Which kind of objects should be returned. - if None, no filter is applied and all objects are returned. - if str or list of str, strings define scitypes specified in search - only objects that are of (at least) one of the scitypes are returned - possible str values are entries of registry.BASE_CLASS_REGISTER (first col) - return_names: bool, optional (default=True) - if True, object class name is included in the all_objects() - return in the order: name, object class, optional tags, either as - a tuple or as pandas.DataFrame columns - if False, object class name is removed from the all_objects() - return. - filter_tags: dict of (str or list of str), optional (default=None) - For a list of valid tag strings, use the registry.all_tags utility. - subsets the returned objects as follows: - each key/value pair is statement in "and"/conjunction - key is tag name to sub-set on - value str or list of string are tag values - condition is "key must be equal to value, or in set(value)" - exclude_objects: str, list of str, optional (default=None) - Names of objects to exclude. - as_dataframe: bool, optional (default=False) - if True, all_objects will return a pandas.DataFrame with named - columns for all of the attributes being returned. - if False, all_objects will return a list (either a list of - objects or a list of tuples, see Returns) - return_tags: str or list of str, optional (default=None) - Names of tags to fetch and return each object's value of. - For a list of valid tag strings, use the registry.all_tags utility. - if str or list of str, - the tag values named in return_tags will be fetched for each - object and will be appended as either columns or tuple entries. - suppress_import_stdout : bool, optional. Default=True - whether to suppress stdout printout upon import. + object_types : Union[str, List[str]], optional (default=None) + Specifies which types of objects to return. + - If None, no filtering is applied and all objects are returned. + - If str or list of str, only objects matching the specified scitypes are returned. + Valid scitypes are entries in `registry.BASE_CLASS_REGISTER` (first column). + + filter_tags : Union[str, Dict[str, Union[str, List[str], bool]]], optional (default=None) + Dictionary or string to filter returned objects based on their tags. + - If a string, it is treated as a boolean tag filter with the value `True`. + - If a dictionary, each key-value pair represents a filter condition in an "AND" conjunction. + - Key is the tag name to filter on. + - Value is a string, list of strings, or boolean that the tag value must match or be within. + - Only objects satisfying all filter conditions are returned. + + exclude_objects : Union[str, List[str]], optional (default=None) + Names of objects to exclude from the results. + + return_names : bool, optional (default=True) + - If True, the object's class name is included in the returned results. + - If False, the class name is omitted. + + as_dataframe : bool, optional (default=False) + - If True, returns a pandas.DataFrame with named columns for all returned attributes. + - If False, returns a list (of objects or tuples). + + return_tags : Union[str, List[str]], optional (default=None) + - Names of tags to fetch and include in the returned results. + - If specified, tag values are appended as either columns or tuple entries. + + suppress_import_stdout : bool, optional (default=True) + Whether to suppress stdout printout upon import. Returns ------- - all_objects will return one of the following: - 1. list of objects, if return_names=False, and return_tags is None - 2. list of tuples (optional object name, class, ~optional object - tags), if return_names=True or return_tags is not None. - 3. pandas.DataFrame if as_dataframe = True - if list of objects: - entries are objects matching the query, - in alphabetical order of object name - if list of tuples: - list of (optional object name, object, optional object - tags) matching the query, in alphabetical order of object name, - where - ``name`` is the object name as string, and is an - optional return - ``object`` is the actual object - ``tags`` are the object's values for each tag in return_tags - and is an optional return. - if dataframe: - all_objects will return a pandas.DataFrame. - column names represent the attributes contained in each column. - "objects" will be the name of the column of objects, "names" - will be the name of the column of object class names and the string(s) - passed in return_tags will serve as column names for all columns of - tags that were optionally requested. + Union[List[Any], List[Tuple], pd.DataFrame] + Depending on the parameters: + 1. List of objects: + - Entries are objects matching the query, in alphabetical order of object name. + 2. List of tuples: + - Each tuple contains (optional object name, object class, optional object tags). + - Ordered alphabetically by object name. + 3. pandas.DataFrame: + - Columns represent the returned attributes. + - Includes "objects", "names", and any specified tag columns. Examples -------- >>> from tsbootstrap.registry import all_objects - >>> # return a complete list of objects as pd.Dataframe + >>> # Return a complete list of objects as a DataFrame >>> all_objects(as_dataframe=True) - >>> # return all bootstrap algorithms by filtering for object type + >>> # Return all bootstrap algorithms by filtering for object type >>> all_objects("bootstrap", as_dataframe=True) - >>> # return all bootstraps which are block bootstraps + >>> # Return all bootstraps which are block bootstraps >>> all_objects( - ... "bootstrap", + ... object_types="bootstrap", ... filter_tags={"bootstrap_type": "block"}, ... as_dataframe=True ... ) References ---------- - Adapted version of sktime's ``all_estimators``, - which is an evolution of scikit-learn's ``all_estimators`` + Adapted version of sktime's `all_estimators`, + which is an evolution of scikit-learn's `all_estimators`. """ MODULES_TO_IGNORE = ( "tests", @@ -129,33 +123,69 @@ def all_objects( "all", ) - result = [] + result: Union[List[Any], List[Tuple], pd.DataFrame] = [] ROOT = str( Path(__file__).parent.parent ) # tsbootstrap package root directory + # Prepare filter_tags if isinstance(filter_tags, str): + # Ensure the tag expects a boolean value + tag = next( + (t for t in OBJECT_TAG_REGISTER if t.name == filter_tags), None + ) + if not tag: + raise ValueError( + f"Tag '{filter_tags}' not found in OBJECT_TAG_REGISTER." + ) + if tag.value_type != "bool": + raise ValueError( + f"Tag '{filter_tags}' does not expect a boolean value." + ) filter_tags = {filter_tags: True} - filter_tags = filter_tags.copy() if filter_tags else None + elif isinstance(filter_tags, dict): + # Validate each tag in filter_tags + for key, value in filter_tags.items(): + try: + if not check_tag_is_valid(key, value): + raise ValueError( + f"Invalid value '{value}' for tag '{key}'." + ) + except KeyError as e: + raise ValueError( + f"Tag '{key}' not found in OBJECT_TAG_REGISTER." + ) from e + else: + filter_tags = None if object_types: + if isinstance(object_types, str): + object_types = [object_types] + # Validate object_types + invalid_types = set(object_types) - VALID_OBJECT_TYPE_STRINGS + if invalid_types: + raise ValueError( + f"Invalid object_types: {invalid_types}. Valid types are {VALID_OBJECT_TYPE_STRINGS}." + ) if filter_tags and "object_type" not in filter_tags: object_tag_filter = {"object_type": object_types} - elif filter_tags: - filter_tags_filter = filter_tags.get("object_type", []) - if isinstance(object_types, str): - object_types = [object_types] - object_tag_update = { - "object_type": object_types + filter_tags_filter - } - filter_tags.update(object_tag_update) - else: - object_tag_filter = {"object_type": object_types} - if filter_tags: filter_tags.update(object_tag_filter) + elif filter_tags and "object_type" in filter_tags: + existing_filter = filter_tags.get("object_type", []) + if isinstance(existing_filter, str): + existing_filter = [existing_filter] + elif isinstance(existing_filter, list): + pass + else: + raise ValueError( + f"Unexpected type for 'object_type' filter: {type(existing_filter)}" + ) + combined_filter = list(set(object_types + existing_filter)) + filter_tags["object_type"] = combined_filter else: - filter_tags = object_tag_filter + filter_tags = {"object_type": object_types} + # Retrieve objects using skbase's all_objects result = _all_objects( object_types=[BaseObject], filter_tags=filter_tags, diff --git a/src/tsbootstrap/registry/_tags.py b/src/tsbootstrap/registry/_tags.py index 998ad63..6a58769 100644 --- a/src/tsbootstrap/registry/_tags.py +++ b/src/tsbootstrap/registry/_tags.py @@ -1,102 +1,232 @@ -"""Register of estimator and object tags. +""" +Register of estimator and object tags. -Note for extenders: new tags should be entered in OBJECT_TAG_REGISTER. -No other place is necessary to add new tags. +Note for extenders: + New tags should be entered in `OBJECT_TAG_REGISTER`. + No other place is necessary to add new tags. This module exports the following: ---- -OBJECT_TAG_REGISTER - list of tuples - -each tuple corresponds to a tag, elements as follows: - 0 : string - name of the tag as used in the _tags dictionary - 1 : string - name of the scitype this tag applies to - must be in _base_classes.BASE_CLASS_SCITYPE_LIST - 2 : string - expected type of the tag value - should be one of: - "bool" - valid values are True/False - "int" - valid values are all integers - "str" - valid values are all strings - "list" - valid values are all lists of arbitrary elements - ("str", list_of_string) - any string in list_of_string is valid - ("list", list_of_string) - any individual string and sub-list is valid - ("list", "str") - any individual string or list of strings is valid - validity can be checked by check_tag_is_valid (see below) - 3 : string - plain English description of the tag - ---- - -OBJECT_TAG_TABLE - pd.DataFrame - OBJECT_TAG_REGISTER in table form, as pd.DataFrame - rows of OBJECT_TABLE correspond to elements in OBJECT_TAG_REGISTER - -OBJECT_TAG_LIST - list of string - elements are 0-th entries of OBJECT_TAG_REGISTER, in same order - ---- - -check_tag_is_valid(tag_name, tag_value) - checks whether tag_value is valid for tag_name +- OBJECT_TAG_REGISTER : List[Tag] + A list of Tag instances, each representing a tag with its attributes. + +- OBJECT_TAG_TABLE : pd.DataFrame + `OBJECT_TAG_REGISTER` in table form. + +- OBJECT_TAG_LIST : List[str] + List of tag names extracted from `OBJECT_TAG_REGISTER`. + +- check_tag_is_valid(tag_name: str, tag_value: Any) -> bool + Function to validate if a tag value is valid for a given tag name. """ -OBJECT_TAG_REGISTER = [ +from typing import Any, List, Tuple, Union + +import pandas as pd +from pydantic import BaseModel, field_validator + + +class Tag(BaseModel): + """ + Represents a single tag with its properties. + + Attributes + ---------- + name : str + Name of the tag as used in the _tags dictionary. + scitype : str + Name of the scitype this tag applies to. + value_type : Union[str, Tuple[str, Union[List[str], str]]] + Expected type of the tag value. + description : str + Plain English description of the tag. + """ + + name: str + scitype: str + value_type: Union[str, Tuple[str, Union[List[str], str]]] + description: str + + @field_validator("value_type") + def validate_value_type(self, v): + valid_base_types = {"bool", "int", "str", "list", "dict"} + if isinstance(v, str): + if v not in valid_base_types: + raise ValueError( + f"Invalid value_type: { + v}. Must be one of {valid_base_types}." + ) + elif isinstance(v, tuple): + if len(v) != 2: + raise ValueError( + "Tuple value_type must have exactly two elements." + ) + base, subtype = v + if base not in {"str", "list"}: + raise ValueError( + "First element of tuple must be 'str' or 'list'." + ) + if base == "str": + if not isinstance(subtype, list) or not all( + isinstance(item, str) for item in subtype + ): + raise ValueError( + "Second element must be a list of strings when base is 'str'." + ) + elif ( + base == "list" + and not ( + isinstance(subtype, list) + and all(isinstance(item, str) for item in subtype) + ) + and not isinstance(subtype, str) + ): + raise ValueError( + "Second element must be a list of strings or 'str' when base is 'list'." + ) + else: + raise TypeError("value_type must be either a string or a tuple.") + return v + + +# Define the OBJECT_TAG_REGISTER with Tag instances +OBJECT_TAG_REGISTER: List[Tag] = [ # -------------------------- - # all objects and estimators + # All objects and estimators # -------------------------- - ( - "object_type", - "object", - "str", - "type of object, e.g., 'regressor', 'transformer'", + Tag( + name="object_type", + scitype="object", + value_type="str", + description="Type of object, e.g., 'regressor', 'transformer'.", ), - ( - "python_version", - "object", - "str", - "python version specifier (PEP 440) for estimator, or None = all versions ok", + Tag( + name="python_version", + scitype="object", + value_type="str", + description="Python version specifier (PEP 440) for estimator, or None for all versions.", ), - ( - "python_dependencies", - "object", - ("list", "str"), - "python dependencies of estimator as str or list of str", + Tag( + name="python_dependencies", + scitype="object", + value_type=("list", "str"), + description="Python dependencies of estimator as string or list of strings.", ), - ( - "python_dependencies_alias", - "object", - "dict", - "should be provided if import name differs from package name, \ - key-value pairs are package name, import name", + Tag( + name="python_dependencies_alias", + scitype="object", + value_type="dict", + description="Alias for Python dependencies if import name differs from package name. Key-value pairs are package name and import name.", ), # ----------------------- # BaseTimeSeriesBootstrap # ----------------------- - ( - "bootstrap_type", - "bootstrap", - ("list", "str"), - "which type of bootstrap the algorithm is", + Tag( + name="bootstrap_type", + scitype="bootstrap", + value_type=("list", "str"), + description="Type(s) of bootstrap the algorithm supports.", ), - ( - "capability:multivariate", - "bootstrap", - "bool", - "whether the bootstrap algorithm supports multivariate data", + Tag( + name="capability:multivariate", + scitype="bootstrap", + value_type="bool", + description="Whether the bootstrap algorithm supports multivariate data.", ), # ---------------------------- # BaseMetaObject reserved tags # ---------------------------- - ( - "named_object_parameters", - "object", - "str", - "name of component list attribute for meta-objects", + Tag( + name="named_object_parameters", + scitype="object", + value_type="str", + description="Name of component list attribute for meta-objects.", ), - ( - "fitted_named_object_parameters", - "estimator", - "str", - "name of fitted component list attribute for meta-objects", + Tag( + name="fitted_named_object_parameters", + scitype="estimator", + value_type="str", + description="Name of fitted component list attribute for meta-objects.", ), ] -OBJECT_TAG_LIST = [x[0] for x in OBJECT_TAG_REGISTER] +# Create OBJECT_TAG_TABLE as a DataFrame +OBJECT_TAG_TABLE: pd.DataFrame = pd.DataFrame( + [ + { + "name": tag.name, + "scitype": tag.scitype, + "value_type": tag.value_type, + "description": tag.description, + } + for tag in OBJECT_TAG_REGISTER + ] +) + +# Create OBJECT_TAG_LIST as a list of tag names +OBJECT_TAG_LIST: List[str] = [tag.name for tag in OBJECT_TAG_REGISTER] + + +def check_tag_is_valid(tag_name: str, tag_value: Any) -> bool: + """ + Check whether a tag value is valid for a given tag name. + + Parameters + ---------- + tag_name : str + The name of the tag to validate. + tag_value : Any + The value to validate against the tag's expected type. + + Returns + ------- + bool + True if the tag value is valid for the tag name, False otherwise. + + Raises + ------ + KeyError + If the tag_name is not found in OBJECT_TAG_REGISTER. + """ + try: + tag = next(tag for tag in OBJECT_TAG_REGISTER if tag.name == tag_name) + except StopIteration as e: + raise KeyError( + f"Tag name '{tag_name}' not found in OBJECT_TAG_REGISTER." + ) from e + + value_type = tag.value_type + + if isinstance(value_type, str): + expected_type = value_type + if expected_type == "bool": + return isinstance(tag_value, bool) + elif expected_type == "int": + return isinstance(tag_value, int) + elif expected_type == "str": + return isinstance(tag_value, str) + elif expected_type == "list": + return isinstance(tag_value, list) + elif expected_type == "dict": + return isinstance(tag_value, dict) + else: + return False + elif isinstance(value_type, tuple): + base_type, subtype = value_type + if base_type == "str": + if isinstance(tag_value, str): + return tag_value in subtype + return False + elif base_type == "list": + if not isinstance(tag_value, list): + return False + if isinstance(subtype, list): + return all( + isinstance(item, str) and item in subtype + for item in tag_value + ) + elif subtype == "str": + return all(isinstance(item, str) for item in tag_value) + return False + else: + return False diff --git a/src/tsbootstrap/tests/scenarios/scenarios_bootstrap.py b/src/tsbootstrap/tests/scenarios/scenarios_bootstrap.py index fb1e90e..933882a 100644 --- a/src/tsbootstrap/tests/scenarios/scenarios_bootstrap.py +++ b/src/tsbootstrap/tests/scenarios/scenarios_bootstrap.py @@ -54,10 +54,7 @@ def scitype(obj): obj_can_handle_multivariate = get_tag(obj, "capability:multivariate") - if is_multivariate and not obj_can_handle_multivariate: - return False - - return True + return not (is_multivariate and not obj_can_handle_multivariate) X_np_uni = rng.random((20, 1))