diff --git a/.changelog/34310.txt b/.changelog/34310.txt new file mode 100644 index 00000000000..3efcda23b86 --- /dev/null +++ b/.changelog/34310.txt @@ -0,0 +1,11 @@ +```release-note:new-data-source +aws_bedrock_custom_model +``` + +```release-note:new-data-source +aws_bedrock_custom_models +``` + +```release-note:resource +aws_bedrock_custom_model +``` \ No newline at end of file diff --git a/internal/framework/base.go b/internal/framework/base.go index e8b29c9440c..94dbf5e792f 100644 --- a/internal/framework/base.go +++ b/internal/framework/base.go @@ -10,6 +10,7 @@ import ( "github.com/aws/aws-sdk-go-v2/aws/arn" "github.com/hashicorp/terraform-plugin-framework-timeouts/resource/timeouts" "github.com/hashicorp/terraform-plugin-framework/datasource" + "github.com/hashicorp/terraform-plugin-framework/diag" "github.com/hashicorp/terraform-plugin-framework/path" "github.com/hashicorp/terraform-plugin-framework/resource" "github.com/hashicorp/terraform-plugin-framework/types" @@ -107,10 +108,26 @@ func (w *WithImportByID) ImportState(ctx context.Context, request resource.Impor resource.ImportStatePassthroughID(ctx, path.Root(names.AttrID), request, response) } -// WithNoOpUpdate is intended to be embedded in resources which have no need of an Update method. -type WithNoOpUpdate struct{} +// WithNoUpdate is intended to be embedded in resources which cannot be updated. +type WithNoUpdate struct{} -func (w *WithNoOpUpdate) Update(ctx context.Context, request resource.UpdateRequest, response *resource.UpdateResponse) { +func (w *WithNoUpdate) Update(ctx context.Context, request resource.UpdateRequest, response *resource.UpdateResponse) { + response.Diagnostics.Append(diag.NewErrorDiagnostic("not supported", "This resource's Update method should not have been called")) +} + +// WithNoOpUpdate is intended to be embedded in resources which have no need of a custom Update method. +// For example, resources where only `tags` can be updated and that is handled via transparent tagging. +type WithNoOpUpdate[T any] struct{} + +func (w *WithNoOpUpdate[T]) Update(ctx context.Context, request resource.UpdateRequest, response *resource.UpdateResponse) { + var t T + + response.Diagnostics.Append(request.Plan.Get(ctx, &t)...) + if response.Diagnostics.HasError() { + return + } + + response.Diagnostics.Append(response.State.Set(ctx, &t)...) } // DataSourceWithConfigure is a structure to be embedded within a DataSource that implements the DataSourceWithConfigure interface. diff --git a/internal/framework/types/mapof.go b/internal/framework/types/mapof.go index 8269c131f8e..eccd2ac7885 100644 --- a/internal/framework/types/mapof.go +++ b/internal/framework/types/mapof.go @@ -16,22 +16,26 @@ import ( ) var ( - _ basetypes.MapTypable = MapTypeOf[basetypes.StringValue]{} + _ basetypes.MapTypable = mapTypeOf[basetypes.StringValue]{} _ basetypes.MapValuable = MapValueOf[basetypes.StringValue]{} ) -// MapTypeOf is the attribute type of a MapValueOf. -type MapTypeOf[T attr.Value] struct { +var ( + // MapOfStringType is a custom type used for defining a Map of strings. + MapOfStringType = mapTypeOf[basetypes.StringValue]{basetypes.MapType{ElemType: basetypes.StringType{}}} +) + +type mapTypeOf[T attr.Value] struct { basetypes.MapType } -func NewMapTypeOf[T attr.Value](ctx context.Context) MapTypeOf[T] { +func NewMapTypeOf[T attr.Value](ctx context.Context) mapTypeOf[T] { var zero T - return MapTypeOf[T]{basetypes.MapType{ElemType: zero.Type(ctx)}} + return mapTypeOf[T]{basetypes.MapType{ElemType: zero.Type(ctx)}} } -func (t MapTypeOf[T]) Equal(o attr.Type) bool { - other, ok := o.(MapTypeOf[T]) +func (t mapTypeOf[T]) Equal(o attr.Type) bool { + other, ok := o.(mapTypeOf[T]) if !ok { return false @@ -40,18 +44,19 @@ func (t MapTypeOf[T]) Equal(o attr.Type) bool { return t.MapType.Equal(other.MapType) } -func (t MapTypeOf[T]) String() string { +func (t mapTypeOf[T]) String() string { var zero T return fmt.Sprintf("%T", zero) } -func (t MapTypeOf[T]) ValueFromMap(ctx context.Context, in basetypes.MapValue) (basetypes.MapValuable, diag.Diagnostics) { +func (t mapTypeOf[T]) ValueFromMap(ctx context.Context, in basetypes.MapValue) (basetypes.MapValuable, diag.Diagnostics) { var diags diag.Diagnostics var zero T if in.IsNull() { return NewMapValueOfNull[T](ctx), diags } + if in.IsUnknown() { return NewMapValueOfUnknown[T](ctx), diags } @@ -73,7 +78,7 @@ func (t MapTypeOf[T]) ValueFromMap(ctx context.Context, in basetypes.MapValue) ( return value, diags } -func (t MapTypeOf[T]) ValueFromTerraform(ctx context.Context, in tftypes.Value) (attr.Value, error) { +func (t mapTypeOf[T]) ValueFromTerraform(ctx context.Context, in tftypes.Value) (attr.Value, error) { attrValue, err := t.MapType.ValueFromTerraform(ctx, in) if err != nil { @@ -93,11 +98,11 @@ func (t MapTypeOf[T]) ValueFromTerraform(ctx context.Context, in tftypes.Value) return mapValuable, nil } -func (t MapTypeOf[T]) ValueType(ctx context.Context) attr.Value { +func (t mapTypeOf[T]) ValueType(ctx context.Context) attr.Value { return MapValueOf[T]{} } -// MapValueOf represents a Terraform Plugin Framework Map value whose elements are of type MapTypeOf. +// MapValueOf represents a Terraform Plugin Framework Map value whose elements are of type mapTypeOf. type MapValueOf[T attr.Value] struct { basetypes.MapValue } diff --git a/internal/framework/types/string_enum.go b/internal/framework/types/string_enum.go index 146d0c165e9..650a8d0c3aa 100644 --- a/internal/framework/types/string_enum.go +++ b/internal/framework/types/string_enum.go @@ -177,6 +177,7 @@ func (v StringEnum[T]) ValueEnum() T { // StringEnumValue is useful if you have a zero value StringEnum but need a // way to get a non-zero value such as when flattening. +// It's called via reflection inside AutoFlEx. func (v StringEnum[T]) StringEnumValue(value string) StringEnum[T] { return StringEnum[T]{StringValue: basetypes.NewStringValue(value)} } diff --git a/internal/framework/validators/s3_uri.go b/internal/framework/validators/s3_uri.go new file mode 100644 index 00000000000..6a762cf9332 --- /dev/null +++ b/internal/framework/validators/s3_uri.go @@ -0,0 +1,48 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package validators + +import ( + "context" + + "github.com/YakDriver/regexache" + "github.com/hashicorp/terraform-plugin-framework-validators/helpers/validatordiag" + "github.com/hashicorp/terraform-plugin-framework/schema/validator" +) + +// s3URIValidator validates that a string Attribute's value is a valid S3 URI. +type s3URIValidator struct{} + +func (validator s3URIValidator) Description(_ context.Context) string { + return "value must be a valid S3 URI" +} + +func (validator s3URIValidator) MarkdownDescription(ctx context.Context) string { + return validator.Description(ctx) +} + +func (validator s3URIValidator) ValidateString(ctx context.Context, request validator.StringRequest, response *validator.StringResponse) { + if request.ConfigValue.IsNull() || request.ConfigValue.IsUnknown() { + return + } + + if !regexache.MustCompile(`^s3://[a-z0-9][\.\-a-z0-9]{1,61}[a-z0-9](/.*)?$`).MatchString(request.ConfigValue.ValueString()) { + response.Diagnostics.Append(validatordiag.InvalidAttributeValueDiagnostic( + request.Path, + validator.Description(ctx), + request.ConfigValue.ValueString(), + )) + return + } +} + +// S3URI returns a string validator which ensures that any configured +// attribute value: +// +// - Is a string, which represents a valid S3 URI (s3://bucket[/key]). +// +// Null (unconfigured) and unknown (known after apply) values are skipped. +func S3URI() validator.String { + return s3URIValidator{} +} diff --git a/internal/framework/validators/s3_uri_test.go b/internal/framework/validators/s3_uri_test.go new file mode 100644 index 00000000000..b8b9212224b --- /dev/null +++ b/internal/framework/validators/s3_uri_test.go @@ -0,0 +1,77 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package validators_test + +import ( + "context" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/hashicorp/terraform-plugin-framework/diag" + "github.com/hashicorp/terraform-plugin-framework/path" + "github.com/hashicorp/terraform-plugin-framework/schema/validator" + "github.com/hashicorp/terraform-plugin-framework/types" + fwvalidators "github.com/hashicorp/terraform-provider-aws/internal/framework/validators" +) + +func TestS3URIValidator(t *testing.T) { + t.Parallel() + + type testCase struct { + val types.String + expectedDiagnostics diag.Diagnostics + } + tests := map[string]testCase{ + "unknown String": { + val: types.StringUnknown(), + }, + "null String": { + val: types.StringNull(), + }, + "invalid String": { + val: types.StringValue("test-value"), + expectedDiagnostics: diag.Diagnostics{ + diag.NewAttributeErrorDiagnostic( + path.Root("test"), + "Invalid Attribute Value", + `Attribute test value must be a valid S3 URI, got: test-value`, + ), + }, + }, + "valid S3 URI": { + val: types.StringValue("s3://bucket/path/to/key"), + }, + "invalid characters": { + val: types.StringValue("s3://asbcdefg--#/key"), + expectedDiagnostics: diag.Diagnostics{ + diag.NewAttributeErrorDiagnostic( + path.Root("test"), + "Invalid Attribute Value", + `Attribute test value must be a valid S3 URI, got: s3://asbcdefg--#/key`, + ), + }, + }, + } + + for name, test := range tests { + name, test := name, test + t.Run(name, func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + request := validator.StringRequest{ + Path: path.Root("test"), + PathExpression: path.MatchRoot("test"), + ConfigValue: test.val, + } + response := validator.StringResponse{} + fwvalidators.S3URI().ValidateString(ctx, request, &response) + + if diff := cmp.Diff(response.Diagnostics, test.expectedDiagnostics); diff != "" { + t.Errorf("unexpected diagnostics difference: %s", diff) + } + }) + } +} diff --git a/internal/service/bedrock/consts.go b/internal/service/bedrock/consts.go new file mode 100644 index 00000000000..e571fdc4591 --- /dev/null +++ b/internal/service/bedrock/consts.go @@ -0,0 +1,12 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package bedrock + +import ( + "time" +) + +const ( + propagationTimeout = 2 * time.Minute +) diff --git a/internal/service/bedrock/custom_model.go b/internal/service/bedrock/custom_model.go new file mode 100644 index 00000000000..785ef799aa4 --- /dev/null +++ b/internal/service/bedrock/custom_model.go @@ -0,0 +1,679 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package bedrock + +import ( + "context" + "errors" + "fmt" + "strings" + "time" + + "github.com/YakDriver/regexache" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/aws/arn" + "github.com/aws/aws-sdk-go-v2/service/bedrock" + awstypes "github.com/aws/aws-sdk-go-v2/service/bedrock/types" + "github.com/hashicorp/terraform-plugin-framework-timeouts/resource/timeouts" + "github.com/hashicorp/terraform-plugin-framework-validators/listvalidator" + "github.com/hashicorp/terraform-plugin-framework-validators/stringvalidator" + "github.com/hashicorp/terraform-plugin-framework/attr" + "github.com/hashicorp/terraform-plugin-framework/resource" + "github.com/hashicorp/terraform-plugin-framework/resource/schema" + "github.com/hashicorp/terraform-plugin-framework/resource/schema/listplanmodifier" + "github.com/hashicorp/terraform-plugin-framework/resource/schema/mapplanmodifier" + "github.com/hashicorp/terraform-plugin-framework/resource/schema/planmodifier" + "github.com/hashicorp/terraform-plugin-framework/resource/schema/setplanmodifier" + "github.com/hashicorp/terraform-plugin-framework/resource/schema/stringplanmodifier" + "github.com/hashicorp/terraform-plugin-framework/schema/validator" + "github.com/hashicorp/terraform-plugin-framework/types" + "github.com/hashicorp/terraform-plugin-sdk/v2/helper/id" + "github.com/hashicorp/terraform-plugin-sdk/v2/helper/retry" + "github.com/hashicorp/terraform-provider-aws/internal/enum" + "github.com/hashicorp/terraform-provider-aws/internal/errs" + "github.com/hashicorp/terraform-provider-aws/internal/errs/fwdiag" + "github.com/hashicorp/terraform-provider-aws/internal/framework" + fwflex "github.com/hashicorp/terraform-provider-aws/internal/framework/flex" + fwtypes "github.com/hashicorp/terraform-provider-aws/internal/framework/types" + fwvalidators "github.com/hashicorp/terraform-provider-aws/internal/framework/validators" + tftags "github.com/hashicorp/terraform-provider-aws/internal/tags" + "github.com/hashicorp/terraform-provider-aws/internal/tfresource" + "github.com/hashicorp/terraform-provider-aws/names" +) + +// @FrameworkResource(name="Custom Model") +// @Tags(identifierAttribute="job_arn") +func newCustomModelResource(context.Context) (resource.ResourceWithConfigure, error) { + r := &customModelResource{} + + r.SetDefaultDeleteTimeout(120 * time.Minute) + + return r, nil +} + +type customModelResource struct { + framework.ResourceWithConfigure + framework.WithImportByID + framework.WithTimeouts +} + +func (r *customModelResource) Metadata(_ context.Context, request resource.MetadataRequest, resp *resource.MetadataResponse) { + resp.TypeName = "aws_bedrock_custom_model" +} + +func (r *customModelResource) Schema(ctx context.Context, request resource.SchemaRequest, response *resource.SchemaResponse) { + // This resource is a composition of the following APIs. These APIs do not have consitently named attributes, so we will normalize them here. + // - CreateModelCustomizationJob + // - GetModelCustomizationJob + // - GetCustomModel + response.Schema = schema.Schema{ + Attributes: map[string]schema.Attribute{ + "base_model_identifier": schema.StringAttribute{ + CustomType: fwtypes.ARNType, + Required: true, + PlanModifiers: []planmodifier.String{ + stringplanmodifier.RequiresReplace(), + }, + }, + "custom_model_arn": schema.StringAttribute{ + Computed: true, + }, + "custom_model_kms_key_id": schema.StringAttribute{ + Optional: true, + CustomType: fwtypes.ARNType, + PlanModifiers: []planmodifier.String{ + stringplanmodifier.RequiresReplace(), + }, + }, + "custom_model_name": schema.StringAttribute{ + Required: true, + PlanModifiers: []planmodifier.String{ + stringplanmodifier.RequiresReplace(), + }, + Validators: []validator.String{ + stringvalidator.LengthBetween(1, 63), + }, + }, + "customization_type": schema.StringAttribute{ + CustomType: fwtypes.StringEnumType[awstypes.CustomizationType](), + Optional: true, + Computed: true, + PlanModifiers: []planmodifier.String{ + stringplanmodifier.RequiresReplace(), + stringplanmodifier.UseStateForUnknown(), + }, + }, + "hyperparameters": schema.MapAttribute{ + CustomType: fwtypes.MapOfStringType, + Required: true, + ElementType: types.StringType, + PlanModifiers: []planmodifier.Map{ + mapplanmodifier.RequiresReplace(), + }, + }, + names.AttrID: framework.IDAttribute(), + "job_arn": framework.ARNAttributeComputedOnly(), + "job_name": schema.StringAttribute{ + Required: true, + PlanModifiers: []planmodifier.String{ + stringplanmodifier.RequiresReplace(), + }, + Validators: []validator.String{ + stringvalidator.LengthBetween(1, 63), + stringvalidator.RegexMatches(regexache.MustCompile(`^[a-zA-Z0-9](-*[a-zA-Z0-9\+\-\.])*$`), + "must be up to 63 letters (uppercase and lowercase), numbers, plus sign, dashes, and dots, and must start with an alphanumeric"), + }, + }, + "job_status": schema.StringAttribute{ + CustomType: fwtypes.StringEnumType[awstypes.ModelCustomizationJobStatus](), + Computed: true, + }, + "role_arn": schema.StringAttribute{ + CustomType: fwtypes.ARNType, + Required: true, + PlanModifiers: []planmodifier.String{ + stringplanmodifier.RequiresReplace(), + }, + }, + names.AttrTags: tftags.TagsAttribute(), + names.AttrTagsAll: tftags.TagsAttributeComputedOnly(), + "training_metrics": schema.ListAttribute{ + CustomType: fwtypes.NewListNestedObjectTypeOf[customModelTrainingMetricsModel](ctx), + Computed: true, + ElementType: types.ObjectType{ + AttrTypes: map[string]attr.Type{ + "training_loss": types.Float64Type, + }, + }, + }, + "validation_metrics": schema.ListAttribute{ + CustomType: fwtypes.NewListNestedObjectTypeOf[customModelValidationMetricsModel](ctx), + Computed: true, + ElementType: types.ObjectType{ + AttrTypes: map[string]attr.Type{ + "validation_loss": types.Float64Type, + }, + }, + }, + }, + Blocks: map[string]schema.Block{ + "output_data_config": schema.ListNestedBlock{ + CustomType: fwtypes.NewListNestedObjectTypeOf[customModelOutputDataConfigModel](ctx), + PlanModifiers: []planmodifier.List{ + listplanmodifier.RequiresReplace(), + }, + Validators: []validator.List{ + listvalidator.IsRequired(), + listvalidator.SizeAtLeast(1), + listvalidator.SizeAtMost(1), + }, + NestedObject: schema.NestedBlockObject{ + Attributes: map[string]schema.Attribute{ + "s3_uri": schema.StringAttribute{ + Required: true, + PlanModifiers: []planmodifier.String{ + stringplanmodifier.RequiresReplace(), + }, + Validators: []validator.String{ + fwvalidators.S3URI(), + }, + }, + }, + }, + }, + "timeouts": timeouts.Block(ctx, timeouts.Opts{ + Create: true, + Delete: true, + }), + "training_data_config": schema.ListNestedBlock{ + CustomType: fwtypes.NewListNestedObjectTypeOf[customModelTrainingDataConfigModel](ctx), + PlanModifiers: []planmodifier.List{ + listplanmodifier.RequiresReplace(), + }, + Validators: []validator.List{ + listvalidator.IsRequired(), + listvalidator.SizeAtLeast(1), + listvalidator.SizeAtMost(1), + }, + NestedObject: schema.NestedBlockObject{ + Attributes: map[string]schema.Attribute{ + "s3_uri": schema.StringAttribute{ + Required: true, + PlanModifiers: []planmodifier.String{ + stringplanmodifier.RequiresReplace(), + }, + Validators: []validator.String{ + fwvalidators.S3URI(), + }, + }, + }, + }, + }, + "validation_data_config": schema.ListNestedBlock{ + CustomType: fwtypes.NewListNestedObjectTypeOf[customModelValidationDataConfigModel](ctx), + PlanModifiers: []planmodifier.List{ + listplanmodifier.RequiresReplace(), + }, + Validators: []validator.List{ + listvalidator.SizeAtMost(1), + }, + NestedObject: schema.NestedBlockObject{ + Blocks: map[string]schema.Block{ + "validator": schema.ListNestedBlock{ + CustomType: fwtypes.NewListNestedObjectTypeOf[customModelValidatorConfigModel](ctx), + PlanModifiers: []planmodifier.List{ + listplanmodifier.RequiresReplace(), + }, + Validators: []validator.List{ + listvalidator.IsRequired(), + listvalidator.SizeAtLeast(1), + listvalidator.SizeAtMost(10), + }, + NestedObject: schema.NestedBlockObject{ + Attributes: map[string]schema.Attribute{ + "s3_uri": schema.StringAttribute{ + Required: true, + PlanModifiers: []planmodifier.String{ + stringplanmodifier.RequiresReplace(), + }, + Validators: []validator.String{ + fwvalidators.S3URI()}, + }, + }, + }, + }, + }, + }, + }, + "vpc_config": schema.ListNestedBlock{ + CustomType: fwtypes.NewListNestedObjectTypeOf[customModelVPCConfigModel](ctx), + PlanModifiers: []planmodifier.List{ + listplanmodifier.RequiresReplace(), + }, + Validators: []validator.List{ + listvalidator.SizeAtMost(1), + }, + NestedObject: schema.NestedBlockObject{ + Attributes: map[string]schema.Attribute{ + "security_group_ids": schema.SetAttribute{ + CustomType: fwtypes.SetOfStringType, + Required: true, + ElementType: types.StringType, + PlanModifiers: []planmodifier.Set{ + setplanmodifier.RequiresReplace(), + }, + }, + "subnet_ids": schema.SetAttribute{ + CustomType: fwtypes.SetOfStringType, + Required: true, + ElementType: types.StringType, + PlanModifiers: []planmodifier.Set{ + setplanmodifier.RequiresReplace(), + }, + }, + }, + }, + }, + }, + } +} + +func (r *customModelResource) Create(ctx context.Context, request resource.CreateRequest, response *resource.CreateResponse) { + var data resourceCustomModelData + response.Diagnostics.Append(request.Plan.Get(ctx, &data)...) + if response.Diagnostics.HasError() { + return + } + + conn := r.Meta().BedrockClient(ctx) + + input := &bedrock.CreateModelCustomizationJobInput{} + response.Diagnostics.Append(fwflex.Expand(ctx, data, input)...) + if response.Diagnostics.HasError() { + return + } + + // Additional fields. + input.ClientRequestToken = aws.String(id.UniqueId()) + input.CustomModelTags = getTagsIn(ctx) + input.JobTags = getTagsIn(ctx) + + outputRaw, err := tfresource.RetryWhenAWSErrMessageContains(ctx, propagationTimeout, func() (interface{}, error) { + return conn.CreateModelCustomizationJob(ctx, input) + }, errCodeValidationException, "Could not assume provided IAM role") + + if err != nil { + response.Diagnostics.AddError("creating Bedrock Custom Model customization job", err.Error()) + + return + } + + jobARN := aws.ToString(outputRaw.(*bedrock.CreateModelCustomizationJobOutput).JobArn) + job, err := findModelCustomizationJobByID(ctx, conn, jobARN) + + if err != nil { + response.Diagnostics.AddError(fmt.Sprintf("reading Bedrock Custom Model customization job (%s)", jobARN), err.Error()) + + return + } + + // Set values for unknowns. + data.CustomizationType = fwtypes.StringEnumValue(job.CustomizationType) + data.CustomModelARN = fwflex.StringToFramework(ctx, job.OutputModelArn) + data.JobARN = fwflex.StringToFramework(ctx, job.JobArn) + data.JobStatus = fwtypes.StringEnumValue(job.Status) + data.TrainingMetrics = fwtypes.NewListNestedObjectValueOfNull[customModelTrainingMetricsModel](ctx) + data.ValidationMetrics = fwtypes.NewListNestedObjectValueOfNull[customModelValidationMetricsModel](ctx) + data.setID() + + response.Diagnostics.Append(response.State.Set(ctx, &data)...) +} + +func (r *customModelResource) Read(ctx context.Context, request resource.ReadRequest, response *resource.ReadResponse) { + var data resourceCustomModelData + response.Diagnostics.Append(request.State.Get(ctx, &data)...) + if response.Diagnostics.HasError() { + return + } + + if err := data.InitFromID(); err != nil { + response.Diagnostics.AddError("parsing resource ID", err.Error()) + + return + } + + conn := r.Meta().BedrockClient(ctx) + + jobARN := data.JobARN.ValueString() + outputGJ, err := findModelCustomizationJobByID(ctx, conn, jobARN) + + if tfresource.NotFound(err) { + response.Diagnostics.Append(fwdiag.NewResourceNotFoundWarningDiagnostic(err)) + response.State.RemoveResource(ctx) + + return + } + + if err != nil { + response.Diagnostics.AddError(fmt.Sprintf("reading Bedrock Custom Model customization job (%s)", jobARN), err.Error()) + + return + } + + response.Diagnostics.Append(fwflex.Flatten(ctx, outputGJ, &data)...) + if response.Diagnostics.HasError() { + return + } + + // Some fields in GetModelCustomizationJobOutput have different names than in CreateModelCustomizationJobInput. + data.CustomModelKmsKeyID = fwflex.StringToFrameworkARN(ctx, outputGJ.OutputModelKmsKeyArn) + data.CustomModelName = fwflex.StringToFramework(ctx, outputGJ.OutputModelName) + data.JobStatus = fwtypes.StringEnumValue(outputGJ.Status) + // The base model ARN in GetCustomModelOutput can contain the model version and parameter count. + baseModelARN := fwflex.StringFromFramework(ctx, data.BaseModelIdentifier) + data.BaseModelIdentifier = fwflex.StringToFrameworkARN(ctx, outputGJ.BaseModelArn) + if baseModelARN != nil { + if old, err := arn.Parse(aws.ToString(baseModelARN)); err == nil { + if new, err := arn.Parse(aws.ToString(outputGJ.BaseModelArn)); err == nil { + if len(strings.SplitN(old.Resource, ":", 2)) == 1 { + // Old ARN doesn't contain the model version and parameter count. + new.Resource = strings.SplitN(new.Resource, ":", 2)[0] + data.BaseModelIdentifier = fwtypes.ARNValue(new.String()) + } + } + } + } + + if outputGJ.OutputModelArn != nil { + customModelARN := aws.ToString(outputGJ.OutputModelArn) + outputGM, err := findCustomModelByID(ctx, conn, customModelARN) + + if tfresource.NotFound(err) { + response.Diagnostics.Append(fwdiag.NewResourceNotFoundWarningDiagnostic(err)) + response.State.RemoveResource(ctx) + + return + } + + if err != nil { + response.Diagnostics.AddError(fmt.Sprintf("reading Bedrock Custom Model (%s)", customModelARN), err.Error()) + + return + } + + var dataFromGetCustomModel resourceCustomModelData + response.Diagnostics.Append(fwflex.Flatten(ctx, outputGM, &dataFromGetCustomModel)...) + if response.Diagnostics.HasError() { + return + } + + data.CustomModelARN = fwflex.StringToFramework(ctx, outputGM.ModelArn) + data.TrainingMetrics = dataFromGetCustomModel.TrainingMetrics + data.ValidationMetrics = dataFromGetCustomModel.ValidationMetrics + } + + response.Diagnostics.Append(response.State.Set(ctx, &data)...) +} + +func (r *customModelResource) Update(ctx context.Context, request resource.UpdateRequest, response *resource.UpdateResponse) { + var old, new resourceCustomModelData + response.Diagnostics.Append(request.State.Get(ctx, &old)...) + if response.Diagnostics.HasError() { + return + } + response.Diagnostics.Append(request.Plan.Get(ctx, &new)...) + if response.Diagnostics.HasError() { + return + } + + // Update is only called when `tags` are updated. + // Set unknowns to the old (in state) values. + new.CustomModelARN = old.CustomModelARN + new.JobStatus = old.JobStatus + new.TrainingMetrics = old.TrainingMetrics + new.ValidationMetrics = old.ValidationMetrics + + response.Diagnostics.Append(response.State.Set(ctx, &new)...) +} + +func (r *customModelResource) Delete(ctx context.Context, request resource.DeleteRequest, response *resource.DeleteResponse) { + var data resourceCustomModelData + response.Diagnostics.Append(request.State.Get(ctx, &data)...) + if response.Diagnostics.HasError() { + return + } + + conn := r.Meta().BedrockClient(ctx) + + if data.JobStatus.ValueEnum() == awstypes.ModelCustomizationJobStatusInProgress { + jobARN := data.JobARN.ValueString() + input := &bedrock.StopModelCustomizationJobInput{ + JobIdentifier: aws.String(jobARN), + } + + _, err := conn.StopModelCustomizationJob(ctx, input) + + if errs.IsA[*awstypes.ResourceNotFoundException](err) { + return + } + + if err != nil { + response.Diagnostics.AddError(fmt.Sprintf("stopping Bedrock Custom Model customization job (%s)", jobARN), err.Error()) + + return + } + + if _, err := waitModelCustomizationJobStopped(ctx, conn, jobARN, r.DeleteTimeout(ctx, data.Timeouts)); err != nil { + response.Diagnostics.AddError(fmt.Sprintf("waiting for Bedrock Custom Model customization job (%s) stop", jobARN), err.Error()) + + return + } + } + + if !data.CustomModelARN.IsNull() { + _, err := conn.DeleteCustomModel(ctx, &bedrock.DeleteCustomModelInput{ + ModelIdentifier: fwflex.StringFromFramework(ctx, data.CustomModelARN), + }) + + if errs.IsA[*awstypes.ResourceNotFoundException](err) { + return + } + + if err != nil { + response.Diagnostics.AddError(fmt.Sprintf("deleting Bedrock Custom Model (%s)", data.ID.ValueString()), err.Error()) + + return + } + } +} + +func (r *customModelResource) ModifyPlan(ctx context.Context, request resource.ModifyPlanRequest, response *resource.ModifyPlanResponse) { + r.SetTagsAll(ctx, request, response) +} + +func findCustomModelByID(ctx context.Context, conn *bedrock.Client, id string) (*bedrock.GetCustomModelOutput, error) { + input := &bedrock.GetCustomModelInput{ + ModelIdentifier: aws.String(id), + } + + output, err := conn.GetCustomModel(ctx, input) + + if errs.IsA[*awstypes.ResourceNotFoundException](err) { + return nil, &retry.NotFoundError{ + LastError: err, + LastRequest: input, + } + } + + if err != nil { + return nil, err + } + + if output == nil { + return nil, tfresource.NewEmptyResultError(input) + } + + return output, nil +} + +func findModelCustomizationJobByID(ctx context.Context, conn *bedrock.Client, id string) (*bedrock.GetModelCustomizationJobOutput, error) { + input := &bedrock.GetModelCustomizationJobInput{ + JobIdentifier: aws.String(id), + } + + output, err := findModelCustomizationJob(ctx, conn, input) + + if err != nil { + return nil, err + } + + if status := output.Status; status == awstypes.ModelCustomizationJobStatusStopped { + return nil, &retry.NotFoundError{ + Message: string(status), + LastRequest: input, + } + } + + return output, nil +} + +func findModelCustomizationJob(ctx context.Context, conn *bedrock.Client, input *bedrock.GetModelCustomizationJobInput) (*bedrock.GetModelCustomizationJobOutput, error) { + output, err := conn.GetModelCustomizationJob(ctx, input) + + if errs.IsA[*awstypes.ResourceNotFoundException](err) { + return nil, &retry.NotFoundError{ + LastError: err, + LastRequest: input, + } + } + + if err != nil { + return nil, err + } + + if output == nil { + return nil, tfresource.NewEmptyResultError(input) + } + + return output, nil +} + +func statusModelCustomizationJob(ctx context.Context, conn *bedrock.Client, id string) retry.StateRefreshFunc { + return func() (interface{}, string, error) { + input := &bedrock.GetModelCustomizationJobInput{ + JobIdentifier: aws.String(id), + } + output, err := findModelCustomizationJob(ctx, conn, input) + + if tfresource.NotFound(err) { + return nil, "", nil + } + + if err != nil { + return nil, "", err + } + + return output, string(output.Status), nil + } +} + +func waitModelCustomizationJobCompleted(ctx context.Context, conn *bedrock.Client, id string, timeout time.Duration) (*bedrock.GetModelCustomizationJobOutput, error) { + stateConf := &retry.StateChangeConf{ + Pending: enum.Slice(awstypes.ModelCustomizationJobStatusInProgress), + Target: enum.Slice(awstypes.ModelCustomizationJobStatusCompleted), + Refresh: statusModelCustomizationJob(ctx, conn, id), + Timeout: timeout, + } + + outputRaw, err := stateConf.WaitForStateContext(ctx) + + if output, ok := outputRaw.(*bedrock.GetModelCustomizationJobOutput); ok { + tfresource.SetLastError(err, errors.New(aws.ToString(output.FailureMessage))) + + return output, err + } + + return nil, err +} + +func waitModelCustomizationJobStopped(ctx context.Context, conn *bedrock.Client, id string, timeout time.Duration) (*bedrock.GetModelCustomizationJobOutput, error) { + stateConf := &retry.StateChangeConf{ + Pending: enum.Slice(awstypes.ModelCustomizationJobStatusStopping), + Target: enum.Slice(awstypes.ModelCustomizationJobStatusStopped), + Refresh: statusModelCustomizationJob(ctx, conn, id), + Timeout: timeout, + } + + outputRaw, err := stateConf.WaitForStateContext(ctx) + + if output, ok := outputRaw.(*bedrock.GetModelCustomizationJobOutput); ok { + tfresource.SetLastError(err, errors.New(aws.ToString(output.FailureMessage))) + + return output, err + } + + return nil, err +} + +type resourceCustomModelData struct { + BaseModelIdentifier fwtypes.ARN `tfsdk:"base_model_identifier"` + CustomModelARN types.String `tfsdk:"custom_model_arn"` + CustomModelKmsKeyID fwtypes.ARN `tfsdk:"custom_model_kms_key_id"` + CustomModelName types.String `tfsdk:"custom_model_name"` + CustomizationType fwtypes.StringEnum[awstypes.CustomizationType] `tfsdk:"customization_type"` + HyperParameters fwtypes.MapValueOf[types.String] `tfsdk:"hyperparameters"` + ID types.String `tfsdk:"id"` + JobARN types.String `tfsdk:"job_arn"` + JobName types.String `tfsdk:"job_name"` + JobStatus fwtypes.StringEnum[awstypes.ModelCustomizationJobStatus] `tfsdk:"job_status"` + OutputDataConfig fwtypes.ListNestedObjectValueOf[customModelOutputDataConfigModel] `tfsdk:"output_data_config"` + RoleARN fwtypes.ARN `tfsdk:"role_arn"` + Tags types.Map `tfsdk:"tags"` + TagsAll types.Map `tfsdk:"tags_all"` + Timeouts timeouts.Value `tfsdk:"timeouts"` + TrainingDataConfig fwtypes.ListNestedObjectValueOf[customModelTrainingDataConfigModel] `tfsdk:"training_data_config"` + TrainingMetrics fwtypes.ListNestedObjectValueOf[customModelTrainingMetricsModel] `tfsdk:"training_metrics"` + ValidationDataConfig fwtypes.ListNestedObjectValueOf[customModelValidationDataConfigModel] `tfsdk:"validation_data_config"` + ValidationMetrics fwtypes.ListNestedObjectValueOf[customModelValidationMetricsModel] `tfsdk:"validation_metrics"` + VPCConfig fwtypes.ListNestedObjectValueOf[customModelVPCConfigModel] `tfsdk:"vpc_config"` +} + +func (data *resourceCustomModelData) InitFromID() error { + data.JobARN = data.ID + + return nil +} + +func (data *resourceCustomModelData) setID() { + data.ID = data.JobARN +} + +type customModelOutputDataConfigModel struct { + S3URI types.String `tfsdk:"s3_uri"` +} + +type customModelTrainingDataConfigModel struct { + S3URI types.String `tfsdk:"s3_uri"` +} + +type customModelTrainingMetricsModel struct { + TrainingLoss types.Float64 `tfsdk:"training_loss"` +} + +type customModelValidationDataConfigModel struct { + Validators fwtypes.ListNestedObjectValueOf[customModelValidatorConfigModel] `tfsdk:"validator"` +} + +type customModelValidationMetricsModel struct { + ValidationLoss types.Float64 `tfsdk:"validation_loss"` +} + +type customModelValidatorConfigModel struct { + S3URI types.String `tfsdk:"s3_uri"` +} + +type customModelVPCConfigModel struct { + SecurityGroupIDs fwtypes.SetValueOf[types.String] `tfsdk:"security_group_ids"` + SubnetIDs fwtypes.SetValueOf[types.String] `tfsdk:"subnet_ids"` +} diff --git a/internal/service/bedrock/custom_model_data_source.go b/internal/service/bedrock/custom_model_data_source.go new file mode 100644 index 00000000000..f5e48f1df6f --- /dev/null +++ b/internal/service/bedrock/custom_model_data_source.go @@ -0,0 +1,206 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package bedrock + +import ( + "context" + "fmt" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/hashicorp/terraform-plugin-framework/attr" + "github.com/hashicorp/terraform-plugin-framework/datasource" + "github.com/hashicorp/terraform-plugin-framework/datasource/schema" + "github.com/hashicorp/terraform-plugin-framework/types" + "github.com/hashicorp/terraform-provider-aws/internal/framework" + fwflex "github.com/hashicorp/terraform-provider-aws/internal/framework/flex" + fwtypes "github.com/hashicorp/terraform-provider-aws/internal/framework/types" + tftags "github.com/hashicorp/terraform-provider-aws/internal/tags" + "github.com/hashicorp/terraform-provider-aws/names" +) + +// @FrameworkDataSource(name="Custom Model") +func newCustomModelDataSource(context.Context) (datasource.DataSourceWithConfigure, error) { + return &customModelDataSource{}, nil +} + +type customModelDataSource struct { + framework.DataSourceWithConfigure +} + +func (d *customModelDataSource) Metadata(_ context.Context, request datasource.MetadataRequest, response *datasource.MetadataResponse) { + response.TypeName = "aws_bedrock_custom_model" +} + +// Schema returns the schema for this data source. +func (d *customModelDataSource) Schema(ctx context.Context, request datasource.SchemaRequest, response *datasource.SchemaResponse) { + response.Schema = schema.Schema{ + Attributes: map[string]schema.Attribute{ + "base_model_arn": schema.StringAttribute{ + Computed: true, + }, + "creation_time": schema.StringAttribute{ + CustomType: fwtypes.TimestampType, + Computed: true, + }, + "hyperparameters": schema.MapAttribute{ + CustomType: fwtypes.MapOfStringType, + Computed: true, + ElementType: types.StringType, + }, + names.AttrID: framework.IDAttribute(), + "job_arn": schema.StringAttribute{ + Computed: true, + }, + "job_name": schema.StringAttribute{ + Computed: true, + }, + "job_tags": tftags.TagsAttributeComputedOnly(), + "model_arn": schema.StringAttribute{ + Computed: true, + }, + "model_id": schema.StringAttribute{ + Required: true, + }, + "model_kms_key_arn": schema.StringAttribute{ + Computed: true, + }, + "model_name": schema.StringAttribute{ + Computed: true, + }, + "model_tags": tftags.TagsAttributeComputedOnly(), + "output_data_config": schema.ListAttribute{ + CustomType: fwtypes.NewListNestedObjectTypeOf[customModelOutputDataConfigModel](ctx), + Computed: true, + ElementType: types.ObjectType{ + AttrTypes: map[string]attr.Type{ + "s3_uri": types.StringType, + }, + }, + }, + "training_data_config": schema.ListAttribute{ + CustomType: fwtypes.NewListNestedObjectTypeOf[customModelTrainingDataConfigModel](ctx), + Computed: true, + ElementType: types.ObjectType{ + AttrTypes: map[string]attr.Type{ + "s3_uri": types.StringType, + }, + }, + }, + "training_metrics": schema.ListAttribute{ + CustomType: fwtypes.NewListNestedObjectTypeOf[customModelTrainingMetricsModel](ctx), + Computed: true, + ElementType: types.ObjectType{ + AttrTypes: map[string]attr.Type{ + "training_loss": types.Float64Type, + }, + }, + }, + "validation_data_config": schema.ListAttribute{ + CustomType: fwtypes.NewListNestedObjectTypeOf[customModelValidationDataConfigModel](ctx), + Computed: true, + ElementType: types.ObjectType{ + AttrTypes: map[string]attr.Type{ + "validator": fwtypes.NewListNestedObjectTypeOf[customModelValidatorConfigModel](ctx), + }, + }, + }, + "validation_metrics": schema.ListAttribute{ + CustomType: fwtypes.NewListNestedObjectTypeOf[customModelValidationMetricsModel](ctx), + Computed: true, + ElementType: types.ObjectType{ + AttrTypes: map[string]attr.Type{ + "validation_loss": types.Float64Type, + }, + }, + }, + }, + } +} + +func (d *customModelDataSource) Read(ctx context.Context, request datasource.ReadRequest, response *datasource.ReadResponse) { + var data customModelDataSourceModel + response.Diagnostics.Append(request.Config.Get(ctx, &data)...) + if response.Diagnostics.HasError() { + return + } + + conn := d.Meta().BedrockClient(ctx) + + modelID := data.ModelID.ValueString() + outputGM, err := findCustomModelByID(ctx, conn, modelID) + + if err != nil { + response.Diagnostics.AddError(fmt.Sprintf("reading Bedrock Custom Model (%s)", modelID), err.Error()) + + return + } + + jobARN := aws.ToString(outputGM.JobArn) + outputGJ, err := findModelCustomizationJobByID(ctx, conn, jobARN) + + if err != nil { + response.Diagnostics.AddError(fmt.Sprintf("reading Bedrock Custom Model customization job (%s)", jobARN), err.Error()) + + return + } + + response.Diagnostics.Append(fwflex.Flatten(ctx, outputGM, &data)...) + if response.Diagnostics.HasError() { + return + } + + // Some fields are only available in GetModelCustomizationJobOutput. + var dataFromGetModelCustomizationJob resourceCustomModelData + response.Diagnostics.Append(fwflex.Flatten(ctx, outputGJ, &dataFromGetModelCustomizationJob)...) + if response.Diagnostics.HasError() { + return + } + + data.ID = types.StringValue(modelID) + data.JobName = dataFromGetModelCustomizationJob.JobName + data.ValidationDataConfig = dataFromGetModelCustomizationJob.ValidationDataConfig + + jobTags, err := listTags(ctx, conn, jobARN) + + if err != nil { + response.Diagnostics.AddError(fmt.Sprintf("reading Bedrock Custom Model customization job (%s) tags", jobARN), err.Error()) + + return + } + + data.JobTags = fwflex.FlattenFrameworkStringValueMap(ctx, jobTags.IgnoreAWS().Map()) + + modelARN := aws.ToString(outputGM.ModelArn) + modelTags, err := listTags(ctx, conn, modelARN) + + if err != nil { + response.Diagnostics.AddError(fmt.Sprintf("reading Bedrock Custom Model (%s) tags", modelARN), err.Error()) + + return + } + + data.ModelTags = fwflex.FlattenFrameworkStringValueMap(ctx, modelTags.IgnoreAWS().Map()) + + response.Diagnostics.Append(response.State.Set(ctx, &data)...) +} + +type customModelDataSourceModel struct { + BaseModelARN types.String `tfsdk:"base_model_arn"` + CreationTime fwtypes.Timestamp `tfsdk:"creation_time"` + HyperParameters fwtypes.MapValueOf[types.String] `tfsdk:"hyperparameters"` + ID types.String `tfsdk:"id"` + JobARN types.String `tfsdk:"job_arn"` + JobName types.String `tfsdk:"job_name"` + JobTags types.Map `tfsdk:"job_tags"` + ModelARN types.String `tfsdk:"model_arn"` + ModelID types.String `tfsdk:"model_id"` + ModelKMSKeyARN types.String `tfsdk:"model_kms_key_arn"` + ModelName types.String `tfsdk:"model_name"` + ModelTags types.Map `tfsdk:"model_tags"` + OutputDataConfig fwtypes.ListNestedObjectValueOf[customModelOutputDataConfigModel] `tfsdk:"output_data_config"` + TrainingDataConfig fwtypes.ListNestedObjectValueOf[customModelTrainingDataConfigModel] `tfsdk:"training_data_config"` + TrainingMetrics fwtypes.ListNestedObjectValueOf[customModelTrainingMetricsModel] `tfsdk:"training_metrics"` + ValidationDataConfig fwtypes.ListNestedObjectValueOf[customModelValidationDataConfigModel] `tfsdk:"validation_data_config"` + ValidationMetrics fwtypes.ListNestedObjectValueOf[customModelValidationMetricsModel] `tfsdk:"validation_metrics"` +} diff --git a/internal/service/bedrock/custom_model_data_source_test.go b/internal/service/bedrock/custom_model_data_source_test.go new file mode 100644 index 00000000000..ad4f04c142f --- /dev/null +++ b/internal/service/bedrock/custom_model_data_source_test.go @@ -0,0 +1,67 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package bedrock_test + +import ( + "testing" + + "github.com/aws/aws-sdk-go-v2/service/bedrock" + sdkacctest "github.com/hashicorp/terraform-plugin-testing/helper/acctest" + "github.com/hashicorp/terraform-plugin-testing/helper/resource" + "github.com/hashicorp/terraform-provider-aws/internal/acctest" + "github.com/hashicorp/terraform-provider-aws/names" +) + +func TestAccBedrockCustomModelDataSource_basic(t *testing.T) { + ctx := acctest.Context(t) + if testing.Short() { + t.Skip("skipping long-running test in short mode") + } + + rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix) + resourceName := "aws_bedrock_custom_model.test" + datasourceName := "data.aws_bedrock_custom_model.test" + var v bedrock.GetModelCustomizationJobOutput + + resource.ParallelTest(t, resource.TestCase{ + PreCheck: func() { acctest.PreCheck(ctx, t); acctest.PreCheckPartitionHasService(t, names.BedrockEndpointID) }, + ErrorCheck: acctest.ErrorCheck(t, names.BedrockEndpointID), + ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, + Steps: []resource.TestStep{ + { + Config: testAccCustomModelConfig_basic(rName), + Check: resource.ComposeTestCheckFunc( + testAccCheckCustomModelExists(ctx, resourceName, &v), + ), + }, + { + PreConfig: func() { + testAccWaitModelCustomizationJobCompleted(ctx, t, &v) + }, + Config: testAccCustomModelDataSourceConfig_basic(rName), + Check: resource.ComposeAggregateTestCheckFunc( + resource.TestCheckResourceAttrPair(resourceName, "hyperparameters", datasourceName, "hyperparameters"), + resource.TestCheckResourceAttrPair(resourceName, "job_arn", datasourceName, "job_arn"), + resource.TestCheckResourceAttrPair(resourceName, "job_name", datasourceName, "job_name"), + resource.TestCheckResourceAttrPair(resourceName, "custom_model_arn", datasourceName, "model_arn"), + resource.TestCheckResourceAttrPair(resourceName, "custom_model_kms_key_id", datasourceName, "model_kms_key_arn"), + resource.TestCheckResourceAttrPair(resourceName, "custom_model_name", datasourceName, "model_name"), + resource.TestCheckResourceAttrPair(resourceName, "output_data_config.#", datasourceName, "output_data_config.#"), + resource.TestCheckResourceAttrPair(resourceName, "training_data_config.#", datasourceName, "training_data_config.#"), + resource.TestCheckResourceAttrPair(resourceName, "training_metrics.#", datasourceName, "training_metrics.#"), + resource.TestCheckResourceAttrPair(resourceName, "validation_data_config.#", datasourceName, "validation_data_config.#"), + resource.TestCheckResourceAttrPair(resourceName, "validation_metrics.#", datasourceName, "validation_metrics.#"), + ), + }, + }, + }) +} + +func testAccCustomModelDataSourceConfig_basic(rName string) string { + return acctest.ConfigCompose(testAccCustomModelConfig_basic(rName), ` +data "aws_bedrock_custom_model" "test" { + model_id = aws_bedrock_custom_model.test.custom_model_arn +} +`) +} diff --git a/internal/service/bedrock/custom_model_test.go b/internal/service/bedrock/custom_model_test.go new file mode 100644 index 00000000000..7a04462d26c --- /dev/null +++ b/internal/service/bedrock/custom_model_test.go @@ -0,0 +1,694 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package bedrock_test + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/bedrock" + sdkacctest "github.com/hashicorp/terraform-plugin-testing/helper/acctest" + "github.com/hashicorp/terraform-plugin-testing/helper/resource" + "github.com/hashicorp/terraform-plugin-testing/terraform" + "github.com/hashicorp/terraform-provider-aws/internal/acctest" + "github.com/hashicorp/terraform-provider-aws/internal/conns" + tfbedrock "github.com/hashicorp/terraform-provider-aws/internal/service/bedrock" + "github.com/hashicorp/terraform-provider-aws/internal/tfresource" + "github.com/hashicorp/terraform-provider-aws/names" +) + +func TestAccBedrockCustomModel_basic(t *testing.T) { + ctx := acctest.Context(t) + rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix) + resourceName := "aws_bedrock_custom_model.test" + var v bedrock.GetModelCustomizationJobOutput + + resource.ParallelTest(t, resource.TestCase{ + PreCheck: func() { acctest.PreCheck(ctx, t); acctest.PreCheckPartitionHasService(t, names.BedrockEndpointID) }, + ErrorCheck: acctest.ErrorCheck(t, names.BedrockEndpointID), + ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, + CheckDestroy: testAccCheckCustomModelDestroy(ctx), + Steps: []resource.TestStep{ + { + Config: testAccCustomModelConfig_basic(rName), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckCustomModelExists(ctx, resourceName, &v), + resource.TestCheckResourceAttrSet(resourceName, "base_model_identifier"), + resource.TestCheckNoResourceAttr(resourceName, "custom_model_arn"), + resource.TestCheckNoResourceAttr(resourceName, "custom_model_kms_key_id"), + resource.TestCheckResourceAttr(resourceName, "custom_model_name", rName), + resource.TestCheckResourceAttr(resourceName, "customization_type", "FINE_TUNING"), + resource.TestCheckResourceAttr(resourceName, "hyperparameters.%", "4"), + resource.TestCheckResourceAttr(resourceName, "hyperparameters.batchSize", "1"), + resource.TestCheckResourceAttr(resourceName, "hyperparameters.epochCount", "1"), + resource.TestCheckResourceAttr(resourceName, "hyperparameters.learningRate", "0.005"), + resource.TestCheckResourceAttr(resourceName, "hyperparameters.learningRateWarmupSteps", "0"), + resource.TestCheckResourceAttrSet(resourceName, "job_arn"), + resource.TestCheckResourceAttr(resourceName, "job_name", rName), + resource.TestCheckResourceAttr(resourceName, "job_status", "InProgress"), + resource.TestCheckResourceAttr(resourceName, "output_data_config.#", "1"), + resource.TestCheckResourceAttrSet(resourceName, "output_data_config.0.s3_uri"), + resource.TestCheckResourceAttrSet(resourceName, "role_arn"), + resource.TestCheckResourceAttr(resourceName, "tags.%", "0"), + resource.TestCheckResourceAttr(resourceName, "training_data_config.#", "1"), + resource.TestCheckResourceAttrSet(resourceName, "training_data_config.0.s3_uri"), + resource.TestCheckNoResourceAttr(resourceName, "training_metrics"), + resource.TestCheckResourceAttr(resourceName, "validation_data_config.#", "0"), + resource.TestCheckNoResourceAttr(resourceName, "validation_metrics"), + resource.TestCheckResourceAttr(resourceName, "vpc_config.#", "0"), + ), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + ImportStateVerifyIgnore: []string{"base_model_identifier"}, + }, + }, + }) +} + +func TestAccBedrockCustomModel_disappears(t *testing.T) { + ctx := acctest.Context(t) + rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix) + resourceName := "aws_bedrock_custom_model.test" + var v bedrock.GetModelCustomizationJobOutput + + resource.ParallelTest(t, resource.TestCase{ + PreCheck: func() { acctest.PreCheck(ctx, t); acctest.PreCheckPartitionHasService(t, names.BedrockEndpointID) }, + ErrorCheck: acctest.ErrorCheck(t, names.BedrockEndpointID), + ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, + CheckDestroy: testAccCheckCustomModelDestroy(ctx), + Steps: []resource.TestStep{ + { + Config: testAccCustomModelConfig_basic(rName), + Check: resource.ComposeTestCheckFunc( + testAccCheckCustomModelExists(ctx, resourceName, &v), + acctest.CheckFrameworkResourceDisappears(ctx, acctest.Provider, tfbedrock.ResourceCustomModel, resourceName), + ), + ExpectNonEmptyPlan: true, + }, + }, + }) +} + +func TestAccBedrockCustomModel_tags(t *testing.T) { + ctx := acctest.Context(t) + rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix) + resourceName := "aws_bedrock_custom_model.test" + var v bedrock.GetModelCustomizationJobOutput + + resource.ParallelTest(t, resource.TestCase{ + PreCheck: func() { acctest.PreCheck(ctx, t); acctest.PreCheckPartitionHasService(t, names.BedrockEndpointID) }, + ErrorCheck: acctest.ErrorCheck(t, names.BedrockEndpointID), + ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, + CheckDestroy: testAccCheckCustomModelDestroy(ctx), + Steps: []resource.TestStep{ + { + Config: testAccCustomModelConfig_tags1(rName, "key1", "value1"), + Check: resource.ComposeTestCheckFunc( + testAccCheckCustomModelExists(ctx, resourceName, &v), + resource.TestCheckResourceAttr(resourceName, "tags.%", "1"), + resource.TestCheckResourceAttr(resourceName, "tags.key1", "value1"), + ), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + ImportStateVerifyIgnore: []string{"base_model_identifier"}, + }, + { + Config: testAccCustomModelConfig_tags2(rName, "key1", "value1updated", "key2", "value2"), + Check: resource.ComposeTestCheckFunc( + testAccCheckCustomModelExists(ctx, resourceName, &v), + resource.TestCheckResourceAttr(resourceName, "tags.%", "2"), + resource.TestCheckResourceAttr(resourceName, "tags.key1", "value1updated"), + resource.TestCheckResourceAttr(resourceName, "tags.key2", "value2"), + ), + }, + { + Config: testAccCustomModelConfig_tags1(rName, "key2", "value2"), + Check: resource.ComposeTestCheckFunc( + testAccCheckCustomModelExists(ctx, resourceName, &v), + resource.TestCheckResourceAttr(resourceName, "tags.%", "1"), + resource.TestCheckResourceAttr(resourceName, "tags.key2", "value2"), + ), + }, + }, + }) +} + +func TestAccBedrockCustomModel_kmsKey(t *testing.T) { + ctx := acctest.Context(t) + rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix) + resourceName := "aws_bedrock_custom_model.test" + var v bedrock.GetModelCustomizationJobOutput + + resource.ParallelTest(t, resource.TestCase{ + PreCheck: func() { acctest.PreCheck(ctx, t); acctest.PreCheckPartitionHasService(t, names.BedrockEndpointID) }, + ErrorCheck: acctest.ErrorCheck(t, names.BedrockEndpointID), + ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, + CheckDestroy: testAccCheckCustomModelDestroy(ctx), + Steps: []resource.TestStep{ + { + Config: testAccCustomModelConfig_kmsKey(rName), + Check: resource.ComposeTestCheckFunc( + testAccCheckCustomModelExists(ctx, resourceName, &v), + resource.TestCheckResourceAttrSet(resourceName, "custom_model_kms_key_id"), + ), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + ImportStateVerifyIgnore: []string{"base_model_identifier"}, + }, + }, + }) +} + +func TestAccBedrockCustomModel_validationDataConfig(t *testing.T) { + ctx := acctest.Context(t) + rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix) + resourceName := "aws_bedrock_custom_model.test" + var v bedrock.GetModelCustomizationJobOutput + + resource.ParallelTest(t, resource.TestCase{ + PreCheck: func() { acctest.PreCheck(ctx, t); acctest.PreCheckPartitionHasService(t, names.BedrockEndpointID) }, + ErrorCheck: acctest.ErrorCheck(t, names.BedrockEndpointID), + ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, + CheckDestroy: testAccCheckCustomModelDestroy(ctx), + Steps: []resource.TestStep{ + { + Config: testAccCustomModelConfig_validationDataConfig(rName), + Check: resource.ComposeTestCheckFunc( + testAccCheckCustomModelExists(ctx, resourceName, &v), + resource.TestCheckResourceAttr(resourceName, "validation_data_config.#", "1"), + resource.TestCheckResourceAttr(resourceName, "validation_data_config.0.validator.#", "1"), + resource.TestCheckResourceAttrSet(resourceName, "validation_data_config.0.validator.0.s3_uri"), + ), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + ImportStateVerifyIgnore: []string{"base_model_identifier"}, + }, + }, + }) +} + +func TestAccBedrockCustomModel_validationDataConfigWaitForCompletion(t *testing.T) { + ctx := acctest.Context(t) + if testing.Short() { + t.Skip("skipping long-running test in short mode") + } + + rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix) + resourceName := "aws_bedrock_custom_model.test" + var v bedrock.GetModelCustomizationJobOutput + + resource.ParallelTest(t, resource.TestCase{ + PreCheck: func() { acctest.PreCheck(ctx, t); acctest.PreCheckPartitionHasService(t, names.BedrockEndpointID) }, + ErrorCheck: acctest.ErrorCheck(t, names.BedrockEndpointID), + ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, + CheckDestroy: testAccCheckCustomModelDestroy(ctx), + Steps: []resource.TestStep{ + { + Config: testAccCustomModelConfig_validationDataConfig(rName), + Check: resource.ComposeTestCheckFunc( + testAccCheckCustomModelExists(ctx, resourceName, &v), + resource.TestCheckResourceAttr(resourceName, "validation_data_config.#", "1"), + resource.TestCheckResourceAttr(resourceName, "validation_data_config.0.validator.#", "1"), + resource.TestCheckResourceAttrSet(resourceName, "validation_data_config.0.validator.0.s3_uri"), + ), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + ImportStateVerifyIgnore: []string{"base_model_identifier"}, + }, + { + PreConfig: func() { + testAccWaitModelCustomizationJobCompleted(ctx, t, &v) + }, + Config: testAccCustomModelConfig_validationDataConfig(rName), + Check: resource.ComposeAggregateTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "job_status", "Completed"), + resource.TestCheckResourceAttr(resourceName, "training_metrics.#", "1"), + resource.TestCheckResourceAttrSet(resourceName, "training_metrics.0.training_loss"), + resource.TestCheckResourceAttr(resourceName, "validation_metrics.#", "1"), + resource.TestCheckResourceAttrSet(resourceName, "validation_metrics.0.validation_loss"), + ), + }, + }, + }) +} + +func TestAccBedrockCustomModel_vpcConfig(t *testing.T) { + ctx := acctest.Context(t) + rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix) + resourceName := "aws_bedrock_custom_model.test" + var v bedrock.GetModelCustomizationJobOutput + + resource.ParallelTest(t, resource.TestCase{ + PreCheck: func() { acctest.PreCheck(ctx, t); acctest.PreCheckPartitionHasService(t, names.BedrockEndpointID) }, + ErrorCheck: acctest.ErrorCheck(t, names.BedrockEndpointID), + ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, + CheckDestroy: testAccCheckCustomModelDestroy(ctx), + Steps: []resource.TestStep{ + { + Config: testAccCustomModelConfig_vpcConfig(rName), + Check: resource.ComposeTestCheckFunc( + testAccCheckCustomModelExists(ctx, resourceName, &v), + resource.TestCheckResourceAttr(resourceName, "vpc_config.#", "1"), + resource.TestCheckResourceAttr(resourceName, "vpc_config.0.security_group_ids.#", "1"), + resource.TestCheckResourceAttr(resourceName, "vpc_config.0.subnet_ids.#", "2"), + ), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + ImportStateVerifyIgnore: []string{"base_model_identifier"}, + }, + }, + }) +} + +func testAccWaitModelCustomizationJobCompleted(ctx context.Context, t *testing.T, v *bedrock.GetModelCustomizationJobOutput) { + conn := acctest.Provider.Meta().(*conns.AWSClient).BedrockClient(ctx) + + jobARN := aws.ToString(v.JobArn) + const ( + timeout = 2 * time.Hour + ) + _, err := tfbedrock.WaitModelCustomizationJobCompleted(ctx, conn, jobARN, timeout) + + if err != nil { + t.Logf("waiting for Bedrock Custom Model customization job (%s) complete: %s", jobARN, err) + } +} + +func testAccCheckCustomModelDestroy(ctx context.Context) resource.TestCheckFunc { + return func(s *terraform.State) error { + conn := acctest.Provider.Meta().(*conns.AWSClient).BedrockClient(ctx) + + for _, rs := range s.RootModule().Resources { + if rs.Type != "aws_bedrock_custom_model" { + continue + } + + output, err := tfbedrock.FindModelCustomizationJobByID(ctx, conn, rs.Primary.ID) + + if tfresource.NotFound(err) { + continue + } + + if err != nil { + return err + } + + // Check the custom model. + if modelARN := aws.ToString(output.OutputModelArn); modelARN != "" { + _, err := tfbedrock.FindCustomModelByID(ctx, conn, modelARN) + + if tfresource.NotFound(err) { + continue + } + + if err != nil { + return err + } + } + + return fmt.Errorf("Bedrock Custom Model %s still exists", rs.Primary.ID) + } + + return nil + } +} + +func testAccCheckCustomModelExists(ctx context.Context, n string, v *bedrock.GetModelCustomizationJobOutput) resource.TestCheckFunc { + return func(s *terraform.State) error { + rs, ok := s.RootModule().Resources[n] + if !ok { + return fmt.Errorf("Not found: %s", n) + } + + conn := acctest.Provider.Meta().(*conns.AWSClient).BedrockClient(ctx) + + output, err := tfbedrock.FindModelCustomizationJobByID(ctx, conn, rs.Primary.ID) + + if err != nil { + return err + } + + *v = *output + + return nil + } +} + +func testAccCustomModelConfig_base(rName string) string { + return fmt.Sprintf(` +data "aws_caller_identity" "current" {} +data "aws_region" "current" {} +data "aws_partition" "current" {} + +resource "aws_s3_bucket" "training" { + bucket = "%[1]s-training" +} + +resource "aws_s3_bucket" "validation" { + bucket = "%[1]s-validation" +} + +resource "aws_s3_bucket" "output" { + bucket = "%[1]s-output" + force_destroy = true +} + +resource "aws_s3_object" "training" { + bucket = aws_s3_bucket.training.id + key = "data/train.jsonl" + source = "test-fixtures/train.jsonl" +} + +resource "aws_s3_object" "validation" { + bucket = aws_s3_bucket.validation.id + key = "data/validate.jsonl" + source = "test-fixtures/validate.jsonl" +} + +resource "aws_iam_role" "test" { + name = %[1]q + + # See https://docs.aws.amazon.com/bedrock/latest/userguide/model-customization-iam-role.html#model-customization-iam-role-trust. + assume_role_policy = < 0 { + return tags + } + } + + return nil +} + +// setTagsOut sets bedrock service tags in Context. +func setTagsOut(ctx context.Context, tags []awstypes.Tag) { + if inContext, ok := tftags.FromContext(ctx); ok { + inContext.TagsOut = option.Some(KeyValueTags(ctx, tags)) + } +} + +// updateTags updates bedrock service tags. +// The identifier is typically the Amazon Resource Name (ARN), although +// it may also be a different identifier depending on the service. +func updateTags(ctx context.Context, conn *bedrock.Client, identifier string, oldTagsMap, newTagsMap any, optFns ...func(*bedrock.Options)) error { + oldTags := tftags.New(ctx, oldTagsMap) + newTags := tftags.New(ctx, newTagsMap) + + ctx = tflog.SetField(ctx, logging.KeyResourceId, identifier) + + removedTags := oldTags.Removed(newTags) + removedTags = removedTags.IgnoreSystem(names.Bedrock) + if len(removedTags) > 0 { + input := &bedrock.UntagResourceInput{ + ResourceARN: aws.String(identifier), + TagKeys: removedTags.Keys(), + } + + _, err := conn.UntagResource(ctx, input, optFns...) + + if err != nil { + return fmt.Errorf("untagging resource (%s): %w", identifier, err) + } + } + + updatedTags := oldTags.Updated(newTags) + updatedTags = updatedTags.IgnoreSystem(names.Bedrock) + if len(updatedTags) > 0 { + input := &bedrock.TagResourceInput{ + ResourceARN: aws.String(identifier), + Tags: Tags(updatedTags), + } + + _, err := conn.TagResource(ctx, input, optFns...) + + if err != nil { + return fmt.Errorf("tagging resource (%s): %w", identifier, err) + } + } + + return nil +} + +// UpdateTags updates bedrock service tags. +// It is called from outside this package. +func (p *servicePackage) UpdateTags(ctx context.Context, meta any, identifier string, oldTags, newTags any) error { + return updateTags(ctx, meta.(*conns.AWSClient).BedrockClient(ctx), identifier, oldTags, newTags) +} diff --git a/internal/service/bedrock/test-fixtures/train.jsonl b/internal/service/bedrock/test-fixtures/train.jsonl new file mode 100644 index 00000000000..edbb920603f --- /dev/null +++ b/internal/service/bedrock/test-fixtures/train.jsonl @@ -0,0 +1 @@ +{"prompt": "what is AWS", "completion": "it's Amazon Web Services"} \ No newline at end of file diff --git a/internal/service/bedrock/test-fixtures/validate.jsonl b/internal/service/bedrock/test-fixtures/validate.jsonl new file mode 100644 index 00000000000..edbb920603f --- /dev/null +++ b/internal/service/bedrock/test-fixtures/validate.jsonl @@ -0,0 +1 @@ +{"prompt": "what is AWS", "completion": "it's Amazon Web Services"} \ No newline at end of file diff --git a/internal/service/route53domains/delegation_signer_record.go b/internal/service/route53domains/delegation_signer_record.go index b4758d99ed7..5196d05b617 100644 --- a/internal/service/route53domains/delegation_signer_record.go +++ b/internal/service/route53domains/delegation_signer_record.go @@ -42,7 +42,7 @@ func newDelegationSignerRecordResource(context.Context) (resource.ResourceWithCo type delegationSignerRecordResource struct { framework.ResourceWithConfigure - framework.WithNoOpUpdate + framework.WithNoUpdate framework.WithTimeouts framework.WithImportByID } diff --git a/website/docs/d/bedrock_custom_model.html.markdown b/website/docs/d/bedrock_custom_model.html.markdown new file mode 100644 index 00000000000..fb90ff3a7dc --- /dev/null +++ b/website/docs/d/bedrock_custom_model.html.markdown @@ -0,0 +1,50 @@ +--- +subcategory: "Amazon Bedrock" +layout: "aws" +page_title: "AWS: aws_bedrock_custom_model" +description: |- + Returns properties of a specific Amazon Bedrock custom model. +--- + +# Data Source: aws_bedrock_custom_model + +Returns properties of a specific Amazon Bedrock custom model. + +## Example Usage + +```terraform +data "aws_bedrock_custom_model" "test" { + model_id = "arn:aws:bedrock:us-west-2:123456789012:custom-model/amazon.titan-text-express-v1:0:8k/ly16hhi765j4 " +} +``` + +## Argument Reference + +* `model_id` – (Required) Name or ARN of the custom model. + +## Attribute Reference + +This data source exports the following attributes in addition to the arguments above: + +* `base_model_arn` - ARN of the base model. +* `creation_time` - Creation time of the model. +* `hyperparameters` - Hyperparameter values associated with this model. +* `job_arn` - Job ARN associated with this model. +* `job_name` - Job name associated with this model. +* `job_tags` - Key-value mapping of tags for the fine-tuning job. +* `model_arn` - ARN associated with this model. +* `model_kms_key_arn` - The custom model is encrypted at rest using this key. +* `model_name` - Model name associated with this model. +* `model_tags` - Key-value mapping of tags for the model. +* `output_data_config` - Output data configuration associated with this custom model. + * `s3_uri` - The S3 URI where the output data is stored. +* `training_data_config` - Information about the training dataset. + * `s3_uri` - The S3 URI where the training data is stored. +* `training_metrics` - Metrics associated with the customization job. + * `training_loss` - Loss metric associated with the customization job. +* `validation_data_config` - Information about the validation dataset. + * `validator` - Information about the validators. + * `s3_uri` - The S3 URI where the validation data is stored.. +* `validation_metrics` - The loss metric for each validator that you provided. + * `validation_loss` - The validation loss associated with the validator. + \ No newline at end of file diff --git a/website/docs/d/bedrock_custom_models.html.markdown b/website/docs/d/bedrock_custom_models.html.markdown new file mode 100644 index 00000000000..eb55165956d --- /dev/null +++ b/website/docs/d/bedrock_custom_models.html.markdown @@ -0,0 +1,30 @@ +--- +subcategory: "Amazon Bedrock" +layout: "aws" +page_title: "AWS: aws_bedrock_custom_models" +description: |- + Returns a list of Amazon Bedrock custom models. +--- + +# Data Source: aws_bedrock_custom_models + +Returns a list of Amazon Bedrock custom models. + +## Example Usage + +```terraform +data "aws_bedrock_custom_models" "test" {} +``` + +## Argument Reference + +None. + +## Attribute Reference + +This data source exports the following attributes in addition to the arguments above: + +* `model_summaries` - Model summaries. + * `creation_time` - Creation time of the model. + * `model_arn` - The ARN of the custom model. + * `model_name` - The name of the custom model. diff --git a/website/docs/r/bedrock_custom_model.html.markdown b/website/docs/r/bedrock_custom_model.html.markdown new file mode 100644 index 00000000000..631e8cb21dc --- /dev/null +++ b/website/docs/r/bedrock_custom_model.html.markdown @@ -0,0 +1,113 @@ +--- +subcategory: "Amazon Bedrock" +layout: "aws" +page_title: "AWS: aws_bedrock_custom_model" +description: |- + Manages an Amazon Bedrock custom model. +--- + +# Resource: aws_bedrock_custom_model + +Manages an Amazon Bedrock custom model. +Model customization is the process of providing training data to a base model in order to improve its performance for specific use-cases. + +This Terraform resource interacts with two Amazon Bedrock entities: + +1. A Continued Pre-training or Fine-tuning job which is started when the Terraform resource is created. The customization job can take several hours to run to completion. The duration of the job depends on the size of the training data (number of records, input tokens, and output tokens), and [hyperparameters](https://docs.aws.amazon.com/bedrock/latest/userguide/custom-models-hp.html) (number of epochs, and batch size). +2. The custom model output on successful completion of the customization job. + +This resource's [behaviors](https://developer.hashicorp.com/terraform/language/resources/behavior) correspond to operations on these Amazon Bedrock entities: + +* [_Create_](https://developer.hashicorp.com/terraform/plugin/framework/resources/create) starts the customization job and immediately returns. +* [_Read_](https://developer.hashicorp.com/terraform/plugin/framework/resources/read) returns the status and results of the customization job. If the customization job has completed, the output model's properties are returned. +* [_Update_](https://developer.hashicorp.com/terraform/plugin/framework/resources/update) updates the customization job's [tags](https://docs.aws.amazon.com/bedrock/latest/userguide/tagging.html). +* [_Delete_](https://developer.hashicorp.com/terraform/plugin/framework/resources/delete) stops the customization job if it is still active. If the customization job has completed, the custom model output by the job is deleted. + +## Example Usage + +```terraform +data "aws_bedrock_foundation_model" "example" { + model_id = "amazon.titan-text-express-v1" +} + +resource "aws_bedrock_custom_model" "example" { + custom_model_name = "example-model" + job_name = "example-job-1" + base_model_identifier = data.aws_bedrock_foundation_model.example.model_arn + role_arn = aws_iam_role.example.arn + + hyperparameters = { + "epochCount" = "1" + "batchSize" = "1" + "learningRate" = "0.005" + "learningRateWarmupSteps" = "0" + } + + output_data_config { + s3_uri = "s3://${aws_s3_bucket.output.id}/data/" + } + + training_data_config { + s3_uri = "s3://${aws_s3_bucket.training.id}/data/train.jsonl" + } +} +``` + +## Argument Reference + +This resource supports the following arguments: + +* `base_model_identifier` - (Required) The Amazon Resource Name (ARN) of the base model. +* `custom_model_kms_key_id` - (Optional) The custom model is encrypted at rest using this key. Specify the key ARN. +* `custom_model_name` - (Required) Name for the custom model. +* `customization_type` -(Optional) The customization type. Valid values: `FINE_TUNING`, `CONTINUED_PRE_TRAINING`. +* `hyperparameters` - (Required) [Parameters](https://docs.aws.amazon.com/bedrock/latest/userguide/custom-models-hp.html) related to tuning the model. +* `job_name` - (Required) A name for the customization job. +* `output_data_config` - (Required) S3 location for the output data. + * `s3_uri` - (Required) The S3 URI where the output data is stored. +* `role_arn` - (Required) The Amazon Resource Name (ARN) of an IAM role that Bedrock can assume to perform tasks on your behalf. +* `tags` - (Optional) A map of tags to assign to the customization job and custom model. If configured with a provider [`default_tags` configuration block](https://registry.terraform.io/providers/hashicorp/aws/latest/docs#default_tags-configuration-block) present, tags with matching keys will overwrite those defined at the provider-level. +* `training_data_config` - (Required) Information about the training dataset. + * `s3_uri` - (Required) The S3 URI where the training data is stored. +* `validation_data_config` - (Optional) Information about the validation dataset. + * `validator` - (Required) Information about the validators. + * `s3_uri` - (Required) The S3 URI where the validation data is stored. +* `vpc_config` - (Optional) Configuration parameters for the private Virtual Private Cloud (VPC) that contains the resources you are using for this job. + * `security_group_ids` – (Required) VPC configuration security group IDs. + * `subnet_ids` – (Required) VPC configuration subnets. + +## Attribute Reference + +This resource exports the following attributes in addition to the arguments above: + +* `custom_model_arn` - The ARN of the output model. +* `job_arn` - The ARN of the customization job. +* `job_status` - The status of the customization job. A successful job transitions from `InProgress` to `Completed` when the output model is ready to use. +* `tags_all` - Map of tags assigned to the resource, including those inherited from the provider [`default_tags` configuration block](https://registry.terraform.io/providers/hashicorp/aws/latest/docs#default_tags-configuration-block). +* `training_metrics` - Metrics associated with the customization job. + * `training_loss` - Loss metric associated with the customization job. +* `validation_metrics` - The loss metric for each validator that you provided. + * `validation_loss` - The validation loss associated with the validator. + +## Timeouts + +[Configuration options](https://developer.hashicorp.com/terraform/language/resources/syntax#operation-timeouts): + +* `delete` - (Default `120m`) + +## Import + +In Terraform v1.5.0 and later, use an [`import` block](https://developer.hashicorp.com/terraform/language/import) to import Bedrock Custom Model using the `job_arn`. For example: + +```terraform +import { + to = aws_bedrock_custom_model.example + model_id = "arn:aws:bedrock:us-west-2:123456789012:model-customization-job/amazon.titan-text-express-v1:0:8k/1y5n57gh5y2e" +} +``` + +Using `terraform import`, import Bedrock custom model using the `job_arn`. For example: + +```console +% terraform import aws_bedrock_custom_model.example arn:aws:bedrock:us-west-2:123456789012:model-customization-job/amazon.titan-text-express-v1:0:8k/1y5n57gh5y2e +```