From 4964a10e69574d4c8a003c8f611a48e39472f8ae Mon Sep 17 00:00:00 2001 From: chaosinthecrd Date: Mon, 8 Jan 2024 14:29:53 +0000 Subject: [PATCH] first pass for implementing kms support Signed-off-by: chaosinthecrd --- cryptoutil/util.go | 72 +++++++ go.mod | 15 ++ go.sum | 30 +++ signer/kms/aws/client.go | 373 +++++++++++++++++++++++++++++++++++ signer/kms/aws/go.mod | 29 +++ signer/kms/aws/go.sum | 44 +++++ signer/kms/aws/signer.go | 232 ++++++++++++++++++++++ signer/kms/signerprovider.go | 140 +++++++++++++ 8 files changed, 935 insertions(+) create mode 100644 signer/kms/aws/client.go create mode 100644 signer/kms/aws/go.mod create mode 100644 signer/kms/aws/go.sum create mode 100644 signer/kms/aws/signer.go create mode 100644 signer/kms/signerprovider.go diff --git a/cryptoutil/util.go b/cryptoutil/util.go index 84d86f48..3685689e 100644 --- a/cryptoutil/util.go +++ b/cryptoutil/util.go @@ -20,6 +20,7 @@ import ( "crypto/x509" "encoding/hex" "encoding/pem" + "errors" "fmt" "io" ) @@ -147,3 +148,74 @@ func TryParseCertificate(data []byte) (*x509.Certificate, error) { return cert, nil } + +// ComputeDigestForSigning calculates the digest value for the specified message using a hash function selected by the following process: +// +// - if a digest value is already specified in a SignOption and the length of the digest matches that of the selected hash function, the +// digest value will be returned without any further computation +// - if a hash function is given using WithCryptoSignerOpts(opts) as a SignOption, it will be used (if it is in the supported list) +// - otherwise defaultHashFunc will be used (if it is in the supported list) +func ComputeDigestForSigning(rawMessage io.Reader, defaultHashFunc crypto.Hash, supportedHashFuncs []crypto.Hash) (digest []byte, hashedWith crypto.Hash, err error) { + var cryptoSignerOpts crypto.SignerOpts = defaultHashFunc + hashedWith = cryptoSignerOpts.HashFunc() + if !isSupportedAlg(hashedWith, supportedHashFuncs) { + return nil, crypto.Hash(0), fmt.Errorf("unsupported hash algorithm: %q not in %v", hashedWith.String(), supportedHashFuncs) + } + if len(digest) > 0 { + if hashedWith != crypto.Hash(0) && len(digest) != hashedWith.Size() { + err = errors.New("unexpected length of digest for hash function specified") + } + return + } + digest, err = hashMessage(rawMessage, hashedWith) + return +} + +// ComputeDigestForVerifying calculates the digest value for the specified message using a hash function selected by the following process: +// +// - if a digest value is already specified in a SignOption and the length of the digest matches that of the selected hash function, the +// digest value will be returned without any further computation +// - if a hash function is given using WithCryptoSignerOpts(opts) as a SignOption, it will be used (if it is in the supported list) +// - otherwise defaultHashFunc will be used (if it is in the supported list) +func ComputeDigestForVerifying(rawMessage io.Reader, defaultHashFunc crypto.Hash, supportedHashFuncs []crypto.Hash) (digest []byte, hashedWith crypto.Hash, err error) { + var cryptoSignerOpts crypto.SignerOpts = defaultHashFunc + hashedWith = cryptoSignerOpts.HashFunc() + if !isSupportedAlg(hashedWith, supportedHashFuncs) { + return nil, crypto.Hash(0), fmt.Errorf("unsupported hash algorithm: %q not in %v", hashedWith.String(), supportedHashFuncs) + } + if len(digest) > 0 { + if hashedWith != crypto.Hash(0) && len(digest) != hashedWith.Size() { + err = errors.New("unexpected length of digest for hash function specified") + } + return + } + digest, err = hashMessage(rawMessage, hashedWith) + return +} + +func isSupportedAlg(alg crypto.Hash, supportedAlgs []crypto.Hash) bool { + if supportedAlgs == nil { + return true + } + for _, supportedAlg := range supportedAlgs { + if alg == supportedAlg { + return true + } + } + return false +} + +func hashMessage(rawMessage io.Reader, hashFunc crypto.Hash) ([]byte, error) { + if rawMessage == nil { + return nil, errors.New("message cannot be nil") + } + if hashFunc == crypto.Hash(0) { + return io.ReadAll(rawMessage) + } + hasher := hashFunc.New() + // avoids reading entire message into memory + if _, err := io.Copy(hasher, rawMessage); err != nil { + return nil, fmt.Errorf("hashing message: %w", err) + } + return hasher.Sum(nil), nil +} diff --git a/go.mod b/go.mod index afee1830..2d53c0cb 100644 --- a/go.mod +++ b/go.mod @@ -3,12 +3,16 @@ module github.com/in-toto/go-witness go 1.19 require ( + github.com/aws/aws-sdk-go-v2 v1.17.5 + github.com/aws/aws-sdk-go-v2/config v1.18.14 + github.com/aws/aws-sdk-go-v2/service/kms v1.20.4 github.com/davecgh/go-spew v1.1.1 github.com/digitorus/pkcs7 v0.0.0-20230220124406-51331ccfc40f github.com/digitorus/timestamp v0.0.0-20230220124323-d542479a2425 github.com/edwarnicke/gitoid v0.0.0-20220710194850-1be5bfda1f9d github.com/go-git/go-git/v5 v5.11.0 github.com/in-toto/archivista v0.2.0 + github.com/jellydator/ttlcache/v3 v3.1.1 github.com/mattn/go-isatty v0.0.20 github.com/open-policy-agent/opa v0.49.2 github.com/owenrumney/go-sarif v1.1.1 @@ -25,6 +29,16 @@ require ( dario.cat/mergo v1.0.0 // indirect filippo.io/edwards25519 v1.0.0 // indirect github.com/agnivade/levenshtein v1.1.1 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.13.14 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.12.23 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.29 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.23 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.3.30 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.23 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.12.3 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.14.3 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.18.4 // indirect + github.com/aws/smithy-go v1.13.5 // indirect github.com/cloudflare/circl v1.3.3 // indirect github.com/coreos/go-oidc/v3 v3.5.0 // indirect github.com/cyphar/filepath-securejoin v0.2.4 // indirect @@ -53,6 +67,7 @@ require ( github.com/zclconf/go-cty v1.12.1 // indirect golang.org/x/mod v0.12.0 // indirect golang.org/x/oauth2 v0.7.0 // indirect + golang.org/x/sync v0.5.0 // indirect golang.org/x/tools v0.13.0 // indirect google.golang.org/appengine v1.6.7 // indirect gopkg.in/inf.v0 v0.9.1 // indirect diff --git a/go.sum b/go.sum index dd3809c3..4f909f26 100644 --- a/go.sum +++ b/go.sum @@ -19,6 +19,32 @@ github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0/go.mod h1:t2tdK github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio= github.com/aws/aws-sdk-go v1.44.334 h1:h2bdbGb//fez6Sv6PaYv868s9liDeoYM6hYsAqTB4MU= github.com/aws/aws-sdk-go v1.44.334/go.mod h1:aVsgQcEevwlmQ7qHE9I3h+dtQgpqhFB+i8Phjh7fkwI= +github.com/aws/aws-sdk-go-v2 v1.17.5 h1:TzCUW1Nq4H8Xscph5M/skINUitxM5UBAyvm2s7XBzL4= +github.com/aws/aws-sdk-go-v2 v1.17.5/go.mod h1:uzbQtefpm44goOPmdKyAlXSNcwlRgF3ePWVW6EtJvvw= +github.com/aws/aws-sdk-go-v2/config v1.18.14 h1:rI47jCe0EzuJlAO5ptREe3LIBAyP5c7gR3wjyYVjuOM= +github.com/aws/aws-sdk-go-v2/config v1.18.14/go.mod h1:0pI6JQBHKwd0JnwAZS3VCapLKMO++UL2BOkWwyyzTnA= +github.com/aws/aws-sdk-go-v2/credentials v1.13.14 h1:jE34fUepssrhmYpvPpdbd+d39PHpuignDpNPNJguP60= +github.com/aws/aws-sdk-go-v2/credentials v1.13.14/go.mod h1:85ckagDuzdIOnZRwws1eLKnymJs3ZM1QwVC1XcuNGOY= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.12.23 h1:Kbiv9PGnQfG/imNI4L/heyUXvzKmcWSBeDvkrQz5pFc= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.12.23/go.mod h1:mOtmAg65GT1HIL/HT/PynwPbS+UG0BgCZ6vhkPqnxWo= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.29 h1:9/aKwwus0TQxppPXFmf010DFrE+ssSbzroLVYINA+xE= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.29/go.mod h1:Dip3sIGv485+xerzVv24emnjX5Sg88utCL8fwGmCeWg= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.23 h1:b/Vn141DBuLVgXbhRWIrl9g+ww7G+ScV5SzniWR13jQ= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.23/go.mod h1:mr6c4cHC+S/MMkrjtSlG4QA36kOznDep+0fga5L/fGQ= +github.com/aws/aws-sdk-go-v2/internal/ini v1.3.30 h1:IVx9L7YFhpPq0tTnGo8u8TpluFu7nAn9X3sUDMb11c0= +github.com/aws/aws-sdk-go-v2/internal/ini v1.3.30/go.mod h1:vsbq62AOBwQ1LJ/GWKFxX8beUEYeRp/Agitrxee2/qM= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.23 h1:QoOybhwRfciWUBbZ0gp9S7XaDnCuSTeK/fySB99V1ls= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.23/go.mod h1:9uPh+Hrz2Vn6oMnQYiUi/zbh3ovbnQk19YKINkQny44= +github.com/aws/aws-sdk-go-v2/service/kms v1.20.4 h1:FOY3JSIwgItCdaeuLKjtijD8Enx6BHy5nSS/V6COOeA= +github.com/aws/aws-sdk-go-v2/service/kms v1.20.4/go.mod h1:oTK4GAHgyFSGKzhReYfD19/vjtgUOPwCbm7v5MgWLW4= +github.com/aws/aws-sdk-go-v2/service/sso v1.12.3 h1:bUeZTWfF1vBdZnoNnnq70rB/CzdZD7NR2Jg2Ax+rvjA= +github.com/aws/aws-sdk-go-v2/service/sso v1.12.3/go.mod h1:jtLIhd+V+lft6ktxpItycqHqiVXrPIRjWIsFIlzMriw= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.14.3 h1:G/+7NUi+q+H0LG3v32jfV4OkaQIcpI92g0owbXKk6NY= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.14.3/go.mod h1:zVwRrfdSmbRZWkUkWjOItY7SOalnFnq/Yg2LVPqDjwc= +github.com/aws/aws-sdk-go-v2/service/sts v1.18.4 h1:j0USUNbl9c/8tBJ8setEbwxc7wva0WyoeAaFRiyTUT8= +github.com/aws/aws-sdk-go-v2/service/sts v1.18.4/go.mod h1:1mKZHLLpDMHTNSYPJ7qrcnCQdHCWsNQaT0xRvq2u80s= +github.com/aws/smithy-go v1.13.5 h1:hgz0X/DX0dGqTYpGALqXJoRKRj5oQ7150i5FdTePzO8= +github.com/aws/smithy-go v1.13.5/go.mod h1:Tg+OJXh4MB2R/uN61Ko2f6hTZwB/ZYGOtib8J3gBHzA= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/bwesterb/go-ristretto v1.2.3/go.mod h1:fUIoIZaG73pV5biE2Blr2xEzDoMj7NFEuV9ekS419A0= github.com/bytecodealliance/wasmtime-go/v3 v3.0.2 h1:3uZCA/BLTIu+DqCfguByNMJa2HVHpXvjfy0Dy7g6fuA= @@ -106,6 +132,8 @@ github.com/in-toto/archivista v0.2.0 h1:FViuHMVVETborvOqlmSYdROY8RmX3CO0V0MOhU/R github.com/in-toto/archivista v0.2.0/go.mod h1:qt9uN4TkHWUgR5A2wxRqQIBizSl32P2nI2AjESskkr0= github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 h1:BQSFePA1RWJOlocH6Fxy8MmwDt+yVQYULKfN0RoTN8A= github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99/go.mod h1:1lJo3i6rXxKeerYnT8Nvf0QmHCRC1n8sfWVwXF2Frvo= +github.com/jellydator/ttlcache/v3 v3.1.1 h1:RCgYJqo3jgvhl+fEWvjNW8thxGWsgxi+TPhRir1Y9y8= +github.com/jellydator/ttlcache/v3 v3.1.1/go.mod h1:hi7MGFdMAwZna5n2tuvh63DvFLzVKySzCVW6+0gA2n4= github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= @@ -213,6 +241,7 @@ github.com/zeebo/errs v1.3.0/go.mod h1:sgbWHsvVuTPHcqJJGQ1WhI5KbWlHYz+2+2C/LSEtC go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0= go.step.sm/crypto v0.25.2 h1:NgoI3bcNF0iLI+Rwq00brlJyFfMqseLOa8L8No3Daog= go.step.sm/crypto v0.25.2/go.mod h1:4pUEuZ+4OAf2f70RgW5oRv/rJudibcAAWQg5prC3DT8= +go.uber.org/goleak v1.2.1 h1:NBol2c7O1ZokfZ0LEU9K6Whx/KnwvepVetCUhtKja4A= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190911031432-227b76d455e7/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= @@ -255,6 +284,7 @@ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE= +golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= diff --git a/signer/kms/aws/client.go b/signer/kms/aws/client.go new file mode 100644 index 00000000..1287924c --- /dev/null +++ b/signer/kms/aws/client.go @@ -0,0 +1,373 @@ +// +// Copyright 2021 The Sigstore Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package aws implement the interface with amazon aws kms service +package aws + +import ( + "context" + "crypto" + "crypto/ecdsa" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "errors" + "fmt" + "io" + "net/http" + "os" + "regexp" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + akms "github.com/aws/aws-sdk-go-v2/service/kms" + "github.com/aws/aws-sdk-go-v2/service/kms/types" + "github.com/in-toto/go-witness/cryptoutil" + "github.com/in-toto/go-witness/signer/kms" + ttlcache "github.com/jellydator/ttlcache/v3" +) + +func init() { + kms.AddProvider(ReferenceScheme, func(ctx context.Context, ksp *kms.KMSSignerProvider) (cryptoutil.Signer, error) { + return LoadSignerVerifier(ctx, ksp) + }) +} + +const ( + cacheKey = "signer" + // ReferenceScheme schemes for various KMS services are copied from https://github.com/google/go-cloud/tree/master/secrets + ReferenceScheme = "awskms://" +) + +type awsClient struct { + client *akms.Client + endpoint string + keyID string + alias string + keyCache *ttlcache.Cache[string, cmk] +} + +var ( + errKMSReference = errors.New("kms specification should be in the format awskms://[ENDPOINT]/[ID/ALIAS/ARN] (endpoint optional)") + + // Key ID/ALIAS/ARN conforms to KMS standard documented here: https://docs.aws.amazon.com/kms/latest/developerguide/concepts.html#key-id + // Key format examples: + // Key ID: awskms:///1234abcd-12ab-34cd-56ef-1234567890ab + // Key ID with endpoint: awskms://localhost:4566/1234abcd-12ab-34cd-56ef-1234567890ab + // Key ARN: awskms:///arn:aws:kms:us-east-2:111122223333:key/1234abcd-12ab-34cd-56ef-1234567890ab + // Key ARN with endpoint: awskms://localhost:4566/arn:aws:kms:us-east-2:111122223333:key/1234abcd-12ab-34cd-56ef-1234567890ab + // Alias name: awskms:///alias/ExampleAlias + // Alias name with endpoint: awskms://localhost:4566/alias/ExampleAlias + // Alias ARN: awskms:///arn:aws:kms:us-east-2:111122223333:alias/ExampleAlias + // Alias ARN with endpoint: awskms://localhost:4566/arn:aws:kms:us-east-2:111122223333:alias/ExampleAlias + uuidRE = `m?r?k?-?[A-Fa-f0-9]{8}-?[A-Fa-f0-9]{4}-?[A-Fa-f0-9]{4}-?[A-Fa-f0-9]{4}-?[A-Fa-f0-9]{12}` + arnRE = `arn:(?:aws|aws-us-gov|aws-cn):kms:[a-z0-9-]+:\d{12}:` + hostRE = `([^/]*)/` + keyIDRE = regexp.MustCompile(`^awskms://` + hostRE + `(` + uuidRE + `)$`) + keyARNRE = regexp.MustCompile(`^awskms://` + hostRE + `(` + arnRE + `key/` + uuidRE + `)$`) + aliasNameRE = regexp.MustCompile(`^awskms://` + hostRE + `((alias/.*))$`) + aliasARNRE = regexp.MustCompile(`^awskms://` + hostRE + `(` + arnRE + `(alias/.*))$`) + allREs = []*regexp.Regexp{keyIDRE, keyARNRE, aliasNameRE, aliasARNRE} +) + +// ValidReference returns a non-nil error if the reference string is invalid +func ValidReference(ref string) error { + for _, re := range allREs { + if re.MatchString(ref) { + return nil + } + } + return errKMSReference +} + +// ParseReference parses an awskms-scheme URI into its constituent parts. +func ParseReference(resourceID string) (endpoint, keyID, alias string, err error) { + var v []string + for _, re := range allREs { + v = re.FindStringSubmatch(resourceID) + if len(v) >= 3 { + endpoint, keyID = v[1], v[2] + if len(v) == 4 { + alias = v[3] + } + return + } + } + err = fmt.Errorf("invalid awskms format %q", resourceID) + return +} + +func newAWSClient(ctx context.Context, ksp *kms.KMSSignerProvider) (*awsClient, error) { + if err := ValidReference(ksp.Reference); err != nil { + return nil, err + } + a := &awsClient{} + var err error + a.endpoint, a.keyID, a.alias, err = ParseReference(ksp.Reference) + if err != nil { + return nil, err + } + + if err := a.setupClient(ctx, ksp); err != nil { + return nil, err + } + + a.keyCache = ttlcache.New[string, cmk]( + ttlcache.WithDisableTouchOnHit[string, cmk](), + ) + + return a, nil +} + +func (a *awsClient) setupClient(ctx context.Context, ksp *kms.KMSSignerProvider) (err error) { + opts := []func(*config.LoadOptions) error{} + if a.endpoint != "" { + opts = append(opts, config.WithEndpointResolverWithOptions( + aws.EndpointResolverWithOptionsFunc(func(service, region string, options ...interface{}) (aws.Endpoint, error) { + return aws.Endpoint{ + URL: "https://" + a.endpoint, + }, nil + }), + )) + } + if os.Getenv("AWS_TLS_INSECURE_SKIP_VERIFY") == "1" { + opts = append(opts, config.WithHTTPClient(&http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // nolint: gosec + }, + })) + } + + cfg, err := config.LoadDefaultConfig(ctx, opts...) + if err != nil { + return fmt.Errorf("loading AWS config: %w", err) + } + + a.client = akms.NewFromConfig(cfg) + return +} + +type cmk struct { + KeyMetadata *types.KeyMetadata + PublicKey crypto.PublicKey +} + +func (c *cmk) HashFunc() crypto.Hash { + switch c.KeyMetadata.SigningAlgorithms[0] { + case types.SigningAlgorithmSpecRsassaPssSha256, types.SigningAlgorithmSpecRsassaPkcs1V15Sha256, types.SigningAlgorithmSpecEcdsaSha256: + return crypto.SHA256 + case types.SigningAlgorithmSpecRsassaPssSha384, types.SigningAlgorithmSpecRsassaPkcs1V15Sha384, types.SigningAlgorithmSpecEcdsaSha384: + return crypto.SHA384 + case types.SigningAlgorithmSpecRsassaPssSha512, types.SigningAlgorithmSpecRsassaPkcs1V15Sha512, types.SigningAlgorithmSpecEcdsaSha512: + return crypto.SHA512 + default: + return 0 + } +} + +func (c *cmk) Verifier() (cryptoutil.Verifier, error) { + switch c.KeyMetadata.SigningAlgorithms[0] { + case types.SigningAlgorithmSpecRsassaPssSha256, types.SigningAlgorithmSpecRsassaPssSha384, types.SigningAlgorithmSpecRsassaPssSha512: + pub, ok := c.PublicKey.(*rsa.PublicKey) + if !ok { + return nil, fmt.Errorf("public key is not rsa") + } + return cryptoutil.NewRSAVerifier(pub, c.HashFunc()), nil + case types.SigningAlgorithmSpecRsassaPkcs1V15Sha256, types.SigningAlgorithmSpecRsassaPkcs1V15Sha384, types.SigningAlgorithmSpecRsassaPkcs1V15Sha512: + pub, ok := c.PublicKey.(*rsa.PublicKey) + if !ok { + return nil, fmt.Errorf("public key is not rsa") + } + return cryptoutil.NewRSAVerifier(pub, c.HashFunc()), nil + case types.SigningAlgorithmSpecEcdsaSha256, types.SigningAlgorithmSpecEcdsaSha384, types.SigningAlgorithmSpecEcdsaSha512: + pub, ok := c.PublicKey.(*ecdsa.PublicKey) + if !ok { + return nil, fmt.Errorf("public key is not ecdsa") + } + return cryptoutil.NewECDSAVerifier(pub, c.HashFunc()), nil + default: + return nil, fmt.Errorf("signing algorithm unsupported") + } +} + +func (a *awsClient) fetchCMK(ctx context.Context) (*cmk, error) { + var err error + cmk := &cmk{} + cmk.PublicKey, err = a.fetchPublicKey(ctx) + if err != nil { + return nil, err + } + cmk.KeyMetadata, err = a.fetchKeyMetadata(ctx) + if err != nil { + return nil, err + } + return cmk, nil +} + +func (a *awsClient) getHashFunc(ctx context.Context) (crypto.Hash, error) { + cmk, err := a.getCMK(ctx) + if err != nil { + return 0, err + } + return cmk.HashFunc(), nil +} + +func (a *awsClient) getCMK(ctx context.Context) (*cmk, error) { + var lerr error + loader := ttlcache.LoaderFunc[string, cmk]( + func(c *ttlcache.Cache[string, cmk], key string) *ttlcache.Item[string, cmk] { + var k *cmk + k, lerr = a.fetchCMK(ctx) + if lerr == nil { + return c.Set(cacheKey, *k, time.Second*300) + } + return nil + }, + ) + + item := a.keyCache.Get(cacheKey, ttlcache.WithLoader[string, cmk](loader)) + if lerr == nil { + cmk := item.Value() + return &cmk, nil + } + return nil, lerr +} + +func (a *awsClient) createKey(ctx context.Context, algorithm string) (crypto.PublicKey, error) { + if a.alias == "" { + return nil, errors.New("must use alias key format") + } + + // look for existing key first + cmk, err := a.getCMK(ctx) + if err == nil { + out := cmk.PublicKey + return out, nil + } + + // return error if not *kms.NotFoundException + var errNotFound *types.NotFoundException + if !errors.As(err, &errNotFound) { + return nil, fmt.Errorf("looking up key: %w", err) + } + + usage := types.KeyUsageTypeSignVerify + description := "Created by Sigstore" + key, err := a.client.CreateKey(ctx, &akms.CreateKeyInput{ + CustomerMasterKeySpec: types.CustomerMasterKeySpec(algorithm), + KeyUsage: usage, + Description: &description, + }) + if err != nil { + return nil, fmt.Errorf("creating key: %w", err) + } + + _, err = a.client.CreateAlias(ctx, &akms.CreateAliasInput{ + AliasName: &a.alias, + TargetKeyId: key.KeyMetadata.KeyId, + }) + if err != nil { + return nil, fmt.Errorf("creating alias %q: %w", a.alias, err) + } + + cmk, err = a.getCMK(ctx) + if err != nil { + return nil, fmt.Errorf("retrieving PublicKey from cache: %w", err) + } + + return cmk.PublicKey, err +} + +func (a *awsClient) verify(ctx context.Context, sig, message io.Reader) error { + cmk, err := a.getCMK(ctx) + if err != nil { + return err + } + verifier, err := cmk.Verifier() + if err != nil { + return err + } + + s, err := io.ReadAll(sig) + if err != nil { + return err + } + + return verifier.Verify(message, s) +} + +func (a *awsClient) verifyRemotely(ctx context.Context, sig, digest []byte) error { + cmk, err := a.getCMK(ctx) + if err != nil { + return err + } + alg := cmk.KeyMetadata.SigningAlgorithms[0] + messageType := types.MessageTypeDigest + if _, err := a.client.Verify(ctx, &akms.VerifyInput{ + KeyId: &a.keyID, + Message: digest, + MessageType: messageType, + Signature: sig, + SigningAlgorithm: alg, + }); err != nil { + return fmt.Errorf("unable to verify signature: %w", err) + } + return nil +} + +func (a *awsClient) sign(ctx context.Context, digest []byte, _ crypto.Hash) ([]byte, error) { + cmk, err := a.getCMK(ctx) + if err != nil { + return nil, err + } + alg := cmk.KeyMetadata.SigningAlgorithms[0] + + messageType := types.MessageTypeDigest + out, err := a.client.Sign(ctx, &akms.SignInput{ + KeyId: &a.keyID, + Message: digest, + MessageType: messageType, + SigningAlgorithm: alg, + }) + if err != nil { + return nil, fmt.Errorf("signing with kms: %w", err) + } + return out.Signature, nil +} + +func (a *awsClient) fetchPublicKey(ctx context.Context) (crypto.PublicKey, error) { + out, err := a.client.GetPublicKey(ctx, &akms.GetPublicKeyInput{ + KeyId: &a.keyID, + }) + if err != nil { + return nil, fmt.Errorf("getting public key: %w", err) + } + key, err := x509.ParsePKIXPublicKey(out.PublicKey) + if err != nil { + return nil, fmt.Errorf("parsing public key: %w", err) + } + return key, nil +} + +func (a *awsClient) fetchKeyMetadata(ctx context.Context) (*types.KeyMetadata, error) { + out, err := a.client.DescribeKey(ctx, &akms.DescribeKeyInput{ + KeyId: &a.keyID, + }) + if err != nil { + return nil, fmt.Errorf("getting key metadata: %w", err) + } + return out.KeyMetadata, nil +} diff --git a/signer/kms/aws/go.mod b/signer/kms/aws/go.mod new file mode 100644 index 00000000..54b316db --- /dev/null +++ b/signer/kms/aws/go.mod @@ -0,0 +1,29 @@ +module github.com/in-toto/go-witness/signer/kms/aws + +replace github.com/in-toto/go-witness => ../../../ + +go 1.21 + +require ( + github.com/aws/aws-sdk-go-v2 v1.24.0 + github.com/aws/aws-sdk-go-v2/config v1.26.2 + github.com/aws/aws-sdk-go-v2/service/kms v1.27.7 + github.com/in-toto/go-witness v0.0.0-00010101000000-000000000000 + github.com/jellydator/ttlcache/v3 v3.1.1 +) + +require ( + github.com/aws/aws-sdk-go-v2/credentials v1.16.13 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.14.10 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.2.9 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.5.9 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.7.2 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.10.4 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.10.9 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.18.5 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.21.5 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.26.6 // indirect + github.com/aws/smithy-go v1.19.0 // indirect + github.com/google/go-cmp v0.5.9 // indirect + golang.org/x/sync v0.5.0 // indirect +) diff --git a/signer/kms/aws/go.sum b/signer/kms/aws/go.sum new file mode 100644 index 00000000..548d2378 --- /dev/null +++ b/signer/kms/aws/go.sum @@ -0,0 +1,44 @@ +github.com/aws/aws-sdk-go-v2 v1.24.0 h1:890+mqQ+hTpNuw0gGP6/4akolQkSToDJgHfQE7AwGuk= +github.com/aws/aws-sdk-go-v2 v1.24.0/go.mod h1:LNh45Br1YAkEKaAqvmE1m8FUx6a5b/V0oAKV7of29b4= +github.com/aws/aws-sdk-go-v2/config v1.26.2 h1:+RWLEIWQIGgrz2pBPAUoGgNGs1TOyF4Hml7hCnYj2jc= +github.com/aws/aws-sdk-go-v2/config v1.26.2/go.mod h1:l6xqvUxt0Oj7PI/SUXYLNyZ9T/yBPn3YTQcJLLOdtR8= +github.com/aws/aws-sdk-go-v2/credentials v1.16.13 h1:WLABQ4Cp4vXtXfOWOS3MEZKr6AAYUpMczLhgKtAjQ/8= +github.com/aws/aws-sdk-go-v2/credentials v1.16.13/go.mod h1:Qg6x82FXwW0sJHzYruxGiuApNo31UEtJvXVSZAXeWiw= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.14.10 h1:w98BT5w+ao1/r5sUuiH6JkVzjowOKeOJRHERyy1vh58= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.14.10/go.mod h1:K2WGI7vUvkIv1HoNbfBA1bvIZ+9kL3YVmWxeKuLQsiw= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.2.9 h1:v+HbZaCGmOwnTTVS86Fleq0vPzOd7tnJGbFhP0stNLs= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.2.9/go.mod h1:Xjqy+Nyj7VDLBtCMkQYOw1QYfAEZCVLrfI0ezve8wd4= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.5.9 h1:N94sVhRACtXyVcjXxrwK1SKFIJrA9pOJ5yu2eSHnmls= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.5.9/go.mod h1:hqamLz7g1/4EJP+GH5NBhcUMLjW+gKLQabgyz6/7WAU= +github.com/aws/aws-sdk-go-v2/internal/ini v1.7.2 h1:GrSw8s0Gs/5zZ0SX+gX4zQjRnRsMJDJ2sLur1gRBhEM= +github.com/aws/aws-sdk-go-v2/internal/ini v1.7.2/go.mod h1:6fQQgfuGmw8Al/3M2IgIllycxV7ZW7WCdVSqfBeUiCY= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.10.4 h1:/b31bi3YVNlkzkBrm9LfpaKoaYZUxIAj4sHfOTmLfqw= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.10.4/go.mod h1:2aGXHFmbInwgP9ZfpmdIfOELL79zhdNYNmReK8qDfdQ= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.10.9 h1:Nf2sHxjMJR8CSImIVCONRi4g0Su3J+TSTbS7G0pUeMU= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.10.9/go.mod h1:idky4TER38YIjr2cADF1/ugFMKvZV7p//pVeV5LZbF0= +github.com/aws/aws-sdk-go-v2/service/kms v1.27.7 h1:wN7AN7iOiAgT9HmdifZNSvbr6S7gSpLjSSOQHIaGmFc= +github.com/aws/aws-sdk-go-v2/service/kms v1.27.7/go.mod h1:D9FVDkZjkZnnFHymJ3fPVz0zOUlNSd0xcIIVmmrAac8= +github.com/aws/aws-sdk-go-v2/service/sso v1.18.5 h1:ldSFWz9tEHAwHNmjx2Cvy1MjP5/L9kNoR0skc6wyOOM= +github.com/aws/aws-sdk-go-v2/service/sso v1.18.5/go.mod h1:CaFfXLYL376jgbP7VKC96uFcU8Rlavak0UlAwk1Dlhc= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.21.5 h1:2k9KmFawS63euAkY4/ixVNsYYwrwnd5fIvgEKkfZFNM= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.21.5/go.mod h1:W+nd4wWDVkSUIox9bacmkBP5NMFQeTJ/xqNabpzSR38= +github.com/aws/aws-sdk-go-v2/service/sts v1.26.6 h1:HJeiuZ2fldpd0WqngyMR6KW7ofkXNLyOaHwEIGm39Cs= +github.com/aws/aws-sdk-go-v2/service/sts v1.26.6/go.mod h1:XX5gh4CB7wAs4KhcF46G6C8a2i7eupU19dcAAE+EydU= +github.com/aws/smithy-go v1.19.0 h1:KWFKQV80DpP3vJrrA9sVAHQ5gc2z8i4EzrLhLlWXcBM= +github.com/aws/smithy-go v1.19.0/go.mod h1:NukqUGpCZIILqqiV0NIjeFh24kd/FAa4beRb6nbIUPE= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/jellydator/ttlcache/v3 v3.1.1 h1:RCgYJqo3jgvhl+fEWvjNW8thxGWsgxi+TPhRir1Y9y8= +github.com/jellydator/ttlcache/v3 v3.1.1/go.mod h1:hi7MGFdMAwZna5n2tuvh63DvFLzVKySzCVW6+0gA2n4= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +go.uber.org/goleak v1.2.1 h1:NBol2c7O1ZokfZ0LEU9K6Whx/KnwvepVetCUhtKja4A= +go.uber.org/goleak v1.2.1/go.mod h1:qlT2yGI9QafXHhZZLxlSuNsMw3FFLxBr+tBRlmO1xH4= +golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE= +golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/signer/kms/aws/signer.go b/signer/kms/aws/signer.go new file mode 100644 index 00000000..f0589f92 --- /dev/null +++ b/signer/kms/aws/signer.go @@ -0,0 +1,232 @@ +// +// Copyright 2021 The Sigstore Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package aws + +import ( + "context" + "crypto" + "fmt" + "io" + + "github.com/aws/aws-sdk-go-v2/service/kms/types" + "github.com/in-toto/go-witness/cryptoutil" + kms "github.com/in-toto/go-witness/signer/kms" +) + +var awsSupportedAlgorithms = []types.CustomerMasterKeySpec{ + types.CustomerMasterKeySpecRsa2048, + types.CustomerMasterKeySpecRsa3072, + types.CustomerMasterKeySpecRsa4096, + types.CustomerMasterKeySpecEccNistP256, + types.CustomerMasterKeySpecEccNistP384, + types.CustomerMasterKeySpecEccNistP521, +} + +var awsSupportedHashFuncs = []crypto.Hash{ + crypto.SHA256, + crypto.SHA384, + crypto.SHA512, +} + +// SignerVerifier is a cryptoutil.SignerVerifier that uses the AWS Key Management Service +type SignerVerifier struct { + client *awsClient +} + +// LoadSignerVerifier generates signatures using the specified key object in AWS KMS and hash algorithm. +func LoadSignerVerifier(ctx context.Context, ksp *kms.KMSSignerProvider) (*SignerVerifier, error) { + a := &SignerVerifier{} + + var err error + a.client, err = newAWSClient(ctx, ksp) + if err != nil { + return nil, err + } + + return a, nil +} + +// NOTE: This might ben all wrong but setting it like so for now +// KeyID returnst the key identifier for the key used by this signer. +func (a *SignerVerifier) KeyID() (string, error) { + return a.client.keyID, nil +} + +// SignMessage signs the provided message using AWS KMS. If the message is provided, +// this method will compute the digest according to the hash function specified +// when the Signer was created. +// +// SignMessage recognizes the following Options listed in order of preference: +// +// - WithContext() +// +// - WithDigest() +// +// - WithCryptoSignerOpts() +// +// All other options are ignored if specified. +func (a *SignerVerifier) Sign(message io.Reader) ([]byte, error) { + var err error + ctx := context.Background() + var digest []byte + + var signerOpts crypto.SignerOpts + signerOpts, err = a.client.getHashFunc(ctx) + if err != nil { + return nil, fmt.Errorf("getting fetching default hash function: %w", err) + } + + hf := signerOpts.HashFunc() + + digest, _, err = cryptoutil.ComputeDigestForVerifying(message, hf, awsSupportedHashFuncs) + if err != nil { + return nil, err + } + + return a.client.sign(ctx, digest, hf) +} + +// PublicKey returns the public key that can be used to verify signatures created by +// this signer. If the caller wishes to specify the context to use to obtain +// the public key, pass option.WithContext(desiredCtx). +// +// All other options are ignored if specified. +func (a *SignerVerifier) Verifier() (cryptoutil.Verifier, error) { + return a, nil +} + +// Bytes returns the bytes of the public key that can be used to verify signatures created by the signer. +func (a *SignerVerifier) Bytes() ([]byte, error) { + ctx := context.Background() + p, err := a.client.fetchPublicKey(ctx) + if err != nil { + return nil, err + } + + return cryptoutil.PublicPemBytes(p) +} + +// VerifySignature verifies the signature for the given message. Unless provided +// in an option, the digest of the message will be computed using the hash function specified +// when the SignerVerifier was created. +// +// This function returns nil if the verification succeeded, and an error message otherwise. +// +// This function recognizes the following Options listed in order of preference: +// +// - WithContext() +// +// - WithDigest() +// +// - WithRemoteVerification() +// +// - WithCryptoSignerOpts() +// +// All other options are ignored if specified. +func (a *SignerVerifier) Verify(message io.Reader, sig []byte) (err error) { + ctx := context.Background() + var digest []byte + // var remoteVerification bool + + //for _, opt := range opts { + // opt.ApplyContext(&ctx) + // opt.ApplyDigest(&digest) + // opt.ApplyRemoteVerification(&remoteVerification) + //} + + var signerOpts crypto.SignerOpts + signerOpts, err = a.client.getHashFunc(ctx) + if err != nil { + return fmt.Errorf("getting hash func: %w", err) + } + hf := signerOpts.HashFunc() + + digest, _, err = cryptoutil.ComputeDigestForVerifying(message, hf, awsSupportedHashFuncs) + if err != nil { + return err + } + + return a.client.verifyRemotely(ctx, sig, digest) +} + +// CreateKey attempts to create a new key in Vault with the specified algorithm. +func (a *SignerVerifier) CreateKey(ctx context.Context, algorithm string) (crypto.PublicKey, error) { + return a.client.createKey(ctx, algorithm) +} + +type cryptoSignerWrapper struct { + ctx context.Context + hashFunc crypto.Hash + sv *SignerVerifier + errFunc func(error) +} + +func (c cryptoSignerWrapper) Public() crypto.PublicKey { + ctx := context.Background() + + cmk, err := c.sv.client.getCMK(ctx) + if err != nil { + return nil + } + + return cmk.PublicKey +} + +func (c cryptoSignerWrapper) Sign(message io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) { + //hashFunc := c.hashFunc + //if opts != nil { + // hashFunc = opts.HashFunc() + //} + //awsOptions := []signature.SignOption{ + // options.WithContext(c.ctx), + // options.WithDigest(digest), + // options.WithCryptoSignerOpts(hashFunc), + //} + + return c.sv.Sign(message) +} + +// CryptoSigner returns a crypto.Signer object that uses the underlying SignerVerifier, along with a crypto.SignerOpts object +// that allows the KMS to be used in APIs that only accept the standard golang objects +func (a *SignerVerifier) CryptoSigner(ctx context.Context, errFunc func(error)) (crypto.Signer, crypto.SignerOpts, error) { + defaultHf, err := a.client.getHashFunc(ctx) + if err != nil { + return nil, nil, fmt.Errorf("getting fetching default hash function: %w", err) + } + + csw := &cryptoSignerWrapper{ + ctx: ctx, + sv: a, + hashFunc: defaultHf, + errFunc: errFunc, + } + + return csw, defaultHf, nil +} + +// SupportedAlgorithms returns the list of algorithms supported by the AWS KMS service +func (*SignerVerifier) SupportedAlgorithms() []string { + s := make([]string, len(awsSupportedAlgorithms)) + for i := range awsSupportedAlgorithms { + s[i] = string(awsSupportedAlgorithms[i]) + } + return s +} + +// DefaultAlgorithm returns the default algorithm for the AWS KMS service +func (*SignerVerifier) DefaultAlgorithm() string { + return string(types.CustomerMasterKeySpecEccNistP256) +} diff --git a/signer/kms/signerprovider.go b/signer/kms/signerprovider.go new file mode 100644 index 00000000..d5f805fd --- /dev/null +++ b/signer/kms/signerprovider.go @@ -0,0 +1,140 @@ +// Copyright 2023 The Witness Contributors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kms + +import ( + "context" + "crypto" + "fmt" + "strings" + + "github.com/in-toto/go-witness/cryptoutil" + "github.com/in-toto/go-witness/registry" + "github.com/in-toto/go-witness/signer" +) + +func init() { + signer.Register("kms", func() signer.SignerProvider { return New() }, + registry.StringConfigOption( + "ref", + "The KMS Reference URI to use for connecting to the KMS service", + "", + func(sp signer.SignerProvider, ref string) (signer.SignerProvider, error) { + ksp, ok := sp.(*KMSSignerProvider) + if !ok { + return sp, fmt.Errorf("provided signer provider is not a kms signer provider") + } + + WithRef(ref)(ksp) + return ksp, nil + }, + ), + registry.StringConfigOption( + "hashType", + "The hash type to use for signing", + "", + func(sp signer.SignerProvider, hash string) (signer.SignerProvider, error) { + ksp, ok := sp.(*KMSSignerProvider) + if !ok { + return sp, fmt.Errorf("provided signer provider is not a kms signer provider") + } + + WithHash(hash)(ksp) + return ksp, nil + }, + ), + ) +} + +type KMSSignerProvider struct { + Reference string + HashFunc crypto.Hash +} + +type Option func(*KMSSignerProvider) + +func WithRef(ref string) Option { + return func(ksp *KMSSignerProvider) { + ksp.Reference = ref + } +} + +func WithHash(hash string) Option { + return func(ksp *KMSSignerProvider) { + // case switch to match hash type string and set hashFunc + switch hash { + case "SHA224": + ksp.HashFunc = crypto.SHA224 + case "SHA256": + ksp.HashFunc = crypto.SHA256 + case "SHA384": + ksp.HashFunc = crypto.SHA384 + case "SHA512": + ksp.HashFunc = crypto.SHA512 + default: + ksp.HashFunc = crypto.SHA256 + } + } +} + +func New(opts ...Option) *KMSSignerProvider { + ksp := KMSSignerProvider{} + + for _, opt := range opts { + opt(&ksp) + } + + return &ksp +} + +// ProviderInit is a function that initializes provider-specific SignerVerifier. +// +// It takes a provider-specific resource ID and hash function, and returns a +// SignerVerifier using that resource, or any error that was encountered. +type ProviderInit func(context.Context, *KMSSignerProvider) (cryptoutil.Signer, error) + +// AddProvider adds the provider implementation into the local cache +func AddProvider(keyResourceID string, init ProviderInit) { + providersMap[keyResourceID] = init +} + +func (ksp *KMSSignerProvider) Signer(ctx context.Context) (cryptoutil.Signer, error) { + for ref, pi := range providersMap { + if strings.HasPrefix(ksp.Reference, ref) { + return pi(ctx, ksp) + } + } + return nil, &ProviderNotFoundError{ref: ksp.Reference} +} + +var providersMap = map[string]ProviderInit{} + +// SupportedProviders returns list of initialized providers +func SupportedProviders() []string { + keys := make([]string, 0, len(providersMap)) + for key := range providersMap { + keys = append(keys, key) + } + return keys +} + +// ProviderNotFoundError indicates that no matching KMS provider was found +type ProviderNotFoundError struct { + ref string +} + +func (e *ProviderNotFoundError) Error() string { + return fmt.Sprintf("no kms provider found for key reference: %s", e.ref) +}