Skip to content

Commit

Permalink
Merge pull request #504 from aws-samples/xuhan-dev
Browse files Browse the repository at this point in the history
fix: fix UAT issues
  • Loading branch information
NingLu authored Jan 7, 2025
2 parents 7c74364 + 1d6d485 commit 1c8bdf0
Show file tree
Hide file tree
Showing 4 changed files with 293 additions and 215 deletions.
234 changes: 141 additions & 93 deletions source/lambda/etl/sfn_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,124 +2,172 @@
import logging
import os
from datetime import datetime, timezone
from typing import Dict, List, TypedDict

import boto3
from chatbot_management import create_chatbot
from constant import ExecutionStatus, IndexType, UiStatus
from utils.parameter_utils import get_query_parameter

# Initialize AWS resources once
client = boto3.client("stepfunctions")
dynamodb = boto3.resource("dynamodb")
execution_table = dynamodb.Table(os.environ.get("EXECUTION_TABLE_NAME"))
index_table = dynamodb.Table(os.environ.get("INDEX_TABLE_NAME"))
chatbot_table = dynamodb.Table(os.environ.get("CHATBOT_TABLE_NAME"))
model_table = dynamodb.Table(os.environ.get("MODEL_TABLE_NAME"))
embedding_endpoint = os.environ.get("EMBEDDING_ENDPOINT")
index_table = dynamodb.Table(os.environ.get("INDEX_TABLE_NAME"))
sfn_arn = os.environ.get("SFN_ARN")
create_time = str(datetime.now(timezone.utc))


# Consolidate constants at the top
CORS_HEADERS = {
"Content-Type": "application/json",
"Access-Control-Allow-Headers": "Content-Type,X-Amz-Date,Authorization,X-Api-Key,X-Amz-Security-Token",
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "*",
}

# Initialize logging at the top level
logger = logging.getLogger()
logger.setLevel(logging.INFO)


