Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: scikit-base registry and testing framework #26

Merged
merged 79 commits into from
Feb 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
79 commits
Select commit Hold shift + click to select a range
88501dc
skbase
fkiraly Nov 23, 2023
ce630e3
Merge branch 'main' into skbase
astrogilda Jan 1, 2024
68b2091
Merge remote-tracking branch 'upstream/main' into skbase
fkiraly Jan 5, 2024
ff35602
Update base_bootstrap.py
fkiraly Jan 5, 2024
3bc4226
Update pyproject.toml
fkiraly Jan 5, 2024
b5e2aa1
Update pyproject.toml
fkiraly Jan 5, 2024
b040f4c
Delete poetry.lock
fkiraly Jan 5, 2024
0e0c922
Merge branch 'remove-lock' into skbase
fkiraly Jan 5, 2024
6b9581c
remove repr etc
fkiraly Jan 6, 2024
7c08090
registry and lookup
fkiraly Jan 6, 2024
a402256
test class basic
fkiraly Jan 6, 2024
cba73bc
scenarios
fkiraly Jan 6, 2024
cea9d70
test all bootstraps
fkiraly Jan 6, 2024
b0a1e81
Merge branch 'main' into skbase
astrogilda Jan 19, 2024
c39b740
Merge remote-tracking branch 'upstream/main' into skbase
fkiraly Jan 19, 2024
b1caa82
fix kwargs, remove repr
fkiraly Jan 19, 2024
60607f2
Update base_bootstrap_configs.py
fkiraly Jan 19, 2024
593188a
Merge branch 'main' into skbase
astrogilda Jan 19, 2024
2a47471
Update base_bootstrap_configs.py
fkiraly Jan 19, 2024
fa8010c
Merge branch 'skbase' of https://github.com/fkiraly/tsbootstrap into …
fkiraly Jan 19, 2024
59070f9
Update base_bootstrap_configs.py
fkiraly Jan 19, 2024
8162463
Update base_bootstrap_configs.py
fkiraly Jan 20, 2024
e2f15ed
Update base_bootstrap_configs.py
fkiraly Jan 20, 2024
cfb9e3a
Merge branch 'skbase' into skbase-machinery
fkiraly Jan 20, 2024
2455fbb
Update CI.yml
fkiraly Jan 20, 2024
8165df7
Merge remote-tracking branch 'upstream/main' into skbase-machinery
fkiraly Jan 20, 2024
5972034
pytest folder spec
fkiraly Jan 20, 2024
6d121e1
Merge remote-tracking branch 'upstream/main' into skbase-machinery
fkiraly Feb 5, 2024
40e3c31
Update _tags.py
fkiraly Feb 5, 2024
e1dd886
Update _tags.py
fkiraly Feb 5, 2024
c97c040
Update CI.yml
fkiraly Feb 5, 2024
11ad881
check soft deps
fkiraly Feb 5, 2024
f65b54a
Revert "check soft deps"
fkiraly Feb 5, 2024
9ba53b3
Revert "Revert "check soft deps""
fkiraly Feb 5, 2024
efcaed5
Update dependencies.py
fkiraly Feb 5, 2024
912f823
Update test_all_estimators.py
fkiraly Feb 9, 2024
da34200
Merge remote-tracking branch 'upstream/main' into skbase-machinery
fkiraly Feb 9, 2024
b621adc
Merge remote-tracking branch 'upstream/main' into skbase-machinery
fkiraly Feb 16, 2024
8622a4a
Update bootstrap.py
fkiraly Feb 16, 2024
9e99667
Update block_bootstrap_configs.py
fkiraly Feb 16, 2024
398970f
exclude configs from test
fkiraly Feb 16, 2024
8d296e3
exclude samplers
fkiraly Feb 16, 2024
a2ca4d5
super
fkiraly Feb 16, 2024
5c212a5
Update test_all_estimators.py
fkiraly Feb 16, 2024
6112377
remove inadmissible types from init values
fkiraly Feb 16, 2024
0ecceb9
Update base_bootstrap.py
fkiraly Feb 16, 2024
a7f2e42
exog
fkiraly Feb 16, 2024
1f02965
Update scenarios_bootstrap.py
fkiraly Feb 16, 2024
3edbc24
Update scenarios_bootstrap.py
fkiraly Feb 16, 2024
94c67bc
Update test_all_bootstraps.py
fkiraly Feb 17, 2024
a23302d
Update test_all_bootstraps.py
fkiraly Feb 17, 2024
646951c
Update test_all_bootstraps.py
fkiraly Feb 17, 2024
93aa384
Update test_all_bootstraps.py
fkiraly Feb 17, 2024
92cca77
Update test_all_bootstraps.py
fkiraly Feb 17, 2024
439a959
Update bootstrap.py
fkiraly Feb 17, 2024
e929a70
multivariate tag
fkiraly Feb 19, 2024
c62baac
Update scenarios_bootstrap.py
fkiraly Feb 19, 2024
a99a51d
Update scenarios_bootstrap.py
fkiraly Feb 19, 2024
6efcc4a
Merge remote-tracking branch 'upstream/main' into skbase-machinery
fkiraly Feb 19, 2024
dcdb602
remove random_state
fkiraly Feb 19, 2024
6e947ff
Update base_bootstrap.py
fkiraly Feb 23, 2024
ba3f539
Update test_bootstrap.py
fkiraly Feb 23, 2024
1802b04
swap data, indices
fkiraly Feb 23, 2024
bbf488b
exclude sbb
fkiraly Feb 23, 2024
59d3054
Update test_bootstrap.py
fkiraly Feb 23, 2024
319b956
Update test_bootstrap.py
fkiraly Feb 23, 2024
3982c92
fix scenario tags
fkiraly Feb 23, 2024
1708b1a
Update base_bootstrap.py
fkiraly Feb 23, 2024
2e3e0cd
Update base_bootstrap.py
fkiraly Feb 23, 2024
e7e226b
super call
fkiraly Feb 23, 2024
c71ec68
Update block_bootstrap.py
fkiraly Feb 23, 2024
714be25
Update base_bootstrap.py
fkiraly Feb 23, 2024
191acab
fix model_params passing
fkiraly Feb 23, 2024
de4cbd4
Update test_base_bootstrap_configs.py
fkiraly Feb 23, 2024
82c17ca
Update test_base_bootstrap_configs.py
fkiraly Feb 23, 2024
4ee37f1
Update test_base_bootstrap_configs.py
fkiraly Feb 23, 2024
e2ff589
Update base_bootstrap.py
fkiraly Feb 23, 2024
3144eaf
check_estimator
fkiraly Feb 23, 2024
83d5b65
fix seed
fkiraly Feb 24, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,7 @@ jobs:
run: git branch -a

