Skip to content

Commit

Permalink
chore: Refactor embeddings instantiation to use a singleton pattern
Browse files Browse the repository at this point in the history
  • Loading branch information
dartpain committed Jun 14, 2024
1 parent 558ecd8 commit 3454309
Showing 1 changed file with 38 additions and 20 deletions.
58 changes: 38 additions & 20 deletions application/vectorstore/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,30 @@
from langchain_openai import OpenAIEmbeddings
from application.core.settings import settings

class EmbeddingsSingleton:
_instances = {}

@staticmethod
def get_instance(embeddings_name, *args, **kwargs):
if embeddings_name not in EmbeddingsSingleton._instances:
EmbeddingsSingleton._instances[embeddings_name] = EmbeddingsSingleton._create_instance(embeddings_name, *args, **kwargs)
return EmbeddingsSingleton._instances[embeddings_name]

Check warning on line 18 in application/vectorstore/base.py

View check run for this annotation

Codecov / codecov/patch

application/vectorstore/base.py#L18

Added line #L18 was not covered by tests

@staticmethod
def _create_instance(embeddings_name, *args, **kwargs):
embeddings_factory = {
"openai_text-embedding-ada-002": OpenAIEmbeddings,
"huggingface_sentence-transformers/all-mpnet-base-v2": HuggingFaceEmbeddings,
"huggingface_sentence-transformers-all-mpnet-base-v2": HuggingFaceEmbeddings,
"huggingface_hkunlp/instructor-large": HuggingFaceInstructEmbeddings,
"cohere_medium": CohereEmbeddings
}

if embeddings_name not in embeddings_factory:
raise ValueError(f"Invalid embeddings_name: {embeddings_name}")

Check warning on line 31 in application/vectorstore/base.py

View check run for this annotation

Codecov / codecov/patch

application/vectorstore/base.py#L31

Added line #L31 was not covered by tests

return embeddings_factory[embeddings_name](*args, **kwargs)

class BaseVectorStore(ABC):
def __init__(self):
pass
Expand All @@ -20,42 +44,36 @@ def is_azure_configured(self):
return settings.OPENAI_API_BASE and settings.OPENAI_API_VERSION and settings.AZURE_DEPLOYMENT_NAME

def _get_embeddings(self, embeddings_name, embeddings_key=None):
embeddings_factory = {
"openai_text-embedding-ada-002": OpenAIEmbeddings,
"huggingface_sentence-transformers/all-mpnet-base-v2": HuggingFaceEmbeddings,
"huggingface_hkunlp/instructor-large": HuggingFaceInstructEmbeddings,
"cohere_medium": CohereEmbeddings
}

if embeddings_name not in embeddings_factory:
raise ValueError(f"Invalid embeddings_name: {embeddings_name}")

if embeddings_name == "openai_text-embedding-ada-002":
if self.is_azure_configured():
os.environ["OPENAI_API_TYPE"] = "azure"
embedding_instance = embeddings_factory[embeddings_name](
embedding_instance = EmbeddingsSingleton.get_instance(

Check warning on line 50 in application/vectorstore/base.py

View check run for this annotation

Codecov / codecov/patch

application/vectorstore/base.py#L50

Added line #L50 was not covered by tests
embeddings_name,
model=settings.AZURE_EMBEDDINGS_DEPLOYMENT_NAME
)
else:
embedding_instance = embeddings_factory[embeddings_name](
embedding_instance = EmbeddingsSingleton.get_instance(
embeddings_name,
openai_api_key=embeddings_key
)
elif embeddings_name == "cohere_medium":
embedding_instance = embeddings_factory[embeddings_name](
embedding_instance = EmbeddingsSingleton.get_instance(

Check warning on line 60 in application/vectorstore/base.py

View check run for this annotation

Codecov / codecov/patch

application/vectorstore/base.py#L60

Added line #L60 was not covered by tests
embeddings_name,
cohere_api_key=embeddings_key
)
elif embeddings_name == "huggingface_sentence-transformers/all-mpnet-base-v2":
if os.path.exists("./model/all-mpnet-base-v2"):
embedding_instance = embeddings_factory[embeddings_name](
embedding_instance = EmbeddingsSingleton.get_instance(

Check warning on line 66 in application/vectorstore/base.py

View check run for this annotation

Codecov / codecov/patch

application/vectorstore/base.py#L66

Added line #L66 was not covered by tests
embeddings_name,
model_name="./model/all-mpnet-base-v2",
model_kwargs={"device": "cpu"},
model_kwargs={"device": "cpu"}
)
else:
embedding_instance = embeddings_factory[embeddings_name](
model_kwargs={"device": "cpu"},
embedding_instance = EmbeddingsSingleton.get_instance(

Check warning on line 72 in application/vectorstore/base.py

View check run for this annotation

Codecov / codecov/patch

application/vectorstore/base.py#L72

Added line #L72 was not covered by tests
embeddings_name,
model_kwargs={"device": "cpu"}
)
else:
embedding_instance = embeddings_factory[embeddings_name]()

return embedding_instance
embedding_instance = EmbeddingsSingleton.get_instance(embeddings_name)

Check warning on line 77 in application/vectorstore/base.py

View check run for this annotation

Codecov / codecov/patch

application/vectorstore/base.py#L77

Added line #L77 was not covered by tests

return embedding_instance

Check warning on line 79 in application/vectorstore/base.py

View check run for this annotation

Codecov / codecov/patch

application/vectorstore/base.py#L79

Added line #L79 was not covered by tests

0 comments on commit 3454309

Please sign in to comment.