Skip to content

Commit

Permalink
CNF-15570: Fix transaction support in collector
Browse files Browse the repository at this point in the history
This addresses an issue in the collector persistence utilities with
using transactions to execute create/update of an object combined with
persisting of the data change event.

Signed-off-by: Allain Legacy <[email protected]>
  • Loading branch information
alegacy authored and openshift-merge-bot[bot] committed Dec 13, 2024
1 parent e3cb2d1 commit fc03dd6
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 21 deletions.
27 changes: 17 additions & 10 deletions internal/service/common/utils/repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (

"github.com/google/uuid"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/jackc/pgx/v5/pgconn"
"github.com/stephenafamo/bob"
"github.com/stephenafamo/bob/dialect/psql"
"github.com/stephenafamo/bob/dialect/psql/dialect"
Expand All @@ -26,10 +26,17 @@ var ErrNotFound = errors.New("record not found")
// Following functions are meant to fulfill basic CRUD operations on the database. More complex queries or bulk operations
// for Insert or Update should be built in the repository files of the specific service and called one of the Execute helper functions.

// DBQuery is an abstraction to allow passing either a pool or a transaction query function to any of the utilities.
type DBQuery interface {
Exec(ctx context.Context, sql string, arguments ...any) (commandTag pgconn.CommandTag, err error)
Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error)
QueryRow(ctx context.Context, sql string, args ...any) pgx.Row
}

// Find retrieves a specific tuple from the database table specified.
// The `uuid` argument is the primary key of the record to retrieve.
// If no record is found ErrNotFound is returned as an error.
func Find[T db.Model](ctx context.Context, db *pgxpool.Pool, uuid uuid.UUID) (*T, error) {
func Find[T db.Model](ctx context.Context, db DBQuery, uuid uuid.UUID) (*T, error) {
var record T
tags := GetAllDBTagsFromStruct(record)

Expand All @@ -48,15 +55,15 @@ func Find[T db.Model](ctx context.Context, db *pgxpool.Pool, uuid uuid.UUID) (*T
// FindAll retrieves all tuples from the database table specified.
// The `fields` argument is a list of columns to retrieve. If no fields are specified then all columns are fetched.
// If no records are found then an empty array is returned.
func FindAll[T db.Model](ctx context.Context, db *pgxpool.Pool, fields ...string) ([]T, error) {
func FindAll[T db.Model](ctx context.Context, db DBQuery, fields ...string) ([]T, error) {
return Search[T](ctx, db, nil, fields...)
}

// Search retrieves tuples from the database table specified using a custom expression.
// The `fields` argument is a list of columns to retrieve. If no fields are specified then all columns are fetched.
// The `whereExpr` argument is a custom expression to filter the records.
// If no records are found then an empty array is returned.
func Search[T db.Model](ctx context.Context, db *pgxpool.Pool, whereExpr bob.Expression, fields ...string) ([]T, error) {
func Search[T db.Model](ctx context.Context, db DBQuery, whereExpr bob.Expression, fields ...string) ([]T, error) {
// Build sql query
var record T
tags := GetAllDBTagsFromStruct(record)
Expand All @@ -83,7 +90,7 @@ func Search[T db.Model](ctx context.Context, db *pgxpool.Pool, whereExpr bob.Exp
// Delete deletes a specific tuple from the database table specified using a custom expression.
// The `whereExpr` argument is a custom expression to filter the records.
// The number of rows affected is returned on success; otherwise an error is returned.
func Delete[T db.Model](ctx context.Context, db *pgxpool.Pool, whereExpr psql.Expression) (int64, error) {
func Delete[T db.Model](ctx context.Context, db DBQuery, whereExpr psql.Expression) (int64, error) {
var record T
query := psql.Delete(
dm.From(record.TableName()),
Expand All @@ -108,7 +115,7 @@ func Delete[T db.Model](ctx context.Context, db *pgxpool.Pool, whereExpr psql.Ex
// The "record" argument is the record to store in the database.
// The "fields" argument is a list of columns to store. If no fields are specified only non-nil fields are stored.
// The stored record is returned on success; otherwise an error is returned.
func Create[T db.Model](ctx context.Context, db *pgxpool.Pool, record T, fields ...string) (*T, error) {
func Create[T db.Model](ctx context.Context, db DBQuery, record T, fields ...string) (*T, error) {
all := GetAllDBTagsFromStruct(record)
tags := GetNonNilDBTagsFromStruct(record)
if len(fields) > 0 {
Expand All @@ -135,7 +142,7 @@ func Create[T db.Model](ctx context.Context, db *pgxpool.Pool, record T, fields
// The `record` argument is the record to update in the database.
// The `fields` argument is a list of columns to update. If no fields are specified only non-nil fields are updated.
// The updated record is returned on success; otherwise an error is returned.
func Update[T db.Model](ctx context.Context, db *pgxpool.Pool, uuid uuid.UUID, record T, fields ...string) (*T, error) {
func Update[T db.Model](ctx context.Context, db DBQuery, uuid uuid.UUID, record T, fields ...string) (*T, error) {
all := GetAllDBTagsFromStruct(record)
tags := all
if len(fields) > 0 {
Expand Down Expand Up @@ -167,7 +174,7 @@ func Update[T db.Model](ctx context.Context, db *pgxpool.Pool, uuid uuid.UUID, r

// Exists checks whether a record exists in the database table specified.
// The `uuid` argument is the primary key of the record to check.
func Exists[T db.Model](ctx context.Context, db *pgxpool.Pool, uuid uuid.UUID) (bool, error) {
func Exists[T db.Model](ctx context.Context, db DBQuery, uuid uuid.UUID) (bool, error) {
var record T

query := psql.RawQuery(fmt.Sprintf("SELECT EXISTS(SELECT 1 FROM %s WHERE %s=?)",
Expand All @@ -192,7 +199,7 @@ func Exists[T db.Model](ctx context.Context, db *pgxpool.Pool, uuid uuid.UUID) (
// Helper Execute Query functions

// ExecuteCollectExactlyOneRow executes a query and collects result using pgx.CollectExactlyOneRow.
func ExecuteCollectExactlyOneRow[T db.Model](ctx context.Context, db *pgxpool.Pool, sql string, args []any) (*T, error) {
func ExecuteCollectExactlyOneRow[T db.Model](ctx context.Context, db DBQuery, sql string, args []any) (*T, error) {
var record T
var err error

Expand All @@ -214,7 +221,7 @@ func ExecuteCollectExactlyOneRow[T db.Model](ctx context.Context, db *pgxpool.Po
}

// ExecuteCollectRows executes a query and collects result using pgx.CollectRows.
func ExecuteCollectRows[T db.Model](ctx context.Context, db *pgxpool.Pool, sql string, args []any) ([]T, error) {
func ExecuteCollectRows[T db.Model](ctx context.Context, db DBQuery, sql string, args []any) ([]T, error) {
var record T
var err error

Expand Down
33 changes: 22 additions & 11 deletions internal/service/resources/collector/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"strings"

"github.com/google/uuid"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"

"github.com/openshift-kni/oran-o2ims/internal/model"
Expand All @@ -27,15 +28,15 @@ type GenericModelConverter func(object interface{}) any
// persistObject persists an object to its database table. If the object does not already have a
// persisted representation then it is created; otherwise any modified fields are updated in the
// database tuple. The function returns both the before and after versions of the object.
func persistObject[T db.Model](ctx context.Context, db *pgxpool.Pool,
func persistObject[T db.Model](ctx context.Context, tx pgx.Tx,
object T, uuid uuid.UUID) (*T, *T, error) {
var before, after *T
// Store the object into the database handling cases for both insert/update separately so that we have access to the
// before & after view of the data.
var record, err = utils.Find[T](ctx, db, uuid)
var record, err = utils.Find[T](ctx, tx, uuid)
if errors.Is(err, utils.ErrNotFound) {
// New object instance
after, err = utils.Create[T](ctx, db, object)
after, err = utils.Create[T](ctx, tx, object)
if err != nil {
return nil, nil, fmt.Errorf("failed to create object '%s/%s': %w", object.TableName(), uuid, err)
}
Expand All @@ -60,7 +61,7 @@ func persistObject[T db.Model](ctx context.Context, db *pgxpool.Pool,
return before, after, nil
}

after, err = utils.Update[T](ctx, db, uuid, object, tags.Fields()...)
after, err = utils.Update[T](ctx, tx, uuid, object, tags.Fields()...)
if err != nil {
return nil, nil, fmt.Errorf("failed to update object '%s/%s': %w", object.TableName(), uuid, err)
}
Expand All @@ -73,7 +74,7 @@ func persistObject[T db.Model](ctx context.Context, db *pgxpool.Pool,

// persistDataChangeEvent persists a data change object to its database table. The before and
// after model objects are marshaled to JSON prior to being stored.
func persistDataChangeEvent(ctx context.Context, db *pgxpool.Pool, tableName string, uuid uuid.UUID,
func persistDataChangeEvent(ctx context.Context, tx pgx.Tx, tableName string, uuid uuid.UUID,
parentUUID *uuid.UUID, before, after any) (*models.DataChangeEvent, error) {
var err error
var beforeJSON, afterJSON []byte
Expand Down Expand Up @@ -103,14 +104,23 @@ func persistDataChangeEvent(ctx context.Context, db *pgxpool.Pool, tableName str
dataChangeEvent.AfterState = &value
}

result, err := utils.Create[models.DataChangeEvent](ctx, db, dataChangeEvent)
result, err := utils.Create[models.DataChangeEvent](ctx, tx, dataChangeEvent)
if err != nil {
return nil, fmt.Errorf("failed to create data change event: %w", err)
}

return result, nil
}

// rollback attempts to execute a rollback on a transaction. It is safe to invoke this as a
// deferred function call even if the transaction has already been committed.
func rollback(ctx context.Context, tx pgx.Tx) {
err := tx.Rollback(ctx)
if err != nil && !errors.Is(err, pgx.ErrTxClosed) {
slog.Error("failed to rollback transaction", "error", err)
}
}

// persistObjectWithChangeEvent persists an object to its database table and if the external API
// model representation of the object has changed then a data change event is stored. Persisting
// of the object and its change event are captured under the same transaction to ensure we never
Expand All @@ -119,14 +129,15 @@ func persistObjectWithChangeEvent[T db.Model](ctx context.Context, db *pgxpool.P
uuid uuid.UUID, parentUUID *uuid.UUID,
converter GenericModelConverter) (*models.DataChangeEvent, error) {
var dataChangeEvent *models.DataChangeEvent
txCtx, cancel := context.WithCancel(ctx)
defer cancel()
tx, err := db.Begin(txCtx)

tx, err := db.Begin(ctx)
if err != nil {
return nil, fmt.Errorf("failed to begin transaction: %w", err)
}

before, after, err := persistObject(ctx, db, record, uuid)
defer rollback(ctx, tx)

before, after, err := persistObject(ctx, tx, record, uuid)
if err != nil {
return nil, fmt.Errorf("failed to persist object: %w", err)
}
Expand All @@ -141,7 +152,7 @@ func persistObjectWithChangeEvent[T db.Model](ctx context.Context, db *pgxpool.P
if beforeModel == nil || !reflect.DeepEqual(beforeModel, afterModel) {
// Capture a change event if the data actually changed
dataChangeEvent, err = persistDataChangeEvent(
ctx, db, record.TableName(), uuid, parentUUID, beforeModel, afterModel)
ctx, tx, record.TableName(), uuid, parentUUID, beforeModel, afterModel)
if err != nil {
return nil, fmt.Errorf("failed to persist resource type data change object: %w", err)
}
Expand Down

0 comments on commit fc03dd6

Please sign in to comment.