- name: Run tests
run: python -m pytest tests -vv

- name: Publish code coverage
uses: codecov/codecov-action@v4
run: python -m pytest src/ tests/ -vv

test-all-softdeps:
needs: test-no-softdeps
Expand Down Expand Up @@ -107,7 +104,7 @@ jobs:
run: git branch -a

- name: Run tests
run: python -m pytest tests -vv
run: python -m pytest src/ tests/ -vv --cov=src/

- name: Publish code coverage
uses: codecov/codecov-action@v4
Expand Down
43 changes: 37 additions & 6 deletions src/tsbootstrap/base_bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@ class BaseTimeSeriesBootstrap(BaseObject):
If n_bootstraps is not greater than 0.
"""

_tags = {"object_type": "bootstrap"}
_tags = {
"object_type": "bootstrap",
"bootstrap_type": "other",
"capability:multivariate": True,
}

def __init__(
self,
Expand Down Expand Up @@ -117,7 +121,7 @@ def bootstrap(
X_train, X_test = time_series_split(X, test_ratio=test_ratio)

if y is not None:
self._check_input(y)
self._check_input(y, enforce_univariate=False)
exog_train, _ = time_series_split(y, test_ratio=test_ratio)
else:
exog_train = None
Expand Down Expand Up @@ -150,8 +154,14 @@ def _generate_samples(
for _ in range(self.config.n_bootstraps):
indices, data = self._generate_samples_single_bootstrap(X=X, y=y)
data = np.concatenate(data, axis=0)

# hack to fix known issue with non-concatenated index sets
# see bug issue #81
if isinstance(indices, list):
indices = np.concatenate(indices, axis=0)

if return_indices:
yield indices, data # type: ignore
yield data, indices # type: ignore
else:
yield data

Expand All @@ -162,11 +172,20 @@ def _generate_samples_single_bootstrap(self, X: np.ndarray, y=None):
"""
raise NotImplementedError("abstract method")

