Skip to content

Commit

Permalink
Fix GenAI stacks deployment and streamlit sagemaker invocation errors (
Browse files Browse the repository at this point in the history
…#224)

* json encode prompt

* Fix payload json format and ContentType

* Use index 0 to access generated text

* Switch to ECR python slim base image

* Update requirements

* Update model data sources and env vars
  • Loading branch information
iamsouravin authored Jul 3, 2024
1 parent 512eed8 commit 957d704
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 14 deletions.
8 changes: 4 additions & 4 deletions cdk/examples/generative_ai_service/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
aws-cdk-lib==2.122.0
python-dotenv==0.21.0
aws-cdk-lib==2.147.3
python-dotenv==1.0.1
streamlit
boto3
sagemaker==2.218.0
sagemaker==2.224.2
sentence_transformers
opensearch-py
torch==2.1.1
torch==2.3.1
2 changes: 1 addition & 1 deletion cdk/examples/generative_ai_service/web-app/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM python:3.9
FROM public.ecr.aws/docker/library/python:3.9-slim
WORKDIR /app
COPY requirements.txt ./requirements.txt
RUN pip3 install -r requirements.txt
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@ def get_parameter(name):

conversation = """Customers were very excited about the wireless charging feature, but the launch has not lived up to their expectations. The phones are not reliably charging and that is frustrating since it is such a fundamental aspect of any electronic device."""

parameters = {
'max_new_tokens': 50,
'top_k': 50,
'top_p': 0.95,
'do_sample': True,
}

with st.spinner("Retrieving configurations..."):
all_configs_loaded = False

Expand All @@ -47,13 +54,14 @@ def get_parameter(name):
with st.spinner("Wait for it..."):
try:
prompt = f"{context}\n{query}"
payload = {'inputs': prompt,'parameters': parameters}
response = runtime.invoke_endpoint(
EndpointName=endpoint_name,
Body=prompt,
ContentType="application/x-text",
Body=json.dumps(payload).encode('utf-8'),
ContentType="application/json",
)
response_body = json.loads(response["Body"].read().decode())
generated_text = response_body["generated_text"]
generated_text = response_body[0]["generated_text"]
st.write(generated_text)

except requests.exceptions.ConnectionError as errc:
Expand Down
16 changes: 14 additions & 2 deletions cdk/examples/other_stack/txt2img_generative_ai_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,25 @@ def __init__(
containers=[
CfnModel.ContainerDefinitionProperty(
image=model_info["model_docker_image"],
model_data_url= "s3://"+model_info["model_bucket_name"]+"/"+model_info["model_bucket_key"],
model_data_source=CfnModel.ModelDataSourceProperty(
s3_data_source=CfnModel.S3DataSourceProperty(
compression_type='None',
s3_data_type='S3Prefix',
s3_uri=f's3://{model_info["model_bucket_name"]}/{model_info["model_bucket_key"]}',
model_access_config=CfnModel.ModelAccessConfigProperty(
accept_eula=True,
),
),
),
environment={
"MMS_MAX_RESPONSE_SIZE": "20000000",
"SAGEMAKER_CONTAINER_LOG_LEVEL": "20",
"SAGEMAKER_PROGRAM": "inference.py",
"SAGEMAKER_REGION": model_info["region_name"],
"SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code",
"ENDPOINT_SERVER_TIMEOUT": "3600",
"MODEL_CACHE_ROOT": "/opt/ml/model",
"SAGEMAKER_ENV": "1",
"SAGEMAKER_MODEL_SERVER_WORKERS": "1",
}
)
]
Expand Down
19 changes: 15 additions & 4 deletions cdk/examples/other_stack/txt2txt_generative_ai_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,27 @@ def __init__(
containers=[
CfnModel.ContainerDefinitionProperty(
image=model_info["model_docker_image"],
model_data_url= "s3://"+model_info["model_bucket_name"]+"/"+model_info["model_bucket_key"],
model_data_source=CfnModel.ModelDataSourceProperty(
s3_data_source=CfnModel.S3DataSourceProperty(
compression_type='None',
s3_data_type='S3Prefix',
s3_uri=f's3://{model_info["model_bucket_name"]}/{model_info["model_bucket_key"]}',
model_access_config=CfnModel.ModelAccessConfigProperty(
accept_eula=True,
),
),
),
environment={
"MODEL_CACHE_ROOT": "/opt/ml/model",
"SAGEMAKER_ENV": "1",
"SAGEMAKER_MODEL_SERVER_TIMEOUT": "3600",
"ENDPOINT_SERVER_TIMEOUT": "3600",
"SAGEMAKER_MODEL_SERVER_WORKERS": "1",
"SAGEMAKER_PROGRAM": "inference.py",
"SAGEMAKER_REGION": model_info["region_name"],
"SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code/",
"TS_DEFAULT_WORKERS_PER_MODEL": "1"
"HF_MODEL_ID": "/opt/ml/model",
"MAX_INPUT_LENGTH": "1024",
"MAX_TOTAL_TOKENS": "2048",
"SM_NUM_GPUS": "1",
}
)
]
Expand Down

0 comments on commit 957d704

Please sign in to comment.