Skip to content

Commit

Permalink
change default config, update readme
Browse files Browse the repository at this point in the history
  • Loading branch information
janumiko committed Feb 12, 2024
1 parent 8d0c2d6 commit c720961
Show file tree
Hide file tree
Showing 7 changed files with 174 additions and 13 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ __pycache__
models
datasets
outputs
wandb
wandb
multirun
27 changes: 21 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,24 @@
This repository contains code for analysis of various pruning methods.
The code is written in PyTorch, including the built-in pruning methods.

## Project Structure
- Iterative pruning is implemented in `iterative_pruning.py`.
- One-shot pruning is implemented in `one_shot_pruning.py`.
- Base model learning loop and evaluation is implemented in `base_workflow.ipynb`.
- Visualizations and results are stored in `validate_models.ipynb`.
- Utility directory contains code for training, evaluation, saving... of the models.
### Installation
To install the required packages, run the following command:
```bash
conda env create -f environment.yml
```
Conda/Anaconda is required to run the above command. If you don't have it installed, you can install it from [here](https://www.anaconda.com/products/distribution).

### Usage
The pruning code is located in the `pruning` directory.
The entry point for the program is `pruning_loop.py`, the main file to run the pruning analysis.
You can configure the code from CLI, or modyfy the configs in `conf` directory. The config is using [Hydra](https://hydra.cc/), which is a configuration system for Python apps.

To run a single pruning you need to provide `pruning.iterations`, `pruning.finetune_epochs` and `pruning.iteration_rate` parameters. For example:
```bash
python pruning_entry.py pruning.iterations=3 pruning.finetune_epochs=1 pruning.iteration_rate=0.02
```

You can do a multi-run by using following hydra syntax:
```bash
python pruning_entry.py --multirun pruning.iterations=1,2,3 pruning.finetune_epochs=1 pruning.iteration_rate=0.01,0.02
```
136 changes: 136 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
name: ml-pruning
channels:
- pytorch
- nvidia
- conda-forge
dependencies:
- _libgcc_mutex=0.1=conda_forge
- _openmp_mutex=4.5=2_kmp_llvm
- antlr-python-runtime=4.9.3=pyhd8ed1ab_1
- appdirs=1.4.4=pyh9f0ad1d_0
- blas=2.116=mkl
- blas-devel=3.9.0=16_linux64_mkl
- brotli-python=1.1.0=py311hb755f60_1
- bzip2=1.0.8=hd590300_5
- ca-certificates=2024.2.2=hbcca054_0
- certifi=2024.2.2=pyhd8ed1ab_0
- charset-normalizer=3.3.2=pyhd8ed1ab_0
- click=8.1.7=unix_pyh707e725_0
- cuda-cudart=12.1.105=0
- cuda-cupti=12.1.105=0
- cuda-libraries=12.1.0=0
- cuda-nvrtc=12.1.105=0
- cuda-nvtx=12.1.105=0
- cuda-opencl=12.3.101=0
- cuda-runtime=12.1.0=0
- docker-pycreds=0.4.0=py_0
- ffmpeg=4.3=hf484d3e_0
- filelock=3.13.1=pyhd8ed1ab_0
- freetype=2.12.1=h267a509_2
- gitdb=4.0.11=pyhd8ed1ab_0
- gitpython=3.1.41=pyhd8ed1ab_0
- gmp=6.3.0=h59595ed_0
- gmpy2=2.1.2=py311h6a5fa03_1
- gnutls=3.6.13=h85f3911_1
- hydra-core=1.3.2=pyhd8ed1ab_0
- icu=73.2=h59595ed_0
- idna=3.6=pyhd8ed1ab_0
- importlib_resources=6.1.1=pyhd8ed1ab_0
- jinja2=3.1.3=pyhd8ed1ab_0
- jpeg=9e=h166bdaf_2
- lame=3.100=h166bdaf_1003
- lcms2=2.15=hfd0df8a_0
- ld_impl_linux-64=2.40=h41732ed_0
- lerc=4.0.0=h27087fc_0
- libabseil=20240116.0=cxx17_h59595ed_1
- libblas=3.9.0=16_linux64_mkl
- libcblas=3.9.0=16_linux64_mkl
- libcublas=12.1.0.26=0
- libcufft=11.0.2.4=0
- libcufile=1.8.1.2=0
- libcurand=10.3.4.107=0
- libcusolver=11.4.4.55=0
- libcusparse=12.0.2.55=0
- libdeflate=1.17=h0b41bf4_0
- libexpat=2.5.0=hcb278e6_1
- libffi=3.4.2=h7f98852_5
- libgcc-ng=13.2.0=h807b86a_5
- libgfortran-ng=13.2.0=h69a702a_5
- libgfortran5=13.2.0=ha4646dd_5
- libgomp=13.2.0=h807b86a_5
- libhwloc=2.9.3=default_h554bfaf_1009
- libiconv=1.17=hd590300_2
- libjpeg-turbo=2.0.0=h9bf148f_0
- liblapack=3.9.0=16_linux64_mkl
- liblapacke=3.9.0=16_linux64_mkl
- libnpp=12.0.2.50=0
- libnsl=2.0.1=hd590300_0
- libnvjitlink=12.1.105=0
- libnvjpeg=12.1.1.14=0
- libpng=1.6.42=h2797004_0
- libprotobuf=4.25.2=h08a7969_0
- libsqlite=3.44.2=h2797004_0
- libstdcxx-ng=13.2.0=h7e041cc_5
- libtiff=4.5.0=h6adf6a1_2
- libuuid=2.38.1=h0b41bf4_0
- libwebp-base=1.3.2=hd590300_0
- libxcb=1.13=h7f98852_1004
- libxcrypt=4.4.36=hd590300_1
- libxml2=2.12.5=h232c23b_0
- libzlib=1.2.13=hd590300_5
- llvm-openmp=15.0.7=h0cdce71_0
- markupsafe=2.1.5=py311h459d7ec_0
- mkl=2022.1.0=h84fe81f_915
- mkl-devel=2022.1.0=ha770c72_916
- mkl-include=2022.1.0=h84fe81f_915
- mpc=1.3.1=hfe3b2da_0
- mpfr=4.2.1=h9458935_0
- mpmath=1.3.0=pyhd8ed1ab_0
- ncurses=6.4=h59595ed_2
- nettle=3.6=he412f7d_0
- networkx=3.2.1=pyhd8ed1ab_0
- numpy=1.26.4=py311h64a7726_0
- omegaconf=2.3.0=pyhd8ed1ab_0
- openh264=2.1.1=h780b84a_0
- openjpeg=2.5.0=hfec8fc6_2
- openssl=3.2.1=hd590300_0
- packaging=23.2=pyhd8ed1ab_0
- pathtools=0.1.2=py_1
- pillow=9.4.0=py311h50def17_1
- pip=24.0=pyhd8ed1ab_0
- protobuf=4.25.2=py311h7b78aeb_0
- psutil=5.9.8=py311h459d7ec_0
- pthread-stubs=0.4=h36c2ea0_1001
- pysocks=1.7.1=pyha2e5f31_6
- python=3.11.7=hab00c5b_1_cpython
- python_abi=3.11=4_cp311
- pytorch=2.2.0=py3.11_cuda12.1_cudnn8.9.2_0
- pytorch-cuda=12.1=ha16c6d3_5
- pytorch-mutex=1.0=cuda
- pyyaml=6.0.1=py311h459d7ec_1
- readline=8.2=h8228510_1
- requests=2.31.0=pyhd8ed1ab_0
- sentry-sdk=1.40.3=pyhd8ed1ab_0
- setproctitle=1.3.3=py311h459d7ec_0
- setuptools=69.0.3=pyhd8ed1ab_0
- six=1.16.0=pyh6c4a22f_0
- smmap=5.0.0=pyhd8ed1ab_0
- sympy=1.12=pypyh9d50eac_103
- tbb=2021.11.0=h00ab1b0_1
- tk=8.6.13=noxft_h4845f30_101
- torchaudio=2.2.0=py311_cu121
- torchtriton=2.2.0=py311
- torchvision=0.17.0=py311_cu121
- typing_extensions=4.9.0=pyha770c72_0
- tzdata=2024a=h0c530f3_0
- urllib3=2.2.0=pyhd8ed1ab_0
- wandb=0.16.3=pyhd8ed1ab_0
- wheel=0.42.0=pyhd8ed1ab_0
- xorg-libxau=1.0.11=hd590300_0
- xorg-libxdmcp=1.1.3=h7f98852_0
- xz=5.2.6=h166bdaf_0
- yaml=0.2.5=h7f98852_2
- zipp=3.17.0=pyhd8ed1ab_0
- zlib=1.2.13=hd590300_5
- zstd=1.5.5=hfc55251_0
prefix: /home/bubuss/miniforge3/envs/ml-pruning
7 changes: 5 additions & 2 deletions pruning/architecture/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,16 @@ def get_cifar10(cfg: DictConfig) -> tuple[Dataset, Dataset, Dataset]:
)

