You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
jax.lax.approx_min_k() on GPU appears to be a brute force search.
When running on a GPU the recall_target value doesn't have any effect. The execution time scales linearly with the number of elements in the array to search in.
In the source code i can see that approx_min_k() is only implemented as a KNN search for Google TPU. Please implement it along the same lines for the GPU, if possible.
Note that the performance is quite impressive for a brute force search. But i need to scale up my database and then the current code won't cut it.
System info (python version, jaxlib version, accelerator, etc.)
>>> jax.print_environment_info()
jax: 0.4.31
jaxlib: 0.4.31
numpy: 1.26.4
python: 3.12.7 (main, Oct 1 2024, 08:52:11) [GCC 9.4.0]
jax.devices (1 total, 1 local): [CudaDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='boxob', release='5.15.0-124-generic', version='#134~20.04.1-Ubuntu SMP Tue Oct 1 15:27:33 UTC 2024', machine='x86_64')
$ nvidia-smi
Sun Oct 20 22:40:50 2024
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.183.01 Driver Version: 535.183.01 CUDA Version: 12.2 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA GeForce RTX 2070 Off | 00000000:01:00.0 On | N/A |
| 0% 47C P2 38W / 185W | 468MiB / 8192MiB | 5% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
| 0 N/A N/A 1343 G /usr/lib/xorg/Xorg 239MiB |
| 0 N/A N/A 1614 G /usr/bin/gnome-shell 50MiB |
| 0 N/A N/A 132100 G ...onEnabled --variations-seed-version 76MiB |
| 0 N/A N/A 184955 C python 96MiB |
+---------------------------------------------------------------------------------------+
>>>
The text was updated successfully, but these errors were encountered:
Description
jax.lax.approx_min_k() on GPU appears to be a brute force search.
When running on a GPU the recall_target value doesn't have any effect. The execution time scales linearly with the number of elements in the array to search in.
In the source code i can see that approx_min_k() is only implemented as a KNN search for Google TPU. Please implement it along the same lines for the GPU, if possible.
Note that the performance is quite impressive for a brute force search. But i need to scale up my database and then the current code won't cut it.
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: