Skip to content

Commit

Permalink
Fix parsing UUID params passed via PG extended protocol
Browse files Browse the repository at this point in the history
  • Loading branch information
exAspArk committed Feb 24, 2025
1 parent c467de9 commit 60f16a4
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 14 deletions.
2 changes: 1 addition & 1 deletion src/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
)

const (
VERSION = "0.34.1"
VERSION = "0.34.2"

ENV_PORT = "BEMIDB_PORT"
ENV_DATABASE = "BEMIDB_DATABASE"
Expand Down
27 changes: 21 additions & 6 deletions src/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ func (postgres *Postgres) handleExtendedQuery(queryHandler *QueryHandler, parseM
}
postgres.writeMessages(messages...)

var previousErr error
for {
message, err := postgres.backend.Receive()
if err != nil {
Expand All @@ -118,28 +119,40 @@ func (postgres *Postgres) handleExtendedQuery(queryHandler *QueryHandler, parseM

switch message := message.(type) {
case *pgproto3.Bind:
if previousErr != nil { // Skip processing the next message if there was an error in the previous message
continue
}

LogDebug(postgres.config, "Binding query", message.PreparedStatement)
messages, preparedStatement, err = queryHandler.HandleBindQuery(message, preparedStatement)
if err != nil {
postgres.writeError(err)
continue
previousErr = err
}
postgres.writeMessages(messages...)
case *pgproto3.Describe:
if previousErr != nil { // Skip processing the next message if there was an error in the previous message
continue
}

LogDebug(postgres.config, "Describing query", message.Name, "("+string(message.ObjectType)+")")
var messages []pgproto3.Message
messages, preparedStatement, err = queryHandler.HandleDescribeQuery(message, preparedStatement)
if err != nil {
postgres.writeError(err)
continue
previousErr = err
}
postgres.writeMessages(messages...)
case *pgproto3.Execute:
if previousErr != nil { // Skip processing the next message if there was an error in the previous message
continue
}

LogDebug(postgres.config, "Executing query", message.Portal)
messages, err := queryHandler.HandleExecuteQuery(message, preparedStatement)
if err != nil {
postgres.writeError(err)
continue
previousErr = err
}
postgres.writeMessages(messages...)
case *pgproto3.Sync:
Expand All @@ -148,11 +161,13 @@ func (postgres *Postgres) handleExtendedQuery(queryHandler *QueryHandler, parseM
&pgproto3.ReadyForQuery{TxStatus: PG_TX_STATUS_IDLE},
)

// If Bind step completed, it means that sync is the last message in the extended query protocol, we can exit handleExtendedQuery
// Otherwise, wait for Bind/Describe/Execute/Sync. For example, psycopg sends an extra Sync after Parse
if preparedStatement.Bound {
// If there was an error or Parse->Bind->Sync (...) or Parse->Describe->Sync (e.g., Metabase)
// it means that sync is the last message in the extended query protocol, we can exit handleExtendedQuery
if previousErr != nil || preparedStatement.Bound || preparedStatement.Described {
return nil
}
// Otherwise, wait for Bind/Describe/Execute/Sync.
// For example, psycopg sends Parse->[extra Sync]->Bind->Describe->Execute->Sync
}
}
}
Expand Down
13 changes: 11 additions & 2 deletions src/query_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"strconv"
"strings"

"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgproto3"
"github.com/jackc/pgx/v5/pgtype"
duckDb "github.com/marcboeker/go-duckdb"
Expand Down Expand Up @@ -44,7 +45,10 @@ type PreparedStatement struct {
Variables []interface{}
Portal string

// Describe or Execute
// Describe
Described bool

// Describe/Execute
Rows *sql.Rows
}

Expand Down Expand Up @@ -309,6 +313,10 @@ func (queryHandler *QueryHandler) HandleBindQuery(message *pgproto3.Bind, prepar
variables = append(variables, int32(binary.BigEndian.Uint32(param)))
} else if len(param) == 8 {
variables = append(variables, int64(binary.BigEndian.Uint64(param)))
} else if len(param) == 16 {
variables = append(variables, uuid.UUID(param).String())
} else {
return nil, nil, fmt.Errorf("unsupported parameter format: %v (length %d). Original query: %s", param, len(param), preparedStatement.OriginalQuery)
}
}

Expand All @@ -334,13 +342,14 @@ func (queryHandler *QueryHandler) HandleDescribeQuery(message *pgproto3.Describe
}
}

preparedStatement.Described = true
if preparedStatement.Query == "" || !preparedStatement.Bound { // Empty query or Parse->[No Bind]->Describe
return []pgproto3.Message{&pgproto3.NoData{}}, preparedStatement, nil
}

rows, err := preparedStatement.Statement.QueryContext(context.Background(), preparedStatement.Variables...)
if err != nil {
return nil, nil, err
return nil, nil, fmt.Errorf("couldn't execute statement: %w. Original query: %s", err, preparedStatement.OriginalQuery)
}
preparedStatement.Rows = rows

Expand Down
64 changes: 59 additions & 5 deletions src/query_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"strings"
"testing"

"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgproto3"
"github.com/jackc/pgx/v5/pgtype"
)
Expand Down Expand Up @@ -1143,8 +1144,7 @@ func TestHandleParseQuery(t *testing.T) {
func TestHandleBindQuery(t *testing.T) {
t.Run("Handles BIND extended query step with text format parameter", func(t *testing.T) {
queryHandler := initQueryHandler()
query := "SELECT usename, passwd FROM pg_shadow WHERE usename=$1"
parseMessage := &pgproto3.Parse{Query: query}
parseMessage := &pgproto3.Parse{Query: "SELECT usename, passwd FROM pg_shadow WHERE usename=$1"}
_, preparedStatement, err := queryHandler.HandleParseQuery(parseMessage)
testNoError(t, err)

Expand All @@ -1166,10 +1166,37 @@ func TestHandleBindQuery(t *testing.T) {
}
})

t.Run("Handles BIND extended query step with binary format parameter", func(t *testing.T) {
t.Run("Handles BIND extended query step with binary format 4-byte parameter", func(t *testing.T) {
queryHandler := initQueryHandler()
query := "SELECT c.oid FROM pg_catalog.pg_class c WHERE c.relnamespace = $1"
parseMessage := &pgproto3.Parse{Query: query}
parseMessage := &pgproto3.Parse{Query: "SELECT c.oid FROM pg_catalog.pg_class c WHERE c.relnamespace = $1"}
_, preparedStatement, err := queryHandler.HandleParseQuery(parseMessage)
testNoError(t, err)

paramValue := int32(2200)
paramBytes := make([]byte, 4)
binary.BigEndian.PutUint32(paramBytes, uint32(paramValue))

bindMessage := &pgproto3.Bind{
Parameters: [][]byte{paramBytes},
ParameterFormatCodes: []int16{1}, // Binary format
}
messages, preparedStatement, err := queryHandler.HandleBindQuery(bindMessage, preparedStatement)

testNoError(t, err)
testMessageTypes(t, messages, []pgproto3.Message{
&pgproto3.BindComplete{},
})
if len(preparedStatement.Variables) != 1 {
t.Errorf("Expected the prepared statement to have 1 variable, got %v", len(preparedStatement.Variables))
}
if preparedStatement.Variables[0] != paramValue {
t.Errorf("Expected the prepared statement variable to be %v, got %v", paramValue, preparedStatement.Variables[0])
}
})

t.Run("Handles BIND extended query step with binary format 8-byte parameter", func(t *testing.T) {
queryHandler := initQueryHandler()
parseMessage := &pgproto3.Parse{Query: "SELECT c.oid FROM pg_catalog.pg_class c WHERE c.relnamespace = $1"}
_, preparedStatement, err := queryHandler.HandleParseQuery(parseMessage)
testNoError(t, err)

Expand All @@ -1194,6 +1221,33 @@ func TestHandleBindQuery(t *testing.T) {
t.Errorf("Expected the prepared statement variable to be %v, got %v", paramValue, preparedStatement.Variables[0])
}
})

t.Run("Handles BIND extended query step with binary format 16-byte (uuid) parameter", func(t *testing.T) {
queryHandler := initQueryHandler()
parseMessage := &pgproto3.Parse{Query: "SELECT uuid_column FROM public.test_table WHERE uuid_column = $1"}
_, preparedStatement, err := queryHandler.HandleParseQuery(parseMessage)
testNoError(t, err)

uuidParam := "58a7c845-af77-44b2-8664-7ca613d92f04"
paramBytes, _ := uuid.Must(uuid.Parse(uuidParam)).MarshalBinary()

bindMessage := &pgproto3.Bind{
Parameters: [][]byte{paramBytes},
ParameterFormatCodes: []int16{1}, // Binary format
}
messages, preparedStatement, err := queryHandler.HandleBindQuery(bindMessage, preparedStatement)

testNoError(t, err)
testMessageTypes(t, messages, []pgproto3.Message{
&pgproto3.BindComplete{},
})
if len(preparedStatement.Variables) != 1 {
t.Errorf("Expected the prepared statement to have 1 variable, got %v", len(preparedStatement.Variables))
}
if preparedStatement.Variables[0] != uuidParam {
t.Errorf("Expected the prepared statement variable to be %v, got %v", uuidParam, preparedStatement.Variables[0])
}
})
}

func TestHandleDescribeQuery(t *testing.T) {
Expand Down

0 comments on commit 60f16a4

Please sign in to comment.