diff --git a/apps/chatbot/src/modules/models.py b/apps/chatbot/src/modules/models.py index 986a92842e..5c16704e7d 100644 --- a/apps/chatbot/src/modules/models.py +++ b/apps/chatbot/src/modules/models.py @@ -1,3 +1,4 @@ +import boto3 import os import logging @@ -34,11 +35,11 @@ MODEL_MAXTOKENS = os.getenv("CHB_MODEL_MAXTOKENS", "768") EMBED_MODEL_ID = os.getenv("CHB_EMBED_MODEL_ID") +CROSS_ACCOUNT_ROLE_ARN = os.getenv("CHB_CROSS_ACCOUNT_ROLE_ARN") def get_llm(): if PROVIDER == "aws": - class ModelEventHandler(BaseEventHandler): @classmethod def class_name(cls) -> str: @@ -56,15 +57,37 @@ def handle(self, event) -> None: root_dispatcher = get_dispatcher() root_dispatcher.add_event_handler(ModelEventHandler()) - - llm = BedrockConverse( - model=MODEL_ID, - temperature=float(MODEL_TEMPERATURE), - max_tokens=int(MODEL_MAXTOKENS), - aws_access_key_id=AWS_ACCESS_KEY_ID, - aws_secret_access_key=AWS_SECRET_ACCESS_KEY, - region_name=AWS_BEDROCK_REGION - ) + if CROSS_ACCOUNT_ROLE_ARN: + # Create an STS client + sts_client = boto3.client('sts') + + # Assume the role + assumed_role_object = sts_client.assume_role( + RoleArn=CROSS_ACCOUNT_ROLE_ARN, + RoleSessionName="chatbot-cross-account-generation" + ) + + # Retrieve the temporary credentials + credentials = assumed_role_object['Credentials'] + + llm = BedrockConverse( + model=MODEL_ID, + temperature=float(MODEL_TEMPERATURE), + max_tokens=int(MODEL_MAXTOKENS), + aws_access_key_id=credentials['AccessKeyId'], + aws_secret_access_key=credentials['SecretAccessKey'], + aws_session_token=credentials['SessionToken'], + region_name=AWS_BEDROCK_REGION + ) + else: + llm = BedrockConverse( + model=MODEL_ID, + temperature=float(MODEL_TEMPERATURE), + max_tokens=int(MODEL_MAXTOKENS), + aws_access_key_id=AWS_ACCESS_KEY_ID, + aws_secret_access_key=AWS_SECRET_ACCESS_KEY, + region_name=AWS_BEDROCK_REGION + ) else: llm = Gemini( @@ -86,14 +109,34 @@ def handle(self, event) -> None: def get_embed_model(): - if PROVIDER == "aws": - embed_model = BedrockEmbedding( - model_name = EMBED_MODEL_ID, - aws_access_key_id=AWS_ACCESS_KEY_ID, - aws_secret_access_key=AWS_SECRET_ACCESS_KEY, - region_name=AWS_BEDROCK_REGION - ) + if CROSS_ACCOUNT_ROLE_ARN: + # Create an STS client + sts_client = boto3.client('sts') + + # Assume the role + assumed_role_object = sts_client.assume_role( + RoleArn=CROSS_ACCOUNT_ROLE_ARN, + RoleSessionName="chatbot-cross-account-generation" + ) + + # Retrieve the temporary credentials + credentials = assumed_role_object['Credentials'] + + embed_model = BedrockEmbedding( + model_name = EMBED_MODEL_ID, + aws_access_key_id=credentials['AccessKeyId'], + aws_secret_access_key=credentials['SecretAccessKey'], + aws_session_token=credentials['SessionToken'], + region_name=AWS_BEDROCK_REGION + ) + else: + embed_model = BedrockEmbedding( + model_name = EMBED_MODEL_ID, + aws_access_key_id=AWS_ACCESS_KEY_ID, + aws_secret_access_key=AWS_SECRET_ACCESS_KEY, + region_name=AWS_BEDROCK_REGION + ) else: embed_model = GeminiEmbedding( api_key=GOOGLE_API_KEY, diff --git a/apps/infrastructure/src/modules/chatbot/data.tf b/apps/infrastructure/src/modules/chatbot/data.tf index 3e20bd5788..aca1d31f6f 100644 --- a/apps/infrastructure/src/modules/chatbot/data.tf +++ b/apps/infrastructure/src/modules/chatbot/data.tf @@ -26,6 +26,15 @@ data "aws_iam_policy_document" "lambda_s3_policy" { ] resources = ["*"] } + + statement { + sid = "AssumeCrossAccountRole" + effect = "Allow" + actions = [ + "sts:AssumeRole" + ] + resources = [local.cross_account_role_arn] + } } data "aws_iam_policy_document" "lambda_dynamodb_policy" { diff --git a/apps/infrastructure/src/modules/chatbot/lambda_chatbot.tf b/apps/infrastructure/src/modules/chatbot/lambda_chatbot.tf index 452bef0838..11bdc8ffcc 100644 --- a/apps/infrastructure/src/modules/chatbot/lambda_chatbot.tf +++ b/apps/infrastructure/src/modules/chatbot/lambda_chatbot.tf @@ -1,5 +1,5 @@ locals { - lambda_env_variables = { + lambda_env_variables = merge({ CHB_AWS_S3_BUCKET = module.s3_bucket_llamaindex.s3_bucket_id CHB_AWS_GUARDRAIL_ID = awscc_bedrock_guardrail.guardrail.guardrail_id CHB_AWS_GUARDRAIL_VERSION = awscc_bedrock_guardrail_version.guardrail.version @@ -23,7 +23,8 @@ locals { CHB_GOOGLE_API_KEY = module.google_api_key_ssm_parameter.ssm_parameter_name CHB_QUERY_TABLE_PREFIX = local.prefix CHB_LLAMAINDEX_INDEX_ID = module.index_id_ssm_parameter.ssm_parameter_name - } + }, + var.environment == "prod" ? { CHB_CROSS_ACCOUNT_ROLE_ARN = local.cross_account_role_arn } : {}) } module "lambda_function" { diff --git a/apps/infrastructure/src/modules/chatbot/locals.tf b/apps/infrastructure/src/modules/chatbot/locals.tf index a7d32eac97..9e319a557f 100644 --- a/apps/infrastructure/src/modules/chatbot/locals.tf +++ b/apps/infrastructure/src/modules/chatbot/locals.tf @@ -5,4 +5,6 @@ locals { redis_container_name = "redis-stack" lambda_timeout = 180 + + cross_account_role_arn = "arn:aws:iam::039804388894:role/chatbot-dev-cross-account-generation" } \ No newline at end of file