diff --git a/pkg/storage/sqlstorage/store_ledger_test.go b/pkg/storage/sqlstorage/store_ledger_test.go index 67c2cf5d1..01ea008ec 100644 --- a/pkg/storage/sqlstorage/store_ledger_test.go +++ b/pkg/storage/sqlstorage/store_ledger_test.go @@ -93,6 +93,7 @@ func TestStore(t *testing.T) { {name: "GetBalancesAggregated", fn: testGetBalancesAggregated}, {name: "GetBalancesAggregatedByAccount", fn: testGetBalancesAggregatedByAccount}, {name: "CreateIK", fn: testIKS}, + {name: "GetTransactionsByAccount", fn: testGetTransactionsByAccount}, } { t.Run(fmt.Sprintf("%s/%s-singleInstance", ledgertesting.StorageDriverName(), tf.name), runTest((tf))) } diff --git a/pkg/storage/sqlstorage/transactions.go b/pkg/storage/sqlstorage/transactions.go index 4f12f85a3..0b017e5fc 100644 --- a/pkg/storage/sqlstorage/transactions.go +++ b/pkg/storage/sqlstorage/transactions.go @@ -99,16 +99,22 @@ func (s *Store) buildTransactionsQuery(flavor Flavor, p ledger.TransactionsQuery sb.Where(s.schema.Table("use_account") + "(postings, " + arg + ")") } else { // new wildcard handling - dst := strings.Split(account, ":") - sb.Where(fmt.Sprintf("(jsonb_array_length(postings.destination) = %d OR jsonb_array_length(postings.source) = %d)", len(dst), len(dst))) - for i, segment := range dst { - if segment == ".*" || segment == "*" || segment == "" { - continue + ands := make([]string, 0) + for _, column := range []string{"source", "destination"} { + forColumn := make([]string, 0) + forColumn = append(forColumn, fmt.Sprintf("(jsonb_array_length(postings.%s) = %d)", column, len(strings.Split(account, ":")))) + for i, segment := range strings.Split(account, ":") { + if segment == ".*" || segment == "*" || segment == "" { + continue + } + + arg := sb.Args.Add(segment) + forColumn = append(forColumn, fmt.Sprintf("postings.%s @@ ('$[%d] == \"' || %s::text || '\"')::jsonpath", column, i, arg)) } - - arg := sb.Args.Add(segment) - sb.Where(fmt.Sprintf("(postings.source @@ ('$[%d] == \"' || %s::text || '\"')::jsonpath OR postings.destination @@ ('$[%d] == \"' || %s::text || '\"')::jsonpath)", i, arg, i, arg)) + ands = append(ands, sb.And(forColumn...)) } + + sb.Where(sb.Or(ands...)) } t.AccountFilter = account } diff --git a/pkg/storage/sqlstorage/transactions_test.go b/pkg/storage/sqlstorage/transactions_test.go index 3fa52f019..98dc31b58 100644 --- a/pkg/storage/sqlstorage/transactions_test.go +++ b/pkg/storage/sqlstorage/transactions_test.go @@ -523,3 +523,28 @@ func testTransactionsQueryAddress(t *testing.T, store *sqlstorage.Store) { assert.Equal(t, cursor.Data[0].ID, tx5.ID) }) } + +func testGetTransactionsByAccount(t *testing.T, store *sqlstorage.Store) { + now := time.Now() + err := store.Commit(context.Background(), core.ExpandedTransaction{ + Transaction: core.Transaction{ + TransactionData: core.TransactionData{ + Postings: core.Postings{ + { + Source: "a:b:c", + Destination: "d:e:f", + Amount: core.NewMonetaryInt(10), + Asset: "USD", + }, + }, + Timestamp: now, + }, + ID: 0, + }, + }) + require.NoError(t, err) + + txs, err := store.GetTransactions(context.Background(), *ledger.NewTransactionsQuery().WithAccountFilter("a:e:c")) + require.NoError(t, err) + require.Empty(t, txs.Data) +}