diff --git a/ChatQnA/chatqna.py b/ChatQnA/chatqna.py index e9ace60885..81dedb063c 100644 --- a/ChatQnA/chatqna.py +++ b/ChatQnA/chatqna.py @@ -58,11 +58,17 @@ def generate_rag_prompt(question, documents): LLM_SERVER_HOST_IP = os.getenv("LLM_SERVER_HOST_IP", "0.0.0.0") LLM_SERVER_PORT = int(os.getenv("LLM_SERVER_PORT", 80)) LLM_MODEL = os.getenv("LLM_MODEL", "Intel/neural-chat-7b-v3-3") +EMBEDDINGS_MODEL_ID = os.getenv("EMBEDDINGS_MODEL_ID", "BAAI/bge-base-en-v1.5") +EMBEDDINGS_USE_VLLM = os.getenv("EMBEDDINGS_USE_VLLM", "false") def align_inputs(self, inputs, cur_node, runtime_graph, llm_parameters_dict, **kwargs): if self.services[cur_node].service_type == ServiceType.EMBEDDING: - inputs["inputs"] = inputs["text"] + if EMBEDDINGS_USE_VLLM == "true": + inputs["input"] = inputs["text"] + inputs["model"] = EMBEDDINGS_MODEL_ID + else: + inputs["inputs"] = inputs["text"] del inputs["text"] elif self.services[cur_node].service_type == ServiceType.RETRIEVER: # prepare the retriever params @@ -88,8 +94,12 @@ def align_inputs(self, inputs, cur_node, runtime_graph, llm_parameters_dict, **k def align_outputs(self, data, cur_node, inputs, runtime_graph, llm_parameters_dict, **kwargs): next_data = {} if self.services[cur_node].service_type == ServiceType.EMBEDDING: - assert isinstance(data, list) - next_data = {"text": inputs["inputs"], "embedding": data[0]} + if EMBEDDINGS_USE_VLLM == "true": + assert isinstance(data["data"][0]["embedding"], list) + next_data = {"text": inputs["input"], "embedding": data["data"][0]["embedding"]} + else: + assert isinstance(data, list) + next_data = {"text": inputs["inputs"], "embedding": data[0]} elif self.services[cur_node].service_type == ServiceType.RETRIEVER: docs = [doc["text"] for doc in data["retrieved_docs"]] @@ -329,6 +339,37 @@ def add_remote_service_with_guardrails(self): self.megaservice.flow_to(rerank, llm) # self.megaservice.flow_to(llm, guardrail_out) + def add_remote_service_with_vllm_embeddings(self): + embedding = MicroService( + name="embedding", + host=EMBEDDING_SERVER_HOST_IP, + port=EMBEDDING_SERVER_PORT, + endpoint="/v1/embeddings", + use_remote_service=True, + service_type=ServiceType.EMBEDDING, + ) + + retriever = MicroService( + name="retriever", + host=RETRIEVER_SERVICE_HOST_IP, + port=RETRIEVER_SERVICE_PORT, + endpoint="/v1/retrieval", + use_remote_service=True, + service_type=ServiceType.RETRIEVER, + ) + + llm = MicroService( + name="llm", + host=LLM_SERVER_HOST_IP, + port=LLM_SERVER_PORT, + endpoint="/v1/chat/completions", + use_remote_service=True, + service_type=ServiceType.LLM, + ) + self.megaservice.add(embedding).add(retriever).add(llm) + self.megaservice.flow_to(embedding, retriever) + self.megaservice.flow_to(retriever, llm) + async def handle_request(self, request: Request): data = await request.json() stream_opt = data.get("stream", True) @@ -403,7 +444,10 @@ def start(self): args = parser.parse_args() chatqna = ChatQnAService(port=MEGA_SERVICE_PORT) - if args.without_rerank: + if EMBEDDINGS_USE_VLLM == "true": + # this also doesn't use rerank + chatqna.add_remote_service_with_vllm_embeddings() + elif args.without_rerank: chatqna.add_remote_service_without_rerank() elif args.with_guardrails: chatqna.add_remote_service_with_guardrails()