From b58a6f303ea59543152ad259ed2db5ce9bd5712e Mon Sep 17 00:00:00 2001 From: Jon Fairbanks Date: Thu, 29 Feb 2024 18:13:18 -0800 Subject: [PATCH] Setup HuggingFaceEmbedding() --- utils/llama_index.py | 33 ++++++++++++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/utils/llama_index.py b/utils/llama_index.py index b37cf29..d1b034e 100644 --- a/utils/llama_index.py +++ b/utils/llama_index.py @@ -4,6 +4,10 @@ import utils.logs as logs +from llama_index.embeddings.huggingface import HuggingFaceEmbedding +from llama_index.embeddings.text_embeddings_inference import ( + TextEmbeddingsInference, +) # This is not used but required by llama-index and must be imported FIRST os.environ["OPENAI_API_KEY"] = "sk-abc123" @@ -15,20 +19,44 @@ set_global_service_context, ) + +################################### +# +# Setup Embedding Model +# +################################### + +@st.cache_resource(show_spinner=False) +def setup_embedding_model( + model: str, + timeout: int = 60, + embed_batch_size: int = 10 +): + embed_model = HuggingFaceEmbedding( + model_name=model, + timeout=timeout, + embed_batch_size=embed_batch_size, + ) + return embed_model + + ################################### # # Create Service Context # ################################### +# TODO: Migrate to LlamaIndex.Settings: https://docs.llamaindex.ai/en/stable/module_guides/supporting_modules/service_context_migration.html @st.cache_resource(show_spinner=False) def create_service_context( _llm, # TODO: Determine type system_prompt: str = None, # TODO: What are the implications of no system prompt being passed? embed_model: str = "BAAI/bge-large-en-v1.5", + embed_timeout: int = 60, + embed_batch_size: int = 10, chunk_size: int = 1024, # Llama-Index default is 1024 - chunk_overlap: int = 20, # Llama-Index default is 1024 + chunk_overlap: int = 200, # Llama-Index default is 200 ): """ Create a service context with the specified language model and embedding model. @@ -45,6 +73,7 @@ def create_service_context( """ formatted_embed_model = f"local:{embed_model}" try: + embedding_model = setup_embedding_model(embed_model, embed_timeout, embed_batch_size) service_context = ServiceContext.from_defaults( llm=_llm, system_prompt=system_prompt, @@ -54,8 +83,10 @@ def create_service_context( ) logs.log.info(f"Service Context created successfully") st.session_state["service_context"] = service_context + # Note: this may be redundant since service_context is returned set_global_service_context(service_context) + return service_context except Exception as e: logs.log.error(f"Failed to create service_context: {e}")