Skip to content

Commit

Permalink
feat: add T5 acceleration support (#58)
Browse files Browse the repository at this point in the history
* fix: fix ONNX shape fixed dim size

* fix: fix ONNX shape fixed dim size

* fix: update docker file base

* feat: add script to generate T5 model

* feat: add support for provided output shape on ORT inference

* fix: disable symbolic shape inference (not working with ONNX last version 1.11.0)

* fix: improve documentation (TRT supported op link)

* fix: add ipython as dependency in Docker image to run Jupyter notebook easily

* feat: better support for dynamic axis (T5)

* feat: change shape management in TensorRT inference

* feat: T5 test script

* fix: error message in tensorrt inference function

* update Pytorch dependencies

* feat: convert manually onnx to fp16

* fix: add generic support of fp16 on trt

* fix: refactoring

* fix. refactoring

* fix: refactoring

* feat: apply new fp32 node detector to t5 dec module

* fix: refactoring

* fix: fix tests, refactoring

* fix: update ORT dependencies

* feat: add cache toy model + cache notebook

* fix: display output correctly

* fix: export T5 graph with cache support

* fix: end to end conversion process

* fix: working inference

* fix: text generation works but is slow

* fix: add some explanations, clean scripts

* fix: clean scripts

* fix: script a bit more stable

* fix: add graphs

* feat: capture pytorch timings

* feat: no more output shape guess for ORT

* fix: fix tests

* fix: fix tests

* feat: no memory copy during ORT inference + refactoring

* fix: mixed precision with several models

* fix: fix link to demo folder

* fix: fix code for CPU only execution

* fix: fp16 conversion works

* fix: refactoring to hide most of the fp16 logic from the lib user eyes. works on t5-base

* feat: improve copy less and FP16 transfo

* fix: linter

* fix: update text and results

* fix: extend gitignore

* fix: add doc dependency

* feat: add log severity setup to ORT loading

* feat: update notebook text

* feat: script for tensorrt + t5

* feat: script for tensorrt + t5

* feat: start to support ONNX external data

* fix: fix linter

* fix: better wording in notebook

* fix: notebook works on 3b

* fix: notebook works on large

* fix: fix encoder

* feat: update notebook text

* feat: notebook text

* feat: text

* fix: linter

* fix: timings

* fix: reduce memory footprint

* fix: move
  • Loading branch information
pommedeterresautee authored May 24, 2022
1 parent a7594e3 commit d397869
Show file tree
Hide file tree
Showing 18 changed files with 8,328 additions and 5,553 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -217,3 +217,8 @@ cython_debug/
.idea/
TensorRT/
triton_models/
*.whl
.vscode
to_delete/
test-*/
.history/
4 changes: 2 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
FROM nvcr.io/nvidia/tritonserver:22.01-py3
FROM nvcr.io/nvidia/tritonserver:22.02-py3

# see .dockerignore to check what is transfered
COPY . ./

RUN pip3 install -U pip && \
pip3 install nvidia-pyindex && \
pip3 install ".[GPU]" -f https://download.pytorch.org/whl/cu113/torch_stable.html --extra-index-url https://pypi.ngc.nvidia.com --no-cache-dir && \
pip3 install sentence-transformers notebook pytorch-quantization
pip3 install sentence-transformers notebook pytorch-quantization ipywidgets
2 changes: 1 addition & 1 deletion demo/generative-model/gpt2.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -816,7 +816,7 @@
"\n",
"def inference_tensorrt(input_ids: torch.Tensor) -> torch.Tensor:\n",
" data = {\"input_ids\": input_ids}\n",
" return tensorrt_model(data)[0]\n",
" return tensorrt_model(data)\n",
"\n",
"\n",
"gpt2_model = GPTModelWrapper(config=model.config, device=torch.device(\"cuda\"), inference=inference_tensorrt)\n",
Expand Down
1,876 changes: 1,876 additions & 0 deletions demo/generative-model/t5.ipynb

Large diffs are not rendered by default.

Loading

0 comments on commit d397869

Please sign in to comment.