Skip to content

Commit

Permalink
refactor variable names to be more generic and add integration tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jgbradley1 committed Jan 3, 2025
1 parent ff5714a commit 0252646
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 32 deletions.
46 changes: 17 additions & 29 deletions backend/src/api/index_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,15 @@

import inspect
import os
import shutil
import traceback

import graphrag.api as api
import yaml
from fastapi import (
APIRouter,
HTTPException,
)
from fastapi.responses import StreamingResponse
from graphrag.prompt_tune.cli import prompt_tune as generate_fine_tune_prompts
from graphrag.config.create_graphrag_config import create_graphrag_config

from src.api.azure_clients import AzureClientManager
from src.api.common import (
Expand All @@ -27,7 +26,7 @@

@index_configuration_route.get(
"/prompts",
summary="Generate graphrag prompts from user-provided data.",
summary="Generate prompts from user-provided data.",
description="Generating custom prompts from user-provided data may take several minutes to run based on the amount of data used.",
)
async def generate_prompts(storage_name: str, limit: int = 5):
Expand All @@ -44,29 +43,23 @@ async def generate_prompts(storage_name: str, limit: int = 5):
status_code=500,
detail=f"Data container '{storage_name}' does not exist.",
)

# load pipeline configuration file (settings.yaml) for input data and other settings
this_directory = os.path.dirname(
os.path.abspath(inspect.getfile(inspect.currentframe()))
)

# write custom settings.yaml to a file and store in a temporary directory
data = yaml.safe_load(open(f"{this_directory}/pipeline-settings.yaml"))
data["input"]["container_name"] = sanitized_storage_name
temp_dir = f"/tmp/{sanitized_storage_name}_prompt_tuning"
shutil.rmtree(temp_dir, ignore_errors=True)
os.makedirs(temp_dir, exist_ok=True)
with open(f"{temp_dir}/settings.yaml", "w") as f:
yaml.dump(data, f, default_flow_style=False)
graphrag_config = create_graphrag_config(values=data, root_dir=".")

# generate prompts
try:
await generate_fine_tune_prompts(
config=f"{temp_dir}/settings.yaml",
root=temp_dir,
domain="",
selection_method="random",
# NOTE: we need to call api.generate_indexing_prompts
prompts: tuple[str, str, str] = await api.generate_indexing_prompts(
config=graphrag_config,
root=".",
limit=limit,
skip_entity_types=True,
output=f"{temp_dir}/prompts",
selection_method="random",
)
except Exception as e:
logger = LoggerSingleton().get_instance()
Expand All @@ -84,14 +77,9 @@ async def generate_prompts(storage_name: str, limit: int = 5):
detail=f"Error generating prompts for data in '{storage_name}'. Please try a lower limit.",
)

# zip up the generated prompt files and return the zip file
temp_archive = (
f"{temp_dir}/prompts" # will become a zip file with the name prompts.zip
)
shutil.make_archive(temp_archive, "zip", root_dir=temp_dir, base_dir="prompts")

def iterfile(file_path: str):
with open(file_path, mode="rb") as file_like:
yield from file_like

return StreamingResponse(iterfile(f"{temp_archive}.zip"))
content = {
"entity_extraction_prompt": prompts[0],
"entity_summarization_prompt": prompts[1],
"community_summarization_prompt": prompts[2],
}
return content # return a fastapi.responses.JSONResponse object
35 changes: 35 additions & 0 deletions backend/tests/integration/test_utils_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ def cosmos_index_job_entry(cosmos_client) -> Generator[str, None, None]:


def test_pipeline_job_interface(cosmos_index_job_entry):
"""Test the src.utils.pipeline.PipelineJob class interface."""
pipeline_job = PipelineJob()

# test creating a new entry
pipeline_job.create_item(
id="synthetic_id",
Expand All @@ -69,3 +71,36 @@ def test_pipeline_job_interface(cosmos_index_job_entry):
assert pipeline_job.status == PipelineJobState.COMPLETE
assert pipeline_job.percent_complete == 50.0
assert pipeline_job.progress == "some progress"
assert pipeline_job.calculate_percent_complete() == 50.0

# test setters and getters
pipeline_job.id = "newID"
assert pipeline_job.id == "newID"
pipeline_job.epoch_request_time = 1
assert pipeline_job.epoch_request_time == 1

pipeline_job.human_readable_index_name = "new_human_readable_index_name"
assert pipeline_job.human_readable_index_name == "new_human_readable_index_name"
pipeline_job.sanitized_index_name = "new_sanitized_index_name"
assert pipeline_job.sanitized_index_name == "new_sanitized_index_name"

pipeline_job.human_readable_storage_name = "new_human_readable_storage_name"
assert pipeline_job.human_readable_storage_name == "new_human_readable_storage_name"
pipeline_job.sanitized_storage_name = "new_sanitized_storage_name"
assert pipeline_job.sanitized_storage_name == "new_sanitized_storage_name"

pipeline_job.entity_extraction_prompt = "new_entity_extraction_prompt"
assert pipeline_job.entity_extraction_prompt == "new_entity_extraction_prompt"
pipeline_job.community_report_prompt = "new_community_report_prompt"
assert pipeline_job.community_report_prompt == "new_community_report_prompt"
pipeline_job.summarize_descriptions_prompt = "new_summarize_descriptions_prompt"
assert pipeline_job.summarize_descriptions_prompt == "new_summarize_descriptions_prompt"

pipeline_job.all_workflows = ["new_workflow1", "new_workflow2", "new_workflow3"]
assert len(pipeline_job.all_workflows) == 3

pipeline_job.completed_workflows = ["new_workflow1", "new_workflow2"]
assert len(pipeline_job.completed_workflows) == 2

pipeline_job.failed_workflows = ["new_workflow3"]
assert len(pipeline_job.failed_workflows) == 1
6 changes: 3 additions & 3 deletions infra/deploy.sh
Original file line number Diff line number Diff line change
Expand Up @@ -347,12 +347,12 @@ deployAzureResources () {
--resource-group $RESOURCE_GROUP \
--template-file ./main.bicep \
--parameters "resourceBaseName=$RESOURCE_BASE_NAME" \
--parameters "graphRagName=$RESOURCE_GROUP" \
--parameters "resourceGroupName=$RESOURCE_GROUP" \
--parameters "apimName=$APIM_NAME" \
--parameters "apimTier=$APIM_TIER" \
--parameters "publisherName=$PUBLISHER_NAME" \
--parameters "apiPublisherName=$PUBLISHER_NAME" \
--parameters "apiPublisherEmail=$PUBLISHER_EMAIL" \
--parameters "aksSshRsaPublicKey=$SSH_PUBLICKEY" \
--parameters "publisherEmail=$PUBLISHER_EMAIL" \
--parameters "enablePrivateEndpoints=$ENABLE_PRIVATE_ENDPOINTS" \
--parameters "acrName=$CONTAINER_REGISTRY_NAME" \
--parameters "deployerPrincipalId=$deployerPrincipalId" \
Expand Down

0 comments on commit 0252646

Please sign in to comment.