def _check_input(self, X):
def _check_input(self, X, enforce_univariate=True):
"""Checks if the input is valid."""
if np.any(np.diff([len(x) for x in X]) != 0):
raise ValueError("All time series must be of the same length.")

self_can_only_univariate = not self.get_tag("capability:multivariate")
check_univariate = enforce_univariate and self_can_only_univariate
if check_univariate and X.shape[1] > 1:
raise ValueError(
f"Unsupported input type: the estimator {type(self)} "
"does not support multivariate endogeneous time series (X argument). "
"Pass an 1D np.array, or a 2D np.array with a single column."
)

def get_n_bootstraps(
self,
X=None,
Expand Down Expand Up @@ -220,6 +239,12 @@ class BaseResidualBootstrap(BaseTimeSeriesBootstrap):
_fit_model : Fits the model to the data and stores the residuals.
"""

_tags = {
"python_dependencies": "statsmodels",
"bootstrap_type": "residual",
"capability:multivariate": False,
}

def __init__(
self,
n_bootstraps: Integral = 10, # type: ignore
Expand Down Expand Up @@ -299,11 +324,14 @@ def _fit_model(self, X: np.ndarray, y=None) -> None:
or self.fit_model is None
or self.coefs is None
):
model_params = self.config.model_params
if model_params is None:
model_params = {}
fit_obj = TSFitBestLag(
model_type=self.config.model_type,
order=self.config.order,
save_models=self.config.save_models,
**self.config.model_params,
**model_params,
)
self.fit_model = fit_obj.fit(X=X, y=y).model
self.X_fitted = fit_obj.get_fitted_X()
Expand Down Expand Up @@ -486,7 +514,7 @@ class BaseStatisticPreservingBootstrap(BaseTimeSeriesBootstrap):
def __init__(
self,
n_bootstraps: Integral = 10, # type: ignore
statistic: Callable = np.mean,
statistic: Callable = None,
statistic_axis: Integral = 0, # type: ignore
statistic_keepdims: bool = False,
rng=None,
Expand All @@ -505,6 +533,9 @@ def __init__(
self.statistic_axis = statistic_axis
self.statistic_keepdims = statistic_keepdims

if statistic is None:
statistic = np.mean

self.config = BaseStatisticPreservingBootstrapConfig(
n_bootstraps=n_bootstraps,
rng=rng,
Expand Down
5 changes: 4 additions & 1 deletion src/tsbootstrap/base_bootstrap_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ class BaseTimeSeriesBootstrapConfig(BaseObject):
that are common to all time series bootstrapping methods.
"""

_tags = {"object_type": "config"}

def __init__(
self,
n_bootstraps: Integral = 10, # type: ignore
Expand Down Expand Up @@ -99,6 +101,7 @@ def __init__(
model_type: ModelTypesWithoutArch = "ar",
order=None,
save_models: bool = False,
model_params=None,
**kwargs,
):
"""
Expand Down Expand Up @@ -148,7 +151,7 @@ def __init__(
self.model_type = model_type
self.order = order
self.save_models = save_models
self.model_params = kwargs
self.model_params = model_params

super().__init__(n_bootstraps=n_bootstraps, rng=rng)

Expand Down
6 changes: 4 additions & 2 deletions src/tsbootstrap/block_bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ class BlockBootstrap(BaseTimeSeriesBootstrap):
If block_length is not greater than 0.
"""

_tags = {"bootstrap_type": "block"}

def __init__(
self,
n_bootstraps: Integral = 10, # type: ignore
Expand Down Expand Up @@ -109,8 +111,8 @@ def __init__(
self.blocks = None
self.block_resampler = None

def _check_input(self, X: np.ndarray) -> None:
super()._check_input(X)
def _check_input(self, X: np.ndarray, enforce_univariate=True) -> None:
super()._check_input(X=X, enforce_univariate=enforce_univariate)
if self.config.block_length is not None and self.config.block_length > X.shape[0]: # type: ignore
raise ValueError(
"block_length cannot be greater than the size of the input array X."
Expand Down
4 changes: 2 additions & 2 deletions src/tsbootstrap/block_bootstrap_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ class MovingBlockBootstrapConfig(BlockBootstrapConfig):

def __init__(
self,
block_length: Integral,
block_length: Integral = None,
**kwargs,
) -> None:
"""
Expand Down Expand Up @@ -383,7 +383,7 @@ class CircularBlockBootstrapConfig(BlockBootstrapConfig):

def __init__(
self,
block_length: Integral,
block_length: Integral = None,
**kwargs,
) -> None:
"""
Expand Down
4 changes: 4 additions & 0 deletions src/tsbootstrap/block_length_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ class BlockLengthSampler(BaseObject):
Sample a block length from the selected distribution.
"""

_tags = {"object_type": "sampler"}

def __init__(self, avg_block_length: Integral = DEFAULT_AVG_BLOCK_LENGTH, block_length_distribution: str = None, rng: RngTypes = None): # type: ignore
"""
Initialize the BlockLengthSampler with the selected distribution and average block length.
Expand All @@ -76,6 +78,8 @@ def __init__(self, avg_block_length: Integral = DEFAULT_AVG_BLOCK_LENGTH, block_
self.avg_block_length = avg_block_length
self.rng = rng

super().__init__()

@property
def block_length_distribution(self) -> str:
"""Getter for block_length_distribution."""
Expand Down
17 changes: 11 additions & 6 deletions src/tsbootstrap/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,8 @@ def _generate_samples_single_bootstrap(
bootstrap_samples = self.X_fitted + np.concatenate(block_data, axis=0)
return block_indices, [bootstrap_samples]

def get_test_params(self):
@classmethod
def get_test_params(cls, parameter_set="default"):
from tsbootstrap.block_bootstrap import MovingBlockBootstrap
bs = MovingBlockBootstrap()
return {"block_bootstrap": bs}
Expand Down Expand Up @@ -416,7 +417,8 @@ def _generate_samples_single_bootstrap(

return block_indices, [bootstrap_samples]

def get_test_params(self):
@classmethod
def get_test_params(cls, parameter_set="default"):
from tsbootstrap.block_bootstrap import MovingBlockBootstrap
bs = MovingBlockBootstrap()
return {"block_bootstrap": bs}
Expand Down Expand Up @@ -500,7 +502,7 @@ def __init__(
self,
block_bootstrap,
n_bootstraps: Integral = 10, # type: ignore
statistic=np.mean,
statistic=None,
statistic_axis: Integral = 0, # type: ignore
statistic_keepdims: bool = False,
rng=None,
Expand Down Expand Up @@ -543,7 +545,8 @@ def _generate_samples_single_bootstrap(
bootstrap_samples = block_data_concat + bias
return block_indices, [bootstrap_samples]

def get_test_params(self):
@classmethod
def get_test_params(cls, parameter_set="default"):
from tsbootstrap.block_bootstrap import MovingBlockBootstrap
bs = MovingBlockBootstrap()
return {"block_bootstrap": bs}
Expand Down Expand Up @@ -763,7 +766,8 @@ def _generate_samples_single_bootstrap(
bootstrap_samples = self.X_fitted + bootstrap_residuals
return block_indices, [bootstrap_samples]

def get_test_params(self):
@classmethod
def get_test_params(cls, parameter_set="default"):
from tsbootstrap.block_bootstrap import MovingBlockBootstrap
bs = MovingBlockBootstrap()
return {"block_bootstrap": bs}
Expand Down Expand Up @@ -947,7 +951,8 @@ def _generate_samples_single_bootstrap(

return block_indices, [bootstrapped_samples]

def get_test_params(self):
@classmethod
def get_test_params(cls, parameter_set="default"):
from tsbootstrap.block_bootstrap import MovingBlockBootstrap
bs = MovingBlockBootstrap()
return {"block_bootstrap": bs}
13 changes: 13 additions & 0 deletions src/tsbootstrap/registry/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""Registry and lookup functionality."""

from tsbootstrap.registry._lookup import all_objects
from tsbootstrap.registry._tags import (
OBJECT_TAG_LIST,
OBJECT_TAG_REGISTER,
)

__all__ = [
"OBJECT_TAG_LIST",
"OBJECT_TAG_REGISTER",
"all_objects",
]
Loading
Loading