Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Support concurrent embedding, update LangChain QA demo with multithreaded embedding creation #348

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Binary file added assets/langchainQA.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
129 changes: 118 additions & 11 deletions examples/LangChain_QA.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,14 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"This demo walks through how to build an LLM-driven question-answering (QA) application with Xinference, Milvus, and LangChain."
"This demo walks through how to build an LLM-driven question-answering (QA) application with Xinference, Milvus, and LangChain. It uses Falcon 40B Instruct model for embedding creation and Llama 2 70B Chat model for inference. Both of the models are fully supported by Xinference."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"![LangChain QA Visualization by Dreamsome](../assets/langchainQA.png)"
]
},
{
Expand All @@ -34,19 +41,19 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model uid: 19c73cee-3506-11ee-b286-fa163e74fa2d\n"
"Model uid: 46bf725e-3a5e-11ee-9dcd-fa163e74fa2d\n"
]
}
],
"source": [
"!xinference launch --model-name \"falcon-instruct\" --model-format pytorch --size-in-billions 40 -e \"http://127.0.0.1:56256\""
"!xinference launch --model-name \"falcon-instruct\" --model-format pytorch --size-in-billions 40 -e \"http://127.0.0.1:55950\""
]
},
{
Expand All @@ -65,7 +72,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -93,15 +100,15 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain.embeddings import XinferenceEmbeddings\n",
"\n",
"xinference_embeddings = XinferenceEmbeddings(\n",
" server_url=\"http://127.0.0.1:56256\", \n",
" model_uid = \"19c73cee-3506-11ee-b286-fa163e74fa2d\" # model_uid is the uid returned from launching the model\n",
" server_url=\"http://127.0.0.1:55950\", \n",
" model_uid = \"46bf725e-3a5e-11ee-9dcd-fa163e74fa2d\" # model_uid is the uid returned from launching the model\n",
")"
]
},
Expand Down Expand Up @@ -205,7 +212,7 @@
}
],
"source": [
"!xinference launch --model-name \"llama-2-chat\" --model-format ggmlv3 --size-in-billions 70 -e \"http://127.0.0.1:56256\""
"!xinference launch --model-name \"llama-2-chat\" --model-format ggmlv3 --size-in-billions 70 -e \"http://127.0.0.1:55950\""
]
},
{
Expand All @@ -217,7 +224,7 @@
"from langchain.llms import Xinference\n",
"\n",
"xinference_llm = Xinference(\n",
" server_url=\"http://127.0.0.1:56256\",\n",
" server_url=\"http://127.0.0.1:55950\",\n",
" model_uid = \"333e1d68-3507-11ee-a0d6-fa163e74fa2d\" # model_uid is the uid returned from launching the model\n",
")"
]
Expand Down Expand Up @@ -380,6 +387,106 @@
"source": [
"We can see the impressive capabilities of the LLM, and LangChain's \"chaining\" feature also allows for more coherent and context-aware interactions with the model."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Concurrent Embedding Creation"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Xinference also supports creating embeddings concurrently. This will speed up the process of storing the document into the vector database. To run the following code, first restart the milvus server. Here, we still use the 40B Falcon-instruct model we launched before."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"from langchain.document_loaders import TextLoader\n",
"from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
"from langchain.embeddings import XinferenceEmbeddings\n",
"from langchain.vectorstores import Milvus\n",
"\n",
"import threading\n",
"\n",
"def process_chunk(chunk):\n",
" vector_db.add_documents(documents=chunk)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"All chunks processed successfully.\n"
]
}
],
"source": [
"num_chunks = 10 # replace this with the number of the threads you want to execute in parallel + 1\n",
"\n",
"loader = TextLoader(\"/home/nijiayi/inference/examples/state_of_the_union.txt\")\n",
"documents = loader.load()\n",
"\n",
"text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=100, length_function=len)\n",
"docs = text_splitter.split_documents(documents)\n",
"\n",
"\n",
"xinference_embeddings = XinferenceEmbeddings(\n",
" server_url=\"http://127.0.0.1:55950\", \n",
" model_uid = \"46bf725e-3a5e-11ee-9dcd-fa163e74fa2d\" # model_uid is the uid returned from launching the model\n",
")\n",
"\n",
"chunks = [docs[i::num_chunks] for i in range(num_chunks)] \n",
"\n",
"vector_db = Milvus.from_documents(\n",
" chunks[0],\n",
" xinference_embeddings,\n",
" connection_args={\"host\": \"0.0.0.0\", \"port\": \"19530\"},\n",
")\n",
"\n",
"# add chunks of document to the vector_db in parallel\n",
"threads = [threading.Thread(target=process_chunk, args=(chunk,)) for chunk in chunks[1:]]\n",
"\n",
"for thread in threads:\n",
" thread.start()\n",
"\n",
"for thread in threads:\n",
" thread.join()\n",
"\n",
"print(\"All chunks processed successfully.\")"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"One of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court. \n",
"\n",
"And I did that 4 days ago, when I nominated Circuit Court of Appeals Judge Ketanji Brown Jackson. One of our nation’s top legal minds, who will continue Justice Breyer’s legacy of excellence.\n"
]
}
],
"source": [
"query = \"what does the president say about Ketanji Brown Jackson\"\n",
"docs = vector_db.similarity_search(query, k=10)\n",
"print(docs[0].page_content) "
]
}
],
"metadata": {
Expand All @@ -404,4 +511,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}
30 changes: 18 additions & 12 deletions xinference/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import inspect
from typing import TYPE_CHECKING, Any, Generic, Iterator, List, Optional, TypeVar, Union

