From 0a61fc04692bd27decaa13999c1e907e98f67767 Mon Sep 17 00:00:00 2001 From: Micka Date: Fri, 8 Nov 2024 16:03:41 +0100 Subject: [PATCH] Add Question Retrieval notebook using Milvus --- ...ectorSearch_QuestionRetrieval_Milvus.ipynb | 732 ++++++++++++++++++ 1 file changed, 732 insertions(+) create mode 100644 notebooks/VectorSearch_QuestionRetrieval_Milvus.ipynb diff --git a/notebooks/VectorSearch_QuestionRetrieval_Milvus.ipynb b/notebooks/VectorSearch_QuestionRetrieval_Milvus.ipynb new file mode 100644 index 000000000..09a6cca43 --- /dev/null +++ b/notebooks/VectorSearch_QuestionRetrieval_Milvus.ipynb @@ -0,0 +1,732 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "f5499b54", + "metadata": {}, + "source": [ + "\n", + "# Similar Questions Retrieval - Milvus - CAGRA-HNSW\n", + "\n", + "This notebook is inspired by the [similar search example of Sentence-Transformers](https://www.sbert.net/examples/applications/semantic-search/README.html#similar-questions-retrieval), and adapted to be used with [Milvus](https://milvus.io) and [cuVS](https://rapids.ai/cuvs/).\n", + "\n", + "The model was pre-trained on the [Natural Questions dataset](https://ai.google.com/research/NaturalQuestions). It consists of about 100k real Google search queries, together with an annotated passage from Wikipedia that provides the answer. It is an example of an asymmetric search task. As corpus, we use the smaller [Simple English Wikipedia](http://sbert.net/datasets/simplewiki-2020-11-01.jsonl.gz) so that it fits easily into memory.\n", + "\n", + "The steps to install the latest Milvus package are available in the [Milvus documentation](https://milvus.io/docs/quickstart.md)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e8d55ede", + "metadata": { + "execution": { + "iopub.execute_input": "2024-11-08T14:47:21.149465Z", + "iopub.status.busy": "2024-11-08T14:47:21.149218Z", + "iopub.status.idle": "2024-11-08T14:47:23.440275Z", + "shell.execute_reply": "2024-11-08T14:47:23.439436Z" + }, + "scrolled": true + }, + "outputs": [], + "source": [ + "!pip install sentence_transformers torch pymilvus pymilvus[bulk_writer] dask dask[distributed]\n", + "\n", + "# Note: if you have a Hopper based GPU, like an H100, use these to install:\n", + "# pip install torch --index-url https://download.pytorch.org/whl/cu118\n", + "# pip install sentence_transformers" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "eb1e81c3", + "metadata": { + "execution": { + "iopub.execute_input": "2024-11-08T14:47:23.444058Z", + "iopub.status.busy": "2024-11-08T14:47:23.443683Z", + "iopub.status.idle": "2024-11-08T14:47:24.219903Z", + "shell.execute_reply": "2024-11-08T14:47:24.219228Z" + } + }, + "outputs": [], + "source": [ + "!nvidia-smi" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ee4c5cc0", + "metadata": { + "execution": { + "iopub.execute_input": "2024-11-08T14:47:24.223131Z", + "iopub.status.busy": "2024-11-08T14:47:24.222874Z", + "iopub.status.idle": "2024-11-08T14:47:34.024085Z", + "shell.execute_reply": "2024-11-08T14:47:34.023435Z" + } + }, + "outputs": [], + "source": [ + "import dask.array as da\n", + "import gzip\n", + "import json\n", + "import math\n", + "import numpy as np\n", + "import os\n", + "import pymilvus\n", + "import time\n", + "import torch\n", + "\n", + "from minio import Minio\n", + "from multiprocessing import Process\n", + "from sentence_transformers import SentenceTransformer, CrossEncoder, util\n", + "from typing import List\n", + "\n", + "\n", + "from pymilvus import (\n", + " connections, utility\n", + ")\n", + "from pymilvus.bulk_writer import LocalBulkWriter, BulkFileType # pip install pymilvus[bulk_writer]\n", + "\n", + "if not torch.cuda.is_available():\n", + " print(\"Warning: No GPU found. Please add GPU to your notebook\")" + ] + }, + { + "cell_type": "markdown", + "id": "47cabaca", + "metadata": {}, + "source": [ + "# Setup Milvus Collection" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5fcd259c", + "metadata": { + "execution": { + "iopub.execute_input": "2024-11-08T14:47:34.027677Z", + "iopub.status.busy": "2024-11-08T14:47:34.027288Z", + "iopub.status.idle": "2024-11-08T14:47:34.109212Z", + "shell.execute_reply": "2024-11-08T14:47:34.108609Z" + } + }, + "outputs": [], + "source": [ + "DIM = 768\n", + "MILVUS_PORT = 30004\n", + "MILVUS_HOST = f\"http://localhost:{MILVUS_PORT}\"\n", + "ID_FIELD=\"id\"\n", + "EMBEDDING_FIELD=\"embedding\"\n", + "\n", + "collection_name = \"simple_wiki\"\n", + "\n", + "def get_milvus_client():\n", + " return pymilvus.MilvusClient(uri=MILVUS_HOST)\n", + "\n", + "client = get_milvus_client()\n", + "\n", + "fields = [\n", + " pymilvus.FieldSchema(name=ID_FIELD, dtype=pymilvus.DataType.INT64, is_primary=True),\n", + " pymilvus.FieldSchema(name=EMBEDDING_FIELD, dtype=pymilvus.DataType.FLOAT_VECTOR, dim=DIM)\n", + "]\n", + "\n", + "schema = pymilvus.CollectionSchema(fields)\n", + "schema.verify()\n", + "\n", + "if collection_name in client.list_collections():\n", + " print(f\"Collection '{collection_name}' already exists. Deleting collection...\")\n", + " client.drop_collection(collection_name)\n", + "\n", + "client.create_collection(collection_name, schema=schema, dimension=DIM, vector_field_name=EMBEDDING_FIELD)\n", + "collection = pymilvus.Collection(name=collection_name, using=client._using)\n", + "collection.release()\n", + "collection.drop_index()\n" + ] + }, + { + "cell_type": "markdown", + "id": "00bd20f5", + "metadata": {}, + "source": [ + "# Setup Sentence Transformer model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0a1a6307", + "metadata": { + "execution": { + "iopub.execute_input": "2024-11-08T14:47:34.111782Z", + "iopub.status.busy": "2024-11-08T14:47:34.111556Z", + "iopub.status.idle": "2024-11-08T14:47:39.654323Z", + "shell.execute_reply": "2024-11-08T14:47:39.653386Z" + } + }, + "outputs": [], + "source": [ + "# We use the Bi-Encoder to encode all passages, so that we can use it with semantic search\n", + "model_name = 'nq-distilbert-base-v1'\n", + "bi_encoder = SentenceTransformer(model_name)\n", + "\n", + "# As dataset, we use Simple English Wikipedia. Compared to the full English wikipedia, it has only\n", + "# about 170k articles. We split these articles into paragraphs and encode them with the bi-encoder\n", + "\n", + "wikipedia_filepath = 'data/simplewiki-2020-11-01.jsonl.gz'\n", + "\n", + "if not os.path.exists(wikipedia_filepath):\n", + " util.http_get('http://sbert.net/datasets/simplewiki-2020-11-01.jsonl.gz', wikipedia_filepath)\n", + "\n", + "passages = []\n", + "with gzip.open(wikipedia_filepath, 'rt', encoding='utf8') as fIn:\n", + " for line in fIn:\n", + " data = json.loads(line.strip())\n", + " for paragraph in data['paragraphs']:\n", + " # We encode the passages as [title, text]\n", + " passages.append([data['title'], paragraph])\n", + "\n", + "# If you like, you can also limit the number of passages you want to use\n", + "print(\"Passages:\", len(passages))\n", + "\n", + "# To speed things up, pre-computed embeddings are downloaded.\n", + "# The provided file encoded the passages with the model 'nq-distilbert-base-v1'\n", + "if model_name == 'nq-distilbert-base-v1':\n", + " embeddings_filepath = 'simplewiki-2020-11-01-nq-distilbert-base-v1.pt'\n", + " if not os.path.exists(embeddings_filepath):\n", + " util.http_get('http://sbert.net/datasets/simplewiki-2020-11-01-nq-distilbert-base-v1.pt', embeddings_filepath)\n", + "\n", + " corpus_embeddings = torch.load(embeddings_filepath, map_location='cpu', weights_only=True).float() # Convert embedding file to float\n", + " #if torch.cuda.is_available():\n", + " # corpus_embeddings = corpus_embeddings.to('cuda')\n", + "else: # Here, we compute the corpus_embeddings from scratch (which can take a while depending on the GPU)\n", + " corpus_embeddings = bi_encoder.encode(passages, convert_to_tensor=True, show_progress_bar=True).to('cpu')" + ] + }, + { + "cell_type": "markdown", + "id": "1f4e9b9d", + "metadata": {}, + "source": [ + "# Vector Search using Milvus and RAPIDS cuVS \n", + "Now that our embeddings are ready to be indexed and that the model has been loaded, we can use Milvus and RAPIDS cuVS to do our vector search.\n", + "\n", + "This is done in 3 steps: First we ingest all the vectors in the Milvus collection, then we build the Milvus index, to finally search it." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "563751c1", + "metadata": { + "execution": { + "iopub.execute_input": "2024-11-08T14:47:39.658832Z", + "iopub.status.busy": "2024-11-08T14:47:39.658374Z", + "iopub.status.idle": "2024-11-08T14:49:47.244768Z", + "shell.execute_reply": "2024-11-08T14:49:47.244162Z" + } + }, + "outputs": [], + "source": [ + "# minio\n", + "MINIO_PORT = 30009\n", + "MINIO_URL = f\"localhost:{MINIO_PORT}\"\n", + "MINIO_SECRET_KEY = \"minioadmin\"\n", + "MINIO_ACCESS_KEY = \"minioadmin\"\n", + "\n", + "def upload_to_minio(file_paths: List[List[str]], remote_paths: List[List[str]], bucket_name=\"milvus-bucket\"):\n", + " minio_client = Minio(endpoint=MINIO_URL, access_key=MINIO_ACCESS_KEY, secret_key=MINIO_SECRET_KEY, secure=False)\n", + " if not minio_client.bucket_exists(bucket_name):\n", + " minio_client.make_bucket(bucket_name)\n", + "\n", + " for local_batch, remote_batch in zip(file_paths, remote_paths):\n", + " for local_file, remote_file in zip(local_batch, remote_batch):\n", + " minio_client.fput_object(bucket_name, \n", + " object_name=remote_file,\n", + " file_path=local_file,\n", + " part_size=512 * 1024 * 1024,\n", + " num_parallel_uploads=5)\n", + " \n", + " \n", + "def ingest_data_bulk(collection_name, vectors, schema: pymilvus.CollectionSchema, log_times=True, bulk_writer_type=\"milvus\", debug=False):\n", + " print(f\"- Ingesting {len(vectors) // 1000}k vectors, Bulk\")\n", + " tic = time.perf_counter()\n", + " collection = pymilvus.Collection(collection_name, using=get_milvus_client()._using)\n", + " remote_path = None\n", + "\n", + " if bulk_writer_type == 'milvus':\n", + " # # Prepare source data for faster ingestion\n", + " writer = LocalBulkWriter(\n", + " schema=schema,\n", + " local_path='bulk_data',\n", + " segment_size=512 * 1024 * 1024, # Default value\n", + " file_type=BulkFileType.NPY\n", + " )\n", + " for id, vec in enumerate(vectors):\n", + " writer.append_row({ID_FIELD: id, EMBEDDING_FIELD: vec})\n", + "\n", + " if debug:\n", + " print(writer.batch_files)\n", + " def callback(file_list):\n", + " if debug:\n", + " print(f\" - Commit successful\")\n", + " print(file_list)\n", + " writer.commit(call_back=callback)\n", + " files_to_upload = writer.batch_files\n", + " elif bulk_writer_type == 'dask':\n", + " # Prepare source data for faster ingestion\n", + " if not os.path.isdir(\"bulk_data\"):\n", + " os.mkdir(\"bulk_data\")\n", + "\n", + " from dask.distributed import Client, LocalCluster\n", + " cluster = LocalCluster(n_workers=1, threads_per_worker=1)\n", + " client = Client(cluster)\n", + "\n", + " chunk_size = 100000\n", + " da_vectors = da.from_array(vectors, chunks=(chunk_size, vectors.shape[1]))\n", + " da_ids = da.arange(len(vectors), chunks=(chunk_size,))\n", + " da.to_npy_stack(\"bulk_data/da_embedding/\", da_vectors)\n", + " da.to_npy_stack(\"bulk_data/da_id/\", da_ids)\n", + " files_to_upload = []\n", + " remote_path = []\n", + " for chunk_nb in range(math.ceil(len(vectors) / chunk_size)):\n", + " files_to_upload.append([f\"bulk_data/da_embedding/{chunk_nb}.npy\", f\"bulk_data/da_id/{chunk_nb}.npy\"])\n", + " remote_path.append([f\"bulk_data/da_{chunk_nb}/embedding.npy\", f\"bulk_data/da__{chunk_nb}/id.npy\"])\n", + "\n", + " elif bulk_writer_type == 'numpy':\n", + " # Directly save NPY files\n", + " np.save(\"bulk_data/embedding.npy\", vectors)\n", + " np.save(\"bulk_data/id.npy\", np.arange(len(vectors)))\n", + " files_to_upload = [[\"bulk_data/embedding.npy\", \"bulk_data/id.npy\"]]\n", + " else:\n", + " raise ValueError(\"Invalid bulk writer type\")\n", + " \n", + " toc = time.perf_counter()\n", + " if log_times:\n", + " print(f\" - File save time: {toc - tic:.2f} seconds\")\n", + " # Import data\n", + " if remote_path is None:\n", + " remote_path = files_to_upload\n", + " upload_to_minio(files_to_upload, remote_path)\n", + " \n", + " job_ids = [utility.do_bulk_insert(collection_name, batch, using=get_milvus_client()._using) for batch in remote_path]\n", + "\n", + " while True:\n", + " tasks = [utility.get_bulk_insert_state(job_id, using=get_milvus_client()._using) for job_id in job_ids]\n", + " success = all(task.state_name == \"Completed\" for task in tasks)\n", + " failure = any(task.state_name == \"Failed\" for task in tasks)\n", + " for i in range(len(tasks)):\n", + " task = tasks[i]\n", + " if debug:\n", + " print(f\" - Task {i}/{len(tasks)} state: {task.state_name}, Progress percent: {task.infos['progress_percent']}, Imported row count: {task.row_count}\")\n", + " if task.state_name == \"Failed\":\n", + " print(task)\n", + " if success or failure:\n", + " break\n", + " time.sleep(2)\n", + "\n", + " added_entities = str(sum([task.row_count for task in tasks]))\n", + " failure = failure or added_entities != str(len(vectors))\n", + " if failure:\n", + " print(f\"- Ingestion failed. Added entities: {added_entities}\")\n", + " toc = time.perf_counter()\n", + " if log_times:\n", + " datasize = vectors.nbytes / 1024 / 1024\n", + " print(f\"- Ingestion time: {toc - tic:.2f} seconds. ({(datasize / (toc-tic)):.2f}MB/s)\")\n", + "\n", + "ingest_data_bulk(collection_name, np.array(corpus_embeddings), schema, bulk_writer_type='dask', log_times=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ad90b4be", + "metadata": { + "execution": { + "iopub.execute_input": "2024-11-08T14:49:47.247498Z", + "iopub.status.busy": "2024-11-08T14:49:47.247268Z", + "iopub.status.idle": "2024-11-08T14:50:00.737502Z", + "shell.execute_reply": "2024-11-08T14:50:00.736808Z" + } + }, + "outputs": [], + "source": [ + "# Setups the IVFPQ index\n", + "\n", + "index_params = dict(\n", + " index_type=\"GPU_IVF_PQ\",\n", + " metric_type=\"L2\",\n", + " params={\"nlist\": 150, # Number of clusters\n", + " \"m\": 96}) # Product Quantization dimension\n", + "\n", + "# Drop the index if it exists\n", + "if collection.has_index():\n", + " collection.release()\n", + " collection.drop_index()\n", + "\n", + "# Create the index\n", + "tic = time.perf_counter()\n", + "collection.create_index(field_name=EMBEDDING_FIELD, index_params=index_params)\n", + "collection.load()\n", + "toc = time.perf_counter()\n", + "print(f\"- Index creation time: {toc - tic:.4f} seconds. ({index_params})\")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "c75acea7", + "metadata": { + "execution": { + "iopub.execute_input": "2024-11-08T14:50:00.740443Z", + "iopub.status.busy": "2024-11-08T14:50:00.740142Z", + "iopub.status.idle": "2024-11-08T14:50:00.745403Z", + "shell.execute_reply": "2024-11-08T14:50:00.744672Z" + } + }, + "outputs": [], + "source": [ + "# Search the index\n", + "def search_cuvs_pq(query, top_k = 5, n_probe = 30):\n", + " # Encode the query using the bi-encoder and find potentially relevant passages\n", + " question_embedding = bi_encoder.encode(query, convert_to_tensor=True)\n", + "\n", + " search_params = {\"nprobe\": n_probe}\n", + " tic = time.perf_counter()\n", + " hits = collection.search(\n", + " data=np.array(question_embedding[None].cpu()), anns_field=EMBEDDING_FIELD, param=search_params, limit=top_k\n", + " )\n", + " toc = time.perf_counter()\n", + "\n", + " # Output of top-k hits\n", + " print(\"Input question:\", query)\n", + " print(\"Results (after {:.3f} ms):\".format((toc - tic)*1000))\n", + " for k in range(top_k):\n", + " print(\"\\t{:.3f}\\t{}\".format(hits[0][k].distance, passages[hits[0][k].id]))" + ] + }, + { + "cell_type": "markdown", + "id": "07935bca", + "metadata": {}, + "source": [ + "The ideal use-case for the IVF-PQ algorithm is when there is a need to reduce the memory footprint while keeping a good accuracy." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c27d4715", + "metadata": { + "execution": { + "iopub.execute_input": "2024-11-08T14:50:00.748001Z", + "iopub.status.busy": "2024-11-08T14:50:00.747783Z", + "iopub.status.idle": "2024-11-08T14:50:01.785914Z", + "shell.execute_reply": "2024-11-08T14:50:01.785223Z" + } + }, + "outputs": [], + "source": [ + "search_cuvs_pq(query=\"Who was Grace Hopper?\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bc375518", + "metadata": { + "execution": { + "iopub.execute_input": "2024-11-08T14:50:01.788877Z", + "iopub.status.busy": "2024-11-08T14:50:01.788640Z", + "iopub.status.idle": "2024-11-08T14:50:01.813820Z", + "shell.execute_reply": "2024-11-08T14:50:01.813153Z" + } + }, + "outputs": [], + "source": [ + "search_cuvs_pq(query=\"Who was Alan Turing?\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ab154181", + "metadata": { + "execution": { + "iopub.execute_input": "2024-11-08T14:50:01.816625Z", + "iopub.status.busy": "2024-11-08T14:50:01.816362Z", + "iopub.status.idle": "2024-11-08T14:50:01.839593Z", + "shell.execute_reply": "2024-11-08T14:50:01.838986Z" + } + }, + "outputs": [], + "source": [ + "search_cuvs_pq(query = \"What is creating tides?\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "836344ec", + "metadata": { + "execution": { + "iopub.execute_input": "2024-11-08T14:50:01.842319Z", + "iopub.status.busy": "2024-11-08T14:50:01.842022Z", + "iopub.status.idle": "2024-11-08T14:50:15.969324Z", + "shell.execute_reply": "2024-11-08T14:50:15.968562Z" + } + }, + "outputs": [], + "source": [ + "# Drop the current index if it exists\n", + "if collection.has_index():\n", + " collection.release()\n", + " collection.drop_index()\n", + "\n", + "# Create the IVF Flat index\n", + "index_params = dict(\n", + " index_type=\"GPU_IVF_FLAT\",\n", + " metric_type=\"L2\",\n", + " params={\"nlist\": 150}) # Number of clusters)\n", + "tic = time.perf_counter()\n", + "collection.create_index(field_name=EMBEDDING_FIELD, index_params=index_params)\n", + "collection.load()\n", + "toc = time.perf_counter()\n", + "print(f\"- Index creation time: {toc - tic:.4f} seconds. ({index_params})\")" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "2d6017ed", + "metadata": { + "execution": { + "iopub.execute_input": "2024-11-08T14:50:15.972764Z", + "iopub.status.busy": "2024-11-08T14:50:15.972368Z", + "iopub.status.idle": "2024-11-08T14:50:15.977806Z", + "shell.execute_reply": "2024-11-08T14:50:15.977064Z" + } + }, + "outputs": [], + "source": [ + "def search_cuvs_flat(query, top_k = 5, n_probe = 30):\n", + " # Encode the query using the bi-encoder and find potentially relevant passages\n", + " question_embedding = bi_encoder.encode(query, convert_to_tensor=True)\n", + " \n", + " search_params = {\"nprobe\": n_probe}\n", + " tic = time.perf_counter()\n", + " hits = collection.search(\n", + " data=np.array(question_embedding[None].cpu()), anns_field=EMBEDDING_FIELD, param=search_params, limit=top_k\n", + " )\n", + " toc = time.perf_counter()\n", + "\n", + " # Output of top-k hits\n", + " print(\"Input question:\", query)\n", + " print(\"Results (after {:.3f} ms):\".format((toc - tic)*1000))\n", + " for k in range(top_k):\n", + " print(\"\\t{:.3f}\\t{}\".format(hits[0][k].distance, passages[hits[0][k].id]))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f5cfb644", + "metadata": { + "execution": { + "iopub.execute_input": "2024-11-08T14:50:15.980796Z", + "iopub.status.busy": "2024-11-08T14:50:15.980408Z", + "iopub.status.idle": "2024-11-08T14:50:16.009271Z", + "shell.execute_reply": "2024-11-08T14:50:16.008579Z" + } + }, + "outputs": [], + "source": [ + "search_cuvs_flat(query=\"Who was Grace Hopper?\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b5694d00", + "metadata": { + "execution": { + "iopub.execute_input": "2024-11-08T14:50:16.012253Z", + "iopub.status.busy": "2024-11-08T14:50:16.011924Z", + "iopub.status.idle": "2024-11-08T14:50:16.043432Z", + "shell.execute_reply": "2024-11-08T14:50:16.042751Z" + } + }, + "outputs": [], + "source": [ + "search_cuvs_flat(query=\"Who was Alan Turing?\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fcfc3c5b", + "metadata": { + "execution": { + "iopub.execute_input": "2024-11-08T14:50:16.046439Z", + "iopub.status.busy": "2024-11-08T14:50:16.046093Z", + "iopub.status.idle": "2024-11-08T14:50:16.071322Z", + "shell.execute_reply": "2024-11-08T14:50:16.070614Z" + } + }, + "outputs": [], + "source": [ + "search_cuvs_flat(query = \"What is creating tides?\")" + ] + }, + { + "cell_type": "markdown", + "id": "a59d7b32-0832-4c3a-864e-aeb2e6e7fe1f", + "metadata": {}, + "source": [ + "## Using CAGRA: Hybrid GPU-CPU graph-based Vector Search\n", + "\n", + "CAGRA is a graph-based nearest neighbors implementation with state-of-the art performance for both small- and large-batch sized vector searches. \n", + "\n", + "CAGRA follows the same steps as IVF-FLAT and IVF-PQ in Milvus, but is also able to be adapted for querying on CPU.\n", + "This means that CAGRA is able to profit from a high training speed on GPU, as well as a low inference time on CPU, that minimize latency even on the smallest queries." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e5ce4dab", + "metadata": { + "execution": { + "iopub.execute_input": "2024-11-08T14:50:16.074449Z", + "iopub.status.busy": "2024-11-08T14:50:16.074128Z", + "iopub.status.idle": "2024-11-08T14:50:30.479027Z", + "shell.execute_reply": "2024-11-08T14:50:30.478265Z" + } + }, + "outputs": [], + "source": [ + "# Drop the current index if it exists\n", + "if collection.has_index():\n", + " collection.release()\n", + " collection.drop_index()\n", + "\n", + "# Create the IVF Flat index\n", + "index_params = dict(\n", + " index_type=\"GPU_CAGRA\",\n", + " metric_type=\"L2\",\n", + " params={\"graph_degree\": 64, \"intermediate_graph_degree\": 128, \"build_algo\": \"NN_DESCENT\", \"adapt_for_cpu\": True})\n", + "tic = time.perf_counter()\n", + "collection.create_index(field_name=EMBEDDING_FIELD, index_params=index_params)\n", + "collection.load()\n", + "toc = time.perf_counter()\n", + "print(f\"- Index creation time: {toc - tic:.4f} seconds. ({index_params})\")" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "df229e21-f6b6-4d6c-ad54-2724f8738934", + "metadata": { + "execution": { + "iopub.execute_input": "2024-11-08T14:50:30.481748Z", + "iopub.status.busy": "2024-11-08T14:50:30.481474Z", + "iopub.status.idle": "2024-11-08T14:50:30.486324Z", + "shell.execute_reply": "2024-11-08T14:50:30.485696Z" + } + }, + "outputs": [], + "source": [ + "def search_cuvs_cagra(query, top_k = 5, itopk = 32):\n", + " # Encode the query using the bi-encoder and find potentially relevant passages\n", + " question_embedding = bi_encoder.encode(query, convert_to_tensor=True)\n", + "\n", + " search_params = {\"params\": {\"itopk\": itopk, \"ef\": 35}}\n", + " tic = time.perf_counter()\n", + " hits = collection.search(\n", + " data=np.array(question_embedding[None].cpu()), anns_field=EMBEDDING_FIELD, param=search_params, limit=top_k\n", + " )\n", + " toc = time.perf_counter()\n", + "\n", + " # Output of top-k hits\n", + " print(\"Input question:\", query)\n", + " print(\"Results (after {:.3f} ms):\".format((toc - tic)*1000))\n", + " for k in range(top_k):\n", + " print(\"\\t{:.3f}\\t{}\".format(hits[0][k].distance, passages[hits[0][k].id]))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b5e862fd-b7e5-4423-8fbf-36918f02c8f3", + "metadata": { + "execution": { + "iopub.execute_input": "2024-11-08T14:50:30.489077Z", + "iopub.status.busy": "2024-11-08T14:50:30.488790Z", + "iopub.status.idle": "2024-11-08T14:50:30.513998Z", + "shell.execute_reply": "2024-11-08T14:50:30.513319Z" + } + }, + "outputs": [], + "source": [ + "search_cuvs_cagra(query=\"Who was Grace Hopper?\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cb8a5b7b", + "metadata": { + "execution": { + "iopub.execute_input": "2024-11-08T14:50:30.516748Z", + "iopub.status.busy": "2024-11-08T14:50:30.516521Z", + "iopub.status.idle": "2024-11-08T14:50:30.538982Z", + "shell.execute_reply": "2024-11-08T14:50:30.538269Z" + } + }, + "outputs": [], + "source": [ + "search_cuvs_cagra(query=\"Who was Alan Turing?\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4c89810a", + "metadata": { + "execution": { + "iopub.execute_input": "2024-11-08T14:50:30.541508Z", + "iopub.status.busy": "2024-11-08T14:50:30.541287Z", + "iopub.status.idle": "2024-11-08T14:50:30.562722Z", + "shell.execute_reply": "2024-11-08T14:50:30.562085Z" + } + }, + "outputs": [], + "source": [ + "search_cuvs_cagra(query=\"What is creating tides?\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.7" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}