You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hello! First of all, thank you very much for the very useful resource and for the effort in making it accessible to the community!
I have installed keypoint in a dedicated conda environment made using the environment.win64_gpu.yml file. My computer is a Windows 10, just updated to latest version, with GeForce 1070 Ti graphics card (8 GB RAM) and the latest available drivers (546.33).
If I do python -c "import jax; print(jax.default_backend()) I get the current gpu result.
I tried out the tutorial workflow. I have <0.3 GB RAM memory usage before starting, that get to 7.3 GB after Jax initialization, and remains stably there. The initialization and the AR-HMM model fit with 50 iterations runs smoothly in ~13 mins. When I start the fit of the whole model it crashes silently.
Assuming an OOM error I have set parallel_message_passing=False and now it runs (is it correct that the test dataset is 643911 frames? In which case shouldn't I be fine with 8 GB > 6.5 GB at 0.01 MB/frame as per the faq?)
With that, it runs on the test dataset in approx 8 hours; is this what you would expect? (it sounds reasonable for the upper edge of the 2-5x slowdown of the 1.5h estimate on google colab). I just am a bit surprised as this is significantly slower than what I would get using the cpu version (approx 4.5 hours).
I guess this is fine as my actual data has the comparable size of 500k frames, reporting here just to make sure I'm not missing anything and in case it could be useful for others.
Thank you very much for your clarifications and for all the nice and useful work!
The text was updated successfully, but these errors were encountered:
vigji
changed the title
Performance question on the test dataset
Question on performance the test dataset
Jan 17, 2024
Hello! First of all, thank you very much for the very useful resource and for the effort in making it accessible to the community!
I have installed keypoint in a dedicated conda environment made using the
environment.win64_gpu.yml
file. My computer is a Windows 10, just updated to latest version, with GeForce 1070 Ti graphics card (8 GB RAM) and the latest available drivers (546.33).From
conda list
:If I do
python -c "import jax; print(jax.default_backend())
I get the currentgpu
result.I tried out the tutorial workflow. I have <0.3 GB RAM memory usage before starting, that get to 7.3 GB after Jax initialization, and remains stably there. The initialization and the AR-HMM model fit with 50 iterations runs smoothly in ~13 mins. When I start the fit of the whole model it crashes silently.
Assuming an OOM error I have set
parallel_message_passing=False
and now it runs (is it correct that the test dataset is 643911 frames? In which case shouldn't I be fine with 8 GB > 6.5 GB at 0.01 MB/frame as per the faq?)With that, it runs on the test dataset in approx 8 hours; is this what you would expect? (it sounds reasonable for the upper edge of the 2-5x slowdown of the 1.5h estimate on google colab). I just am a bit surprised as this is significantly slower than what I would get using the cpu version (approx 4.5 hours).
I guess this is fine as my actual data has the comparable size of 500k frames, reporting here just to make sure I'm not missing anything and in case it could be useful for others.
Thank you very much for your clarifications and for all the nice and useful work!
The text was updated successfully, but these errors were encountered: