diff --git a/ecsvdp_dh_key_agreement.go b/ecsvdp_dh_key_agreement.go index 84e3031..687210b 100644 --- a/ecsvdp_dh_key_agreement.go +++ b/ecsvdp_dh_key_agreement.go @@ -20,12 +20,12 @@ func NewEcsvdpDhKeyAgreement() *EcsvdpDhKeyAgreement { return &EcsvdpDhKeyAgreement{} } -//CalculateAgreement calculate a key following 1363 7.2.1 ECSVDP-DH +// CalculateAgreement calculate a key following 1363 7.2.1 ECSVDP-DH func (ka *EcsvdpDhKeyAgreement) CalculateAgreement(privateKey *PrivateKey, anotherPublicKey *PublicKey) ([]byte, error) { - if anotherPublicKey == nil { + if anotherPublicKey == nil || anotherPublicKey.X == nil || anotherPublicKey.Y == nil { return nil, invalidEphemeralPublicKey } - if privateKey == nil { + if privateKey == nil || privateKey.D == nil || privateKey.Curve == nil { return nil, invalidPrivateKey } diff --git a/ecsvdp_dh_key_agreement_test.go b/ecsvdp_dh_key_agreement_test.go index a492f07..b87d206 100644 --- a/ecsvdp_dh_key_agreement_test.go +++ b/ecsvdp_dh_key_agreement_test.go @@ -32,4 +32,48 @@ func TestCalculateAgreement(t *testing.T) { assert.NotNil(t, realZ) assert.Equal(t, string(expectedZ), string(realZ)) + +} + +func TestCalculateAgreementWithInvalidParams(t *testing.T) { + { + realZ, err := NewEcsvdpDhKeyAgreement().CalculateAgreement(nil, nil) + assert.Nil(t, realZ) + assert.Equal(t, invalidEphemeralPublicKey, err) + } + { + privateKey := &PrivateKey{} + anotherPublicKey := &PublicKey{} + realZ, err := NewEcsvdpDhKeyAgreement().CalculateAgreement(privateKey, anotherPublicKey) + assert.Nil(t, realZ) + assert.NotNil(t, err) + assert.Equal(t, invalidEphemeralPublicKey, err) + } + { + privateKey := &PrivateKey{} + publicBytes, err := hex.DecodeString("040fe85dfc76083c4d3e9dda070df0ce6bc5b7a837c2b7975c32df26cca3f610725fa4d126cc2d0cc23762dbb199e5f7f4bc6281946f0086ef0800d288192aa1da") + anotherPublicKey, err := DeserializePublicKey(publicBytes) + assert.Nil(t, err) + assert.NotNil(t, anotherPublicKey) + realZ, err := NewEcsvdpDhKeyAgreement().CalculateAgreement(privateKey, anotherPublicKey) + assert.Nil(t, realZ) + assert.Equal(t, invalidPrivateKey, err) + } + { + privateBytes, err := hex.DecodeString("7819b30ff63ebd35f9fcf4233ccb7ecd8e9d90db8ec977cdf7b1f7bdc212b238") + assert.Nil(t, err) + + privateKey := DeserializePrivateKey(privateBytes) + assert.NotNil(t, privateKey) + privateKey.Curve = nil + + publicBytes, err := hex.DecodeString("040fe85dfc76083c4d3e9dda070df0ce6bc5b7a837c2b7975c32df26cca3f610725fa4d126cc2d0cc23762dbb199e5f7f4bc6281946f0086ef0800d288192aa1da") + anotherPublicKey, err := DeserializePublicKey(publicBytes) + assert.Nil(t, err) + assert.NotNil(t, anotherPublicKey) + + realZ, err := NewEcsvdpDhKeyAgreement().CalculateAgreement(privateKey, anotherPublicKey) + assert.Nil(t, realZ) + assert.Equal(t, invalidPrivateKey, err) + } } diff --git a/test.sh b/test.sh new file mode 100644 index 0000000..cd2906a --- /dev/null +++ b/test.sh @@ -0,0 +1,2 @@ +go clean --cache +go test ./... \ No newline at end of file