Skip to content

Commit

Permalink
More updates to integrated vectorization, fixes type checks
Browse files Browse the repository at this point in the history
  • Loading branch information
pamelafox committed Oct 17, 2024
1 parent d4e40b8 commit 06c0956
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 123 deletions.
4 changes: 4 additions & 0 deletions app/backend/prepdocs.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,10 @@ async def main(strategy: Strategy, setup_index: bool = True):

ingestion_strategy: Strategy
if use_int_vectorization:

if not openai_embeddings_service or not isinstance(openai_embeddings_service, AzureOpenAIEmbeddingService):
raise Exception("Integrated vectorization strategy requires an Azure OpenAI embeddings service")

ingestion_strategy = IntegratedVectorizerStrategy(
search_info=search_info,
list_file_strategy=list_file_strategy,
Expand Down
13 changes: 2 additions & 11 deletions app/backend/prepdocslib/integratedvectorizerstrategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,14 @@ def __init__(
list_file_strategy: ListFileStrategy,
blob_manager: BlobManager,
search_info: SearchInfo,
embeddings: Optional[AzureOpenAIEmbeddingService],
embeddings: AzureOpenAIEmbeddingService,
subscription_id: str,
search_service_user_assigned_id: str,
document_action: DocumentAction = DocumentAction.Add,
search_analyzer_name: Optional[str] = None,
use_acls: bool = False,
category: Optional[str] = None,
):
if not embeddings or not isinstance(embeddings, AzureOpenAIEmbeddingService):
raise Exception("Expecting AzureOpenAI embedding service")

self.list_file_strategy = list_file_strategy
self.blob_manager = blob_manager
Expand Down Expand Up @@ -78,9 +76,6 @@ async def create_embedding_skill(self, index_name: str):
outputs=[OutputFieldMappingEntry(name="textItems", target_name="pages")],
)

if self.embeddings is None:
raise ValueError("Expecting Azure Open AI instance")

embedding_skill = AzureOpenAIEmbeddingSkill(
name=f"{index_name}-embedding-skill",
description="Skill to generate embeddings via Azure OpenAI",
Expand Down Expand Up @@ -123,6 +118,7 @@ async def create_embedding_skill(self, index_name: str):
return skillset

async def setup(self):
logger.info("Setting up search index using integrated vectorization...")
search_manager = SearchManager(
search_info=self.search_info,
search_analyzer_name=self.search_analyzer_name,
Expand All @@ -132,12 +128,8 @@ async def setup(self):
search_images=False,
)

if self.embeddings is None:
raise ValueError("Expecting Azure Open AI instance")

await search_manager.create_index()

# create indexer client
ds_client = self.search_info.create_search_indexer_client()
ds_container = SearchIndexerDataContainer(name=self.blob_manager.container)
data_source_connection = SearchIndexerDataSourceConnection(
Expand All @@ -149,7 +141,6 @@ async def setup(self):
)

await ds_client.create_or_update_data_source_connection(data_source_connection)
logger.info("Search indexer data source connection updated.")

embedding_skillset = await self.create_embedding_skill(self.search_info.index_name)
await ds_client.create_or_update_skillset(embedding_skillset)
Expand Down
265 changes: 153 additions & 112 deletions app/backend/prepdocslib/searchmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
)

