From eaf293aaadec8a94700da38466798e00905afe5f Mon Sep 17 00:00:00 2001 From: Khaleel Khan Date: Fri, 24 May 2024 15:04:52 +0200 Subject: [PATCH] Release candidate v0.2.0 (#10) * Added some context for EvNN * AdamW, Simpler Thresholds (#5) * Switched to AdamW optimizer, simplified threshold parameterization, slight changes to the training of thresholds * removed wandb from training script * fixed inference script and updated README.md * Improved setup and install (#8) * improve setup and remove makefiles * remove makefile --------- authored-by: KhaleelKhan * bump up version, update readme * include required files in distributed archive * only require nvcc to compile cuda kernels * cleaned LM code from pruning attempts * update changelog and prepare merge --------- Co-authored-by: Anand Co-authored-by: Mark Schoene Co-authored-by: KhaleelKhan --- CHANGELOG.md | 9 + build/MANIFEST.in => MANIFEST.in | 1 - Makefile | 66 ----- README.md | 25 +- benchmarks/lm/README.md | 22 +- benchmarks/lm/eval.py | 5 +- benchmarks/lm/infer.py | 68 +---- benchmarks/lm/models.py | 58 +--- benchmarks/lm/train.py | 269 ++++++++---------- build/common.py | 35 --- build/setup.pytorch.py | 83 ------ docker/Dockerfile | 25 +- frameworks/pytorch/egru.py | 32 ++- ...ackward_gpu.cu.cc => egru_backward_gpu.cu} | 0 ..._forward_gpu.cu.cc => egru_forward_gpu.cu} | 0 setup.py | 132 +++++++++ validation/self_consistency_test.py | 15 +- 17 files changed, 339 insertions(+), 506 deletions(-) rename build/MANIFEST.in => MANIFEST.in (87%) delete mode 100644 Makefile delete mode 100644 build/common.py delete mode 100644 build/setup.pytorch.py rename lib/{egru_backward_gpu.cu.cc => egru_backward_gpu.cu} (100%) rename lib/{egru_forward_gpu.cu.cc => egru_forward_gpu.cu} (100%) create mode 100644 setup.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 98b594d..ad488a4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,14 @@ # ChangeLog +## 0.2.0-egru (2024-05-24) +### Changed +- Simplified install and removed makefile +- CUDA compute capability is automatically detected +- Update Readme with the setup instruction +- Update Dockerfile +- Cleaned LM pruning code + + ## 0.1.0-egru (2022-03-01) ### Changed - Project forked from original diff --git a/build/MANIFEST.in b/MANIFEST.in similarity index 87% rename from build/MANIFEST.in rename to MANIFEST.in index a919161..559f30d 100644 --- a/build/MANIFEST.in +++ b/MANIFEST.in @@ -1,4 +1,3 @@ -include Makefile include frameworks/pytorch/*.h include frameworks/pytorch/*.cc include lib/*.cc diff --git a/Makefile b/Makefile deleted file mode 100644 index c008308..0000000 --- a/Makefile +++ /dev/null @@ -1,66 +0,0 @@ -AR ?= ar -CXX ?= g++ -NVCC ?= nvcc -ccbin $(CXX) -PYTHON ?= python - -ifeq ($(OS),Windows_NT) -LIBEVNN := libevnn.lib -CUDA_HOME ?= $(CUDA_PATH) -AR := lib -AR_FLAGS := /nologo /out:$(LIBEVNN) -NVCC_FLAGS := -x cu -Xcompiler "/MD" -else -LIBEVNN := libevnn.a -CUDA_HOME ?= $(CUDA_PATH) -AR ?= ar -AR_FLAGS := -crv $(LIBEVNN) -NVCC_FLAGS := -std=c++11 -x cu -Xcompiler -fPIC -lineinfo -Wno-deprecated-gpu-targets -endif - -LOCAL_CUDA_CFLAGS := -I$(CUDA_HOME)/include -LOCAL_CUDA_LDFLAGS := -L$(CUDA_HOME)/lib64 -lcudart -lcublas -LOCAL_CFLAGS := -Ilib -O3 -g -LOCAL_LDFLAGS := -L. -lcblas -GPU_ARCH_FLAGS := -gencode arch=compute_37,code=compute_37 -gencode arch=compute_60,code=compute_60 -gencode arch=compute_70,code=compute_70 - -# Small enough project that we can just recompile all the time. -.PHONY: all evnn evnn_pytorch examples clean - -all: evnn evnn_pytorch examples - -# Dependencies handled by setup.py -evnn_pytorch: - @$(eval TMP := $(shell mktemp -d)) - @cp -r . $(TMP) - @cat build/common.py build/setup.pytorch.py > $(TMP)/setup.py - @(cd $(TMP); $(PYTHON) setup.py bdist_wheel) - @cp $(TMP)/dist/*.whl . - @rm -rf $(TMP) - -dist: - @$(eval TMP := $(shell mktemp -d)) - @cp -r . $(TMP) - @cp build/MANIFEST.in $(TMP) - @cat build/common.py build/setup.pytorch.py > $(TMP)/setup.py - @(cd $(TMP); $(PYTHON) setup.py -q sdist) - @cp $(TMP)/dist/*.tar.gz . - @rm -rf $(TMP) - -evnn: - $(NVCC) $(GPU_ARCH_FLAGS) -c lib/egru_forward_gpu.cu.cc -o lib/egru_forward_gpu.o $(NVCC_FLAGS) $(LOCAL_CUDA_CFLAGS) $(LOCAL_CFLAGS) - $(NVCC) $(GPU_ARCH_FLAGS) -c lib/egru_backward_gpu.cu.cc -o lib/egru_backward_gpu.o $(NVCC_FLAGS) $(LOCAL_CUDA_CFLAGS) $(LOCAL_CFLAGS) - $(NVCC) $(GPU_ARCH_FLAGS) -c lib/egru_forward_cpu.cc -o lib/egru_forward_cpu.o $(NVCC_FLAGS) $(LOCAL_CUDA_CFLAGS) $(LOCAL_CFLAGS) - $(NVCC) $(GPU_ARCH_FLAGS) -c lib/egru_backward_cpu.cc -o lib/egru_backward_cpu.o $(NVCC_FLAGS) $(LOCAL_CUDA_CFLAGS) $(LOCAL_CFLAGS) - $(AR) $(AR_FLAGS) lib/*.o - -evnn_cpu: - $(CXX) -c lib/egru_forward_cpu.cc $(LOCAL_LDFLAGS) -o lib/egru_forward_cpu.o -fPIC $(LOCAL_CFLAGS) - $(CXX) -c lib/egru_backward_cpu.cc $(LOCAL_LDFLAGS) -o lib/egru_backward_cpu.o -fPIC $(LOCAL_CFLAGS) - $(AR) $(AR_FLAGS) lib/*.o - -examples: evnn - $(CXX) -std=c++11 examples/egru.cc $(LIBEVNN) -Ieigen3 $(LOCAL_CUDA_CFLAGS) $(LOCAL_CFLAGS) $(LOCAL_CUDA_LDFLAGS) $(LOCAL_LDFLAGS) -o evnn_egru -Wno-ignored-attributes - -clean: - rm -fr benchmark_lstm benchmark_gru evnn_egru evnn_*.whl evnn_*.tar.gz - find . \( -iname '*.o' -o -iname '*.so' -o -iname '*.a' -o -iname '*.lib' \) -delete diff --git a/README.md b/README.md index e585aa9..5d7ea95 100644 --- a/README.md +++ b/README.md @@ -30,36 +30,41 @@ Here's what you'll need to get started: - a [CUDA Compute Capability](https://developer.nvidia.com/cuda-gpus) 3.7+ GPU (required only if using GPU) - [CUDA Toolkit](https://developer.nvidia.com/cuda-toolkit) 11.0+ (required only if using GPU) - [PyTorch](https://pytorch.org) 1.3+ for PyTorch integration (GPU optional) -- [BLAS](https://netlib.org/blas/) or any BLAS-like library for CPU computation. -- [Eigen 3](http://eigen.tuxfamily.org/) to build the C++ examples (optional) +- [OpenBLAS](https://www.openblas.net/) or any BLAS-like library for CPU computation. Once you have the prerequisites, you can install with pip or by building the source code. - +``` ### Building from source > **Note** > > Currenty supported only on Linux, use Docker for building on Windows. +Build and install it with `pip`: ```bash -make evnn_pytorch # Build PyTorch API +pip install . ``` +### Building in Docker -If you built the PyTorch API, install it with `pip`: +Build docker image: ```bash -pip install evnn_pytorch-*.whl +docker build -t evnn -f docker/Dockerfile . ``` -If the CUDA Toolkit that you're building against is not in `/usr/local/cuda`, you must specify the -`$CUDA_HOME` environment variable before running make: +Example usage: ```bash -CUDA_HOME=/usr/local/cuda-10.2 make +docker run --rm --gpus=all evnn python -m unittest discover -p "*_test.py" -s /evnn_src/validation -v ``` +> **Note** +> +> The build script tries to automatically detect GPU compute capability. In case the GPU is not available during compilation, for example when building with docker or when using compute cluster login nodes for compiling, Use enviroment variable `EVNN_CUDA_COMPUTE` to set the required compute capability. +> Example: For CUDA Compute capability 8.0 use ```export EVNN_CUDA_COMPUTE=80``` + ## Performance Code for the experiments and benchmarks presented in the paper are published in ``benchmarks`` directory. diff --git a/benchmarks/lm/README.md b/benchmarks/lm/README.md index 9d3aac2..d7089f1 100644 --- a/benchmarks/lm/README.md +++ b/benchmarks/lm/README.md @@ -4,18 +4,30 @@ To run the language modeling experiments, first download the data ./getdata -Then run Penn Treebank experiments with EGRU (1350 units) +We [provide checkpoints for EGRU](https://cloudstore.zih.tu-dresden.de/index.php/s/NPQ9pLnpZnTsM5X) with 3 layers of hidden size (1350, 1350, 750) - python lm/train.py --data path_to_your_data --scratch ./log --dataset PTB --epochs 2500 --rnn_type egru --layers 3 --hidden_dim 1350 --batch_size=64 --bptt=68 --dropout_connect=0.6788113982442464 --dropout_emb=0.7069992062976298 --dropout_forward=0.2641540030663871 --dropout_words=0.05460274136214911 --emb_dim=788 --learning_rate=0.00044406742918918466 --pseudo_derivative_width=2.179414375864446 --thr_init_mean=-3.76855645544185 --weight_decay=9.005509348932795e-06 --seed 12008 +# Penn Treebank +To train EGRU on Penn Treebank word-level language modeling, run -or EGRU (2000 units) + python benchmarks/lm/train.py --data=/path/to/data --scratch=/your/scratch/directory/Experiments --dataset=PTB --epochs=1000 --batch_size=64 --rnn_type=egru --layer=3 --bptt=70 --scheduler=cosine --weight_decay=0.10 --learning_rate=0.0012 --learning_rate_thresholds 0.0 --emb_dim=750 --dropout_emb=0.6 --dropout_words=0.1 --dropout_forward=0.25 --grad_clip=0.1 --thr_init_mean=0.01 --dropout_connect=0.7 --hidden_dim=1350 --pseudo_derivative_width=3.6 --scheduler_start=700 --seed=9612 - python lm/train.py --data path_to_your_data --scratch ./log --dataset PTB --epochs 2500 --rnn_type egru --layers 3 --hidden_dim 2000 --batch_size=128 --bptt=67 --dropout_connect=0.621405385527356 --dropout_emb=0.7651296208061924 --dropout_forward=0.24131807369801447 --dropout_words=0.14942681962154375 --emb_dim=786 --learning_rate=0.000494172266064804 --pseudo_derivative_width=2.35216907207571 --thr_init_mean=-3.4957794302256007 --weight_decay=6.6878095661652755e-06 --seed 52798 +For inference with the [provided checkpoint](https://cloudstore.zih.tu-dresden.de/index.php/s/NPQ9pLnpZnTsM5X), run + + python benchmarks/lm/infer.py --data /path/to/data --dataset PTB --datasplit test --batch_size 1 --directory /path/to/checkpoint + +# Wikitext-2 +To train EGRU on Wikitext-2, run + + python benchmarks/lm/train.py --data=/your/data/directory --scratch=/your/scratch/directory/Experiments --dataset=WT2 --epochs=800 --batch_size=128 --rnn_type=egru --layer=3 --bptt=70 --scheduler=cosine --weight_decay=0.12 --learning_rate=0.001 --learning_rate_thresholds 0.0 --emb_dim=750 --dropout_emb=0.7 --dropout_words=0.1 --dropout_forward=0.25 --grad_clip=0.1 --thr_init_mean=0.01 --dropout_connect=0.7 --hidden_dim=1350 --pseudo_derivative_width=3.6 --scheduler_start=400 --seed=913420 + +For inference with the [provided checkpoint](https://cloudstore.zih.tu-dresden.de/index.php/s/NPQ9pLnpZnTsM5X), run + + python benchmarks/lm/infer.py --data /path/to/data --dataset WT2 --datasplit test --batch_size 1 --directory /path/to/checkpoint Various flags can be passed to change the defaults parameters. See "train.py" for a list of all available arguments. -This code was tested with PyTorch >= 1.9.0 +This code was tested with PyTorch >= 1.9.0, CUDA 11. A large batch of code stems from Salesforce AWD-LSTM implementation: https://github.com/salesforce/awd-lstm-lm \ No newline at end of file diff --git a/benchmarks/lm/eval.py b/benchmarks/lm/eval.py index a98805f..d55d0e8 100644 --- a/benchmarks/lm/eval.py +++ b/benchmarks/lm/eval.py @@ -14,16 +14,15 @@ # ============================================================================== import torch -import lm.data as d +import data as d -def evaluate(model, eval_data, criterion, batch_size, bptt, ntokens, device, return_hidden=False): +def evaluate(model, eval_data, criterion, batch_size, bptt, ntokens, device, hidden_dims, return_hidden=False): # turn on evaluation mode model.eval() # initialize evaluation metrics iter_range = range(0, eval_data.size(0) - 1, bptt) - hidden_dims = [rnn.hidden_size for rnn in model.rnns] total_loss = 0. mean_activities = torch.zeros(len(iter_range), dtype=torch.float16, device=device) diff --git a/benchmarks/lm/infer.py b/benchmarks/lm/infer.py index 57bf3d1..b65b491 100644 --- a/benchmarks/lm/infer.py +++ b/benchmarks/lm/infer.py @@ -23,9 +23,9 @@ import torch import torch.nn -import lm.data as d -from lm.models import LanguageModel -from lm.eval import evaluate +import data as d +from models import LanguageModel +from eval import evaluate def get_args(): @@ -37,7 +37,6 @@ def get_args(): argparser.add_argument('--batch_size', type=int, default=80) argparser.add_argument('--directory', type=str, required=False, help='model directory for checkpoints and config') argparser.add_argument('--hidden', action='store_true', help='returns the hidden states of the whole dataset to perform analysis') - argparser.add_argument('--prune', type=float, default=0.0) return argparser.parse_args() @@ -85,7 +84,7 @@ def main(args): model = LanguageModel(**model_args).to(device) elif config['rnn_type'] == 'egru': model = LanguageModel(**model_args, - dampening_factor=config['damp_factor'], + dampening_factor=config['pseudo_derivative_width'], pseudo_derivative_support=config['pseudo_derivative_width']).to(device) else: raise RuntimeError("Unknown RNN type: %s" % config['rnn_type']) @@ -93,6 +92,11 @@ def main(args): best_model_path = os.path.join(args.directory, 'checkpoints', f"{config['rnn_type'].upper()}_best_model.cpt") model.load_state_dict(torch.load(best_model_path, map_location=device)) + if model_args['rnn_type'] == 'egru': + hidden_dims = [rnn.hidden_size for rnn in model.rnns] + else: + hidden_dims = [rnn.module.hidden_size if args.dropout_connect > 0 else rnn.hidden_size for rnn in model.rnns] + criterion = torch.nn.CrossEntropyLoss() if args.hidden: @@ -104,6 +108,7 @@ def main(args): bptt=config['bptt'], ntokens=vocab_size, device=device, + hidden_dims=hidden_dims, return_hidden=True) save_file = os.path.join(args.directory, f'hidden_states_{args.datasplit}.hdf') with h5py.File(save_file, 'w') as f: @@ -121,6 +126,7 @@ def main(args): bptt=config['bptt'], ntokens=vocab_size, device=device, + hidden_dims=hidden_dims, return_hidden=False) test_ppl = math.exp(test_loss) @@ -131,58 +137,6 @@ def main(args): print(f'Layerwise activity {test_layerwise_activity_mean.tolist()} +- {test_layerwise_activity_std.tolist()}') print('=' * 89) - if args.prune > 0.0 and args.hidden: - print(f"Model Parameter Count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") - input_indices = torch.arange(model.rnns[0].input_size).to(device) - for i in range(model.nlayers): - if i < model.nlayers - 1: - # get event frequencies - hid_dim = all_hiddens[i].shape[2] - hid_cells = all_hiddens[i].reshape(-1, hid_dim) - seq_len = hid_cells.shape[0] - spike_frequency = torch.sum(hid_cells != 0, dim=0) / seq_len - print( - f"Layer {i + 1}: " - f"less than 1/100: {torch.sum(spike_frequency < 0.01)} / {spike_frequency.shape} " - f"// never: {torch.sum(hid_cells.sum(dim=0) == 0)} / {spike_frequency.shape}") - - # compute remaining indicies from spike frequencies - topk = int(model.rnns[i].hidden_size * (1 - args.prune)) - hidden_indices, _ = torch.sort(torch.argsort(spike_frequency, descending=True)[:topk], descending=False) - hidden_indices = hidden_indices.to(device) - else: - hidden_indices = torch.arange(model.rnns[i].hidden_size).to(device) - model.rnns[i].prune_units(input_indices, hidden_indices) - input_indices = hidden_indices - - print(f"Model Parameter Count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") - - test_loss, test_activity, test_layerwise_activity_mean, test_layerwise_activity_std, centered_cell_states, all_hiddens = \ - evaluate(model=model, - eval_data=test_data, - criterion=criterion, - batch_size=args.batch_size, - bptt=config['bptt'], - ntokens=vocab_size, - device=device, - return_hidden=True) - for i in range(model.nlayers - 1): - # get event frequencies - hid_dim = all_hiddens[i].shape[2] - hid_cells = all_hiddens[i].reshape(-1, hid_dim) - seq_len = hid_cells.shape[0] - spike_frequency = torch.sum(hid_cells != 0, dim=0) / seq_len - print( - f"less than 1/100: {torch.sum(spike_frequency < 0.01)} / {spike_frequency.shape} " - f"// never: {torch.sum(hid_cells.sum(dim=0) == 0)} / {spike_frequency.shape}") - test_ppl = math.exp(test_loss) - print('=' * 89) - print(f'| Inference | test loss {test_loss:5.2f} | ' - f'test ppl {test_ppl:8.2f} | ' - f'test mean activity {test_activity}') - print(f'Layerwise activity {test_layerwise_activity_mean.tolist()} +- {test_layerwise_activity_std.tolist()}') - print('=' * 89) - if __name__ == "__main__": args = get_args() diff --git a/benchmarks/lm/models.py b/benchmarks/lm/models.py index 58638c1..f349379 100644 --- a/benchmarks/lm/models.py +++ b/benchmarks/lm/models.py @@ -17,8 +17,8 @@ import torch.nn as nn import torch.nn.functional as F import evnn_pytorch as evnn -from lm.modules import VariationalDropout, WeightDrop -from lm.embedding_dropout import embedded_dropout +from modules import VariationalDropout, WeightDrop +from embedding_dropout import embedded_dropout from typing import Union @@ -64,9 +64,8 @@ def forward(self, x): bs, seq_len, ninp = x.shape if self.project: x = x.view(-1, ninp) - x = F.relu(self.projection(x)) + x = self.projection(x) x = x.view(bs, seq_len, self.nemb) - x = self.variational_dropout(x, self.dropout) x = x.view(-1, self.nemb) x = self.decoder(x) return x @@ -155,57 +154,6 @@ def __init__(self, self.backward_sparsity = torch.zeros(len(self.rnns)) - def prune_embeddings(self, index): - device = next(self.parameters()).device - self.embeddings.weight = nn.Parameter( - self.embeddings.weight[:, index]).to(device) - self.emb_dim = self.embeddings.weight.shape[1] - self.decoder = Decoder(ninp=self.hidden_dim if self.projection else self.emb_dim, ntokens=self.vocab_size, - project=self.projection, nemb=self.emb_dim, - dropout=self.dropout_forward).to(device) - self.decoder.decoder.weight = self.embeddings.weight - - def prune(self, fractions, hiddens, device): - # calculate new hidden dimensions - indicies = [torch.arange(self.rnns[0].input_size).to(device)] - - for i in range(self.nlayers): - if isinstance(fractions, float): - frac = fractions - elif isinstance(fractions, tuple) or isinstance(fractions, list): - frac = fractions[i] - else: - raise NotImplementedError( - f"data type {type(fractions)} not implemented. Use float, tuple or list") - - # get event frequencies - hid_dim = hiddens[i].shape[2] - hid_cells = hiddens[i].reshape(-1, hid_dim) - seq_len = hid_cells.shape[0] - spike_frequency = torch.sum(hid_cells != 0, dim=0) / seq_len - print( - f"Layer {i + 1}: " - f"less than 1/100: {torch.sum(spike_frequency < 0.01)} / {spike_frequency.shape} " - f"// never: {torch.sum(hid_cells.sum(dim=0) == 0)} / {spike_frequency.shape}") - - # compute remaining indicies from spike frequencies - topk = int(self.rnns[i].hidden_size * (1 - frac)) - hidden_indices, _ = torch.sort(torch.argsort( - spike_frequency, descending=True)[:topk], descending=False) - hidden_indices = hidden_indices.to(device) - indicies.append(hidden_indices) - - # input dimension equals embedding dimension for tied weights - indicies[0] = indicies[-1] - - # prune weights - for i in range(self.nlayers): - self.rnns[i].prune_units(indicies[i], indicies[i+1]) - - self.prune_embeddings(indicies[-1]) - print( - f"Final model hidden size: {[rnn.hidden_size for rnn in self.rnns]}") - def init_embedding(self, initrange): nn.init.uniform_(self.embeddings.weight, -initrange, initrange) diff --git a/benchmarks/lm/train.py b/benchmarks/lm/train.py index 39ac1b9..34518a7 100644 --- a/benchmarks/lm/train.py +++ b/benchmarks/lm/train.py @@ -12,21 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - import argparse import math import os import time - +from enum import Enum import numpy as np import torch import yaml from torch import nn -import lm.data as d -from lm import RNNType -from lm.eval import evaluate -from lm.models import LanguageModel +import data as d +from eval import evaluate +from models import LanguageModel + + +class RNNType(Enum): + LSTM = 'lstm' + GRU = 'gru' + EGRU = 'egru' def get_args(): @@ -38,7 +42,7 @@ def get_args(): argparser.add_argument('--epochs', type=int, default=800) argparser.add_argument('--batch_size', type=int, default=80) argparser.add_argument('--learning_rate', type=float, default=0.0003) - argparser.add_argument('--avg_learning_rate', type=float, required=False) + argparser.add_argument('--learning_rate_thresholds', type=float, default=1.0) argparser.add_argument('--bptt', type=int, default=70) argparser.add_argument('--grad_clip', type=float, default=2.0) argparser.add_argument('--rnn_type', type=str, @@ -51,14 +55,12 @@ def get_args(): argparser.add_argument('--dropout_words', type=float, default=0.1) argparser.add_argument('--dropout_forward', type=float, default=0.3) argparser.add_argument('--dropout_connect', type=float, default=0.5) - argparser.add_argument('--damp_factor', type=float, default=0.7) argparser.add_argument('--checkpoint', type=str, required=False, default="") argparser.add_argument('--log_interval', type=int, default=1000) - argparser.add_argument('--optimizer', type=str, default='adam') - argparser.add_argument('--nonmono', type=int, default=5) + argparser.add_argument('--scheduler', type=str, default='lambda', choices=['lambda', 'cosine', 'step']) + argparser.add_argument('--scheduler_start', type=int, default=200) argparser.add_argument('--momentum', type=float, default=0.0) argparser.add_argument('--weight_decay', type=float, default=1.2e-6) - argparser.add_argument('--thr_init_scale', type=float, default=1.0) argparser.add_argument('--alpha', type=float, default=0, help='alpha L2 regularization on RNN activation (alpha = 0 means no regularization)') argparser.add_argument('--beta', type=float, default=0, @@ -66,9 +68,8 @@ def get_args(): argparser.add_argument('--gamma', type=float, default=0, help='EGRU activity regularization') argparser.add_argument('--pseudo_derivative_width', type=float, default=1.0) - argparser.add_argument('--thr_init_mean', type=float, default=-2.0) - argparser.add_argument('--prune', nargs='*', type=float, default=0.0) - argparser.add_argument('--state_decay', type=float, default=0.001) + argparser.add_argument('--thr_init_mean', type=float, default=0.2) + argparser.add_argument('--weight_init_gain', type=float, default=1.0) return argparser.parse_args() @@ -84,10 +85,12 @@ def main(args): print(device) # load dataset - train_data, val_data, test_data, vocab_size = d.get_data(root=args.data, - dset=args.dataset, - batch_size=args.batch_size, - device=device) + train_data, val_data, test_data, vocab_size = d.get_data( + root=args.data, + dset=args.dataset, + batch_size=args.batch_size, + device=device + ) print(f"Dataset {args.dataset} has {vocab_size} tokens") criterion = nn.CrossEntropyLoss() @@ -112,11 +115,14 @@ def main(args): if args.rnn_type == 'lstm' or args.rnn_type == 'gru': model = LanguageModel(**model_args) elif args.rnn_type == 'egru': - model = LanguageModel(**model_args, - dampening_factor=args.damp_factor, - pseudo_derivative_support=args.pseudo_derivative_width, - grad_clip=args.grad_clip, - thr_mean=args.thr_init_mean) + model = LanguageModel( + **model_args, + dampening_factor=args.pseudo_derivative_width, + pseudo_derivative_support=args.pseudo_derivative_width, + grad_clip=args.grad_clip, + thr_mean=args.thr_init_mean, + weight_initialization_gain=args.weight_init_gain + ) else: raise RuntimeError("Unknown RNN type: %s" % args.rnn_type) print("RNN parameters: ", list(map(lambda x: x[0], model.named_parameters()))) @@ -124,30 +130,19 @@ def main(args): if len(args.checkpoint) > 0: model.load_state_dict(torch.load(args.checkpoint)) model = model.to(device) - - # MODEL PRUNING - pruning = False - print(args.prune) - if isinstance(args.prune, list) and len(args.prune) == 1: - pruning = True - args.prune = args.prune[0] - elif isinstance(args.prune, list) and len(args.prune) == args.layers: - pruning = True - - if pruning: - prune(model=model, - criterion=criterion, - data=train_data, - batch_size=args.batch_size, - sequence_length=args.bptt, - ntokens=vocab_size, - device=device, - fractions=args.prune) - + # get the dimensions of the hidden state if args.rnn_type == 'egru': hidden_dims = [rnn.hidden_size for rnn in model.rnns] else: hidden_dims = [rnn.module.hidden_size if args.dropout_connect > 0 else rnn.hidden_size for rnn in model.rnns] + + config = vars(args) + config.update({'num_parameters': sum(p.numel() for p in model.parameters() if p.requires_grad)}) + config.update({'SLURM_JOB_ID': os.getenv('SLURM_JOB_ID')}) + print(f"Model Parameter Count: {config['num_parameters']}") + + model_signature = 'rnn_type={0}__nlayers{1}_lr={2}_decay={3}'.format(args.rnn_type, args.layers, args.learning_rate, args.weight_decay) + return_bw_sparsity = True if model.rnn_type == 'egru' else False config = vars(args) @@ -155,28 +150,36 @@ def main(args): print(f"Model Parameter Count: {config['num_parameters']}") # setup training - param_groups = [{'params': [param for name, param in model.named_parameters() if - 'thr' not in name and 'layernorm' not in name]}, - {'params' : [param for name, param in model.named_parameters() if - 'thr' in name or 'layernorm' in name], - 'weight_decay': 0} - ] - if args.optimizer == 'sgd': - optimizer = torch.optim.SGD(param_groups, lr=args.learning_rate, momentum=args.momentum, - weight_decay=args.weight_decay) + param_groups = [ + # most parameters + {'params': [param for name, param in model.named_parameters() + if 'thr' not in name and 'layernorm' not in name]}, + # layernorm + {'params': [param for name, param in model.named_parameters() + if 'layernorm' in name], + 'weight_decay': 0}, + # thresholds + {'params': [param for name, param in model.named_parameters() + if 'thr' in name], + 'lr': args.learning_rate * args.learning_rate_thresholds, + 'weight_decay': 0} + ] + + optimizer = torch.optim.AdamW(param_groups, lr=args.learning_rate, weight_decay=args.weight_decay, betas=(args.momentum, 0.999)) + + if args.scheduler == "lambda": scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda ep: 1.0) - elif args.optimizer == 'adam': - optimizer = torch.optim.Adam(param_groups, lr=args.learning_rate, betas=(args.momentum, 0.999), - weight_decay=args.weight_decay) - milestone = args.epochs // 2 - scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda ep: 1.0) - else: - raise NotImplementedError(f'Optimizer {args.optimizer} not implemented') + + if args.scheduler == "cosine": + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=(args.epochs-args.scheduler_start), eta_min=0) + + if args.scheduler == "step": + scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=(args.epochs-args.scheduler_start)/4, gamma=args.gamma) best_val_loss = float('inf') output_path = os.path.join(args.scratch, args.dataset, args.rnn_type.upper(), - f"{args.rnn_type.upper()}_{time.strftime('%y-%m-%d-%H:%M:%S')}") + f"{model_signature}_{time.strftime('%y-%m-%d-%H:%M:%S')}") best_model_path = os.path.join(output_path, 'checkpoints', f'{args.rnn_type.upper()}_best_model.cpt') print("Saving model weights to", best_model_path) @@ -190,81 +193,45 @@ def main(args): for epoch in range(1, args.epochs + 1): epoch_start_time = time.time() - train_results = train(model=model, - train_data=train_data, - optimizer=optimizer, - criterion=criterion, - epoch=epoch, - batch_size=args.batch_size, - bptt=args.bptt, - ntokens=vocab_size, - grad_clip=args.grad_clip, - log_interval=args.log_interval, - device=device, - return_backward_sparsity=return_bw_sparsity) + train_results = train( + model=model, + train_data=train_data, + optimizer=optimizer, + criterion=criterion, + epoch=epoch, + batch_size=args.batch_size, + bptt=args.bptt, + ntokens=vocab_size, + grad_clip=args.grad_clip, + log_interval=args.log_interval, + device=device, + return_backward_sparsity=return_bw_sparsity + ) + if return_bw_sparsity: train_loss, bw_sparsity = train_results else: train_loss = train_results - # If already averaging - if args.optimizer == 'sgd' and 't0' in optimizer.param_groups[0]: - if 't0' in optimizer.param_groups[0]: - tmp = {} - for prm in model.parameters(): - tmp[prm] = prm.data.clone() - if 'ax' in optimizer.state[prm]: - prm.data = optimizer.state[prm]['ax'].clone() - - val_loss, mean_activity, layerwise_activity_mean, layerwise_activity_std, centered_cell_states = \ - evaluate(model=model, - eval_data=val_data, - criterion=criterion, - batch_size=args.batch_size, - bptt=args.bptt, - ntokens=vocab_size, - device=device) - - if val_loss < best_val_loss: - best_val_loss = val_loss - checkpoint_model(model.state_dict(), best_model_path) - - for prm in model.parameters(): - prm.data = tmp[prm].clone() - + val_loss, mean_activity, layerwise_activity_mean, layerwise_activity_std, centered_cell_states = \ + evaluate( + model=model, + eval_data=val_data, + criterion=criterion, + batch_size=args.batch_size, + bptt=args.bptt, + ntokens=vocab_size, + hidden_dims=hidden_dims, + device=device + ) + + # save best model + if val_loss < best_val_loss: + best_val_loss = val_loss + checkpoint_model(model.state_dict(), best_model_path) + n = 0 else: - val_loss, mean_activity, layerwise_activity_mean, layerwise_activity_std, centered_cell_states = \ - evaluate(model=model, - eval_data=val_data, - criterion=criterion, - batch_size=args.batch_size, - bptt=args.bptt, - ntokens=vocab_size, - device=device) - - if val_loss < best_val_loss: - best_val_loss = val_loss - checkpoint_model(model.state_dict(), best_model_path) - n = 0 - else: - n += 1 - - if isinstance(optimizer, torch.optim.SGD) and ( - epoch > args.nonmono and n > args.nonmono): - print('Switching to ASGD') - - # param_groups are reset by optimizer, we thus have to overwrite the learning rate - for pg in param_groups: - if args.avg_learning_rate: - pg['lr'] = args.avg_learning_rate - - # set optimizer and learning rate schedule - optimizer = torch.optim.ASGD(param_groups, - lr=args.learning_rate if not args.avg_learning_rate else args.avg_learning_rate, - t0=0, - lambd=0., - weight_decay=args.weight_decay) - scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda ep: 1.0) + n += 1 val_ppl = math.exp(val_loss) elapsed = time.time() - epoch_start_time @@ -272,16 +239,19 @@ def main(args): print(f'| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | ' f'train loss {train_loss:5.2f} | ' f'valid loss {val_loss:5.2f} | valid ppl {val_ppl:8.2f} | mean activity {mean_activity:.4f}') + if return_bw_sparsity: mean_bw_sparsity = np.dot(bw_sparsity, np.array(hidden_dims)) / np.sum(np.array(hidden_dims)) print(f'backward sparsity {mean_bw_sparsity}') print('-' * 89) - scheduler.step() + # if the loss diverged to infinity, stop training + if np.isnan(val_loss).any(): + print(f"EXITING DUE TO NAN LOSS {val_loss}") + break - if isinstance(optimizer, torch.optim.Adam) and epoch == milestone: - scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=milestone, - eta_min=args.learning_rate / 10) + if epoch > args.scheduler_start: + scheduler.step() ###################################################################### # Evaluate the best model on the test dataset @@ -294,13 +264,16 @@ def main(args): device=device) model.load_state_dict(torch.load(best_model_path, map_location=device)) test_loss, test_activity, test_layerwise_activity_mean, test_layerwise_activity_std, centered_cell_states = \ - evaluate(model=model, - eval_data=test_data, - criterion=criterion, - batch_size=test_batch_size, - bptt=args.bptt, - ntokens=vocab_size, - device=device) + evaluate( + model=model, + eval_data=test_data, + criterion=criterion, + batch_size=test_batch_size, + bptt=args.bptt, + ntokens=vocab_size, + hidden_dims=hidden_dims, + device=device + ) test_ppl = math.exp(test_loss) print('=' * 89) @@ -321,22 +294,6 @@ def repackage_hidden(h): return tuple(repackage_hidden(v) for v in h) -def prune(model, criterion, data, batch_size, sequence_length, ntokens, device, fractions=0.0): - print("Pruning model...") - test_loss, test_activity, test_layerwise_activity_mean, test_layerwise_activity_std, centered_cell_states, all_hiddens = \ - evaluate(model=model, - eval_data=data, - criterion=criterion, - batch_size=batch_size, - bptt=sequence_length, - ntokens=ntokens, - device=device, - return_hidden=True) - - model.prune(fractions, all_hiddens, device) - print(f"Perplexity before pruning {math.exp(test_loss)}") - - def train(model, train_data, optimizer, diff --git a/build/common.py b/build/common.py deleted file mode 100644 index 219fed3..0000000 --- a/build/common.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright 2020 LMNT, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - - -VERSION = '0.1.0' -DESCRIPTION = 'EVNN: a torch extension for custom event based RNN models.' -AUTHOR = 'TUD and RUB' -AUTHOR_EMAIL = 'khaleelulla.khan_nazeer@tu-dresden.de' -URL = 'https://tu-dresden.de/ing/elektrotechnik/iee/hpsn' -LICENSE = 'Apache 2.0' -CLASSIFIERS = [ - 'Development Status :: 4 - Beta', - 'Intended Audience :: Developers', - 'Intended Audience :: Education', - 'Intended Audience :: Science/Research', - 'License :: OSI Approved :: Apache Software License', - 'Programming Language :: Python :: 3.6', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', - 'Topic :: Scientific/Engineering :: Mathematics', - 'Topic :: Software Development :: Libraries :: Python Modules', - 'Topic :: Software Development :: Libraries', -] diff --git a/build/setup.pytorch.py b/build/setup.pytorch.py deleted file mode 100644 index db45f37..0000000 --- a/build/setup.pytorch.py +++ /dev/null @@ -1,83 +0,0 @@ -# Copyright 2020 LMNT, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -import os -import sys - -from glob import glob -from platform import platform -from torch.utils import cpp_extension -import torch -from setuptools import setup -from setuptools.dist import Distribution - - -with open(f'frameworks/pytorch/_version.py', 'wt') as f: - f.write(f'__version__ = "{VERSION}"') - -base_path = os.path.dirname(os.path.realpath(__file__)) - -if torch.cuda.is_available(): - make_cmd = 'make evnn' - - if 'Windows' in platform(): - CUDA_HOME = os.environ.get('CUDA_HOME', os.environ.get('CUDA_PATH')) - extra_args = [] - else: - CUDA_HOME = os.environ.get('CUDA_HOME', '/usr/local/cuda') - extra_args = ['-Wno-sign-compare'] - - extension = cpp_extension.CUDAExtension( - 'evnn_pytorch_lib', - sources=glob('frameworks/pytorch/*.cc'), - extra_compile_args=extra_args + ['-DWITH_CUDA'], - include_dirs=[os.path.join(base_path, 'lib'), - os.path.join(CUDA_HOME, 'include')], - libraries=['evnn', 'cblas', 'c10'], - library_dirs=['.']) -else: - make_cmd = 'make evnn_cpu' - extra_args = [] - extension = cpp_extension.CppExtension( - 'evnn_pytorch_lib', - sources=glob('frameworks/pytorch/*.cc'), - extra_compile_args=extra_args, - include_dirs=[os.path.join(base_path, 'lib'),], - libraries=['evnn', 'cblas'], - library_dirs=['.', os.path.join('/usr/lib/x86_64-linux-gnu')]) - - -class BuildEVNN(cpp_extension.BuildExtension): - def run(self): - os.system(make_cmd) - super().run() - - -setup(name='evnn_pytorch', - version=VERSION, - description=DESCRIPTION, - long_description=open('README.md', 'r', encoding='utf-8').read(), - long_description_content_type='text/markdown', - author=AUTHOR, - author_email=AUTHOR_EMAIL, - url=URL, - license=LICENSE, - keywords='pytorch machine learning rnn lstm gru custom op', - packages=['evnn_pytorch'], - package_dir={'evnn_pytorch': 'frameworks/pytorch'}, - install_requires=[], - ext_modules=[extension], - cmdclass={'build_ext': BuildEVNN}, - classifiers=CLASSIFIERS) diff --git a/docker/Dockerfile b/docker/Dockerfile index fe5da49..1d1b863 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -1,20 +1,9 @@ -FROM ubuntu:latest AS build -ENV MAX_JOBS=2 PIP_DEFAULT_TIMEOUT=100 DEBIAN_FRONTEND=noninteractive +FROM nvcr.io/nvidia/pytorch:24.04-py3 -#set up environment -RUN apt-get update && apt-get install --no-install-recommends --no-install-suggests -y software-properties-common build-essential gpg-agent -RUN add-apt-repository -y ppa:deadsnakes/ppa && apt-get update -RUN apt-get install -y python3.8 python3.8-distutils python3-pip -# RUN wget https://bootstrap.pypa.io/get-pip.py && python3.8 get-pip.py -RUN ln -s /usr/bin/python3.8 /usr/bin/python +ENV EVNN_CUDA_COMPUTE 80 +WORKDIR /evnn_src +COPY . . +RUN pip3 install . -# Nvidia toolkit -RUN apt-get -y install wget -RUN wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.0-1_all.deb && dpkg -i cuda-keyring_1.0-1_all.deb -RUN apt-get update && apt-get install -y cuda -#set up prerequisites for evnn compilation -RUN apt-get -y install python3-dev libpython3.8-dev libblas64-dev -RUN python -m pip install --upgrade pip -RUN python -m pip install ninja -RUN python -m pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu113 -RUN apt-get clean && python -m pip cache purge \ No newline at end of file + +WORKDIR /workspace \ No newline at end of file diff --git a/frameworks/pytorch/egru.py b/frameworks/pytorch/egru.py index 6585e40..a3ae009 100644 --- a/frameworks/pytorch/egru.py +++ b/frameworks/pytorch/egru.py @@ -247,7 +247,8 @@ def __init__(self, zoneout=0.0, dampening_factor=0.7, pseudo_derivative_support=1.0, - thr_mean=0.0, + thr_mean=0.3, + weight_initialization_gain=1.0, return_state_sequence=False, grad_clip=None, use_custom_cuda=True): @@ -288,6 +289,8 @@ def __init__(self, self.dropout = dropout self.alpha = torch.tensor(0.9) + self.weight_initialization_gain = weight_initialization_gain + self.kernel = nn.Parameter(torch.empty(input_size, hidden_size * 3)) self.recurrent_kernel = nn.Parameter( torch.empty(hidden_size, hidden_size * 3)) @@ -299,9 +302,13 @@ def __init__(self, torch.Tensor([dampening_factor]), requires_grad=False) self.pseudo_derivative_support = nn.Parameter( torch.Tensor([pseudo_derivative_support]), requires_grad=False) - self.thr_reparam = nn.Parameter(torch.normal(torch.zeros(self.hidden_size) + thr_mean, - math.sqrt(2) * torch.ones(self.hidden_size))) - self.thr = torch.sigmoid(self.thr_reparam) + + # initialize thresholds according to the beta distribution with mean 'thr_mean' + assert 0 < thr_mean < 1, f"thr_mean must be between 0 and 1, but {thr_mean} was given" + beta = 3 + alpha = beta * thr_mean / (1 - thr_mean) + distribution = torch.distributions.beta.Beta(alpha, beta) + self.thr = nn.Parameter(distribution.sample(torch.Size([self.hidden_size]))) def to_native_weights(self): """ @@ -327,7 +334,7 @@ def reorder_weights(w): recurrent_kernel = torch.nn.Parameter(recurrent_kernel) bias1 = torch.nn.Parameter(bias1) bias2 = torch.nn.Parameter(bias2) - thr = torch.nn.Parameter(self.thr_reparam) + thr = torch.nn.Parameter(self.thr) return kernel, recurrent_kernel, bias1, bias2, thr def from_native_weights(self, weight_ih_l0, weight_hh_l0, bias_ih_l0, bias_hh_l0, thr): @@ -354,14 +361,14 @@ def reorder_weights(w): self.recurrent_kernel = nn.Parameter(recurrent_kernel) self.bias = nn.Parameter(bias) self.recurrent_bias = nn.Parameter(recurrent_bias) - self.thr_reparam = nn.Parameter(thr) + self.thr = nn.Parameter(thr) def reset_parameters(self): """Resets this layer's parameters to their initial values.""" for k, v in self.named_parameters(): if k in ['kernel', 'recurrent_kernel', 'bias', 'recurrent_bias']: if v.data.ndimension() >= 2: - nn.init.xavier_normal_(v) + nn.init.xavier_normal_(v, gain=self.weight_initialization_gain) else: nn.init.zeros_(v) @@ -397,10 +404,15 @@ def forward(self, input, state=None, lengths=None): input = self._permute(input) state_shape = [1, input.shape[1], self.hidden_size] h0 = self._get_state(input, state, state_shape) - thr = torch.sigmoid(self.thr_reparam) + + # restrict thresholds to be between 0 and 1 + self.thr.data.clamp_(min=0.0, max=1.0) + + # run forward pass y, h, o, trace = self._impl( - input, h0[0], thr, self._get_zoneout_mask(input)) - state = self._get_final_state(y, lengths) + input, h0[0], self.thr, self._get_zoneout_mask(input)) + + # prepare outputs output = self._permute(y[1:]) h = self._permute(h[1:]) o = self._permute(o[1:]) diff --git a/lib/egru_backward_gpu.cu.cc b/lib/egru_backward_gpu.cu similarity index 100% rename from lib/egru_backward_gpu.cu.cc rename to lib/egru_backward_gpu.cu diff --git a/lib/egru_forward_gpu.cu.cc b/lib/egru_forward_gpu.cu similarity index 100% rename from lib/egru_forward_gpu.cu.cc rename to lib/egru_forward_gpu.cu diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..106a0ff --- /dev/null +++ b/setup.py @@ -0,0 +1,132 @@ +# Copyright (c) 2023 Khaleelulla Khan Nazeer +# This file incorporates work covered by the following copyright: +# Copyright 2020 LMNT, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import os + +from glob import glob +import warnings +import subprocess +import torch +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CppExtension, CUDA_HOME + + +def get_gpu_arch_flags(): + try: + major, minor = torch.cuda.get_device_capability() + return [f"-gencode=arch=compute_{major}{minor},code=sm_{major}{minor}"] + except Exception as e: + warnings.warn(f"Error while detecting GPU architecture: {e}\n \ + Use env var EVNN_CUDA_COMPUTE to set cuda compute capability") + compute_capability = os.getenv("EVNN_CUDA_COMPUTE", None) + if compute_capability is None: + warnings.warn("EVNN_CUDA_COMPUTE not defined, using default: 80") + compute_capability = 80 + + return [f"-gencode=arch=compute_{compute_capability},code=sm_{compute_capability}"] + +def check_nvcc_available(): + try: + subprocess.run(["nvcc", "--version"], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + return True + except FileNotFoundError: + warnings.warn( + f"nvcc was not found. Skip compiling GPU kernels" + ) + return False + except subprocess.CalledProcessError: + warnings.warn( + f"nvcc was not found. Skip compiling GPU kernels" + ) + return False + + +arch_flags = get_gpu_arch_flags() + +VERSION = '0.2.0' +DESCRIPTION = 'EVNN: a torch extension for custom event based RNN models.' +AUTHOR = 'TUD and RUB' +AUTHOR_EMAIL = 'khaleelulla.khan_nazeer@tu-dresden.de' +URL = 'https://tu-dresden.de/ing/elektrotechnik/iee/hpsn' +LICENSE = 'Apache 2.0' +CLASSIFIERS = [ + 'Development Status :: 4 - Beta', + 'Intended Audience :: Developers', + 'Intended Audience :: Education', + 'Intended Audience :: Science/Research', + 'License :: OSI Approved :: Apache Software License', + 'Programming Language :: Python :: 3.10', + 'Programming Language :: Python :: 3.11', + 'Topic :: Scientific/Engineering :: Mathematics', + 'Topic :: Software Development :: Libraries :: Python Modules', + 'Topic :: Software Development :: Libraries', +] + + +with open(f'frameworks/pytorch/_version.py', 'wt') as f: + f.write(f'__version__ = "{VERSION}"') + +base_path = os.path.dirname(os.path.realpath(__file__)) + +if check_nvcc_available(): + + extension = [CUDAExtension( + 'evnn_pytorch_lib', + sources=glob('frameworks/pytorch/*.cc') + glob('lib/*.cu') + glob('lib/*.cc'), + extra_compile_args={ + "cxx": ["-O2", "-std=c++17", "-D_GLIBCXX_USE_CXX11_ABI=0", "-DWITH_CUDA", "-Wno-sign-compare"], + "nvcc": ["-O2", "-std=c++17", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-D_GLIBCXX_USE_CXX11_ABI=0", "-DWITH_CUDA", + "-Xcompiler", "-fPIC", "-lineinfo"] + + arch_flags, + }, + include_dirs=[os.path.join(base_path, 'lib'), + os.path.join(CUDA_HOME, 'include'), + os.path.join(CUDA_HOME, 'lib64')], + libraries=['openblas', 'c10', 'cudart', 'cublas'], + library_dirs=['.']), + ] +else: + extension = [CppExtension( + 'evnn_pytorch_lib', + sources=glob('frameworks/pytorch/*.cc') + glob('lib/*.cc'), + extra_compile_args={ + "cxx": ["-O2", "-std=c++17", "-D_GLIBCXX_USE_CXX11_ABI=0", "-Wno-sign-compare"], + }, + include_dirs=[os.path.join(base_path, 'lib'),], + libraries=['openblas'], + library_dirs=['.', os.path.join('/usr/lib/x86_64-linux-gnu')])] + + +setup(name='evnn_pytorch', + version=VERSION, + description=DESCRIPTION, + long_description=open('README.md', 'r', encoding='utf-8').read(), + long_description_content_type='text/markdown', + author=AUTHOR, + author_email=AUTHOR_EMAIL, + url=URL, + license=LICENSE, + keywords='pytorch machine learning rnn lstm gru custom op', + packages=['evnn_pytorch'], + package_dir={'evnn_pytorch': 'frameworks/pytorch'}, + install_requires=['torch'], + ext_modules=extension, + cmdclass={'build_ext': BuildExtension.with_options(use_ninja=False),}, + classifiers=CLASSIFIERS) diff --git a/validation/self_consistency_test.py b/validation/self_consistency_test.py index 1fa3669..559f685 100644 --- a/validation/self_consistency_test.py +++ b/validation/self_consistency_test.py @@ -32,6 +32,7 @@ time_steps = 8 input_size = 4 hidden_size = 8 +error_tol = 1e-3 @unittest.skipUnless(torch.cuda.is_available(), 'CUDA not available') class EGRUCUDAForwardTest(unittest.TestCase): @@ -58,7 +59,7 @@ def test_forward_y(self): with mock.patch.object(self.egru, "use_custom_cuda", False): y2, _ = self.egru.forward(self.x_cuda_torch) - assert torch.allclose(y1, y2) + assert torch.allclose(y1, y2, atol=error_tol) def test_forward_h(self): with torch.no_grad(): @@ -68,7 +69,7 @@ def test_forward_h(self): with mock.patch.object(self.egru, "use_custom_cuda", False): _, (h2, _, _) = self.egru.forward(self.x_cuda_torch) - assert torch.allclose(h1, h2) + assert torch.allclose(h1, h2, atol=error_tol) def test_forward_o(self): with torch.no_grad(): @@ -88,7 +89,7 @@ def test_forward_trace(self): with mock.patch.object(self.egru, "use_custom_cuda", False): _, (_, _, t2) = self.egru.forward(self.x_cuda_torch) - assert torch.allclose(t1, t2) + assert torch.allclose(t1, t2, atol=error_tol) @unittest.skipUnless(torch.cuda.is_available(), 'CUDA not available') @@ -121,7 +122,7 @@ def test_backward_y(self): y2.backward(torch.ones_like(y2), retain_graph=True) assert torch.allclose(self.x_cuda_torch.grad.data, - self.x_cuda.grad.data, atol=1e-06) + self.x_cuda.grad.data, atol=error_tol) def test_backward_h(self): _, (h1, _, _) = self.egru.forward(self.x_cuda) @@ -134,7 +135,7 @@ def test_backward_h(self): h2.backward(torch.ones_like(h2), retain_graph=True) assert torch.allclose(self.x_cuda_torch.grad.data, - self.x_cuda.grad.data, atol=1e-06) + self.x_cuda.grad.data, atol=error_tol) def test_backward_o(self): _, (_, o1, _) = self.egru.forward(self.x_cuda) @@ -147,7 +148,7 @@ def test_backward_o(self): o2.backward(torch.ones_like(o2), retain_graph=True) assert torch.allclose(self.x_cuda_torch.grad.data, - self.x_cuda.grad.data, atol=1e-06) + self.x_cuda.grad.data, atol=error_tol) def test_backward_trace(self): _, (_, _, t1) = self.egru.forward(self.x_cuda) @@ -160,7 +161,7 @@ def test_backward_trace(self): t2.backward(torch.ones_like(t2), retain_graph=True) assert torch.allclose(self.x_cuda_torch.grad.data, - self.x_cuda.grad.data, atol=1e-06) + self.x_cuda.grad.data, atol=error_tol) if __name__ == '__main__':