Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce SQL sanitizer allocations #2136

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
59 changes: 59 additions & 0 deletions internal/sanitize/benchmmark.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#!/usr/bin/env bash

current_branch=$(git rev-parse --abbrev-ref HEAD)
if [ "$current_branch" == "HEAD" ]; then
current_branch=$(git rev-parse HEAD)
fi

restore_branch() {
echo "Restoring original branch/commit: $current_branch"
git checkout "$current_branch"
}
trap restore_branch EXIT

# Check if there are uncommitted changes
if ! git diff --quiet || ! git diff --cached --quiet; then
echo "There are uncommitted changes. Please commit or stash them before running this script."
exit 1
fi

# Ensure that at least one commit argument is passed
if [ "$#" -lt 1 ]; then
echo "Usage: $0 <commit1> <commit2> ... <commitN>"
exit 1
fi

commits=("$@")
benchmarks_dir=benchmarks

if ! mkdir -p "${benchmarks_dir}"; then
echo "Unable to create dir for benchmarks data"
exit 1
fi

# Benchmark results
bench_files=()

# Run benchmark for each listed commit
for i in "${!commits[@]}"; do
commit="${commits[i]}"
git checkout "$commit" || {
echo "Failed to checkout $commit"
exit 1
}

# Sanitized commmit message
commit_message=$(git log -1 --pretty=format:"%s" | tr ' ' '_')
Copy link
Contributor

@sean- sean- Oct 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to escape /:

commit_message=$(git log -1 --pretty=format:"%s" | tr -c '[:alnum:]-_' '_')


# Benchmark data will go there
bench_file="${benchmarks_dir}/${i}_${commit_message}.bench"

if ! go test -bench=. -count=10 >"$bench_file"; then
echo "Benchmarking failed for commit $commit"
exit 1
fi

bench_files+=("$bench_file")
done

benchstat "${bench_files[@]}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you prefix with a small comment: # go install golang.org/x/perf/cmd/benchstat@latest

