diff --git a/assets/langchainQA.png b/assets/langchainQA.png new file mode 100644 index 0000000000..890e654989 Binary files /dev/null and b/assets/langchainQA.png differ diff --git a/examples/LangChain_QA.ipynb b/examples/LangChain_QA.ipynb index 82cb06f118..09978f9c1f 100644 --- a/examples/LangChain_QA.ipynb +++ b/examples/LangChain_QA.ipynb @@ -11,7 +11,14 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "This demo walks through how to build an LLM-driven question-answering (QA) application with Xinference, Milvus, and LangChain." + "This demo walks through how to build an LLM-driven question-answering (QA) application with Xinference, Milvus, and LangChain. It uses Falcon 40B Instruct model for embedding creation and Llama 2 70B Chat model for inference. Both of the models are fully supported by Xinference." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "![LangChain QA Visualization by Dreamsome](../assets/langchainQA.png)" ] }, { @@ -34,19 +41,19 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Model uid: 19c73cee-3506-11ee-b286-fa163e74fa2d\n" + "Model uid: 46bf725e-3a5e-11ee-9dcd-fa163e74fa2d\n" ] } ], "source": [ - "!xinference launch --model-name \"falcon-instruct\" --model-format pytorch --size-in-billions 40 -e \"http://127.0.0.1:56256\"" + "!xinference launch --model-name \"falcon-instruct\" --model-format pytorch --size-in-billions 40 -e \"http://127.0.0.1:55950\"" ] }, { @@ -65,7 +72,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -93,15 +100,15 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from langchain.embeddings import XinferenceEmbeddings\n", "\n", "xinference_embeddings = XinferenceEmbeddings(\n", - " server_url=\"http://127.0.0.1:56256\", \n", - " model_uid = \"19c73cee-3506-11ee-b286-fa163e74fa2d\" # model_uid is the uid returned from launching the model\n", + " server_url=\"http://127.0.0.1:55950\", \n", + " model_uid = \"46bf725e-3a5e-11ee-9dcd-fa163e74fa2d\" # model_uid is the uid returned from launching the model\n", ")" ] }, @@ -205,7 +212,7 @@ } ], "source": [ - "!xinference launch --model-name \"llama-2-chat\" --model-format ggmlv3 --size-in-billions 70 -e \"http://127.0.0.1:56256\"" + "!xinference launch --model-name \"llama-2-chat\" --model-format ggmlv3 --size-in-billions 70 -e \"http://127.0.0.1:55950\"" ] }, { @@ -217,7 +224,7 @@ "from langchain.llms import Xinference\n", "\n", "xinference_llm = Xinference(\n", - " server_url=\"http://127.0.0.1:56256\",\n", + " server_url=\"http://127.0.0.1:55950\",\n", " model_uid = \"333e1d68-3507-11ee-a0d6-fa163e74fa2d\" # model_uid is the uid returned from launching the model\n", ")" ] @@ -380,6 +387,106 @@ "source": [ "We can see the impressive capabilities of the LLM, and LangChain's \"chaining\" feature also allows for more coherent and context-aware interactions with the model." ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Concurrent Embedding Creation" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Xinference also supports creating embeddings concurrently. This will speed up the process of storing the document into the vector database. To run the following code, first restart the milvus server. Here, we still use the 40B Falcon-instruct model we launched before." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.document_loaders import TextLoader\n", + "from langchain.text_splitter import RecursiveCharacterTextSplitter\n", + "from langchain.embeddings import XinferenceEmbeddings\n", + "from langchain.vectorstores import Milvus\n", + "\n", + "import threading\n", + "\n", + "def process_chunk(chunk):\n", + " vector_db.add_documents(documents=chunk)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "All chunks processed successfully.\n" + ] + } + ], + "source": [ + "num_chunks = 10 # replace this with the number of the threads you want to execute in parallel + 1\n", + "\n", + "loader = TextLoader(\"/home/nijiayi/inference/examples/state_of_the_union.txt\")\n", + "documents = loader.load()\n", + "\n", + "text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=100, length_function=len)\n", + "docs = text_splitter.split_documents(documents)\n", + "\n", + "\n", + "xinference_embeddings = XinferenceEmbeddings(\n", + " server_url=\"http://127.0.0.1:55950\", \n", + " model_uid = \"46bf725e-3a5e-11ee-9dcd-fa163e74fa2d\" # model_uid is the uid returned from launching the model\n", + ")\n", + "\n", + "chunks = [docs[i::num_chunks] for i in range(num_chunks)] \n", + "\n", + "vector_db = Milvus.from_documents(\n", + " chunks[0],\n", + " xinference_embeddings,\n", + " connection_args={\"host\": \"0.0.0.0\", \"port\": \"19530\"},\n", + ")\n", + "\n", + "# add chunks of document to the vector_db in parallel\n", + "threads = [threading.Thread(target=process_chunk, args=(chunk,)) for chunk in chunks[1:]]\n", + "\n", + "for thread in threads:\n", + " thread.start()\n", + "\n", + "for thread in threads:\n", + " thread.join()\n", + "\n", + "print(\"All chunks processed successfully.\")" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "One of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court. \n", + "\n", + "And I did that 4 days ago, when I nominated Circuit Court of Appeals Judge Ketanji Brown Jackson. One of our nation’s top legal minds, who will continue Justice Breyer’s legacy of excellence.\n" + ] + } + ], + "source": [ + "query = \"what does the president say about Ketanji Brown Jackson\"\n", + "docs = vector_db.similarity_search(query, k=10)\n", + "print(docs[0].page_content) " + ] } ], "metadata": { @@ -404,4 +511,4 @@ }, "nbformat": 4, "nbformat_minor": 2 -} \ No newline at end of file +} diff --git a/xinference/core/model.py b/xinference/core/model.py index 64240f12cb..0674959ee6 100644 --- a/xinference/core/model.py +++ b/xinference/core/model.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import inspect from typing import TYPE_CHECKING, Any, Generic, Iterator, List, Optional, TypeVar, Union @@ -54,7 +55,7 @@ async def __anext__(self) -> T: raise -class ModelActor(xo.Actor): +class ModelActor(xo.StatelessActor): @classmethod def gen_uid(cls, model: "LLM"): return f"{model.__class__}-model-actor" @@ -81,6 +82,7 @@ def __init__(self, model: "LLM"): super().__init__() self._model = model self._generator: Optional[Iterator] = None + self._lock = asyncio.Lock() def load(self): self._model.load() @@ -95,20 +97,28 @@ async def _wrap_generator(self, ret: Any): return ret async def generate(self, prompt: str, *args, **kwargs): - if not hasattr(self._model, "generate"): - raise AttributeError(f"Model {self._model.model_spec} is not for generate.") + async with self._lock: + if not hasattr(self._model, "generate"): + raise AttributeError( + f"Model {self._model.model_spec} is not for generate." + ) + + result = await self._wrap_generator( + getattr(self._model, "generate")(prompt, *args, **kwargs) + ) - return self._wrap_generator( - getattr(self._model, "generate")(prompt, *args, **kwargs) - ) + return result async def chat(self, prompt: str, *args, **kwargs): - if not hasattr(self._model, "chat"): - raise AttributeError(f"Model {self._model.model_spec} is not for chat.") + async with self._lock: + if not hasattr(self._model, "chat"): + raise AttributeError(f"Model {self._model.model_spec} is not for chat.") + + result = await self._wrap_generator( + getattr(self._model, "chat")(prompt, *args, **kwargs) + ) - return self._wrap_generator( - getattr(self._model, "chat")(prompt, *args, **kwargs) - ) + return result async def create_embedding(self, input: Union[str, List[str]], *args, **kwargs): if not hasattr(self._model, "create_embedding"): diff --git a/xinference/tests/test_concurrent_embedding.py b/xinference/tests/test_concurrent_embedding.py new file mode 100644 index 0000000000..b86663987b --- /dev/null +++ b/xinference/tests/test_concurrent_embedding.py @@ -0,0 +1,84 @@ +# Copyright 2022-2023 XProbe Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +Simple test for multithreaded embedding creation +""" +import threading +import time + +from xinference.client import RESTfulClient + +lock = threading.Lock() +concurrent_results = {} +nonconcurrent_results = {} + + +def embedding_thread(model, text): + global concurrent_results + embedding = model.create_embedding(text) + with lock: + concurrent_results[text] = embedding + + +def nonconcurrent_embedding(model, texts): + global nonconcurrent_results + for text in texts: + embedding = model.create_embedding(text) + nonconcurrent_results[text] = embedding + + +def main(): + client = RESTfulClient("http://127.0.0.1:20881") + model_uid = client.launch_model( + model_name="opt", + model_size_in_billions=1, + model_format="pytorch", + quantization="8-bit", + ) + model = client.get_model(model_uid) + + texts = ["Once upon a time", "Hello, world!", "Hi"] + + start_time = time.time() + + threads = [] + for text in texts: + thread = threading.Thread(target=embedding_thread, args=(model, text)) + threads.append(thread) + thread.start() + + for thread in threads: + thread.join() + + end_time = time.time() + print(f"Concurrent Time: {end_time - start_time:.4f} seconds") + + start_time = time.time() + nonconcurrent_embedding(model, texts) + end_time = time.time() + print(f"Nonconcurrent Time: {end_time - start_time:.4f} seconds") + + print("Comparing embeddings...") + + for text in texts: + if concurrent_results[text] == nonconcurrent_results[text]: + print(f"Embedding for '{text}' matches.") + else: + print(f"Embedding for '{text}' does not match.") + + +if __name__ == "__main__": + main()