train_dataset = datasets.CIFAR10(
root=cfg.dataset.path, train=True, download=True, transform=normalize_tensor
root=cfg.dataset.path,
train=True,
download=cfg.dataset.download,
transform=normalize_tensor,
)
train_dataset, validate_dataset = random_split(train_dataset, [0.8, 0.2])
test_dataset = datasets.CIFAR10(
root=cfg.dataset.path,
train=False,
download=True,
download=cfg.dataset.download,
transform=normalize_tensor,
)

Expand Down
6 changes: 3 additions & 3 deletions pruning/conf/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ defaults:
- optimizer: adamw

pruning:
rate: 0.05
iterations: 10
finetune_epochs: 1
iteration_rate: ???
iterations: ???
finetune_epochs: ???
batch_size: 64

model:
Expand Down
1 change: 1 addition & 0 deletions pruning/conf/dataset/cifar10.yaml
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
name: cifar10
path: datasets/cifar10
download: False
7 changes: 6 additions & 1 deletion pruning/pruning_entry.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import hydra
import torch
import torch.nn.utils.prune as prune
from pathlib import Path
from omegaconf import DictConfig, OmegaConf
from architecture.dataloaders import get_dataloaders
from architecture.construct_model import construct_model
Expand All @@ -26,7 +27,7 @@ def main(cfg: DictConfig) -> None:
pruning_amount = int(
round(
utility.pruning.calculate_parameters_amount(pruning_parameters)
* cfg.pruning.rate
* cfg.pruning.iteration_rate
)
)

Expand Down Expand Up @@ -54,6 +55,10 @@ def main(cfg: DictConfig) -> None:

print(f"Test accuracy: {test_accuracy:.2f}%")

# Save the model to the Hydra output directory
output_dir = Path(hydra.core.hydra_config.HydraConfig.get().runtime.output_dir)
torch.save(pruned_model.state_dict(), output_dir / f"{cfg.model.name}.pth")


if __name__ == "__main__":
main()

0 comments on commit c720961

Please sign in to comment.