From e70291d3c21e6518220cea32399d528becab8bf4 Mon Sep 17 00:00:00 2001 From: Mattie Tesfaldet Date: Sat, 29 Jul 2023 12:33:23 -0400 Subject: [PATCH] Docstrings revamp (#589) --- .pre-commit-config.yaml | 30 +++++- src/data/mnist_datamodule.py | 117 ++++++++++++++++------ src/eval.py | 16 +-- src/models/components/simple_dense_net.py | 20 +++- src/models/mnist_module.py | 112 +++++++++++++++++---- src/train.py | 17 ++-- src/utils/instantiators.py | 10 +- src/utils/logging_utils.py | 14 ++- src/utils/pylogger.py | 8 +- src/utils/rich_utils.py | 18 ++-- src/utils/utils.py | 37 ++++--- tests/conftest.py | 38 +++++-- tests/helpers/package_available.py | 7 +- tests/helpers/run_if.py | 43 ++++---- tests/helpers/run_sh_command.py | 7 +- tests/test_configs.py | 12 ++- tests/test_datamodules.py | 8 +- tests/test_eval.py | 13 ++- tests/test_sweeps.py | 37 +++++-- tests/test_train.py | 46 ++++++--- 20 files changed, 450 insertions(+), 160 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4c42c8d86..ee45ce194 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -40,10 +40,33 @@ repos: # python docstring formatting - repo: https://github.com/myint/docformatter - rev: v1.5.1 + rev: v1.7.4 hooks: - id: docformatter - args: [--in-place, --wrap-summaries=99, --wrap-descriptions=99] + args: + [ + --in-place, + --wrap-summaries=99, + --wrap-descriptions=99, + --style=sphinx, + --black, + ] + + # python docstring coverage checking + - repo: https://github.com/econchick/interrogate + rev: 1.5.0 # or master if you're bold + hooks: + - id: interrogate + args: + [ + --verbose, + --fail-under=80, + --ignore-init-module, + --ignore-init-method, + --ignore-module, + --ignore-nested-functions, + -vv, + ] # python check (PEP8), programming errors and code complexity - repo: https://github.com/PyCQA/flake8 @@ -53,10 +76,11 @@ repos: args: [ "--extend-ignore", - "E203,E402,E501,F401,F841", + "E203,E402,E501,F401,F841,RST2,RST301", "--exclude", "logs/*,data/*", ] + additional_dependencies: [flake8-rst-docstrings==0.3.0] # python security linter - repo: https://github.com/PyCQA/bandit diff --git a/src/data/mnist_datamodule.py b/src/data/mnist_datamodule.py index 6dd176ec6..a77053034 100644 --- a/src/data/mnist_datamodule.py +++ b/src/data/mnist_datamodule.py @@ -8,24 +8,42 @@ class MNISTDataModule(LightningDataModule): - """Example of LightningDataModule for MNIST dataset. + """`LightningDataModule` for the MNIST dataset. - A DataModule implements 6 key methods: + The MNIST database of handwritten digits has a training set of 60,000 examples, and a test set of 10,000 examples. + It is a subset of a larger set available from NIST. The digits have been size-normalized and centered in a + fixed-size image. The original black and white images from NIST were size normalized to fit in a 20x20 pixel box + while preserving their aspect ratio. The resulting images contain grey levels as a result of the anti-aliasing + technique used by the normalization algorithm. the images were centered in a 28x28 image by computing the center of + mass of the pixels, and translating the image so as to position this point at the center of the 28x28 field. + + A `LightningDataModule` implements 7 key methods: + + ```python def prepare_data(self): - # things to do on 1 GPU/TPU (not on every GPU/TPU in DDP) - # download data, pre-process, split, save to disk, etc... + # Things to do on 1 GPU/TPU (not on every GPU/TPU in DDP). + # Download data, pre-process, split, save to disk, etc... + def setup(self, stage): - # things to do on every process in DDP - # load data, set variables, etc... + # Things to do on every process in DDP. + # Load data, set variables, etc... + def train_dataloader(self): - # return train dataloader + # return train dataloader + def val_dataloader(self): - # return validation dataloader + # return validation dataloader + def test_dataloader(self): - # return test dataloader - def teardown(self): - # called on every process in DDP - # clean up after fit or test + # return test dataloader + + def predict_dataloader(self): + # return predict dataloader + + def teardown(self, stage): + # Called on every process in DDP. + # Clean up after fit or test. + ``` This allows you to share a full dataset without explaining how to download, split, transform and process the data. @@ -41,7 +59,15 @@ def __init__( batch_size: int = 64, num_workers: int = 0, pin_memory: bool = False, - ): + ) -> None: + """Initialize a `MNISTDataModule`. + + :param data_dir: The data directory. Defaults to `"data/"`. + :param train_val_test_split: The train, validation and test split. Defaults to `(55_000, 5_000, 10_000)`. + :param batch_size: The batch size. Defaults to `64`. + :param num_workers: The number of workers. Defaults to `0`. + :param pin_memory: Whether to pin memory. Defaults to `False`. + """ super().__init__() # this line allows to access init params with 'self.hparams' attribute @@ -58,22 +84,33 @@ def __init__( self.data_test: Optional[Dataset] = None @property - def num_classes(self): + def num_classes(self) -> int: + """Get the number of classes. + + :return: The number of MNIST classes (10). + """ return 10 - def prepare_data(self): - """Download data if needed. + def prepare_data(self) -> None: + """Download data if needed. Lightning ensures that `self.prepare_data()` is called only + within a single process on CPU, so you can safely add your downloading logic within. In + case of multi-node training, the execution of this hook depends upon + `self.prepare_data_per_node()`. Do not use it to assign state (self.x = y). """ MNIST(self.hparams.data_dir, train=True, download=True) MNIST(self.hparams.data_dir, train=False, download=True) - def setup(self, stage: Optional[str] = None): + def setup(self, stage: Optional[str] = None) -> None: """Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`. - This method is called by lightning with both `trainer.fit()` and `trainer.test()`, so be - careful not to execute things like random split twice! + This method is called by Lightning before `trainer.fit()`, `trainer.validate()`, `trainer.test()`, and + `trainer.predict()`, so be careful not to execute things like random split twice! Also, it is called after + `self.prepare_data()` and there is a barrier in between which ensures that all the processes proceed to + `self.setup()` once the data is prepared and available for use. + + :param stage: The stage to setup. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. Defaults to ``None``. """ # load and split datasets only if not loaded already if not self.data_train and not self.data_val and not self.data_test: @@ -86,7 +123,11 @@ def setup(self, stage: Optional[str] = None): generator=torch.Generator().manual_seed(42), ) - def train_dataloader(self): + def train_dataloader(self) -> DataLoader[Any]: + """Create and return the train dataloader. + + :return: The train dataloader. + """ return DataLoader( dataset=self.data_train, batch_size=self.hparams.batch_size, @@ -95,7 +136,11 @@ def train_dataloader(self): shuffle=True, ) - def val_dataloader(self): + def val_dataloader(self) -> DataLoader[Any]: + """Create and return the validation dataloader. + + :return: The validation dataloader. + """ return DataLoader( dataset=self.data_val, batch_size=self.hparams.batch_size, @@ -104,7 +149,11 @@ def val_dataloader(self): shuffle=False, ) - def test_dataloader(self): + def test_dataloader(self) -> DataLoader[Any]: + """Create and return the test dataloader. + + :return: The test dataloader. + """ return DataLoader( dataset=self.data_test, batch_size=self.hparams.batch_size, @@ -113,16 +162,28 @@ def test_dataloader(self): shuffle=False, ) - def teardown(self, stage: Optional[str] = None): - """Clean up after fit or test.""" + def teardown(self, stage: Optional[str] = None) -> None: + """Lightning hook for cleaning up after `trainer.fit()`, `trainer.validate()`, + `trainer.test()`, and `trainer.predict()`. + + :param stage: The stage being torn down. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. + Defaults to ``None``. + """ pass - def state_dict(self): - """Extra things to save to checkpoint.""" + def state_dict(self) -> Dict[Any, Any]: + """Called when saving a checkpoint. Implement to generate and save the datamodule state. + + :return: A dictionary containing the datamodule state that you want to save. + """ return {} - def load_state_dict(self, state_dict: Dict[str, Any]): - """Things to do when loading checkpoint.""" + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + """Called when loading a checkpoint. Implement to reload datamodule state given datamodule + `state_dict()`. + + :param state_dict: The datamodule state returned by `self.state_dict()`. + """ pass diff --git a/src/eval.py b/src/eval.py index 763dbb65c..2c77474a1 100644 --- a/src/eval.py +++ b/src/eval.py @@ -1,4 +1,4 @@ -from typing import List, Tuple +from typing import Any, Dict, List, Tuple import hydra import pyrootutils @@ -30,19 +30,15 @@ @utils.task_wrapper -def evaluate(cfg: DictConfig) -> Tuple[dict, dict]: +def evaluate(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: """Evaluates given checkpoint on a datamodule testset. This method is wrapped in optional @task_wrapper decorator, that controls the behavior during failure. Useful for multiruns, saving info about the crash, etc. - Args: - cfg (DictConfig): Configuration composed by Hydra. - - Returns: - Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects. + :param cfg: DictConfig configuration composed by Hydra. + :return: Tuple[dict, dict] with metrics and dict with all instantiated objects. """ - assert cfg.ckpt_path log.info(f"Instantiating datamodule <{cfg.data._target_}>") @@ -82,6 +78,10 @@ def evaluate(cfg: DictConfig) -> Tuple[dict, dict]: @hydra.main(version_base="1.3", config_path="../configs", config_name="eval.yaml") def main(cfg: DictConfig) -> None: + """Main entry point for evaluation. + + :param cfg: DictConfig configuration composed by Hydra. + """ # apply extra utilities # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.) utils.extras(cfg) diff --git a/src/models/components/simple_dense_net.py b/src/models/components/simple_dense_net.py index a276b9e94..ae817c951 100644 --- a/src/models/components/simple_dense_net.py +++ b/src/models/components/simple_dense_net.py @@ -1,7 +1,10 @@ +import torch from torch import nn class SimpleDenseNet(nn.Module): + """A simple fully-connected neural net for computing predictions.""" + def __init__( self, input_size: int = 784, @@ -9,7 +12,15 @@ def __init__( lin2_size: int = 256, lin3_size: int = 256, output_size: int = 10, - ): + ) -> None: + """Initialize a `SimpleDenseNet` module. + + :param input_size: The number of input features. + :param lin1_size: The number of output features of the first linear layer. + :param lin2_size: The number of output features of the second linear layer. + :param lin3_size: The number of output features of the third linear layer. + :param output_size: The number of output features of the final linear layer. + """ super().__init__() self.model = nn.Sequential( @@ -25,7 +36,12 @@ def __init__( nn.Linear(lin3_size, output_size), ) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Perform a single forward pass through the network. + + :param x: The input tensor. + :return: A tensor of predictions. + """ batch_size, channels, width, height = x.size() # (batch, 1, width, height) -> (batch, 1*width*height) diff --git a/src/models/mnist_module.py b/src/models/mnist_module.py index d27cc9f22..5c1049725 100644 --- a/src/models/mnist_module.py +++ b/src/models/mnist_module.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, Dict, Tuple import torch from lightning import LightningModule @@ -7,15 +7,33 @@ class MNISTLitModule(LightningModule): - """Example of LightningModule for MNIST classification. + """Example of a `LightningModule` for MNIST classification. - A LightningModule organizes your PyTorch code into 6 sections: - - Initialization (__init__) - - Train Loop (training_step) - - Validation loop (validation_step) - - Test loop (test_step) - - Prediction Loop (predict_step) - - Optimizers and LR Schedulers (configure_optimizers) + A `LightningModule` implements 8 key methods: + + ```python + def __init__(self): + # Define initialization code here. + + def setup(self, stage): + # Things to setup before each stage, 'fit', 'validate', 'test', 'predict'. + # This hook is called on every process when using DDP. + + def training_step(self, batch, batch_idx): + # The complete training step. + + def validation_step(self, batch, batch_idx): + # The complete validation step. + + def test_step(self, batch, batch_idx): + # The complete test step. + + def predict_step(self, batch, batch_idx): + # The complete predict step. + + def configure_optimizers(self): + # Define and configure optimizers and LR schedulers. + ``` Docs: https://lightning.ai/docs/pytorch/latest/common/lightning_module.html @@ -26,7 +44,13 @@ def __init__( net: torch.nn.Module, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler, - ): + ) -> None: + """Initialize a `MNISTLitModule`. + + :param net: The model to train. + :param optimizer: The optimizer to use for training. + :param scheduler: The learning rate scheduler to use for training. + """ super().__init__() # this line allows to access init params with 'self.hparams' attribute @@ -51,24 +75,50 @@ def __init__( # for tracking best so far validation accuracy self.val_acc_best = MaxMetric() - def forward(self, x: torch.Tensor): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Perform a forward pass through the model `self.net`. + + :param x: A tensor of images. + :return: A tensor of logits. + """ return self.net(x) - def on_train_start(self): + def on_train_start(self) -> None: + """Lightning hook that is called when training begins.""" # by default lightning executes validation step sanity checks before training starts, # so it's worth to make sure validation metrics don't store results from these checks self.val_loss.reset() self.val_acc.reset() self.val_acc_best.reset() - def model_step(self, batch: Any): + def model_step( + self, batch: Tuple[torch.Tensor, torch.Tensor] + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Perform a single model step on a batch of data. + + :param batch: A batch of data (a tuple) containing the input tensor of images and target labels. + + :return: A tuple containing (in order): + - A tensor of losses. + - A tensor of predictions. + - A tensor of target labels. + """ x, y = batch logits = self.forward(x) loss = self.criterion(logits, y) preds = torch.argmax(logits, dim=1) return loss, preds, y - def training_step(self, batch: Any, batch_idx: int): + def training_step( + self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int + ) -> torch.Tensor: + """Perform a single training step on a batch of data from the training set. + + :param batch: A batch of data (a tuple) containing the input tensor of images and target + labels. + :param batch_idx: The index of the current batch. + :return: A tensor of losses between model predictions and targets. + """ loss, preds, targets = self.model_step(batch) # update and log metrics @@ -80,10 +130,17 @@ def training_step(self, batch: Any, batch_idx: int): # return loss or backpropagation will fail return loss - def on_train_epoch_end(self): + def on_train_epoch_end(self) -> None: + "Lightning hook that is called when a training epoch ends." pass - def validation_step(self, batch: Any, batch_idx: int): + def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None: + """Perform a single validation step on a batch of data from the validation set. + + :param batch: A batch of data (a tuple) containing the input tensor of images and target + labels. + :param batch_idx: The index of the current batch. + """ loss, preds, targets = self.model_step(batch) # update and log metrics @@ -92,14 +149,21 @@ def validation_step(self, batch: Any, batch_idx: int): self.log("val/loss", self.val_loss, on_step=False, on_epoch=True, prog_bar=True) self.log("val/acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True) - def on_validation_epoch_end(self): + def on_validation_epoch_end(self) -> None: + "Lightning hook that is called when a validation epoch ends." acc = self.val_acc.compute() # get current val acc self.val_acc_best(acc) # update best so far val acc # log `val_acc_best` as a value through `.compute()` method, instead of as a metric object # otherwise metric would be reset by lightning after each epoch self.log("val/acc_best", self.val_acc_best.compute(), sync_dist=True, prog_bar=True) - def test_step(self, batch: Any, batch_idx: int): + def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None: + """Perform a single test step on a batch of data from the test set. + + :param batch: A batch of data (a tuple) containing the input tensor of images and target + labels. + :param batch_idx: The index of the current batch. + """ loss, preds, targets = self.model_step(batch) # update and log metrics @@ -108,15 +172,19 @@ def test_step(self, batch: Any, batch_idx: int): self.log("test/loss", self.test_loss, on_step=False, on_epoch=True, prog_bar=True) self.log("test/acc", self.test_acc, on_step=False, on_epoch=True, prog_bar=True) - def on_test_epoch_end(self): + def on_test_epoch_end(self) -> None: + """Lightning hook that is called when a test epoch ends.""" pass - def configure_optimizers(self): - """Choose what optimizers and learning-rate schedulers to use in your optimization. - Normally you'd need one. But in the case of GANs or similar you might have multiple. + def configure_optimizers(self) -> Dict[str, Any]: + """Configures optimizers and learning-rate schedulers to be used for training. + + Normally you'd need one, but in the case of GANs or similar you might need multiple. Examples: https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers + + :return: A dict containing the configured optimizers and learning-rate schedulers to be used for training. """ optimizer = self.hparams.optimizer(params=self.parameters()) if self.hparams.scheduler is not None: diff --git a/src/train.py b/src/train.py index dad481dd2..df741c8cb 100644 --- a/src/train.py +++ b/src/train.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import hydra import lightning as L @@ -32,20 +32,16 @@ @utils.task_wrapper -def train(cfg: DictConfig) -> Tuple[dict, dict]: +def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: """Trains the model. Can additionally evaluate on a testset, using best weights obtained during training. This method is wrapped in optional @task_wrapper decorator, that controls the behavior during failure. Useful for multiruns, saving info about the crash, etc. - Args: - cfg (DictConfig): Configuration composed by Hydra. - - Returns: - Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects. + :param cfg: A DictConfig configuration composed by Hydra. + :return: A tuple with metrics and dict with all instantiated objects. """ - # set seed for random number generators in pytorch, numpy and python.random if cfg.get("seed"): L.seed_everything(cfg.seed, workers=True) @@ -107,6 +103,11 @@ def train(cfg: DictConfig) -> Tuple[dict, dict]: @hydra.main(version_base="1.3", config_path="../configs", config_name="train.yaml") def main(cfg: DictConfig) -> Optional[float]: + """Main entry point for training. + + :param cfg: DictConfig configuration composed by Hydra. + :return: Optional[float] with optimized metric value. + """ # apply extra utilities # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.) utils.extras(cfg) diff --git a/src/utils/instantiators.py b/src/utils/instantiators.py index adabbf059..ada7a5253 100644 --- a/src/utils/instantiators.py +++ b/src/utils/instantiators.py @@ -11,8 +11,11 @@ def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: - """Instantiates callbacks from config.""" + """Instantiates callbacks from config. + :param callbacks_cfg: A DictConfig object containing callback configurations. + :return: A list of instantiated callbacks. + """ callbacks: List[Callback] = [] if not callbacks_cfg: @@ -31,8 +34,11 @@ def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]: - """Instantiates loggers from config.""" + """Instantiates loggers from config. + :param logger_cfg: A DictConfig object containing logger configurations. + :return: A list of instantiated loggers. + """ logger: List[Logger] = [] if not logger_cfg: diff --git a/src/utils/logging_utils.py b/src/utils/logging_utils.py index 3a8ca8e5c..899defd84 100644 --- a/src/utils/logging_utils.py +++ b/src/utils/logging_utils.py @@ -1,3 +1,5 @@ +from typing import Any, Dict + from lightning.pytorch.utilities import rank_zero_only from omegaconf import OmegaConf @@ -7,13 +9,17 @@ @rank_zero_only -def log_hyperparameters(object_dict: dict) -> None: - """Controls which config parts are saved by lightning loggers. +def log_hyperparameters(object_dict: Dict[str, Any]) -> None: + """Controls which config parts are saved by Lightning loggers. Additionally saves: - - Number of model parameters - """ + - Number of model parameters + :param object_dict: A dictionary containing the following objects: + - `"cfg"`: A DictConfig object containing the main config. + - `"model"`: The Lightning model. + - `"trainer"`: The Lightning trainer. + """ hparams = {} cfg = OmegaConf.to_container(object_dict["cfg"]) diff --git a/src/utils/pylogger.py b/src/utils/pylogger.py index 62176cad9..616006780 100644 --- a/src/utils/pylogger.py +++ b/src/utils/pylogger.py @@ -3,9 +3,13 @@ from lightning.pytorch.utilities import rank_zero_only -def get_pylogger(name=__name__) -> logging.Logger: - """Initializes multi-GPU-friendly python command line logger.""" +def get_pylogger(name: str = __name__) -> logging.Logger: + """Initializes a multi-GPU-friendly python command line logger. + :param name: The name of the logger, defaults to ``__name__``. + + :return: A logger object. + """ logger = logging.getLogger(name) # this ensures all logging levels get marked with the rank zero decorator diff --git a/src/utils/rich_utils.py b/src/utils/rich_utils.py index 6df129aae..430590adf 100644 --- a/src/utils/rich_utils.py +++ b/src/utils/rich_utils.py @@ -29,15 +29,14 @@ def print_config_tree( resolve: bool = False, save_to_file: bool = False, ) -> None: - """Prints content of DictConfig using Rich library and its tree structure. + """Prints the contents of a DictConfig as a tree structure using the Rich library. - Args: - cfg (DictConfig): Configuration composed by Hydra. - print_order (Sequence[str], optional): Determines in what order config components are printed. - resolve (bool, optional): Whether to resolve reference fields of DictConfig. - save_to_file (bool, optional): Whether to export config to the hydra output folder. + :param cfg: A DictConfig composed by Hydra. + :param print_order: Determines in what order config components are printed. Default is ``("data", "model", + "callbacks", "logger", "trainer", "paths", "extras")``. + :param resolve: Whether to resolve reference fields of DictConfig. Default is ``False``. + :param save_to_file: Whether to export config to the hydra output folder. Default is ``False``. """ - style = "dim" tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) @@ -77,8 +76,11 @@ def print_config_tree( @rank_zero_only def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: - """Prompts user to input tags from command line if no tags are provided in config.""" + """Prompts user to input tags from command line if no tags are provided in config. + :param cfg: A DictConfig composed by Hydra. + :param save_to_file: Whether to export tags to the hydra output folder. Default is ``False``. + """ if not cfg.get("tags"): if "id" in HydraConfig().cfg.hydra.job: raise ValueError("Specify tags before launching a multirun!") diff --git a/src/utils/utils.py b/src/utils/utils.py index b0a81c3f8..b3b404e36 100644 --- a/src/utils/utils.py +++ b/src/utils/utils.py @@ -1,6 +1,6 @@ import warnings from importlib.util import find_spec -from typing import Callable +from typing import Any, Callable, Dict, Tuple from omegaconf import DictConfig @@ -13,11 +13,12 @@ def extras(cfg: DictConfig) -> None: """Applies optional utilities before the task is started. Utilities: - - Ignoring python warnings - - Setting tags from command line - - Rich config printing - """ + - Ignoring python warnings + - Setting tags from command line + - Rich config printing + :param cfg: A DictConfig object containing the config tree. + """ # return if no `extras` config if not cfg.get("extras"): log.warning("Extras config not found! ") @@ -43,23 +44,25 @@ def task_wrapper(task_func: Callable) -> Callable: """Optional decorator that controls the failure behavior when executing the task function. This wrapper can be used to: - - make sure loggers are closed even if the task function raises an exception (prevents multirun failure) - - save the exception to a `.log` file - - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later) - - etc. (adjust depending on your needs) + - make sure loggers are closed even if the task function raises an exception (prevents multirun failure) + - save the exception to a `.log` file + - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later) + - etc. (adjust depending on your needs) Example: ``` @utils.task_wrapper - def train(cfg: DictConfig) -> Tuple[dict, dict]: - + def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: ... - return metric_dict, object_dict ``` + + :param task_func: The task function to be wrapped. + + :return: The wrapped task function. """ - def wrap(cfg: DictConfig): + def wrap(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: # execute the task try: metric_dict, object_dict = task_func(cfg=cfg) @@ -92,9 +95,13 @@ def wrap(cfg: DictConfig): return wrap -def get_metric_value(metric_dict: dict, metric_name: str) -> float: - """Safely retrieves value of the metric logged in LightningModule.""" +def get_metric_value(metric_dict: Dict[str, Any], metric_name: str) -> float: + """Safely retrieves value of the metric logged in LightningModule. + :param metric_dict: A dict containing metric values. + :param metric_name: The name of the metric to retrieve. + :return: The value of the metric. + """ if not metric_name: log.info("Metric name is None! Skipping metric value retrieval...") return None diff --git a/tests/conftest.py b/tests/conftest.py index 8fda2e0c1..81c0542f6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,7 @@ """This file prepares config fixtures for other tests.""" +from pathlib import Path + import pyrootutils import pytest from hydra import compose, initialize @@ -9,6 +11,10 @@ @pytest.fixture(scope="package") def cfg_train_global() -> DictConfig: + """A pytest fixture for setting up a default Hydra DictConfig for training. + + :return: A DictConfig object containing a default Hydra configuration for training. + """ with initialize(version_base="1.3", config_path="../configs"): cfg = compose(config_name="train.yaml", return_hydra_config=True, overrides=[]) @@ -32,6 +38,10 @@ def cfg_train_global() -> DictConfig: @pytest.fixture(scope="package") def cfg_eval_global() -> DictConfig: + """A pytest fixture for setting up a default Hydra DictConfig for evaluation. + + :return: A DictConfig containing a default Hydra configuration for evaluation. + """ with initialize(version_base="1.3", config_path="../configs"): cfg = compose(config_name="eval.yaml", return_hydra_config=True, overrides=["ckpt_path=."]) @@ -51,10 +61,18 @@ def cfg_eval_global() -> DictConfig: return cfg -# this is called by each test which uses `cfg_train` arg -# each test generates its own temporary logging path @pytest.fixture(scope="function") -def cfg_train(cfg_train_global, tmp_path) -> DictConfig: +def cfg_train(cfg_train_global: DictConfig, tmp_path: Path) -> DictConfig: + """A pytest fixture built on top of the `cfg_train_global()` fixture, which accepts a temporary + logging path `tmp_path` for generating a temporary logging path. + + This is called by each test which uses the `cfg_train` arg. Each test generates its own temporary logging path. + + :param cfg_train_global: The input DictConfig object to be modified. + :param tmp_path: The temporary logging path. + + :return: A DictConfig with updated output and log directories corresponding to `tmp_path`. + """ cfg = cfg_train_global.copy() with open_dict(cfg): @@ -66,10 +84,18 @@ def cfg_train(cfg_train_global, tmp_path) -> DictConfig: GlobalHydra.instance().clear() -# this is called by each test which uses `cfg_eval` arg -# each test generates its own temporary logging path @pytest.fixture(scope="function") -def cfg_eval(cfg_eval_global, tmp_path) -> DictConfig: +def cfg_eval(cfg_eval_global: DictConfig, tmp_path: Path) -> DictConfig: + """A pytest fixture built on top of the `cfg_eval_global()` fixture, which accepts a temporary + logging path `tmp_path` for generating a temporary logging path. + + This is called by each test which uses the `cfg_eval` arg. Each test generates its own temporary logging path. + + :param cfg_train_global: The input DictConfig object to be modified. + :param tmp_path: The temporary logging path. + + :return: A DictConfig with updated output and log directories corresponding to `tmp_path`. + """ cfg = cfg_eval_global.copy() with open_dict(cfg): diff --git a/tests/helpers/package_available.py b/tests/helpers/package_available.py index 614778fef..0afdba8dc 100644 --- a/tests/helpers/package_available.py +++ b/tests/helpers/package_available.py @@ -5,7 +5,12 @@ def _package_available(package_name: str) -> bool: - """Check if a package is available in your environment.""" + """Check if a package is available in your environment. + + :param package_name: The name of the package to be checked. + + :return: `True` if the package is available. `False` otherwise. + """ try: return pkg_resources.require(package_name) is not None except pkg_resources.DistributionNotFound: diff --git a/tests/helpers/run_if.py b/tests/helpers/run_if.py index a9de9e848..9703af425 100644 --- a/tests/helpers/run_if.py +++ b/tests/helpers/run_if.py @@ -4,12 +4,13 @@ """ import sys -from typing import Optional +from typing import Any, Dict, Optional import pytest import torch from packaging.version import Version from pkg_resources import get_distribution +from pytest import MarkDecorator from tests.helpers.package_available import ( _COMET_AVAILABLE, @@ -31,14 +32,16 @@ class RunIf: Example: + ```python @RunIf(min_torch="1.8") @pytest.mark.parametrize("arg1", [1.0, 2.0]) def test_wrapper(arg1): assert arg1 > 0 + ``` """ def __new__( - self, + cls, min_gpus: int = 0, min_torch: Optional[str] = None, max_torch: Optional[str] = None, @@ -52,24 +55,24 @@ def __new__( neptune: bool = False, comet: bool = False, mlflow: bool = False, - **kwargs, - ): - """ - Args: - min_gpus: min number of GPUs required to run test - min_torch: minimum pytorch version to run test - max_torch: maximum pytorch version to run test - min_python: minimum python version required to run test - skip_windows: skip test for Windows platform - tpu: if TPU is available - sh: if `sh` module is required to run the test - fairscale: if `fairscale` module is required to run the test - deepspeed: if `deepspeed` module is required to run the test - wandb: if `wandb` module is required to run the test - neptune: if `neptune` module is required to run the test - comet: if `comet` module is required to run the test - mlflow: if `mlflow` module is required to run the test - kwargs: native pytest.mark.skipif keyword arguments + **kwargs: Dict[Any, Any], + ) -> MarkDecorator: + """Creates a new `@RunIf` `MarkDecorator` decorator. + + :param min_gpus: Min number of GPUs required to run test. + :param min_torch: Minimum pytorch version to run test. + :param max_torch: Maximum pytorch version to run test. + :param min_python: Minimum python version required to run test. + :param skip_windows: Skip test for Windows platform. + :param tpu: If TPU is available. + :param sh: If `sh` module is required to run the test. + :param fairscale: If `fairscale` module is required to run the test. + :param deepspeed: If `deepspeed` module is required to run the test. + :param wandb: If `wandb` module is required to run the test. + :param neptune: If `neptune` module is required to run the test. + :param comet: If `comet` module is required to run the test. + :param mlflow: If `mlflow` module is required to run the test. + :param kwargs: Native `pytest.mark.skipif` keyword arguments. """ conditions = [] reasons = [] diff --git a/tests/helpers/run_sh_command.py b/tests/helpers/run_sh_command.py index b0afcbb73..fdd2ed633 100644 --- a/tests/helpers/run_sh_command.py +++ b/tests/helpers/run_sh_command.py @@ -8,8 +8,11 @@ import sh -def run_sh_command(command: List[str]): - """Default method for executing shell commands with pytest and sh package.""" +def run_sh_command(command: List[str]) -> None: + """Default method for executing shell commands with `pytest` and `sh` package. + + :param command: A list of shell commands as strings. + """ msg = None try: sh.python(command) diff --git a/tests/test_configs.py b/tests/test_configs.py index e7189df0f..d7041dc78 100644 --- a/tests/test_configs.py +++ b/tests/test_configs.py @@ -3,7 +3,11 @@ from omegaconf import DictConfig -def test_train_config(cfg_train: DictConfig): +def test_train_config(cfg_train: DictConfig) -> None: + """Tests the training configuration provided by the `cfg_train` pytest fixture. + + :param cfg_train: A DictConfig containing a valid training configuration. + """ assert cfg_train assert cfg_train.data assert cfg_train.model @@ -16,7 +20,11 @@ def test_train_config(cfg_train: DictConfig): hydra.utils.instantiate(cfg_train.trainer) -def test_eval_config(cfg_eval: DictConfig): +def test_eval_config(cfg_eval: DictConfig) -> None: + """Tests the evaluation configuration provided by the `cfg_eval` pytest fixture. + + :param cfg_train: A DictConfig containing a valid evaluation configuration. + """ assert cfg_eval assert cfg_eval.data assert cfg_eval.model diff --git a/tests/test_datamodules.py b/tests/test_datamodules.py index 407b949a1..901f3d6bb 100644 --- a/tests/test_datamodules.py +++ b/tests/test_datamodules.py @@ -7,7 +7,13 @@ @pytest.mark.parametrize("batch_size", [32, 128]) -def test_mnist_datamodule(batch_size): +def test_mnist_datamodule(batch_size: int) -> None: + """Tests `MNISTDataModule` to verify that it can be downloaded correctly, that the necessary + attributes were created (e.g., the dataloader objects), and that dtypes and batch sizes + correctly match. + + :param batch_size: Batch size of the data to be loaded by the dataloader. + """ data_dir = "data/" dm = MNISTDataModule(data_dir=data_dir, batch_size=batch_size) diff --git a/tests/test_eval.py b/tests/test_eval.py index 291776fef..423c9d295 100644 --- a/tests/test_eval.py +++ b/tests/test_eval.py @@ -1,16 +1,23 @@ import os +from pathlib import Path import pytest from hydra.core.hydra_config import HydraConfig -from omegaconf import open_dict +from omegaconf import DictConfig, open_dict from src.eval import evaluate from src.train import train @pytest.mark.slow -def test_train_eval(tmp_path, cfg_train, cfg_eval): - """Train for 1 epoch with `train.py` and evaluate with `eval.py`""" +def test_train_eval(tmp_path: Path, cfg_train: DictConfig, cfg_eval: DictConfig) -> None: + """Tests training and evaluation by training for 1 epoch with `train.py` then evaluating with + `eval.py`. + + :param tmp_path: The temporary logging path. + :param cfg_train: A DictConfig containing a valid training configuration. + :param cfg_eval: A DictConfig containing a valid evaluation configuration. + """ assert str(tmp_path) == cfg_train.paths.output_dir == cfg_eval.paths.output_dir with open_dict(cfg_train): diff --git a/tests/test_sweeps.py b/tests/test_sweeps.py index b03c459f5..7856b1551 100644 --- a/tests/test_sweeps.py +++ b/tests/test_sweeps.py @@ -1,3 +1,5 @@ +from pathlib import Path + import pytest from tests.helpers.run_if import RunIf @@ -9,8 +11,11 @@ @RunIf(sh=True) @pytest.mark.slow -def test_experiments(tmp_path): - """Test running all available experiment configs with fast_dev_run=True.""" +def test_experiments(tmp_path: Path) -> None: + """Test running all available experiment configs with `fast_dev_run=True.` + + :param tmp_path: The temporary logging path. + """ command = [ startfile, "-m", @@ -23,8 +28,11 @@ def test_experiments(tmp_path): @RunIf(sh=True) @pytest.mark.slow -def test_hydra_sweep(tmp_path): - """Test default hydra sweep.""" +def test_hydra_sweep(tmp_path: Path) -> None: + """Test default hydra sweep. + + :param tmp_path: The temporary logging path. + """ command = [ startfile, "-m", @@ -38,8 +46,11 @@ def test_hydra_sweep(tmp_path): @RunIf(sh=True) @pytest.mark.slow -def test_hydra_sweep_ddp_sim(tmp_path): - """Test default hydra sweep with ddp sim.""" +def test_hydra_sweep_ddp_sim(tmp_path: Path) -> None: + """Test default hydra sweep with ddp sim. + + :param tmp_path: The temporary logging path. + """ command = [ startfile, "-m", @@ -56,8 +67,11 @@ def test_hydra_sweep_ddp_sim(tmp_path): @RunIf(sh=True) @pytest.mark.slow -def test_optuna_sweep(tmp_path): - """Test optuna sweep.""" +def test_optuna_sweep(tmp_path: Path) -> None: + """Test Optuna hyperparam sweeping. + + :param tmp_path: The temporary logging path. + """ command = [ startfile, "-m", @@ -72,8 +86,11 @@ def test_optuna_sweep(tmp_path): @RunIf(wandb=True, sh=True) @pytest.mark.slow -def test_optuna_sweep_ddp_sim_wandb(tmp_path): - """Test optuna sweep with wandb and ddp sim.""" +def test_optuna_sweep_ddp_sim_wandb(tmp_path: Path) -> None: + """Test Optuna sweep with wandb logging and ddp sim. + + :param tmp_path: The temporary logging path. + """ command = [ startfile, "-m", diff --git a/tests/test_train.py b/tests/test_train.py index bbe93f4fa..f0a8fdf76 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -1,15 +1,19 @@ import os +from pathlib import Path import pytest from hydra.core.hydra_config import HydraConfig -from omegaconf import open_dict +from omegaconf import DictConfig, open_dict from src.train import train from tests.helpers.run_if import RunIf -def test_train_fast_dev_run(cfg_train): - """Run for 1 train, val and test step.""" +def test_train_fast_dev_run(cfg_train: DictConfig) -> None: + """Run for 1 train, val and test step. + + :param cfg_train: A DictConfig containing a valid training configuration. + """ HydraConfig().set_config(cfg_train) with open_dict(cfg_train): cfg_train.trainer.fast_dev_run = True @@ -18,8 +22,11 @@ def test_train_fast_dev_run(cfg_train): @RunIf(min_gpus=1) -def test_train_fast_dev_run_gpu(cfg_train): - """Run for 1 train, val and test step on GPU.""" +def test_train_fast_dev_run_gpu(cfg_train: DictConfig) -> None: + """Run for 1 train, val and test step on GPU. + + :param cfg_train: A DictConfig containing a valid training configuration. + """ HydraConfig().set_config(cfg_train) with open_dict(cfg_train): cfg_train.trainer.fast_dev_run = True @@ -29,8 +36,11 @@ def test_train_fast_dev_run_gpu(cfg_train): @RunIf(min_gpus=1) @pytest.mark.slow -def test_train_epoch_gpu_amp(cfg_train): - """Train 1 epoch on GPU with mixed-precision.""" +def test_train_epoch_gpu_amp(cfg_train: DictConfig) -> None: + """Train 1 epoch on GPU with mixed-precision. + + :param cfg_train: A DictConfig containing a valid training configuration. + """ HydraConfig().set_config(cfg_train) with open_dict(cfg_train): cfg_train.trainer.max_epochs = 1 @@ -40,8 +50,11 @@ def test_train_epoch_gpu_amp(cfg_train): @pytest.mark.slow -def test_train_epoch_double_val_loop(cfg_train): - """Train 1 epoch with validation loop twice per epoch.""" +def test_train_epoch_double_val_loop(cfg_train: DictConfig) -> None: + """Train 1 epoch with validation loop twice per epoch. + + :param cfg_train: A DictConfig containing a valid training configuration. + """ HydraConfig().set_config(cfg_train) with open_dict(cfg_train): cfg_train.trainer.max_epochs = 1 @@ -50,8 +63,11 @@ def test_train_epoch_double_val_loop(cfg_train): @pytest.mark.slow -def test_train_ddp_sim(cfg_train): - """Simulate DDP (Distributed Data Parallel) on 2 CPU processes.""" +def test_train_ddp_sim(cfg_train: DictConfig) -> None: + """Simulate DDP (Distributed Data Parallel) on 2 CPU processes. + + :param cfg_train: A DictConfig containing a valid training configuration. + """ HydraConfig().set_config(cfg_train) with open_dict(cfg_train): cfg_train.trainer.max_epochs = 2 @@ -62,8 +78,12 @@ def test_train_ddp_sim(cfg_train): @pytest.mark.slow -def test_train_resume(tmp_path, cfg_train): - """Run 1 epoch, finish, and resume for another epoch.""" +def test_train_resume(tmp_path: Path, cfg_train: DictConfig) -> None: + """Run 1 epoch, finish, and resume for another epoch. + + :param tmp_path: The temporary logging path. + :param cfg_train: A DictConfig containing a valid training configuration. + """ with open_dict(cfg_train): cfg_train.trainer.max_epochs = 1