Skip to content

Commit

Permalink
Rewrite merkletree tests based on work of @liamsi from refactor_testi…
Browse files Browse the repository at this point in the history
…ng_code branch
  • Loading branch information
vqhuy committed Oct 28, 2017
1 parent 31c2259 commit 3bbce60
Show file tree
Hide file tree
Showing 6 changed files with 187 additions and 265 deletions.
2 changes: 1 addition & 1 deletion coniksclient/encoding.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func UnmarshalResponse(t int, msg []byte) *protocol.Response {
}

switch t {
case protocol.RegistrationType, protocol.KeyLookupInEpochType, protocol.MonitoringType:
case protocol.RegistrationType, protocol.KeyLookupType, protocol.KeyLookupInEpochType, protocol.MonitoringType:
response := new(protocol.DirectoryProof)
if err := json.Unmarshal(resp.DirectoryResponse, &response); err != nil {
return &protocol.Response{
Expand Down
32 changes: 12 additions & 20 deletions merkletree/merkletree_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,12 @@ import (
"bytes"
"testing"

"github.com/coniks-sys/coniks-go/crypto/vrf"
"github.com/coniks-sys/coniks-go/crypto"
"github.com/coniks-sys/coniks-go/utils"
"golang.org/x/crypto/sha3"
)

var vrfPrivKey1, _ = vrf.GenerateKey(bytes.NewReader(
[]byte("deterministic tests need 256 bit")))

var vrfPrivKey2, _ = vrf.GenerateKey(bytes.NewReader(
[]byte("deterministic tests need 32 byte")))

// TODO: When #178 is merged, 3 tests below should be removed.
func TestOneEntry(t *testing.T) {
m, err := NewMerkleTree()
if err != nil {
Expand All @@ -26,7 +21,7 @@ func TestOneEntry(t *testing.T) {

key := "key"
val := []byte("value")
index := vrfPrivKey1.Compute([]byte(key))
index := crypto.StaticVRF(t).Compute([]byte(key))
if err := m.Set(index, key, val); err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -90,10 +85,10 @@ func TestTwoEntries(t *testing.T) {
}

key1 := "key1"
index1 := vrfPrivKey1.Compute([]byte(key1))
index1 := crypto.StaticVRF(t).Compute([]byte(key1))
val1 := []byte("value1")
key2 := "key2"
index2 := vrfPrivKey1.Compute([]byte(key2))
index2 := crypto.StaticVRF(t).Compute([]byte(key2))
val2 := []byte("value2")

if err := m.Set(index1, key1, val1); err != nil {
Expand Down Expand Up @@ -130,13 +125,13 @@ func TestThreeEntries(t *testing.T) {
}

key1 := "key1"
index1 := vrfPrivKey1.Compute([]byte(key1))
index1 := crypto.StaticVRF(t).Compute([]byte(key1))
val1 := []byte("value1")
key2 := "key2"
index2 := vrfPrivKey1.Compute([]byte(key2))
index2 := crypto.StaticVRF(t).Compute([]byte(key2))
val2 := []byte("value2")
key3 := "key3"
index3 := vrfPrivKey1.Compute([]byte(key3))
index3 := crypto.StaticVRF(t).Compute([]byte(key3))
val3 := []byte("value3")

if err := m.Set(index1, key1, val1); err != nil {
Expand Down Expand Up @@ -191,13 +186,10 @@ func TestThreeEntries(t *testing.T) {
}

func TestInsertExistedKey(t *testing.T) {
m, err := NewMerkleTree()
if err != nil {
t.Fatal(err)
}
m := newTestTree(t)

key1 := "key"
index1 := vrfPrivKey1.Compute([]byte(key1))
index1 := crypto.StaticVRF(t).Compute([]byte(key1))
val1 := append([]byte(nil), "value"...)

if err := m.Set(index1, key1, val1); err != nil {
Expand Down Expand Up @@ -241,10 +233,10 @@ func TestInsertExistedKey(t *testing.T) {

func TestTreeClone(t *testing.T) {
key1 := "key1"
index1 := vrfPrivKey1.Compute([]byte(key1))
index1 := crypto.StaticVRF(t).Compute([]byte(key1))
val1 := []byte("value1")
key2 := "key2"
index2 := vrfPrivKey1.Compute([]byte(key2))
index2 := crypto.StaticVRF(t).Compute([]byte(key2))
val2 := []byte("value2")

m1, err := NewMerkleTree()
Expand Down
166 changes: 67 additions & 99 deletions merkletree/pad_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package merkletree

import (
"bytes"
"strconv"
"testing"

"crypto/rand"
Expand All @@ -10,16 +11,22 @@ import (
"io"

"github.com/coniks-sys/coniks-go/crypto/sign"
"github.com/coniks-sys/coniks-go/crypto/vrf"
)

var signKey sign.PrivateKey
var vrfKey vrf.PrivateKey

func init() {
var err error
signKey, err = sign.GenerateKey(nil)
if err != nil {
panic(err)
}
vrfKey, err = vrf.GenerateKey(nil)
if err != nil {
panic(err)
}
}

type TestAd struct {
Expand All @@ -35,48 +42,29 @@ func (t TestAd) Serialize() []byte {
// 3rd: epoch = 2 (key1, key2)
// 4th: epoch = 3 (key1, key2, key3) (latest STR)
func TestPADHashChain(t *testing.T) {
key1 := "key"
val1 := []byte("value")

key2 := "key2"
val2 := []byte("value2")

key3 := "key3"
val3 := []byte("value3")

N := uint64(3)
treeHashes := make(map[uint64][]byte)

pad, err := NewPAD(TestAd{""}, signKey, vrfPrivKey1, 10)
if err != nil {
t.Fatal(err)
afterCreate := func(pad *PAD) {
treeHashes[0] = append([]byte{}, pad.tree.hash...)
}
treeHashes[0] = append([]byte{}, pad.tree.hash...)

if err := pad.Set(key1, val1); err != nil {
t.Fatal(err)
}
pad.Update(nil)
treeHashes[1] = append([]byte{}, pad.tree.hash...)

if err := pad.Set(key2, val2); err != nil {
t.Fatal(err)
afterInsert := func(i uint64, pad *PAD) {
pad.Update(nil)
treeHashes[i+1] = append([]byte{}, pad.tree.hash...)
}
pad.Update(nil)
treeHashes[2] = append([]byte{}, pad.tree.hash...)

if err := pad.Set(key3, val3); err != nil {
pad, err := createPad(N, keyPrefix, valuePrefix, 10, afterCreate, afterInsert)
if err != nil {
t.Fatal(err)
}
pad.Update(nil)
treeHashes[3] = append([]byte{}, pad.tree.hash...)

for i := 0; i < 4; i++ {
for i := uint64(0); i < N; i++ {
str := pad.GetSTR(uint64(i))
if str == nil {
t.Fatal("Cannot get STR #", i)
}
if !bytes.Equal(str.TreeHash, treeHashes[uint64(i)]) {
t.Fatal("Malformed PAD Update")
t.Fatal("Malformed PAD Update:", i)
}

if str.Epoch != uint64(i) {
Expand All @@ -93,66 +81,38 @@ func TestPADHashChain(t *testing.T) {
t.Error("Got invalid STR", "want", 3, "got", str.Epoch)
}

// lookup
ap, _ := pad.Lookup(key1)
if ap.Leaf.Value == nil {
t.Error("Cannot find key:", key1)
return
}
if !bytes.Equal(ap.Leaf.Value, val1) {
t.Error(key1, "value mismatch")
}

ap, _ = pad.Lookup(key2)
if ap.Leaf.Value == nil {
t.Error("Cannot find key:", key2)
return
}
if !bytes.Equal(ap.Leaf.Value, val2) {
t.Error(key2, "value mismatch")
}

ap, _ = pad.Lookup(key3)
if ap.Leaf.Value == nil {
t.Error("Cannot find key:", key3)
return
}
if !bytes.Equal(ap.Leaf.Value, val3) {
t.Error(key3, "value mismatch")
}

ap, err = pad.LookupInEpoch(key2, 1)
if err != nil {
t.Error(err)
} else if ap.Leaf.Value != nil {
t.Error("Found unexpected key", key2, "in STR #", 1)
}
ap, err = pad.LookupInEpoch(key2, 2)
if err != nil {
t.Error(err)
} else if ap.Leaf.Value == nil {
t.Error("Cannot find key", key2, "in STR #", 2)
}
for i := uint64(0); i < N; i++ {
key := keyPrefix + strconv.FormatUint(i, 10)
val := append(valuePrefix, byte(i))
ap, _ := pad.Lookup(key)
if ap.Leaf.Value == nil {
t.Fatal("Cannot find key:", key)
}
if !bytes.Equal(ap.Leaf.Value, val) {
t.Error(key, "value mismatch")
}

ap, err = pad.LookupInEpoch(key3, 2)
if err != nil {
t.Error(err)
} else if ap.Leaf.Value != nil {
t.Error("Found unexpected key", key3, "in STR #", 2)
}

ap, err = pad.LookupInEpoch(key3, 3)
if err != nil {
t.Error(err)
} else if ap.Leaf.Value == nil {
t.Error("Cannot find key", key3, "in STR #", 3)
for epoch := uint64(0); epoch < N; epoch++ {
for keyNum := uint64(0); keyNum < N; keyNum++ {
key := keyPrefix + strconv.FormatUint(keyNum, 10)
ap, err := pad.LookupInEpoch(key, epoch)
if err != nil {
t.Error(err)
} else if keyNum < epoch && ap.Leaf.Value == nil {
t.Error("Cannot find key", key, "in STR #", epoch)
} else if keyNum >= epoch && ap.Leaf.Value != nil {
t.Error("Found unexpected key", key, "in STR #", epoch)
}
}
}
}

func TestHashChainExceedsMaximumSize(t *testing.T) {
var hashChainLimit uint64 = 4

pad, err := NewPAD(TestAd{""}, signKey, vrfPrivKey2, hashChainLimit)
pad, err := NewPAD(TestAd{""}, signKey, vrfKey, hashChainLimit)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -193,7 +153,7 @@ func TestAssocDataChange(t *testing.T) {
key3 := "key3"
val3 := []byte("value3")

pad, err := NewPAD(TestAd{""}, signKey, vrfPrivKey1, 10)
pad, err := NewPAD(TestAd{""}, signKey, vrfKey, 10)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -256,16 +216,15 @@ func TestNewPADMissingAssocData(t *testing.T) {
t.Fatal("Expected NewPAD to panic if ad are missing.")
}
}()
if _, err := NewPAD(nil, signKey, vrfPrivKey1, 10); err != nil {
if _, err := NewPAD(nil, signKey, vrfKey, 10); err != nil {
t.Fatal("Expected NewPAD to panic but got error.")
}
}

// TODO move the following to some (internal?) testutils package
type testErrorRandReader struct{}

func (er testErrorRandReader) Read([]byte) (int, error) {
return 0, errors.New("Not enough entropy!")
return 0, errors.New("not enough entropy")
}

func mockRandReadWithErroringReader() (orig io.Reader) {
Expand All @@ -282,7 +241,7 @@ func TestNewPADErrorWhileCreatingTree(t *testing.T) {
origRand := mockRandReadWithErroringReader()
defer unMockRandReader(origRand)

pad, err := NewPAD(TestAd{""}, signKey, vrfPrivKey1, 3)
pad, err := NewPAD(TestAd{""}, signKey, vrfKey, 3)
if err == nil || pad != nil {
t.Fatal("NewPad should return an error in case the tree creation failed")
}
Expand All @@ -295,14 +254,11 @@ func BenchmarkCreateLargePAD(b *testing.B) {

// total number of entries in tree:
NumEntries := uint64(1000000)
// tree.Clone and update STR every:
noUpdate := uint64(NumEntries + 1)

b.ResetTimer()
// benchmark creating a large tree (don't Update tree)
for n := 0; n < b.N; n++ {
_, err := createPad(NumEntries, keyPrefix, valuePrefix, snapLen,
noUpdate)
_, err := createPadSimple(NumEntries, keyPrefix, valuePrefix, snapLen)
if err != nil {
b.Fatal(err)
}
Expand All @@ -327,9 +283,8 @@ func benchPADUpdate(b *testing.B, entries uint64) {
keyPrefix := "key"
valuePrefix := []byte("value")
snapLen := uint64(10)
noUpdate := uint64(entries + 1)
// This takes a lot of time for a large number of entries:
pad, err := createPad(uint64(entries), keyPrefix, valuePrefix, snapLen, noUpdate)
pad, err := createPadSimple(uint64(entries), keyPrefix, valuePrefix, snapLen)
if err != nil {
b.Fatal(err)
}
Expand Down Expand Up @@ -374,9 +329,13 @@ func benchPADLookup(b *testing.B, entries uint64) {
snapLen := uint64(10)
keyPrefix := "key"
valuePrefix := []byte("value")
updateOnce := uint64(entries - 1)
pad, err := createPad(entries, keyPrefix, valuePrefix, snapLen,
updateOnce)
updateOnce := func(iteration uint64, pad *PAD) {
if iteration == entries-1 {
pad.Update(nil)
}
}

pad, err := createPad(entries, keyPrefix, valuePrefix, snapLen, nil, updateOnce)
if err != nil {
b.Fatal(err)
}
Expand Down Expand Up @@ -407,22 +366,31 @@ func benchPADLookup(b *testing.B, entries uint64) {
// The STR will get updated every epoch defined by every multiple of
// `updateEvery`. If `updateEvery > N` createPAD won't update the STR.
func createPad(N uint64, keyPrefix string, valuePrefix []byte, snapLen uint64,
updateEvery uint64) (*PAD, error) {
pad, err := NewPAD(TestAd{""}, signKey, vrfPrivKey1, snapLen)
afterCreateCB func(pad *PAD),
afterInsertCB func(iteration uint64, pad *PAD)) (*PAD, error) {
pad, err := NewPAD(TestAd{""}, signKey, vrfKey, snapLen)
if err != nil {
return nil, err
}
if afterCreateCB != nil {
afterCreateCB(pad)
}

for i := uint64(0); i < N; i++ {
key := keyPrefix + string(i)
key := keyPrefix + strconv.FormatUint(i, 10)
value := append(valuePrefix, byte(i))
if err := pad.Set(key, value); err != nil {
return nil, fmt.Errorf("Couldn't set key=%s and value=%s. Error: %v",
key, value, err)
}
if i != 0 && (i%updateEvery == 0) {
pad.Update(nil)
if afterInsertCB != nil {
afterInsertCB(i, pad)
}
}
return pad, nil
}

func createPadSimple(N uint64, keyPrefix string, valuePrefix []byte,
snapLen uint64) (*PAD, error) {
return createPad(N, keyPrefix, valuePrefix, snapLen, nil, nil)
}
Loading

0 comments on commit 3bbce60

Please sign in to comment.