Skip to content

Commit

Permalink
Setup HuggingFaceEmbedding()
Browse files Browse the repository at this point in the history
  • Loading branch information
jonfairbanks committed Mar 1, 2024
1 parent 01f74da commit b58a6f3
Showing 1 changed file with 32 additions and 1 deletion.
33 changes: 32 additions & 1 deletion utils/llama_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@

import utils.logs as logs

from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.embeddings.text_embeddings_inference import (
TextEmbeddingsInference,
)

# This is not used but required by llama-index and must be imported FIRST
os.environ["OPENAI_API_KEY"] = "sk-abc123"
Expand All @@ -15,20 +19,44 @@
set_global_service_context,
)


###################################
#
# Setup Embedding Model
#
###################################

@st.cache_resource(show_spinner=False)
def setup_embedding_model(
model: str,
timeout: int = 60,
embed_batch_size: int = 10
):
embed_model = HuggingFaceEmbedding(
model_name=model,
timeout=timeout,
embed_batch_size=embed_batch_size,
)
return embed_model


###################################
#
# Create Service Context
#
###################################

# TODO: Migrate to LlamaIndex.Settings: https://docs.llamaindex.ai/en/stable/module_guides/supporting_modules/service_context_migration.html

@st.cache_resource(show_spinner=False)
def create_service_context(
_llm, # TODO: Determine type
system_prompt: str = None, # TODO: What are the implications of no system prompt being passed?
embed_model: str = "BAAI/bge-large-en-v1.5",
embed_timeout: int = 60,
embed_batch_size: int = 10,
chunk_size: int = 1024, # Llama-Index default is 1024
chunk_overlap: int = 20, # Llama-Index default is 1024
chunk_overlap: int = 200, # Llama-Index default is 200
):
"""
Create a service context with the specified language model and embedding model.
Expand All @@ -45,6 +73,7 @@ def create_service_context(
"""
formatted_embed_model = f"local:{embed_model}"
try:
embedding_model = setup_embedding_model(embed_model, embed_timeout, embed_batch_size)
service_context = ServiceContext.from_defaults(
llm=_llm,
system_prompt=system_prompt,
Expand All @@ -54,8 +83,10 @@ def create_service_context(
)
logs.log.info(f"Service Context created successfully")
st.session_state["service_context"] = service_context

# Note: this may be redundant since service_context is returned
set_global_service_context(service_context)

return service_context
except Exception as e:
logs.log.error(f"Failed to create service_context: {e}")
Expand Down

0 comments on commit b58a6f3

Please sign in to comment.