165 changes: 138 additions & 27 deletions internal/sanitize/sanitize.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ import (
"bytes"
"encoding/hex"
"fmt"
"slices"
"strconv"
"strings"
"sync"
"time"
"unicode/utf8"
)
Expand All @@ -24,53 +26,75 @@ type Query struct {
// https://github.com/jackc/pgx/issues/1380
const replacementcharacterwidth = 3

const maxBufSize = 16384 // 16 Ki

var bufPool = &pool[*bytes.Buffer]{
new: func() *bytes.Buffer {
return &bytes.Buffer{}
},
reset: func(b *bytes.Buffer) bool {
n := b.Len()
b.Reset()
return n < maxBufSize
},
}

var null = []byte("null")

func (q *Query) Sanitize(args ...any) (string, error) {
argUse := make([]bool, len(args))
buf := &bytes.Buffer{}
buf := bufPool.get()
defer bufPool.put(buf)

for _, part := range q.Parts {
var str string
switch part := part.(type) {
case string:
str = part
buf.WriteString(part)
case int:
argIdx := part - 1

var p []byte
if argIdx < 0 {
return "", fmt.Errorf("first sql argument must be > 0")
}

if argIdx >= len(args) {
return "", fmt.Errorf("insufficient arguments")
}

// Prevent SQL injection via Line Comment Creation
// https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p
buf.WriteByte(' ')

arg := args[argIdx]
switch arg := arg.(type) {
case nil:
str = "null"
p = null
case int64:
str = strconv.FormatInt(arg, 10)
p = strconv.AppendInt(buf.AvailableBuffer(), arg, 10)
case float64:
str = strconv.FormatFloat(arg, 'f', -1, 64)
p = strconv.AppendFloat(buf.AvailableBuffer(), arg, 'f', -1, 64)
case bool:
str = strconv.FormatBool(arg)
p = strconv.AppendBool(buf.AvailableBuffer(), arg)
case []byte:
str = QuoteBytes(arg)
p = QuoteBytes(buf.AvailableBuffer(), arg)
case string:
str = QuoteString(arg)
p = QuoteString(buf.AvailableBuffer(), arg)
case time.Time:
str = arg.Truncate(time.Microsecond).Format("'2006-01-02 15:04:05.999999999Z07:00:00'")
p = arg.Truncate(time.Microsecond).
AppendFormat(buf.AvailableBuffer(), "'2006-01-02 15:04:05.999999999Z07:00:00'")
default:
return "", fmt.Errorf("invalid arg type: %T", arg)
}
argUse[argIdx] = true

buf.Write(p)

// Prevent SQL injection via Line Comment Creation
// https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p
str = " " + str + " "
buf.WriteByte(' ')
default:
return "", fmt.Errorf("invalid Part type: %T", part)
}
buf.WriteString(str)
}

for i, used := range argUse {
Expand All @@ -82,26 +106,81 @@ func (q *Query) Sanitize(args ...any) (string, error) {
}

func NewQuery(sql string) (*Query, error) {
l := &sqlLexer{
src: sql,
stateFn: rawState,
query := &Query{}
query.init(sql)

return query, nil
}

var sqlLexerPool = &pool[*sqlLexer]{
new: func() *sqlLexer {
return &sqlLexer{}
},
reset: func(sl *sqlLexer) bool {
*sl = sqlLexer{}
return true
},
}

func (q *Query) init(sql string) {
parts := q.Parts[:0]
if parts == nil {
// dirty, but fast heuristic to preallocate for ~90% usecases
n := strings.Count(sql, "$") + strings.Count(sql, "--") + 1
parts = make([]Part, 0, n)
}

l := sqlLexerPool.get()
defer sqlLexerPool.put(l)

l.src = sql
l.stateFn = rawState
l.parts = parts

for l.stateFn != nil {
l.stateFn = l.stateFn(l)
}

query := &Query{Parts: l.parts}

return query, nil
q.Parts = l.parts
}

func QuoteString(str string) string {
return "'" + strings.ReplaceAll(str, "'", "''") + "'"
func QuoteString(dst []byte, str string) []byte {
const quote = "'"

n := strings.Count(str, quote)

dst = append(dst, quote...)

p := slices.Grow(dst[len(dst):], 2*len(quote)+len(str)+2*n)

for len(str) > 0 {
i := strings.Index(str, quote)
if i < 0 {
p = append(p, str...)
break
}
p = append(p, str[:i]...)
p = append(p, "''"...)
str = str[i+1:]
}

dst = append(dst, p...)

dst = append(dst, quote...)

return dst
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is purely a style nit, but I don't like reslicing for these types of functions because it's not idiomatic and hard to follow. I took the above QuoteString() and replaced it with something that uses an iterator:

func QuoteString(dst []byte, str string) []byte {
        const quote = '\''

        // Preallocate space for the worst case scenario
        dst = slices.Grow(dst, len(str)*2+2)

        // Add opening quote
        dst = append(dst, quote)

        // Iterate through the string without allocating
        for i := 0; i < len(str); i++ {
                if str[i] == quote {
                        dst = append(dst, quote, quote)
                } else {
                        dst = append(dst, str[i])
                }
        }

        // Add closing quote
        dst = append(dst, quote)

        return dst
}


func QuoteBytes(buf []byte) string {
return `'\x` + hex.EncodeToString(buf) + "'"
func QuoteBytes(dst, buf []byte) []byte {
dst = append(dst, `'\x`...)

n := hex.EncodedLen(len(buf))
p := slices.Grow(dst[len(dst):], n)[:n]
hex.Encode(p, buf)
dst = append(dst, p...)

dst = append(dst, `'`...)
return dst
}
Copy link
Contributor

@sean- sean- Oct 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was able to measure an improvement by optimizing this function:

func QuoteBytes(dst, buf []byte) []byte {
        if len(buf) == 0 {
                return append(dst, `'\x'`...)
        }

        // Calculate required length
        requiredLen := 3 + hex.EncodedLen(len(buf)) + 1

        // Ensure dst has enough capacity
        if cap(dst)-len(dst) < requiredLen {
                newDst := make([]byte, len(dst), len(dst)+requiredLen)
                copy(newDst, dst)
                dst = newDst
        }

        // Record original length and extend slice
        origLen := len(dst)
        dst = dst[:origLen+requiredLen]

        // Add prefix
        dst[origLen] = '\''
        dst[origLen+1] = '\\'
        dst[origLen+2] = 'x'

        // Encode bytes directly into dst
        hex.Encode(dst[origLen+3:len(dst)-1], buf)

        // Add suffix
        dst[len(dst)-1] = '\''

        return dst
}


type sqlLexer struct {
Expand Down Expand Up @@ -319,13 +398,45 @@ func multilineCommentState(l *sqlLexer) stateFn {
}
}

var queryPool = &pool[*Query]{
new: func() *Query {
return &Query{}
},
reset: func(q *Query) bool {
n := len(q.Parts)
q.Parts = q.Parts[:0]
return n < 64 // drop too large queries
},
}

// SanitizeSQL replaces placeholder values with args. It quotes and escapes args
// as necessary. This function is only safe when standard_conforming_strings is
// on.
func SanitizeSQL(sql string, args ...any) (string, error) {
query, err := NewQuery(sql)
if err != nil {
return "", err
}
query := queryPool.get()
query.init(sql)
defer queryPool.put(query)

return query.Sanitize(args...)
}

type pool[E any] struct {
p sync.Pool
new func() E
reset func(E) bool
}

func (pool *pool[E]) get() E {
v, ok := pool.p.Get().(E)
if !ok {
v = pool.new()
}

return v
}

func (p *pool[E]) put(v E) {
if p.reset(v) {
p.p.Put(v)
}
}
62 changes: 62 additions & 0 deletions internal/sanitize/sanitize_bench_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// sanitize_benchmark_test.go
package sanitize_test

import (
"testing"
"time"

"github.com/jackc/pgx/v5/internal/sanitize"
)

var benchmarkSanitizeResult string

const benchmarkQuery = "" +
`SELECT *
FROM "water_containers"
WHERE NOT "id" = $1 -- int64
AND "tags" NOT IN $2 -- nil
AND "volume" > $3 -- float64
AND "transportable" = $4 -- bool
AND position($5 IN "sign") -- bytes
AND "label" LIKE $6 -- string
AND "created_at" > $7; -- time.Time`

var benchmarkArgs = []any{
int64(12345),
nil,
float64(500),
true,
[]byte("8BADF00D"),
"kombucha's han'dy awokowa",
time.Date(2015, 10, 1, 0, 0, 0, 0, time.UTC),
}

func BenchmarkSanitize(b *testing.B) {
query, err := sanitize.NewQuery(benchmarkQuery)
if err != nil {
b.Fatalf("failed to create query: %v", err)
}

b.ResetTimer()
b.ReportAllocs()

for i := 0; i < b.N; i++ {
benchmarkSanitizeResult, err = query.Sanitize(benchmarkArgs...)
if err != nil {
b.Fatalf("failed to sanitize query: %v", err)
}
}
}

var benchmarkNewSQLResult string

func BenchmarkSanitizeSQL(b *testing.B) {
b.ReportAllocs()
var err error
for i := 0; i < b.N; i++ {
benchmarkNewSQLResult, err = sanitize.SanitizeSQL(benchmarkQuery, benchmarkArgs...)
if err != nil {
b.Fatalf("failed to sanitize SQL: %v", err)
}
}
}
Loading