Skip to content

Latest commit

 

History

History
39 lines (27 loc) · 2.91 KB

README.md

File metadata and controls

39 lines (27 loc) · 2.91 KB

NeRFax

In this work we implement NeRF and, to the best of our knowledge, the first fully parallel implementation of pNeRF in an emerging framework, JAX. We demonstrate speedups in the range 35-175x in comparison to the fastest public implementation for single chain proteins and utilising the frameworks ability to trivially parallelise functions we show a >10,000x speedup relative to using mp-NeRF serially for a biomolecular condensate of 1,000 chains of 163 residues.

Benchmarks

Single chain

Runtime of different computational methods for single chains

Speedup, relative to the CPU mp_nerf implementation, of different computational methods for single chains

This can be reproduced with notebooks/benchmark_single_chain_reconstruction.ipynb.

Multiple chains: Biomolecular condensate reconstruction

Leveraging the automatic vectorization feature of JAX the reconstruction was parallelized, running in 3.4 ms on GPU. Extrapolation of the torch implementation gives ~60 seconds in previous implementations, approximately 17,000x faster as the torch has no parallel chain implementation so has to be computed serially. This can be reproduced with notebooks/benchmark_multiple_chain_reconstruction.ipynb.

Installation

Pip

git clone https://github.com/PeptoneLtd/nerfax.git && pip install ./nerfax[optional]

Note: for running on GPU, a GPU version of JAX must be installed, please follows the instructions at JAX GPU compatibility instructions

Docker image

We also provide a Dockerfile which can be used to install NerFax. The dockerfile includes the GPU version of JAX.