Skip to content

Commit

Permalink
Merge pull request #944 from buck54321/end-recovery-on-shutdown
Browse files Browse the repository at this point in the history
end recovery on shutdown
  • Loading branch information
Roasbeef authored Aug 15, 2024
2 parents 7d3434c + 8e2426a commit 6ecae9c
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 15 deletions.
10 changes: 8 additions & 2 deletions wallet/mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ import (
)

type mockChainClient struct {
getBestBlockHeight int32
getBlockHashFunc func() (*chainhash.Hash, error)
getBlockHeader *wire.BlockHeader
}

var _ chain.Interface = (*mockChainClient)(nil)
Expand All @@ -26,20 +29,23 @@ func (m *mockChainClient) Stop() {
func (m *mockChainClient) WaitForShutdown() {}

func (m *mockChainClient) GetBestBlock() (*chainhash.Hash, int32, error) {
return nil, 0, nil
return nil, m.getBestBlockHeight, nil
}

func (m *mockChainClient) GetBlock(*chainhash.Hash) (*wire.MsgBlock, error) {
return nil, nil
}

func (m *mockChainClient) GetBlockHash(int64) (*chainhash.Hash, error) {
if m.getBlockHashFunc != nil {
return m.getBlockHashFunc()
}
return nil, nil
}

func (m *mockChainClient) GetBlockHeader(*chainhash.Hash) (*wire.BlockHeader,
error) {
return nil, nil
return m.getBlockHeader, nil
}

func (m *mockChainClient) IsCurrent() bool {
Expand Down
33 changes: 20 additions & 13 deletions wallet/wallet.go
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,8 @@ func (w *Wallet) quitChan() <-chan struct{} {

// Stop signals all wallet goroutines to shutdown.
func (w *Wallet) Stop() {
<-w.endRecovery()

w.quitMu.Lock()
quit := w.quit
w.quitMu.Unlock()
Expand Down Expand Up @@ -1380,6 +1382,23 @@ type (
heldUnlock chan struct{}
)

// endRecovery tells (*Wallet).recovery to stop, if running, and returns a
// channel that will be closed when the recovery routine exits.
func (w *Wallet) endRecovery() <-chan struct{} {
if recoverySyncI := w.recovering.Load(); recoverySyncI != nil {
recoverySync := recoverySyncI.(*recoverySyncer)

// If recovery is still running, it will end early with an error
// once we set the quit flag.
atomic.StoreUint32(&recoverySync.quit, 1)

return recoverySync.done
}
c := make(chan struct{})
close(c)
return c
}

// walletLocker manages the locked/unlocked state of a wallet.
func (w *Wallet) walletLocker() {
var timeout <-chan time.Time
Expand Down Expand Up @@ -1472,19 +1491,7 @@ out:

// We can't lock the manager if recovery is active because we use
// cryptoKeyPriv and cryptoKeyScript in recovery.
if recoverySyncI := w.recovering.Load(); recoverySyncI != nil {
recoverySync := recoverySyncI.(*recoverySyncer)
// If recovery is still running, it will end early with an error
// once we set the quit flag.
atomic.StoreUint32(&recoverySync.quit, 1)

select {
case <-recoverySync.done:
case <-quit:
break out
}

}
<-w.endRecovery()

timeout = nil
err := w.Manager.Lock()
Expand Down
126 changes: 126 additions & 0 deletions wallet/wallet_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,16 @@ package wallet
import (
"encoding/hex"
"fmt"
"math"
"strings"
"sync"
"sync/atomic"
"testing"
"time"

"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcwallet/waddrmgr"
"github.com/btcsuite/btcwallet/walletdb"
"github.com/btcsuite/btcwallet/wtxmgr"
Expand Down Expand Up @@ -359,3 +363,125 @@ func TestDuplicateAddressDerivation(t *testing.T) {
require.NoError(t, eg.Wait())
}
}

func TestEndRecovery(t *testing.T) {
// This is an unconventional unit test, but I'm trying to keep things as
// succint as possible so that this test is readable without having to mock
// up literally everything.
// The unmonitored goroutine we're looking at is pretty deep:
// SynchronizeRPC -> handleChainNotifications -> syncWithChain -> recovery
// The "deadlock" we're addressing isn't actually a deadlock, but the wallet
// will hang on Stop() -> WaitForShutdown() until (*Wallet).recovery gets
// every single block, which could be hours depending on hardware and
// network factors. The WaitGroup is incremented in SynchronizeRPC, and
// WaitForShutdown will not return until handleChainNotifications returns,
// which is blocked by a running (*Wallet).recovery loop.
// It is noted that the conditions for long recovery are difficult to hit
// when using btcwallet with a fresh seed, because it requires an early
// birthday to be set or established.

w, cleanup := testWallet(t)

blockHashCalled := make(chan struct{})

chainClient := &mockChainClient{
// Force the loop to iterate about forever.
getBestBlockHeight: math.MaxInt32,
// Get control of when the loop iterates.
getBlockHashFunc: func() (*chainhash.Hash, error) {
blockHashCalled <- struct{}{}
return &chainhash.Hash{}, nil
},
// Avoid a panic.
getBlockHeader: &wire.BlockHeader{},
}

recoveryDone := make(chan struct{})
go func() {
defer close(recoveryDone)
w.recovery(chainClient, &waddrmgr.BlockStamp{})
}()

getBlockHashCalls := func(expCalls int) {
var i int
for {
select {
case <-blockHashCalled:
i++
case <-time.After(time.Second):
t.Fatal("expected BlockHash to be called")
}
if i == expCalls {
break
}
}
}

// Recovery is running.
getBlockHashCalls(3)

// Closing the quit channel, e.g. Stop() without endRecovery, alone will not
// end the recovery loop.
w.quitMu.Lock()
close(w.quit)
w.quitMu.Unlock()
// Continues scanning.
getBlockHashCalls(3)

// We're done with this one
atomic.StoreUint32(&w.recovering.Load().(*recoverySyncer).quit, 1)
select {
case <-blockHashCalled:
case <-recoveryDone:
}
cleanup()

// Try again.
w, cleanup = testWallet(t)
defer cleanup()

// We'll catch the error to make sure we're hitting our desired path. The
// WaitGroup isn't required for the test, but does show how it completes
// shutdown at a higher level.
var err error
w.wg.Add(1)
recoveryDone = make(chan struct{})
go func() {
defer w.wg.Done()
defer close(recoveryDone)
err = w.recovery(chainClient, &waddrmgr.BlockStamp{})
}()

waitedForShutdown := make(chan struct{})
go func() {
w.WaitForShutdown()
close(waitedForShutdown)
}()

// Recovery is running.
getBlockHashCalls(3)

// endRecovery is required to exit the unmonitored goroutine.
end := w.endRecovery()
select {
case <-blockHashCalled:
case <-recoveryDone:
}
<-end

// testWallet starts a couple of other unrelated goroutines that need to be
// killed, so we still need to close the quit channel.
w.quitMu.Lock()
close(w.quit)
w.quitMu.Unlock()

select {
case <-waitedForShutdown:
case <-time.After(time.Second):
t.Fatal("WaitForShutdown never returned")
}

if !strings.EqualFold(err.Error(), "recovery: forced shutdown") {
t.Fatal("wrong error")
}
}

0 comments on commit 6ecae9c

Please sign in to comment.