Skip to content

Commit

Permalink
Replace SageMaker with Bedrock connector for generating embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
bmquinn committed Aug 14, 2024
1 parent 5ccf0a0 commit d29b193
Show file tree
Hide file tree
Showing 7 changed files with 97 additions and 248 deletions.
32 changes: 16 additions & 16 deletions data_services/.terraform.lock.hcl

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

88 changes: 41 additions & 47 deletions data_services/opensearch_connector.tf
Original file line number Diff line number Diff line change
@@ -1,78 +1,73 @@
locals {
connector_spec = { for key, config in var.sagemaker_configurations : key => {
name = "${local.namespace}-${config.name}-embedding"
description = "Opensearch Connector for ${config.name}"
version = 1
protocol = "aws_sigv4"
connector_spec = {
name = "${local.namespace}-embedding"
description = "Opensearch Connector for ${var.embedding_model_name} via Amazon Bedrock"
version = 1
protocol = "aws_sigv4"

credential = {
roleArn = aws_iam_role.opensearch_connector.arn
}

parameters = {
region = data.aws_region.current.name
service_name = "sagemaker"
region = data.aws_region.current.name
service_name = "bedrock"
model_name = var.embedding_model_name
}

actions = [
{
action_type = "predict"
method = "POST"
action_type = "predict"
method = "POST"

headers = {
"content-type" = "application/json"
}

url = local.embedding_invocation_url[key]
post_process_function = file("${path.module}/opensearch_connector/post-process.painless")
request_body = "{\"inputs\": $${parameters.input}}"
url = "https://bedrock-runtime.$${parameters.region}.amazonaws.com/model/$${parameters.model_name}/invoke"
post_process_function = file("${path.module}/opensearch_connector/post-process.painless")
request_body = "{\"texts\": $${parameters.input}, \"input_type\": \"search_document\"}"
}
]

client_config = {
max_connections = config.max_concurrency / var.opensearch_cluster_nodes
connection_timeout = 5000
read_timeout = 60000
}
}}
}
}

data "aws_iam_policy_document" "opensearch_connector_assume_role" {
statement {
effect = "Allow"
actions = ["sts:AssumeRole"]
effect = "Allow"
actions = ["sts:AssumeRole"]

principals {
type = "Service"
identifiers = ["opensearchservice.amazonaws.com"]
type = "Service"
identifiers = ["opensearchservice.amazonaws.com"]
}
}
}

