Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question on performance the test dataset #124

Open
vigji opened this issue Jan 17, 2024 · 0 comments
Open

Question on performance the test dataset #124

vigji opened this issue Jan 17, 2024 · 0 comments

Comments

@vigji
Copy link

vigji commented Jan 17, 2024

Hello! First of all, thank you very much for the very useful resource and for the effort in making it accessible to the community!

I have installed keypoint in a dedicated conda environment made using the environment.win64_gpu.yml file. My computer is a Windows 10, just updated to latest version, with GeForce 1070 Ti graphics card (8 GB RAM) and the latest available drivers (546.33).

From conda list:

...
cuda-nvcc                 12.3.107                      0    nvidia
cuda-version              11.8                 h70ddcb2_2    conda-forge
cudatoolkit               11.8.0              h09e9e62_12    conda-forge
cudnn                     8.8.0.121            h84bb9a4_4    conda-forge
...
jax                       0.3.22                   pypi_0    pypi
jax-moseq                 0.2.1                    pypi_0    pypi
jaxlib                    0.3.22                   pypi_0    pypi
...
keypoint-moseq            0.4.2                    pypi_0    pypi

If I do python -c "import jax; print(jax.default_backend()) I get the current gpu result.

I tried out the tutorial workflow. I have <0.3 GB RAM memory usage before starting, that get to 7.3 GB after Jax initialization, and remains stably there. The initialization and the AR-HMM model fit with 50 iterations runs smoothly in ~13 mins. When I start the fit of the whole model it crashes silently.

Assuming an OOM error I have set parallel_message_passing=False and now it runs (is it correct that the test dataset is 643911 frames? In which case shouldn't I be fine with 8 GB > 6.5 GB at 0.01 MB/frame as per the faq?)

With that, it runs on the test dataset in approx 8 hours; is this what you would expect? (it sounds reasonable for the upper edge of the 2-5x slowdown of the 1.5h estimate on google colab). I just am a bit surprised as this is significantly slower than what I would get using the cpu version (approx 4.5 hours).
I guess this is fine as my actual data has the comparable size of 500k frames, reporting here just to make sure I'm not missing anything and in case it could be useful for others.

Thank you very much for your clarifications and for all the nice and useful work!

@vigji vigji changed the title Performance question on the test dataset Question on performance the test dataset Jan 17, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant