Skip to content

Commit

Permalink
Refactor summarizer
Browse files Browse the repository at this point in the history
  • Loading branch information
amandeepg committed Nov 28, 2024
1 parent 55df405 commit 13f8f21
Show file tree
Hide file tree
Showing 8 changed files with 92 additions and 5,107 deletions.
5,003 changes: 0 additions & 5,003 deletions serverless/path-summarize-alerts/package-lock.json

This file was deleted.

45 changes: 0 additions & 45 deletions serverless/path-summarize-alerts/requirements.txt

This file was deleted.

4 changes: 3 additions & 1 deletion serverless/path-summarize-alerts/serverless.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ provider:
memorySize: 256
timeout: 30
logRetentionInDays: 14
apiGateway:
minimumCompressionSize: 0

iam:
role:
Expand All @@ -36,7 +38,7 @@ provider:

functions:
summarize:
handler: src/handler.handler # Updated handler path
handler: src/handler.handler
events:
- http:
path: /summarize
Expand Down
37 changes: 17 additions & 20 deletions serverless/path-summarize-alerts/src/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,20 @@ def handler(event: APIGatewayProxyEvent, context: LambdaContext) -> Dict[str, An

try:
# Get input string from event
query_params = event.get('queryStringParameters', {})
if not query_params or 'input' not in query_params:
query_params = event.get("queryStringParameters", {})
if not query_params or "input" not in query_params:
logger.error("Missing input parameter in request")
return {
'statusCode': 400,
'body': json.dumps({'error': 'Missing input parameter'})
"statusCode": 400,
"body": json.dumps({"error": "Missing input parameter"}),
}

input_text = query_params['input']
input_text = query_params["input"]
logger.info(f"Processing request with input length: {len(input_text)}")

skip_cache = 'skip_cache' in query_params and query_params['skip_cache'] == os.environ.get(
"SKIP_CACHE_MAGIC_WORD")
skip_cache = "skip_cache" in query_params and query_params[
"skip_cache"
] == os.environ.get("SKIP_CACHE_MAGIC_WORD")

summarizer = AlertSummarizer()
result = summarizer.summarize(input_text, skip_cache=skip_cache)
Expand All @@ -40,22 +41,18 @@ def handler(event: APIGatewayProxyEvent, context: LambdaContext) -> Dict[str, An
logger.debug(f"Response: {result.model_dump_json()}")

return {
'statusCode': 200,
'body': result.model_dump_json(),
'headers': {
'Content-Type': 'application/json',
'Cache-Control': 'max-age=86400' # Cache for 24 hours
}
"statusCode": 200,
"body": result.model_dump_json(),
"headers": {
"Content-Type": "application/json",
"Cache-Control": "max-age=86400", # Cache for 24 hours
},
}

except Exception as e:
logger.exception("Error processing request")
return {
'statusCode': 500,
'body': json.dumps({
'error': str(e)
}),
'headers': {
'Content-Type': 'application/json'
}
"statusCode": 500,
"body": json.dumps({"error": str(e)}),
"headers": {"Content-Type": "application/json"},
}
28 changes: 19 additions & 9 deletions serverless/path-summarize-alerts/src/lib/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,17 @@

class CacheService:
def __init__(self, bucket_name: str):
self.s3_client = boto3.client('s3')
self.s3_client = boto3.client("s3")
self.bucket_name = bucket_name
logger.info(f"Initialized CacheService with bucket: {bucket_name}")

@staticmethod
def hash_category_key() -> str:
"""Create a SHA-1 hash of the prompt."""
input_string = f"{CacheResponse.model_json_schema}|{SYSTEM_MESSAGE}|{MODEL_NAME}"
hash_value = hashlib.sha1(input_string.encode('utf-8')).hexdigest()
input_string = (
f"{CacheResponse.model_json_schema}|{SYSTEM_MESSAGE}|{MODEL_NAME}"
)
hash_value = hashlib.sha1(input_string.encode("utf-8")).hexdigest()
return hash_value

@staticmethod
Expand All @@ -34,11 +36,17 @@ def create_versioned_key(hash_key: str) -> str:
def get(self, hash_key: str) -> Optional[str]:
"""Try to get cached response from S3."""
versioned_key = self.create_versioned_key(hash_key)
logger.debug(f"Attempting to retrieve cached response for versioned key: {versioned_key}")
logger.debug(
f"Attempting to retrieve cached response for versioned key: {versioned_key}"
)
try:
response = self.s3_client.get_object(Bucket=self.bucket_name, Key=versioned_key)
data = response['Body'].read().decode('utf-8')
logger.info(f"Successfully retrieved cached response for versioned key: {versioned_key}")
response = self.s3_client.get_object(
Bucket=self.bucket_name, Key=versioned_key
)
data = response["Body"].read().decode("utf-8")
logger.info(
f"Successfully retrieved cached response for versioned key: {versioned_key}"
)
logger.debug(f"Cache data: {json.dumps(data)}")
return data
except self.s3_client.exceptions.NoSuchKey:
Expand All @@ -59,9 +67,11 @@ def save(self, hash_key: str, data: str) -> None:
Bucket=self.bucket_name,
Key=versioned_key,
Body=data,
ContentType='application/json'
ContentType="application/json",
)
logger.info(
f"Successfully cached response for versioned key: {versioned_key}"
)
logger.info(f"Successfully cached response for versioned key: {versioned_key}")
logger.debug(f"Cached data: {data}")
except Exception as e:
logger.error(f"Error saving to lib: {str(e)}", exc_info=True)
Expand Down
2 changes: 1 addition & 1 deletion serverless/path-summarize-alerts/src/lib/constants.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
BUCKET_NAME = 'path-summarize-data'
BUCKET_NAME = "path-summarize-data"
MODEL_NAME = "gpt-4o"
SYSTEM_MESSAGE = """
You are a very helpful transit agency alert summarizer to take alert text from a transit agency (the PATH subway transit system in NYC and NJ) and make it more digestible for transit riders.
Expand Down
46 changes: 31 additions & 15 deletions serverless/path-summarize-alerts/src/lib/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,14 @@ class TimeRange(BaseModel):
chain_of_thought: str = Field(
..., description="Step by step reasoning to get the correct time range"
)
start_time: Optional[datetime] = Field(...,
description="The date and time when the alert begins being relevant, if applicable")
end_time: Optional[datetime] = Field(...,
description="The date and time when the alert ends being relevant, if applicable")
start_time: Optional[datetime] = Field(
...,
description="The date and time when the alert begins being relevant, if applicable",
)
end_time: Optional[datetime] = Field(
...,
description="The date and time when the alert ends being relevant, if applicable",
)


class Summary(BaseModel):
Expand All @@ -51,25 +55,25 @@ class Summary(BaseModel):
class AffectedStations(BaseModel):
chain_of_thought: str = Field(
...,
description="Step by step reasoning to get the stations that are going to be affected by this alert, i.e. riders at this station really care about this alert"
description="Step by step reasoning to get the stations that are going to be affected by this alert, i.e. riders at this station really care about this alert",
)
affected_stations: Optional[List[PathStation]]


class AffectedRoutes(BaseModel):
chain_of_thought: str = Field(
...,
description="Step by step reasoning to get the routes that are going to be affected by this alert, i.e. riders on this line really care about this alert"
description="Step by step reasoning to get the routes that are going to be affected by this alert, i.e. riders on this line really care about this alert",
)
affected_routes: Optional[List[PathLine]]

@field_serializer('affected_routes')
@field_serializer("affected_routes")
def get_lines_value(self, routes: Optional[List[PathLine]]) -> Optional[List[str]]:
if routes is None:
return None
return [route.name for route in routes]

@field_validator('affected_routes', mode='before')
@field_validator("affected_routes", mode="before")
@classmethod
def validate_routes(cls, value: Optional[List[str]]) -> Optional[List[PathLine]]:
if value is None:
Expand All @@ -86,24 +90,36 @@ def validate_routes(cls, value: Optional[List[str]]) -> Optional[List[PathLine]]
try:
valid_routes.append(PathLine(route.lower()))
except ValueError:
raise ValueError(f"Invalid route: {route}. Valid routes are: {[e.name for e in PathLine]}")
raise ValueError(
f"Invalid route: {route}. Valid routes are: {[e.name for e in PathLine]}"
)
elif isinstance(route, PathLine):
valid_routes.append(route)
else:
raise ValueError(f"Route must be string or PathLine enum, got {type(route)}")
raise ValueError(
f"Route must be string or PathLine enum, got {type(route)}"
)

return valid_routes


class AlertSummary(BaseModel):
text: Summary
is_delay: bool = Field(
description="indicating if the alert is about a delay on lines, true is yes, false if it is a general announcement")
is_relevant: bool = Field(description="indicating if the alert affects the rider's experience or not")
description="indicating if the alert is about a delay on lines, true is yes, false if it is a general announcement"
)
is_relevant: bool = Field(
description="indicating if the alert affects the rider's experience or not"
)
duration: Optional[TimeRange] = Field(
description="The date and time that this alert begins and ends being applicable. Events at Red Bull Arena are usually 3 hours.")
affected_routes: Optional[AffectedRoutes] = Field(description="List of affected routes")
affected_stations: Optional[AffectedStations] = Field(description="List of affected stations")
description="The date and time that this alert begins and ends being applicable. Events at Red Bull Arena are usually 3 hours."
)
affected_stations: Optional[AffectedStations] = Field(
description="List of affected stations"
)
affected_routes: Optional[AffectedRoutes] = Field(
description="List of affected routes"
)


class CacheResponse(BaseModel):
Expand Down
34 changes: 21 additions & 13 deletions serverless/path-summarize-alerts/src/lib/summarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@

from .cache import CacheService
from .constants import BUCKET_NAME, SYSTEM_MESSAGE, MODEL_NAME
from .models import *
from .models import CacheResponse
from .models import CacheResponse, AlertSummary

logger = Logger()
tracer = Tracer()
Expand All @@ -22,14 +21,18 @@ def __init__(self):
@staticmethod
def hash_string(input_string: str) -> str:
"""Create a SHA-1 hash of the input string."""
hash_value = hashlib.sha1(input_string.encode('utf-8')).hexdigest()
logger.debug(f"Generated hash: {hash_value} for input length: {len(input_string)}")
hash_value = hashlib.sha1(input_string.encode("utf-8")).hexdigest()
logger.debug(
f"Generated hash: {hash_value} for input length: {len(input_string)}"
)
return hash_value

@tracer.capture_method
def summarize(self, input_text: str, skip_cache: bool) -> CacheResponse:
"""Summarize the input text using OpenAI API with caching."""
logger.info(f"Processing new summarization request. Input length: {len(input_text)}")
logger.info(
f"Processing new summarization request. Input length: {len(input_text)}"
)
logger.debug(f"Raw input text: {input_text}")

hash_key = self.hash_string(input_text)
Expand All @@ -47,19 +50,24 @@ def summarize(self, input_text: str, skip_cache: bool) -> CacheResponse:
logger.info("Making OpenAI API request")
logger.debug(f"System message: {SYSTEM_MESSAGE}")

ai_response, completion = self.client.chat.completions.create_with_completion(
response_model=AlertSummary,
model=MODEL_NAME,
messages=[
{"role": "system", "content": SYSTEM_MESSAGE},
self.user_msg(input_text)
],
ai_response, completion = (
self.client.chat.completions.create_with_completion(
response_model=AlertSummary,
model=MODEL_NAME,
messages=[
{"role": "system", "content": SYSTEM_MESSAGE},
self.user_msg(input_text),
],
)
)

logger.info("Successfully received OpenAI API response")
logger.debug(f"Raw API response: {ai_response}")
logger.info(f"Usage: {completion.usage}")
cost = completion.usage.prompt_tokens * 2.5 / 1_000_000 + completion.usage.completion_tokens * 10 / 1_000_000
cost = (
completion.usage.prompt_tokens * 2.5 / 1_000_000
+ completion.usage.completion_tokens * 10 / 1_000_000
)
logger.info(f"Cost approx: {cost * 100:.2f} cents")

# Prepare response
Expand Down

0 comments on commit 13f8f21

Please sign in to comment.