data "aws_iam_policy_document" "opensearch_connector_role" {
statement {
effect = "Allow"
actions = [
"sagemaker:InvokeEndpoint",
"sagemaker:InvokeEndpointAsync"
effect = "Allow"
actions = [
"bedrock:InvokeModel",
"bedrock:InvokeModelWithResultStream",
]
resources = [ for endpoint in aws_sagemaker_endpoint.serverless_inference : endpoint.arn ]
resources = ["*"]
}
}

resource "aws_iam_policy" "opensearch_connector" {
name = "${local.namespace}-opensearch-connector"
policy = data.aws_iam_policy_document.opensearch_connector_role.json
name = "${local.namespace}-opensearch-connector"
policy = data.aws_iam_policy_document.opensearch_connector_role.json
}

resource "aws_iam_role" "opensearch_connector" {
name = "${local.namespace}-opensearch-connector"
assume_role_policy = data.aws_iam_policy_document.opensearch_connector_assume_role.json
name = "${local.namespace}-opensearch-connector"
assume_role_policy = data.aws_iam_policy_document.opensearch_connector_assume_role.json
}

resource "aws_iam_role_policy_attachment" "opensearch_connector" {
role = aws_iam_role.opensearch_connector.id
policy_arn = aws_iam_policy.opensearch_connector.arn
role = aws_iam_role.opensearch_connector.id
policy_arn = aws_iam_policy.opensearch_connector.arn
}

data "aws_iam_policy_document" "deploy_model_lambda" {
Expand All @@ -93,29 +88,28 @@ module "deploy_model_lambda" {
source = "terraform-aws-modules/lambda/aws"
version = "~> 7.2.1"

function_name = "${local.namespace}-deploy-opensearch-ml-model"
description = "Utility lambda to deploy a SageMaker model within Opensearch"
handler = "index.handler"
runtime = "nodejs18.x"
source_path = "${path.module}/deploy_model_lambda"
timeout = 30
attach_policy_json = true
policy_json = data.aws_iam_policy_document.deploy_model_lambda.json
function_name = "${local.namespace}-deploy-opensearch-ml-model"
description = "Utility lambda to deploy an embedding model within Opensearch"
handler = "index.handler"
runtime = "nodejs18.x"
source_path = "${path.module}/deploy_model_lambda"
timeout = 30
attach_policy_json = true
policy_json = data.aws_iam_policy_document.deploy_model_lambda.json

environment_variables = {
OPENSEARCH_ENDPOINT = aws_opensearch_domain.elasticsearch.endpoint
}
}

resource "aws_lambda_invocation" "deploy_model" {
for_each = local.connector_spec
function_name = module.deploy_model_lambda.lambda_function_name
lifecycle_scope = "CRUD"

input = jsonencode({
namespace = local.namespace
connector_spec = local.connector_spec[each.key]
model_name = "${each.value.name}/huggingface/${var.model_repository}"
model_version = "1.0.0"
namespace = local.namespace
connector_spec = local.connector_spec
model_name = var.embedding_model_name
model_version = "1.0.0"
})
}
15 changes: 10 additions & 5 deletions data_services/opensearch_connector/post-process.painless
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
def name = 'sentence_embedding';
def dataType = 'FLOAT32';
if (params.embedding == null || params.embedding.length == 0) {
return params.message;
if (params.embeddings == null || params.embeddings.length == 0) {
return params.message;
}
def shape = [params.embedding.length];
def json = '{"name":"' + name + '","data_type":"' + dataType + '","shape":' + shape + ',"data":' + params.embedding + '}';
return json;

def embedding = params.embeddings[0];
if (embedding == null || embedding.length == 0) {
return params.message;
}
def shape = [embedding.length];
def json = '{"name":"' + name + '","data_type":"' + dataType + '","shape":' + shape + ',"data":' + embedding + '}';
return json;
26 changes: 13 additions & 13 deletions data_services/outputs.tf
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
locals {
deploy_model_result = { for key in keys(var.sagemaker_configurations) : key => jsondecode(aws_lambda_invocation.deploy_model[key].result) }
deploy_model_body = { for key in keys(var.sagemaker_configurations) : key => jsondecode(local.deploy_model_result[key].body) }
deploy_model_result = jsondecode(aws_lambda_invocation.deploy_model.result)
deploy_model_body = jsondecode(local.deploy_model_result.body)
}

output "elasticsearch" {
Expand All @@ -14,22 +14,22 @@ output "elasticsearch" {
}

output "inference" {
value = { for key, value in local.deploy_model_body : key => {
endpoint_name = aws_sagemaker_endpoint.serverless_inference[key].name
invocation_url = local.embedding_invocation_url[key]
opensearch_model_id = lookup(value, "model_id", "DEPLOY ERROR")
}}
value = {
endpoint_name = var.embedding_model_name
invocation_url = "https://bedrock-runtime.${data.aws_region.current.name}.amazonaws.com/model/${var.embedding_model_name}/invoke"
opensearch_model_id = lookup(local.deploy_model_body, "model_id", "DEPLOY ERROR")
}
}

output "search_snapshot_configuration" {
value = {
create_url = "https://${aws_opensearch_domain.elasticsearch.endpoint}/_snapshot/"
create_doc = jsonencode({
type = "s3"
create_url = "https://${aws_opensearch_domain.elasticsearch.endpoint}/_snapshot/"
create_doc = jsonencode({
type = "s3"
settings = {
bucket = aws_s3_bucket.elasticsearch_snapshot_bucket.id
region = data.aws_region.current.name
role_arn = aws_iam_role.elasticsearch_snapshot_bucket_access.arn
bucket = aws_s3_bucket.elasticsearch_snapshot_bucket.id
region = data.aws_region.current.name
role_arn = aws_iam_role.elasticsearch_snapshot_bucket_access.arn
}
})
}
Expand Down
Loading

0 comments on commit d29b193

Please sign in to comment.