Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax.lax.approx_min_k() on GPU appears to be a brute force search #24414

Open
notnot opened this issue Oct 20, 2024 · 0 comments
Open

jax.lax.approx_min_k() on GPU appears to be a brute force search #24414

notnot opened this issue Oct 20, 2024 · 0 comments
Labels
bug Something isn't working

Comments

@notnot
Copy link

notnot commented Oct 20, 2024

Description

jax.lax.approx_min_k() on GPU appears to be a brute force search.
When running on a GPU the recall_target value doesn't have any effect. The execution time scales linearly with the number of elements in the array to search in.

In the source code i can see that approx_min_k() is only implemented as a KNN search for Google TPU. Please implement it along the same lines for the GPU, if possible.

Note that the performance is quite impressive for a brute force search. But i need to scale up my database and then the current code won't cut it.

System info (python version, jaxlib version, accelerator, etc.)

>>> jax.print_environment_info()
jax:    0.4.31
jaxlib: 0.4.31
numpy:  1.26.4
python: 3.12.7 (main, Oct  1 2024, 08:52:11) [GCC 9.4.0]
jax.devices (1 total, 1 local): [CudaDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='boxob', release='5.15.0-124-generic', version='#134~20.04.1-Ubuntu SMP Tue Oct 1 15:27:33 UTC 2024', machine='x86_64')


$ nvidia-smi
Sun Oct 20 22:40:50 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.183.01             Driver Version: 535.183.01   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA GeForce RTX 2070        Off | 00000000:01:00.0  On |                  N/A |
|  0%   47C    P2              38W / 185W |    468MiB /  8192MiB |      5%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A      1343      G   /usr/lib/xorg/Xorg                          239MiB |
|    0   N/A  N/A      1614      G   /usr/bin/gnome-shell                         50MiB |
|    0   N/A  N/A    132100      G   ...onEnabled --variations-seed-version       76MiB |
|    0   N/A  N/A    184955      C   python                                       96MiB |
+---------------------------------------------------------------------------------------+

>>> 
@notnot notnot added the bug Something isn't working label Oct 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant