diff --git a/internal/service/common/utils/repository.go b/internal/service/common/utils/repository.go index 55c1a559..b9799867 100644 --- a/internal/service/common/utils/repository.go +++ b/internal/service/common/utils/repository.go @@ -8,7 +8,7 @@ import ( "github.com/google/uuid" "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/pgxpool" + "github.com/jackc/pgx/v5/pgconn" "github.com/stephenafamo/bob" "github.com/stephenafamo/bob/dialect/psql" "github.com/stephenafamo/bob/dialect/psql/dialect" @@ -26,10 +26,17 @@ var ErrNotFound = errors.New("record not found") // Following functions are meant to fulfill basic CRUD operations on the database. More complex queries or bulk operations // for Insert or Update should be built in the repository files of the specific service and called one of the Execute helper functions. +// DBQuery is an abstraction to allow passing either a pool or a transaction query function to any of the utilities. +type DBQuery interface { + Exec(ctx context.Context, sql string, arguments ...any) (commandTag pgconn.CommandTag, err error) + Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) + QueryRow(ctx context.Context, sql string, args ...any) pgx.Row +} + // Find retrieves a specific tuple from the database table specified. // The `uuid` argument is the primary key of the record to retrieve. // If no record is found ErrNotFound is returned as an error. -func Find[T db.Model](ctx context.Context, db *pgxpool.Pool, uuid uuid.UUID) (*T, error) { +func Find[T db.Model](ctx context.Context, db DBQuery, uuid uuid.UUID) (*T, error) { var record T tags := GetAllDBTagsFromStruct(record) @@ -48,7 +55,7 @@ func Find[T db.Model](ctx context.Context, db *pgxpool.Pool, uuid uuid.UUID) (*T // FindAll retrieves all tuples from the database table specified. // The `fields` argument is a list of columns to retrieve. If no fields are specified then all columns are fetched. // If no records are found then an empty array is returned. -func FindAll[T db.Model](ctx context.Context, db *pgxpool.Pool, fields ...string) ([]T, error) { +func FindAll[T db.Model](ctx context.Context, db DBQuery, fields ...string) ([]T, error) { return Search[T](ctx, db, nil, fields...) } @@ -56,7 +63,7 @@ func FindAll[T db.Model](ctx context.Context, db *pgxpool.Pool, fields ...string // The `fields` argument is a list of columns to retrieve. If no fields are specified then all columns are fetched. // The `whereExpr` argument is a custom expression to filter the records. // If no records are found then an empty array is returned. -func Search[T db.Model](ctx context.Context, db *pgxpool.Pool, whereExpr bob.Expression, fields ...string) ([]T, error) { +func Search[T db.Model](ctx context.Context, db DBQuery, whereExpr bob.Expression, fields ...string) ([]T, error) { // Build sql query var record T tags := GetAllDBTagsFromStruct(record) @@ -83,7 +90,7 @@ func Search[T db.Model](ctx context.Context, db *pgxpool.Pool, whereExpr bob.Exp // Delete deletes a specific tuple from the database table specified using a custom expression. // The `whereExpr` argument is a custom expression to filter the records. // The number of rows affected is returned on success; otherwise an error is returned. -func Delete[T db.Model](ctx context.Context, db *pgxpool.Pool, whereExpr psql.Expression) (int64, error) { +func Delete[T db.Model](ctx context.Context, db DBQuery, whereExpr psql.Expression) (int64, error) { var record T query := psql.Delete( dm.From(record.TableName()), @@ -108,7 +115,7 @@ func Delete[T db.Model](ctx context.Context, db *pgxpool.Pool, whereExpr psql.Ex // The "record" argument is the record to store in the database. // The "fields" argument is a list of columns to store. If no fields are specified only non-nil fields are stored. // The stored record is returned on success; otherwise an error is returned. -func Create[T db.Model](ctx context.Context, db *pgxpool.Pool, record T, fields ...string) (*T, error) { +func Create[T db.Model](ctx context.Context, db DBQuery, record T, fields ...string) (*T, error) { all := GetAllDBTagsFromStruct(record) tags := GetNonNilDBTagsFromStruct(record) if len(fields) > 0 { @@ -135,7 +142,7 @@ func Create[T db.Model](ctx context.Context, db *pgxpool.Pool, record T, fields // The `record` argument is the record to update in the database. // The `fields` argument is a list of columns to update. If no fields are specified only non-nil fields are updated. // The updated record is returned on success; otherwise an error is returned. -func Update[T db.Model](ctx context.Context, db *pgxpool.Pool, uuid uuid.UUID, record T, fields ...string) (*T, error) { +func Update[T db.Model](ctx context.Context, db DBQuery, uuid uuid.UUID, record T, fields ...string) (*T, error) { all := GetAllDBTagsFromStruct(record) tags := all if len(fields) > 0 { @@ -167,7 +174,7 @@ func Update[T db.Model](ctx context.Context, db *pgxpool.Pool, uuid uuid.UUID, r // Exists checks whether a record exists in the database table specified. // The `uuid` argument is the primary key of the record to check. -func Exists[T db.Model](ctx context.Context, db *pgxpool.Pool, uuid uuid.UUID) (bool, error) { +func Exists[T db.Model](ctx context.Context, db DBQuery, uuid uuid.UUID) (bool, error) { var record T query := psql.RawQuery(fmt.Sprintf("SELECT EXISTS(SELECT 1 FROM %s WHERE %s=?)", @@ -192,7 +199,7 @@ func Exists[T db.Model](ctx context.Context, db *pgxpool.Pool, uuid uuid.UUID) ( // Helper Execute Query functions // ExecuteCollectExactlyOneRow executes a query and collects result using pgx.CollectExactlyOneRow. -func ExecuteCollectExactlyOneRow[T db.Model](ctx context.Context, db *pgxpool.Pool, sql string, args []any) (*T, error) { +func ExecuteCollectExactlyOneRow[T db.Model](ctx context.Context, db DBQuery, sql string, args []any) (*T, error) { var record T var err error @@ -214,7 +221,7 @@ func ExecuteCollectExactlyOneRow[T db.Model](ctx context.Context, db *pgxpool.Po } // ExecuteCollectRows executes a query and collects result using pgx.CollectRows. -func ExecuteCollectRows[T db.Model](ctx context.Context, db *pgxpool.Pool, sql string, args []any) ([]T, error) { +func ExecuteCollectRows[T db.Model](ctx context.Context, db DBQuery, sql string, args []any) ([]T, error) { var record T var err error diff --git a/internal/service/resources/collector/utils.go b/internal/service/resources/collector/utils.go index 78c5ef0c..e1127542 100644 --- a/internal/service/resources/collector/utils.go +++ b/internal/service/resources/collector/utils.go @@ -11,6 +11,7 @@ import ( "strings" "github.com/google/uuid" + "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" "github.com/openshift-kni/oran-o2ims/internal/model" @@ -27,15 +28,15 @@ type GenericModelConverter func(object interface{}) any // persistObject persists an object to its database table. If the object does not already have a // persisted representation then it is created; otherwise any modified fields are updated in the // database tuple. The function returns both the before and after versions of the object. -func persistObject[T db.Model](ctx context.Context, db *pgxpool.Pool, +func persistObject[T db.Model](ctx context.Context, tx pgx.Tx, object T, uuid uuid.UUID) (*T, *T, error) { var before, after *T // Store the object into the database handling cases for both insert/update separately so that we have access to the // before & after view of the data. - var record, err = utils.Find[T](ctx, db, uuid) + var record, err = utils.Find[T](ctx, tx, uuid) if errors.Is(err, utils.ErrNotFound) { // New object instance - after, err = utils.Create[T](ctx, db, object) + after, err = utils.Create[T](ctx, tx, object) if err != nil { return nil, nil, fmt.Errorf("failed to create object '%s/%s': %w", object.TableName(), uuid, err) } @@ -60,7 +61,7 @@ func persistObject[T db.Model](ctx context.Context, db *pgxpool.Pool, return before, after, nil } - after, err = utils.Update[T](ctx, db, uuid, object, tags.Fields()...) + after, err = utils.Update[T](ctx, tx, uuid, object, tags.Fields()...) if err != nil { return nil, nil, fmt.Errorf("failed to update object '%s/%s': %w", object.TableName(), uuid, err) } @@ -73,7 +74,7 @@ func persistObject[T db.Model](ctx context.Context, db *pgxpool.Pool, // persistDataChangeEvent persists a data change object to its database table. The before and // after model objects are marshaled to JSON prior to being stored. -func persistDataChangeEvent(ctx context.Context, db *pgxpool.Pool, tableName string, uuid uuid.UUID, +func persistDataChangeEvent(ctx context.Context, tx pgx.Tx, tableName string, uuid uuid.UUID, parentUUID *uuid.UUID, before, after any) (*models.DataChangeEvent, error) { var err error var beforeJSON, afterJSON []byte @@ -103,7 +104,7 @@ func persistDataChangeEvent(ctx context.Context, db *pgxpool.Pool, tableName str dataChangeEvent.AfterState = &value } - result, err := utils.Create[models.DataChangeEvent](ctx, db, dataChangeEvent) + result, err := utils.Create[models.DataChangeEvent](ctx, tx, dataChangeEvent) if err != nil { return nil, fmt.Errorf("failed to create data change event: %w", err) } @@ -111,6 +112,15 @@ func persistDataChangeEvent(ctx context.Context, db *pgxpool.Pool, tableName str return result, nil } +// rollback attempts to execute a rollback on a transaction. It is safe to invoke this as a +// deferred function call even if the transaction has already been committed. +func rollback(ctx context.Context, tx pgx.Tx) { + err := tx.Rollback(ctx) + if err != nil && !errors.Is(err, pgx.ErrTxClosed) { + slog.Error("failed to rollback transaction", "error", err) + } +} + // persistObjectWithChangeEvent persists an object to its database table and if the external API // model representation of the object has changed then a data change event is stored. Persisting // of the object and its change event are captured under the same transaction to ensure we never @@ -119,14 +129,15 @@ func persistObjectWithChangeEvent[T db.Model](ctx context.Context, db *pgxpool.P uuid uuid.UUID, parentUUID *uuid.UUID, converter GenericModelConverter) (*models.DataChangeEvent, error) { var dataChangeEvent *models.DataChangeEvent - txCtx, cancel := context.WithCancel(ctx) - defer cancel() - tx, err := db.Begin(txCtx) + + tx, err := db.Begin(ctx) if err != nil { return nil, fmt.Errorf("failed to begin transaction: %w", err) } - before, after, err := persistObject(ctx, db, record, uuid) + defer rollback(ctx, tx) + + before, after, err := persistObject(ctx, tx, record, uuid) if err != nil { return nil, fmt.Errorf("failed to persist object: %w", err) } @@ -141,7 +152,7 @@ func persistObjectWithChangeEvent[T db.Model](ctx context.Context, db *pgxpool.P if beforeModel == nil || !reflect.DeepEqual(beforeModel, afterModel) { // Capture a change event if the data actually changed dataChangeEvent, err = persistDataChangeEvent( - ctx, db, record.TableName(), uuid, parentUUID, beforeModel, afterModel) + ctx, tx, record.TableName(), uuid, parentUUID, beforeModel, afterModel) if err != nil { return nil, fmt.Errorf("failed to persist resource type data change object: %w", err) }