-
Notifications
You must be signed in to change notification settings - Fork 838
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Reduce SQL sanitizer allocations #2136
base: master
Are you sure you want to change the base?
Changes from all commits
b9d0214
730c324
c57e0d6
22e4205
9639283
f5c9af0
4dcba02
cc0b941
3f27c12
01e234e
07809d5
48cc36a
1d50b82
339b193
e5053ee
8264272
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
#!/usr/bin/env bash | ||
|
||
current_branch=$(git rev-parse --abbrev-ref HEAD) | ||
if [ "$current_branch" == "HEAD" ]; then | ||
current_branch=$(git rev-parse HEAD) | ||
fi | ||
|
||
restore_branch() { | ||
echo "Restoring original branch/commit: $current_branch" | ||
git checkout "$current_branch" | ||
} | ||
trap restore_branch EXIT | ||
|
||
# Check if there are uncommitted changes | ||
if ! git diff --quiet || ! git diff --cached --quiet; then | ||
echo "There are uncommitted changes. Please commit or stash them before running this script." | ||
exit 1 | ||
fi | ||
|
||
# Ensure that at least one commit argument is passed | ||
if [ "$#" -lt 1 ]; then | ||
echo "Usage: $0 <commit1> <commit2> ... <commitN>" | ||
exit 1 | ||
fi | ||
|
||
commits=("$@") | ||
benchmarks_dir=benchmarks | ||
|
||
if ! mkdir -p "${benchmarks_dir}"; then | ||
echo "Unable to create dir for benchmarks data" | ||
exit 1 | ||
fi | ||
|
||
# Benchmark results | ||
bench_files=() | ||
|
||
# Run benchmark for each listed commit | ||
for i in "${!commits[@]}"; do | ||
commit="${commits[i]}" | ||
git checkout "$commit" || { | ||
echo "Failed to checkout $commit" | ||
exit 1 | ||
} | ||
|
||
# Sanitized commmit message | ||
commit_message=$(git log -1 --pretty=format:"%s" | tr ' ' '_') | ||
|
||
# Benchmark data will go there | ||
bench_file="${benchmarks_dir}/${i}_${commit_message}.bench" | ||
|
||
if ! go test -bench=. -count=10 >"$bench_file"; then | ||
echo "Benchmarking failed for commit $commit" | ||
exit 1 | ||
fi | ||
|
||
bench_files+=("$bench_file") | ||
done | ||
|
||
benchstat "${bench_files[@]}" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you prefix with a small comment: |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,8 +4,10 @@ import ( | |
"bytes" | ||
"encoding/hex" | ||
"fmt" | ||
"slices" | ||
"strconv" | ||
"strings" | ||
"sync" | ||
"time" | ||
"unicode/utf8" | ||
) | ||
|
@@ -24,53 +26,75 @@ type Query struct { | |
// https://github.com/jackc/pgx/issues/1380 | ||
const replacementcharacterwidth = 3 | ||
|
||
const maxBufSize = 16384 // 16 Ki | ||
|
||
var bufPool = &pool[*bytes.Buffer]{ | ||
new: func() *bytes.Buffer { | ||
return &bytes.Buffer{} | ||
}, | ||
reset: func(b *bytes.Buffer) bool { | ||
n := b.Len() | ||
b.Reset() | ||
return n < maxBufSize | ||
}, | ||
} | ||
|
||
var null = []byte("null") | ||
|
||
func (q *Query) Sanitize(args ...any) (string, error) { | ||
argUse := make([]bool, len(args)) | ||
buf := &bytes.Buffer{} | ||
buf := bufPool.get() | ||
defer bufPool.put(buf) | ||
|
||
for _, part := range q.Parts { | ||
var str string | ||
switch part := part.(type) { | ||
case string: | ||
str = part | ||
buf.WriteString(part) | ||
case int: | ||
argIdx := part - 1 | ||
|
||
var p []byte | ||
if argIdx < 0 { | ||
return "", fmt.Errorf("first sql argument must be > 0") | ||
} | ||
|
||
if argIdx >= len(args) { | ||
return "", fmt.Errorf("insufficient arguments") | ||
} | ||
|
||
// Prevent SQL injection via Line Comment Creation | ||
// https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p | ||
buf.WriteByte(' ') | ||
|
||
arg := args[argIdx] | ||
switch arg := arg.(type) { | ||
case nil: | ||
str = "null" | ||
p = null | ||
case int64: | ||
str = strconv.FormatInt(arg, 10) | ||
p = strconv.AppendInt(buf.AvailableBuffer(), arg, 10) | ||
case float64: | ||
str = strconv.FormatFloat(arg, 'f', -1, 64) | ||
p = strconv.AppendFloat(buf.AvailableBuffer(), arg, 'f', -1, 64) | ||
case bool: | ||
str = strconv.FormatBool(arg) | ||
p = strconv.AppendBool(buf.AvailableBuffer(), arg) | ||
case []byte: | ||
str = QuoteBytes(arg) | ||
p = QuoteBytes(buf.AvailableBuffer(), arg) | ||
case string: | ||
str = QuoteString(arg) | ||
p = QuoteString(buf.AvailableBuffer(), arg) | ||
case time.Time: | ||
str = arg.Truncate(time.Microsecond).Format("'2006-01-02 15:04:05.999999999Z07:00:00'") | ||
p = arg.Truncate(time.Microsecond). | ||
AppendFormat(buf.AvailableBuffer(), "'2006-01-02 15:04:05.999999999Z07:00:00'") | ||
default: | ||
return "", fmt.Errorf("invalid arg type: %T", arg) | ||
} | ||
argUse[argIdx] = true | ||
|
||
buf.Write(p) | ||
|
||
// Prevent SQL injection via Line Comment Creation | ||
// https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p | ||
str = " " + str + " " | ||
buf.WriteByte(' ') | ||
default: | ||
return "", fmt.Errorf("invalid Part type: %T", part) | ||
} | ||
buf.WriteString(str) | ||
} | ||
|
||
for i, used := range argUse { | ||
|
@@ -82,26 +106,81 @@ func (q *Query) Sanitize(args ...any) (string, error) { | |
} | ||
|
||
func NewQuery(sql string) (*Query, error) { | ||
l := &sqlLexer{ | ||
src: sql, | ||
stateFn: rawState, | ||
query := &Query{} | ||
query.init(sql) | ||
|
||
return query, nil | ||
} | ||
|
||
var sqlLexerPool = &pool[*sqlLexer]{ | ||
new: func() *sqlLexer { | ||
return &sqlLexer{} | ||
}, | ||
reset: func(sl *sqlLexer) bool { | ||
*sl = sqlLexer{} | ||
return true | ||
}, | ||
} | ||
|
||
func (q *Query) init(sql string) { | ||
parts := q.Parts[:0] | ||
if parts == nil { | ||
// dirty, but fast heuristic to preallocate for ~90% usecases | ||
n := strings.Count(sql, "$") + strings.Count(sql, "--") + 1 | ||
parts = make([]Part, 0, n) | ||
} | ||
|
||
l := sqlLexerPool.get() | ||
defer sqlLexerPool.put(l) | ||
|
||
l.src = sql | ||
l.stateFn = rawState | ||
l.parts = parts | ||
|
||
for l.stateFn != nil { | ||
l.stateFn = l.stateFn(l) | ||
} | ||
|
||
query := &Query{Parts: l.parts} | ||
|
||
return query, nil | ||
q.Parts = l.parts | ||
} | ||
|
||
func QuoteString(str string) string { | ||
return "'" + strings.ReplaceAll(str, "'", "''") + "'" | ||
func QuoteString(dst []byte, str string) []byte { | ||
const quote = "'" | ||
|
||
n := strings.Count(str, quote) | ||
|
||
dst = append(dst, quote...) | ||
|
||
p := slices.Grow(dst[len(dst):], 2*len(quote)+len(str)+2*n) | ||
|
||
for len(str) > 0 { | ||
i := strings.Index(str, quote) | ||
if i < 0 { | ||
p = append(p, str...) | ||
break | ||
} | ||
p = append(p, str[:i]...) | ||
p = append(p, "''"...) | ||
str = str[i+1:] | ||
} | ||
|
||
dst = append(dst, p...) | ||
|
||
dst = append(dst, quote...) | ||
|
||
return dst | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is purely a style nit, but I don't like reslicing for these types of functions because it's not idiomatic and hard to follow. I took the above func QuoteString(dst []byte, str string) []byte {
const quote = '\''
// Preallocate space for the worst case scenario
dst = slices.Grow(dst, len(str)*2+2)
// Add opening quote
dst = append(dst, quote)
// Iterate through the string without allocating
for i := 0; i < len(str); i++ {
if str[i] == quote {
dst = append(dst, quote, quote)
} else {
dst = append(dst, str[i])
}
}
// Add closing quote
dst = append(dst, quote)
return dst
} |
||
|
||
func QuoteBytes(buf []byte) string { | ||
return `'\x` + hex.EncodeToString(buf) + "'" | ||
func QuoteBytes(dst, buf []byte) []byte { | ||
dst = append(dst, `'\x`...) | ||
|
||
n := hex.EncodedLen(len(buf)) | ||
p := slices.Grow(dst[len(dst):], n)[:n] | ||
hex.Encode(p, buf) | ||
dst = append(dst, p...) | ||
|
||
dst = append(dst, `'`...) | ||
return dst | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was able to measure an improvement by optimizing this function: func QuoteBytes(dst, buf []byte) []byte {
if len(buf) == 0 {
return append(dst, `'\x'`...)
}
// Calculate required length
requiredLen := 3 + hex.EncodedLen(len(buf)) + 1
// Ensure dst has enough capacity
if cap(dst)-len(dst) < requiredLen {
newDst := make([]byte, len(dst), len(dst)+requiredLen)
copy(newDst, dst)
dst = newDst
}
// Record original length and extend slice
origLen := len(dst)
dst = dst[:origLen+requiredLen]
// Add prefix
dst[origLen] = '\''
dst[origLen+1] = '\\'
dst[origLen+2] = 'x'
// Encode bytes directly into dst
hex.Encode(dst[origLen+3:len(dst)-1], buf)
// Add suffix
dst[len(dst)-1] = '\''
return dst
} |
||
|
||
type sqlLexer struct { | ||
|
@@ -319,13 +398,45 @@ func multilineCommentState(l *sqlLexer) stateFn { | |
} | ||
} | ||
|
||
var queryPool = &pool[*Query]{ | ||
new: func() *Query { | ||
return &Query{} | ||
}, | ||
reset: func(q *Query) bool { | ||
n := len(q.Parts) | ||
q.Parts = q.Parts[:0] | ||
return n < 64 // drop too large queries | ||
}, | ||
} | ||
|
||
// SanitizeSQL replaces placeholder values with args. It quotes and escapes args | ||
// as necessary. This function is only safe when standard_conforming_strings is | ||
// on. | ||
func SanitizeSQL(sql string, args ...any) (string, error) { | ||
query, err := NewQuery(sql) | ||
if err != nil { | ||
return "", err | ||
} | ||
query := queryPool.get() | ||
query.init(sql) | ||
defer queryPool.put(query) | ||
|
||
return query.Sanitize(args...) | ||
} | ||
|
||
type pool[E any] struct { | ||
p sync.Pool | ||
new func() E | ||
reset func(E) bool | ||
} | ||
|
||
func (pool *pool[E]) get() E { | ||
v, ok := pool.p.Get().(E) | ||
if !ok { | ||
v = pool.new() | ||
} | ||
|
||
return v | ||
} | ||
|
||
func (p *pool[E]) put(v E) { | ||
if p.reset(v) { | ||
p.p.Put(v) | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
// sanitize_benchmark_test.go | ||
package sanitize_test | ||
|
||
import ( | ||
"testing" | ||
"time" | ||
|
||
"github.com/jackc/pgx/v5/internal/sanitize" | ||
) | ||
|
||
var benchmarkSanitizeResult string | ||
|
||
const benchmarkQuery = "" + | ||
`SELECT * | ||
FROM "water_containers" | ||
WHERE NOT "id" = $1 -- int64 | ||
AND "tags" NOT IN $2 -- nil | ||
AND "volume" > $3 -- float64 | ||
AND "transportable" = $4 -- bool | ||
AND position($5 IN "sign") -- bytes | ||
AND "label" LIKE $6 -- string | ||
AND "created_at" > $7; -- time.Time` | ||
|
||
var benchmarkArgs = []any{ | ||
int64(12345), | ||
nil, | ||
float64(500), | ||
true, | ||
[]byte("8BADF00D"), | ||
"kombucha's han'dy awokowa", | ||
time.Date(2015, 10, 1, 0, 0, 0, 0, time.UTC), | ||
} | ||
|
||
func BenchmarkSanitize(b *testing.B) { | ||
query, err := sanitize.NewQuery(benchmarkQuery) | ||
if err != nil { | ||
b.Fatalf("failed to create query: %v", err) | ||
} | ||
|
||
b.ResetTimer() | ||
b.ReportAllocs() | ||
|
||
for i := 0; i < b.N; i++ { | ||
benchmarkSanitizeResult, err = query.Sanitize(benchmarkArgs...) | ||
if err != nil { | ||
b.Fatalf("failed to sanitize query: %v", err) | ||
} | ||
} | ||
} | ||
|
||
var benchmarkNewSQLResult string | ||
|
||
func BenchmarkSanitizeSQL(b *testing.B) { | ||
b.ReportAllocs() | ||
var err error | ||
for i := 0; i < b.N; i++ { | ||
benchmarkNewSQLResult, err = sanitize.SanitizeSQL(benchmarkQuery, benchmarkArgs...) | ||
if err != nil { | ||
b.Fatalf("failed to sanitize SQL: %v", err) | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This needs to escape
/
:commit_message=$(git log -1 --pretty=format:"%s" | tr -c '[:alnum:]-_' '_')