Skip to content

Commit

Permalink
Speed up ATX cache warmup (#6241)
Browse files Browse the repository at this point in the history
## Motivation

Speeding up the in-memory ATX cache warmup that is especially slow on HDDs.



Co-authored-by: Jedrzej Nowak <[email protected]>
  • Loading branch information
poszu and pigmej committed Aug 12, 2024
1 parent 16d8ac9 commit be1305a
Show file tree
Hide file tree
Showing 8 changed files with 81 additions and 72 deletions.
2 changes: 1 addition & 1 deletion activation/e2e/checkpoint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ func TestCheckpoint_PublishingSoloATXs(t *testing.T) {
// 3. Spawn new ATX handler and builder using the new DB
poetDb = activation.NewPoetDb(newDB, logger.Named("poetDb"))
cdb = datastore.NewCachedDB(newDB, logger)
atxdata, err = atxsdata.Warm(newDB, 1)
atxdata, err = atxsdata.Warm(newDB, 1, logger)
poetService = activation.NewPoetServiceWithClient(poetDb, client, poetCfg, logger)
validator = activation.NewValidator(newDB, poetDb, cfg, opts.Scrypt, verifier)
require.NoError(t, err)
Expand Down
25 changes: 8 additions & 17 deletions atxsdata/data.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,6 @@ type ATX struct {
Weight uint64
BaseHeight, Height uint64
Nonce types.VRFPostIndex
// unexported to avoid accidental unsynchronized access
// (this field is mutated by the Data under a lock and
// might only be safely read under the same lock)
malicious bool
}

func New() *Data {
Expand Down Expand Up @@ -107,9 +103,6 @@ func (d *Data) AddAtx(target types.EpochID, id types.ATXID, atx *ATX) bool {
atxsCounter.WithLabelValues(target.String()).Inc()

ecache.index[id] = atx
if atx.malicious {
d.malicious[atx.Node] = struct{}{}
}
return true
}

Expand All @@ -131,7 +124,9 @@ func (d *Data) Add(
BaseHeight: baseHeight,
Height: height,
Nonce: nonce,
malicious: malicious,
}
if malicious {
d.SetMalicious(node)
}
if d.AddAtx(epoch, atxid, atx) {
return atx
Expand Down Expand Up @@ -165,8 +160,6 @@ func (d *Data) Get(epoch types.EpochID, atx types.ATXID) *ATX {
if !exists {
return nil
}
_, exists = d.malicious[data.Node]
data.malicious = exists
return data
}

Expand All @@ -185,10 +178,11 @@ type lockGuard struct{}
// AtxFilter is a function that filters atxs.
// The `lockGuard` prevents using the filter functions outside of the allowed context
// to prevent data races.
type AtxFilter func(*ATX, lockGuard) bool
type AtxFilter func(*Data, *ATX, lockGuard) bool

func NotMalicious(data *ATX, _ lockGuard) bool {
return !data.malicious
func NotMalicious(d *Data, atx *ATX, _ lockGuard) bool {
_, m := d.malicious[atx.Node]
return !m
}

// IterateInEpoch calls `fn` for every ATX in epoch.
Expand All @@ -202,12 +196,9 @@ func (d *Data) IterateInEpoch(epoch types.EpochID, fn func(types.ATXID, *ATX), f
return
}
for id, atx := range ecache.index {
if _, exists := d.malicious[atx.Node]; exists {
atx.malicious = true
}
ok := true
for _, filter := range filters {
ok = ok && filter(atx, lockGuard{})
ok = ok && filter(d, atx, lockGuard{})
}
if ok {
fn(id, atx)
Expand Down
13 changes: 5 additions & 8 deletions atxsdata/data_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,15 @@ func TestData(t *testing.T) {
d.BaseHeight,
d.Height,
d.Nonce,
d.malicious,
false,
)
}
}
for epoch := 0; epoch < epochs; epoch++ {
for i := range atxids[epoch] {
byatxid := c.Get(types.EpochID(epoch)+1, atxids[epoch][i])
require.Equal(t, &data[epoch][i], byatxid)
atx := c.Get(types.EpochID(epoch)+1, atxids[epoch][i])
require.Equal(t, &data[epoch][i], atx)
require.False(t, c.IsMalicious(atx.Node))
}
}
}
Expand All @@ -71,13 +72,9 @@ func TestData(t *testing.T) {
)
data := c.Get(types.EpochID(epoch), types.ATXID{byte(epoch)})
require.NotNil(t, data)
require.False(t, data.malicious)
require.False(t, c.IsMalicious(data.Node))
}
c.SetMalicious(node)
for epoch := 1; epoch <= 10; epoch++ {
data := c.Get(types.EpochID(epoch), types.ATXID{byte(epoch)})
require.True(t, data.malicious)
}
require.True(t, c.IsMalicious(node))
})
t.Run("eviction", func(t *testing.T) {
Expand Down
40 changes: 34 additions & 6 deletions atxsdata/warmup.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,31 @@ package atxsdata
import (
"context"
"fmt"
"time"

"go.uber.org/zap"

"github.com/spacemeshos/go-spacemesh/common/types"
"github.com/spacemeshos/go-spacemesh/sql"
"github.com/spacemeshos/go-spacemesh/sql/atxs"
"github.com/spacemeshos/go-spacemesh/sql/identities"
"github.com/spacemeshos/go-spacemesh/sql/layers"
)

func Warm(db *sql.Database, keep types.EpochID) (*Data, error) {
func Warm(db *sql.Database, keep types.EpochID, logger *zap.Logger) (*Data, error) {
cache := New()
tx, err := db.Tx(context.Background())
if err != nil {
return nil, err
}
defer tx.Release()
if err := Warmup(tx, cache, keep); err != nil {
if err := Warmup(tx, cache, keep, logger); err != nil {
return nil, fmt.Errorf("warmup %w", err)
}
return cache, nil
}

func Warmup(db sql.Executor, cache *Data, keep types.EpochID) error {
func Warmup(db sql.Executor, cache *Data, keep types.EpochID, logger *zap.Logger) error {
latest, err := atxs.LatestEpoch(db)
if err != nil {
return err
Expand All @@ -38,7 +42,14 @@ func Warmup(db sql.Executor, cache *Data, keep types.EpochID) error {
}
cache.EvictEpoch(evict)

return atxs.IterateAtxsData(db, cache.Evicted(), latest,
from := cache.Evicted()
logger.Info("Reading ATXs from DB",
zap.Uint32("from epoch", from.Uint32()),
zap.Uint32("to epoch", latest.Uint32()),
)
start := time.Now()
var processed int
err = atxs.IterateAtxsData(db, cache.Evicted(), latest,
func(
id types.ATXID,
node types.NodeID,
Expand All @@ -48,7 +59,6 @@ func Warmup(db sql.Executor, cache *Data, keep types.EpochID) error {
base,
height uint64,
nonce types.VRFPostIndex,
malicious bool,
) bool {
cache.Add(
epoch+1,
Expand All @@ -59,8 +69,26 @@ func Warmup(db sql.Executor, cache *Data, keep types.EpochID) error {
base,
height,
nonce,
malicious,
false,
)
processed += 1
if processed%1_000_000 == 0 {
logger.Debug("Processed 1M", zap.Int("total", processed))
}
return true
})
if err != nil {
return fmt.Errorf("warming up atxdata with ATXs: %w", err)
}
logger.Info("Finished reading ATXs. Starting reading malfeasance", zap.Duration("duration", time.Since(start)))
start = time.Now()
err = identities.IterateMalicious(db, func(_ int, id types.NodeID) error {
cache.SetMalicious(id)
return nil
})
if err != nil {
return fmt.Errorf("warming up atxdata with malfeasance: %w", err)
}
logger.Info("Finished reading malfeasance", zap.Duration("duration", time.Since(start)))
return nil
}
9 changes: 5 additions & 4 deletions atxsdata/warmup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"go.uber.org/zap/zaptest"

"github.com/spacemeshos/go-spacemesh/common/types"
"github.com/spacemeshos/go-spacemesh/sql"
Expand Down Expand Up @@ -53,21 +54,21 @@ func TestWarmup(t *testing.T) {
}
require.NoError(t, layers.SetApplied(db, applied, types.BlockID{1}))

c, err := Warm(db, 1)
c, err := Warm(db, 1, zaptest.NewLogger(t))
require.NoError(t, err)
for _, atx := range data[2:] {
require.NotNil(t, c.Get(atx.TargetEpoch(), atx.ID()))
}
})
t.Run("no data", func(t *testing.T) {
c, err := Warm(sql.InMemory(), 1)
c, err := Warm(sql.InMemory(), 1, zaptest.NewLogger(t))
require.NoError(t, err)
require.NotNil(t, c)
})
t.Run("closed db", func(t *testing.T) {
db := sql.InMemory()
require.NoError(t, db.Close())
c, err := Warm(db, 1)
c, err := Warm(db, 1, zaptest.NewLogger(t))
require.Error(t, err)
require.Nil(t, c)
})
Expand All @@ -94,7 +95,7 @@ func TestWarmup(t *testing.T) {
AnyTimes()
for range 3 {
c := New()
require.Error(t, Warmup(exec, c, 1))
require.Error(t, Warmup(exec, c, 1, zaptest.NewLogger(t)))
fail++
call = 0
}
Expand Down
34 changes: 19 additions & 15 deletions node/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -1985,24 +1985,28 @@ func (app *App) setupDBs(ctx context.Context, lg log.Log) error {
app.Config.DatabaseSizeMeteringInterval,
)
}
app.log.Info("starting cache warmup")
applied, err := layers.GetLastApplied(app.db)
if err != nil {
return err
}
start := time.Now()
data, err := atxsdata.Warm(
app.db,
app.Config.Tortoise.WindowSizeEpochs(applied),
)
if err != nil {
return err
{
warmupLog := app.log.Zap().Named("warmup")
app.log.Info("starting cache warmup")
applied, err := layers.GetLastApplied(app.db)
if err != nil {
return err
}
start := time.Now()
data, err := atxsdata.Warm(
app.db,
app.Config.Tortoise.WindowSizeEpochs(applied),
warmupLog,
)
if err != nil {
return err
}
app.atxsdata = data
app.log.With().Info("cache warmup", log.Duration("duration", time.Since(start)))
}
app.atxsdata = data
app.log.With().Info("cache warmup", log.Duration("duration", time.Since(start)))
app.cachedDB = datastore.NewCachedDB(sqlDB, app.addLogger(CachedDBLogger, lg).Zap(),
datastore.WithConfig(app.Config.Cache),
datastore.WithConsensusCache(data),
datastore.WithConsensusCache(app.atxsdata),
)

migrations, err = sql.LocalMigrations()
Expand Down
28 changes: 8 additions & 20 deletions sql/atxs/atxs.go
Original file line number Diff line number Diff line change
Expand Up @@ -727,28 +727,18 @@ func IterateAtxsData(
base uint64,
height uint64,
nonce types.VRFPostIndex,
isMalicious bool,
) bool,
) error {
_, err := db.Exec(
`select
a.id, a.pubkey, a.epoch, a.coinbase, a.effective_num_units,
a.base_tick_height, a.tick_count, a.nonce,
iif(idn.proof is null, 0, 1) as is_malicious
from atxs a left join identities idn on a.pubkey = idn.pubkey`,
// SQLite happens to process the query much faster if we don't
// filter it by epoch
// where a.epoch between ? and ?`,
// func(stmt *sql.Statement) {
// stmt.BindInt64(1, int64(from.Uint32()))
// stmt.BindInt64(2, int64(to.Uint32()))
// },
nil,
`SELECT id, pubkey, epoch, coinbase, effective_num_units, base_tick_height, tick_count, nonce FROM atxs
WHERE epoch between ?1 and ?2`,
// filtering in CODE is no longer effective on some machines in epoch 29
func(stmt *sql.Statement) {
stmt.BindInt64(1, int64(from.Uint32()))
stmt.BindInt64(2, int64(to.Uint32()))
},
func(stmt *sql.Statement) bool {
epoch := types.EpochID(uint32(stmt.ColumnInt64(2)))
if epoch < from || epoch > to {
return true
}
var id types.ATXID
stmt.ColumnBytes(0, id[:])
var node types.NodeID
Expand All @@ -759,9 +749,7 @@ func IterateAtxsData(
baseHeight := uint64(stmt.ColumnInt64(5))
ticks := uint64(stmt.ColumnInt64(6))
nonce := types.VRFPostIndex(stmt.ColumnInt64(7))
isMalicious := stmt.ColumnInt(8) != 0
return fn(id, node, epoch, coinbase, effectiveUnits*ticks,
baseHeight, baseHeight+ticks, nonce, isMalicious)
return fn(id, node, epoch, coinbase, effectiveUnits*ticks, baseHeight, baseHeight+ticks, nonce)
},
)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion tortoise/replay/replay_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ func TestReplayMainnet(t *testing.T) {
require.NoError(t, err)

start := time.Now()
atxsdata, err := atxsdata.Warm(db, cfg.Tortoise.WindowSizeEpochs(applied))
atxsdata, err := atxsdata.Warm(db, cfg.Tortoise.WindowSizeEpochs(applied), logger)
require.NoError(t, err)
trtl, err := tortoise.Recover(
context.Background(),
Expand Down

0 comments on commit be1305a

Please sign in to comment.