Skip to content

Commit

Permalink
epoch sum, env refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
Kowalski1024 committed May 9, 2024
1 parent 2878f29 commit 2bdfd6b
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 210 deletions.
216 changes: 10 additions & 206 deletions environment.yml
Original file line number Diff line number Diff line change
@@ -1,211 +1,15 @@
name: ml-pruning
channels:
- pytorch
- nvidia
- conda-forge
- pytorch
- nvidia
dependencies:
- _libgcc_mutex=0.1
- _openmp_mutex=4.5
- aiohttp=3.9.3
- aiosignal=1.3.1
- antlr-python-runtime=4.9.3
- appdirs=1.4.4
- attrs=23.2.0
- aws-c-auth=0.7.15
- aws-c-cal=0.6.9
- aws-c-common=0.9.12
- aws-c-compression=0.2.17
- aws-c-event-stream=0.4.1
- aws-c-http=0.8.0
- aws-c-io=0.14.3
- aws-c-mqtt=0.10.1
- aws-c-s3=0.5.0
- aws-c-sdkutils=0.1.14
- aws-checksums=0.1.17
- aws-crt-cpp=0.26.1
- aws-sdk-cpp=1.11.242
- blas=2.116
- blas-devel=3.9.0
- brotli-python=1.1.0
- bzip2=1.0.8
- c-ares=1.26.0
- ca-certificates=2024.2.2
- certifi=2024.2.2
- charset-normalizer=3.3.2
- click=8.1.7
- colorama=0.4.6
- cuda-cudart=12.1.105
- cuda-cupti=12.1.105
- cuda-libraries=12.1.0
- cuda-nvrtc=12.1.105
- cuda-nvtx=12.1.105
- cuda-opencl=12.3.101
- cuda-runtime=12.1.0
- datasets=2.17.0
- dill=0.3.8
- docker-pycreds=0.4.0
- ffmpeg=4.3
- filelock=3.13.1
- freetype=2.12.1
- frozenlist=1.4.1
- fsspec=2023.10.0
- gflags=2.2.2
- gitdb=4.0.11
- gitpython=3.1.41
- glog=0.6.0
- gmp=6.3.0
- gmpy2=2.1.2
- gnutls=3.6.13
- huggingface_hub=0.20.2
- hydra-core=1.3.2
- icu=73.2
- idna=3.6
- importlib_resources=6.1.1
- jinja2=3.1.3
- jpeg=9e
- keyutils=1.6.1
- krb5=1.21.2
- lame=3.100
- lcms2=2.15
- ld_impl_linux-64=2.40
- lerc=4.0.0
- libabseil=20230802.1
- libarrow=15.0.0
- libarrow-acero=15.0.0
- libarrow-dataset=15.0.0
- libarrow-flight=15.0.0
- libarrow-flight-sql=15.0.0
- libarrow-gandiva=15.0.0
- libarrow-substrait=15.0.0
- libblas=3.9.0
- libbrotlicommon=1.1.0
- libbrotlidec=1.1.0
- libbrotlienc=1.1.0
- libcblas=3.9.0
- libcrc32c=1.1.2
- libcublas=12.1.0.26
- libcufft=11.0.2.4
- libcufile=1.8.1.2
- libcurand=10.3.4.107
- libcurl=8.5.0
- libcusolver=11.4.4.55
- libcusparse=12.0.2.55
- libdeflate=1.17
- libedit=3.1.20191231
- libev=4.33
- libevent=2.1.12
- libexpat=2.5.0
- libffi=3.4.2
- libgcc-ng=13.2.0
- libgfortran-ng=13.2.0
- libgfortran5=13.2.0
- libgomp=13.2.0
- libgoogle-cloud=2.12.0
- libgrpc=1.60.1
- libhwloc=2.9.3
- libiconv=1.17
- libjpeg-turbo=2.0.0
- liblapack=3.9.0
- liblapacke=3.9.0
- libllvm15=15.0.7
- libnghttp2=1.58.0
- libnl=3.9.0
- libnpp=12.0.2.50
- libnsl=2.0.1
- libnuma=2.0.16
- libnvjitlink=12.1.105
- libnvjpeg=12.1.1.14
- libparquet=15.0.0
- libpng=1.6.42
- libprotobuf=4.25.1
- libre2-11=2023.06.02
- libsqlite=3.44.2
- libssh2=1.11.0
- libstdcxx-ng=13.2.0
- libthrift=0.19.0
- libtiff=4.5.0
- libutf8proc=2.8.0
- libuuid=2.38.1
- libwebp-base=1.3.2
- libxcb=1.13
- libxcrypt=4.4.36
- libxml2=2.12.5
- libzlib=1.2.13
- llvm-openmp=15.0.7
- lz4-c=1.9.4
- markupsafe=2.1.5
- mkl=2022.1.0
- mkl-devel=2022.1.0
- mkl-include=2022.1.0
- mpc=1.3.1
- mpfr=4.2.1
- mpmath=1.3.0
- multidict=6.0.5
- multiprocess=0.70.16
- ncurses=6.4
- nettle=3.6
- networkx=3.2.1
- numpy=1.26.4
- omegaconf=2.3.0
- openh264=2.1.1
- openjpeg=2.5.0
- openssl=3.2.1
- orc=1.9.2
- packaging=23.2
- pandas=2.2.0
- pathtools=0.1.2
- pillow=9.4.0
- pip=24.0
- protobuf=4.25.1
- psutil=5.9.8
- pthread-stubs=0.4
- pyarrow=15.0.0
- pyarrow-hotfix=0.6
- pysocks=1.7.1
- python=3.11.7
- python-dateutil=2.8.2
- python-tzdata=2024.1
- python-xxhash=3.4.1
- python_abi=3.11
- pytorch=2.2.0
- python=3.11
- pytorch
- torchvision
- pytorch-cuda=12.1
- pytorch-mutex=1.0
- pytz=2024.1
- pyyaml=6.0.1
- rdma-core=50.0
- re2=2023.06.02
- readline=8.2
- requests=2.31.0
- s2n=1.4.3
- safetensors=0.4.2
- scipy=1.12.0
- sentry-sdk=1.40.3
- setproctitle=1.3.3
- setuptools=69.0.3
- six=1.16.0
- smmap=5.0.0
- snappy=1.1.10
- sympy=1.12
- tbb=2021.11.0
- timm=0.9.12
- tk=8.6.13
- torchaudio=2.2.0
- torchtriton=2.2.0
- torchvision=0.17.0
- tqdm=4.66.2
- typing-extensions=4.9.0
- typing_extensions=4.9.0
- tzdata=2024a
- ucx=1.15.0
- urllib3=2.2.0
- wandb=0.16.3
- wheel=0.42.0
- xorg-libxau=1.0.11
- xorg-libxdmcp=1.1.3
- xxhash=0.8.2
- xz=5.2.6
- yaml=0.2.5
- yarl=1.9.4
- zipp=3.17.0
- zlib=1.2.13
- zstd=1.5.5
- hydra-core
- wandb
- timm
- scipy
- pandas
7 changes: 3 additions & 4 deletions pruning/architecture/pruning_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def prune_model(
checkpoints_data = pd.DataFrame(
columns=["pruned_precent", "top1_accuracy", "top5_accuracy", "epoch_mean", "epoch_std"]
)
epochs = []
epoch_sum = 0

for iteration, step in enumerate(pruning_steps):
logger.info(f"Pruning iteration {iteration + 1}/{len(pruning_steps)}")
Expand All @@ -193,6 +193,7 @@ def prune_model(

for epoch in range(finetune_epochs):
logger.info(f"Epoch {epoch + 1}/{finetune_epochs}")
epoch_sum += 1

train_loss = utility.training.train_epoch(
module=model,
Expand Down Expand Up @@ -223,16 +224,14 @@ def prune_model(
if early_stopper and early_stopper.check_stop(metrics["validation_loss"]):
logger.info(f"Early stopping after {epoch+1} epochs")
early_stopper.reset()
epochs.append(epoch + 1)
break

if (
checkpoints_interval.start * 100 <= pruned <= checkpoints_interval.end * 100
and finetune_epochs > 0
):
# post epoch metrics
metrics["epoch_mean"] = np.mean(epochs) if epochs else finetune_epochs
metrics["epoch_std"] = np.std(epochs) if epochs else 0
metrics["epoch_sum"] = epoch_sum

checkpoints_data.loc[iteration] = {
key: metrics[key] for key in checkpoints_data.columns
Expand Down

0 comments on commit 2bdfd6b

Please sign in to comment.