Expand Down Expand Up @@ -54,7 +55,7 @@ async def __anext__(self) -> T:
raise


class ModelActor(xo.Actor):
class ModelActor(xo.StatelessActor):
@classmethod
def gen_uid(cls, model: "LLM"):
return f"{model.__class__}-model-actor"
Expand All @@ -81,6 +82,7 @@ def __init__(self, model: "LLM"):
super().__init__()
self._model = model
self._generator: Optional[Iterator] = None
self._lock = asyncio.Lock()

def load(self):
self._model.load()
Expand All @@ -95,20 +97,24 @@ async def _wrap_generator(self, ret: Any):
return ret

async def generate(self, prompt: str, *args, **kwargs):
if not hasattr(self._model, "generate"):
raise AttributeError(f"Model {self._model.model_spec} is not for generate.")

return self._wrap_generator(
getattr(self._model, "generate")(prompt, *args, **kwargs)
)
async with self._lock:
jiayini1119 marked this conversation as resolved.
Show resolved Hide resolved
if not hasattr(self._model, "generate"):
raise AttributeError(
f"Model {self._model.model_spec} is not for generate."
)

return self._wrap_generator(
getattr(self._model, "generate")(prompt, *args, **kwargs)
)

async def chat(self, prompt: str, *args, **kwargs):
if not hasattr(self._model, "chat"):
raise AttributeError(f"Model {self._model.model_spec} is not for chat.")
async with self._lock:
if not hasattr(self._model, "chat"):
raise AttributeError(f"Model {self._model.model_spec} is not for chat.")

return self._wrap_generator(
getattr(self._model, "chat")(prompt, *args, **kwargs)
)
return self._wrap_generator(
getattr(self._model, "chat")(prompt, *args, **kwargs)
)

async def create_embedding(self, input: Union[str, List[str]], *args, **kwargs):
if not hasattr(self._model, "create_embedding"):
Expand Down
63 changes: 63 additions & 0 deletions xinference/tests/test_concurrent_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright 2022-2023 XProbe Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""
Simple test for multithreaded embedding creation
"""

import threading
import time

from xinference.client import RESTfulClient


def embedding_thread(model, text):
model.create_embedding(text)


def nonconcurrent_embedding(model, texts):
for text in texts:
model.create_embedding(text)


def main():
client = RESTfulClient("http://127.0.0.1:35819")
model_uid = client.launch_model(model_name="orca", quantization="q4_0")
model = client.get_model(model_uid)

texts = ["Once upon a time", "Hello, world!", "Hi"]

start_time = time.time()

threads = []
for text in texts:
thread = threading.Thread(target=embedding_thread, args=(model, text))
threads.append(thread)
thread.start()

for thread in threads:
thread.join()

end_time = time.time()
print(f"Concurrent Time: {end_time - start_time:.4f} seconds")

start_time = time.time()
nonconcurrent_embedding(model, texts)
end_time = time.time()
print(f"Nonconcurrent Time: {end_time - start_time:.4f} seconds")
jiayini1119 marked this conversation as resolved.
Show resolved Hide resolved


if __name__ == "__main__":
main()