Skip to content

Commit

Permalink
refactor registry modules and enhance doc with pydantic
Browse files Browse the repository at this point in the history
  • Loading branch information
Sankalp Gilda committed Oct 7, 2024
1 parent fbae014 commit 2bf41b2
Show file tree
Hide file tree
Showing 4 changed files with 342 additions and 184 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ dependencies = [
"scikit-base>=0.8.0,<0.9.0",
"scikit-learn>=1.5.1,<1.6.0",
"scipy>=1.13,<1.14.0",
"packaging",
"packaging>=24.0,<24.2",
"pydantic>=2.0,<3.0",
]

[project.optional-dependencies]
Expand Down
232 changes: 131 additions & 101 deletions src/tsbootstrap/registry/_lookup.py
Original file line number Diff line number Diff line change
@@ -1,125 +1,119 @@
"""Registry lookup methods.
"""
Registry lookup methods.
This module exports the following methods for registry lookup:
all_objects(object_types, filter_tags)
lookup and filtering of objects
- all_objects(object_types: Optional[Union[str, List[str]]] = None,
filter_tags: Optional[Dict[str, Union[str, List[str], bool]]] = None,
exclude_objects: Optional[Union[str, List[str]]] = None,
return_names: bool = True,
as_dataframe: bool = False,
return_tags: Optional[Union[str, List[str]]] = None,
suppress_import_stdout: bool = True) -> Union[List[Any], List[Tuple], pd.DataFrame]
Lookup and filtering of objects in the tsbootstrap registry.
"""

# based on the sktime module of same name

__author__ = ["fkiraly"]
# all_objects is based on the sklearn utility all_estimators


from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union

import pandas as pd
from skbase.base import BaseObject
from skbase.lookup import all_objects as _all_objects

from tsbootstrap.registry._tags import OBJECT_TAG_REGISTER
from tsbootstrap.registry._tags import (
OBJECT_TAG_REGISTER,
check_tag_is_valid,
)

VALID_OBJECT_TYPE_STRINGS = {x[1] for x in OBJECT_TAG_REGISTER}
VALID_OBJECT_TYPE_STRINGS: set = {tag.scitype for tag in OBJECT_TAG_REGISTER}


def all_objects(
object_types=None,
filter_tags=None,
exclude_objects=None,
return_names=True,
as_dataframe=False,
return_tags=None,
suppress_import_stdout=True,
):
"""Get a list of all objects from tsbootstrap.
This function crawls the module and gets all classes that inherit
object_types: Optional[Union[str, List[str]]] = None,
filter_tags: Optional[
Union[str, Dict[str, Union[str, List[str], bool]]]
] = None,
exclude_objects: Optional[Union[str, List[str]]] = None,
return_names: bool = True,
as_dataframe: bool = False,
return_tags: Optional[Union[str, List[str]]] = None,
suppress_import_stdout: bool = True,
) -> Union[List[Any], List[Tuple], pd.DataFrame]:
"""
Get a list of all objects from tsbootstrap.
This function crawls the module and retrieves all classes that inherit
from tsbootstrap's and sklearn's base classes.
Not included are: the base classes themselves, classes defined in test modules.
Excluded from retrieval are:
- The base classes themselves
- Classes defined in test modules
Parameters
----------
object_types: str, list of str, optional (default=None)
Which kind of objects should be returned.
if None, no filter is applied and all objects are returned.
if str or list of str, strings define scitypes specified in search
only objects that are of (at least) one of the scitypes are returned
possible str values are entries of registry.BASE_CLASS_REGISTER (first col)
return_names: bool, optional (default=True)
if True, object class name is included in the all_objects()
return in the order: name, object class, optional tags, either as
a tuple or as pandas.DataFrame columns
if False, object class name is removed from the all_objects()
return.
filter_tags: dict of (str or list of str), optional (default=None)
For a list of valid tag strings, use the registry.all_tags utility.
subsets the returned objects as follows:
each key/value pair is statement in "and"/conjunction
key is tag name to sub-set on
value str or list of string are tag values
condition is "key must be equal to value, or in set(value)"
exclude_objects: str, list of str, optional (default=None)
Names of objects to exclude.
as_dataframe: bool, optional (default=False)
if True, all_objects will return a pandas.DataFrame with named
columns for all of the attributes being returned.
if False, all_objects will return a list (either a list of
objects or a list of tuples, see Returns)
return_tags: str or list of str, optional (default=None)
Names of tags to fetch and return each object's value of.
For a list of valid tag strings, use the registry.all_tags utility.
if str or list of str,
the tag values named in return_tags will be fetched for each
object and will be appended as either columns or tuple entries.
suppress_import_stdout : bool, optional. Default=True
whether to suppress stdout printout upon import.
object_types : Union[str, List[str]], optional (default=None)
Specifies which types of objects to return.
- If None, no filtering is applied and all objects are returned.
- If str or list of str, only objects matching the specified scitypes are returned.
Valid scitypes are entries in `registry.BASE_CLASS_REGISTER` (first column).
filter_tags : Union[str, Dict[str, Union[str, List[str], bool]]], optional (default=None)
Dictionary or string to filter returned objects based on their tags.
- If a string, it is treated as a boolean tag filter with the value `True`.
- If a dictionary, each key-value pair represents a filter condition in an "AND" conjunction.
- Key is the tag name to filter on.
- Value is a string, list of strings, or boolean that the tag value must match or be within.
- Only objects satisfying all filter conditions are returned.
exclude_objects : Union[str, List[str]], optional (default=None)
Names of objects to exclude from the results.
return_names : bool, optional (default=True)
- If True, the object's class name is included in the returned results.
- If False, the class name is omitted.
as_dataframe : bool, optional (default=False)
- If True, returns a pandas.DataFrame with named columns for all returned attributes.
- If False, returns a list (of objects or tuples).
return_tags : Union[str, List[str]], optional (default=None)
- Names of tags to fetch and include in the returned results.
- If specified, tag values are appended as either columns or tuple entries.
suppress_import_stdout : bool, optional (default=True)
Whether to suppress stdout printout upon import.
Returns
-------
all_objects will return one of the following:
1. list of objects, if return_names=False, and return_tags is None
2. list of tuples (optional object name, class, ~optional object
tags), if return_names=True or return_tags is not None.
3. pandas.DataFrame if as_dataframe = True
if list of objects:
entries are objects matching the query,
in alphabetical order of object name
if list of tuples:
list of (optional object name, object, optional object
tags) matching the query, in alphabetical order of object name,
where
``name`` is the object name as string, and is an
optional return
``object`` is the actual object
``tags`` are the object's values for each tag in return_tags
and is an optional return.
if dataframe:
all_objects will return a pandas.DataFrame.
column names represent the attributes contained in each column.
"objects" will be the name of the column of objects, "names"
will be the name of the column of object class names and the string(s)
passed in return_tags will serve as column names for all columns of
tags that were optionally requested.
Union[List[Any], List[Tuple], pd.DataFrame]
Depending on the parameters:
1. List of objects:
- Entries are objects matching the query, in alphabetical order of object name.
2. List of tuples:
- Each tuple contains (optional object name, object class, optional object tags).
- Ordered alphabetically by object name.
3. pandas.DataFrame:
- Columns represent the returned attributes.
- Includes "objects", "names", and any specified tag columns.
Examples
--------
>>> from tsbootstrap.registry import all_objects
>>> # return a complete list of objects as pd.Dataframe
>>> # Return a complete list of objects as a DataFrame
>>> all_objects(as_dataframe=True)
>>> # return all bootstrap algorithms by filtering for object type
>>> # Return all bootstrap algorithms by filtering for object type
>>> all_objects("bootstrap", as_dataframe=True)
>>> # return all bootstraps which are block bootstraps
>>> # Return all bootstraps which are block bootstraps
>>> all_objects(
... "bootstrap",
... object_types="bootstrap",
... filter_tags={"bootstrap_type": "block"},
... as_dataframe=True
... )
References
----------
Adapted version of sktime's ``all_estimators``,
which is an evolution of scikit-learn's ``all_estimators``
Adapted version of sktime's `all_estimators`,
which is an evolution of scikit-learn's `all_estimators`.
"""
MODULES_TO_IGNORE = (
"tests",
Expand All @@ -129,33 +123,69 @@ def all_objects(
"all",
)

result = []
result: Union[List[Any], List[Tuple], pd.DataFrame] = []
ROOT = str(
Path(__file__).parent.parent
) # tsbootstrap package root directory

# Prepare filter_tags
if isinstance(filter_tags, str):
# Ensure the tag expects a boolean value
tag = next(
(t for t in OBJECT_TAG_REGISTER if t.name == filter_tags), None
)
if not tag:
raise ValueError(
f"Tag '{filter_tags}' not found in OBJECT_TAG_REGISTER."
)
if tag.value_type != "bool":
raise ValueError(
f"Tag '{filter_tags}' does not expect a boolean value."
)
filter_tags = {filter_tags: True}
filter_tags = filter_tags.copy() if filter_tags else None
elif isinstance(filter_tags, dict):
# Validate each tag in filter_tags
for key, value in filter_tags.items():
try:
if not check_tag_is_valid(key, value):
raise ValueError(
f"Invalid value '{value}' for tag '{key}'."
)
except KeyError as e:
raise ValueError(
f"Tag '{key}' not found in OBJECT_TAG_REGISTER."
) from e
else:
filter_tags = None

if object_types:
if isinstance(object_types, str):
object_types = [object_types]
# Validate object_types
invalid_types = set(object_types) - VALID_OBJECT_TYPE_STRINGS
if invalid_types:
raise ValueError(
f"Invalid object_types: {invalid_types}. Valid types are {VALID_OBJECT_TYPE_STRINGS}."
)
if filter_tags and "object_type" not in filter_tags:
object_tag_filter = {"object_type": object_types}
elif filter_tags:
filter_tags_filter = filter_tags.get("object_type", [])
if isinstance(object_types, str):
object_types = [object_types]
object_tag_update = {
"object_type": object_types + filter_tags_filter
}
filter_tags.update(object_tag_update)
else:
object_tag_filter = {"object_type": object_types}
if filter_tags:
filter_tags.update(object_tag_filter)
elif filter_tags and "object_type" in filter_tags:
existing_filter = filter_tags.get("object_type", [])
if isinstance(existing_filter, str):
existing_filter = [existing_filter]
elif isinstance(existing_filter, list):
pass
else:
raise ValueError(
f"Unexpected type for 'object_type' filter: {type(existing_filter)}"
)
combined_filter = list(set(object_types + existing_filter))
filter_tags["object_type"] = combined_filter
else:
filter_tags = object_tag_filter
filter_tags = {"object_type": object_types}

# Retrieve objects using skbase's all_objects
result = _all_objects(
object_types=[BaseObject],
filter_tags=filter_tags,
Expand Down
Loading

0 comments on commit 2bf41b2

Please sign in to comment.