diff --git a/certtostore.go b/certtostore.go index 0a5a993..dd2dfd6 100644 --- a/certtostore.go +++ b/certtostore.go @@ -18,6 +18,8 @@ package certtostore import ( "bytes" "crypto" + "crypto/ecdsa" + "crypto/elliptic" "crypto/rand" "crypto/rsa" "crypto/tls" @@ -28,6 +30,10 @@ import ( "io/ioutil" "os" "path/filepath" + + // BEGIN-INTERNAL + // internal content #1 + // END-INTERNAL ) const ( @@ -106,7 +112,7 @@ type FileStorage struct { certFile string caCertFile string keyFile string - key *rsa.PrivateKey + key crypto.Signer } // NewFileStorage sets up a new file storage struct for use by StoreCert. @@ -166,14 +172,28 @@ func (f FileStorage) CertificateChain() ([][]*x509.Certificate, error) { return certificateChain(cert, intermediate) } -// Generate creates a new RSA private key and returns a signer that can be used to make a CSR for the key. +var ecdsaCurves = map[int]elliptic.Curve{ + 256: elliptic.P256(), + 384: elliptic.P384(), + 521: elliptic.P521(), +} + +// Generate creates a new ECDSA or RSA private key and returns a signer that can be used to make a CSR for the key. func (f *FileStorage) Generate(opts GenerateOpts) (crypto.Signer, error) { + var err error switch opts.Algorithm { case RSA: - var err error f.key, err = rsa.GenerateKey(rand.Reader, opts.Size) return f.key, err + case EC: + curve, ok := ecdsaCurves[opts.Size] + if !ok { + return nil, fmt.Errorf("invalid ecdsa curve size: %d", opts.Size) + } + f.key, err = ecdsa.GenerateKey(curve, rand.Reader) + return f.key, err default: + return nil, fmt.Errorf("unsupported key type: %q", opts.Algorithm) } } @@ -219,6 +239,8 @@ func (f *FileStorage) Store(cert *x509.Certificate, intermediate *x509.Certifica } // Sign returns a signature for the provided digest. +// The opts are passed to the private key's Sign method, as per the crypto.Signer interface. +// https://pkg.go.dev/crypto#Signer func (f FileStorage) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) { tlsCert, err := tls.LoadX509KeyPair(f.certFile, f.keyFile) if err != nil { @@ -232,6 +254,9 @@ func (f FileStorage) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) } // Decrypt decrypts msg. Returns an error if not implemented. +// The opts are passed to the private key's Decrypt method, as per the crypto.Decrypter interface. +// https://pkg.go.dev/crypto#Decrypter +// Only RSA keys are supported for decryption. func (f FileStorage) Decrypt(rand io.Reader, msg []byte, opts crypto.DecrypterOpts) ([]byte, error) { tlsCert, err := tls.LoadX509KeyPair(f.certFile, f.keyFile) if err != nil { diff --git a/certtostore_test.go b/certtostore_test.go index bde425d..d140141 100644 --- a/certtostore_test.go +++ b/certtostore_test.go @@ -16,6 +16,7 @@ package certtostore import ( "crypto" + "crypto/ecdsa" "crypto/rand" "crypto/rsa" "crypto/sha256" @@ -35,7 +36,7 @@ import ( // TODO(b/142911419): Create OS specific test packages to cover each // CertStorage implementation independently. -func generateCertificate(caStore CertStorage) (CertStorage, error) { +func generateCertificate(caStore CertStorage, opts GenerateOpts) (CertStorage, error) { dir, err := ioutil.TempDir("", "certstorage_cert_test") if err != nil { return nil, fmt.Errorf("ioutil.Tempdir: %v", err) @@ -60,10 +61,6 @@ func generateCertificate(caStore CertStorage) (CertStorage, error) { IsCA: false, SignatureAlgorithm: x509.SHA256WithRSA, } - opts := GenerateOpts{ - Algorithm: RSA, - Size: 2048, - } leafSigner, err := leafStore.Generate(opts) if err != nil { return nil, fmt.Errorf("leafStore.Generate(%v): %v", opts, err) @@ -96,67 +93,97 @@ func generateCertificate(caStore CertStorage) (CertStorage, error) { } func TestCredential(t *testing.T) { - ca := NewFileStorage(testdata.CAPath()) - // Use the CA CertStorage to issue a leaf cert. - leafStore, err := generateCertificate(ca) - if err != nil { - t.Fatalf("error generating certificate: %v", err) - } - // Retrieve the leaf cert. - leafCrt, err := leafStore.Cert() - if err != nil { - t.Fatalf("error retrieving certificate: %v", err) - } - // Retrieve a certificate and key for the CA. - caCrt, err := ca.Cert() - if err != nil { - t.Fatalf("error retrieving CA certificate: %v", err) - } - caKey, err := ca.Key() - if err != nil { - t.Fatalf("error retrieving CA credential: %v", err) - } - // Exercise CertificateChain. - chains, err := leafStore.CertificateChain() - if err != nil { - t.Fatalf("error retrieving certificate chain: %v", err) - } - for ci, chain := range chains { - for i, cert := range chain { - t.Logf("%d.%d: %s", ci, i, cert.Subject) - } - } - if len(chains) != 1 { - t.Fatalf("%d chains found, expected 1", len(chains)) - } - if len(chains[0]) < 2 { - t.Fatalf("%d chain entries found, expected at least 2", len(chains[0])) - } - if !leafCrt.Equal(chains[0][0]) { - t.Errorf("certificate chain[0] is not the leaf") - } - if !caCrt.Equal(chains[0][1]) { - t.Errorf("certificate chain[1] is not the ca") - } - // Exercise the CA Public key by verifying the leaf cert. - caPub := caKey.Public() - if caPub == nil { - t.Fatal("CA public key not found") - } - rsaPub, ok := caPub.(*rsa.PublicKey) - if !ok { - t.Fatal("CA public key is not RSA") - } - leafHash := sha256.Sum256(leafCrt.RawTBSCertificate) - if err := rsa.VerifyPKCS1v15(rsaPub, crypto.SHA256, leafHash[:], leafCrt.Signature); err != nil { - t.Fatalf("error verifying certificate signature: %v", err) + for _, testCase := range []struct { + name string + opts GenerateOpts + }{ + { + name: "rsa-2048", + opts: GenerateOpts{ + Algorithm: RSA, + Size: 2048, + }, + }, + { + name: "ecdsa-p256", + opts: GenerateOpts{ + Algorithm: EC, + Size: 256, + }, + }, + } { + t.Run(testCase.name, func(t *testing.T) { + ca := NewFileStorage(testdata.CAPath()) + // Use the CA CertStorage to issue a leaf cert. + leafStore, err := generateCertificate(ca, testCase.opts) + if err != nil { + t.Fatalf("error generating certificate: %v", err) + } + // Retrieve the leaf cert. + leafCrt, err := leafStore.Cert() + if err != nil { + t.Fatalf("error retrieving certificate: %v", err) + } + // Retrieve a certificate and key for the CA. + caCrt, err := ca.Cert() + if err != nil { + t.Fatalf("error retrieving CA certificate: %v", err) + } + caKey, err := ca.Key() + if err != nil { + t.Fatalf("error retrieving CA credential: %v", err) + } + // Exercise CertificateChain. + chains, err := leafStore.CertificateChain() + if err != nil { + t.Fatalf("error retrieving certificate chain: %v", err) + } + for ci, chain := range chains { + for i, cert := range chain { + t.Logf("%d.%d: %s", ci, i, cert.Subject) + } + } + if len(chains) != 1 { + t.Fatalf("%d chains found, expected 1", len(chains)) + } + if len(chains[0]) < 2 { + t.Fatalf("%d chain entries found, expected at least 2", len(chains[0])) + } + if !leafCrt.Equal(chains[0][0]) { + t.Errorf("certificate chain[0] is not the leaf") + } + if !caCrt.Equal(chains[0][1]) { + t.Errorf("certificate chain[1] is not the ca") + } + // Exercise the CA Public key by verifying the leaf cert. + caPub := caKey.Public() + if caPub == nil { + t.Fatal("CA public key not found") + } + rsaPub, ok := caPub.(*rsa.PublicKey) + if !ok { + t.Fatal("CA public key is not RSA") + } + leafHash := sha256.Sum256(leafCrt.RawTBSCertificate) + if err := rsa.VerifyPKCS1v15(rsaPub, crypto.SHA256, leafHash[:], leafCrt.Signature); err != nil { + t.Fatalf("error verifying certificate signature: %v", err) + } + }) } } func verifySig(pub crypto.PublicKey, sig []byte, digest []byte) error { + if len(digest) == 0 { + return fmt.Errorf("digest is empty") + } switch pub := pub.(type) { case *rsa.PublicKey: return rsa.VerifyPKCS1v15(pub, crypto.SHA256, digest, sig) + case *ecdsa.PublicKey: + if !ecdsa.VerifyASN1(pub, digest, sig) { + return fmt.Errorf("signature verification failed") + } + return nil default: return fmt.Errorf("unsupported public key type: %T", pub) } @@ -164,33 +191,80 @@ func verifySig(pub crypto.PublicKey, sig []byte, digest []byte) error { func TestSign(t *testing.T) { testmsg := []byte("test") - digest := sha256.Sum256(testmsg) - ca := NewFileStorage(testdata.CAPath()) - // Use the CA CertStorage to issue a leaf cert. - leafStore, err := generateCertificate(ca) - if err != nil { - t.Fatalf("error generating certificate: %v", err) - } - k, err := leafStore.Key() - if err != nil { - t.Fatalf("error retrieving key: %v", err) - } + for _, testCase := range []struct { + name string + hash crypto.Hash + opts GenerateOpts + }{ + { + name: "rsa-2048", + hash: crypto.SHA256, + opts: GenerateOpts{ + Algorithm: RSA, + Size: 2048, + }, + }, + { + name: "ecdsa-p256", + hash: crypto.SHA256, + opts: GenerateOpts{ + Algorithm: EC, + Size: 256, + }, + }, + { + name: "ecdsa-p384", + hash: crypto.SHA384, + opts: GenerateOpts{ + Algorithm: EC, + Size: 384, + }, + }, + { + name: "ecdsa-p521", + hash: crypto.SHA512, + opts: GenerateOpts{ + Algorithm: EC, + Size: 521, + }, + }, + } { + t.Run(testCase.name, func(t *testing.T) { + ca := NewFileStorage(testdata.CAPath()) + // Use the CA CertStorage to issue a leaf cert. + leafStore, err := generateCertificate(ca, testCase.opts) + if err != nil { + t.Fatalf("error generating certificate: %v", err) + } + k, err := leafStore.Key() + if err != nil { + t.Fatalf("error retrieving key: %v", err) + } - sig, err := k.Sign(rand.Reader, digest[:], crypto.SHA256) - if err != nil { - t.Fatalf("error signing: %v", err) - } - if len(sig) == 0 { - t.Fatalf("signature is empty") - } + // Hash the test message using the given hash function. + h := testCase.hash.New() + if _, err := h.Write(testmsg); err != nil { + t.Fatalf("error writing data to hash: %v", err) + } + digest := h.Sum(nil) - pub := k.Public() - if pub == nil { - t.Fatal("public key is nil") - } - err = verifySig(pub, sig, digest[:]) - if err != nil { - t.Fatalf("error verifying signature: %v", err) + sig, err := k.Sign(rand.Reader, digest, testCase.hash) + if err != nil { + t.Fatalf("error signing: %v", err) + } + if len(sig) == 0 { + t.Fatalf("signature is empty") + } + + pub := k.Public() + if pub == nil { + t.Fatal("public key is nil") + } + err = verifySig(pub, sig, digest[:]) + if err != nil { + t.Fatalf("error verifying signature: %v", err) + } + }) } } @@ -215,67 +289,85 @@ func TestDecrypt(t *testing.T) { } func TestFileStore(t *testing.T) { - pem, err := testdata.Certificate() - if err != nil { - t.Fatalf("testdata.Certificate: %v", err) - } - xc, err := PEMToX509(pem) - if err != nil { - t.Fatalf("error decoding test certificate: %v", err) - } + for _, testCase := range []struct { + name string + opts GenerateOpts + }{ + { + name: "rsa-2048", + opts: GenerateOpts{ + Algorithm: RSA, + Size: 2048, + }, + }, + { + name: "ecdsa-p256", + opts: GenerateOpts{ + Algorithm: EC, + Size: 256, + }, + }, + } { + t.Run(testCase.name, func(t *testing.T) { + pem, err := testdata.Certificate() + if err != nil { + t.Fatalf("testdata.Certificate: %v", err) + } + xc, err := PEMToX509(pem) + if err != nil { + t.Fatalf("error decoding test certificate: %v", err) + } - dir, err := ioutil.TempDir("", "certstorage_test") - if err != nil { - t.Fatalf("failed to create temporary dir: %v", err) - } - tc := NewFileStorage(dir) - cert, err := tc.Cert() - if err != nil { - t.Errorf("error while reading empty cert: %v", err) - } - if cert != nil { - t.Errorf("expected cert on new file store to be nil, instead %v", cert) - } + dir, err := ioutil.TempDir("", "certstorage_test") + if err != nil { + t.Fatalf("failed to create temporary dir: %v", err) + } + tc := NewFileStorage(dir) + cert, err := tc.Cert() + if err != nil { + t.Errorf("error while reading empty cert: %v", err) + } + if cert != nil { + t.Errorf("expected cert on new file store to be nil, instead %v", cert) + } - cert, err = tc.Intermediate() - if err != nil { - t.Errorf("error while reading empty intermediate: %v", err) - } - if cert != nil { - t.Errorf("expected intermediate on new file store to be nil, instead %v", cert) - } + cert, err = tc.Intermediate() + if err != nil { + t.Errorf("error while reading empty intermediate: %v", err) + } + if cert != nil { + t.Errorf("expected intermediate on new file store to be nil, instead %v", cert) + } - opts := GenerateOpts{ - Algorithm: RSA, - Size: 2048, - } - signer, err := tc.Generate(opts) - if err != nil { - t.Errorf("failed to generate signer: %v", err) - } - _, err = x509.CreateCertificateRequest(rand.Reader, &x509.CertificateRequest{}, signer) - if err != nil { - t.Errorf("failed to create signed CSR with signer from Generate: %v", err) - } + signer, err := tc.Generate(testCase.opts) + if err != nil { + t.Errorf("failed to generate signer: %v", err) + } + _, err = x509.CreateCertificateRequest(rand.Reader, &x509.CertificateRequest{}, signer) + if err != nil { + t.Errorf("failed to create signed CSR with signer from Generate: %v", err) + } - if err := tc.Store(xc, xc); err != nil { - t.Errorf("store failed: %v", err) - } + if err := tc.Store(xc, xc); err != nil { + t.Errorf("store failed: %v", err) + } - cert, err = tc.Cert() - if err != nil { - t.Fatalf("error while reading back written cert: %v", err) - } - if !cert.Equal(xc) { - t.Errorf("expected read-back cert to match xc, instead it's %v", cert) - } + cert, err = tc.Cert() + if err != nil { + t.Fatalf("error while reading back written cert: %v", err) + } + if !cert.Equal(xc) { + t.Errorf("expected read-back cert to match xc, instead it's %v", cert) + } - cert, err = tc.Intermediate() - if err != nil { - t.Fatalf("error while reading back written intermediate: %v", err) - } - if !cert.Equal(xc) { - t.Errorf("expected read-back intermediate to match xc, instead it's %v", cert) + cert, err = tc.Intermediate() + if err != nil { + t.Fatalf("error while reading back written intermediate: %v", err) + } + if !cert.Equal(xc) { + t.Errorf("expected read-back intermediate to match xc, instead it's %v", cert) + } + }) } }