diff --git a/pkcs12/resource_pkcs12.go b/pkcs12/resource_pkcs12.go index f3e99e8..5b3faa0 100644 --- a/pkcs12/resource_pkcs12.go +++ b/pkcs12/resource_pkcs12.go @@ -2,6 +2,7 @@ package pkcs12 import ( "context" + "crypto/x509" "fmt" "encoding/base64" @@ -60,7 +61,23 @@ func resourcePkcs12() *schema.Resource { } } -func resourcePkcs12Create(ctx context.Context, d *schema.ResourceData, m interface{}) diag.Diagnostics { +func decodeCerts(certStr []byte) (*x509.Certificate, []*x509.Certificate, error) { + certificates, err := decodeCertificates(certStr) + if err != nil { + return nil, nil, err + } + if len(certificates) == 0 { + return nil, nil, fmt.Errorf("cert_pem must contains at least one certificate") + } + certificate := certificates[0] + caListAndIntermediate := []*x509.Certificate{} + if len(certificates) > 1 { + caListAndIntermediate = certificates[1:] + } + return certificate, caListAndIntermediate, nil +} + +func resourcePkcs12Create(ctx context.Context, d *schema.ResourceData, _ interface{}) diag.Diagnostics { var diags diag.Diagnostics var err error certStr := d.Get("cert_pem").(string) @@ -69,15 +86,12 @@ func resourcePkcs12Create(ctx context.Context, d *schema.ResourceData, m interfa password := d.Get("password").(string) caStr := d.Get("ca_pem").(string) - certificates, err := decodeCertificates([]byte(certStr)) + certificate, caListAndIntermediate, err := decodeCerts([]byte(certStr)) + if err != nil { return diag.FromErr(err) } - if len(certificates) == 0 { - return diag.FromErr(fmt.Errorf("cert_pem must contains at least one certificate")) - } - certificate := certificates[0] - caListAndIntermediate := certificates[1:] + // Read private filekey, fails if given data does not contain any private key privateKeys, err := decodePrivateKeysFromPem([]byte(privatekeyStr), []byte(privatekeyPass)) if err != nil { @@ -100,7 +114,6 @@ func resourcePkcs12Create(ctx context.Context, d *schema.ResourceData, m interfa if err != nil { return diag.FromErr(err) } - d.SetId(hashForState("pkcs12_" + password + certStr + privatekeyStr + caStr)) d.Set("result", base64.StdEncoding.EncodeToString(res)) return diags diff --git a/pkcs12/utils_test.go b/pkcs12/utils_test.go index 299e70f..0af9c7d 100644 --- a/pkcs12/utils_test.go +++ b/pkcs12/utils_test.go @@ -235,42 +235,42 @@ ogrIU+Z+JyIPd47DI8acKlzGeR2Wn5hQrdQApC0Ve2Lvmbz8Hj67pJ4= ) func TestDecodeCertificateAllInOne(t *testing.T) { - list, err := decodeCertificates(allInOnePem) + cert, list, err := decodeCerts(allInOnePem) if err != nil { t.Error(err) t.FailNow() } - if len(list) != 3 { + if len(list) != 2 { t.Log(len(list)) t.Error("certificate list must a certificate and ca's") t.FailNow() } - if list[0].IsCA { + if cert.IsCA { t.Error("certificate[0] must not be a CA") } + if !list[0].IsCA { + t.Error("certificate[0] must be a CA") + } if !list[1].IsCA { t.Error("certificate[1] must be a CA") } - if !list[2].IsCA { - t.Error("certificate[2] must be a CA") - } } func TestDecodeCertificate(t *testing.T) { - list, err := decodeCertificates(certificateExample) + cert, list, err := decodeCerts(certificateExample) if err != nil { t.Error(err) t.FailNow() } - if len(list) != 1 { - t.Error("certificate list must contain one entry") + if len(list) != 0 { + t.Error("certificateExample must not contain any CAs") t.FailNow() } - if list[0].IsCA { + if cert.IsCA { t.Error("certificate must not ba a CA") } }