diff --git a/checkpoint/recovery.go b/checkpoint/recovery.go index 261ede31e2..15d312f663 100644 --- a/checkpoint/recovery.go +++ b/checkpoint/recovery.go @@ -276,7 +276,7 @@ func RecoverFromLocalFile( newDB, err := statesql.Open("file:" + cfg.DbPath()) if err != nil { - return nil, fmt.Errorf("creating new DB: %w", err) + return nil, fmt.Errorf("create new db: %w", err) } defer newDB.Close() logger.Info("populating new database", diff --git a/fetch/handler_test.go b/fetch/handler_test.go index bca3c7fac2..a1ff0f781f 100644 --- a/fetch/handler_test.go +++ b/fetch/handler_test.go @@ -330,7 +330,7 @@ func TestHandleEpochInfoReq(t *testing.T) { var resp server.Response require.NoError(t, codec.Decode(b.Bytes(), &resp)) require.Empty(t, resp.Data) - require.Contains(t, resp.Error, "exec epoch 11: database: no free connection") + require.Contains(t, resp.Error, "exec epoch 11: database closed") }) }) } diff --git a/fetch/p2p_test.go b/fetch/p2p_test.go index e92d5ea53b..1e56fb703f 100644 --- a/fetch/p2p_test.go +++ b/fetch/p2p_test.go @@ -270,7 +270,7 @@ func forStreamingCachedUncached( func TestP2PPeerEpochInfo(t *testing.T) { forStreamingCachedUncached( - t, "peer error: getting ATX IDs: exec epoch 11: database: no free connection", + t, "peer error: getting ATX IDs: exec epoch 11: database closed", func(t *testing.T, ctx context.Context, tpf *testP2PFetch, errStr string) { epoch := types.EpochID(11) atxIDs := tpf.createATXs(epoch) @@ -291,7 +291,7 @@ func TestP2PPeerEpochInfo(t *testing.T) { func TestP2PPeerMeshHashes(t *testing.T) { forStreaming( - t, "peer error: get aggHashes from 7 to 23 by 5: database: no free connection", false, + t, "peer error: get aggHashes from 7 to 23 by 5: database closed", false, func(t *testing.T, ctx context.Context, tpf *testP2PFetch, errStr string) { req := &MeshHashRequest{ From: 7, @@ -324,7 +324,7 @@ func TestP2PPeerMeshHashes(t *testing.T) { func TestP2PMaliciousIDs(t *testing.T) { forStreaming( - t, "database: no free connection", false, + t, "database closed", false, func(t *testing.T, ctx context.Context, tpf *testP2PFetch, errStr string) { var bad []types.NodeID for i := 0; i < 11; i++ { @@ -349,7 +349,7 @@ func TestP2PMaliciousIDs(t *testing.T) { func TestP2PGetATXs(t *testing.T) { forStreamingCachedUncached( - t, "database: no free connection", + t, "database closed", func(t *testing.T, ctx context.Context, tpf *testP2PFetch, errStr string) { epoch := types.EpochID(11) atx := newAtx(tpf.t, epoch) @@ -365,7 +365,7 @@ func TestP2PGetATXs(t *testing.T) { func TestP2PGetPoet(t *testing.T) { forStreaming( - t, "database: no free connection", false, + t, "database closed", false, func(t *testing.T, ctx context.Context, tpf *testP2PFetch, errStr string) { ref := types.PoetProofRef{0x42, 0x43} require.NoError(t, poets.Add(tpf.serverCDB, ref, []byte("proof1"), []byte("sid1"), "rid1")) @@ -380,7 +380,7 @@ func TestP2PGetPoet(t *testing.T) { func TestP2PGetBallot(t *testing.T) { forStreaming( - t, "database: no free connection", false, + t, "database closed", false, func(t *testing.T, ctx context.Context, tpf *testP2PFetch, errStr string) { signer, err := signing.NewEdSigner() require.NoError(t, err) @@ -402,7 +402,7 @@ func TestP2PGetBallot(t *testing.T) { func TestP2PGetActiveSet(t *testing.T) { forStreamingCachedUncached( - t, "database: no free connection", + t, "database closed", func(t *testing.T, ctx context.Context, tpf *testP2PFetch, errStr string) { id := types.RandomHash() set := &types.EpochActiveSet{ @@ -421,7 +421,7 @@ func TestP2PGetActiveSet(t *testing.T) { func TestP2PGetBlock(t *testing.T) { forStreaming( - t, "database: no free connection", false, + t, "database closed", false, func(t *testing.T, ctx context.Context, tpf *testP2PFetch, errStr string) { lid := types.LayerID(111) bk := types.NewExistingBlock(types.RandomBlockID(), types.InnerBlock{LayerIndex: lid}) @@ -472,7 +472,7 @@ func TestP2PGetProp(t *testing.T) { func TestP2PGetBlockTransactions(t *testing.T) { forStreaming( - t, "database: no free connection", false, + t, "database closed", false, func(t *testing.T, ctx context.Context, tpf *testP2PFetch, errStr string) { signer, err := signing.NewEdSigner() require.NoError(t, err) @@ -488,7 +488,7 @@ func TestP2PGetBlockTransactions(t *testing.T) { func TestP2PGetProposalTransactions(t *testing.T) { forStreaming( - t, "database: no free connection", false, + t, "database closed", false, func(t *testing.T, ctx context.Context, tpf *testP2PFetch, errStr string) { signer, err := signing.NewEdSigner() require.NoError(t, err) @@ -506,7 +506,7 @@ func TestP2PGetProposalTransactions(t *testing.T) { func TestP2PGetMalfeasanceProofs(t *testing.T) { forStreaming( - t, "database: no free connection", false, + t, "database closed", false, func(t *testing.T, ctx context.Context, tpf *testP2PFetch, errStr string) { nid := types.RandomNodeID() proof := types.RandomBytes(11) diff --git a/node/node.go b/node/node.go index 065d5fa0fb..deafafbde8 100644 --- a/node/node.go +++ b/node/node.go @@ -1966,7 +1966,7 @@ func (app *App) setupDBs(ctx context.Context, lg log.Log) error { } sqlDB, err := statesql.Open("file:"+filepath.Join(dbPath, dbFile), dbopts...) if err != nil { - return fmt.Errorf("open sqlite db %w", err) + return fmt.Errorf("open sqlite db: %w", err) } app.db = sqlDB if app.Config.CollectMetrics && app.Config.DatabaseSizeMeteringInterval != 0 { @@ -2012,7 +2012,7 @@ func (app *App) setupDBs(ctx context.Context, lg log.Log) error { sql.WithAllowSchemaDrift(app.Config.DatabaseSchemaAllowDrift), ) if err != nil { - return fmt.Errorf("open sqlite db %w", err) + return fmt.Errorf("open sqlite db: %w", err) } app.localDB = localDB return nil diff --git a/sql/database.go b/sql/database.go index 7ae299b43e..a647086f4a 100644 --- a/sql/database.go +++ b/sql/database.go @@ -6,6 +6,8 @@ import ( "errors" "fmt" "maps" + "net/url" + "os" "strings" "sync" "sync/atomic" @@ -20,6 +22,8 @@ import ( ) var ( + // ErrClosed is returned if database is closed. + ErrClosed = errors.New("database closed") // ErrNoConnection is returned if pooled connection is not available. ErrNoConnection = errors.New("database: no free connection") // ErrNotFound is returned if requested record is not found. @@ -63,27 +67,32 @@ type Decoder func(*Statement) bool func defaultConf() *conf { return &conf{ - enableMigrations: true, - connections: 16, - logger: zap.NewNop(), - schema: &Schema{}, - checkSchemaDrift: true, + enableMigrations: true, + connections: 16, + logger: zap.NewNop(), + schema: &Schema{}, + checkSchemaDrift: true, + handleIncompleteMigrations: true, } } type conf struct { - enableMigrations bool - forceFresh bool - forceMigrations bool - connections int - vacuumState int - enableLatency bool - cache bool - cacheSizes map[QueryCacheKind]int - logger *zap.Logger - schema *Schema - allowSchemaDrift bool - checkSchemaDrift bool + uri string + enableMigrations bool + forceFresh bool + forceMigrations bool + connections int + vacuumState int + enableLatency bool + cache bool + cacheSizes map[QueryCacheKind]int + logger *zap.Logger + schema *Schema + allowSchemaDrift bool + checkSchemaDrift bool + temp bool + handleIncompleteMigrations bool + exclusive bool } // WithConnections overwrites number of pooled connections. @@ -178,6 +187,33 @@ func withForceFresh() Opt { } } +// WithTemp specifies temporary database mode. +// For the temporary database, the migrations are always run in place, and vacuuming is +// nover done. PRAGMA journal_mode=OFF and PRAGMA synchronous=OFF are used. +func WithTemp() Opt { + return func(c *conf) { + c.temp = true + } +} + +func withDisableIncompleteMigrationHandling() Opt { + return func(c *conf) { + c.handleIncompleteMigrations = false + } +} + +// WithExclusive specifies that the database is to be open in exclusive mode. +// This means that no other processes can open the database at the same time. +// If the database is already open by any process, this Open will fail. +// Any subsequent attempts by other processes to open the database will fail until this db +// handle is closed. +// In Exclusive mode, the database supports just one concurrent connection. +func WithExclusive() Opt { + return func(c *conf) { + c.exclusive = true + } +} + // Opt for configuring database. type Opt func(c *conf) @@ -204,85 +240,124 @@ func InMemory(opts ...Opt) *sqliteDatabase { // https://www.sqlite.org/pragma.html#pragma_synchronous func Open(uri string, opts ...Opt) (*sqliteDatabase, error) { config := defaultConf() + config.uri = uri for _, opt := range opts { opt(config) } - logger := config.logger.With(zap.String("uri", uri)) + if !config.temp && config.handleIncompleteMigrations && !config.forceFresh { + if err := handleIncompleteCopyMigration(config); err != nil { + return nil, err + } + } + return openDB(config) +} + +func openDB(config *conf) (db *sqliteDatabase, err error) { + logger := config.logger.With(zap.String("uri", config.uri)) var flags sqlite.OpenFlags if !config.forceFresh { flags = sqlite.SQLITE_OPEN_READWRITE | - sqlite.SQLITE_OPEN_WAL | sqlite.SQLITE_OPEN_URI | sqlite.SQLITE_OPEN_NOMUTEX + if !config.temp { + // Note that SQLITE_OPEN_WAL is not handled by SQLITE api itself, + // but rather by the crawshaw library which executes + // PRAGMA journal_mode=WAL in this case. + // We don't want it for temporary databases as they're not + // using any journal + flags |= sqlite.SQLITE_OPEN_WAL + } } freshDB := config.forceFresh - pool, err := sqlitex.Open(uri, flags, config.connections) + if config.exclusive { + config.connections = 1 + } + pool, err := sqlitex.Open(config.uri, flags, config.connections) if err != nil { if config.forceFresh || sqlite.ErrCode(err) != sqlite.SQLITE_CANTOPEN { - return nil, fmt.Errorf("open db %s: %w", uri, err) + return nil, fmt.Errorf("open db %s: %w", config.uri, err) } flags |= sqlite.SQLITE_OPEN_CREATE freshDB = true - pool, err = sqlitex.Open(uri, flags, config.connections) + pool, err = sqlitex.Open(config.uri, flags, config.connections) if err != nil { - return nil, fmt.Errorf("create db %s: %w", uri, err) + return nil, fmt.Errorf("create db %s: %w", config.uri, err) } } - db := &sqliteDatabase{pool: pool} + db = &sqliteDatabase{pool: pool} + defer func() { + // If something goes wrong, close the database even in case of a + // panic. This is important for tests that verify incomplete migration. + if r := recover(); r != nil { + db.Close() + panic(r) + } + }() + // In case of VACUUM INTO based migration, prepareDB may close this database and + // open another one. + actualDB, err := prepareDB(logger, db, config, freshDB) + if err != nil { + db.Close() + return nil, err + } + return actualDB, nil +} + +func prepareDB(logger *zap.Logger, db *sqliteDatabase, config *conf, freshDB bool) (*sqliteDatabase, error) { + var err error + if config.enableLatency { db.latency = newQueryLatency() } + + if config.temp { + // Temporary database is used for migration and is deleted if migrations + // fail, so we make it faster by disabling journaling and synchronous + // writes. + if _, err := db.Exec("PRAGMA journal_mode=OFF", nil, nil); err != nil { + return nil, fmt.Errorf("PRAGMA journal_mode=OFF: %w", err) + } + if _, err := db.Exec("PRAGMA synchronous=OFF", nil, nil); err != nil { + return nil, fmt.Errorf("PRAGMA synchronous=OFF: %w", err) + } + } + + if config.exclusive { + if err := db.startExclusive(); err != nil { + return nil, fmt.Errorf("start exclusive: %w", err) + } + } + if freshDB && !config.forceMigrations { if err := config.schema.Apply(db); err != nil { - return nil, errors.Join( - fmt.Errorf("error running schema script: %w", err), - db.Close()) + return nil, fmt.Errorf("error running schema script: %w", err) } - } else { - before, after, err := config.schema.CheckDBVersion(logger, db) - switch { - case err != nil: - return nil, errors.Join(err, db.Close()) - case before != after && config.enableMigrations: - logger.Info("running migrations", - zap.Int("current version", before), - zap.Int("target version", after), - ) - if err := config.schema.Migrate( - logger, db, before, config.vacuumState, - ); err != nil { - return nil, errors.Join(err, db.Close()) - } - case before != after: - logger.Error("database version is too old", - zap.Int("current version", before), - zap.Int("target version", after), - ) - return nil, errors.Join( - fmt.Errorf("%w: %d < %d", ErrOldSchema, before, after), - db.Close()) + } else if db, err = ensureDBSchemaUpToDate(logger, db, config); err != nil { + // ensureDBSchemaUpToDate may replace the original database and open the new one, + // in which case the original db is already closed but we must close the new one. + // If there are migrations to be done in place without vacuuming, + // the original db is returned and we must close it if there's an error. + if db != nil { + db.Close() } + return nil, err } if config.checkSchemaDrift { loaded, err := LoadDBSchemaScript(db) if err != nil { - return nil, errors.Join( - fmt.Errorf("error loading database schema: %w", err), - db.Close()) + return nil, fmt.Errorf("error loading database schema: %w", err) } diff := config.schema.Diff(loaded) switch { case diff == "": // ok case config.allowSchemaDrift: logger.Warn("database schema drift detected", - zap.String("uri", uri), + zap.String("uri", config.uri), zap.String("diff", diff), ) default: - return nil, errors.Join( - fmt.Errorf("schema drift detected (uri %s):\n%s", uri, diff), - db.Close()) + return nil, fmt.Errorf("schema drift detected (uri %s):\n%s", config.uri, diff) } } @@ -294,16 +369,187 @@ func Open(uri string, opts ...Opt) (*sqliteDatabase, error) { return db, nil } +func ensureDBSchemaUpToDate(logger *zap.Logger, db *sqliteDatabase, config *conf) (*sqliteDatabase, error) { + before, after, err := config.schema.CheckDBVersion(logger, db) + switch { + case err != nil: + return db, fmt.Errorf("check db version: %w", err) + case before == after: + return db, nil + case before > after: + return db, fmt.Errorf("%w: %d > %d", ErrTooNew, before, after) + case !config.enableMigrations: + return db, fmt.Errorf("%w: %d < %d", ErrOldSchema, before, after) + case config.temp: + // Temporary database, do migrations without transactions + // and sync afterwards + return db, config.schema.MigrateTempDB(logger, db, before) + case config.vacuumState != 0 && + before <= config.vacuumState && + strings.HasPrefix(config.uri, "file:"): + logger.Info("running migrations", + zap.Int("current version", before), + zap.Int("target version", after), + ) + return db.copyMigrateDB(config) + } + + logger.Info("running migrations in-place", + zap.Int("current version", before), + zap.Int("target version", after), + ) + return db, config.schema.Migrate(logger, db, before, config.vacuumState) +} + func Version(uri string) (int, error) { pool, err := sqlitex.Open(uri, sqlite.SQLITE_OPEN_READONLY, 1) if err != nil { return 0, fmt.Errorf("open db %s: %w", uri, err) } db := &sqliteDatabase{pool: pool} - defer db.Close() - return version(db) + v, err := version(db) + if err != nil { + db.Close() + return 0, err + } + if err := db.Close(); err != nil { + return 0, fmt.Errorf("close db %s: %w", uri, err) + } + return v, nil +} + +// deleteDB deletes the database at the specified path by removing /path/to/DB* files. +// If the database doesn't exist, no error is returned. +// In addition to what DROP DATABASE does, this also removes the migration marker file. +func deleteDB(path string) error { + // https://www.sqlite.org/tempfiles.html plus marker *_done + for _, suffix := range []string{"", "-journal", "-wal", "-shm", "_done"} { + file := path + suffix + if err := os.Remove(file); err != nil { + if errors.Is(err, os.ErrNotExist) { + continue + } + return fmt.Errorf("remove %s: %w", file, err) + } + } + return nil +} + +// moveMigratedDB runs "VACUUM INTO" on the database at fromPath and +// replaces the database at toPath with the vacuumed one. The database +// at fromPath is deleted after the operation. +func moveMigratedDB(config *conf, fromPath, toPath string) (err error) { + config.logger.Warn("finalizing migration by moving the temporary DB to the original path", + zap.String("fromPath", fromPath), + zap.String("toPath", toPath)) + // Try to open the temporary migrated DB in exclusive mode before deleting the + // original one. + // If the temporary DB is being copied to the original path by another + // process, this will fail and the original database will not be deleted. + // We don't use the proper database schema here because the temporary DB + // may have been created with a different set of migrations. + db, err := Open("file:"+fromPath, + WithLogger(config.logger), + WithConnections(1), + WithTemp(), + WithNoCheckSchemaDrift(), + WithExclusive(), + ) + if err != nil { + return fmt.Errorf("open temporary DB %s: %w", fromPath, err) + } + if err := deleteDB(toPath); err != nil { + return err + } + if err := db.vacuumInto(toPath); err != nil { + db.Close() + return err + } + // Open the freshly vacuumed DB in exclusive mode to avoid race condition when + // another process also tries to vacuum the temporary DB into the original path + // after we close the temporary DB. + origDB, err := Open("file:"+toPath, + WithLogger(config.logger), + WithConnections(1), + WithMigrationsDisabled(), + WithNoCheckSchemaDrift(), + withDisableIncompleteMigrationHandling(), + WithExclusive(), + ) + if err != nil { + return fmt.Errorf("open vacuumed DB %s: %w", toPath, err) + } + if err := db.Close(); err != nil { + origDB.Close() + return fmt.Errorf("close temporary DB %s: %w", fromPath, err) + } + if err := deleteDB(fromPath); err != nil { + origDB.Close() + return err + } + if err := origDB.Close(); err != nil { + return fmt.Errorf("close DB %s after migration: %w", toPath, err) + } + return nil } +func dbMigrationPaths(uri string) (dbPath, migratedPath string, err error) { + url, err := url.Parse(uri) + if err != nil { + return "", "", fmt.Errorf("parse uri: %w", err) + } + if url.Scheme != "file" { + return "", "", nil + } + path := url.Opaque + if path == "" { + path = url.Path + } + return path, path + "_migrate", nil +} + +// handleIncompleteCopyMigration handles incomplete copy-based migrations. +// It only works for 'file:' URIs, doing nothing for other URIs. +// It first checks if there's a copy of the database with "_migrate" suffix. +// If it's there, it checks if the migration is complete by checking if +// DBNAME_migrate_done file exists. It it doesn't, the migration is considered +// incomplete and the migrated database is removed. If DBNAME_migrate_done +// file exists, the migration is finalized by running "VACUUM INTO" on the +// migrated database and replacing the original, after which the migrated +// database is deleted. +func handleIncompleteCopyMigration(config *conf) error { + dbPath, migratedPath, err := dbMigrationPaths(config.uri) + if err != nil { + return fmt.Errorf("getting DB migration paths: %w", err) + } + if migratedPath == "" { + return nil + } + if _, err := os.Stat(migratedPath); err != nil { + if errors.Is(err, os.ErrNotExist) { + // no migration in progress + return nil + } + return fmt.Errorf("stat %s: %w", migratedPath, err) + } + if _, err := os.Stat(migratedPath + "_done"); err != nil { + if errors.Is(err, os.ErrNotExist) { + // incomplete migration, delete the temporary DB to start over + // after that + config.logger.Warn("incomplete migration detected, deleting the temporary DB", + zap.String("path", migratedPath)) + return deleteDB(migratedPath) + } + } + + // the migration is complete except for the last step + return moveMigratedDB(config, migratedPath, dbPath) +} + +// Interceptor is invoked on every query after it's added to a database using +// PushIntercept. The query will fail if Interceptor returns an error. +type Interceptor func(query string) error + // Database represents a database. type Database interface { Executor @@ -315,6 +561,8 @@ type Database interface { WithTx(ctx context.Context, exec func(Transaction) error) error TxImmediate(ctx context.Context) (Transaction, error) WithTxImmediate(ctx context.Context, exec func(Transaction) error) error + Intercept(key string, fn Interceptor) + RemoveInterceptor(key string) } // Transaction represents a transaction. @@ -333,6 +581,9 @@ type sqliteDatabase struct { latency *prometheus.HistogramVec queryCount atomic.Int64 + + interceptMtx sync.Mutex + interceptors map[string]Interceptor } var _ Database = &sqliteDatabase{} @@ -347,6 +598,9 @@ func (db *sqliteDatabase) getConn(ctx context.Context) *sqlite.Conn { } func (db *sqliteDatabase) getTx(ctx context.Context, initstmt string) (*sqliteTx, error) { + if db.closed { + return nil, ErrClosed + } conn := db.getConn(ctx) if conn == nil { return nil, ErrNoConnection @@ -358,12 +612,16 @@ func (db *sqliteDatabase) getTx(ctx context.Context, initstmt string) (*sqliteTx return tx, nil } -func (db *sqliteDatabase) withTx(ctx context.Context, initstmt string, exec func(Transaction) error) error { +func (db *sqliteDatabase) withTx(ctx context.Context, initstmt string, exec func(Transaction) error) (err error) { tx, err := db.getTx(ctx, initstmt) if err != nil { return err } - defer tx.Release() + defer func() { + if rErr := tx.Release(); rErr != nil && err == nil { + err = fmt.Errorf("release tx: %w", rErr) + } + }() if err := exec(tx); err != nil { tx.queryCache.ClearCache() return err @@ -371,6 +629,39 @@ func (db *sqliteDatabase) withTx(ctx context.Context, initstmt string, exec func return tx.Commit() } +func (db *sqliteDatabase) startExclusive() error { + conn := db.getConn(context.Background()) + if conn == nil { + return ErrNoConnection + } + defer db.pool.Put(conn) + // We don't need to wait for long if the database is busy + conn.SetBusyTimeout(1 * time.Millisecond) + // From SQLite docs: + // When the locking-mode is set to EXCLUSIVE, the database connection + // never releases file-locks. The first time the database is read in + // EXCLUSIVE mode, a shared lock is obtained and held. The first time the + // database is written, an exclusive lock is obtained and held. + if _, err := exec(conn, "PRAGMA locking_mode=EXCLUSIVE", nil, nil); err != nil { + return fmt.Errorf("PRAGMA locking_mode=EXCLUSIVE: %w", err) + } + // We need to perform a transaction to have the database actually locked. + // From SQLite docs, regarding BEGIN EXCLUSIVE / BEGIN IMMEDIATE: + // EXCLUSIVE is similar to IMMEDIATE in that a write transaction is + // started immediately. EXCLUSIVE and IMMEDIATE are the same in WAL mode, + // but in other journaling modes, EXCLUSIVE prevents other database + // connections from reading the database while the transaction is + // underway. + _, err := exec(conn, "BEGIN EXCLUSIVE", nil, nil) + if err != nil { + return fmt.Errorf("error starting the EXCLUSIVE transaction: %w", err) + } + if _, err := exec(conn, "COMMIT", nil, nil); err != nil { + return fmt.Errorf("error committing the EXCLUSIVE transaction: %w", err) + } + return nil +} + // Tx creates deferred sqlite transaction. // // Deferred transactions are not started until the first statement. @@ -406,6 +697,17 @@ func (db *sqliteDatabase) WithTxImmediate( return db.withTx(ctx, beginImmediate, exec) } +func (db *sqliteDatabase) runInterceptors(query string) error { + db.interceptMtx.Lock() + defer db.interceptMtx.Unlock() + for _, interceptFn := range db.interceptors { + if err := interceptFn(query); err != nil { + return err + } + } + return nil +} + // Exec statement using one of the connection from the pool. // // If you care about atomicity of the operation (for example writing rewards to multiple accounts) @@ -415,6 +717,13 @@ func (db *sqliteDatabase) WithTxImmediate( // Note that Exec will block until database is closed or statement has finished. // If application needs to control statement execution lifetime use one of the transaction. func (db *sqliteDatabase) Exec(query string, encoder Encoder, decoder Decoder) (int, error) { + if err := db.runInterceptors(query); err != nil { + return 0, err + } + + if db.closed { + return 0, ErrClosed + } db.queryCount.Add(1) conn := db.getConn(context.Background()) if conn == nil { @@ -438,12 +747,229 @@ func (db *sqliteDatabase) Close() error { return nil } if err := db.pool.Close(); err != nil { - return fmt.Errorf("close pool %w", err) + return fmt.Errorf("close pool: %w", err) } db.closed = true return nil } +// Intercept adds an interceptor function to the database. The interceptor functions +// are invoked upon each query. The query will fail if the interceptor returns an error. +// The interceptor can later be removed using RemoveInterceptor with the same key. +func (db *sqliteDatabase) Intercept(key string, fn Interceptor) { + db.interceptMtx.Lock() + defer db.interceptMtx.Unlock() + if db.interceptors == nil { + db.interceptors = make(map[string]Interceptor) + } + db.interceptors[key] = fn +} + +// PopIntercept removes the interceptor function with specified key from the database. +// If there's no such interceptor, the function does nothing. +func (db *sqliteDatabase) RemoveInterceptor(key string) { + db.interceptMtx.Lock() + defer db.interceptMtx.Unlock() + delete(db.interceptors, key) +} + +// vacuumInto runs VACUUM INTO on the database and saves the vacuumed +// database at toPath. +func (db *sqliteDatabase) vacuumInto(toPath string) error { + if _, err := db.Exec("VACUUM INTO ?1", func(stmt *Statement) { + stmt.BindText(1, toPath) + }, nil); err != nil { + return fmt.Errorf("vacuum into %s: %w", toPath, err) + } + return nil +} + +// copyMigrateDB performs a copy-based migration of the database. +// The source database is always closed by this function. +// Upon success, the migrated database is opened. +func (db *sqliteDatabase) copyMigrateDB(config *conf) (finalDB *sqliteDatabase, err error) { + dbPath, migratedPath, err := dbMigrationPaths(config.uri) + if err != nil { + return nil, fmt.Errorf("getting DB migration paths: %w", err) + } + if migratedPath == "" { + return nil, fmt.Errorf("cannot migrate database, only file DBs are supported: %s", config.uri) + } + + // Before we start the migration, re-open the database in exclusive mode + // so that no other connections will be able to use it. + // This will fail if another process is already using this database. + if err := db.Close(); err != nil { + return nil, fmt.Errorf("error closing DB: %w", err) + } + + excDB, err := Open("file:"+dbPath, + WithLogger(config.logger), + WithConnections(1), + WithNoCheckSchemaDrift(), + WithExclusive(), + ) + if err != nil { + return nil, fmt.Errorf("error opening the database in exclusive mode: %v", err) + } + defer excDB.Close() + + // instead of just copying the source database to the temporary migration DB, use VACUUM INTO. + // This is somewhat slower but achieves two goals: + // 1. The lock is held on the source database while it's being copied + // 2. If the source database has a lot of free pages for whatever reason, those + // are not copied, saving disk space + config.logger.Info("making a temporary copy of the database", + zap.String("path", dbPath), + zap.String("target", migratedPath)) + if err := excDB.vacuumInto(migratedPath); err != nil { + if err := deleteDB(migratedPath); err != nil { + config.logger.Error( + "incomplete temporary copy of the database couldn't be deleted", + zap.String("path", migratedPath), + zap.Error(err), + ) + } + return nil, err + } + + // Opening the temporary migrated DB runs the actual migrations on it. + // We disable vacuuming here because we're going to vacuum the temporary DB + // into the original one. + opts := []Opt{ + WithLogger(config.logger), + WithConnections(1), + WithTemp(), + WithDatabaseSchema(config.schema), + WithExclusive(), + } + if !config.checkSchemaDrift { + opts = append(opts, WithNoCheckSchemaDrift()) + } + migratedDB, err := Open("file:"+migratedPath, opts...) + if err != nil { + if err := deleteDB(migratedPath); err != nil { + config.logger.Error( + "incomplete temporary copy of the database couldn't be deleted", + zap.String("path", migratedPath), + zap.Error(err), + ) + } + return nil, fmt.Errorf("process temporary DB %s: %w", migratedPath, err) + } + defer migratedDB.Close() + + // Make sure the temporary DB is fully synced to the disk before creating the marker file. + // We don't need wal_checkpoint(TRUNCATE) here as we're going to delete the temporary DB. + if _, err := migratedDB.Exec("PRAGMA wal_checkpoint(FULL)", nil, nil); err != nil { + if err := deleteDB(migratedPath); err != nil { + config.logger.Error( + "incomplete temporary copy of the database couldn't be deleted", + zap.String("path", migratedPath), + zap.Error(err), + ) + } + return nil, fmt.Errorf("checkpoint temporary DB %s: %w", migratedPath, err) + } + + // Create the marker file to indicate that the migration is complete and make sure + // the file is written to the disk before closing the database. + // We could create a table in the temporary database instead of the marker file, + // but as the temporary database is opened without PRAGMA journal_mode=OFF + // and PRAGMA synchronous=OFF, it may become corrupt in case of a crash or power + // outage, so we avoid trying to open it. + if err := createMarkerFile(migratedPath); err != nil { + if err := deleteDB(migratedPath); err != nil { + config.logger.Error( + "incomplete temporary copy of the database couldn't be deleted", + zap.String("path", migratedPath), + zap.Error(err), + ) + } + // The errors returned by createMarkerFile are already descriptive enough + // so no need to augment them + return nil, err + } + + // At this point, the temporary database is complete and should not be deleted + // until we copy it to the original database location. + + // We only close the source database at the end of the migration process + // so that the lock is held. There's a possibility that right after we + // close the source database, another process will see the migrated database + // and the marker file and will try to open the migrated database. If the + if err := excDB.Close(); err != nil { + return nil, fmt.Errorf("close db: %w", err) + } + + // Delete the original database. VACUUM INTO will fail if the destination + // database exists. + if err := deleteDB(dbPath); err != nil { + return nil, fmt.Errorf("delete original DB %s: %w", dbPath, err) + } + + // Overwrite the original database with the migrated one. + // The lock is held on the temporary DB during this, preventing concurrent + // go-spacemesh instances to attempt the same operation. + config.logger.Info("moving the temporary DB to original location", zap.String("path", dbPath)) + if err := migratedDB.vacuumInto(dbPath); err != nil { + return nil, err + } + + // Open the final DB in the exclusive mode before deleting the source DB, so one of the locks + // is always held. The migrations are already run, so we're disabling them. + origExclusive := config.exclusive + config.enableMigrations = false + config.exclusive = true + finalDB, err = openDB(config) + if err != nil { + return nil, fmt.Errorf("open final DB %s: %w", config.uri, err) + } + + if err := migratedDB.Close(); err != nil { + finalDB.Close() + return nil, fmt.Errorf("close temporary DB %s: %w", migratedPath, err) + } + + // Now we can delete the temporary DB and the marker file. + if err := deleteDB(migratedPath); err != nil { + finalDB.Close() + return nil, err + } + + // If we were not intending to open the database in exclusive mode, + // reopen it in the normal mode + if !origExclusive { + if err := finalDB.Close(); err != nil { + return nil, fmt.Errorf("close final DB: %w", err) + } + config.exclusive = false + finalDB, err = openDB(config) + if err != nil { + return nil, fmt.Errorf("open final DB %s: %w", config.uri, err) + } + } + + return finalDB, nil +} + +func createMarkerFile(basePath string) error { + markerPath := basePath + "_done" + f, err := os.Create(markerPath) + if err != nil { + return fmt.Errorf("create marker file %s: %w", markerPath, err) + } + if err := f.Sync(); err != nil { + f.Close() + os.Remove(markerPath) + return fmt.Errorf("sync/close marker file %s: %w", markerPath, err) + } + if err := f.Close(); err != nil { + return fmt.Errorf("close marker file %s: %w", markerPath, err) + } + return nil +} + // QueryCount returns the number of queries executed, including failed // queries, but not counting transaction start / commit / rollback. func (db *sqliteDatabase) QueryCount() int { @@ -530,6 +1056,10 @@ func (tx *sqliteTx) Release() error { // Exec query. func (tx *sqliteTx) Exec(query string, encoder Encoder, decoder Decoder) (int, error) { + if err := tx.db.runInterceptors(query); err != nil { + return 0, fmt.Errorf("running query interceptors: %w", err) + } + tx.db.queryCount.Add(1) if tx.db.latency != nil { start := time.Now() diff --git a/sql/database_test.go b/sql/database_test.go index 088ef5e24b..e3d138b3f4 100644 --- a/sql/database_test.go +++ b/sql/database_test.go @@ -5,6 +5,7 @@ import ( "errors" "os" "path/filepath" + "strings" "testing" "github.com/stretchr/testify/require" @@ -109,6 +110,7 @@ func Test_Migration_Rollback_Only_NewMigrations(t *testing.T) { WithDatabaseSchema(&Schema{ Migrations: MigrationList{migration1, migration2}, }), + WithNoCheckSchemaDrift(), ) require.ErrorContains(t, err, "migration 2 failed") } @@ -167,43 +169,83 @@ func TestDatabaseSkipMigrations(t *testing.T) { require.NoError(t, db.Close()) } +func execSQL(t *testing.T, db Executor, sql string, col int) (result string) { + _, err := db.Exec(sql, nil, func(stmt *Statement) bool { + if col >= 0 { + result = stmt.ColumnText(col) + } + return true + }) + require.NoError(t, err) + return result +} + func TestDatabaseVacuumState(t *testing.T) { dir := t.TempDir() logger := zaptest.NewLogger(t) ctrl := gomock.NewController(t) + + // The first migration is done without vacuuming and thus it is performed + // in-place. migration1 := NewMockMigration(ctrl) migration1.EXPECT().Order().Return(1).AnyTimes() - migration1.EXPECT().Apply(gomock.Any(), gomock.Any()).Return(nil).Times(1) + migration1.EXPECT().Apply(gomock.Any(), gomock.Any()). + DoAndReturn(func(db Executor, logger *zap.Logger) error { + require.NotContains(t, execSQL(t, db, "PRAGMA database_list", 2), "_migrate") + require.Equal(t, "wal", execSQL(t, db, "PRAGMA journal_mode", 0)) + require.Equal(t, "1", execSQL(t, db, "PRAGMA synchronous", 0)) // NORMAL + execSQL(t, db, "create table foo(x int)", -1) + return nil + }).Times(1) migration2 := NewMockMigration(ctrl) migration2.EXPECT().Order().Return(2).AnyTimes() - migration2.EXPECT().Apply(gomock.Any(), gomock.Any()).Return(nil).Times(1) + migration2.EXPECT().Apply(gomock.Any(), gomock.Any()). + DoAndReturn(func(db Executor, logger *zap.Logger) error { + // We must be operating on a temp database. + require.Contains(t, execSQL(t, db, "PRAGMA database_list", 2), "_migrate") + // Journaling is off for the temp database as it is deleted in case + // of migration failure. + require.Equal(t, "off", execSQL(t, db, "PRAGMA journal_mode", 0)) + // Synchronous is off for the temp database as it is deleted in case + // of migration failure. + require.Equal(t, "0", execSQL(t, db, "PRAGMA synchronous", 0)) // OFF + execSQL(t, db, "create table bar(y int)", -1) + return nil + }).Times(1) dbFile := filepath.Join(dir, "test.sql") db, err := Open("file:"+dbFile, WithLogger(logger), WithDatabaseSchema(&Schema{ + Script: "PRAGMA user_version = 1;\n" + + "CREATE TABLE foo(x int);\n", Migrations: MigrationList{migration1}, }), WithForceMigrations(true), - WithNoCheckSchemaDrift(), + WithConnections(10), ) require.NoError(t, err) + execSQL(t, db, "select * from foo", -1) // ensure table exists require.NoError(t, db.Close()) db, err = Open("file:"+dbFile, WithLogger(logger), WithDatabaseSchema(&Schema{ + Script: "PRAGMA user_version = 2;\n" + + "CREATE TABLE bar(y int);\n" + + "CREATE TABLE foo(x int);\n", Migrations: MigrationList{migration1, migration2}, }), WithVacuumState(2), - WithNoCheckSchemaDrift(), ) require.NoError(t, err) + execSQL(t, db, "select * from foo, bar", -1) require.NoError(t, db.Close()) - // we run pragma wal_checkpoint(TRUNCATE) after vacuum, which drops the wal file + // The wal file should be absent after the database is re-created + // with VACUUM INTO _, err = os.Open(dbFile + "-wal") require.ErrorIs(t, err, os.ErrNotExist) } @@ -222,6 +264,240 @@ func TestQueryCount(t *testing.T) { require.Equal(t, 2, db.QueryCount()) } +func TestDatabaseVacuumStateError(t *testing.T) { + dir := t.TempDir() + logger := zaptest.NewLogger(t) + + ctrl := gomock.NewController(t) + + migration1 := &sqlMigration{ + order: 1, + name: "0001_initial.sql", + content: "create table foo(x int)", + } + + fail := true + migration2 := NewMockMigration(ctrl) + migration2.EXPECT().Name().Return("0002_test.sql").AnyTimes() + migration2.EXPECT().Order().Return(2).AnyTimes() + migration2.EXPECT().Apply(gomock.Any(), gomock.Any()). + DoAndReturn(func(db Executor, logger *zap.Logger) error { + if fail { + return errors.New("migration failed") + } + execSQL(t, db, "create table bar(y int)", -1) + return nil + }).Times(2) + + dbFile := filepath.Join(dir, "test.sql") + db, err := Open("file:"+dbFile, + WithLogger(logger), + WithDatabaseSchema(&Schema{ + Script: "PRAGMA user_version = 1;\n" + + "CREATE TABLE foo(x int);\n", + Migrations: MigrationList{migration1}, + }), + ) + require.NoError(t, err) + execSQL(t, db, "select * from foo", -1) // ensure table exists + require.NoError(t, db.Close()) + + schema := &Schema{ + Script: "PRAGMA user_version = 2;\n" + + "CREATE TABLE bar(y int);\n" + + "CREATE TABLE foo(x int);\n", + Migrations: MigrationList{migration1, migration2}, + } + _, err = Open("file:"+dbFile, + WithLogger(logger), + WithDatabaseSchema(schema), + WithVacuumState(2), + ) + require.Error(t, err) + + // All temporary files need to be deleted upon migration failure. + tmpDBFiles, err := filepath.Glob(filepath.Join(dir, "*_migrate*")) + require.NoError(t, err) + require.Empty(t, tmpDBFiles) + + // Make sure the initial DB is intact after failed migration, + // and the 2nd migration is applied on the second attempt. + fail = false + db, err = Open("file:"+dbFile, + WithLogger(logger), + WithDatabaseSchema(schema), + WithVacuumState(2), + ) + require.NoError(t, err) + execSQL(t, db, "select * from foo, bar", -1) + require.NoError(t, db.Close()) +} + +// faultyMigration is a migration that can be configured to panic during Apply. +// We don't use mock for this as it's not entirely clear what happens if a mocked method +// panics. +type faultyMigration struct { + panic, interceptVacuumInto bool + *sqlMigration +} + +var _ Migration = &faultyMigration{} + +func (m *faultyMigration) Apply(db Executor, logger *zap.Logger) error { + if m.interceptVacuumInto { + db.(Database).Intercept("crashOnVacuum", func(query string) error { + if strings.Contains(strings.ToLower(query), "vacuum into") { + panic("simulated crash") + } + return nil + }) + } + if m.panic { + panic("simulated crash") + } + return m.sqlMigration.Apply(db, logger) +} + +func TestDropIncompleteMigration(t *testing.T) { + dir := t.TempDir() + logger := zaptest.NewLogger(t) + migration1 := &sqlMigration{ + order: 1, + name: "0001_initial.sql", + content: "create table foo(x int)", + } + migration2 := &faultyMigration{ + panic: true, + sqlMigration: &sqlMigration{ + order: 2, + name: "0002_test.sql", + content: "create table bar(y int)", + }, + } + + dbFile := filepath.Join(dir, "test.sql") + db, err := Open("file:"+dbFile, + WithLogger(logger), + WithDatabaseSchema(&Schema{ + Script: "PRAGMA user_version = 1;\n" + + "CREATE TABLE foo(x int);\n", + Migrations: MigrationList{migration1}, + }), + WithForceMigrations(true), + ) + require.NoError(t, err) + require.NoError(t, db.Close()) + + schema := &Schema{ + Script: "PRAGMA user_version = 2;\n" + + "CREATE TABLE bar(y int);\n" + + "CREATE TABLE foo(x int);\n", + Migrations: MigrationList{migration1, migration2}, + } + + require.Panics(t, func() { + Open("file:"+dbFile, + WithLogger(logger), + WithDatabaseSchema(schema), + WithVacuumState(2), + ) + }) + + // Check that temporary database exists after the simulated crash. + // Note that we're checking "*_migrate" not "*_migrate*" to avoid matching + // any erroneously created successful migration markers. + tmpDBFiles, err := filepath.Glob(filepath.Join(dir, "*_migrate")) + require.NoError(t, err) + require.NotEmpty(t, tmpDBFiles) + + // Retry migration. The incompletely migrated temporary database should be dropped. + migration2.panic = false + db, err = Open("file:"+dbFile, + WithLogger(logger), + WithDatabaseSchema(schema), + WithVacuumState(2), + ) + require.NoError(t, err) + execSQL(t, db, "select * from foo, bar", -1) + require.NoError(t, db.Close()) +} + +func TestResumeCopyMigration(t *testing.T) { + dir := t.TempDir() + logger := zaptest.NewLogger(t) + migration1 := &sqlMigration{ + order: 1, + name: "0001_initial.sql", + content: "create table foo(x int)", + } + // This migration will panic when VACUUM INTO is attempted to copy + // the migrated database to the source database location. + migration2 := &faultyMigration{ + interceptVacuumInto: true, + sqlMigration: &sqlMigration{ + order: 2, + name: "0002_test.sql", + content: "create table bar(y int)", + }, + } + + dbFile := filepath.Join(dir, "test.sql") + db, err := Open("file:"+dbFile, + WithLogger(logger), + WithDatabaseSchema(&Schema{ + Script: "PRAGMA user_version = 1;\n" + + "CREATE TABLE foo(x int);\n", + Migrations: MigrationList{migration1}, + }), + WithForceMigrations(true), + ) + require.NoError(t, err) + require.NoError(t, db.Close()) + + schema := &Schema{ + Script: "PRAGMA user_version = 2;\n" + + "CREATE TABLE bar(y int);\n" + + "CREATE TABLE foo(x int);\n", + Migrations: MigrationList{migration1, migration2}, + } + + require.Panics(t, func() { + Open("file:"+dbFile, + WithLogger(logger), + WithDatabaseSchema(schema), + WithVacuumState(2), + ) + }) + + // Check that temporary database exists after the simulated crash. + tmpDBFiles, err := filepath.Glob(filepath.Join(dir, "*")) + t.Logf("tmpDBFiles: %v", tmpDBFiles) + require.NoError(t, err) + require.NotEmpty(t, tmpDBFiles) + + // Retry migration. The migrated database should be copied + // to the source database location without invoking any further + // migrations. As the migration with fault injection is not called, + // the final VACUUM INTO must succeed. + db, err = Open("file:"+dbFile, + WithLogger(logger), + WithDatabaseSchema(schema), + WithVacuumState(2), + ) + require.NoError(t, err) + execSQL(t, db, "select * from foo, bar", -1) + require.NoError(t, db.Close()) +} + +func TestDBClosed(t *testing.T) { + db := InMemory(WithLogger(zaptest.NewLogger(t)), WithNoCheckSchemaDrift()) + require.NoError(t, db.Close()) + _, err := db.Exec("select 1", nil, nil) + require.ErrorIs(t, err, ErrClosed) + err = db.WithTx(context.Background(), func(tx Transaction) error { return nil }) + require.ErrorIs(t, err, ErrClosed) +} + func Test_Migration_FailsIfDatabaseTooNew(t *testing.T) { dir := t.TempDir() logger := zaptest.NewLogger(t) @@ -302,3 +578,38 @@ func TestSchemaDrift(t *testing.T) { require.Regexp(t, `.*\n.*\+.*CREATE TABLE newtbl \(id int\);`, observedLogs.All()[0].ContextMap()["diff"]) } + +func TestExclusive(t *testing.T) { + for _, tc := range []struct { + name string + optsA []Opt + optsB []Opt + }{ + { + name: "exclusive succeeds, non-exclusive fails", + optsA: []Opt{WithExclusive()}, + }, + { + name: "exclusive succeeds, non-exclusive fails", + optsB: []Opt{WithExclusive()}, + }, + { + name: "first exclusive succeeds, second exclusive fails", + optsA: []Opt{WithExclusive()}, + optsB: []Opt{WithExclusive()}, + }, + } { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + dir := t.TempDir() + dbPath := filepath.Join(dir, "test.db") + db, err := Open(dbPath, append([]Opt{WithNoCheckSchemaDrift()}, tc.optsA...)...) + require.NoError(t, err) + _, err = Open(dbPath, append([]Opt{WithNoCheckSchemaDrift()}, tc.optsB...)...) + require.ErrorContains(t, err, "SQLITE_BUSY: database is locked") + _, err = db.Exec("select count(*) from sqlite_master", nil, nil) + require.NoError(t, err) + require.NoError(t, db.Close()) + }) + } +} diff --git a/sql/localsql/localsql_test.go b/sql/localsql/localsql_test.go index 320702d7eb..08af28c66d 100644 --- a/sql/localsql/localsql_test.go +++ b/sql/localsql/localsql_test.go @@ -39,7 +39,7 @@ func TestIdempotentMigration(t *testing.T) { require.Equal(t, 1, observedLogs.Len(), "expected 1 log messages") l := observedLogs.All()[0] - require.Equal(t, "running migrations", l.Message) + require.Equal(t, "running migrations in-place", l.Message) require.Equal(t, int64(0), l.ContextMap()["current version"]) require.Equal(t, int64(versionA), l.ContextMap()["target version"]) diff --git a/sql/migrations.go b/sql/migrations.go index 3a4d37844d..5f4bc2c26f 100644 --- a/sql/migrations.go +++ b/sql/migrations.go @@ -66,10 +66,6 @@ func (m *sqlMigration) Apply(db Executor, logger *zap.Logger) error { } } } - // binding values in pragma statement is not allowed - if _, err := db.Exec(fmt.Sprintf("PRAGMA user_version = %d;", m.order), nil, nil); err != nil { - return fmt.Errorf("update user_version to %d: %w", m.order, err) - } return nil } diff --git a/sql/schema.go b/sql/schema.go index 0ef49bff45..4dc667e78b 100644 --- a/sql/schema.go +++ b/sql/schema.go @@ -120,10 +120,22 @@ func (s *Schema) CheckDBVersion(logger *zap.Logger, db Database) (before, after return before, after, nil } +func (s *Schema) setVersion(db Executor, version int) error { + // binding values in pragma statement is not allowed + if _, err := db.Exec(fmt.Sprintf("PRAGMA user_version = %d;", version), nil, nil); err != nil { + return fmt.Errorf("update user_version to %d: %w", version, err) + } + return nil +} + // Migrate performs database migration. In case if migrations are disabled, the database // version is checked but no migrations are run, and if the database is too old and // migrations are disabled, an error is returned. func (s *Schema) Migrate(logger *zap.Logger, db Database, before, vacuumState int) error { + if logger.Core().Enabled(zap.DebugLevel) { + db.Intercept("logQueries", logQueryInterceptor(logger)) + defer db.RemoveInterceptor("logQueries") + } for i, m := range s.Migrations { if m.Order() <= before { continue @@ -141,19 +153,16 @@ func (s *Schema) Migrate(logger *zap.Logger, db Database, before, vacuumState in return fmt.Errorf("apply %s: %w", m.Name(), err) } } - // version is set intentionally even if actual migration was skipped - if _, err := tx.Exec(fmt.Sprintf("PRAGMA user_version = %d;", m.Order()), nil, nil); err != nil { - return fmt.Errorf("update user_version to %d: %w", m.Order(), err) + if err := s.setVersion(tx, m.Order()); err != nil { + return err } return nil }); err != nil { - err = errors.Join(err, db.Close()) return err } if vacuumState != 0 && before <= vacuumState { if err := Vacuum(db); err != nil { - err = errors.Join(err, db.Close()) return err } } @@ -162,6 +171,53 @@ func (s *Schema) Migrate(logger *zap.Logger, db Database, before, vacuumState in return nil } +// MigrateTempDB performs database migration on the temporary database. +// It doesn't use transactions and the temporary database should be considered +// invalid and discarded if it fails. +// The database is switched into synchronous mode with WAL journal enabled and +// synced after the migrations are completed before setting the database version, +// which triggers file sync. +func (s *Schema) MigrateTempDB(logger *zap.Logger, db Database, before int) error { + if logger.Core().Enabled(zap.DebugLevel) { + db.Intercept("logQueries", logQueryInterceptor(logger)) + defer db.RemoveInterceptor("logQueries") + } + v := before + for _, m := range s.Migrations { + if m.Order() <= v { + continue + } + + if _, ok := s.skipMigration[m.Order()]; !ok { + if err := m.Apply(db, logger); err != nil { + return fmt.Errorf("apply %s: %w", m.Name(), err) + } + } + + // We don't set the version here as if any migration fails, + // the temporary database is considered invalid and should be discarded. + v = m.Order() + } + + logger.Info("syncing temporary database") + + // Enable WAL journal and synchronous mode to ensure the database is synced + if _, err := db.Exec("PRAGMA journal_mode=WAL", nil, nil); err != nil { + return fmt.Errorf("setting WAL journal mode: %w", err) + } + + if _, err := db.Exec("PRAGMA synchronous=FULL", nil, nil); err != nil { + return fmt.Errorf("setting synchronous mode: %w", err) + } + + // This should trigger file sync + if err := s.setVersion(db, v); err != nil { + return fmt.Errorf("setting DB schema version: %w", err) + } + + return nil +} + // SchemaGenOpt represents a schema generator option. type SchemaGenOpt func(g *SchemaGen) @@ -216,3 +272,14 @@ func (g *SchemaGen) Generate(outputFile string) error { } return nil } + +func logQueryInterceptor(logger *zap.Logger) Interceptor { + return func(query string) error { + query = strings.TrimSpace(query) + if p := strings.Index(query, "\n"); p >= 0 { + query = query[:p] + } + logger.Debug("executing query", zap.String("query", query)) + return nil + } +} diff --git a/sql/statesql/statesql_test.go b/sql/statesql/statesql_test.go index a59ce12079..7a5693d290 100644 --- a/sql/statesql/statesql_test.go +++ b/sql/statesql/statesql_test.go @@ -40,7 +40,7 @@ func TestIdempotentMigration(t *testing.T) { // "running migrations" require.Equal(t, 1, observedLogs.Len(), "expected count of log messages") l := observedLogs.All()[0] - require.Equal(t, "running migrations", l.Message) + require.Equal(t, "running migrations in-place", l.Message) require.Equal(t, int64(0), l.ContextMap()["current version"]) require.Equal(t, int64(versionA), l.ContextMap()["target version"])