def handler(event, context):
# Check the event for possible S3 created event
input_payload = {}
logger.info(event)
resp_header = {
"Content-Type": "application/json",
"Access-Control-Allow-Headers": "Content-Type,X-Amz-Date,Authorization,X-Api-Key,X-Amz-Security-Token",
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "*",
def validate_index_type(index_type: str) -> bool:
"""Validate if the provided index type is supported."""
valid_types = [
IndexType.QD.value,
IndexType.QQ.value,
IndexType.INTENTION.value,
]
return index_type in valid_types


def get_etl_info(group_name: str, chatbot_id: str, index_type: str):
"""
Retrieve the index id, model type, and model endpoint for the given chatbot and index type.
These will be further used to perform knowledge ingestion to opensearch.
Returns: Tuple of (index_id, model_type, model_endpoint)
"""

chatbot_item = chatbot_table.get_item(
Key={"groupName": group_name, "chatbotId": chatbot_id}
).get("Item")

model_item = model_table.get_item(
Key={"groupName": group_name, "modelId": f"{chatbot_id}-embedding"}
).get("Item")

if not (chatbot_item and model_item):
raise ValueError("Chatbot or model not found")

model = model_item.get("parameter", {})
specific_type_indices = (
chatbot_item.get("indexIds", {}).get(index_type, {}).get("value", {})
)

if not specific_type_indices:
raise ValueError("No indices found for the given index type")

return (
next(iter(specific_type_indices.values())), # First index ID
model.get("ModelType"),
model.get("ModelEndpoint"),
)


def create_execution_record(
execution_id: str, input_body: Dict, sfn_execution_id: str
) -> None:
"""Create execution record in DynamoDB."""
execution_record = {
**input_body,
"sfnExecutionId": sfn_execution_id,
"executionStatus": ExecutionStatus.IN_PROGRESS.value,
"executionId": execution_id,
"uiStatus": UiStatus.ACTIVE.value,
"createTime": str(datetime.now(timezone.utc)),
}
del execution_record["tableItemId"]
execution_table.put_item(Item=execution_record)


def handler(event: Dict, context) -> Dict:
"""Main Lambda handler for ETL operations."""
logger.info(event)

authorizer_type = event["requestContext"].get("authorizer", {}).get("authorizerType")
if authorizer_type == "lambda_authorizer":
claims = json.loads(event["requestContext"]["authorizer"]["claims"])
try:
# Validate and extract authorization
authorizer = event["requestContext"].get("authorizer", {})
if authorizer.get("authorizerType") != "lambda_authorizer":
raise ValueError("Invalid authorizer type")

claims = json.loads(authorizer.get("claims", {}))
if "use_api_key" in claims:
group_name = get_query_parameter(event, "GroupName", "Admin")
cognito_groups_list = [group_name]
else:
cognito_groups = claims["cognito:groups"]
cognito_groups_list = cognito_groups.split(",")
else:
logger.error("Invalid authorizer type")
cognito_groups_list = claims["cognito:groups"].split(",")

# Process input
input_body = json.loads(event["body"])
index_type = input_body.get("indexType")

if not validate_index_type(index_type):
return {
"statusCode": 400,
"headers": CORS_HEADERS,
"body": f"Invalid indexType, valid values are {', '.join([t.value for t in IndexType])}",
}

group_name = input_body.get("groupName") or (
"Admin"
if "Admin" in cognito_groups_list
else cognito_groups_list[0]
)
chatbot_id = input_body.get("chatbotId", group_name.lower())
index_id, embedding_model_type, embedding_endpoint = get_etl_info(
group_name, chatbot_id, index_type
)

# Update input body with processed values
input_body.update(
{
"chatbotId": chatbot_id,
"groupName": group_name,
"tableItemId": context.aws_request_id,
"indexId": index_id,
"embeddingModelType": embedding_model_type,
"embeddingEndpoint": embedding_endpoint,
}
)

# Start step function and create execution record
sfn_response = client.start_execution(
stateMachineArn=sfn_arn, input=json.dumps(input_body)
)

execution_id = context.aws_request_id
create_execution_record(
execution_id,
input_body,
sfn_response["executionArn"].split(":")[-1],
)

return {
"statusCode": 403,
"headers": resp_header,
"body": json.dumps({"error": "Invalid authorizer type"}),
"statusCode": 200,
"headers": CORS_HEADERS,
"body": json.dumps(
{
"execution_id": execution_id,
"step_function_arn": sfn_response["executionArn"],
"input_payload": input_body,
}
),
}

# Parse the body from the event object
input_body = json.loads(event["body"])
if "indexType" not in input_body or input_body["indexType"] not in [
IndexType.QD.value,
IndexType.QQ.value,
IndexType.INTENTION.value,
]:
except Exception as e:
logger.error(f"Error processing request: {str(e)}")
return {
"statusCode": 400,
"headers": resp_header,
"body": (
f"Invalid indexType, valid values are "
f"{IndexType.QD.value}, {IndexType.QQ.value}, "
f"{IndexType.INTENTION.value}"
),
"statusCode": 500,
"headers": CORS_HEADERS,
"body": json.dumps({"error": str(e)}),
}
index_type = input_body["indexType"]
group_name = "Admin" if "Admin" in cognito_groups_list else cognito_groups_list[0]
chatbot_id = input_body.get("chatbotId", group_name.lower())

if "indexId" in input_body:
index_id = input_body["indexId"]
else:
# Use default index id if not specified in the request
index_id = f"{chatbot_id}-qd-default"
if index_type == IndexType.QQ.value:
index_id = f"{chatbot_id}-qq-default"
elif index_type == IndexType.INTENTION.value:
index_id = f"{chatbot_id}-intention-default"

if "tag" in input_body:
tag = input_body["tag"]
else:
tag = index_id

input_body["indexId"] = index_id
input_body["groupName"] = group_name if "groupName" not in input_body else input_body["groupName"]
chatbot_event_body = input_body
chatbot_event_body["group_name"] = group_name
chatbot_event = {"body": json.dumps(chatbot_event_body)}
chatbot_result = create_chatbot(chatbot_event, group_name)

input_body["tableItemId"] = context.aws_request_id
input_body["chatbotId"] = chatbot_id
input_body["embeddingModelType"] = chatbot_result["modelType"]
input_payload = json.dumps(input_body)
response = client.start_execution(stateMachineArn=sfn_arn, input=input_payload)

# Update execution table item
if "tableItemId" in input_body:
del input_body["tableItemId"]
execution_id = response["executionArn"].split(":")[-1]
input_body["sfnExecutionId"] = execution_id
input_body["executionStatus"] = ExecutionStatus.IN_PROGRESS.value
input_body["indexId"] = index_id
input_body["executionId"] = context.aws_request_id
input_body["uiStatus"] = UiStatus.ACTIVE.value
input_body["createTime"] = create_time

execution_table.put_item(Item=input_body)

return {
"statusCode": 200,
"headers": resp_header,
"body": json.dumps(
{
"execution_id": context.aws_request_id,
"step_function_arn": response["executionArn"],
"input_payload": input_payload,
}
),
}
Loading

0 comments on commit 1c8bdf0

Please sign in to comment.