diff --git a/.gitignore b/.gitignore
index d462496df..b0c4afce4 100644
--- a/.gitignore
+++ b/.gitignore
@@ -78,6 +78,7 @@ fabric.properties
!.vscode/launch.json
!.vscode/extensions.json
*.code-workspace
+**/.vscode
### Python template
# Byte-compiled / optimized / DLL files
@@ -225,7 +226,5 @@ ipython_config.py
# git rm -r .ipynb_checkpoints/
.idea/
-project/data/
-project/lightning_logs/
-project/wandb/
-project/logs/
+data/
+logs/
diff --git a/README.md b/README.md
index 4bd9676ba..b6956d3dc 100644
--- a/README.md
+++ b/README.md
@@ -1,92 +1,146 @@
-## PyTorch Lightning + Hydra template
-### A clean and simple template to kickstart your deep learning project 🚀⚡🔥
-- structures ML code the same so that work can easily be extended and replicated
-- allows for rapid experimentation by automating pipeline with config files
-- extends functionality of popular experiment loggers like Weights&Biases (mostly with dedicated callbacks)
+
-This template tries to be as generic as possible - you should be able to easily modify behavior in [train.py](project/train.py) in case you need some unconventional configuration wiring.
+# PyTorch Lightning + Hydra Template
+A clean and scalable template to kickstart your deep learning project 🚀⚡🔥
+Click on `Use this template` button above to initialize new repository.
-Click on `Use this template` button above to initialize new repository.
+This template tries to be as generic as possible. You should be able to easily modify behavior in [train.py](train.py) in case you need some unconventional configuration wiring.
+
+
+## Contents
+- [PyTorch Lightning + Hydra Template](#pytorch-lightning--hydra-template)
+ - [Contents](#contents)
+ - [Main Ideas](#main-ideas)
+ - [Some Notes](#some-notes)
+ - [Why Lightning + Hydra?](#why-lightning--hydra)
+ - [Features](#features)
+ - [Project Structure](#project-structure)
+ - [Workflow](#workflow)
+ - [Main Project Configuration](#main-project-configuration)
+ - [Experiment Configuration](#experiment-configuration)
+ - [Logs](#logs)
+ - [Experiment Tracking](#experiment-tracking)
+ - [Distributed Training](#distributed-training)
+ - [Tricks](#tricks)
+
+- [Your Project Name](#your-project-name)
+ - [Description](#description)
+ - [How to run](#how-to-run)
+ - [Installing project as a package](#installing-project-as-a-package)
+
-### Why Lightning + Hydra?
-- [PyTorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning) provides great abstractions for well structured ML code and advanced features like checkpointing, gradient accumulation, distributed training, etc.
-- [Hydra](https://github.com/facebookresearch/hydra) provides convenient way to manage experiment configurations and advanced features like overriding any config parameter from command line, scheduling execution of many runs, etc.
+## Main Ideas
+- Predefined Structure: clean and scalable so that work can easily be extended and replicated (see [#Project Structure](#project-structure))
+- Modularity: all abstractions are splitted into different submodules
+- Rapid Experimentation: thanks to automating pipeline with config files and hydra command line superpowers
+- Little Boilerplate: so pipeline can be easily modified (see [train.py](train.py))
+- Main Configuration: main config file specifies default training configuration (see [#Main Project Configuration](#main-project-configuration))
+- Experiment Configurations: stored in a separate folder, they can be composed out of smaller configs, override chosen parameters or define everything from scratch (see [#Experiment Configuration](#experiment-configuration))
+- Experiment Tracking: most logging frameworks can be easily integrated! (see [#Experiment Tracking](#experiment-tracking))
+- Tests: simple bash scripts to check if your model doesn't crash under different training conditions (see [tests/](tests/))
+- Logs: all logs (checkpoints, data from loggers, chosen hparams, etc.) are stored in a convenient folder structure imposed by Hydra (see [#Logs](#logs))
+- Hyperparameter Search: made easier with Hydra built in plugins like [Optuna Sweeper](https://hydra.cc/docs/next/plugins/optuna_sweeper)
+- Workflow: comes down to 4 simple steps (see [#Workflow](#workflow))
+
+
+## Some Notes
+- ***Warning: this template currently uses development version of hydra which might be unstable (we wait until Hydra 1.1 is released).***
+- *Based on:
+[deep-learninig-project-template](https://github.com/PyTorchLightning/deep-learning-project-template),
+[cookiecutter-data-science](https://github.com/drivendata/cookiecutter-data-science),
+[hydra-torch](https://github.com/pytorch/hydra-torch),
+[hydra-lightning](https://github.com/romesco/hydra-lightning),
+[lightning-hydra-seed](https://github.com/tchaton/lightning-hydra-seed),
+[pytorch_tempest](https://github.com/Erlemar/pytorch_tempest),
+[pytorch-project-template](https://github.com/ryul99/pytorch-project-template).*
+- *To learn how to configure PyTorch with Hydra take a look at [this detailed MNIST tutorial](https://github.com/pytorch/hydra-torch/blob/master/examples/mnist_00.md).*
+- *Suggestions are always welcome!*
+
-### Some Notes
-***\*warning: this template currently uses development version of hydra which might be unstable (we wait until version 1.1 is released)***
-*\*based on [deep-learninig-project-template](https://github.com/PyTorchLightning/deep-learning-project-template) by PyTorchLightning organization.*
-*\*Suggestions are always welcome!*
+## Why Lightning + Hydra?
+- [PyTorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning) provides great abstractions for well structured ML code and advanced features like checkpointing, gradient accumulation, distributed training, etc.
+- [Hydra](https://github.com/facebookresearch/hydra) provides convenient way to manage experiment configurations and advanced features like overriding any config parameter from command line, scheduling execution of many runs, etc.
+
## Features
-- Predefined folder structure
-- Modularity: all abstractions are splitted into different submodules
-- Automates PyTorch Lightning training pipeline with little boilerplate, so it can be easily modified (see [train.py](project/train.py))
-- All advantages of Hydra
- - Main config file contains default training configuration (see [config.yaml](project/configs/config.yaml))
- - Storing many experiment configurations in a convenient way (see [project/configs/experiment](project/configs/experiment))
- - Command line features (see [#How to run](README.md#How-to-run) for examples)
- - Override any config parameter from command line
- - Schedule execution of many experiments from command line
- - Sweep over hyperparameters from command line
- - Convenient logging of run history, ckpts, etc. (see [#Logs](README.md#Logs))
- - ~~Validating correctness of config with schemas~~ (TODO)
-- Optional Weights&Biases utilities for experiment tracking
- - Callbacks (see [wandb_callbacks.py](project/src/callbacks/wandb_callbacks.py))
- - Automatically store all code files and model checkpoints as artifacts in W&B cloud
- - Generate confusion matrices and f1/precision/recall heatmaps
- - ~~Hyperparameter search with Weights&Biases sweeps ([execute_sweep.py](project/template_utils/execute_sweep.py))~~ (TODO)
-- Example of inference with trained model ([inference_example.py](project/src/utils/inference_example.py))
+- Hydra superpowers
+ - Override any config parameter from command line
+ - Easily switch between different loggers, callbacks sets, optimizers, etc. from command line
+ - Sweep over hyperparameters from command line
+ - Automatic logging of run history
+ - Sweeper integrations for Optuna, Ray and others
+- Optional callbacks for Weigths&Biases ([wandb_callbacks.py](src/callbacks/wandb_callbacks.py))
+ - To support reproducibility:
+ - UploadCodeToWandbAsArtifact
+ - UploadCheckpointsToWandbAsArtifact
+ - WatchModelWithWandb
+ - To provide examples of logging custom visualisations and metrics with callbacks:
+ - LogBestMetricScoresToWandb
+ - LogF1PrecisionRecallHeatmapToWandb
+ - LogConfusionMatrixToWandb
+- ~~Validating correctness of config with Hydra schemas~~ (TODO)
+- Method to pretty print configuration composed by Hydra at the start of the run, using [Rich](https://github.com/willmcgugan/rich/) library ([template_utils.py](src/utils/template_utils.py))
+- Method to log chosen parts of Hydra config to all loggers ([template_utils.py](src/utils/template_utils.py))
+- Example of hyperparameter search with Optuna sweeps ([config_optuna.yaml](configs/config_optuna.yaml))
+- ~~Example of hyperparameter search with Weights&Biases sweeps~~ (TODO)
+- Examples of simple bash scripts to check if your model doesn't crash under different training conditions ([tests/](tests/))
+- Example of inference with trained model ([inference_example.py](src/utils/inference_example.py))
- Built in requirements ([requirements.txt](requirements.txt))
-- Built in conda environment initialization ([conda_env.yaml](conda_env.yaml))
+- Built in conda environment initialization ([conda_env_gpu.yaml](conda_env_gpu.yaml), [conda_env_cpu.yaml](conda_env_cpu.yaml))
- Built in python package setup ([setup.py](setup.py))
-- Example with MNIST digits classification ([mnist_model.py](project/src/models/mnist_model.py), [mnist_datamodule.py](project/src/datamodules/mnist_datamodule.py))
+- Example with MNIST classification ([mnist_model.py](src/models/mnist_model.py), [mnist_datamodule.py](src/datamodules/mnist_datamodule.py))
-## Project structure
+## Project Structure
The directory structure of new project looks like this:
```
-├── project
-│ ├── configs <- Hydra configuration files
-│ │ ├── trainer <- Configurations of lightning trainers
-│ │ ├── model <- Configurations of lightning models
-│ │ ├── datamodule <- Configurations of lightning datamodules
-│ │ ├── callbacks <- Configurations of lightning callbacks
-│ │ ├── logger <- Configurations of lightning loggers
-│ │ ├── seeds <- Configurations of seeds
-│ │ ├── experiment <- Configurations of experiments
-│ │ │
-│ │ └── config.yaml <- Main project configuration file
-│ │
-│ ├── data <- Project data
-│ │
-│ ├── logs <- Logs generated by hydra and pytorch lightning loggers
-│ │
-│ ├── notebooks <- Jupyter notebooks
-│ │
-│ ├── src
-│ │ ├── architectures <- PyTorch model architectures
-│ │ ├── callbacks <- PyTorch Lightning callbacks
-│ │ ├── datamodules <- PyTorch Lightning datamodules
-│ │ ├── datasets <- PyTorch datasets
-│ │ ├── models <- PyTorch Lightning models
-│ │ ├── transforms <- Data transformations
-│ │ └── utils <- Utility scripts
-│ │ ├── inference_example.py <- Example of inference with trained model
-│ │ └── template_utils.py <- Some extra template utilities
-│ │
-│ └── train.py <- Train model with chosen experiment configuration
+├── configs <- Hydra configuration files
+│ ├── trainer <- Configurations of Lightning trainers
+│ ├── model <- Configurations of Lightning models
+│ ├── datamodule <- Configurations of Lightning datamodules
+│ ├── callbacks <- Configurations of Lightning callbacks
+│ ├── logger <- Configurations of Lightning loggers
+│ ├── experiment <- Configurations of experiments
+│ │
+│ ├── config.yaml <- Main project configuration file
+│ └── config_optuna.yaml <- Configuration of Optuna hyperparameter search
+│
+├── data <- Project data
+│
+├── logs <- Logs generated by Hydra and PyTorch Lightning loggers
+│
+├── notebooks <- Jupyter notebooks
+│
+├── tests <- Tests of any kind
+│ ├── quick_tests.sh <- A couple of quick experiments to test if your model
+│ │ doesn't crash under different training conditions
+│ └── ...
+│
+├── src
+│ ├── architectures <- PyTorch model architectures
+│ ├── callbacks <- PyTorch Lightning callbacks
+│ ├── datamodules <- PyTorch Lightning datamodules
+│ ├── datasets <- PyTorch datasets
+│ ├── models <- PyTorch Lightning models
+│ ├── transforms <- Data transformations
+│ └── utils <- Utility scripts
+│ ├── inference_example.py <- Example of inference with trained model
+│ └── template_utils.py <- Some extra template utilities
+│
+├── train.py <- Train model with chosen experiment configuration
│
├── .gitignore
├── LICENSE
├── README.md
-├── conda_env.yaml <- File for installing conda environment
+├── conda_env_gpu.yaml <- File for installing conda env for GPU
+├── conda_env_cpu.yaml <- File for installing conda env for CPU
├── requirements.txt <- File for installing python dependencies
└── setup.py <- File for installing project as a package
```
@@ -104,8 +158,9 @@ The directory structure of new project looks like this:
-## Main project configuration file ([config.yaml](project/configs/config.yaml))
-Main config contains default training configuration.
+## Main Project Configuration
+Location: [configs/config.yaml](configs/config.yaml)
+Main project config contains default training configuration.
It determines how config is composed when simply executing command: `python train.py`
```yaml
# to execute run with default training configuration simply run:
@@ -117,20 +172,23 @@ defaults:
- trainer: default_trainer.yaml
- model: mnist_model.yaml
- datamodule: mnist_datamodule.yaml
- - seeds: default_seeds.yaml # set this to null if you don't want to use seeds
- callbacks: default_callbacks.yaml # set this to null if you don't want to use callbacks
- logger: null # set logger here or use command line (e.g. `python train.py logger=wandb`)
-# path to original working directory (the directory that `train.py` was executed from in command line)
+# path to original working directory (that `train.py` was executed from in command line)
# hydra hijacks working directory by changing it to the current log directory,
# so it's useful to have path to original working directory as a special variable
# read more here: https://hydra.cc/docs/next/tutorials/basic/running_your_app/working_directory
-original_work_dir: ${hydra:runtime.cwd}
+work_dir: ${hydra:runtime.cwd}
# path to folder with data
-data_dir: ${original_work_dir}/data/
+data_dir: ${work_dir}/data/
+
+
+# pretty print config at the start of the run using Rich library
+print_config: True
# output paths for hydra logs
@@ -144,7 +202,8 @@ hydra:
-## Experiment configuration ([project/configs/experiment](project/configs/experiment))
+## Experiment Configuration
+Location: [configs/experiment](configs/experiment)
You can store many experiment configurations in this folder.
Example experiment configuration:
```yaml
@@ -155,15 +214,13 @@ defaults:
- override /trainer: default_trainer.yaml
- override /model: mnist_model.yaml
- override /datamodule: mnist_datamodule.yaml
- - override /seeds: default_seeds.yaml
- override /callbacks: default_callbacks.yaml
- override /logger: null
# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters
-seeds:
- pytorch_seed: 12345
+seed: 12345
trainer:
max_epochs: 10
@@ -184,25 +241,24 @@ datamodule:
More advanced experiment configuration:
```yaml
# to execute this experiment run:
-# python train.py +experiment=exp_example_with_paths
+# python train.py +experiment=exp_example_full
defaults:
- override /trainer: null
- override /model: null
- override /datamodule: null
- - override /seeds: null
- - override /callbacks: default_callbacks.yaml
+ - override /callbacks: null
- override /logger: null
# we override default configurations with nulls to prevent them from loading at all
# instead we define all modules and their paths directly in this config,
# so everything is stored in one place for more readibility
-seeds:
- pytorch_seed: 12345
+seed: 12345
trainer:
_target_: pytorch_lightning.Trainer
+ gpus: 0
min_epochs: 1
max_epochs: 10
gradient_clip_val: 0.5
@@ -211,7 +267,7 @@ model:
_target_: src.models.mnist_model.LitModelMNIST
optimizer: adam
lr: 0.001
- weight_decay: 0.000001
+ weight_decay: 0.00005
architecture: SimpleDenseNet
input_size: 784
lin1_size: 256
@@ -227,17 +283,24 @@ datamodule:
data_dir: ${data_dir}
batch_size: 64
train_val_test_split: [55_000, 5_000, 10_000]
- num_workers: 1
+ num_workers: 0
pin_memory: False
+
+logger:
+ wandb:
+ tags: ["best_model", "uwu"]
+ notes: "Description of this model."
```
## Logs
+Hydra creates new working directory for every executed run.
By default, logs have the following structure:
```
+│
├── logs
-│ ├── runs # Folder for logs generated from single runs
+│ ├── runs # Folder for logs generated from single runs
│ │ ├── 2021-02-15 # Date of executing run
│ │ │ ├── 16-50-49 # Hour of executing run
│ │ │ │ ├── .hydra # Hydra logs
@@ -249,23 +312,49 @@ By default, logs have the following structure:
│ │ ├── ...
│ │ └── ...
│ │
-│ ├── multiruns # Folder for logs generated from sweeps
-│ │ ├── 2021-02-15_16-50-49 # Date and hour of executing sweep
-│ │ │ ├── 0 # Job number
-│ │ │ │ ├── .hydra # Hydra logs
-│ │ │ │ ├── wandb # Weights&Biases logs
-│ │ │ │ ├── checkpoints # Training checkpoints
-│ │ │ │ └── ... # Any other thing saved during training
-│ │ │ ├── 1
-│ │ │ ├── 2
-│ │ │ └── ...
-│ │ ├── ...
-│ │ └── ...
-│ │
+│ └── multiruns # Folder for logs generated from multiruns (sweeps)
+│ ├── 2021-02-15_16-50-49 # Date and hour of executing sweep
+│ │ ├── 0 # Job number
+│ │ │ ├── .hydra # Hydra logs
+│ │ │ ├── wandb # Weights&Biases logs
+│ │ │ ├── checkpoints # Training checkpoints
+│ │ │ └── ... # Any other thing saved during training
+│ │ ├── 1
+│ │ ├── 2
+│ │ └── ...
+│ ├── ...
+│ └── ...
+│
```
+You can change this structure by modifying paths in [config.yaml](configs/config.yaml).
+
+
+
+## Experiment Tracking
+PyTorch Lightning provides built in loggers for Weights&Biases, Neptune, Comet, MLFlow, Tensorboard, TestTube and CSV. To use one of them, simply add its configuration to [configs/logger/](configs/logger/) and run:
+ `python train.py logger=logger_config.yaml`
+You can use many of them at once (see [configs/logger/many_loggers.yaml](configs/logger/many_loggers.yaml) for example).
+
+
+
+## Distributed Training
+(TODO)
+
+
+
+## Tricks
+(TODO)
+
+
+
+
+
+
### DELETE EVERYTHING ABOVE FOR YOUR PROJECT
---
@@ -273,11 +362,7 @@ By default, logs have the following structure:
# Your Project Name
-
-[![Paper](http://img.shields.io/badge/paper-arxiv.1001.2234-B31B1B.svg)](https://www.nature.com/articles/nature14539)
-[![Conference](http://img.shields.io/badge/NeurIPS-2019-4b44ce.svg)](https://papers.nips.cc/book/advances-in-neural-information-processing-systems-31-2018)
-[![Conference](http://img.shields.io/badge/ICLR-2019-4b44ce.svg)](https://papers.nips.cc/book/advances-in-neural-information-processing-systems-31-2018)
-[![Conference](http://img.shields.io/badge/AnyConference-year-4b44ce.svg)](https://papers.nips.cc/book/advances-in-neural-information-processing-systems-31-2018)
+Some short description.
@@ -292,8 +377,7 @@ git clone https://github.com/YourGithubName/your-repo-name
cd your-repo-name
# optionally create conda environment
-conda update conda
-conda env create -f conda_env.yaml -n your_env_name
+conda env create -f conda_env_gpu.yaml -n your_env_name
conda activate your_env_name
# install requirements
@@ -302,13 +386,12 @@ pip install -r requirements.txt
Next, you can train model with default configuration without logging:
```yaml
-cd project
python train.py
```
Or you can train model with chosen logger like Weights&Biases:
```yaml
-# set project and entity names in 'project/configs/logger/wandb.yaml'
+# set project and entity names in `project/configs/logger/wandb.yaml`
wandb:
project: "your_project_name"
entity: "your_wandb_team_name"
@@ -321,14 +404,14 @@ python train.py logger=wandb
Or you can train model with chosen experiment config:
```yaml
-# experiment configurations are placed in 'project/configs/experiment' folder
+# experiment configurations are placed in folder `configs/experiment/`
python train.py +experiment=exp_example_simple
```
To execute all experiments from folder run:
```yaml
-# execute all experiments from folder `project/configs/experiment`
-python train.py --multirun '+experiment=glob(*)'
+# execute all experiments from folder `configs/experiment/`
+python train.py -m '+experiment=glob(*)'
```
You can override any parameter from command line like this:
@@ -343,30 +426,48 @@ python train.py trainer.gpus=1
Attach some callback set to run:
```yaml
-# callback sets configurations are placed in 'project/configs/callbacks' folder
+# callback sets configurations are placed in `configs/callbacks/`
python train.py callbacks=default_callbacks
```
Combaining it all:
```yaml
-python train.py --multirun '+experiment=glob(*)' trainer.max_epochs=10 logger=wandb
+python train.py -m '+experiment=glob(*)' trainer.max_epochs=10 logger=wandb
```
To create a sweep over some hyperparameters run:
```yaml
# this will run 6 experiments one after the other,
# each with different combination of batch_size and learning rate
-python train.py --multirun datamodule.batch_size=32,64,128 model.lr=0.001,0.0005
+python train.py -m datamodule.batch_size=32,64,128 model.lr=0.001,0.0005
+```
+
+To sweep with Optuna:
+```yaml
+# this will run hyperparameter search defined in `configs/config_optuna.yaml`
+python train.py -m --config-name config_optuna.yaml +experiment=exp_example_simple
```
+Resume from checkpoint:
+```yaml
+# checkpoint can be either path or URL
+# path should be either absolute or prefixed with `${work_dir}/`
+# use quotes '' around argument or otherwise $ symbol breaks it
+python train.py '+trainer.resume_from_checkpoint=${work_dir}/logs/runs/2021-02-28/16-50-49/checkpoints/last.ckpt'
+```
+
+
## Installing project as a package
Optionally you can install project as a package with [setup.py](setup.py):
```yaml
+# install from local files
pip install -e .
+
+# or install from git repo
+pip install git+git://github.com/YourGithubName/your-repo-name.git --upgrade
```
So you can easily import any file into any other file like so:
```python
-from project.src.datasets.img_test_dataset import TestDataset
-from project.src.models.mnist_model import LitModelMNIST
-from project.src.datamodules.mnist_datamodule import MNISTDataModule
+from src.models.mnist_model import LitModelMNIST
+from src.datamodules.mnist_datamodule import MNISTDataModule
```
diff --git a/conda_env_cpu.yaml b/conda_env_cpu.yaml
new file mode 100644
index 000000000..7cc25bf63
--- /dev/null
+++ b/conda_env_cpu.yaml
@@ -0,0 +1,14 @@
+#name: conda_env_name
+
+channels:
+ - pytorch
+ - conda-forge
+ - defaults
+
+dependencies:
+ - python=3.8
+ - pip
+ - notebook
+ - pytorch
+ - torchvision
+ - torchaudio
diff --git a/conda_env.yaml b/conda_env_gpu.yaml
similarity index 100%
rename from conda_env.yaml
rename to conda_env_gpu.yaml
diff --git a/project/configs/callbacks/default_callbacks.yaml b/configs/callbacks/default_callbacks.yaml
similarity index 80%
rename from project/configs/callbacks/default_callbacks.yaml
rename to configs/callbacks/default_callbacks.yaml
index ebe31531e..48803ccd9 100644
--- a/project/configs/callbacks/default_callbacks.yaml
+++ b/configs/callbacks/default_callbacks.yaml
@@ -1,16 +1,16 @@
model_checkpoint:
_target_: pytorch_lightning.callbacks.ModelCheckpoint
- monitor: "val_acc" # name of the logged metric which determines when model is improving
+ monitor: "val/acc" # name of the logged metric which determines when model is improving
save_top_k: 2 # save k best models (determined by above metric)
save_last: True # additionaly always save model from last epoch
mode: "max" # can be "max" or "min"
dirpath: 'checkpoints/'
- filename: 'sample-mnist-{epoch:02d}'
+ filename: '{epoch:02d}'
early_stopping:
_target_: pytorch_lightning.callbacks.EarlyStopping
- monitor: "val_acc" # name of the logged metric which determines when model is improving
+ monitor: "val/acc" # name of the logged metric which determines when model is improving
patience: 100 # how many epochs of not improving until training stops
mode: "max" # can be "max" or "min"
min_delta: 0.0 # minimum change in the monitored metric needed to qualify as an improvement
diff --git a/project/data/.gitkeep b/configs/callbacks/none.yaml
similarity index 100%
rename from project/data/.gitkeep
rename to configs/callbacks/none.yaml
diff --git a/configs/callbacks/wandb_callbacks.yaml b/configs/callbacks/wandb_callbacks.yaml
new file mode 100644
index 000000000..9a7e1737d
--- /dev/null
+++ b/configs/callbacks/wandb_callbacks.yaml
@@ -0,0 +1,34 @@
+defaults:
+ - default_callbacks.yaml
+
+
+upload_code_to_wandb_as_artifact:
+ _target_: src.callbacks.wandb_callbacks.UploadCodeToWandbAsArtifact
+ code_dir: ${work_dir}
+
+
+upload_ckpts_to_wandb_as_artifact:
+ _target_: src.callbacks.wandb_callbacks.UploadCheckpointsToWandbAsArtifact
+ ckpt_dir: "checkpoints/"
+ upload_best_only: False
+
+
+watch_model_with_wandb:
+ _target_: src.callbacks.wandb_callbacks.WatchModelWithWandb
+ log: "all"
+ log_freq: 100
+
+
+# BUGGED :(
+# save_best_metric_scores_to_wandb:
+# _target_: src.callbacks.wandb_callbacks.LogBestMetricScoresToWandb
+
+
+save_f1_precision_recall_heatmap_to_wandb:
+ _target_: src.callbacks.wandb_callbacks.LogF1PrecisionRecallHeatmapToWandb
+ class_names: ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
+
+
+save_confusion_matrix_to_wandb:
+ _target_: src.callbacks.wandb_callbacks.LogConfusionMatrixToWandb
+ class_names: ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
diff --git a/project/configs/config.yaml b/configs/config.yaml
similarity index 69%
rename from project/configs/config.yaml
rename to configs/config.yaml
index 8291a5b8c..55d8cb72c 100644
--- a/project/configs/config.yaml
+++ b/configs/config.yaml
@@ -5,24 +5,27 @@ defaults:
- trainer: default_trainer.yaml
- model: mnist_model.yaml
- datamodule: mnist_datamodule.yaml
- - seeds: default_seeds.yaml # set this to null if you don't want to use seeds
- callbacks: default_callbacks.yaml # set this to null if you don't want to use callbacks
- logger: null # set logger here or use command line (e.g. `python train.py logger=wandb`)
- # we add this just to enable color logging
- # - hydra/hydra_logging: colorlog
- # - hydra/job_logging: colorlog
+ # enable color logging
+ # - override hydra/hydra_logging: colorlog
+ # - override hydra/job_logging: colorlog
-# path to original working directory (the directory that `train.py` was executed from in command line)
+# path to original working directory (that `train.py` was executed from in command line)
# hydra hijacks working directory by changing it to the current log directory,
# so it's useful to have path to original working directory as a special variable
# read more here: https://hydra.cc/docs/next/tutorials/basic/running_your_app/working_directory
-original_work_dir: ${hydra:runtime.cwd}
+work_dir: ${hydra:runtime.cwd}
# path to folder with data
-data_dir: ${original_work_dir}/data/
+data_dir: ${work_dir}/data/
+
+
+# pretty print config at the start of the run using Rich library
+print_config: True
# output paths for hydra logs
diff --git a/configs/config_optuna.yaml b/configs/config_optuna.yaml
new file mode 100644
index 000000000..cf9394a68
--- /dev/null
+++ b/configs/config_optuna.yaml
@@ -0,0 +1,65 @@
+# @package _global_
+
+# example hyperparameter optimization of some experiment with optuna:
+# python train.py -m --config-name config_optuna.yaml +experiment=exp_example_simple logger=wandb
+
+defaults:
+ # load everything from main config file
+ - config.yaml
+
+ # override sweeper to optuna!
+ - override hydra/sweeper: optuna
+
+
+# choose metric which will be optimized by optuna
+optimized_metric: "val/acc_best"
+
+
+hydra:
+ # here we define optuna objective
+ # it optimizes for value returned from function with @hydra.main decorator
+ # learn more here: https://hydra.cc/docs/next/plugins/optuna_sweeper
+ sweeper:
+ optuna_config:
+ study_name: null
+ storage: null
+ n_jobs: 1
+ seed: 12345
+
+ # 'minimize' or 'maximize' the objective
+ direction: maximize
+
+ # number of experiments that will be executed
+ n_trials: 30
+
+ # choose optuna hyperparameter sampler ('tpe', 'random', 'cmaes' or 'nsgaii', 'motpe')
+ # learn more here: https://optuna.readthedocs.io/en/stable/reference/samplers.html
+ sampler: tpe
+
+ # define range of hyperparameters
+ search_space:
+ datamodule.batch_size:
+ type: categorical
+ choices: [32, 64, 128]
+ model.lr:
+ type: float
+ low: 0.0001
+ high: 0.2
+ model.lin1_size:
+ type: categorical
+ choices: [64, 128, 256, 512]
+ model.dropout1:
+ type: categorical
+ choices: [0.05, 0.1, 0.25, 0.5]
+ model.lin2_size:
+ type: categorical
+ choices: [64, 128, 256, 512]
+ model.dropout2:
+ type: categorical
+ choices: [0.05, 0.1, 0.25, 0.5]
+ model.lin3_size:
+ type: categorical
+ choices: [32, 64, 128, 256]
+ model.dropout3:
+ type: categorical
+ choices: [0.05, 0.1, 0.25, 0.5]
diff --git a/project/configs/datamodule/mnist_datamodule.yaml b/configs/datamodule/mnist_datamodule.yaml
similarity index 100%
rename from project/configs/datamodule/mnist_datamodule.yaml
rename to configs/datamodule/mnist_datamodule.yaml
diff --git a/configs/experiment/exp_example_full.yaml b/configs/experiment/exp_example_full.yaml
new file mode 100644
index 000000000..a664cd924
--- /dev/null
+++ b/configs/experiment/exp_example_full.yaml
@@ -0,0 +1,74 @@
+# @package _global_
+
+# to execute this experiment run:
+# python train.py +experiment=exp_example_full
+
+defaults:
+ - override /trainer: null # override trainer to null so it's not loaded from main config defaults...
+ - override /model: null
+ - override /datamodule: null
+ - override /callbacks: null
+ - override /logger: null
+
+# we override default configurations with nulls to prevent them from loading at all
+# instead we define all modules and their paths directly in this config,
+# so everything is stored in one place for more readibility
+
+seed: 12345
+
+trainer:
+ _target_: pytorch_lightning.Trainer
+ gpus: 0
+ min_epochs: 1
+ max_epochs: 10
+ gradient_clip_val: 0.5
+ accumulate_grad_batches: 2
+ weights_summary: null
+ # resume_from_checkpoint: ${work_dir}/last.ckpt
+
+model:
+ _target_: src.models.mnist_model.LitModelMNIST
+ optimizer: adam
+ lr: 0.001
+ weight_decay: 0.00005
+ architecture: SimpleDenseNet
+ input_size: 784
+ lin1_size: 256
+ dropout1: 0.30
+ lin2_size: 256
+ dropout2: 0.25
+ lin3_size: 128
+ dropout3: 0.20
+ output_size: 10
+
+datamodule:
+ _target_: src.datamodules.mnist_datamodule.MNISTDataModule
+ data_dir: ${data_dir}
+ batch_size: 64
+ train_val_test_split: [55_000, 5_000, 10_000]
+ num_workers: 0
+ pin_memory: False
+
+callbacks:
+ model_checkpoint:
+ _target_: pytorch_lightning.callbacks.ModelCheckpoint
+ monitor: "val/acc"
+ save_top_k: 2
+ save_last: True
+ mode: "max"
+ dirpath: 'checkpoints/'
+ filename: 'sample-mnist-{epoch:02d}'
+ early_stopping:
+ _target_: pytorch_lightning.callbacks.EarlyStopping
+ monitor: "val/acc"
+ patience: 100
+ mode: "max"
+
+logger:
+ wandb:
+ tags: ["best_model", "uwu"]
+ notes: "Description of this model."
+ neptune:
+ tags: ["best_model"]
+ csv_logger:
+ save_dir: "."
diff --git a/project/configs/experiment/exp_example_simple.yaml b/configs/experiment/exp_example_simple.yaml
similarity index 74%
rename from project/configs/experiment/exp_example_simple.yaml
rename to configs/experiment/exp_example_simple.yaml
index bf0dc01b9..32af6daea 100644
--- a/project/configs/experiment/exp_example_simple.yaml
+++ b/configs/experiment/exp_example_simple.yaml
@@ -7,19 +7,18 @@ defaults:
- override /trainer: default_trainer.yaml # choose trainer from 'configs/trainer/' folder or set to null
- override /model: mnist_model.yaml # choose model from 'configs/model/' folder or set to null
- override /datamodule: mnist_datamodule.yaml # choose datamodule from 'configs/datamodule/' folder or set to null
- - override /seeds: default_seeds.yaml # choose seeds from 'configs/seeds/' folder or set to null
- override /callbacks: default_callbacks.yaml # choose callback set from 'configs/callbacks/' folder or set to null
- - override /logger: null # choose logger from 'configs/logger/' folder or set it from console when running experiment:
- # `python train.py +experiment=exp_example_simple logger=wandb`
+ - override /logger: null # choose logger from 'configs/logger/' folder or set to null
# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters
-seeds:
- pytorch_seed: 12345
+seed: 12345
trainer:
+ min_epochs: 1
max_epochs: 10
+ gradient_clip_val: 0.5
model:
lr: 0.001
diff --git a/project/configs/logger/comet.yaml b/configs/logger/comet.yaml
similarity index 71%
rename from project/configs/logger/comet.yaml
rename to configs/logger/comet.yaml
index 511f2de91..3447b5ef2 100644
--- a/project/configs/logger/comet.yaml
+++ b/configs/logger/comet.yaml
@@ -1,5 +1,7 @@
-# Comet logger config
+# https://www.comet.ml
+
comet:
_target_: pytorch_lightning.loggers.comet.CometLogger
api_key: ???
project_name: "project_template_test"
+ experiment_name: null
diff --git a/project/configs/logger/csv_logger.yaml b/configs/logger/csv.yaml
similarity index 55%
rename from project/configs/logger/csv_logger.yaml
rename to configs/logger/csv.yaml
index 393fd438b..c76f1f19c 100644
--- a/project/configs/logger/csv_logger.yaml
+++ b/configs/logger/csv.yaml
@@ -1,5 +1,6 @@
-# Csv logger config
-csv_logger:
+# CSVLogger built in PyTorch Lightning
+
+csv:
_target_: pytorch_lightning.loggers.csv_logs.CSVLogger
save_dir: "."
- name: "csv_logger/"
+ name: "csv/"
diff --git a/configs/logger/many_loggers.yaml b/configs/logger/many_loggers.yaml
new file mode 100644
index 000000000..87569f2b5
--- /dev/null
+++ b/configs/logger/many_loggers.yaml
@@ -0,0 +1,8 @@
+# train with many loggers at once
+
+defaults:
+ - csv.yaml
+ - wandb.yaml
+ # - neptune.yaml
+ # - comet.yaml
+ # - tensorboard.yaml
diff --git a/project/configs/logger/neptune.yaml b/configs/logger/neptune.yaml
similarity index 55%
rename from project/configs/logger/neptune.yaml
rename to configs/logger/neptune.yaml
index b1e649bef..4fa773e09 100644
--- a/project/configs/logger/neptune.yaml
+++ b/configs/logger/neptune.yaml
@@ -1,6 +1,7 @@
-# Neptune logger config
+# https://neptune.ai
+
neptune:
_target_: pytorch_lightning.loggers.neptune.NeptuneLogger
- project_name: "hobogalaxy/lightning-hydra-template-test"
+ project_name: "your_name/lightning-hydra-template-test"
api_key: ${env:NEPTUNE_API_TOKEN} # api key is laoded from environment variable
-# experiment_name: "some_experiment"
+ # experiment_name: "some_experiment"
diff --git a/configs/logger/tensorboard.yaml b/configs/logger/tensorboard.yaml
new file mode 100644
index 000000000..2bb8deece
--- /dev/null
+++ b/configs/logger/tensorboard.yaml
@@ -0,0 +1,6 @@
+# TensorBoard
+
+tensorboard:
+ _target_: pytorch_lightning.loggers.tensorboard.TensorBoardLogger
+ save_dir: "tensorboard/"
+ name: "default"
diff --git a/configs/logger/wandb.yaml b/configs/logger/wandb.yaml
new file mode 100644
index 000000000..54382f3b5
--- /dev/null
+++ b/configs/logger/wandb.yaml
@@ -0,0 +1,10 @@
+# https://wandb.ai (Weights&Biases)
+
+wandb:
+ _target_: pytorch_lightning.loggers.wandb.WandbLogger
+ project: "env_tests"
+ # entity: "" # set to name of your wandb team or just remove it
+ # offline: False # set True to store all logs only locally
+ job_type: "train"
+ group: ""
+ save_dir: "."
diff --git a/project/configs/model/mnist_model.yaml b/configs/model/mnist_model.yaml
similarity index 90%
rename from project/configs/model/mnist_model.yaml
rename to configs/model/mnist_model.yaml
index f1dd0d618..a879add2a 100644
--- a/project/configs/model/mnist_model.yaml
+++ b/configs/model/mnist_model.yaml
@@ -1,7 +1,7 @@
_target_: src.models.mnist_model.LitModelMNIST
optimizer: adam
lr: 0.001
-weight_decay: 0.000001
+weight_decay: 0.00005
architecture: SimpleDenseNet
input_size: 784
lin1_size: 256
diff --git a/configs/trainer/debug_trainer.yaml b/configs/trainer/debug_trainer.yaml
new file mode 100644
index 000000000..f73b0d10f
--- /dev/null
+++ b/configs/trainer/debug_trainer.yaml
@@ -0,0 +1,62 @@
+# Trainer args for debugging model
+
+_target_: pytorch_lightning.Trainer
+
+# set -1 to train on all GPUs available, set 0 to train on CPU only
+gpus: 0
+# auto_select_gpus: True
+
+min_epochs: 3
+max_epochs: 3
+
+# overfit on 10 of the same training set batches
+# overfit_batches: 10
+
+# overfit on 5% of the training data
+# overfit_batches: 0.05
+
+# run validation loop every 5 training epochs
+# check_val_every_n_epoch: 5
+
+# run validation loop 2 times during a training epoch
+val_check_interval: 0.5
+
+# run validation loop every 100 training batches
+# val_check_interval: 100
+
+# run for 1 train, 1 val and 1 test batch
+# fast_dev_run: True
+
+# use only 20% of the data
+limit_train_batches: 0.2
+limit_val_batches: 0.2
+limit_test_batches: 0.2
+
+# number of sanity validation steps
+num_sanity_val_steps: 3
+
+# print execution time profiling after training end
+# profiler: "simple"
+
+# print full weight summary of all modules and submodules
+# weights_summary: "full"
+# weights_summary: "top"
+
+# use gradient clipping because why not
+gradient_clip_val: 0.5
+
+# perform optimisation after accumulating gradient from 5 batches
+accumulate_grad_batches: 5
+
+# no accumulation for epochs 1-4. accumulate 3 for epochs 5-10. accumulate 20 after that
+# accumulate_grad_batches: {5: 3, 10: 20}
+
+# automatically find the largest batch size that fits into memory and is power of 2
+# (requires calling trainer.tune(model=model, datamodule=datamodule))
+# auto_scale_batch_size:'power'
+
+# set tensor precision to 16 (default is 32 bits)
+# precision: 16
+
+# apex backend for mixed precision training https://github.com/NVIDIA/apex
+# amp_backend: 'apex'
diff --git a/project/configs/trainer/default_trainer.yaml b/configs/trainer/default_trainer.yaml
similarity index 68%
rename from project/configs/trainer/default_trainer.yaml
rename to configs/trainer/default_trainer.yaml
index ecf7e82ee..682b748ea 100644
--- a/project/configs/trainer/default_trainer.yaml
+++ b/configs/trainer/default_trainer.yaml
@@ -1,8 +1,9 @@
_target_: pytorch_lightning.Trainer
gpus: 0 # set -1 to train on all GPUs available, set 0 to train on CPU only
+min_epochs: 1
max_epochs: 10
gradient_clip_val: 0.5
num_sanity_val_steps: 3
-progress_bar_refresh_rate: 10
-weights_summary: null # null in yaml represents python None value
+progress_bar_refresh_rate: 20
+weights_summary: null
default_root_dir: "lightning_logs/"
diff --git a/project/logs/.gitkeep b/data/.gitkeep
similarity index 100%
rename from project/logs/.gitkeep
rename to data/.gitkeep
diff --git a/project/notebooks/.gitkeep b/logs/.gitkeep
similarity index 100%
rename from project/notebooks/.gitkeep
rename to logs/.gitkeep
diff --git a/notebooks/.gitkeep b/notebooks/.gitkeep
new file mode 100644
index 000000000..e69de29bb
diff --git a/project/configs/callbacks/mnist_callbacks.yaml b/project/configs/callbacks/mnist_callbacks.yaml
deleted file mode 100644
index 2b9b9ab92..000000000
--- a/project/configs/callbacks/mnist_callbacks.yaml
+++ /dev/null
@@ -1,13 +0,0 @@
-defaults:
- - default_callbacks.yaml
- - wandb_callbacks.yaml
-
-
-save_confusion_matrix_to_wandb:
- _target_: src.callbacks.wandb_callbacks.SaveConfusionMatrixToWandb
- class_names: ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
-
-
-save_f1_precision_recall_heatmap_to_wandb:
- _target_: src.callbacks.wandb_callbacks.SaveMetricsHeatmapToWandb
- class_names: ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
diff --git a/project/configs/callbacks/wandb_callbacks.yaml b/project/configs/callbacks/wandb_callbacks.yaml
deleted file mode 100644
index b525227a0..000000000
--- a/project/configs/callbacks/wandb_callbacks.yaml
+++ /dev/null
@@ -1,17 +0,0 @@
-defaults:
- - default_callbacks.yaml
-
-
-save_best_metric_scores_to_wandb:
- _target_: src.callbacks.wandb_callbacks.SaveBestMetricScoresToWandb
-
-
-upload_code_to_wandb_as_artifact:
- _target_: src.callbacks.wandb_callbacks.SaveCodeToWandb
- code_dir: ${original_work_dir}/
-
-
-upload_ckpts_to_wandb_as_artifact:
- _target_: src.callbacks.wandb_callbacks.UploadAllCheckpointsToWandb
- ckpt_dir: "checkpoints/"
- upload_best_only: False
diff --git a/project/configs/experiment/exp_example_advanced.yaml b/project/configs/experiment/exp_example_advanced.yaml
deleted file mode 100644
index 66839e1b7..000000000
--- a/project/configs/experiment/exp_example_advanced.yaml
+++ /dev/null
@@ -1,56 +0,0 @@
-# @package _global_
-
-# to execute this experiment run:
-# python train.py +experiment=exp_example_advanced
-
-defaults:
- - override /trainer: default_trainer.yaml # choose trainer from 'configs/trainer/' folder
- - override /model: mnist_model.yaml # choose model from 'configs/model/' folder
- - override /datamodule: mnist_datamodule.yaml # choose datamodule from 'configs/datamodule/' folder
- - override /seeds: default_seeds.yaml # choose seeds from 'configs/seeds/' folder
- - override /callbacks: default_callbacks.yaml # choose callback set from 'configs/callbacks/' folder
- - override /logger: null # choose logger from 'configs/logger/' folder or set it from console when running experiment:
- # `python train.py +experiment=exp_example_advanced logger=wandb`
-
-# all parameters below will be merged with parameters from default configurations set above
-# this allows you to overwrite only specified parameters
-
-seeds:
- pytorch_seed: 12345 # pytorch seed for this experiment (affects torch.utils.data.random_split() method used in mnist_datamodule)
-
-trainer:
- min_epochs: 1 # train for at least this many epochs (denies early stopping)
- max_epochs: 10 # train for maximum this many epochs
- gradient_clip_val: 0.5 # gradient clipping (helps with exploding gradient issues)
- accumulate_grad_batches: 2 # perform optimization step after accumulating gradient from 2 batches
- fast_dev_run: False # execute 1 training, 1 validation and 1 test epoch only
- limit_train_batches: 0.6 # train on 60% of training data
- limit_val_batches: 0.9 # validate on 90% of validation data
- limit_test_batches: 1.0 # test on 100% of test data
- val_check_interval: 0.5 # perform validation twice per epoch
- # resume_from_checkpoint: ${work_dir}/last.ckpt # path to checkpoint (this can be also url for download)
-
-model: # you can add here any params you want and then access them in lightning model
- lr: 0.001
- weight_decay: 0.00001
- input_size: 784 # img size is 1*28*28
- output_size: 10 # there are 10 digit classes
- lin1_size: 256
- lin2_size: 256
- lin3_size: 128
-
-datamodule: # you can add here any params you want and then access them in lightning datamodule
- batch_size: 64
- train_val_test_split: [55_000, 5_000, 10_000]
- num_workers: 1 # num of processes used for loading data in parallel
- pin_memory: False # dataloaders will copy tensors into CUDA pinned memory before returning them
-
-logger: # you can add here additional logger arguments specific for this experiment
- wandb:
- tags: ["best_model", "uwu"]
- notes: "Description of this model."
- group: "mnist"
- neptune:
- tags: ["best_model"]
- csv_logger:
- save_dir: "."
diff --git a/project/configs/experiment/exp_example_with_paths.yaml b/project/configs/experiment/exp_example_with_paths.yaml
deleted file mode 100644
index c9fc37aeb..000000000
--- a/project/configs/experiment/exp_example_with_paths.yaml
+++ /dev/null
@@ -1,50 +0,0 @@
-# @package _global_
-
-# to execute this experiment run:
-# python train.py +experiment=exp_example_with_paths
-
-defaults:
- - override /trainer: null # override trainer to null so it's not loaded from main config defaults
- - override /model: null # override model to null so it's not loaded from main config defaults
- - override /datamodule: null # override datamodel to null so it's not loaded from main config defaults
- - override /seeds: null # override seeds to null so it's not loaded from main config defaults
- - override /callbacks: default_callbacks.yaml # choose callback set from 'configs/callbacks/' folder
- - override /logger: null # choose logger from 'configs/logger/' folder or set it from console when running experiment:
- # `python train.py +experiment=exp_example_with_paths logger=wandb`
-
-# we override default configurations with nulls to prevent them from loading at all - instead we define all modules
-# and their paths directly in this config so everything is stored in one place and we have more readibility
-
-seeds:
- pytorch_seed: 12345
-
-trainer:
- _target_: pytorch_lightning.Trainer
- min_epochs: 1
- max_epochs: 10
- gradient_clip_val: 0.5
- weights_summary: null
- gpus: 0
-
-model:
- _target_: src.models.mnist_model.LitModelMNIST
- optimizer: adam
- lr: 0.001
- weight_decay: 0.000001
- architecture: SimpleDenseNet
- input_size: 784
- lin1_size: 256
- dropout1: 0.30
- lin2_size: 256
- dropout2: 0.25
- lin3_size: 128
- dropout3: 0.20
- output_size: 10
-
-datamodule:
- _target_: src.datamodules.mnist_datamodule.MNISTDataModule
- data_dir: ${data_dir}
- batch_size: 64
- train_val_test_split: [55_000, 5_000, 10_000]
- num_workers: 1
- pin_memory: False
diff --git a/project/configs/logger/all_loggers.yaml b/project/configs/logger/all_loggers.yaml
deleted file mode 100644
index 155e96dd4..000000000
--- a/project/configs/logger/all_loggers.yaml
+++ /dev/null
@@ -1,7 +0,0 @@
-# Train with many loggers at once
-defaults:
- - csv_logger.yaml
- - wandb.yaml
- - neptune.yaml
-# - comet.yaml
-# - tensorboard.yaml
diff --git a/project/configs/logger/tensorboard.yaml b/project/configs/logger/tensorboard.yaml
deleted file mode 100644
index 47eb6c7db..000000000
--- a/project/configs/logger/tensorboard.yaml
+++ /dev/null
@@ -1,3 +0,0 @@
-# Tensorboard logger config
-tensorboard:
- _target_: pytorch_lightning.loggers.tensorboard.TensorboardLogger
diff --git a/project/configs/logger/wandb.yaml b/project/configs/logger/wandb.yaml
deleted file mode 100644
index 6aac90036..000000000
--- a/project/configs/logger/wandb.yaml
+++ /dev/null
@@ -1,8 +0,0 @@
-# Weights&Biases logger config
-wandb:
- _target_: pytorch_lightning.loggers.wandb.WandbLogger
- project: "env_tests"
-# entity: "" # set to name of your wandb team or just remove it
- offline: False # set True to store all logs only locally
- job_type: "train"
- save_dir: "."
diff --git a/project/configs/seeds/default_seeds.yaml b/project/configs/seeds/default_seeds.yaml
deleted file mode 100644
index d69af1618..000000000
--- a/project/configs/seeds/default_seeds.yaml
+++ /dev/null
@@ -1 +0,0 @@
-pytorch_seed: 12345
diff --git a/project/configs/trainer/debug_trainer.yaml b/project/configs/trainer/debug_trainer.yaml
deleted file mode 100644
index 2d348ec1a..000000000
--- a/project/configs/trainer/debug_trainer.yaml
+++ /dev/null
@@ -1,17 +0,0 @@
-# trainer args for debugging model
-_target_: pytorch_lightning.Trainer
-gpus: 0 # set -1 to train on all GPUs abailable, set 0 to train on CPU only
-# auto_select_gpus: True
-
-gradient_clip_val: 0.5
-
-# fast_dev_run: True # bugged :( (probably wee need to wait for lightning patch)
-limit_train_batches: 1.0
-limit_val_batches: 1.0
-limit_test_batches: 1.0
-val_check_interval: 1.0
-profiler: "simple" # use profiler to print execution time profiling after training ends
-
-progress_bar_refresh_rate: 10
-weights_summary: "full"
-default_root_dir: "lightning_logs/"
diff --git a/project/src/__init__.py b/project/src/__init__.py
deleted file mode 100644
index e24e3b1a3..000000000
--- a/project/src/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-# package file
diff --git a/project/src/callbacks/wandb_callbacks.py b/project/src/callbacks/wandb_callbacks.py
deleted file mode 100644
index 9164dc373..000000000
--- a/project/src/callbacks/wandb_callbacks.py
+++ /dev/null
@@ -1,215 +0,0 @@
-from sklearn.metrics import precision_score, recall_score, f1_score
-from pytorch_lightning.loggers import WandbLogger
-from wandb.sdk.wandb_run import Run as wandb_run
-from pytorch_lightning import Callback
-import pytorch_lightning as pl
-import torch
-import wandb
-import glob
-import os
-
-
-def get_wandb_logger(trainer: pl.Trainer) -> wandb_run:
- logger = None
- for some_logger in trainer.logger.experiment:
- if isinstance(some_logger, wandb_run):
- logger = some_logger
-
- if not logger:
- raise Exception("You're using wandb related callback, "
- "but wandb logger was not initialized for some reason...")
-
- return logger
-
-
-class SaveCodeToWandb(Callback):
- """
- Upload all *.py files to wandb as an artifact at the beginning of the run.
- """
- def __init__(self, code_dir: str):
- self.code_dir = code_dir
-
- def on_sanity_check_end(self, trainer, pl_module):
- """Upload files when all validation sanity checks end."""
- logger = get_wandb_logger(trainer=trainer)
-
- code = wandb.Artifact('project-source', type='code')
- for path in glob.glob(os.path.join(self.code_dir, '**/*.py'), recursive=True):
- code.add_file(path)
- wandb.run.use_artifact(code)
-
-
-class UploadAllCheckpointsToWandb(Callback):
- """
- Upload experiment checkpoints to wandb as an artifact at the end of training.
- """
- def __init__(self, ckpt_dir: str = "checkpoints/", upload_best_only: bool = False):
- self.ckpt_dir = ckpt_dir
- self.upload_best_only = upload_best_only
-
- def on_train_end(self, trainer, pl_module):
- """Upload ckpts when training ends."""
- logger = get_wandb_logger(trainer=trainer)
-
- ckpts = wandb.Artifact('experiment-ckpts', type='checkpoints')
- if self.upload_best_only:
- ckpts.add_file(trainer.checkpoint_callback.best_model_path)
- else:
- for path in glob.glob(os.path.join(self.ckpt_dir, '**/*.ckpt'), recursive=True):
- ckpts.add_file(path)
- wandb.run.use_artifact(ckpts)
-
-
-class SaveMetricsHeatmapToWandb(Callback):
- """
- Generate f1, precision and recall heatmap from validation step outputs.
- Expects validation step to return predictions and targets.
- Works only for single label classification!
- """
- def __init__(self, class_names=None):
- self.class_names = class_names
- self.preds = []
- self.targets = []
- self.ready = False
-
- def on_sanity_check_end(self, trainer, pl_module):
- """Start executing this callback only after all validation sanity checks end."""
- self.ready = True
-
- def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
- """Gather data from single batch."""
- if self.ready:
- preds, targets = outputs["batch_val_preds"], outputs["batch_val_y"]
- self.preds.append(preds)
- self.targets.append(targets)
-
- def on_validation_epoch_end(self, trainer, pl_module):
- """Generate f1, precision and recall heatmap."""
- if self.ready:
- logger = get_wandb_logger(trainer=trainer)
-
- self.preds = torch.cat(self.preds).cpu()
- self.targets = torch.cat(self.targets).cpu()
- f1 = f1_score(self.preds, self.targets, average=None)
- r = recall_score(self.preds, self.targets, average=None)
- p = precision_score(self.preds, self.targets, average=None)
-
- logger.log({
- f"f1_p_r_heatmap_{trainer.current_epoch}_{logger.id}": wandb.plots.HeatMap(
- x_labels=self.class_names,
- y_labels=["f1", "precision", "recall"],
- matrix_values=[f1, p, r],
- show_text=True,
- )}, commit=False)
-
- self.preds = []
- self.targets = []
-
-
-class SaveConfusionMatrixToWandb(Callback):
- """
- Generate Confusion Matrix.
- Expects validation step to return predictions and targets.
- Works only for single label classification!
- """
- def __init__(self, class_names=None):
- self.class_names = class_names
- self.preds = []
- self.targets = []
- self.ready = False
-
- def on_sanity_check_end(self, trainer, pl_module):
- """Start executing this callback only after all validation sanity checks end."""
- self.ready = True
-
- def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
- """Gather data from single batch."""
- if self.ready:
- preds, targets = outputs["batch_val_preds"], outputs["batch_val_y"]
- self.preds.append(preds)
- self.targets.append(targets)
-
- def on_validation_epoch_end(self, trainer, pl_module):
- """Generate f1, precision and recall heatmap."""
- if self.ready:
- logger = get_wandb_logger(trainer=trainer)
-
- self.preds = torch.cat(self.preds).tolist()
- self.targets = torch.cat(self.targets).tolist()
-
- logger.log({
- f"conf_mat_{trainer.current_epoch}_{logger.id}": wandb.plot.confusion_matrix(
- preds=self.preds,
- y_true=self.targets,
- class_names=self.class_names)
- }, commit=False)
-
- self.preds = []
- self.targets = []
-
-
-class SaveBestMetricScoresToWandb(Callback):
- """
- Store in wandb:
- - max train acc
- - min train loss
- - max val acc
- - min val loss
- Useful for comparing runs in table views, as wandb doesn't currently supports column aggregation.
- """
- def __init__(self):
- self.train_loss_best = None
- self.train_acc_best = None
- self.val_loss_best = None
- self.val_acc_best = None
- self.ready = False
-
- def on_sanity_check_end(self, trainer, pl_module):
- """Start executing this callback only after all validation sanity checks end."""
- self.ready = True
-
- def on_epoch_end(self, trainer, pl_module):
- if self.ready:
- logger = get_wandb_logger(trainer=trainer)
-
- metrics = trainer.callback_metrics
- if self.train_loss_best is None or metrics["train_loss"] < self.train_loss_best:
- self.train_loss_best = metrics["train_loss"]
- if self.train_acc_best is None or metrics["train_acc"] > self.train_acc_best:
- self.train_acc_best = metrics["train_acc"]
- if self.val_loss_best is None or metrics["val_loss"] < self.val_loss_best:
- self.val_loss_best = metrics["val_loss"]
- if self.val_acc_best is None or metrics["val_acc"] > self.val_acc_best:
- self.val_acc_best = metrics["val_acc"]
-
- logger.log({"train_loss_best": self.train_loss_best}, commit=False)
- logger.log({"train_acc_best": self.train_acc_best}, commit=False)
- logger.log({"val_loss_best": self.val_loss_best}, commit=False)
- logger.log({"val_acc_best": self.val_acc_best}, commit=False)
-
-
-# class SaveImagePredictionsToWandb(Callback):
-# """
-# Each epoch upload to wandb a couple of the same images with predicted labels.
-# """
-# def __init__(self, datamodule, num_samples=8):
-# first_batch = next(iter(datamodule.train_dataloader()))
-# self.imgs, self.labels = first_batch
-# self.imgs, self.labels = self.imgs[:num_samples], self.labels[:num_samples]
-# self.ready = True
-#
-# def on_sanity_check_end(self, trainer, pl_module):
-# """Start executing this callback only after all validation sanity checks end."""
-# self.ready = True
-#
-# def on_validation_epoch_end(self, trainer, pl_module):
-# if self.ready:
-# imgs = self.imgs.to(device=pl_module.device)
-# logits = pl_module(imgs)
-# preds = torch.argmax(logits, -1)
-# trainer.logger.experiment.log({f"img_examples": [
-# wandb.Image(
-# x,
-# caption=f"Epoch: {trainer.current_epoch} Pred:{pred}, Label:{y}"
-# ) for x, pred, y in zip(imgs, preds, self.labels)
-# ]}, commit=False)
diff --git a/project/src/models/mnist_model.py b/project/src/models/mnist_model.py
deleted file mode 100644
index 0a9e8a583..000000000
--- a/project/src/models/mnist_model.py
+++ /dev/null
@@ -1,75 +0,0 @@
-from pytorch_lightning.metrics.classification import Accuracy
-import pytorch_lightning as pl
-import torch.nn.functional as F
-import torch
-
-# import custom architectures
-from src.architectures.simple_dense_net import SimpleDenseNet
-
-
-class LitModelMNIST(pl.LightningModule):
- """
- This is example of lightning model for MNIST classification.
- To learn how to create lightning models visit:
- https://pytorch-lightning.readthedocs.io/en/latest/lightning_module.html
- """
-
- def __init__(self, *args, **kwargs):
- super().__init__()
- self.save_hyperparameters()
- self.accuracy = Accuracy()
-
- # Initialize model architecture
- if self.hparams.architecture == "SimpleDenseNet":
- self.architecture = SimpleDenseNet(hparams=self.hparams)
- else:
- raise Exception("Invalid architecture name")
-
- def forward(self, x):
- return self.architecture(x)
-
- # logic for a single training step
- def training_step(self, batch, batch_idx):
- x, y = batch
- logits = self.architecture(x)
- loss = F.nll_loss(logits, y)
-
- # training metrics
- preds = torch.argmax(logits, dim=1)
- acc = self.accuracy(preds, y)
- self.log('train_loss', loss, on_step=False, on_epoch=True)
- self.log('train_acc', acc, on_step=False, on_epoch=True, prog_bar=True)
-
- return loss
-
- # logic for a single validation step
- def validation_step(self, batch, batch_idx):
- x, y = batch
- logits = self.architecture(x)
- loss = F.nll_loss(logits, y)
-
- # validation metrics
- preds = torch.argmax(logits, dim=1)
- acc = self.accuracy(preds, y)
- self.log('val_loss', loss, on_step=False, on_epoch=True)
- self.log('val_acc', acc, on_step=False, on_epoch=True, prog_bar=True)
-
- # we can return here anything and then read it in some callback
- return {"batch_val_loss": loss, "batch_val_acc": acc, "batch_val_preds": preds, "batch_val_y": y}
-
- # logic for a single testing step
- def test_step(self, batch, batch_idx):
- x, y = batch
- logits = self.architecture(x)
- loss = F.nll_loss(logits, y)
-
- # test metrics
- preds = torch.argmax(logits, dim=1)
- acc = self.accuracy(preds, y)
- self.log('test_loss', loss, on_step=False, on_epoch=True)
- self.log('test_acc', acc, on_step=False, on_epoch=True)
-
- return loss
-
- def configure_optimizers(self):
- return torch.optim.Adam(self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay)
diff --git a/project/src/transforms/mnist_transforms.py b/project/src/transforms/mnist_transforms.py
deleted file mode 100644
index 7ab835b71..000000000
--- a/project/src/transforms/mnist_transforms.py
+++ /dev/null
@@ -1,13 +0,0 @@
-"""
-Example file containing data transformations, which can be used by datamodule.
-"""
-from torchvision import transforms
-
-
-mnist_train_transforms = transforms.Compose([
- transforms.ToTensor(),
-])
-
-mnist_test_transforms = transforms.Compose([
- transforms.ToTensor(),
-])
diff --git a/project/src/utils/template_utils.py b/project/src/utils/template_utils.py
deleted file mode 100644
index 7262afc2d..000000000
--- a/project/src/utils/template_utils.py
+++ /dev/null
@@ -1,96 +0,0 @@
-# pytorch lightning imports
-from pytorch_lightning.loggers.wandb import WandbLogger
-from pytorch_lightning.loggers.neptune import NeptuneLogger
-import pytorch_lightning as pl
-
-# hydra imports
-from omegaconf import DictConfig, OmegaConf
-
-# normal imports
-from typing import List
-import logging
-import neptune
-import wandb
-
-log = logging.getLogger(__name__)
-
-
-def print_config(config: DictConfig):
- log.info(f"\n{OmegaConf.to_yaml(config, resolve=True)}")
-
-
-def print_module_init_info(model, datamodule, callbacks, loggers, trainer):
- message = "Model initialised:" + "\n" + model.__module__ + "." + model.__class__.__name__ + "\n"
- log.info(message)
-
- message = "Datamodule initialised:" + "\n" + datamodule.__module__ + "." + datamodule.__class__.__name__ + "\n"
- log.info(message)
-
- message = "Callbacks initialised:" + "\n"
- for cb in callbacks:
- message += cb.__module__ + "." + cb.__class__.__name__ + "\n"
- log.info(message)
-
- message = "Loggers initialised:" + "\n"
- for logger in loggers:
- message += logger.__module__ + "." + logger.__class__.__name__ + "\n"
- log.info(message)
-
- message = "Trainer initialised:" + "\n" + trainer.__module__ + "." + trainer.__class__.__name__ + "\n"
- log.info(message)
-
-
-def make_wandb_watch_model(loggers: List[pl.loggers.LightningLoggerBase], model: pl.LightningModule):
- for logger in loggers:
- if isinstance(logger, WandbLogger):
- if hasattr(model, 'architecture'):
- logger.watch(model.architecture)
- else:
- logger.watch(model)
-
-
-def send_hparams_to_loggers(loggers: List[pl.loggers.LightningLoggerBase], hparams: dict):
- for logger in loggers:
- logger.log_hyperparams(hparams)
-
-
-def log_hparams(config, model, datamodule, callbacks, loggers, trainer):
- hparams = {
- "_class_model": config["model"]["_target_"],
- "_class_datamodule": config["datamodule"]["_target_"]
- }
-
- if hasattr(model, "architecture"):
- obj = model.architecture
- hparams["_class_model_architecture"] = obj.__module__ + "." + obj.__class__.__name__
-
- hparams.update(config["seeds"])
- hparams.update(config["model"])
- hparams.update(config["datamodule"])
- hparams.update(config["trainer"])
- hparams.pop("_target_")
-
- if hasattr(datamodule, 'data_train') and datamodule.data_train is not None:
- hparams["train_size"] = len(datamodule.data_train)
- if hasattr(datamodule, 'data_val') and datamodule.data_val is not None:
- hparams["val_size"] = len(datamodule.data_val)
- if hasattr(datamodule, 'data_test') and datamodule.data_test is not None:
- hparams["test_size"] = len(datamodule.data_test)
-
- send_hparams_to_loggers(loggers=loggers, hparams=hparams)
-
-
-def extras(config, model, datamodule, callbacks, loggers, trainer):
- # Print info about which modules were initialized
- print_module_init_info(model, datamodule, callbacks, loggers, trainer)
-
- # Log extra hyperparameters to loggers
- log_hparams(config, model, datamodule, callbacks, loggers, trainer)
-
- # If WandbLogger was initialized, make it watch the model
- make_wandb_watch_model(loggers=loggers, model=model)
-
-
-def finish():
- wandb.finish()
- # neptune.stop()
diff --git a/project/train.py b/project/train.py
deleted file mode 100644
index bb3afe3fa..000000000
--- a/project/train.py
+++ /dev/null
@@ -1,66 +0,0 @@
-# pytorch lightning imports
-from pytorch_lightning.loggers import LightningLoggerBase
-from pytorch_lightning import LightningModule, LightningDataModule, Callback, Trainer
-import torch
-
-# hydra imports
-from omegaconf import DictConfig
-import hydra
-
-# normal imports
-from typing import List
-
-# template utils imports
-from src.utils import template_utils as utils
-
-
-def train(config):
- # Set global PyTorch seed
- if "seeds" in config and "pytorch_seed" in config["seeds"]:
- torch.manual_seed(seed=config["seeds"]["pytorch_seed"])
-
- # Init PyTorch Lightning model âš¡
- model: LightningModule = hydra.utils.instantiate(config["model"])
-
- # Init PyTorch Lightning datamodule âš¡
- datamodule: LightningDataModule = hydra.utils.instantiate(config["datamodule"])
- datamodule.prepare_data()
- datamodule.setup()
-
- # Init PyTorch Lightning callbacks âš¡
- callbacks: List[Callback] = [
- hydra.utils.instantiate(callback_conf)
- for callback_name, callback_conf in config["callbacks"].items()
- ] if "callbacks" in config else []
-
- # Init PyTorch Lightning loggers âš¡
- loggers: List[LightningLoggerBase] = [
- hydra.utils.instantiate(logger_conf)
- for logger_name, logger_conf in config["logger"].items()
- if "_target_" in logger_conf # ignore logger conf if there's no target
- ] if "logger" in config else []
-
- # Init PyTorch Lightning trainer âš¡
- trainer: Trainer = hydra.utils.instantiate(config["trainer"], callbacks=callbacks, logger=loggers)
-
- # Magic
- utils.extras(config, model, datamodule, callbacks, loggers, trainer)
-
- # Train the model
- trainer.fit(model=model, datamodule=datamodule)
-
- # Evaluate model on test set after training
- trainer.test()
-
- # Finish run
- utils.finish()
-
-
-@hydra.main(config_path="configs/", config_name="config.yaml")
-def main(config: DictConfig):
- utils.print_config(config)
- train(config)
-
-
-if __name__ == "__main__":
- main()
diff --git a/requirements.txt b/requirements.txt
index 2bce92cee..eac912c63 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,11 +1,28 @@
+# --------- pytorch --------- #
torch>=1.7.1
torchvision>=0.8.2
torchaudio>=0.7.2
-pytorch-lightning>=1.1.6
-# hydra-core>=1.0.5
-hydra-core>=1.1.0.dev2
-wandb>=0.10.15
-neptune-client>=0.5.0
-scikit-learn>=0.24.0
-pandas>=1.2.0
+pytorch-lightning>=1.2.1
+
+# --------- hydra --------- #
+hydra-core==1.1.0.dev3
hydra_colorlog>=1.0.0
+hydra-optuna-sweeper>=0.9.0rc2
+
+# --------- loggers --------- #
+wandb>=0.10.20
+# neptune-client
+# comet_ml
+# mlflow
+# tensorboard
+
+# --------- linters --------- #
+# black>=20.8b1
+# flake8>=3.8.4
+# pylint>=2.7.1
+# isort>=5.7.0
+
+# --------- others --------- #
+rich>=9.12.3
+scikit-learn>=0.24.1
+pandas>=1.2.2
diff --git a/setup.py b/setup.py
index 11739d58b..dd9269ced 100644
--- a/setup.py
+++ b/setup.py
@@ -3,12 +3,12 @@
setup(
- name='project',
- version='0.0.0',
- description='Describe Your Cool Project',
- author='',
- author_email='',
- url='https://github.com/hobogalaxy/lightning-hydra-wandb-template', # REPLACE WITH YOUR OWN GITHUB PROJECT LINK
- install_requires=['pytorch-lightning', 'hydra-core'],
+ name="src",
+ version="0.0.0",
+ description="Describe Your Cool Project",
+ author="",
+ author_email="",
+ url="https://github.com/hobogalaxy/lightning-hydra-template", # REPLACE WITH YOUR OWN GITHUB PROJECT LINK
+ install_requires=["pytorch-lightning>=1.2.1", "hydra-core>=1.0.6"],
packages=find_packages(),
)
diff --git a/src/__init__.py b/src/__init__.py
new file mode 100644
index 000000000..c4a3ec875
--- /dev/null
+++ b/src/__init__.py
@@ -0,0 +1 @@
+# makes 'src' a Python module
diff --git a/project/src/architectures/simple_dense_net.py b/src/architectures/simple_dense_net.py
similarity index 96%
rename from project/src/architectures/simple_dense_net.py
rename to src/architectures/simple_dense_net.py
index 5bd49e607..05ae143a3 100644
--- a/project/src/architectures/simple_dense_net.py
+++ b/src/architectures/simple_dense_net.py
@@ -2,7 +2,6 @@
class SimpleDenseNet(nn.Module):
-
def __init__(self, hparams):
super().__init__()
@@ -20,7 +19,7 @@ def __init__(self, hparams):
nn.ReLU(),
nn.Dropout(p=hparams["dropout3"]),
nn.Linear(hparams["lin3_size"], hparams["output_size"]),
- nn.LogSoftmax(dim=1)
+ nn.LogSoftmax(dim=1),
)
def forward(self, x):
diff --git a/project/src/callbacks/custom_callbacks.py b/src/callbacks/custom_callbacks.py
similarity index 67%
rename from project/src/callbacks/custom_callbacks.py
rename to src/callbacks/custom_callbacks.py
index 7c906dfc8..1b3ce4dca 100644
--- a/project/src/callbacks/custom_callbacks.py
+++ b/src/callbacks/custom_callbacks.py
@@ -6,20 +6,20 @@ def __init__(self):
pass
def on_init_start(self, trainer):
- print('Starting to initialize trainer!')
+ print("Starting to initialize trainer!")
def on_init_end(self, trainer):
- print('Trainer is initialized now.')
+ print("Trainer is initialized now.")
def on_train_end(self, trainer, pl_module):
- print('Do something when training ends.')
+ print("Do something when training ends.")
class UnfreezeModelCallback(Callback):
"""
- Unfreeze model after a few epochs.
- It currently unfreezes every possible parameter in model, probably shouldn't work that way...
+ Unfreeze all model parameters after a few epochs.
"""
+
def __init__(self, wait_epochs=5):
self.wait_epochs = wait_epochs
diff --git a/src/callbacks/wandb_callbacks.py b/src/callbacks/wandb_callbacks.py
new file mode 100644
index 000000000..17de76874
--- /dev/null
+++ b/src/callbacks/wandb_callbacks.py
@@ -0,0 +1,234 @@
+# wandb
+from pytorch_lightning.loggers import WandbLogger
+import wandb
+
+# pytorch
+from pytorch_lightning import Callback
+import pytorch_lightning as pl
+import torch
+
+# others
+from sklearn.metrics import precision_score, recall_score, f1_score
+from typing import List
+import glob
+import os
+
+
+def get_wandb_logger(trainer: pl.Trainer) -> WandbLogger:
+ logger = None
+ for lg in trainer.logger:
+ if isinstance(lg, WandbLogger):
+ logger = lg
+
+ if not logger:
+ raise Exception(
+ "You're using wandb related callback, "
+ "but WandbLogger was not found for some reason..."
+ )
+
+ return logger
+
+
+class UploadCodeToWandbAsArtifact(Callback):
+ """Upload all *.py files to wandb as an artifact at the beginning of the run."""
+
+ def __init__(self, code_dir: str):
+ self.code_dir = code_dir
+
+ def on_train_start(self, trainer, pl_module):
+ logger = get_wandb_logger(trainer=trainer)
+ experiment = logger.experiment
+
+ code = wandb.Artifact("project-source", type="code")
+ for path in glob.glob(os.path.join(self.code_dir, "**/*.py"), recursive=True):
+ code.add_file(path)
+
+ experiment.use_artifact(code)
+
+
+class UploadCheckpointsToWandbAsArtifact(Callback):
+ """Upload experiment checkpoints to wandb as an artifact at the end of training."""
+
+ def __init__(self, ckpt_dir: str = "checkpoints/", upload_best_only: bool = False):
+ self.ckpt_dir = ckpt_dir
+ self.upload_best_only = upload_best_only
+
+ def on_train_end(self, trainer, pl_module):
+ logger = get_wandb_logger(trainer=trainer)
+ experiment = logger.experiment
+
+ ckpts = wandb.Artifact("experiment-ckpts", type="checkpoints")
+
+ if self.upload_best_only:
+ ckpts.add_file(trainer.checkpoint_callback.best_model_path)
+ else:
+ for path in glob.glob(
+ os.path.join(self.ckpt_dir, "**/*.ckpt"), recursive=True
+ ):
+ ckpts.add_file(path)
+
+ experiment.use_artifact(ckpts)
+
+
+class WatchModelWithWandb(Callback):
+ """Make WandbLogger watch model at the beginning of the run."""
+
+ def __init__(self, log: str = "gradients", log_freq: int = 100):
+ self.log = log
+ self.log_freq = log_freq
+
+ def on_train_start(self, trainer, pl_module):
+ logger = get_wandb_logger(trainer=trainer)
+ logger.watch(model=trainer.model, log=self.log, log_freq=self.log_freq)
+
+
+class LogF1PrecisionRecallHeatmapToWandb(Callback):
+ """
+ Generate f1, precision and recall heatmap from validation step outputs.
+ Expects validation step to return predictions and targets.
+ Works only for single label classification!
+ """
+
+ def __init__(self, class_names: List[str] = None):
+ self.class_names = class_names
+ self.preds = []
+ self.targets = []
+ self.ready = False
+
+ def on_sanity_check_end(self, trainer, pl_module):
+ """Start executing this callback only after all validation sanity checks end."""
+ self.ready = True
+
+ def on_validation_batch_end(
+ self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
+ ):
+ """Gather data from single batch."""
+ if self.ready:
+ preds, targets = outputs["preds"], outputs["targets"]
+ self.preds.append(preds)
+ self.targets.append(targets)
+
+ def on_validation_epoch_end(self, trainer, pl_module):
+ """Generate f1, precision and recall heatmap."""
+ if self.ready:
+ logger = get_wandb_logger(trainer=trainer)
+ experiment = logger.experiment
+
+ self.preds = torch.cat(self.preds).cpu()
+ self.targets = torch.cat(self.targets).cpu()
+ f1 = f1_score(self.preds, self.targets, average=None)
+ r = recall_score(self.preds, self.targets, average=None)
+ p = precision_score(self.preds, self.targets, average=None)
+
+ experiment.log(
+ {
+ f"f1_p_r_heatmap/{trainer.current_epoch}_{experiment.id}": wandb.plots.HeatMap(
+ x_labels=self.class_names,
+ y_labels=["f1", "precision", "recall"],
+ matrix_values=[f1, p, r],
+ show_text=True,
+ )
+ },
+ commit=False,
+ )
+
+ self.preds = []
+ self.targets = []
+
+
+class LogConfusionMatrixToWandb(Callback):
+ """
+ Generate Confusion Matrix.
+ Expects validation step to return predictions and targets.
+ Works only for single label classification!
+ """
+
+ def __init__(self, class_names: List[str] = None):
+ self.class_names = class_names
+ self.preds = []
+ self.targets = []
+ self.ready = False
+
+ def on_sanity_check_end(self, trainer, pl_module):
+ """Start executing this callback only after all validation sanity checks end."""
+ self.ready = True
+
+ def on_validation_batch_end(
+ self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
+ ):
+ """Gather data from single batch."""
+ if self.ready:
+ preds, targets = outputs["preds"], outputs["targets"]
+ self.preds.append(preds)
+ self.targets.append(targets)
+
+ def on_validation_epoch_end(self, trainer, pl_module):
+ """Generate confusion matrix."""
+ if self.ready:
+ logger = get_wandb_logger(trainer=trainer)
+ experiment = logger.experiment
+
+ self.preds = torch.cat(self.preds).tolist()
+ self.targets = torch.cat(self.targets).tolist()
+
+ experiment.log(
+ {
+ f"confusion_matrix/{trainer.current_epoch}_{experiment.id}": wandb.plot.confusion_matrix(
+ preds=self.preds,
+ y_true=self.targets,
+ class_names=self.class_names,
+ )
+ },
+ commit=False,
+ )
+
+ self.preds = []
+ self.targets = []
+
+
+''' BUGGED :(
+class LogBestMetricScoresToWandb(Callback):
+ """
+ Store in wandb:
+ - max train acc
+ - min train loss
+ - max val acc
+ - min val loss
+ Useful for comparing runs in table views, as wandb doesn't currently support column aggregation.
+ """
+
+ def __init__(self):
+ self.train_loss_best = None
+ self.train_acc_best = None
+ self.val_loss_best = None
+ self.val_acc_best = None
+ self.ready = False
+
+ def on_sanity_check_end(self, trainer, pl_module):
+ """Start executing this callback only after all validation sanity checks end."""
+ self.ready = True
+
+ def on_epoch_end(self, trainer, pl_module):
+ if self.ready:
+ logger = get_wandb_logger(trainer=trainer)
+ experiment = logger.experiment
+
+ metrics = trainer.callback_metrics
+
+ if not self.train_loss_best or metrics["train/loss"] < self.train_loss_best:
+ self.train_loss_best = metrics["train_loss"]
+
+ if not self.train_acc_best or metrics["train/acc"] > self.train_acc_best:
+ self.train_acc_best = metrics["train/acc"]
+
+ if not self.val_loss_best or metrics["val/loss"] < self.val_loss_best:
+ self.val_loss_best = metrics["val/loss"]
+
+ if not self.val_acc_best or metrics["val/acc"] > self.val_acc_best:
+ self.val_acc_best = metrics["val/acc"]
+
+ experiment.log({"train/loss_best": self.train_loss_best}, commit=False)
+ experiment.log({"train/acc_best": self.train_acc_best}, commit=False)
+ experiment.log({"val/loss_best": self.val_loss_best}, commit=False)
+ experiment.log({"val/acc_best": self.val_acc_best}, commit=False)
+'''
diff --git a/project/src/datamodules/mnist_datamodule.py b/src/datamodules/mnist_datamodule.py
similarity index 63%
rename from project/src/datamodules/mnist_datamodule.py
rename to src/datamodules/mnist_datamodule.py
index 370fd5ff3..2bb2ae34c 100644
--- a/project/src/datamodules/mnist_datamodule.py
+++ b/src/datamodules/mnist_datamodule.py
@@ -6,8 +6,7 @@
class MNISTDataModule(LightningDataModule):
"""
- This is example of lightning datamodule for MNIST dataset.
- To learn how to create datamodules visit:
+ This is example of LightningDataModule for MNIST dataset.
https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
"""
@@ -20,10 +19,9 @@ def __init__(self, *args, **kwargs):
self.num_workers = kwargs["num_workers"]
self.pin_memory = kwargs["pin_memory"]
- self.transforms = transforms.Compose([
- transforms.ToTensor(),
- transforms.Normalize((0.1307,), (0.3081,))
- ])
+ self.transforms = transforms.Compose(
+ [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
+ )
# self.dims is returned when you call datamodule.size()
self.dims = (1, 28, 28)
@@ -43,16 +41,33 @@ def setup(self, stage=None):
trainset = MNIST(self.data_dir, train=True, transform=self.transforms)
testset = MNIST(self.data_dir, train=False, transform=self.transforms)
dataset = ConcatDataset(datasets=[trainset, testset])
- self.data_train, self.data_val, self.data_test = random_split(dataset, self.train_val_test_split)
+ self.data_train, self.data_val, self.data_test = random_split(
+ dataset, self.train_val_test_split
+ )
def train_dataloader(self):
- return DataLoader(dataset=self.data_train, batch_size=self.batch_size, num_workers=self.num_workers,
- pin_memory=self.pin_memory, shuffle=True)
+ return DataLoader(
+ dataset=self.data_train,
+ batch_size=self.batch_size,
+ num_workers=self.num_workers,
+ pin_memory=self.pin_memory,
+ shuffle=True,
+ )
def val_dataloader(self):
- return DataLoader(dataset=self.data_val, batch_size=self.batch_size, num_workers=self.num_workers,
- pin_memory=self.pin_memory, shuffle=False)
+ return DataLoader(
+ dataset=self.data_val,
+ batch_size=self.batch_size,
+ num_workers=self.num_workers,
+ pin_memory=self.pin_memory,
+ shuffle=False,
+ )
def test_dataloader(self):
- return DataLoader(dataset=self.data_test, batch_size=self.batch_size, num_workers=self.num_workers,
- pin_memory=self.pin_memory, shuffle=False)
+ return DataLoader(
+ dataset=self.data_test,
+ batch_size=self.batch_size,
+ num_workers=self.num_workers,
+ pin_memory=self.pin_memory,
+ shuffle=False,
+ )
diff --git a/project/src/datasets/img_test_dataset.py b/src/datasets/img_test_dataset.py
similarity index 87%
rename from project/src/datasets/img_test_dataset.py
rename to src/datasets/img_test_dataset.py
index 14076d6af..6c10289f6 100644
--- a/project/src/datasets/img_test_dataset.py
+++ b/src/datasets/img_test_dataset.py
@@ -7,7 +7,6 @@ class TestDataset(Dataset):
"""
Example dataset class for loading images from folder and converting them to monochromatic.
Can be used to perform inference with trained MNIST model.
- 'Dataset' type classes can also be used to create 'DataLoader' type classes which are used by datamodules.
"""
def __init__(self, img_dir, transform):
diff --git a/src/models/mnist_model.py b/src/models/mnist_model.py
new file mode 100644
index 000000000..a6eca4adf
--- /dev/null
+++ b/src/models/mnist_model.py
@@ -0,0 +1,94 @@
+from pytorch_lightning.metrics.classification import Accuracy
+import pytorch_lightning as pl
+import torch.nn.functional as F
+import torch
+
+# import custom architectures
+from src.architectures.simple_dense_net import SimpleDenseNet
+
+
+class LitModelMNIST(pl.LightningModule):
+ """
+ This is example of LightningModule for MNIST classification.
+ https://pytorch-lightning.readthedocs.io/en/latest/lightning_module.html
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__()
+ self.save_hyperparameters()
+ self.accuracy = Accuracy()
+ self.architecture = SimpleDenseNet(hparams=self.hparams)
+
+ self.train_acc_hist = []
+ self.train_loss_hist = []
+ self.val_acc_hist = []
+ self.val_loss_hist = []
+
+ def forward(self, x):
+ return self.architecture(x)
+
+ # logic for a single training step
+ def training_step(self, batch, batch_idx):
+ x, y = batch
+ logits = self.forward(x)
+ loss = F.nll_loss(logits, y)
+
+ # training metrics
+ preds = torch.argmax(logits, dim=1)
+ acc = self.accuracy(preds, y)
+ self.log("train/loss", loss, on_step=False, on_epoch=True, prog_bar=False)
+ self.log("train/acc", acc, on_step=False, on_epoch=True, prog_bar=True)
+
+ # we can return here anything and then read it in some callback or in training_epoch_end() below
+ return {"loss": loss, "preds": preds, "targets": y}
+
+ # logic for a single validation step
+ def validation_step(self, batch, batch_idx):
+ x, y = batch
+ logits = self.forward(x)
+ loss = F.nll_loss(logits, y)
+
+ # validation metrics
+ preds = torch.argmax(logits, dim=1)
+ acc = self.accuracy(preds, y)
+ self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=False)
+ self.log("val/acc", acc, on_step=False, on_epoch=True, prog_bar=True)
+
+ # we can return here anything and then read it in some callback or in validation_epoch_end() below
+ return {"loss": loss, "preds": preds, "targets": y}
+
+ # logic for a single testing step
+ def test_step(self, batch, batch_idx):
+ x, y = batch
+ logits = self.forward(x)
+ loss = F.nll_loss(logits, y)
+
+ # test metrics
+ preds = torch.argmax(logits, dim=1)
+ acc = self.accuracy(preds, y)
+ self.log("test/loss", loss, on_step=False, on_epoch=True)
+ self.log("test/acc", acc, on_step=False, on_epoch=True)
+
+ return loss
+
+ def training_epoch_end(self, outputs):
+ self.train_acc_hist.append(self.trainer.callback_metrics["train/acc"])
+ self.train_loss_hist.append(self.trainer.callback_metrics["train/loss"])
+ self.log("train/acc_best", max(self.train_acc_hist), prog_bar=False)
+ self.log("train/loss_best", min(self.train_loss_hist), prog_bar=False)
+
+ def validation_epoch_end(self, outputs):
+ self.val_acc_hist.append(self.trainer.callback_metrics["val/acc"])
+ self.val_loss_hist.append(self.trainer.callback_metrics["val/loss"])
+ self.log("val/acc_best", max(self.val_acc_hist), prog_bar=False)
+ self.log("val/loss_best", min(self.val_loss_hist), prog_bar=False)
+
+ def configure_optimizers(self):
+ if self.hparams.optimizer == "adam":
+ return torch.optim.Adam(
+ self.parameters(),
+ lr=self.hparams.lr,
+ weight_decay=self.hparams.weight_decay,
+ )
+ else:
+ raise Exception("Invalid optimizer name")
diff --git a/src/transforms/mnist_transforms.py b/src/transforms/mnist_transforms.py
new file mode 100644
index 000000000..0a806f420
--- /dev/null
+++ b/src/transforms/mnist_transforms.py
@@ -0,0 +1,10 @@
+from torchvision import transforms
+
+
+mnist_train_transforms = transforms.Compose(
+ [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
+)
+
+mnist_test_transforms = transforms.Compose(
+ [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
+)
diff --git a/project/src/utils/inference_example.py b/src/utils/inference_example.py
similarity index 59%
rename from project/src/utils/inference_example.py
rename to src/utils/inference_example.py
index c3da5a75a..91d6cc00c 100644
--- a/project/src/utils/inference_example.py
+++ b/src/utils/inference_example.py
@@ -1,20 +1,16 @@
+from src.models.mnist_model import LitModelMNIST
from src.transforms import mnist_transforms
from PIL import Image
-# the LitModel you import should be the same as the one you used for training!
-from src.models.mnist_model import LitModelMNIST
-
-# ckpt can be a url!
-
def predict():
"""
- This method is example of inference with a trained model.
- It Loads trained image classification model from checkpoint.
- Then it loads example image and predicts its label.
- Model used in mnist_model.py should be the same as during training!!!
+ This method is example of inference with a trained model.
+ It Loads trained image classification model from checkpoint.
+ Then it loads example image and predicts its label.
"""
+ # ckpt can be also a URL!
CKPT_PATH = "epoch=0.ckpt"
# load model from checkpoint
@@ -25,8 +21,8 @@ def predict():
trained_model.freeze()
# load data
- img = Image.open("data/example_img.png").convert("L") # for monochromatic conversion
- # img = Image.open("data/example_img.png").convert("RGB") # for RGB conversion
+ img = Image.open("data/example_img.png").convert("L") # convert to black and white
+ # img = Image.open("data/example_img.png").convert("RGB") # convert to RGB
# preprocess
img = mnist_transforms.mnist_test_transforms(img)
diff --git a/src/utils/template_utils.py b/src/utils/template_utils.py
new file mode 100644
index 000000000..f1f1e544c
--- /dev/null
+++ b/src/utils/template_utils.py
@@ -0,0 +1,154 @@
+# pytorch lightning imports
+import pytorch_lightning as pl
+
+# hydra imports
+from omegaconf import DictConfig, OmegaConf
+from hydra.utils import get_original_cwd, to_absolute_path
+
+# loggers
+import wandb
+from pytorch_lightning.loggers.wandb import WandbLogger
+
+# from pytorch_lightning.loggers.neptune import NeptuneLogger
+# from pytorch_lightning.loggers.comet import CometLogger
+# from pytorch_lightning.loggers.mlflow import MLFlowLogger
+# from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
+
+# rich imports
+from rich import print
+from rich.syntax import Syntax
+from rich.tree import Tree
+
+# normal imports
+from typing import List
+
+
+def print_config(config: DictConfig):
+ """Prints content of Hydra config using Rich library.
+
+ Args:
+ config (DictConfig): [description]
+ """
+
+ # TODO print main config path and experiment config path
+ # directory = to_absolute_path("configs/config.yaml")
+ # print(f"Main config path: [link file://{directory}]{directory}")
+
+ style = "dim"
+
+ tree = Tree(f":gear: FULL HYDRA CONFIG", style=style, guide_style=style)
+
+ trainer = OmegaConf.to_yaml(config["trainer"], resolve=True)
+ trainer_branch = tree.add("Trainer", style=style, guide_style=style)
+ trainer_branch.add(Syntax(trainer, "yaml"))
+
+ model = OmegaConf.to_yaml(config["model"], resolve=True)
+ model_branch = tree.add("Model", style=style, guide_style=style)
+ model_branch.add(Syntax(model, "yaml"))
+
+ datamodule = OmegaConf.to_yaml(config["datamodule"], resolve=True)
+ datamodule_branch = tree.add("Datamodule", style=style, guide_style=style)
+ datamodule_branch.add(Syntax(datamodule, "yaml"))
+
+ callbacks_branch = tree.add("Callbacks", style=style, guide_style=style)
+ if "callbacks" in config:
+ for cb_name, cb_conf in config["callbacks"].items():
+ cb = callbacks_branch.add(cb_name, style=style, guide_style=style)
+ cb.add(Syntax(OmegaConf.to_yaml(cb_conf, resolve=True), "yaml"))
+ else:
+ callbacks_branch.add("None")
+
+ logger_branch = tree.add("Logger", style=style, guide_style=style)
+ if "logger" in config:
+ for lg_name, lg_conf in config["logger"].items():
+ lg = logger_branch.add(lg_name, style=style, guide_style=style)
+ lg.add(Syntax(OmegaConf.to_yaml(lg_conf, resolve=True), "yaml"))
+ else:
+ logger_branch.add("None")
+
+ seed = config.get("seed", "None")
+ seed_branch = tree.add(f"Seed", style=style, guide_style=style)
+ seed_branch.add(seed, style=style, guide_style=style)
+
+ print(tree)
+
+
+def log_hparams_to_all_loggers(
+ config: DictConfig,
+ model: pl.LightningModule,
+ datamodule: pl.LightningDataModule,
+ trainer: pl.Trainer,
+ callbacks: List[pl.Callback],
+ logger: List[pl.loggers.LightningLoggerBase],
+):
+ """This method controls which parameters from Hydra config are saved by Lightning loggers.
+
+ Args:
+ config (DictConfig): [description]
+ model (pl.LightningModule): [description]
+ datamodule (pl.LightningDataModule): [description]
+ trainer (pl.Trainer): [description]
+ callbacks (List[pl.Callback]): [description]
+ logger (List[pl.loggers.LightningLoggerBase]): [description]
+ """
+
+ hparams = {}
+
+ # save all params of model, datamodule and trainer
+ hparams.update(config["model"])
+ hparams.update(config["datamodule"])
+ hparams.update(config["trainer"])
+ hparams.pop("_target_")
+
+ # save seed
+ hparams["seed"] = config.get("seed", "None")
+
+ # save targets
+ hparams["_class_model"] = config["model"]["_target_"]
+ hparams["_class_datamodule"] = config["datamodule"]["_target_"]
+
+ # save sizes of each dataset
+ if hasattr(datamodule, "data_train") and datamodule.data_train:
+ hparams["train_size"] = len(datamodule.data_train)
+ if hasattr(datamodule, "data_val") and datamodule.data_val:
+ hparams["val_size"] = len(datamodule.data_val)
+ if hasattr(datamodule, "data_test") and datamodule.data_test:
+ hparams["test_size"] = len(datamodule.data_test)
+
+ # save number of model parameters
+ hparams["#params_total"] = sum(p.numel() for p in model.parameters())
+ hparams["#params_trainable"] = sum(
+ p.numel() for p in model.parameters() if p.requires_grad
+ )
+ hparams["#params_not_trainable"] = sum(
+ p.numel() for p in model.parameters() if not p.requires_grad
+ )
+
+ # send hparams to all loggers
+ for lg in logger:
+ lg.log_hyperparams(hparams)
+
+
+def finish(
+ config: DictConfig,
+ model: pl.LightningModule,
+ datamodule: pl.LightningDataModule,
+ trainer: pl.Trainer,
+ callbacks: List[pl.Callback],
+ logger: List[pl.loggers.LightningLoggerBase],
+):
+ """Makes sure everything closed properly.
+
+ Args:
+ config (DictConfig): [description]
+ model (pl.LightningModule): [description]
+ datamodule (pl.LightningDataModule): [description]
+ trainer (pl.Trainer): [description]
+ callbacks (List[pl.Callback]): [description]
+ logger (List[pl.loggers.LightningLoggerBase]): [description]
+ """
+
+ # without this sweeps with wandb logger might crash!
+ for lg in logger:
+ if isinstance(lg, WandbLogger):
+ wandb.finish()
diff --git a/tests/hydra_wandb_test.py b/tests/hydra_wandb_test.py
new file mode 100644
index 000000000..cc6b1da9d
--- /dev/null
+++ b/tests/hydra_wandb_test.py
@@ -0,0 +1,20 @@
+# import os, sys
+# sys.path.insert(1, os.path.join(sys.path[0], ".."))
+# print(os.path.abspath(os.curdir))
+
+#####################################################
+# python hydra_wandb_test.py -m +some_param=1,2,3,4
+#####################################################
+
+import hydra
+import wandb
+
+
+@hydra.main(config_path="../configs/", config_name="config.yaml")
+def main(config):
+ wandb.init(project="env_tests")
+ wandb.finish()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tests/logger_tests.sh b/tests/logger_tests.sh
new file mode 100644
index 000000000..11166d0b3
--- /dev/null
+++ b/tests/logger_tests.sh
@@ -0,0 +1,21 @@
+# TESTS FOR DIFFERENT LOGGERS
+# TO EXECUTE:
+# bash tests/logger_tests.sh
+
+# conda activate testenv
+
+# Test CSV logger
+echo TEST 1
+python train.py logger=csv_logger trainer.min_epochs=3 trainer.max_epochs=3 trainer.gpus=1
+
+# # Test Weights&Biases logger
+echo TEST 2
+python train.py logger=wandb logger.wandb.project="env_tests" trainer.min_epochs=10 trainer.max_epochs=10 trainer.gpus=1
+
+# Test TensorBoard logger
+echo TEST 3
+python train.py logger=tensorboard trainer.min_epochs=10 trainer.max_epochs=10 trainer.gpus=1
+
+# Test many loggers at once
+echo TEST 4
+python train.py logger=many_loggers trainer.min_epochs=10 trainer.max_epochs=10 trainer.gpus=1
diff --git a/tests/quick_tests.sh b/tests/quick_tests.sh
new file mode 100644
index 000000000..c6cdab68c
--- /dev/null
+++ b/tests/quick_tests.sh
@@ -0,0 +1,35 @@
+# THESE ARE JUST A COUPLE OF QUICK EXPERIMENTS TO TEST IF YOUR MODEL DOESN'T CRASH UNDER DIFFERENT CONDITIONS
+# TO EXECUTE:
+# bash tests/quick_tests.sh
+
+# conda activate testenv
+
+# Test for CPU
+echo TEST 1
+python train.py trainer.gpus=0 trainer.max_epochs=1
+
+# Test for GPU
+echo TEST 2
+python train.py trainer.gpus=1 trainer.max_epochs=1
+
+# Test multiple workers and cuda pinned memory
+echo TEST 3
+python train.py trainer.gpus=1 trainer.max_epochs=2 \
+datamodule.num_workers=4 datamodule.pin_memory=True
+
+# Test all experiment configs
+echo TEST 4
+python train.py -m '+experiment=glob(*)' trainer.gpus=1 trainer.max_epochs=3
+
+# Test with debug trainer
+echo TEST 5
+python train.py trainer=debug_trainer
+
+# Overfit to 10 bathes
+echo TEST 6
+python train.py trainer.min_epochs=20 trainer.max_epochs=20 +trainer.overfit_batches=10
+
+# Test default hydra sweep over hyperparameters (runs 4 different combinations for 1 epoch)
+echo TEST 7
+python train.py -m datamodule.batch_size=32,64 model.lr=0.001,0.003 \
+trainer.gpus=1 trainer.max_epochs=1
diff --git a/tests/sweep_tests.sh b/tests/sweep_tests.sh
new file mode 100644
index 000000000..791c047a3
--- /dev/null
+++ b/tests/sweep_tests.sh
@@ -0,0 +1,28 @@
+# TESTS FOR HYPERPARAMETER SWEEPS
+# TO EXECUTE:
+# bash tests/sweep_tests.sh
+
+# conda activate testenv
+
+
+# currently there are some issues with running sweeps alongside wandb
+# https://github.com/wandb/client/issues/1314
+# this env variable fixes that
+export WANDB_START_METHOD=thread
+
+
+# Test default hydra sweep with wandb logging
+echo TEST 1
+python train.py -m datamodule.batch_size=64,128 model.lr=0.001,0.003 \
++experiment=exp_example_simple \
+trainer.gpus=1 trainer.max_epochs=2 seed=12345 \
+datamodule.num_workers=12 datamodule.pin_memory=True \
+logger=wandb logger.wandb.project="env_tests" logger.wandb.group="DefaultSweep_MNIST_SimpleDenseNet"
+
+# Test optuna sweep with wandb logging
+echo TEST 2
+python train.py -m --config-name config_optuna.yaml \
++experiment=exp_example_simple \
+trainer.gpus=1 trainer.max_epochs=5 seed=12345 \
+datamodule.num_workers=12 datamodule.pin_memory=True \
+logger=wandb logger.wandb.project="env_tests" logger.wandb.group="Optuna_MNIST_SimpleDenseNet"
diff --git a/train.py b/train.py
new file mode 100644
index 000000000..7ccfe9b01
--- /dev/null
+++ b/train.py
@@ -0,0 +1,92 @@
+# pytorch lightning imports
+from pytorch_lightning import LightningModule, LightningDataModule, Callback, Trainer
+from pytorch_lightning.loggers import LightningLoggerBase
+from pytorch_lightning import seed_everything
+
+# hydra imports
+from omegaconf import DictConfig
+import hydra
+
+# normal imports
+from typing import List
+
+# src imports
+from src.utils import template_utils as utils
+
+
+def train(config: DictConfig):
+
+ # Pretty print config using Rich library
+ if config["print_config"]:
+ utils.print_config(config)
+
+ # Set seed for random number generators in pytorch, numpy and python.random
+ if "seed" in config:
+ seed_everything(config["seed"])
+
+ # Init PyTorch Lightning model âš¡
+ model: LightningModule = hydra.utils.instantiate(config["model"])
+
+ # Init PyTorch Lightning datamodule âš¡
+ datamodule: LightningDataModule = hydra.utils.instantiate(config["datamodule"])
+ datamodule.prepare_data()
+ datamodule.setup()
+
+ # Init PyTorch Lightning callbacks âš¡
+ callbacks: List[Callback] = []
+ if "callbacks" in config:
+ for _, cb_conf in config["callbacks"].items():
+ if "_target_" in cb_conf:
+ callbacks.append(hydra.utils.instantiate(cb_conf))
+
+ # Init PyTorch Lightning loggers âš¡
+ logger: List[LightningLoggerBase] = []
+ if "logger" in config:
+ for _, lg_conf in config["logger"].items():
+ if "_target_" in lg_conf:
+ logger.append(hydra.utils.instantiate(lg_conf))
+
+ # Init PyTorch Lightning trainer âš¡
+ trainer: Trainer = hydra.utils.instantiate(
+ config["trainer"], callbacks=callbacks, logger=logger
+ )
+
+ # Send some parameters from config to all lightning loggers
+ utils.log_hparams_to_all_loggers(
+ config=config,
+ model=model,
+ datamodule=datamodule,
+ trainer=trainer,
+ callbacks=callbacks,
+ logger=logger,
+ )
+
+ # Train the model
+ trainer.fit(model=model, datamodule=datamodule)
+
+ # Evaluate model on test set after training
+ trainer.test()
+
+ # Make sure everything closed properly
+ utils.finish(
+ config=config,
+ model=model,
+ datamodule=datamodule,
+ trainer=trainer,
+ callbacks=callbacks,
+ logger=logger,
+ )
+
+ # Return best achieved metric score for optuna
+ optimized_metric = config.get("optimized_metric", None)
+ if optimized_metric:
+ return trainer.callback_metrics[optimized_metric]
+
+
+@hydra.main(config_path="configs/", config_name="config.yaml")
+def main(config: DictConfig):
+ return train(config)
+
+
+if __name__ == "__main__":
+ main()