diff --git a/application/app.py b/application/app.py index 3ba2de69a..bcf2b4665 100644 --- a/application/app.py +++ b/application/app.py @@ -25,7 +25,6 @@ CohereEmbeddings, HuggingFaceInstructEmbeddings, ) -from langchain.llms import GPT4All from langchain.prompts import PromptTemplate from langchain.prompts.chat import ( ChatPromptTemplate, @@ -50,11 +49,20 @@ else: gpt_model = 'gpt-3.5-turbo' -if settings.LLM_NAME == "manifest": - from manifest import Manifest - from langchain.llms.manifest import ManifestWrapper - manifest = Manifest(client_name="huggingface", client_connection="http://127.0.0.1:5000") +if settings.SELF_HOSTED_MODEL == True: + from langchain.llms import HuggingFacePipeline + from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline + + model_id = settings.LLM_NAME # hf model id (Arc53/docsgpt-7b-falcon, Arc53/docsgpt-14b) + tokenizer = AutoTokenizer.from_pretrained(model_id) + model = AutoModelForCausalLM.from_pretrained(model_id) + pipe = pipeline( + "text-generation", model=model, + tokenizer=tokenizer, max_new_tokens=2000, + device_map="auto", eos_token_id=tokenizer.eos_token_id + ) + hf = HuggingFacePipeline(pipeline=pipe) # Redirect PosixPath to WindowsPath on Windows @@ -346,14 +354,10 @@ def api_answer(): p_chat_combine = ChatPromptTemplate.from_messages(messages_combine) elif settings.LLM_NAME == "openai": llm = OpenAI(openai_api_key=api_key, temperature=0) - elif settings.LLM_NAME == "manifest": - llm = ManifestWrapper(client=manifest, llm_kwargs={"temperature": 0.001, "max_tokens": 2048}) - elif settings.LLM_NAME == "huggingface": - llm = HuggingFaceHub(repo_id="bigscience/bloom", huggingfacehub_api_token=api_key) + elif settings.SELF_HOSTED_MODEL: + llm = hf elif settings.LLM_NAME == "cohere": llm = Cohere(model="command-xlarge-nightly", cohere_api_key=api_key) - elif settings.LLM_NAME == "gpt4all": - llm = GPT4All(model=settings.MODEL_PATH) else: raise ValueError("unknown LLM model") @@ -369,7 +373,7 @@ def api_answer(): # result = chain({"question": question, "chat_history": chat_history}) # generate async with async generate method result = run_async_chain(chain, question, chat_history) - elif settings.LLM_NAME == "gpt4all": + elif settings.SELF_HOSTED_MODEL: question_generator = LLMChain(llm=llm, prompt=CONDENSE_QUESTION_PROMPT) doc_chain = load_qa_chain(llm, chain_type="map_reduce", combine_prompt=p_chat_combine) chain = ConversationalRetrievalChain( diff --git a/application/core/settings.py b/application/core/settings.py index 34c8c0234..08673475e 100644 --- a/application/core/settings.py +++ b/application/core/settings.py @@ -11,6 +11,7 @@ class Settings(BaseSettings): MONGO_URI: str = "mongodb://localhost:27017/docsgpt" MODEL_PATH: str = "./models/gpt4all-model.bin" TOKENS_MAX_HISTORY: int = 150 + SELF_HOSTED_MODEL: bool = False API_URL: str = "http://localhost:7091" # backend url for celery worker diff --git a/docker-compose.yaml b/docker-compose.yaml index d5dd10e5b..aaef62eab 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -19,6 +19,7 @@ services: - CELERY_BROKER_URL=redis://redis:6379/0 - CELERY_RESULT_BACKEND=redis://redis:6379/1 - MONGO_URI=mongodb://mongo:27017/docsgpt + - SELF_HOSTED_MODEL=$SELF_HOSTED_MODEL ports: - "7091:7091" volumes: