diff --git a/signer/cosigner_nonce_cache.go b/signer/cosigner_nonce_cache.go index d364a2a6..865e7572 100644 --- a/signer/cosigner_nonce_cache.go +++ b/signer/cosigner_nonce_cache.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "sync" + "sync/atomic" "time" "github.com/strangelove-ventures/horcrux/pkg/metrics" @@ -24,7 +25,7 @@ type CosignerNonceCache struct { leader Leader - lastReconcileNonces lastCount + lastReconcileNonces atomic.Uint64 lastReconcileTime time.Time getNoncesInterval time.Duration @@ -33,7 +34,7 @@ type CosignerNonceCache struct { threshold uint8 - cache NonceCache + cache *NonceCache pruner NonceCachePruner @@ -88,29 +89,6 @@ func (m *movingAverage) average() float64 { return weightedSum / duration } -type lastCount struct { - count int - mu sync.RWMutex -} - -func (lc *lastCount) Set(n int) { - lc.mu.Lock() - defer lc.mu.Unlock() - lc.count = n -} - -func (lc *lastCount) Inc() { - lc.mu.Lock() - defer lc.mu.Unlock() - lc.count++ -} - -func (lc *lastCount) Get() int { - lc.mu.RLock() - defer lc.mu.RUnlock() - return lc.count -} - type NonceCachePruner interface { PruneNonces() int } @@ -136,6 +114,30 @@ func (nc *NonceCache) Delete(index int) { nc.cache = append(nc.cache[:index], nc.cache[index+1:]...) } +func (nc *NonceCache) PruneNonces() int { + nc.mu.Lock() + defer nc.mu.Unlock() + nonExpiredIndex := -1 + for i := 0; i < len(nc.cache); i++ { + if time.Now().Before(nc.cache[i].Expiration) { + nonExpiredIndex = i + break + } + } + + var deleteCount int + if nonExpiredIndex == -1 { + // No non-expired nonces, delete everything + deleteCount = len(nc.cache) + nc.cache = nil + } else { + // Prune everything up to the non-expired nonce + deleteCount = nonExpiredIndex + nc.cache = nc.cache[nonExpiredIndex:] + } + return deleteCount +} + type CosignerNoncesRel struct { Cosigner Cosigner Nonces CosignerNonces @@ -176,12 +178,14 @@ func NewCosignerNonceCache( nonceExpiration: nonceExpiration, threshold: threshold, pruner: pruner, - empty: make(chan struct{}, 1), - movingAverage: newMovingAverage(4 * getNoncesInterval), // weighted average over 4 intervals + cache: new(NonceCache), + // buffer up to 1000 empty events so that we don't ever block + empty: make(chan struct{}, 1000), + movingAverage: newMovingAverage(4 * getNoncesInterval), // weighted average over 4 intervals } // the only time pruner is expected to be non-nil is during tests, otherwise we use the cache logic. if pruner == nil { - cnc.pruner = cnc + cnc.pruner = cnc.cache } return cnc @@ -213,9 +217,9 @@ func (cnc *CosignerNonceCache) reconcile(ctx context.Context) { remainingNonces := cnc.cache.Size() timeSinceLastReconcile := time.Since(cnc.lastReconcileTime) - lastReconcileNonces := cnc.lastReconcileNonces.Get() + lastReconcileNonces := cnc.lastReconcileNonces.Load() // calculate nonces per minute - noncesPerMin := float64(lastReconcileNonces-remainingNonces-pruned) / timeSinceLastReconcile.Minutes() + noncesPerMin := float64(int(lastReconcileNonces)-remainingNonces-pruned) / timeSinceLastReconcile.Minutes() if noncesPerMin < 0 { noncesPerMin = 0 } @@ -232,7 +236,7 @@ func (cnc *CosignerNonceCache) reconcile(ctx context.Context) { additional := t - remainingNonces defer func() { - cnc.lastReconcileNonces.Set(remainingNonces + additional) + cnc.lastReconcileNonces.Store(uint64(remainingNonces + additional)) cnc.lastReconcileTime = time.Now() }() @@ -327,19 +331,23 @@ func (cnc *CosignerNonceCache) LoadN(ctx context.Context, n int) { } func (cnc *CosignerNonceCache) Start(ctx context.Context) { - cnc.lastReconcileNonces.Set(cnc.cache.Size()) + cnc.lastReconcileNonces.Store(uint64(cnc.cache.Size())) cnc.lastReconcileTime = time.Now() - ticker := time.NewTicker(cnc.getNoncesInterval) + ticker := time.NewTimer(cnc.getNoncesInterval) for { select { case <-ctx.Done(): return case <-ticker.C: - cnc.reconcile(ctx) case <-cnc.empty: - cnc.reconcile(ctx) + // clear out channel + for len(cnc.empty) > 0 { + <-cnc.empty + } } + cnc.reconcile(ctx) + ticker.Reset(cnc.getNoncesInterval) } } @@ -367,7 +375,7 @@ CheckNoncesLoop: // remove this set of nonces from the cache cnc.cache.Delete(i) - if len(cnc.cache.cache) == 0 { + if len(cnc.cache.cache) == 0 && len(cnc.empty) == 0 { cnc.logger.Debug("Nonce cache is empty, triggering reload") cnc.empty <- struct{}{} } @@ -380,7 +388,7 @@ CheckNoncesLoop: } // increment so it's taken into account in the nonce burn rate in the next reconciliation - cnc.lastReconcileNonces.Inc() + cnc.lastReconcileNonces.Add(1) // no nonces found cosignerInts := make([]int, len(fastestPeers)) @@ -390,28 +398,6 @@ CheckNoncesLoop: return nil, fmt.Errorf("no nonces found involving cosigners %+v", cosignerInts) } -func (cnc *CosignerNonceCache) PruneNonces() int { - cnc.cache.mu.Lock() - defer cnc.cache.mu.Unlock() - nonExpiredIndex := len(cnc.cache.cache) - 1 - for i := len(cnc.cache.cache) - 1; i >= 0; i-- { - if time.Now().Before(cnc.cache.cache[i].Expiration) { - nonExpiredIndex = i - break - } - if i == 0 { - deleteCount := len(cnc.cache.cache) - cnc.cache.cache = nil - return deleteCount - } - } - deleteCount := len(cnc.cache.cache) - nonExpiredIndex - 1 - if nonExpiredIndex != len(cnc.cache.cache)-1 { - cnc.cache.cache = cnc.cache.cache[:nonExpiredIndex+1] - } - return deleteCount -} - func (cnc *CosignerNonceCache) ClearNonces(cosigner Cosigner) { cnc.cache.mu.Lock() defer cnc.cache.mu.Unlock() diff --git a/signer/cosigner_nonce_cache_test.go b/signer/cosigner_nonce_cache_test.go index 756f5960..2c65f175 100644 --- a/signer/cosigner_nonce_cache_test.go +++ b/signer/cosigner_nonce_cache_test.go @@ -64,6 +64,7 @@ func TestClearNonces(t *testing.T) { cnc := CosignerNonceCache{ threshold: 2, + cache: new(NonceCache), } for i := 0; i < 10; i++ { @@ -98,18 +99,34 @@ func TestClearNonces(t *testing.T) { for _, n := range cnc.cache.cache { require.Len(t, n.Nonces, 2) + oneFound := false + twoFound := false + for _, cnr := range n.Nonces { + if cnr.Cosigner == cosigners[1] { + oneFound = true + } + if cnr.Cosigner == cosigners[2] { + twoFound = true + } + } + require.True(t, oneFound) + require.True(t, twoFound) } + + cnc.ClearNonces(cosigners[1]) + + require.Equal(t, 0, cnc.cache.Size()) } type mockPruner struct { - cnc *CosignerNonceCache + cache *NonceCache count int pruned int mu sync.Mutex } func (mp *mockPruner) PruneNonces() int { - pruned := mp.cnc.PruneNonces() + pruned := mp.cache.PruneNonces() mp.mu.Lock() defer mp.mu.Unlock() mp.count++ @@ -143,7 +160,7 @@ func TestNonceCacheDemand(t *testing.T) { mp, ) - mp.cnc = nonceCache + mp.cache = nonceCache.cache ctx, cancel := context.WithCancel(context.Background()) @@ -168,8 +185,8 @@ func TestNonceCacheDemand(t *testing.T) { count, pruned := mp.Result() - require.Greater(t, count, 0) - require.Equal(t, 0, pruned) + require.Greater(t, count, 0, "count of pruning calls must be greater than 0") + require.Equal(t, 0, pruned, "no nonces should have been pruned") } func TestNonceCacheExpiration(t *testing.T) { @@ -181,42 +198,206 @@ func TestNonceCacheExpiration(t *testing.T) { mp := &mockPruner{} + noncesExpiration := 1000 * time.Millisecond + getNoncesInterval := noncesExpiration / 5 + getNoncesTimeout := 10 * time.Millisecond nonceCache := NewCosignerNonceCache( cometlog.NewTMLogger(cometlog.NewSyncWriter(os.Stdout)), cosigners, &MockLeader{id: 1, leader: &ThresholdValidator{myCosigner: lcs[0]}}, - 250*time.Millisecond, - 10*time.Millisecond, - 500*time.Millisecond, + getNoncesInterval, + getNoncesTimeout, + noncesExpiration, 2, mp, ) - mp.cnc = nonceCache + mp.cache = nonceCache.cache ctx, cancel := context.WithCancel(context.Background()) - const loadN = 500 - + const loadN = 100 + // Load first set of 100 nonces nonceCache.LoadN(ctx, loadN) go nonceCache.Start(ctx) - time.Sleep(1 * time.Second) + // Sleep for 1/2 nonceExpiration, no nonces should have expired yet + time.Sleep(noncesExpiration / 2) + + // Load second set of 100 nonces + nonceCache.LoadN(ctx, loadN) + + // Wait for first set of nonces to expire + wait for the interval to have run + time.Sleep((noncesExpiration / 2) + getNoncesInterval) count, pruned := mp.Result() - // we should have pruned at least three times after - // waiting for a second with a reconcile interval of 250ms - require.GreaterOrEqual(t, count, 3) + // we should have pruned at least 5 times after + // waiting for 1200ms with a reconcile interval of 200ms + require.GreaterOrEqual(t, count, 5) - // we should have pruned at least the number of nonces we loaded and knew would expire - require.GreaterOrEqual(t, pruned, loadN) + // we should have pruned only the first set of nonces + // The second set of nonces should not have expired yet and we should not have load any more + require.Equal(t, pruned, loadN) cancel() - // the cache should be empty or 1 since no nonces are being consumed. - require.LessOrEqual(t, nonceCache.cache.Size(), 1) + // the cache should be 100 (loadN) as the second set should not have expired. + require.LessOrEqual(t, nonceCache.cache.Size(), loadN) +} + +func TestNonceCachePrune(t *testing.T) { + type testCase struct { + name string + nonces []*CachedNonce + expected []*CachedNonce + } + + now := time.Now() + + testCases := []testCase{ + { + name: "no nonces", + nonces: nil, + expected: nil, + }, + { + name: "no expired nonces", + nonces: []*CachedNonce{ + { + UUID: uuid.MustParse("d6ef381f-6234-432d-b204-d8957fe60360"), + Expiration: now.Add(1 * time.Second), + }, + { + UUID: uuid.MustParse("cdc3673d-7946-459a-b458-cbbde0eecd04"), + Expiration: now.Add(2 * time.Second), + }, + { + UUID: uuid.MustParse("38c6a201-0b8b-46eb-ab69-c7b2716d408e"), + Expiration: now.Add(3 * time.Second), + }, + { + UUID: uuid.MustParse("5caf5ab2-d460-430f-87fa-8ed2983ae8fb"), + Expiration: now.Add(4 * time.Second), + }, + }, + expected: []*CachedNonce{ + { + UUID: uuid.MustParse("d6ef381f-6234-432d-b204-d8957fe60360"), + Expiration: now.Add(1 * time.Second), + }, + { + UUID: uuid.MustParse("cdc3673d-7946-459a-b458-cbbde0eecd04"), + Expiration: now.Add(2 * time.Second), + }, + { + UUID: uuid.MustParse("38c6a201-0b8b-46eb-ab69-c7b2716d408e"), + Expiration: now.Add(3 * time.Second), + }, + { + UUID: uuid.MustParse("5caf5ab2-d460-430f-87fa-8ed2983ae8fb"), + Expiration: now.Add(4 * time.Second), + }, + }, + }, + { + name: "first nonce is expired", + nonces: []*CachedNonce{ + { + UUID: uuid.MustParse("d6ef381f-6234-432d-b204-d8957fe60360"), + Expiration: now.Add(-1 * time.Second), + }, + { + UUID: uuid.MustParse("cdc3673d-7946-459a-b458-cbbde0eecd04"), + Expiration: now.Add(2 * time.Second), + }, + { + UUID: uuid.MustParse("38c6a201-0b8b-46eb-ab69-c7b2716d408e"), + Expiration: now.Add(3 * time.Second), + }, + { + UUID: uuid.MustParse("5caf5ab2-d460-430f-87fa-8ed2983ae8fb"), + Expiration: now.Add(4 * time.Second), + }, + }, + expected: []*CachedNonce{ + { + UUID: uuid.MustParse("cdc3673d-7946-459a-b458-cbbde0eecd04"), + Expiration: now.Add(2 * time.Second), + }, + { + UUID: uuid.MustParse("38c6a201-0b8b-46eb-ab69-c7b2716d408e"), + Expiration: now.Add(3 * time.Second), + }, + { + UUID: uuid.MustParse("5caf5ab2-d460-430f-87fa-8ed2983ae8fb"), + Expiration: now.Add(4 * time.Second), + }, + }, + }, + { + name: "all but last nonce expired", + nonces: []*CachedNonce{ + { + UUID: uuid.MustParse("d6ef381f-6234-432d-b204-d8957fe60360"), + Expiration: now.Add(-1 * time.Second), + }, + { + UUID: uuid.MustParse("cdc3673d-7946-459a-b458-cbbde0eecd04"), + Expiration: now.Add(-1 * time.Second), + }, + { + UUID: uuid.MustParse("38c6a201-0b8b-46eb-ab69-c7b2716d408e"), + Expiration: now.Add(-1 * time.Second), + }, + { + UUID: uuid.MustParse("5caf5ab2-d460-430f-87fa-8ed2983ae8fb"), + Expiration: now.Add(4 * time.Second), + }, + }, + expected: []*CachedNonce{ + { + UUID: uuid.MustParse("5caf5ab2-d460-430f-87fa-8ed2983ae8fb"), + Expiration: now.Add(4 * time.Second), + }, + }, + }, + { + name: "all nonces expired", + nonces: []*CachedNonce{ + { + UUID: uuid.MustParse("d6ef381f-6234-432d-b204-d8957fe60360"), + Expiration: now.Add(-1 * time.Second), + }, + { + UUID: uuid.MustParse("cdc3673d-7946-459a-b458-cbbde0eecd04"), + Expiration: now.Add(-1 * time.Second), + }, + { + UUID: uuid.MustParse("38c6a201-0b8b-46eb-ab69-c7b2716d408e"), + Expiration: now.Add(-1 * time.Second), + }, + { + UUID: uuid.MustParse("5caf5ab2-d460-430f-87fa-8ed2983ae8fb"), + Expiration: now.Add(-1 * time.Second), + }, + }, + expected: nil, + }, + } + + for _, tc := range testCases { + nc := NonceCache{ + cache: tc.nonces, + } + + pruned := nc.PruneNonces() + + require.Equal(t, len(tc.nonces)-len(tc.expected), pruned, tc.name) + + require.Equal(t, tc.expected, nc.cache, tc.name) + } } func TestNonceCacheDemandSlow(t *testing.T) { diff --git a/signer/local_cosigner.go b/signer/local_cosigner.go index 9768692c..f9ab6ee2 100644 --- a/signer/local_cosigner.go +++ b/signer/local_cosigner.go @@ -345,6 +345,11 @@ func (cosigner *LocalCosigner) GetNonces( u := u outerEg.Go(func() error { + meta, err := cosigner.generateNoncesIfNecessary(u) + if err != nil { + return err + } + var eg errgroup.Group nonces := make([]CosignerNonce, total-1) @@ -358,7 +363,7 @@ func (cosigner *LocalCosigner) GetNonces( i := i eg.Go(func() error { - secretPart, err := cosigner.getNonce(u, peerID) + secretPart, err := cosigner.getNonce(meta, peerID) if i >= id { nonces[i-1] = secretPart @@ -392,10 +397,10 @@ func (cosigner *LocalCosigner) GetNonces( func (cosigner *LocalCosigner) generateNoncesIfNecessary(uuid uuid.UUID) (*types.NoncesWithExpiration, error) { // protects the meta map - cosigner.noncesMu.Lock() - defer cosigner.noncesMu.Unlock() - - if nonces, ok := cosigner.nonces[uuid]; ok { + cosigner.noncesMu.RLock() + nonces, ok := cosigner.nonces[uuid] + cosigner.noncesMu.RUnlock() + if ok { return nonces, nil } @@ -409,25 +414,23 @@ func (cosigner *LocalCosigner) generateNoncesIfNecessary(uuid uuid.UUID) (*types Expiration: time.Now().Add(nonceExpiration), } + cosigner.noncesMu.Lock() cosigner.nonces[uuid] = &res + cosigner.noncesMu.Unlock() + return &res, nil } // Get the ephemeral secret part for an ephemeral share // The ephemeral secret part is encrypted for the receiver func (cosigner *LocalCosigner) getNonce( - uuid uuid.UUID, + meta *types.NoncesWithExpiration, peerID int, ) (CosignerNonce, error) { zero := CosignerNonce{} id := cosigner.GetIndex() - meta, err := cosigner.generateNoncesIfNecessary(uuid) - if err != nil { - return zero, err - } - ourCosignerMeta := meta.Nonces[id-1] nonce, err := cosigner.security.EncryptAndSign(peerID, ourCosignerMeta.PubKey, ourCosignerMeta.Shares[peerID-1]) if err != nil { @@ -437,6 +440,8 @@ func (cosigner *LocalCosigner) getNonce( return nonce, nil } +const errUnexpectedState = "unexpected state, metadata does not exist for U:" + // setNonce stores a nonce provided by another cosigner func (cosigner *LocalCosigner) setNonce(uuid uuid.UUID, nonce CosignerNonce) error { // Verify the source signature @@ -458,7 +463,8 @@ func (cosigner *LocalCosigner) setNonce(uuid uuid.UUID, nonce CosignerNonce) err // generate metadata placeholder if !ok { return fmt.Errorf( - "unexpected state, metadata does not exist for U: %s", + "%s %s", + errUnexpectedState, uuid, ) } diff --git a/signer/threshold_validator.go b/signer/threshold_validator.go index 0343286b..6572c67f 100644 --- a/signer/threshold_validator.go +++ b/signer/threshold_validator.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "os" + "strings" "sync" "time" @@ -685,6 +686,10 @@ func (pv *ThresholdValidator) Sign(ctx context.Context, chainID string, block ty "err", err.Error(), ) + if strings.Contains(err.Error(), errUnexpectedState) { + pv.nonceCache.ClearNonces(cosigner) + } + if cosigner.GetIndex() == pv.myCosigner.GetIndex() { return err }