Skip to content

Commit

Permalink
Group common test utililty in an internal package
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 713949926
Change-Id: Ief74f897f2397fbf55c9e463b0ea8468fd39d7d7
  • Loading branch information
morambro authored and copybara-github committed Jan 10, 2025
1 parent 3a315f7 commit 8cca13f
Show file tree
Hide file tree
Showing 7 changed files with 152 additions and 115 deletions.
41 changes: 15 additions & 26 deletions aead/aesgcm/key_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,9 @@ import (

"google.golang.org/protobuf/proto"
"github.com/tink-crypto/tink-go/v2/core/registry"
"github.com/tink-crypto/tink-go/v2/insecuresecretdataaccess"
"github.com/tink-crypto/tink-go/v2/internal/aead"
"github.com/tink-crypto/tink-go/v2/internal/protoserialization"
"github.com/tink-crypto/tink-go/v2/keyset"
"github.com/tink-crypto/tink-go/v2/secretdata"
"github.com/tink-crypto/tink-go/v2/subtle/random"
gcmpb "github.com/tink-crypto/tink-go/v2/proto/aes_gcm_go_proto"
tinkpb "github.com/tink-crypto/tink-go/v2/proto/tink_go_proto"
Expand All @@ -47,37 +46,27 @@ var _ registry.KeyManager = (*keyManager)(nil)

// Primitive creates an AESGCM subtle for the given serialized AESGCMKey proto.
func (km *keyManager) Primitive(serializedKey []byte) (any, error) {
if len(serializedKey) == 0 {
return nil, errInvalidKey
}
protoKey := new(gcmpb.AesGcmKey)
if err := proto.Unmarshal(serializedKey, protoKey); err != nil {
return nil, errInvalidKey
}
if err := km.validateKey(protoKey); err != nil {
keySerialization, err := protoserialization.NewKeySerialization(&tinkpb.KeyData{
TypeUrl: typeURL,
Value: serializedKey,
KeyMaterialType: tinkpb.KeyData_SYMMETRIC,
}, tinkpb.OutputPrefixType_RAW, 0)
if err != nil {
return nil, err
}

keyBytes := secretdata.NewBytesFromData(protoKey.GetKeyValue(), insecuresecretdataaccess.Token{})
opts := ParametersOpts{
KeySizeInBytes: keyBytes.Len(),
IVSizeInBytes: ivSize,
TagSizeInBytes: tagSize,
Variant: VariantNoPrefix,
}
parameters, err := NewParameters(opts)
key, err := protoserialization.ParseKey(keySerialization)
if err != nil {
return nil, fmt.Errorf("aes_gcm_key_manager: cannot create new parameters: %s", err)
return nil, err
}
key, err := NewKey(keyBytes, 0, parameters)
if err != nil {
return nil, fmt.Errorf("aes_gcm_key_manager: cannot create new key: %s", err)
aesGCMKey, ok := key.(*Key)
if !ok {
return nil, fmt.Errorf("xaesgcm_key_manager: invalid key type: got %T, want %T", key, (*Key)(nil))
}
primitive, err := NewAEAD(key)
ret, err := NewAEAD(aesGCMKey)
if err != nil {
return nil, fmt.Errorf("aes_gcm_key_manager: cannot create new AEAD: %s", err)
return nil, fmt.Errorf("xaesgcm_key_manager: %v", err)
}
return primitive, nil
return ret, nil
}

// NewKey creates a new key according to specification the given serialized AESGCMKeyFormat.
Expand Down
35 changes: 10 additions & 25 deletions aead/aesgcm/key_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"github.com/google/go-cmp/cmp"
"google.golang.org/protobuf/proto"
"github.com/tink-crypto/tink-go/v2/aead/aesgcm"
aeadtestutil "github.com/tink-crypto/tink-go/v2/aead/internal/testutil"
"github.com/tink-crypto/tink-go/v2/aead/subtle"
"github.com/tink-crypto/tink-go/v2/core/registry"
"github.com/tink-crypto/tink-go/v2/insecuresecretdataaccess"
Expand Down Expand Up @@ -51,15 +52,19 @@ func TestAESGCMGetPrimitiveBasic(t *testing.T) {
if err != nil {
t.Errorf("unexpected error: %s", err)
}
aesGCM, ok := p.(tink.AEAD)
if !ok {
t.Errorf("Primitive() = %T, want tink.AEAD", p)
}
subtleAESGCM, err := subtle.NewAESGCM(key.GetKeyValue())
if err != nil {
t.Errorf("subtle.NewAESGCM(key.GetKeyValue()) err = %v, want nil", err)
}
if err := validateAESGCMPrimitive(p, subtleAESGCM); err != nil {
t.Errorf("validateAESGCMPrimitive(p, subtleAESGCM) err = %v, want nil", err)
if err := aeadtestutil.EncryptDecrypt(aesGCM, subtleAESGCM); err != nil {
t.Errorf("aeadtestutil.EncryptDecrypt(aesGCM, subtleAESGCM) err = %v, want nil", err)
}
if err := validateAESGCMPrimitive(subtleAESGCM, p); err != nil {
t.Errorf("validateAESGCMPrimitive(subtleAESGCM, p) err = %v, want nil", err)
if err := aeadtestutil.EncryptDecrypt(subtleAESGCM, aesGCM); err != nil {
t.Errorf("aeadtestutil.EncryptDecrypt(subtleAESGCM, aesGCM) err = %v, want nil", err)
}
}
}
Expand Down Expand Up @@ -463,25 +468,5 @@ func validateAESGCMKey(key *gcmpb.AesGcmKey, format *gcmpb.AesGcmKeyFormat) erro
if err != nil {
return fmt.Errorf("aesgcm.NewAEAD() err = %v, want nil", err)
}
return validateAESGCMPrimitive(p, p)
}

func validateAESGCMPrimitive(encryptor any, decryptor any) error {
aesGCMEncryptor := encryptor.(tink.AEAD)
aesGCMDecryptor := encryptor.(tink.AEAD)
// try to encrypt and decrypt
pt := random.GetRandomBytes(32)
aad := random.GetRandomBytes(32)
ct, err := aesGCMEncryptor.Encrypt(pt, aad)
if err != nil {
return fmt.Errorf("encryption failed")
}
decrypted, err := aesGCMDecryptor.Decrypt(ct, aad)
if err != nil {
return fmt.Errorf("decryption failed")
}
if !bytes.Equal(decrypted, pt) {
return fmt.Errorf("decryption failed")
}
return nil
return aeadtestutil.EncryptDecrypt(p, p)
}
39 changes: 10 additions & 29 deletions aead/aesgcmsiv/key_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,13 @@
package aesgcmsiv_test

import (
"bytes"
"fmt"
"testing"

"google.golang.org/protobuf/proto"
aeadtestutil "github.com/tink-crypto/tink-go/v2/aead/internal/testutil"
"github.com/tink-crypto/tink-go/v2/aead/subtle"
"github.com/tink-crypto/tink-go/v2/core/registry"
"github.com/tink-crypto/tink-go/v2/subtle/random"
"github.com/tink-crypto/tink-go/v2/testutil"
"github.com/tink-crypto/tink-go/v2/tink"
gcmsivpb "github.com/tink-crypto/tink-go/v2/proto/aes_gcm_siv_go_proto"
Expand Down Expand Up @@ -56,11 +55,11 @@ func TestKeyManagerGetPrimitiveBasic(t *testing.T) {
if err != nil {
t.Fatalf("subtle.NewAESGCMSIV(key.GetKeyValue()) err = %v, want nil", err)
}
if err := encryptDecrypt(aesGCMSIV, subtleAESGCMSIV); err != nil {
t.Errorf("encryptDecrypt(aesGCMSIV, subtleAESGCMSIV) err = %v, want nil", err)
if err := aeadtestutil.EncryptDecrypt(aesGCMSIV, subtleAESGCMSIV); err != nil {
t.Errorf("aeadtestutil.EncryptDecrypt(aesGCMSIV, subtleAESGCMSIV) err = %v, want nil", err)
}
if err := encryptDecrypt(subtleAESGCMSIV, aesGCMSIV); err != nil {
t.Errorf("encryptDecrypt(subtleAESGCMSIV, aesGCMSIV) err = %v, want nil", err)
if err := aeadtestutil.EncryptDecrypt(subtleAESGCMSIV, aesGCMSIV); err != nil {
t.Errorf("aeadtestutil.EncryptDecrypt(subtleAESGCMSIV, aesGCMSIV) err = %v, want nil", err)
}
})
}
Expand Down Expand Up @@ -223,11 +222,11 @@ func TestKeyManagerNewKeyDataBasic(t *testing.T) {
t.Errorf("subtle.NewAESGCMSIV(key.GetKeyValue()) err = %v, want nil", err)
continue
}
if err := encryptDecrypt(aesGCMSIV, subtleAESGCMSIV); err != nil {
t.Errorf("encryptDecrypt(aesGCMSIV, subtleAESGCMSIV) err = %v, want nil", err)
if err := aeadtestutil.EncryptDecrypt(aesGCMSIV, subtleAESGCMSIV); err != nil {
t.Errorf("aeadtestutil.EncryptDecrypt(aesGCMSIV, subtleAESGCMSIV) err = %v, want nil", err)
}
if err := encryptDecrypt(subtleAESGCMSIV, aesGCMSIV); err != nil {
t.Errorf("encryptDecrypt(subtleAESGCMSIV, aesGCMSIV) err = %v, want nil", err)
if err := aeadtestutil.EncryptDecrypt(subtleAESGCMSIV, aesGCMSIV); err != nil {
t.Errorf("aeadtestutil.EncryptDecrypt(subtleAESGCMSIV, aesGCMSIV) err = %v, want nil", err)
}
}
}
Expand Down Expand Up @@ -317,23 +316,5 @@ func validateAESGCMSIVKey(key *gcmsivpb.AesGcmSivKey, format *gcmsivpb.AesGcmSiv
if err != nil {
return fmt.Errorf("subtle.NewAESGCMSIV(key=%v): Invalid key; err=%v", key.KeyValue, err)
}
return encryptDecrypt(p, p)
}

func encryptDecrypt(encryptor, decryptor tink.AEAD) error {
// Try to encrypt and decrypt random data.
pt := random.GetRandomBytes(32)
aad := random.GetRandomBytes(32)
ct, err := encryptor.Encrypt(pt, aad)
if err != nil {
return fmt.Errorf("encryptor.Encrypt() err = %v, want nil", err)
}
decrypted, err := decryptor.Decrypt(ct, aad)
if err != nil {
return fmt.Errorf("decryptor.Decrypt() err = %v, want nil", err)
}
if !bytes.Equal(decrypted, pt) {
return fmt.Errorf("decryptor.Decrypt() = %v, want %v", decrypted, pt)
}
return nil
return aeadtestutil.EncryptDecrypt(p, p)
}
44 changes: 44 additions & 0 deletions aead/internal/testutil/testutil.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// Package testutil provides testing utilities for AEAD primitives.
package testutil

import (
"bytes"
"fmt"

"github.com/tink-crypto/tink-go/v2/subtle/random"
"github.com/tink-crypto/tink-go/v2/tink"
)

// EncryptDecrypt encrypts and decrypts random data using the given AEAD
// primitives.
func EncryptDecrypt(encryptor, decryptor tink.AEAD) error {
// Try to encrypt and decrypt random data.
pt := random.GetRandomBytes(32)
aad := random.GetRandomBytes(32)
ct, err := encryptor.Encrypt(pt, aad)
if err != nil {
return fmt.Errorf("encryptor.Encrypt() err = %v, want nil", err)
}
decrypted, err := decryptor.Decrypt(ct, aad)
if err != nil {
return fmt.Errorf("decryptor.Decrypt() err = %v, want nil", err)
}
if !bytes.Equal(decrypted, pt) {
return fmt.Errorf("decryptor.Decrypt() = %v, want %v", decrypted, pt)
}
return nil
}
57 changes: 57 additions & 0 deletions aead/internal/testutil/testutil_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package testutil_test

import (
"fmt"
"testing"

"github.com/tink-crypto/tink-go/v2/aead/internal/testutil"
"github.com/tink-crypto/tink-go/v2/aead/subtle"
"github.com/tink-crypto/tink-go/v2/subtle/random"
tinktestutil "github.com/tink-crypto/tink-go/v2/testutil"
)

func TestEncryptDecryptFailsWithFailingAEAD(t *testing.T) {
failingAEAD := tinktestutil.NewAlwaysFailingAead(fmt.Errorf("test error"))
a, err := subtle.NewAESGCM(random.GetRandomBytes(32))
if err != nil {
t.Fatalf("subtle.NewAESGCM() err = %v, want nil", err)
}
if err := testutil.EncryptDecrypt(failingAEAD, a); err == nil {
t.Errorf("EncryptDecrypt(failingAEAD, a) err = nil, want non-nil")
}
if err := testutil.EncryptDecrypt(a, failingAEAD); err == nil {
t.Errorf("EncryptDecrypt(a, failingAEAD) err = nil, want non-nil")
}
if err := testutil.EncryptDecrypt(failingAEAD, failingAEAD); err == nil {
t.Errorf("EncryptDecrypt(failingAEAD, failingAEAD) err = nil, want non-nil")
}
}

func TestEncryptDecryptWorks(t *testing.T) {
keyBytes := random.GetRandomBytes(32)
encryptor, err := subtle.NewAESGCM(keyBytes)
if err != nil {
t.Fatalf("subtle.NewAESGCM() err = %v, want nil", err)
}
decryptor, err := subtle.NewAESGCM(keyBytes)
if err != nil {
t.Fatalf("subtle.NewAESGCM() err = %v, want nil", err)
}
if err := testutil.EncryptDecrypt(encryptor, decryptor); err != nil {
t.Fatalf("subtle.NewAESGCM() err = %v, want nil", err)
}
}
4 changes: 2 additions & 2 deletions aead/xaesgcm/key_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,11 @@ func (km *keyManager) Primitive(serializedKey []byte) (any, error) {
}
xAESGCMKey, ok := key.(*Key)
if !ok {
return nil, fmt.Errorf("invalid key type: got %T, want *xaesgcm.Key", key)
return nil, fmt.Errorf("xaesgcm_key_manager: invalid key type: got %T, want %T", key, (*Key)(nil))
}
ret, err := NewAEAD(xAESGCMKey, internalapi.Token{})
if err != nil {
return nil, fmt.Errorf("xaesgcm_key_manager: cannot create new primitive: %v", err)
return nil, fmt.Errorf("xaesgcm_key_manager: %v", err)
}
return ret, nil
}
Expand Down
Loading

0 comments on commit 8cca13f

Please sign in to comment.