From 5735963faee9e6507e69cec6b0c3ce27cc01881c Mon Sep 17 00:00:00 2001 From: DanielUH2019 Date: Mon, 5 Feb 2024 04:22:19 -0500 Subject: [PATCH] Create basic example for pinecone integration --- .../pinecone/pinecone_wikipedia_example.ipynb | 625 ++++++++++++++++++ 1 file changed, 625 insertions(+) create mode 100644 examples/integrations/pinecone/pinecone_wikipedia_example.ipynb diff --git a/examples/integrations/pinecone/pinecone_wikipedia_example.ipynb b/examples/integrations/pinecone/pinecone_wikipedia_example.ipynb new file mode 100644 index 000000000..d522032bd --- /dev/null +++ b/examples/integrations/pinecone/pinecone_wikipedia_example.ipynb @@ -0,0 +1,625 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_2031537/1159854371.py:20: DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html\n", + " import pkg_resources # Install the package if it's not installed\n" + ] + } + ], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", + "import sys\n", + "import os\n", + "\n", + "try: # When on google Colab, let's clone the notebook so we download the cache.\n", + " import google.colab\n", + " repo_path = 'dspy'\n", + " !git -C $repo_path pull origin || git clone https://github.com/stanfordnlp/dspy $repo_path\n", + "except:\n", + " repo_path = '.'\n", + "\n", + "if repo_path not in sys.path:\n", + " sys.path.append(repo_path)\n", + "\n", + "# Set up the cache for this notebook\n", + "os.environ[\"DSP_NOTEBOOK_CACHEDIR\"] = os.path.join(repo_path, 'cache')\n", + "\n", + "import pkg_resources # Install the package if it's not installed\n", + "if not \"dspy-ai\" in {pkg.key for pkg in pkg_resources.working_set}:\n", + " !pip install -U pip\n", + " !pip install dspy-ai\n", + " # !pip install -e $repo_path\n", + "\n", + "import dspy\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from datasets import load_dataset, Dataset\n", + "docs = load_dataset(f\"Cohere/wikipedia-22-12-simple-embeddings\", \"en\", split=\"train[:5%]\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Dataset({\n", + " features: ['id', 'title', 'text', 'url', 'wiki_id', 'views', 'paragraph_id', 'langs', 'emb'],\n", + " num_rows: 24293\n", + "})" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "docs" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "import os\n", + "from pinecone.grpc import PineconeGRPC as Pinecone\n", + "from pinecone import PodSpec\n", + "\n", + "api_key = os.getenv(\"PINECONE_API_KEY\")\n", + "pc = Pinecone(\n", + " api_key=os.environ.get(\"PINECONE_API_KEY\")\n", + ")\n", + "\n", + "# Pick a name for the new index\n", + "index_name = 'wikipedia-articles'\n", + "\n", + "# Check whether the index with the same name already exists - if so, delete it\n", + "if index_name in pc.list_indexes():\n", + " pc.delete_index(index_name)\n", + " \n", + "# Creates new index\n", + "if index_name not in pc.list_indexes().names():\n", + " pc.create_index(\n", + " name=index_name, \n", + " dimension=768, \n", + " metric='dotproduct',\n", + " spec=PodSpec(\n", + " environment=\"gcp-starter\",\n", + " )\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'dimension': 768,\n", + " 'index_fullness': 0.0,\n", + " 'namespaces': {'': {'vector_count': 0}},\n", + " 'total_vector_count': 0}" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "index = pc.Index(index_name)\n", + "index.describe_index_stats()" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 53]" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\n", + "def chunker(data, batch_size):\n", + " data_iter = iter(data)\n", + " # end = False\n", + " for i in range(0, len(data), batch_size):\n", + " chunk = []\n", + " # if i + batch_size >= len(data):\n", + " # batch_size += len(data) - i\n", + " for x in data_iter:\n", + " if len(chunk) == batch_size:\n", + " break\n", + " chunk.append(x)\n", + " \n", + " chunk_to_insert = []\n", + " for x in chunk:\n", + " item = {}\n", + " item['id'] = str(x['id'])\n", + " item['values'] = x['emb']\n", + " item['metadata'] = {}\n", + " item['metadata']['text'] = x['text']\n", + " chunk_to_insert.append(item)\n", + "\n", + " yield chunk_to_insert\n", + "\n", + "async_results = [\n", + " index.upsert(vectors=chunk, async_req=True)\n", + " for chunk in chunker(docs, batch_size=100) if len(chunk) > 0\n", + "]\n", + "\n", + "# Wait for and retrieve responses (in case of error)\n", + "results = [async_result.result() for async_result in async_results]\n", + "results" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'dimension': 768,\n", + " 'index_fullness': 0.24053,\n", + " 'namespaces': {'': {'vector_count': 24053}},\n", + " 'total_vector_count': 24053}" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "index.describe_index_stats()" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "from dspy.retrieve.pinecone_rm import PineconeRM, CohereEmbed\n", + "\n", + "cohere_embed = CohereEmbed()\n", + "\n", + "llm = dspy.OllamaLocal(model=\"openhermes2.5-mistral:7b-q5_K_M\", model_type=\"chat\")\n", + "retriever_model = PineconeRM(index_name, cloud_emded_provider=cohere_embed)\n", + "dspy.settings.configure(lm=llm, rm=retriever_model)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(20, 50)" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from dspy.datasets import HotPotQA\n", + "\n", + "# Load the dataset.\n", + "dataset = HotPotQA(train_seed=1, train_size=20, eval_seed=2023, dev_size=50, test_size=0)\n", + "\n", + "# Tell DSPy that the 'question' field is the input. Any other fields are labels and/or metadata.\n", + "trainset = [x.with_inputs('question') for x in dataset.train]\n", + "devset = [x.with_inputs('question') for x in dataset.dev]\n", + "\n", + "len(trainset), len(devset)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "class GenerateAnswer(dspy.Signature):\n", + " \"\"\"Answer questions with short factoid answers.\"\"\"\n", + "\n", + " context = dspy.InputField(desc=\"may contain relevant facts\")\n", + " question = dspy.InputField()\n", + " answer = dspy.OutputField(desc=\"often between 1 and 5 words\")" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "class RAG(dspy.Module):\n", + " def __init__(self, num_passages=3):\n", + " super().__init__()\n", + "\n", + " self.retrieve = dspy.Retrieve(k=num_passages)\n", + " self.generate_answer = dspy.ChainOfThought(GenerateAnswer)\n", + " \n", + " def forward(self, question):\n", + " context = self.retrieve(question).passages\n", + " prediction = self.generate_answer(context=context, question=question)\n", + " return dspy.Prediction(context=context, answer=prediction.answer)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 0%| | 0/20 [00:00