diff --git a/modules/aws-http4s/src/smithy4s/aws/internals/AwsPayloadSignature.scala b/modules/aws-http4s/src/smithy4s/aws/internals/AwsPayloadSignature.scala new file mode 100644 index 000000000..6a37afe8b --- /dev/null +++ b/modules/aws-http4s/src/smithy4s/aws/internals/AwsPayloadSignature.scala @@ -0,0 +1,76 @@ +/* + * Copyright 2021-2024 Disney Streaming + * + * Licensed under the Tomorrow Open Source Technology License, Version 1.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://disneystreaming.github.io/TOST-1.0.txt + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package smithy4s.aws +package internals + +import cats.effect.Concurrent +import cats.effect.Resource +import cats.syntax.all._ +import fs2.Chunk +import org.http4s._ +import org.http4s.client.Client +import org.typelevel.ci.CIString +import smithy4s._ +import smithy4s.aws.kernel.AwsCrypto._ + +private[aws] sealed trait AwsPayloadSignature { + import AwsPayloadSignature._ + val headerValue: String = this match { + case Sha256(v) => v + case UnsignedPayload => "UNSIGNED-PAYLOAD" + // case StreamingUnsignedPayload => "STREAMING-UNSIGNED-PAYLOAD-TRAILER" + } +} + +/** + * This is a draft API. There are many other ways to include the payload in the signature. + * Some of which are complex: using trailers and/or multiple chunks + */ +private[aws] object AwsPayloadSignature { + case class Sha256(value: String) extends AwsPayloadSignature + case object UnsignedPayload extends AwsPayloadSignature + // case object StreamingUnsignedPayload extends AwsPayloadSignature + + val `X-Amz-Content-SHA256` = CIString("X-Amz-Content-SHA256") + + def makeHeader(value: AwsPayloadSignature): Header.Raw = + Header.Raw(`X-Amz-Content-SHA256`, value.headerValue) + + + def signSingleChunk[F[_]: Concurrent]: Endpoint.Middleware[Client[F]] = + new Endpoint.Middleware[Client[F]] { + def prepare[Alg[_[_, _, _, _, _]]](service: Service[Alg])( + endpoint: service.Endpoint[_, _, _, _, _] + ): Client[F] => Client[F] = { client => + Client { request => + Resource.eval(hashSingleChunk(request)).flatMap { request => + client.run(request) + } + } + } + } + + private def hashSingleChunk[F[_]: Concurrent]( + request: Request[F] + ): F[Request[F]] = { + request.body.chunks.compile.to(Chunk).map(_.flatten).map { body => + val payloadHash = sha256HexDigest(body.toArray) + val signature = AwsPayloadSignature.Sha256(payloadHash) + request.putHeaders(AwsPayloadSignature.makeHeader(signature)) + } + } +} diff --git a/modules/aws-http4s/src/smithy4s/aws/internals/AwsSigning.scala b/modules/aws-http4s/src/smithy4s/aws/internals/AwsSigning.scala index 21f55f007..cce1bb6dd 100644 --- a/modules/aws-http4s/src/smithy4s/aws/internals/AwsSigning.scala +++ b/modules/aws-http4s/src/smithy4s/aws/internals/AwsSigning.scala @@ -20,20 +20,18 @@ package internals import cats.effect.Concurrent import cats.effect.Resource import cats.syntax.all._ -import fs2.Chunk import org.http4s._ import org.http4s.client.Client import org.typelevel.ci.CIString import smithy4s._ import smithy4s.aws.kernel.AwsCrypto._ +import smithy4s.aws.internals.AwsPayloadSignature.`X-Amz-Content-SHA256` import java.net.URLEncoder import java.nio.charset.StandardCharsets /** * A Client middleware that signs http requests before they are sent to AWS. - * This works by compiling the body of the request in memory in a chunk before sending - * it back, which means it is not proper to use it in the context of streaming. */ private[aws] object AwsSigning { @@ -108,8 +106,7 @@ private[aws] object AwsSigning { // scalafmt: { align.preset = most, danglingParentheses.preset = false, maxColumn = 240, align.tokens = [{code = ":"}]} (request: Request[F]) => { - val bodyF = request.body.chunks.compile.to(Chunk).map(_.flatten) - val awsHeadersF = (bodyF, timestamp, credentials, region).mapN { case (body, timestamp, credentials, region) => + val awsHeadersF = (timestamp, credentials, region).mapN { case (timestamp, credentials, region) => val credentialsScope = s"${timestamp.conciseDate}/$region/$endpointPrefix/aws4_request" val queryParams: Vector[(String, String)] = request.uri.query.toVector.sorted.map { case (k, v) => k -> v.getOrElse("") } @@ -122,23 +119,39 @@ private[aws] object AwsSigning { } .mkString("&") - // // !\ Important: these must remain in the same order - val baseHeadersList = List( + val amzHeaders: List[(CIString, String)] = request.headers.headers + .filter(_.name.toString.toLowerCase.startsWith("x-amz")) + .map(h => (h.name, h.value)) + .filterNot(_._2 == null) + + // It is assumed that the hash value is computed before this middleware run + // via another middleware. If it is not, we use a default value. + val contentSha = amzHeaders.find(_._1 == `X-Amz-Content-SHA256`) + val payloadHash = contentSha.map(_._2).getOrElse(AwsPayloadSignature.UnsignedPayload.headerValue) + val missingContentShaHeader = + if (contentSha.isEmpty) List(`X-Amz-Content-SHA256` -> AwsPayloadSignature.UnsignedPayload.headerValue) + else List.empty + + val addedHeaders: List[(CIString, String)] = List( `Content-Type` -> request.contentType.map(contentType.value(_)).orNull, `Host` -> request.uri.host.map(_.renderString).orNull, `X-Amz-Date` -> timestamp.conciseDateTime, `X-Amz-Security-Token` -> credentials.sessionToken.orNull, `X-Amz-Target` -> (serviceName + "." + operationName) - ).filterNot(_._2 == null) + ).filterNot(_._2 == null) ++ + // we also include the header, if it was not because it is required + missingContentShaHeader + + // Headers included in the signature needs to be sorted alphabetically + val allHeaders = (addedHeaders ++ amzHeaders).sortBy(_._1) - val canonicalHeadersString = baseHeadersList + val canonicalHeadersString = allHeaders .map { case (key, value) => key.toString.toLowerCase + ":" + value.trim } .mkString(newline) - lazy val signedHeadersString = baseHeadersList.map(_._1).map(_.toString.toLowerCase()).mkString(";") + lazy val signedHeadersString = allHeaders.map(_._1).map(_.toString.toLowerCase()).mkString(";") - val payloadHash = sha256HexDigest(body.toArray) val pathString = request.uri.path.toAbsolute.renderString val canonicalRequest = new StringBuilder() .append(request.method.name.toUpperCase()) @@ -171,7 +184,7 @@ private[aws] object AwsSigning { val signature = toHexString(hmacSha256(stringToSign, signatureKey)) val authHeaderValue = s"${algorithm} Credential=${credentials.accessKeyId}/$credentialsScope, SignedHeaders=$signedHeadersString, Signature=$signature" val authHeader = Headers("Authorization" -> authHeaderValue) - val baseHeaders = Headers(baseHeadersList.map { case (k, v) => Header.Raw(k, v) }) + val baseHeaders = Headers(addedHeaders.map { case (k, v) => Header.Raw(k, v) }) authHeader ++ baseHeaders }