from .blobmanager import BlobManager
from .embeddings import OpenAIEmbeddings
from .embeddings import AzureOpenAIEmbeddingService, OpenAIEmbeddings
from .listfilestrategy import File
from .strategy import SearchInfo
from .textsplitter import SplitPage
Expand Down Expand Up @@ -67,149 +67,190 @@ def __init__(
self.search_images = search_images

async def create_index(self, vectorizers: Optional[List[VectorSearchVectorizer]] = None):
logger.info("Ensuring search index %s exists", self.search_info.index_name)
logger.info("Checking whether search index %s exists...", self.search_info.index_name)

async with self.search_info.create_search_index_client() as search_index_client:
fields = [
(
SimpleField(name="id", type="Edm.String", key=True)
if not self.use_int_vectorization
else SearchField(
name="id",

if self.search_info.index_name not in [name async for name in search_index_client.list_index_names()]:
logger.info("Creating new search index %s", self.search_info.index_name)
fields = [
(
SimpleField(name="id", type="Edm.String", key=True)
if not self.use_int_vectorization
else SearchField(
name="id",
type="Edm.String",
key=True,
sortable=True,
filterable=True,
facetable=True,
analyzer_name="keyword",
)
),
SearchableField(
name="content",
type="Edm.String",
key=True,
sortable=True,
filterable=True,
facetable=True,
analyzer_name="keyword",
)
),
SearchableField(
name="content",
type="Edm.String",
analyzer_name=self.search_analyzer_name,
),
SearchField(
name="embedding",
type=SearchFieldDataType.Collection(SearchFieldDataType.Single),
hidden=False,
searchable=True,
filterable=False,
sortable=False,
facetable=False,
vector_search_dimensions=self.embedding_dimensions,
vector_search_profile_name="embedding_config",
),
SimpleField(name="category", type="Edm.String", filterable=True, facetable=True),
SimpleField(
name="sourcepage",
type="Edm.String",
filterable=True,
facetable=True,
),
SimpleField(
name="sourcefile",
type="Edm.String",
filterable=True,
facetable=True,
),
SimpleField(
name="storageUrl",
type="Edm.String",
filterable=True,
facetable=False,
),
]
if self.use_acls:
fields.append(
SimpleField(
name="oids",
type=SearchFieldDataType.Collection(SearchFieldDataType.String),
filterable=True,
)
)
fields.append(
SimpleField(
name="groups",
type=SearchFieldDataType.Collection(SearchFieldDataType.String),
filterable=True,
)
)
if self.use_int_vectorization:
fields.append(SearchableField(name="parent_id", type="Edm.String", filterable=True))
if self.search_images:
fields.append(
analyzer_name=self.search_analyzer_name,
),
SearchField(
name="imageEmbedding",
name="embedding",
type=SearchFieldDataType.Collection(SearchFieldDataType.Single),
hidden=False,
searchable=True,
filterable=False,
sortable=False,
facetable=False,
vector_search_dimensions=1024,
vector_search_dimensions=self.embedding_dimensions,
vector_search_profile_name="embedding_config",
),
)

index = SearchIndex(
name=self.search_info.index_name,
fields=fields,
semantic_search=SemanticSearch(
configurations=[
SemanticConfiguration(
name="default",
prioritized_fields=SemanticPrioritizedFields(
title_field=None, content_fields=[SemanticField(field_name="content")]
),
SimpleField(name="category", type="Edm.String", filterable=True, facetable=True),
SimpleField(
name="sourcepage",
type="Edm.String",
filterable=True,
facetable=True,
),
SimpleField(
name="sourcefile",
type="Edm.String",
filterable=True,
facetable=True,
),
SimpleField(
name="storageUrl",
type="Edm.String",
filterable=True,
facetable=False,
),
]
if self.use_acls:
fields.append(
SimpleField(
name="oids",
type=SearchFieldDataType.Collection(SearchFieldDataType.String),
filterable=True,
)
]
),
vector_search=VectorSearch(
algorithms=[
HnswAlgorithmConfiguration(
name="hnsw_config",
parameters=HnswParameters(metric="cosine"),
)
fields.append(
SimpleField(
name="groups",
type=SearchFieldDataType.Collection(SearchFieldDataType.String),
filterable=True,
)
],
profiles=[
VectorSearchProfile(
name="embedding_config",
algorithm_configuration_name="hnsw_config",
vectorizer_name=(
f"{self.search_info.index_name}-vectorizer" if self.use_int_vectorization else None
),
)
if self.use_int_vectorization:
logger.info("Including parent_id field in new index %s", self.search_info.index_name)
fields.append(SearchableField(name="parent_id", type="Edm.String", filterable=True))
if self.search_images:
logger.info("Including imageEmbedding field in new index %s", self.search_info.index_name)
fields.append(
SearchField(
name="imageEmbedding",
type=SearchFieldDataType.Collection(SearchFieldDataType.Single),
hidden=False,
searchable=True,
filterable=False,
sortable=False,
facetable=False,
vector_search_dimensions=1024,
vector_search_profile_name="embedding_config",
),
],
vectorizers=[
)

vectorizers = []
if self.embeddings and isinstance(self.embeddings, AzureOpenAIEmbeddingService):
logger.info(
"Including vectorizer for search index %s, using Azure OpenAI service %s",
self.search_info.index_name,
self.embeddings.open_ai_service,
)
vectorizers.append(
AzureOpenAIVectorizer(
vectorizer_name=f"{self.search_info.index_name}-vectorizer",
parameters=AzureOpenAIVectorizerParameters(
resource_url=f"https://{self.embeddings.open_ai_service}.openai.azure.com",
resource_url=self.embeddings.open_ai_endpoint,
deployment_name=self.embeddings.open_ai_deployment,
model_name=self.embeddings.open_ai_model_name,
),
),
],
),
)
if self.search_info.index_name not in [name async for name in search_index_client.list_index_names()]:
logger.info("Creating %s search index", self.search_info.index_name)
)
)
else:
logger.info(
"Not including vectorizer for search index %s, no Azure OpenAI service found",
self.search_info.index_name,
)

index = SearchIndex(
name=self.search_info.index_name,
fields=fields,
semantic_search=SemanticSearch(
configurations=[
SemanticConfiguration(
name="default",
prioritized_fields=SemanticPrioritizedFields(
title_field=None, content_fields=[SemanticField(field_name="content")]
),
)
]
),
vector_search=VectorSearch(
algorithms=[
HnswAlgorithmConfiguration(
name="hnsw_config",
parameters=HnswParameters(metric="cosine"),
)
],
profiles=[
VectorSearchProfile(
name="embedding_config",
algorithm_configuration_name="hnsw_config",
vectorizer_name=(
f"{self.search_info.index_name}-vectorizer" if self.use_int_vectorization else None
),
),
],
vectorizers=vectorizers,
),
)

await search_index_client.create_index(index)
else:
logger.info("Search index %s already exists", self.search_info.index_name)
index_definition = await search_index_client.get_index(self.search_info.index_name)
if not any(field.name == "storageUrl" for field in index_definition.fields):
existing_index = await search_index_client.get_index(self.search_info.index_name)
if not any(field.name == "storageUrl" for field in existing_index.fields):
logger.info("Adding storageUrl field to index %s", self.search_info.index_name)
index_definition.fields.append(
existing_index.fields.append(
SimpleField(
name="storageUrl",
type="Edm.String",
filterable=True,
facetable=False,
),
)
await search_index_client.create_or_update_index(index_definition)
await search_index_client.create_or_update_index(existing_index)

if existing_index.vector_search is not None and (
existing_index.vector_search.vectorizers is None
or len(existing_index.vector_search.vectorizers) == 0
):
if self.embeddings is not None:
logger.info("Adding vectorizer to search index %s", self.search_info.index_name)
existing_index.vector_search.vectorizers = [
AzureOpenAIVectorizer(
vectorizer_name=f"{self.search_info.index_name}-vectorizer",
parameters=AzureOpenAIVectorizerParameters(
resource_url=self.embeddings.open_ai_endpoint,
deployment_name=self.embeddings.open_ai_deployment,
model_name=self.embeddings.open_ai_model_name,
),
)
]
await search_index_client.create_or_update_index(existing_index)
else:
logger.info(
"Can't add vectorizer to search index %s since embeddings service isn't defined",
self.search_info,
)

async def update_content(
self, sections: List[Section], image_embeddings: Optional[List[List[float]]] = None, url: Optional[str] = None
Expand Down

0 comments on commit 06c0956

Please sign in to comment.