Skip to content

Commit

Permalink
fixing issues
Browse files Browse the repository at this point in the history
  • Loading branch information
GabrielBG0 committed Jul 5, 2024
1 parent e9e46fb commit 40c6b19
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 10 deletions.
8 changes: 2 additions & 6 deletions minerva/models/nets/time_series/cnns.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,7 @@ def __init__(
test_metrics={"acc": Accuracy(task="multiclass", num_classes=num_classes)},
)

def _create_backbone(
self, input_shape: Union[Tuple[int, int], Tuple[int, int, int]]
) -> torch.nn.Module:
def _create_backbone(self, input_shape: Tuple[int, int, int]) -> torch.nn.Module:
return torch.nn.Sequential(
# First 2D convolutional layer
torch.nn.Conv2d(
Expand Down Expand Up @@ -146,9 +144,7 @@ def __init__(
test_metrics={"acc": Accuracy(task="multiclass", num_classes=num_classes)},
)

def _create_backbone(
self, input_shape: Union[Tuple[int, int], Tuple[int, int, int]]
) -> torch.nn.Module:
def _create_backbone(self, input_shape: Tuple[int, int, int]) -> torch.nn.Module:
first_kernel_size = 4
return torch.nn.Sequential(
# Add padding
Expand Down
4 changes: 2 additions & 2 deletions minerva/pipelines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class Pipeline(HyperparametersMixin):
def __init__(
self,
log_dir: Optional[PathLike] = None,
ignore: Union[str, List[str], None] = None,
ignore: Optional[Union[str, List[str]]] = None,
cache_result: bool = False,
save_run_status: bool = False,
):
Expand All @@ -55,7 +55,7 @@ def __init__(
log_dir : PathLike, optional
The default logging directory where all related pipeline files
should be saved. By default None (uses current working directory)
ignore : str | List[str], optional
ignore : Union[str, List[str]], optional
Pipeline __init__ attributes are saved into config attibute. This
option allows to ignore some attributes from being saved. This is
quite useful when the attributes are not serializable or very large.
Expand Down
4 changes: 2 additions & 2 deletions minerva/pipelines/lightning_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def _calculate_metrics(
return results

# Private methods
def _fit(self, data: L.LightningDataModule, ckpt_path: Optional[PathLike]):
def _fit(self, data: L.LightningDataModule, ckpt_path: Optional[PathLike] = None):
"""Fit the model using the given data.
Parameters
Expand All @@ -182,7 +182,7 @@ def _fit(self, data: L.LightningDataModule, ckpt_path: Optional[PathLike]):
model=self._model, datamodule=data, ckpt_path=ckpt_path
)

def _test(self, data: L.LightningDataModule, ckpt_path: Optional[PathLike]):
def _test(self, data: L.LightningDataModule, ckpt_path: Optional[PathLike] = None):
"""Test the model using the given data.
Parameters
Expand Down

0 comments on commit 40c6b19

Please sign in to comment.