From ddde166a122219bf117fbd29facb0f2142ac7b14 Mon Sep 17 00:00:00 2001 From: Kasvi Date: Mon, 25 Nov 2024 16:02:29 +1100 Subject: [PATCH 1/5] packing sqlx to work with safesql --- go.mod | 1 + go.sum | 51 +++++++++++++++++++++++++++++++++++++++++++++++++ sqlx.go | 13 +++++++------ sqlx_context.go | 4 +++- 4 files changed, 62 insertions(+), 7 deletions(-) diff --git a/go.mod b/go.mod index 9b164a1..15ffd49 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.10 require ( github.com/go-sql-driver/mysql v1.8.1 + github.com/google/go-safeweb v0.0.0-20240727104708-c2d1215a6a24 github.com/lib/pq v1.10.9 github.com/mattn/go-sqlite3 v1.14.22 ) diff --git a/go.sum b/go.sum index 31d5aba..cdbb2ae 100644 --- a/go.sum +++ b/go.sum @@ -2,7 +2,58 @@ filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= +github.com/google/go-cmp v0.5.0 h1:/QaMHBdZ26BB3SSst0Iwl10Epc+xhTquomWX0oZEB6w= +github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-safeweb v0.0.0-20240727104708-c2d1215a6a24 h1:84uaHf8KLrWKVuJOfQfGZMMGdC8ujrPDDeekLGxkTss= +github.com/google/go-safeweb v0.0.0-20240727104708-c2d1215a6a24/go.mod h1:ukNyX9TdScmbBtYBtWwaU+n7MtodX/Wr6rBKBADTvGQ= +github.com/google/safehtml v0.0.2/go.mod h1:L4KWwDsUJdECRAEpZoBn3O64bQaywRscowZjJAzjHnU= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= +golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= +golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= +golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= +golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= +golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= +golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= +golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/sqlx.go b/sqlx.go index 8259a4f..0464ba2 100644 --- a/sqlx.go +++ b/sqlx.go @@ -11,6 +11,7 @@ import ( "strings" "sync" + "github.com/google/go-safeweb/safesql" "github.com/jmoiron/sqlx/reflectx" ) @@ -241,7 +242,7 @@ func (r *Row) Err() error { // DB is a wrapper around sql.DB which keeps track of the driverName upon Open, // used mostly to automatically bind named queries using the right bindvars. type DB struct { - *sql.DB + *safesql.DB driverName string unsafe bool Mapper *reflectx.Mapper @@ -251,7 +252,7 @@ type DB struct { // driverName of the original database is required for named query support. // //lint:ignore ST1003 changing this would break the package interface. -func NewDb(db *sql.DB, driverName string) *DB { +func NewDb(db *safesql.DB, driverName string) *DB { return &DB{DB: db, driverName: driverName, Mapper: mapper()} } @@ -262,11 +263,11 @@ func (db *DB) DriverName() string { // Open is the same as sql.Open, but returns an *sqlx.DB instead. func Open(driverName, dataSourceName string) (*DB, error) { - db, err := sql.Open(driverName, dataSourceName) + db, err := safesql.Open(driverName, dataSourceName) if err != nil { return nil, err } - return &DB{DB: db, driverName: driverName, Mapper: mapper()}, err + return &DB{DB: &db, driverName: driverName, Mapper: mapper()}, err } // MustOpen is the same as sql.Open, but returns an *sqlx.DB instead and panics on error. @@ -348,7 +349,7 @@ func (db *DB) Beginx() (*Tx, error) { // Queryx queries the database and returns an *sqlx.Rows. // Any placeholder parameters are replaced with supplied args. -func (db *DB) Queryx(query string, args ...interface{}) (*Rows, error) { +func (db *DB) Queryx(query safesql.TrustedSQLString, args ...interface{}) (*Rows, error) { r, err := db.DB.Query(query, args...) if err != nil { return nil, err @@ -358,7 +359,7 @@ func (db *DB) Queryx(query string, args ...interface{}) (*Rows, error) { // QueryRowx queries the database and returns an *sqlx.Row. // Any placeholder parameters are replaced with supplied args. -func (db *DB) QueryRowx(query string, args ...interface{}) *Row { +func (db *DB) QueryRowx(query safesql.TrustedSQLString, args ...interface{}) *Row { rows, err := db.DB.Query(query, args...) return &Row{rows: rows, err: err, unsafe: db.unsafe, Mapper: db.Mapper} } diff --git a/sqlx_context.go b/sqlx_context.go index 32621d5..918ea43 100644 --- a/sqlx_context.go +++ b/sqlx_context.go @@ -10,6 +10,8 @@ import ( "io/ioutil" "path/filepath" "reflect" + + "github.com/google/go-safeweb/safesql" ) // ConnectContext to a database and verify with a ping. @@ -168,7 +170,7 @@ func (db *DB) QueryxContext(ctx context.Context, query string, args ...interface // QueryRowxContext queries the database and returns an *sqlx.Row. // Any placeholder parameters are replaced with supplied args. -func (db *DB) QueryRowxContext(ctx context.Context, query string, args ...interface{}) *Row { +func (db *DB) QueryRowxContext(ctx context.Context, query safesql.TrustedSQLString, args ...interface{}) *Row { rows, err := db.DB.QueryContext(ctx, query, args...) return &Row{rows: rows, err: err, unsafe: db.unsafe, Mapper: db.Mapper} } From 7e3281745bef297583665fcdbb1827cf1992edf9 Mon Sep 17 00:00:00 2001 From: Kasvi Date: Mon, 25 Nov 2024 17:10:21 +1100 Subject: [PATCH 2/5] packing sqlx to work with safesql --- sqlx_context.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlx_context.go b/sqlx_context.go index 918ea43..33d86ef 100644 --- a/sqlx_context.go +++ b/sqlx_context.go @@ -160,7 +160,7 @@ func (db *DB) PreparexContext(ctx context.Context, query string) (*Stmt, error) // QueryxContext queries the database and returns an *sqlx.Rows. // Any placeholder parameters are replaced with supplied args. -func (db *DB) QueryxContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) { +func (db *DB) QueryxContext(ctx context.Context, query safesql.TrustedSQLString, args ...interface{}) (*Rows, error) { r, err := db.DB.QueryContext(ctx, query, args...) if err != nil { return nil, err From 644a08f800d0155436d0345456f6612ea5ffd7e4 Mon Sep 17 00:00:00 2001 From: Peter Arts <77252233+sc-peter@users.noreply.github.com> Date: Wed, 27 Nov 2024 12:05:22 +1100 Subject: [PATCH 3/5] Refactor to safesql --- bind.go | 20 +++++++++---- named.go | 72 ++++++++++++++++++++++++---------------------- sqlx.go | 88 +++++++++++++++++++++++++++++--------------------------- 3 files changed, 97 insertions(+), 83 deletions(-) diff --git a/bind.go b/bind.go index e698039..5743a9b 100644 --- a/bind.go +++ b/bind.go @@ -9,6 +9,8 @@ import ( "strings" "sync" + "github.com/google/go-safeweb/safesql" + "github.com/google/go-safeweb/safesql/uncheckedconversions" "github.com/jmoiron/sqlx/reflectx" ) @@ -57,19 +59,23 @@ func BindDriver(driverName string, bindType int) { // losing much speed, and should be to avoid confusion. // Rebind a query from the default bindtype (QUESTION) to the target bindtype. -func Rebind(bindType int, query string) string { +func Rebind(bindType int, query safesql.TrustedSQLString) safesql.TrustedSQLString { switch bindType { case QUESTION, UNKNOWN: return query } + // Work with the raw string and convert back via unchecked conversion + // Make sure that no untrusted data is making its way into the string in the meantime! + rawQuery := query.String() + // Add space enough for 10 params before we have to allocate - rqb := make([]byte, 0, len(query)+10) + rqb := make([]byte, 0, len(rawQuery)+10) var i, j int - for i = strings.Index(query, "?"); i != -1; i = strings.Index(query, "?") { - rqb = append(rqb, query[:i]...) + for i = strings.Index(rawQuery, "?"); i != -1; i = strings.Index(rawQuery, "?") { + rqb = append(rqb, rawQuery[:i]...) switch bindType { case DOLLAR: @@ -83,10 +89,12 @@ func Rebind(bindType int, query string) string { j++ rqb = strconv.AppendInt(rqb, int64(j), 10) - query = query[i+1:] + rawQuery = rawQuery[i+1:] } - return string(append(rqb, query...)) + rawQuery = string(append(rqb, rawQuery...)) + + return uncheckedconversions.TrustedSQLStringFromStringKnownToSatisfyTypeContract(rawQuery) } // Experimental implementation of Rebind which uses a bytes.Buffer. The code is diff --git a/named.go b/named.go index 6ac4477..ae99074 100644 --- a/named.go +++ b/named.go @@ -21,6 +21,8 @@ import ( "strconv" "unicode" + "github.com/google/go-safeweb/safesql" + "github.com/google/go-safeweb/safesql/uncheckedconversions" "github.com/jmoiron/sqlx/reflectx" ) @@ -28,7 +30,7 @@ import ( // how you would execute a NamedQuery, but pass in a struct or map when executing. type NamedStmt struct { Params []string - QueryString string + QueryString safesql.TrustedSQLString Stmt *Stmt } @@ -129,9 +131,9 @@ type namedPreparer interface { binder } -func prepareNamed(p namedPreparer, query string) (*NamedStmt, error) { +func prepareNamed(p namedPreparer, query safesql.TrustedSQLString) (*NamedStmt, error) { bindType := BindType(p.DriverName()) - q, args, err := compileNamedQuery([]byte(query), bindType) + q, args, err := compileNamedQuery(query, bindType) if err != nil { return nil, err } @@ -210,15 +212,15 @@ func bindMapArgs(names []string, arg map[string]interface{}) ([]interface{}, err // bindStruct binds a named parameter query with fields from a struct argument. // The rules for binding field names to parameter names follow the same // conventions as for StructScan, including obeying the `db` struct tags. -func bindStruct(bindType int, query string, arg interface{}, m *reflectx.Mapper) (string, []interface{}, error) { - bound, names, err := compileNamedQuery([]byte(query), bindType) +func bindStruct(bindType int, query safesql.TrustedSQLString, arg interface{}, m *reflectx.Mapper) (safesql.TrustedSQLString, []interface{}, error) { + bound, names, err := compileNamedQuery(query, bindType) if err != nil { - return "", []interface{}{}, err + return safesql.New(""), []interface{}{}, err } arglist, err := bindAnyArgs(names, arg, m) if err != nil { - return "", []interface{}{}, err + return safesql.New(""), []interface{}{}, err } return bound, arglist, nil @@ -242,15 +244,16 @@ func findMatchingClosingBracketIndex(s string) int { return 0 } -func fixBound(bound string, loop int) string { - loc := valuesReg.FindStringIndex(bound) +func fixBound(bound safesql.TrustedSQLString, loop int) safesql.TrustedSQLString { + rawBound := bound.String() + loc := valuesReg.FindStringIndex(rawBound) // defensive guard when "VALUES (...)" not found if len(loc) < 2 { return bound } openingBracketIndex := loc[1] - 1 - index := findMatchingClosingBracketIndex(bound[openingBracketIndex:]) + index := findMatchingClosingBracketIndex(rawBound[openingBracketIndex:]) // defensive guard. must have closing bracket if index == 0 { return bound @@ -259,34 +262,34 @@ func fixBound(bound string, loop int) string { var buffer bytes.Buffer - buffer.WriteString(bound[0:closingBracketIndex]) + buffer.WriteString(rawBound[0:closingBracketIndex]) for i := 0; i < loop-1; i++ { buffer.WriteString(",") - buffer.WriteString(bound[openingBracketIndex:closingBracketIndex]) + buffer.WriteString(rawBound[openingBracketIndex:closingBracketIndex]) } - buffer.WriteString(bound[closingBracketIndex:]) - return buffer.String() + buffer.WriteString(rawBound[closingBracketIndex:]) + return uncheckedconversions.TrustedSQLStringFromStringKnownToSatisfyTypeContract(buffer.String()) } // bindArray binds a named parameter query with fields from an array or slice of // structs argument. -func bindArray(bindType int, query string, arg interface{}, m *reflectx.Mapper) (string, []interface{}, error) { +func bindArray(bindType int, query safesql.TrustedSQLString, arg interface{}, m *reflectx.Mapper) (safesql.TrustedSQLString, []interface{}, error) { // do the initial binding with QUESTION; if bindType is not question, // we can rebind it at the end. - bound, names, err := compileNamedQuery([]byte(query), QUESTION) + bound, names, err := compileNamedQuery(query, QUESTION) if err != nil { - return "", []interface{}{}, err + return safesql.New(""), []interface{}{}, err } arrayValue := reflect.ValueOf(arg) arrayLen := arrayValue.Len() if arrayLen == 0 { - return "", []interface{}{}, fmt.Errorf("length of array is 0: %#v", arg) + return safesql.New(""), []interface{}{}, fmt.Errorf("length of array is 0: %#v", arg) } var arglist = make([]interface{}, 0, len(names)*arrayLen) for i := 0; i < arrayLen; i++ { elemArglist, err := bindAnyArgs(names, arrayValue.Index(i).Interface(), m) if err != nil { - return "", []interface{}{}, err + return safesql.New(""), []interface{}{}, err } arglist = append(arglist, elemArglist...) } @@ -301,10 +304,10 @@ func bindArray(bindType int, query string, arg interface{}, m *reflectx.Mapper) } // bindMap binds a named parameter query with a map of arguments. -func bindMap(bindType int, query string, args map[string]interface{}) (string, []interface{}, error) { - bound, names, err := compileNamedQuery([]byte(query), bindType) +func bindMap(bindType int, query safesql.TrustedSQLString, args map[string]interface{}) (safesql.TrustedSQLString, []interface{}, error) { + bound, names, err := compileNamedQuery(query, bindType) if err != nil { - return "", []interface{}{}, err + return safesql.New(""), []interface{}{}, err } arglist, err := bindMapArgs(names, args) @@ -328,20 +331,21 @@ var allowedBindRunes = []*unicode.RangeTable{unicode.Letter, unicode.Digit} // compile a NamedQuery into an unbound query (using the '?' bindvar) and // a list of names. -func compileNamedQuery(qs []byte, bindType int) (query string, names []string, err error) { +func compileNamedQuery(qs safesql.TrustedSQLString, bindType int) (query safesql.TrustedSQLString, names []string, err error) { names = make([]string, 0, 10) - rebound := make([]byte, 0, len(qs)) + rawQs := []byte(qs.String()) + rebound := make([]byte, 0, len(rawQs)) inName := false - last := len(qs) - 1 + last := len(rawQs) - 1 currentVar := 1 name := make([]byte, 0, 10) - for i, b := range qs { + for i, b := range rawQs { // a ':' while we're in a name is an error if b == ':' { // if this is the second ':' in a '::' escape sequence, append a ':' - if inName && i > 0 && qs[i-1] == ':' { + if inName && i > 0 && rawQs[i-1] == ':' { rebound = append(rebound, ':') inName = false continue @@ -402,30 +406,30 @@ func compileNamedQuery(qs []byte, bindType int) (query string, names []string, e } } - return string(rebound), names, err + return uncheckedconversions.TrustedSQLStringFromStringKnownToSatisfyTypeContract(string(rebound)), names, err } // BindNamed binds a struct or a map to a query with named parameters. // DEPRECATED: use sqlx.Named` instead of this, it may be removed in future. -func BindNamed(bindType int, query string, arg interface{}) (string, []interface{}, error) { +func BindNamed(bindType int, query safesql.TrustedSQLString, arg interface{}) (safesql.TrustedSQLString, []interface{}, error) { return bindNamedMapper(bindType, query, arg, mapper()) } // Named takes a query using named parameters and an argument and // returns a new query with a list of args that can be executed by // a database. The return value uses the `?` bindvar. -func Named(query string, arg interface{}) (string, []interface{}, error) { +func Named(query safesql.TrustedSQLString, arg interface{}) (safesql.TrustedSQLString, []interface{}, error) { return bindNamedMapper(QUESTION, query, arg, mapper()) } -func bindNamedMapper(bindType int, query string, arg interface{}, m *reflectx.Mapper) (string, []interface{}, error) { +func bindNamedMapper(bindType int, query safesql.TrustedSQLString, arg interface{}, m *reflectx.Mapper) (safesql.TrustedSQLString, []interface{}, error) { t := reflect.TypeOf(arg) k := t.Kind() switch { case k == reflect.Map && t.Key().Kind() == reflect.String: m, ok := convertMapStringInterface(arg) if !ok { - return "", nil, fmt.Errorf("sqlx.bindNamedMapper: unsupported map type: %T", arg) + return safesql.New(""), nil, fmt.Errorf("sqlx.bindNamedMapper: unsupported map type: %T", arg) } return bindMap(bindType, query, m) case k == reflect.Array || k == reflect.Slice: @@ -438,7 +442,7 @@ func bindNamedMapper(bindType int, query string, arg interface{}, m *reflectx.Ma // NamedQuery binds a named query and then runs Query on the result using the // provided Ext (sqlx.Tx, sqlx.Db). It works with both structs and with // map[string]interface{} types. -func NamedQuery(e Ext, query string, arg interface{}) (*Rows, error) { +func NamedQuery(e Ext, query safesql.TrustedSQLString, arg interface{}) (*Rows, error) { q, args, err := bindNamedMapper(BindType(e.DriverName()), query, arg, mapperFor(e)) if err != nil { return nil, err @@ -449,7 +453,7 @@ func NamedQuery(e Ext, query string, arg interface{}) (*Rows, error) { // NamedExec uses BindStruct to get a query executable by the driver and // then runs Exec on the result. Returns an error from the binding // or the query execution itself. -func NamedExec(e Ext, query string, arg interface{}) (sql.Result, error) { +func NamedExec(e Ext, query safesql.TrustedSQLString, arg interface{}) (sql.Result, error) { q, args, err := bindNamedMapper(BindType(e.DriverName()), query, arg, mapperFor(e)) if err != nil { return nil, err diff --git a/sqlx.go b/sqlx.go index 0464ba2..51174a1 100644 --- a/sqlx.go +++ b/sqlx.go @@ -12,6 +12,7 @@ import ( "sync" "github.com/google/go-safeweb/safesql" + "github.com/google/go-safeweb/safesql/uncheckedconversions" "github.com/jmoiron/sqlx/reflectx" ) @@ -76,21 +77,21 @@ type ColScanner interface { // Queryer is an interface used by Get and Select type Queryer interface { - Query(query string, args ...interface{}) (*sql.Rows, error) - Queryx(query string, args ...interface{}) (*Rows, error) - QueryRowx(query string, args ...interface{}) *Row + Query(query safesql.TrustedSQLString, args ...interface{}) (*sql.Rows, error) + Queryx(query safesql.TrustedSQLString, args ...interface{}) (*Rows, error) + QueryRowx(query safesql.TrustedSQLString, args ...interface{}) *Row } // Execer is an interface used by MustExec and LoadFile type Execer interface { - Exec(query string, args ...interface{}) (sql.Result, error) + Exec(query safesql.TrustedSQLString, args ...interface{}) (sql.Result, error) } // Binder is an interface for something which can bind queries (Tx, DB) type binder interface { DriverName() string - Rebind(string) string - BindNamed(string, interface{}) (string, []interface{}, error) + Rebind(safesql.TrustedSQLString) safesql.TrustedSQLString + BindNamed(safesql.TrustedSQLString, interface{}) (safesql.TrustedSQLString, []interface{}, error) } // Ext is a union interface which can bind, query, and exec, used by @@ -103,7 +104,7 @@ type Ext interface { // Preparer is an interface used by Preparex. type Preparer interface { - Prepare(query string) (*sql.Stmt, error) + Prepare(query safesql.TrustedSQLString) (*sql.Stmt, error) } // determine if any of our extensions are unsafe @@ -286,7 +287,7 @@ func (db *DB) MapperFunc(mf func(string) string) { } // Rebind transforms a query from QUESTION to the DB driver's bindvar type. -func (db *DB) Rebind(query string) string { +func (db *DB) Rebind(query safesql.TrustedSQLString) safesql.TrustedSQLString { return Rebind(BindType(db.driverName), query) } @@ -299,32 +300,32 @@ func (db *DB) Unsafe() *DB { } // BindNamed binds a query using the DB driver's bindvar type. -func (db *DB) BindNamed(query string, arg interface{}) (string, []interface{}, error) { +func (db *DB) BindNamed(query safesql.TrustedSQLString, arg interface{}) (safesql.TrustedSQLString, []interface{}, error) { return bindNamedMapper(BindType(db.driverName), query, arg, db.Mapper) } // NamedQuery using this DB. // Any named placeholder parameters are replaced with fields from arg. -func (db *DB) NamedQuery(query string, arg interface{}) (*Rows, error) { +func (db *DB) NamedQuery(query safesql.TrustedSQLString, arg interface{}) (*Rows, error) { return NamedQuery(db, query, arg) } // NamedExec using this DB. // Any named placeholder parameters are replaced with fields from arg. -func (db *DB) NamedExec(query string, arg interface{}) (sql.Result, error) { +func (db *DB) NamedExec(query safesql.TrustedSQLString, arg interface{}) (sql.Result, error) { return NamedExec(db, query, arg) } // Select using this DB. // Any placeholder parameters are replaced with supplied args. -func (db *DB) Select(dest interface{}, query string, args ...interface{}) error { +func (db *DB) Select(dest interface{}, query safesql.TrustedSQLString, args ...interface{}) error { return Select(db, dest, query, args...) } // Get using this DB. // Any placeholder parameters are replaced with supplied args. // An error is returned if the result set is empty. -func (db *DB) Get(dest interface{}, query string, args ...interface{}) error { +func (db *DB) Get(dest interface{}, query safesql.TrustedSQLString, args ...interface{}) error { return Get(db, dest, query, args...) } @@ -344,7 +345,7 @@ func (db *DB) Beginx() (*Tx, error) { if err != nil { return nil, err } - return &Tx{Tx: tx, driverName: db.driverName, unsafe: db.unsafe, Mapper: db.Mapper}, err + return &Tx{Tx: &tx, driverName: db.driverName, unsafe: db.unsafe, Mapper: db.Mapper}, err } // Queryx queries the database and returns an *sqlx.Rows. @@ -366,17 +367,17 @@ func (db *DB) QueryRowx(query safesql.TrustedSQLString, args ...interface{}) *Ro // MustExec (panic) runs MustExec using this database. // Any placeholder parameters are replaced with supplied args. -func (db *DB) MustExec(query string, args ...interface{}) sql.Result { +func (db *DB) MustExec(query safesql.TrustedSQLString, args ...interface{}) sql.Result { return MustExec(db, query, args...) } // Preparex returns an sqlx.Stmt instead of a sql.Stmt -func (db *DB) Preparex(query string) (*Stmt, error) { +func (db *DB) Preparex(query safesql.TrustedSQLString) (*Stmt, error) { return Preparex(db, query) } // PrepareNamed returns an sqlx.NamedStmt -func (db *DB) PrepareNamed(query string) (*NamedStmt, error) { +func (db *DB) PrepareNamed(query safesql.TrustedSQLString) (*NamedStmt, error) { return prepareNamed(db, query) } @@ -390,7 +391,7 @@ type Conn struct { // Tx is an sqlx wrapper around sql.Tx with extra functionality type Tx struct { - *sql.Tx + *safesql.Tx driverName string unsafe bool Mapper *reflectx.Mapper @@ -402,7 +403,7 @@ func (tx *Tx) DriverName() string { } // Rebind a query within a transaction's bindvar type. -func (tx *Tx) Rebind(query string) string { +func (tx *Tx) Rebind(query safesql.TrustedSQLString) safesql.TrustedSQLString { return Rebind(BindType(tx.driverName), query) } @@ -413,31 +414,31 @@ func (tx *Tx) Unsafe() *Tx { } // BindNamed binds a query within a transaction's bindvar type. -func (tx *Tx) BindNamed(query string, arg interface{}) (string, []interface{}, error) { +func (tx *Tx) BindNamed(query safesql.TrustedSQLString, arg interface{}) (safesql.TrustedSQLString, []interface{}, error) { return bindNamedMapper(BindType(tx.driverName), query, arg, tx.Mapper) } // NamedQuery within a transaction. // Any named placeholder parameters are replaced with fields from arg. -func (tx *Tx) NamedQuery(query string, arg interface{}) (*Rows, error) { +func (tx *Tx) NamedQuery(query safesql.TrustedSQLString, arg interface{}) (*Rows, error) { return NamedQuery(tx, query, arg) } // NamedExec a named query within a transaction. // Any named placeholder parameters are replaced with fields from arg. -func (tx *Tx) NamedExec(query string, arg interface{}) (sql.Result, error) { +func (tx *Tx) NamedExec(query safesql.TrustedSQLString, arg interface{}) (sql.Result, error) { return NamedExec(tx, query, arg) } // Select within a transaction. // Any placeholder parameters are replaced with supplied args. -func (tx *Tx) Select(dest interface{}, query string, args ...interface{}) error { +func (tx *Tx) Select(dest interface{}, query safesql.TrustedSQLString, args ...interface{}) error { return Select(tx, dest, query, args...) } // Queryx within a transaction. // Any placeholder parameters are replaced with supplied args. -func (tx *Tx) Queryx(query string, args ...interface{}) (*Rows, error) { +func (tx *Tx) Queryx(query safesql.TrustedSQLString, args ...interface{}) (*Rows, error) { r, err := tx.Tx.Query(query, args...) if err != nil { return nil, err @@ -447,7 +448,7 @@ func (tx *Tx) Queryx(query string, args ...interface{}) (*Rows, error) { // QueryRowx within a transaction. // Any placeholder parameters are replaced with supplied args. -func (tx *Tx) QueryRowx(query string, args ...interface{}) *Row { +func (tx *Tx) QueryRowx(query safesql.TrustedSQLString, args ...interface{}) *Row { rows, err := tx.Tx.Query(query, args...) return &Row{rows: rows, err: err, unsafe: tx.unsafe, Mapper: tx.Mapper} } @@ -455,18 +456,18 @@ func (tx *Tx) QueryRowx(query string, args ...interface{}) *Row { // Get within a transaction. // Any placeholder parameters are replaced with supplied args. // An error is returned if the result set is empty. -func (tx *Tx) Get(dest interface{}, query string, args ...interface{}) error { +func (tx *Tx) Get(dest interface{}, query safesql.TrustedSQLString, args ...interface{}) error { return Get(tx, dest, query, args...) } // MustExec runs MustExec within a transaction. // Any placeholder parameters are replaced with supplied args. -func (tx *Tx) MustExec(query string, args ...interface{}) sql.Result { +func (tx *Tx) MustExec(query safesql.TrustedSQLString, args ...interface{}) sql.Result { return MustExec(tx, query, args...) } // Preparex a statement within a transaction. -func (tx *Tx) Preparex(query string) (*Stmt, error) { +func (tx *Tx) Preparex(query safesql.TrustedSQLString) (*Stmt, error) { return Preparex(tx, query) } @@ -497,7 +498,7 @@ func (tx *Tx) NamedStmt(stmt *NamedStmt) *NamedStmt { } // PrepareNamed returns an sqlx.NamedStmt -func (tx *Tx) PrepareNamed(query string) (*NamedStmt, error) { +func (tx *Tx) PrepareNamed(query safesql.TrustedSQLString) (*NamedStmt, error) { return prepareNamed(tx, query) } @@ -517,46 +518,46 @@ func (s *Stmt) Unsafe() *Stmt { // Select using the prepared statement. // Any placeholder parameters are replaced with supplied args. func (s *Stmt) Select(dest interface{}, args ...interface{}) error { - return Select(&qStmt{s}, dest, "", args...) + return Select(&qStmt{s}, dest, safesql.New(""), args...) } // Get using the prepared statement. // Any placeholder parameters are replaced with supplied args. // An error is returned if the result set is empty. func (s *Stmt) Get(dest interface{}, args ...interface{}) error { - return Get(&qStmt{s}, dest, "", args...) + return Get(&qStmt{s}, dest, safesql.New(""), args...) } // MustExec (panic) using this statement. Note that the query portion of the error // output will be blank, as Stmt does not expose its query. // Any placeholder parameters are replaced with supplied args. func (s *Stmt) MustExec(args ...interface{}) sql.Result { - return MustExec(&qStmt{s}, "", args...) + return MustExec(&qStmt{s}, safesql.New(""), args...) } // QueryRowx using this statement. // Any placeholder parameters are replaced with supplied args. func (s *Stmt) QueryRowx(args ...interface{}) *Row { qs := &qStmt{s} - return qs.QueryRowx("", args...) + return qs.QueryRowx(safesql.New(""), args...) } // Queryx using this statement. // Any placeholder parameters are replaced with supplied args. func (s *Stmt) Queryx(args ...interface{}) (*Rows, error) { qs := &qStmt{s} - return qs.Queryx("", args...) + return qs.Queryx(safesql.New(""), args...) } // qStmt is an unexposed wrapper which lets you use a Stmt as a Queryer & Execer by // implementing those interfaces and ignoring the `query` argument. type qStmt struct{ *Stmt } -func (q *qStmt) Query(query string, args ...interface{}) (*sql.Rows, error) { +func (q *qStmt) Query(query safesql.TrustedSQLString, args ...interface{}) (*sql.Rows, error) { return q.Stmt.Query(args...) } -func (q *qStmt) Queryx(query string, args ...interface{}) (*Rows, error) { +func (q *qStmt) Queryx(query safesql.TrustedSQLString, args ...interface{}) (*Rows, error) { r, err := q.Stmt.Query(args...) if err != nil { return nil, err @@ -564,12 +565,12 @@ func (q *qStmt) Queryx(query string, args ...interface{}) (*Rows, error) { return &Rows{Rows: r, unsafe: q.Stmt.unsafe, Mapper: q.Stmt.Mapper}, err } -func (q *qStmt) QueryRowx(query string, args ...interface{}) *Row { +func (q *qStmt) QueryRowx(query safesql.TrustedSQLString, args ...interface{}) *Row { rows, err := q.Stmt.Query(args...) return &Row{rows: rows, err: err, unsafe: q.Stmt.unsafe, Mapper: q.Stmt.Mapper} } -func (q *qStmt) Exec(query string, args ...interface{}) (sql.Result, error) { +func (q *qStmt) Exec(query safesql.TrustedSQLString, args ...interface{}) (sql.Result, error) { return q.Stmt.Exec(args...) } @@ -661,7 +662,7 @@ func MustConnect(driverName, dataSourceName string) *DB { } // Preparex prepares a statement. -func Preparex(p Preparer, query string) (*Stmt, error) { +func Preparex(p Preparer, query safesql.TrustedSQLString) (*Stmt, error) { s, err := p.Prepare(query) if err != nil { return nil, err @@ -674,7 +675,7 @@ func Preparex(p Preparer, query string) (*Stmt, error) { // the result set must have only one column. Otherwise, StructScan is used. // The *sql.Rows are closed automatically. // Any placeholder parameters are replaced with supplied args. -func Select(q Queryer, dest interface{}, query string, args ...interface{}) error { +func Select(q Queryer, dest interface{}, query safesql.TrustedSQLString, args ...interface{}) error { rows, err := q.Queryx(query, args...) if err != nil { return err @@ -689,7 +690,7 @@ func Select(q Queryer, dest interface{}, query string, args ...interface{}) erro // StructScan is used. Get will return sql.ErrNoRows like row.Scan would. // Any placeholder parameters are replaced with supplied args. // An error is returned if the result set is empty. -func Get(q Queryer, dest interface{}, query string, args ...interface{}) error { +func Get(q Queryer, dest interface{}, query safesql.TrustedSQLString, args ...interface{}) error { r := q.QueryRowx(query, args...) return r.scanAny(dest, false) } @@ -714,13 +715,14 @@ func LoadFile(e Execer, path string) (*sql.Result, error) { if err != nil { return nil, err } - res, err := e.Exec(string(contents)) + // UNSAFE: for now assuming that if we use this pattern it does not contain untrusted data + res, err := e.Exec(uncheckedconversions.TrustedSQLStringFromStringKnownToSatisfyTypeContract(string(contents))) return &res, err } // MustExec execs the query using e and panics if there was an error. // Any placeholder parameters are replaced with supplied args. -func MustExec(e Execer, query string, args ...interface{}) sql.Result { +func MustExec(e Execer, query safesql.TrustedSQLString, args ...interface{}) sql.Result { res, err := e.Exec(query, args...) if err != nil { panic(err) From 75e1b2fddc00c39e4c806a66b55314341cc02357 Mon Sep 17 00:00:00 2001 From: Peter Arts <77252233+sc-peter@users.noreply.github.com> Date: Wed, 27 Nov 2024 13:00:45 +1100 Subject: [PATCH 4/5] Refactor tests --- bind.go | 32 ++--- named_context.go | 10 +- named_context_test.go | 22 ++-- named_test.go | 176 +++++++++++++------------ sqlx.go | 2 +- sqlx_context.go | 88 +++++++------ sqlx_context_test.go | 232 ++++++++++++++++---------------- sqlx_test.go | 298 +++++++++++++++++++++--------------------- 8 files changed, 443 insertions(+), 417 deletions(-) diff --git a/bind.go b/bind.go index 5743a9b..e0406ad 100644 --- a/bind.go +++ b/bind.go @@ -101,15 +101,16 @@ func Rebind(bindType int, query safesql.TrustedSQLString) safesql.TrustedSQLStri // much simpler and should be more resistant to odd unicode, but it is twice as // slow. Kept here for benchmarking purposes and to possibly replace Rebind if // problems arise with its somewhat naive handling of unicode. -func rebindBuff(bindType int, query string) string { +func rebindBuff(bindType int, query safesql.TrustedSQLString) safesql.TrustedSQLString { if bindType != DOLLAR { return query } - b := make([]byte, 0, len(query)) + rawQuery := query.String() + b := make([]byte, 0, len(rawQuery)) rqb := bytes.NewBuffer(b) j := 1 - for _, r := range query { + for _, r := range rawQuery { if r == '?' { rqb.WriteRune('$') rqb.WriteString(strconv.Itoa(j)) @@ -119,7 +120,7 @@ func rebindBuff(bindType int, query string) string { } } - return rqb.String() + return uncheckedconversions.TrustedSQLStringFromStringKnownToSatisfyTypeContract(rqb.String()) } func asSliceForIn(i interface{}) (v reflect.Value, ok bool) { @@ -147,7 +148,7 @@ func asSliceForIn(i interface{}) (v reflect.Value, ok bool) { // In expands slice values in args, returning the modified query string // and a new arg list that can be executed by a database. The `query` should // use the `?` bindVar. The return value uses the `?` bindVar. -func In(query string, args ...interface{}) (string, []interface{}, error) { +func In(query safesql.TrustedSQLString, args ...interface{}) (safesql.TrustedSQLString, []interface{}, error) { // argMeta stores reflect.Value and length for slices and // the value itself for non-slice arguments type argMeta struct { @@ -173,7 +174,7 @@ func In(query string, args ...interface{}) (string, []interface{}, error) { var err error arg, err = a.Value() if err != nil { - return "", nil, err + return safesql.New(""), nil, err } } @@ -185,7 +186,7 @@ func In(query string, args ...interface{}) (string, []interface{}, error) { flatArgsCount += meta[i].length if meta[i].length == 0 { - return "", nil, errors.New("empty slice passed to 'in' query") + return safesql.New(""), nil, errors.New("empty slice passed to 'in' query") } } else { meta[i].i = arg @@ -202,17 +203,18 @@ func In(query string, args ...interface{}) (string, []interface{}, error) { newArgs := make([]interface{}, 0, flatArgsCount) var buf strings.Builder - buf.Grow(len(query) + len(", ?")*flatArgsCount) + rawQuery := query.String() + buf.Grow(len(rawQuery) + len(", ?")*flatArgsCount) var arg, offset int - for i := strings.IndexByte(query[offset:], '?'); i != -1; i = strings.IndexByte(query[offset:], '?') { + for i := strings.IndexByte(rawQuery[offset:], '?'); i != -1; i = strings.IndexByte(rawQuery[offset:], '?') { if arg >= len(meta) { // if an argument wasn't passed, lets return an error; this is // not actually how database/sql Exec/Query works, but since we are // creating an argument list programmatically, we want to be able // to catch these programmer errors earlier. - return "", nil, errors.New("number of bindVars exceeds arguments") + return safesql.New(""), nil, errors.New("number of bindVars exceeds arguments") } argMeta := meta[arg] @@ -228,7 +230,7 @@ func In(query string, args ...interface{}) (string, []interface{}, error) { } // write everything up to and including our ? character - buf.WriteString(query[:offset+i+1]) + buf.WriteString(rawQuery[:offset+i+1]) for si := 1; si < argMeta.length; si++ { buf.WriteString(", ?") @@ -238,17 +240,17 @@ func In(query string, args ...interface{}) (string, []interface{}, error) { // slice the query and reset the offset. this avoids some bookkeeping for // the write after the loop - query = query[offset+i+1:] + rawQuery = rawQuery[offset+i+1:] offset = 0 } - buf.WriteString(query) + buf.WriteString(rawQuery) if arg < len(meta) { - return "", nil, errors.New("number of bindVars less than number arguments") + return safesql.New(""), nil, errors.New("number of bindVars less than number arguments") } - return buf.String(), newArgs, nil + return uncheckedconversions.TrustedSQLStringFromStringKnownToSatisfyTypeContract(buf.String()), newArgs, nil } func appendReflectSlice(args []interface{}, v reflect.Value, vlen int) []interface{} { diff --git a/named_context.go b/named_context.go index 9ad23f4..edba7e4 100644 --- a/named_context.go +++ b/named_context.go @@ -6,6 +6,8 @@ package sqlx import ( "context" "database/sql" + + "github.com/google/go-safeweb/safesql" ) // A union interface of contextPreparer and binder, required to be able to @@ -15,9 +17,9 @@ type namedPreparerContext interface { binder } -func prepareNamedContext(ctx context.Context, p namedPreparerContext, query string) (*NamedStmt, error) { +func prepareNamedContext(ctx context.Context, p namedPreparerContext, query safesql.TrustedSQLString) (*NamedStmt, error) { bindType := BindType(p.DriverName()) - q, args, err := compileNamedQuery([]byte(query), bindType) + q, args, err := compileNamedQuery(query, bindType) if err != nil { return nil, err } @@ -113,7 +115,7 @@ func (n *NamedStmt) GetContext(ctx context.Context, dest interface{}, arg interf // NamedQueryContext binds a named query and then runs Query on the result using the // provided Ext (sqlx.Tx, sqlx.Db). It works with both structs and with // map[string]interface{} types. -func NamedQueryContext(ctx context.Context, e ExtContext, query string, arg interface{}) (*Rows, error) { +func NamedQueryContext(ctx context.Context, e ExtContext, query safesql.TrustedSQLString, arg interface{}) (*Rows, error) { q, args, err := bindNamedMapper(BindType(e.DriverName()), query, arg, mapperFor(e)) if err != nil { return nil, err @@ -124,7 +126,7 @@ func NamedQueryContext(ctx context.Context, e ExtContext, query string, arg inte // NamedExecContext uses BindStruct to get a query executable by the driver and // then runs Exec on the result. Returns an error from the binding // or the query execution itself. -func NamedExecContext(ctx context.Context, e ExtContext, query string, arg interface{}) (sql.Result, error) { +func NamedExecContext(ctx context.Context, e ExtContext, query safesql.TrustedSQLString, arg interface{}) (sql.Result, error) { q, args, err := bindNamedMapper(BindType(e.DriverName()), query, arg, mapperFor(e)) if err != nil { return nil, err diff --git a/named_context_test.go b/named_context_test.go index 03f933d..f980c67 100644 --- a/named_context_test.go +++ b/named_context_test.go @@ -7,6 +7,8 @@ import ( "context" "database/sql" "testing" + + "github.com/google/go-safeweb/safesql" ) func TestNamedContextQueries(t *testing.T) { @@ -19,25 +21,25 @@ func TestNamedContextQueries(t *testing.T) { ctx := context.Background() // Check that invalid preparations fail - _, err = db.PrepareNamedContext(ctx, "SELECT * FROM person WHERE first_name=:first:name") + _, err = db.PrepareNamedContext(ctx, safesql.New("SELECT * FROM person WHERE first_name=:first:name")) if err == nil { t.Error("Expected an error with invalid prepared statement.") } - _, err = db.PrepareNamedContext(ctx, "invalid sql") + _, err = db.PrepareNamedContext(ctx, safesql.New("invalid sql")) if err == nil { t.Error("Expected an error with invalid prepared statement.") } // Check closing works as anticipated - ns, err = db.PrepareNamedContext(ctx, "SELECT * FROM person WHERE first_name=:first_name") + ns, err = db.PrepareNamedContext(ctx, safesql.New("SELECT * FROM person WHERE first_name=:first_name")) test.Error(err) err = ns.Close() test.Error(err) - ns, err = db.PrepareNamedContext(ctx, ` + ns, err = db.PrepareNamedContext(ctx, safesql.New(` SELECT first_name, last_name, email - FROM person WHERE first_name=:first_name AND email=:email`) + FROM person WHERE first_name=:first_name AND email=:email`)) test.Error(err) // test Queryx w/ uses Query @@ -78,9 +80,9 @@ func TestNamedContextQueries(t *testing.T) { } // test Exec - ns, err = db.PrepareNamedContext(ctx, ` + ns, err = db.PrepareNamedContext(ctx, safesql.New(` INSERT INTO person (first_name, last_name, email) - VALUES (:first_name, :last_name, :email)`) + VALUES (:first_name, :last_name, :email)`)) test.Error(err) js := Person{ @@ -93,7 +95,7 @@ func TestNamedContextQueries(t *testing.T) { // Make sure we can pull him out again p2 := Person{} - db.GetContext(ctx, &p2, db.Rebind("SELECT * FROM person WHERE email=?"), js.Email) + db.GetContext(ctx, &p2, db.Rebind(safesql.New("SELECT * FROM person WHERE email=?")), js.Email) if p2.Email != js.Email { t.Errorf("expected %s, got %s", js.Email, p2.Email) } @@ -114,7 +116,7 @@ func TestNamedContextQueries(t *testing.T) { // then rollback... tx.Rollback() // looking for Steven after a rollback should fail - err = db.GetContext(ctx, &p2, db.Rebind("SELECT * FROM person WHERE email=?"), sl.Email) + err = db.GetContext(ctx, &p2, db.Rebind(safesql.New("SELECT * FROM person WHERE email=?")), sl.Email) if err != sql.ErrNoRows { t.Errorf("expected no rows error, got %v", err) } @@ -127,7 +129,7 @@ func TestNamedContextQueries(t *testing.T) { tx.Commit() // looking for Steven after a Commit should succeed - err = db.GetContext(ctx, &p2, db.Rebind("SELECT * FROM person WHERE email=?"), sl.Email) + err = db.GetContext(ctx, &p2, db.Rebind(safesql.New("SELECT * FROM person WHERE email=?")), sl.Email) test.Error(err) if p2.Email != sl.Email { t.Errorf("expected %s, got %s", sl.Email, p2.Email) diff --git a/named_test.go b/named_test.go index 0ee5b85..f5b8b2f 100644 --- a/named_test.go +++ b/named_test.go @@ -4,53 +4,56 @@ import ( "database/sql" "fmt" "testing" + + "github.com/google/go-safeweb/safesql" + "github.com/google/go-safeweb/safesql/uncheckedconversions" ) func TestCompileQuery(t *testing.T) { table := []struct { - Q, R, D, T, N string + Q, R, D, T, N safesql.TrustedSQLString V []string }{ // basic test for named parameters, invalid char ',' terminating { - Q: `INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last)`, - R: `INSERT INTO foo (a,b,c,d) VALUES (?, ?, ?, ?)`, - D: `INSERT INTO foo (a,b,c,d) VALUES ($1, $2, $3, $4)`, - T: `INSERT INTO foo (a,b,c,d) VALUES (@p1, @p2, @p3, @p4)`, - N: `INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last)`, + Q: safesql.New(`INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last)`), + R: safesql.New(`INSERT INTO foo (a,b,c,d) VALUES (?, ?, ?, ?)`), + D: safesql.New(`INSERT INTO foo (a,b,c,d) VALUES ($1, $2, $3, $4)`), + T: safesql.New(`INSERT INTO foo (a,b,c,d) VALUES (@p1, @p2, @p3, @p4)`), + N: safesql.New(`INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last)`), V: []string{"name", "age", "first", "last"}, }, // This query tests a named parameter ending the string as well as numbers { - Q: `SELECT * FROM a WHERE first_name=:name1 AND last_name=:name2`, - R: `SELECT * FROM a WHERE first_name=? AND last_name=?`, - D: `SELECT * FROM a WHERE first_name=$1 AND last_name=$2`, - T: `SELECT * FROM a WHERE first_name=@p1 AND last_name=@p2`, - N: `SELECT * FROM a WHERE first_name=:name1 AND last_name=:name2`, + Q: safesql.New(`SELECT * FROM a WHERE first_name=:name1 AND last_name=:name2`), + R: safesql.New(`SELECT * FROM a WHERE first_name=? AND last_name=?`), + D: safesql.New(`SELECT * FROM a WHERE first_name=$1 AND last_name=$2`), + T: safesql.New(`SELECT * FROM a WHERE first_name=@p1 AND last_name=@p2`), + N: safesql.New(`SELECT * FROM a WHERE first_name=:name1 AND last_name=:name2`), V: []string{"name1", "name2"}, }, { - Q: `SELECT "::foo" FROM a WHERE first_name=:name1 AND last_name=:name2`, - R: `SELECT ":foo" FROM a WHERE first_name=? AND last_name=?`, - D: `SELECT ":foo" FROM a WHERE first_name=$1 AND last_name=$2`, - T: `SELECT ":foo" FROM a WHERE first_name=@p1 AND last_name=@p2`, - N: `SELECT ":foo" FROM a WHERE first_name=:name1 AND last_name=:name2`, + Q: safesql.New(`SELECT "::foo" FROM a WHERE first_name=:name1 AND last_name=:name2`), + R: safesql.New(`SELECT ":foo" FROM a WHERE first_name=? AND last_name=?`), + D: safesql.New(`SELECT ":foo" FROM a WHERE first_name=$1 AND last_name=$2`), + T: safesql.New(`SELECT ":foo" FROM a WHERE first_name=@p1 AND last_name=@p2`), + N: safesql.New(`SELECT ":foo" FROM a WHERE first_name=:name1 AND last_name=:name2`), V: []string{"name1", "name2"}, }, { - Q: `SELECT 'a::b::c' || first_name, '::::ABC::_::' FROM person WHERE first_name=:first_name AND last_name=:last_name`, - R: `SELECT 'a:b:c' || first_name, '::ABC:_:' FROM person WHERE first_name=? AND last_name=?`, - D: `SELECT 'a:b:c' || first_name, '::ABC:_:' FROM person WHERE first_name=$1 AND last_name=$2`, - T: `SELECT 'a:b:c' || first_name, '::ABC:_:' FROM person WHERE first_name=@p1 AND last_name=@p2`, - N: `SELECT 'a:b:c' || first_name, '::ABC:_:' FROM person WHERE first_name=:first_name AND last_name=:last_name`, + Q: safesql.New(`SELECT 'a::b::c' || first_name, '::::ABC::_::' FROM person WHERE first_name=:first_name AND last_name=:last_name`), + R: safesql.New(`SELECT 'a:b:c' || first_name, '::ABC:_:' FROM person WHERE first_name=? AND last_name=?`), + D: safesql.New(`SELECT 'a:b:c' || first_name, '::ABC:_:' FROM person WHERE first_name=$1 AND last_name=$2`), + T: safesql.New(`SELECT 'a:b:c' || first_name, '::ABC:_:' FROM person WHERE first_name=@p1 AND last_name=@p2`), + N: safesql.New(`SELECT 'a:b:c' || first_name, '::ABC:_:' FROM person WHERE first_name=:first_name AND last_name=:last_name`), V: []string{"first_name", "last_name"}, }, { - Q: `SELECT @name := "name", :age, :first, :last`, - R: `SELECT @name := "name", ?, ?, ?`, - D: `SELECT @name := "name", $1, $2, $3`, - N: `SELECT @name := "name", :age, :first, :last`, - T: `SELECT @name := "name", @p1, @p2, @p3`, + Q: safesql.New(`SELECT @name := "name", :age, :first, :last`), + R: safesql.New(`SELECT @name := "name", ?, ?, ?`), + D: safesql.New(`SELECT @name := "name", $1, $2, $3`), + N: safesql.New(`SELECT @name := "name", :age, :first, :last`), + T: safesql.New(`SELECT @name := "name", @p1, @p2, @p3`), V: []string{"age", "first", "last"}, }, /* This unicode awareness test sadly fails, because of our byte-wise worldview. @@ -66,7 +69,7 @@ func TestCompileQuery(t *testing.T) { } for _, test := range table { - qr, names, err := compileNamedQuery([]byte(test.Q), QUESTION) + qr, names, err := compileNamedQuery(test.Q, QUESTION) if err != nil { t.Error(err) } @@ -82,19 +85,19 @@ func TestCompileQuery(t *testing.T) { } } } - qd, _, _ := compileNamedQuery([]byte(test.Q), DOLLAR) + qd, _, _ := compileNamedQuery(test.Q, DOLLAR) if qd != test.D { t.Errorf("\nexpected: `%s`\ngot: `%s`", test.D, qd) } - qt, _, _ := compileNamedQuery([]byte(test.Q), AT) + qt, _, _ := compileNamedQuery(test.Q, AT) if qt != test.T { t.Errorf("\nexpected: `%s`\ngot: `%s`", test.T, qt) } - qq, _, _ := compileNamedQuery([]byte(test.Q), NAMED) + qq, _, _ := compileNamedQuery(test.Q, NAMED) if qq != test.N { - t.Errorf("\nexpected: `%s`\ngot: `%s`\n(len: %d vs %d)", test.N, qq, len(test.N), len(qq)) + t.Errorf("\nexpected: `%s`\ngot: `%s`\n(len: %d vs %d)", test.N, qq, len(test.N.String()), len(qq.String())) } } } @@ -123,9 +126,9 @@ func (t Test) Errorf(err error, format string, args ...interface{}) { func TestEscapedColons(t *testing.T) { t.Skip("not sure it is possible to support this in general case without an SQL parser") - var qs = `SELECT * FROM testtable WHERE timeposted BETWEEN (now() AT TIME ZONE 'utc') AND - (now() AT TIME ZONE 'utc') - interval '01:30:00') AND name = '\'this is a test\'' and id = :id` - _, _, err := compileNamedQuery([]byte(qs), DOLLAR) + var qs = safesql.New(`SELECT * FROM testtable WHERE timeposted BETWEEN (now() AT TIME ZONE 'utc') AND + (now() AT TIME ZONE 'utc') - interval '01:30:00') AND name = '\'this is a test\'' and id = :id`) + _, _, err := compileNamedQuery(qs, DOLLAR) if err != nil { t.Error("Didn't handle colons correctly when inside a string") } @@ -139,25 +142,25 @@ func TestNamedQueries(t *testing.T) { var err error // Check that invalid preparations fail - _, err = db.PrepareNamed("SELECT * FROM person WHERE first_name=:first:name") + _, err = db.PrepareNamed(safesql.New("SELECT * FROM person WHERE first_name=:first:name")) if err == nil { t.Error("Expected an error with invalid prepared statement.") } - _, err = db.PrepareNamed("invalid sql") + _, err = db.PrepareNamed(safesql.New("invalid sql")) if err == nil { t.Error("Expected an error with invalid prepared statement.") } // Check closing works as anticipated - ns, err = db.PrepareNamed("SELECT * FROM person WHERE first_name=:first_name") + ns, err = db.PrepareNamed(safesql.New("SELECT * FROM person WHERE first_name=:first_name")) test.Error(err) err = ns.Close() test.Error(err) - ns, err = db.PrepareNamed(` + ns, err = db.PrepareNamed(safesql.New(` SELECT first_name, last_name, email - FROM person WHERE first_name=:first_name AND email=:email`) + FROM person WHERE first_name=:first_name AND email=:email`)) test.Error(err) // test Queryx w/ uses Query @@ -204,10 +207,12 @@ func TestNamedQueries(t *testing.T) { {FirstName: "Ngani", LastName: "Laumape", Email: "nlaumape@ab.co.nz"}, } - insert := fmt.Sprintf( + // UNSAFE: this is fine for tests when there's no easy alternative + // Never do this in production code! + insert := uncheckedconversions.TrustedSQLStringFromStringKnownToSatisfyTypeContract(fmt.Sprintf( "INSERT INTO person (first_name, last_name, email, added_at) VALUES (:first_name, :last_name, :email, %v)\n", now, - ) + )) _, err = db.NamedExec(insert, sls) test.Error(err) @@ -218,8 +223,8 @@ func TestNamedQueries(t *testing.T) { {"first_name": "Ngani", "last_name": "Laumape", "email": "nlaumape@ab.co.nz"}, } - _, err = db.NamedExec(`INSERT INTO person (first_name, last_name, email) - VALUES (:first_name, :last_name, :email) ;--`, slsMap) + _, err = db.NamedExec(safesql.New(`INSERT INTO person (first_name, last_name, email) + VALUES (:first_name, :last_name, :email) ;--`), slsMap) test.Error(err) type A map[string]interface{} @@ -230,13 +235,13 @@ func TestNamedQueries(t *testing.T) { {"first_name": "Ngani", "last_name": "Laumape", "email": "nlaumape@ab.co.nz"}, } - _, err = db.NamedExec(`INSERT INTO person (first_name, last_name, email) - VALUES (:first_name, :last_name, :email) ;--`, typedMap) + _, err = db.NamedExec(safesql.New(`INSERT INTO person (first_name, last_name, email) + VALUES (:first_name, :last_name, :email) ;--`), typedMap) test.Error(err) for _, p := range sls { dest := Person{} - err = db.Get(&dest, db.Rebind("SELECT * FROM person WHERE email=?"), p.Email) + err = db.Get(&dest, db.Rebind(safesql.New("SELECT * FROM person WHERE email=?")), p.Email) test.Error(err) if dest.Email != p.Email { t.Errorf("expected %s, got %s", p.Email, dest.Email) @@ -244,9 +249,9 @@ func TestNamedQueries(t *testing.T) { } // test Exec - ns, err = db.PrepareNamed(` + ns, err = db.PrepareNamed(safesql.New(` INSERT INTO person (first_name, last_name, email) - VALUES (:first_name, :last_name, :email)`) + VALUES (:first_name, :last_name, :email)`)) test.Error(err) js := Person{ @@ -259,7 +264,7 @@ func TestNamedQueries(t *testing.T) { // Make sure we can pull him out again p2 := Person{} - db.Get(&p2, db.Rebind("SELECT * FROM person WHERE email=?"), js.Email) + db.Get(&p2, db.Rebind(safesql.New("SELECT * FROM person WHERE email=?")), js.Email) if p2.Email != js.Email { t.Errorf("expected %s, got %s", js.Email, p2.Email) } @@ -280,7 +285,7 @@ func TestNamedQueries(t *testing.T) { // then rollback... tx.Rollback() // looking for Steven after a rollback should fail - err = db.Get(&p2, db.Rebind("SELECT * FROM person WHERE email=?"), sl.Email) + err = db.Get(&p2, db.Rebind(safesql.New("SELECT * FROM person WHERE email=?")), sl.Email) if err != sql.ErrNoRows { t.Errorf("expected no rows error, got %v", err) } @@ -293,7 +298,7 @@ func TestNamedQueries(t *testing.T) { tx.Commit() // looking for Steven after a Commit should succeed - err = db.Get(&p2, db.Rebind("SELECT * FROM person WHERE email=?"), sl.Email) + err = db.Get(&p2, db.Rebind(safesql.New("SELECT * FROM person WHERE email=?")), sl.Email) test.Error(err) if p2.Email != sl.Email { t.Errorf("expected %s, got %s", sl.Email, p2.Email) @@ -304,96 +309,97 @@ func TestNamedQueries(t *testing.T) { func TestFixBounds(t *testing.T) { table := []struct { - name, query, expect string - loop int + name string + query, expect safesql.TrustedSQLString + loop int }{ { name: `named syntax`, - query: `INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last)`, - expect: `INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last),(:name, :age, :first, :last)`, + query: safesql.New(`INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last)`), + expect: safesql.New(`INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last),(:name, :age, :first, :last)`), loop: 2, }, { name: `mysql syntax`, - query: `INSERT INTO foo (a,b,c,d) VALUES (?, ?, ?, ?)`, - expect: `INSERT INTO foo (a,b,c,d) VALUES (?, ?, ?, ?),(?, ?, ?, ?)`, + query: safesql.New(`INSERT INTO foo (a,b,c,d) VALUES (?, ?, ?, ?)`), + expect: safesql.New(`INSERT INTO foo (a,b,c,d) VALUES (?, ?, ?, ?),(?, ?, ?, ?)`), loop: 2, }, { name: `named syntax w/ trailer`, - query: `INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last) ;--`, - expect: `INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last),(:name, :age, :first, :last) ;--`, + query: safesql.New(`INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last) ;--`), + expect: safesql.New(`INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last),(:name, :age, :first, :last) ;--`), loop: 2, }, { name: `mysql syntax w/ trailer`, - query: `INSERT INTO foo (a,b,c,d) VALUES (?, ?, ?, ?) ;--`, - expect: `INSERT INTO foo (a,b,c,d) VALUES (?, ?, ?, ?),(?, ?, ?, ?) ;--`, + query: safesql.New(`INSERT INTO foo (a,b,c,d) VALUES (?, ?, ?, ?) ;--`), + expect: safesql.New(`INSERT INTO foo (a,b,c,d) VALUES (?, ?, ?, ?),(?, ?, ?, ?) ;--`), loop: 2, }, { name: `not found test`, - query: `INSERT INTO foo (a,b,c,d) (:name, :age, :first, :last)`, - expect: `INSERT INTO foo (a,b,c,d) (:name, :age, :first, :last)`, + query: safesql.New(`INSERT INTO foo (a,b,c,d) (:name, :age, :first, :last)`), + expect: safesql.New(`INSERT INTO foo (a,b,c,d) (:name, :age, :first, :last)`), loop: 2, }, { name: `found twice test`, - query: `INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last) VALUES (:name, :age, :first, :last)`, - expect: `INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last),(:name, :age, :first, :last) VALUES (:name, :age, :first, :last)`, + query: safesql.New(`INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last) VALUES (:name, :age, :first, :last)`), + expect: safesql.New(`INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last),(:name, :age, :first, :last) VALUES (:name, :age, :first, :last)`), loop: 2, }, { name: `nospace`, - query: `INSERT INTO foo (a,b) VALUES(:a, :b)`, - expect: `INSERT INTO foo (a,b) VALUES(:a, :b),(:a, :b)`, + query: safesql.New(`INSERT INTO foo (a,b) VALUES(:a, :b)`), + expect: safesql.New(`INSERT INTO foo (a,b) VALUES(:a, :b),(:a, :b)`), loop: 2, }, { name: `lowercase`, - query: `INSERT INTO foo (a,b) values(:a, :b)`, - expect: `INSERT INTO foo (a,b) values(:a, :b),(:a, :b)`, + query: safesql.New(`INSERT INTO foo (a,b) values(:a, :b)`), + expect: safesql.New(`INSERT INTO foo (a,b) values(:a, :b),(:a, :b)`), loop: 2, }, { name: `on duplicate key using VALUES`, - query: `INSERT INTO foo (a,b) VALUES (:a, :b) ON DUPLICATE KEY UPDATE a=VALUES(a)`, - expect: `INSERT INTO foo (a,b) VALUES (:a, :b),(:a, :b) ON DUPLICATE KEY UPDATE a=VALUES(a)`, + query: safesql.New(`INSERT INTO foo (a,b) VALUES (:a, :b) ON DUPLICATE KEY UPDATE a=VALUES(a)`), + expect: safesql.New(`INSERT INTO foo (a,b) VALUES (:a, :b),(:a, :b) ON DUPLICATE KEY UPDATE a=VALUES(a)`), loop: 2, }, { name: `single column`, - query: `INSERT INTO foo (a) VALUES (:a)`, - expect: `INSERT INTO foo (a) VALUES (:a),(:a)`, + query: safesql.New(`INSERT INTO foo (a) VALUES (:a)`), + expect: safesql.New(`INSERT INTO foo (a) VALUES (:a),(:a)`), loop: 2, }, { name: `call now`, - query: `INSERT INTO foo (a, b) VALUES (:a, NOW())`, - expect: `INSERT INTO foo (a, b) VALUES (:a, NOW()),(:a, NOW())`, + query: safesql.New(`INSERT INTO foo (a, b) VALUES (:a, NOW())`), + expect: safesql.New(`INSERT INTO foo (a, b) VALUES (:a, NOW()),(:a, NOW())`), loop: 2, }, { name: `two level depth function call`, - query: `INSERT INTO foo (a, b) VALUES (:a, YEAR(NOW()))`, - expect: `INSERT INTO foo (a, b) VALUES (:a, YEAR(NOW())),(:a, YEAR(NOW()))`, + query: safesql.New(`INSERT INTO foo (a, b) VALUES (:a, YEAR(NOW()))`), + expect: safesql.New(`INSERT INTO foo (a, b) VALUES (:a, YEAR(NOW())),(:a, YEAR(NOW()))`), loop: 2, }, { name: `missing closing bracket`, - query: `INSERT INTO foo (a, b) VALUES (:a, YEAR(NOW())`, - expect: `INSERT INTO foo (a, b) VALUES (:a, YEAR(NOW())`, + query: safesql.New(`INSERT INTO foo (a, b) VALUES (:a, YEAR(NOW())`), + expect: safesql.New(`INSERT INTO foo (a, b) VALUES (:a, YEAR(NOW())`), loop: 2, }, { name: `table with "values" at the end`, - query: `INSERT INTO table_values (a, b) VALUES (:a, :b)`, - expect: `INSERT INTO table_values (a, b) VALUES (:a, :b),(:a, :b)`, + query: safesql.New(`INSERT INTO table_values (a, b) VALUES (:a, :b)`), + expect: safesql.New(`INSERT INTO table_values (a, b) VALUES (:a, :b),(:a, :b)`), loop: 2, }, { name: `multiline indented query`, - query: `INSERT INTO foo ( + query: safesql.New(`INSERT INTO foo ( a, b, c, @@ -403,8 +409,8 @@ func TestFixBounds(t *testing.T) { :age, :first, :last - )`, - expect: `INSERT INTO foo ( + )`), + expect: safesql.New(`INSERT INTO foo ( a, b, c, @@ -419,7 +425,7 @@ func TestFixBounds(t *testing.T) { :age, :first, :last - )`, + )`), loop: 2, }, } diff --git a/sqlx.go b/sqlx.go index 51174a1..90887dc 100644 --- a/sqlx.go +++ b/sqlx.go @@ -383,7 +383,7 @@ func (db *DB) PrepareNamed(query safesql.TrustedSQLString) (*NamedStmt, error) { // Conn is a wrapper around sql.Conn with extra functionality type Conn struct { - *sql.Conn + *safesql.Conn driverName string unsafe bool Mapper *reflectx.Mapper diff --git a/sqlx_context.go b/sqlx_context.go index 33d86ef..f0330cd 100644 --- a/sqlx_context.go +++ b/sqlx_context.go @@ -12,6 +12,7 @@ import ( "reflect" "github.com/google/go-safeweb/safesql" + "github.com/google/go-safeweb/safesql/uncheckedconversions" ) // ConnectContext to a database and verify with a ping. @@ -26,19 +27,19 @@ func ConnectContext(ctx context.Context, driverName, dataSourceName string) (*DB // QueryerContext is an interface used by GetContext and SelectContext type QueryerContext interface { - QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) - QueryxContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) - QueryRowxContext(ctx context.Context, query string, args ...interface{}) *Row + QueryContext(ctx context.Context, query safesql.TrustedSQLString, args ...interface{}) (*sql.Rows, error) + QueryxContext(ctx context.Context, query safesql.TrustedSQLString, args ...interface{}) (*Rows, error) + QueryRowxContext(ctx context.Context, query safesql.TrustedSQLString, args ...interface{}) *Row } // PreparerContext is an interface used by PreparexContext. type PreparerContext interface { - PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) + PrepareContext(ctx context.Context, query safesql.TrustedSQLString) (*sql.Stmt, error) } // ExecerContext is an interface used by MustExecContext and LoadFileContext type ExecerContext interface { - ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) + ExecContext(ctx context.Context, query safesql.TrustedSQLString, args ...interface{}) (sql.Result, error) } // ExtContext is a union interface which can bind, query, and exec, with Context @@ -54,7 +55,7 @@ type ExtContext interface { // scannable, then the result set must have only one column. Otherwise, // StructScan is used. The *sql.Rows are closed automatically. // Any placeholder parameters are replaced with supplied args. -func SelectContext(ctx context.Context, q QueryerContext, dest interface{}, query string, args ...interface{}) error { +func SelectContext(ctx context.Context, q QueryerContext, dest interface{}, query safesql.TrustedSQLString, args ...interface{}) error { rows, err := q.QueryxContext(ctx, query, args...) if err != nil { return err @@ -68,7 +69,7 @@ func SelectContext(ctx context.Context, q QueryerContext, dest interface{}, quer // // The provided context is used for the preparation of the statement, not for // the execution of the statement. -func PreparexContext(ctx context.Context, p PreparerContext, query string) (*Stmt, error) { +func PreparexContext(ctx context.Context, p PreparerContext, query safesql.TrustedSQLString) (*Stmt, error) { s, err := p.PrepareContext(ctx, query) if err != nil { return nil, err @@ -81,7 +82,7 @@ func PreparexContext(ctx context.Context, p PreparerContext, query string) (*Stm // column. Otherwise, StructScan is used. Get will return sql.ErrNoRows like // row.Scan would. Any placeholder parameters are replaced with supplied args. // An error is returned if the result set is empty. -func GetContext(ctx context.Context, q QueryerContext, dest interface{}, query string, args ...interface{}) error { +func GetContext(ctx context.Context, q QueryerContext, dest interface{}, query safesql.TrustedSQLString, args ...interface{}) error { r := q.QueryRowxContext(ctx, query, args...) return r.scanAny(dest, false) } @@ -106,13 +107,14 @@ func LoadFileContext(ctx context.Context, e ExecerContext, path string) (*sql.Re if err != nil { return nil, err } - res, err := e.ExecContext(ctx, string(contents)) + // UNSAFE: for now assuming that if we use this pattern it does not contain untrusted data + res, err := e.ExecContext(ctx, uncheckedconversions.TrustedSQLStringFromStringKnownToSatisfyTypeContract(string(contents))) return &res, err } // MustExecContext execs the query using e and panics if there was an error. // Any placeholder parameters are replaced with supplied args. -func MustExecContext(ctx context.Context, e ExecerContext, query string, args ...interface{}) sql.Result { +func MustExecContext(ctx context.Context, e ExecerContext, query safesql.TrustedSQLString, args ...interface{}) sql.Result { res, err := e.ExecContext(ctx, query, args...) if err != nil { panic(err) @@ -121,32 +123,32 @@ func MustExecContext(ctx context.Context, e ExecerContext, query string, args .. } // PrepareNamedContext returns an sqlx.NamedStmt -func (db *DB) PrepareNamedContext(ctx context.Context, query string) (*NamedStmt, error) { +func (db *DB) PrepareNamedContext(ctx context.Context, query safesql.TrustedSQLString) (*NamedStmt, error) { return prepareNamedContext(ctx, db, query) } // NamedQueryContext using this DB. // Any named placeholder parameters are replaced with fields from arg. -func (db *DB) NamedQueryContext(ctx context.Context, query string, arg interface{}) (*Rows, error) { +func (db *DB) NamedQueryContext(ctx context.Context, query safesql.TrustedSQLString, arg interface{}) (*Rows, error) { return NamedQueryContext(ctx, db, query, arg) } // NamedExecContext using this DB. // Any named placeholder parameters are replaced with fields from arg. -func (db *DB) NamedExecContext(ctx context.Context, query string, arg interface{}) (sql.Result, error) { +func (db *DB) NamedExecContext(ctx context.Context, query safesql.TrustedSQLString, arg interface{}) (sql.Result, error) { return NamedExecContext(ctx, db, query, arg) } // SelectContext using this DB. // Any placeholder parameters are replaced with supplied args. -func (db *DB) SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error { +func (db *DB) SelectContext(ctx context.Context, dest interface{}, query safesql.TrustedSQLString, args ...interface{}) error { return SelectContext(ctx, db, dest, query, args...) } // GetContext using this DB. // Any placeholder parameters are replaced with supplied args. // An error is returned if the result set is empty. -func (db *DB) GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error { +func (db *DB) GetContext(ctx context.Context, dest interface{}, query safesql.TrustedSQLString, args ...interface{}) error { return GetContext(ctx, db, dest, query, args...) } @@ -154,7 +156,7 @@ func (db *DB) GetContext(ctx context.Context, dest interface{}, query string, ar // // The provided context is used for the preparation of the statement, not for // the execution of the statement. -func (db *DB) PreparexContext(ctx context.Context, query string) (*Stmt, error) { +func (db *DB) PreparexContext(ctx context.Context, query safesql.TrustedSQLString) (*Stmt, error) { return PreparexContext(ctx, db, query) } @@ -192,7 +194,7 @@ func (db *DB) MustBeginTx(ctx context.Context, opts *sql.TxOptions) *Tx { // MustExecContext (panic) runs MustExec using this database. // Any placeholder parameters are replaced with supplied args. -func (db *DB) MustExecContext(ctx context.Context, query string, args ...interface{}) sql.Result { +func (db *DB) MustExecContext(ctx context.Context, query safesql.TrustedSQLString, args ...interface{}) sql.Result { return MustExecContext(ctx, db, query, args...) } @@ -208,7 +210,7 @@ func (db *DB) BeginTxx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) { if err != nil { return nil, err } - return &Tx{Tx: tx, driverName: db.driverName, unsafe: db.unsafe, Mapper: db.Mapper}, err + return &Tx{Tx: &tx, driverName: db.driverName, unsafe: db.unsafe, Mapper: db.Mapper}, err } // Connx returns an *sqlx.Conn instead of an *sql.Conn. @@ -218,7 +220,7 @@ func (db *DB) Connx(ctx context.Context) (*Conn, error) { return nil, err } - return &Conn{Conn: conn, driverName: db.driverName, unsafe: db.unsafe, Mapper: db.Mapper}, nil + return &Conn{Conn: &conn, driverName: db.driverName, unsafe: db.unsafe, Mapper: db.Mapper}, nil } // BeginTxx begins a transaction and returns an *sqlx.Tx instead of an @@ -233,19 +235,19 @@ func (c *Conn) BeginTxx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) { if err != nil { return nil, err } - return &Tx{Tx: tx, driverName: c.driverName, unsafe: c.unsafe, Mapper: c.Mapper}, err + return &Tx{Tx: &tx, driverName: c.driverName, unsafe: c.unsafe, Mapper: c.Mapper}, err } // SelectContext using this Conn. // Any placeholder parameters are replaced with supplied args. -func (c *Conn) SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error { +func (c *Conn) SelectContext(ctx context.Context, dest interface{}, query safesql.TrustedSQLString, args ...interface{}) error { return SelectContext(ctx, c, dest, query, args...) } // GetContext using this Conn. // Any placeholder parameters are replaced with supplied args. // An error is returned if the result set is empty. -func (c *Conn) GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error { +func (c *Conn) GetContext(ctx context.Context, dest interface{}, query safesql.TrustedSQLString, args ...interface{}) error { return GetContext(ctx, c, dest, query, args...) } @@ -253,13 +255,13 @@ func (c *Conn) GetContext(ctx context.Context, dest interface{}, query string, a // // The provided context is used for the preparation of the statement, not for // the execution of the statement. -func (c *Conn) PreparexContext(ctx context.Context, query string) (*Stmt, error) { +func (c *Conn) PreparexContext(ctx context.Context, query safesql.TrustedSQLString) (*Stmt, error) { return PreparexContext(ctx, c, query) } // QueryxContext queries the database and returns an *sqlx.Rows. // Any placeholder parameters are replaced with supplied args. -func (c *Conn) QueryxContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) { +func (c *Conn) QueryxContext(ctx context.Context, query safesql.TrustedSQLString, args ...interface{}) (*Rows, error) { r, err := c.Conn.QueryContext(ctx, query, args...) if err != nil { return nil, err @@ -269,13 +271,13 @@ func (c *Conn) QueryxContext(ctx context.Context, query string, args ...interfac // QueryRowxContext queries the database and returns an *sqlx.Row. // Any placeholder parameters are replaced with supplied args. -func (c *Conn) QueryRowxContext(ctx context.Context, query string, args ...interface{}) *Row { +func (c *Conn) QueryRowxContext(ctx context.Context, query safesql.TrustedSQLString, args ...interface{}) *Row { rows, err := c.Conn.QueryContext(ctx, query, args...) return &Row{rows: rows, err: err, unsafe: c.unsafe, Mapper: c.Mapper} } // Rebind a query within a Conn's bindvar type. -func (c *Conn) Rebind(query string) string { +func (c *Conn) Rebind(query safesql.TrustedSQLString) safesql.TrustedSQLString { return Rebind(BindType(c.driverName), query) } @@ -310,24 +312,24 @@ func (tx *Tx) NamedStmtContext(ctx context.Context, stmt *NamedStmt) *NamedStmt // // The provided context is used for the preparation of the statement, not for // the execution of the statement. -func (tx *Tx) PreparexContext(ctx context.Context, query string) (*Stmt, error) { +func (tx *Tx) PreparexContext(ctx context.Context, query safesql.TrustedSQLString) (*Stmt, error) { return PreparexContext(ctx, tx, query) } // PrepareNamedContext returns an sqlx.NamedStmt -func (tx *Tx) PrepareNamedContext(ctx context.Context, query string) (*NamedStmt, error) { +func (tx *Tx) PrepareNamedContext(ctx context.Context, query safesql.TrustedSQLString) (*NamedStmt, error) { return prepareNamedContext(ctx, tx, query) } // MustExecContext runs MustExecContext within a transaction. // Any placeholder parameters are replaced with supplied args. -func (tx *Tx) MustExecContext(ctx context.Context, query string, args ...interface{}) sql.Result { +func (tx *Tx) MustExecContext(ctx context.Context, query safesql.TrustedSQLString, args ...interface{}) sql.Result { return MustExecContext(ctx, tx, query, args...) } // QueryxContext within a transaction and context. // Any placeholder parameters are replaced with supplied args. -func (tx *Tx) QueryxContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) { +func (tx *Tx) QueryxContext(ctx context.Context, query safesql.TrustedSQLString, args ...interface{}) (*Rows, error) { r, err := tx.Tx.QueryContext(ctx, query, args...) if err != nil { return nil, err @@ -337,69 +339,69 @@ func (tx *Tx) QueryxContext(ctx context.Context, query string, args ...interface // SelectContext within a transaction and context. // Any placeholder parameters are replaced with supplied args. -func (tx *Tx) SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error { +func (tx *Tx) SelectContext(ctx context.Context, dest interface{}, query safesql.TrustedSQLString, args ...interface{}) error { return SelectContext(ctx, tx, dest, query, args...) } // GetContext within a transaction and context. // Any placeholder parameters are replaced with supplied args. // An error is returned if the result set is empty. -func (tx *Tx) GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error { +func (tx *Tx) GetContext(ctx context.Context, dest interface{}, query safesql.TrustedSQLString, args ...interface{}) error { return GetContext(ctx, tx, dest, query, args...) } // QueryRowxContext within a transaction and context. // Any placeholder parameters are replaced with supplied args. -func (tx *Tx) QueryRowxContext(ctx context.Context, query string, args ...interface{}) *Row { +func (tx *Tx) QueryRowxContext(ctx context.Context, query safesql.TrustedSQLString, args ...interface{}) *Row { rows, err := tx.Tx.QueryContext(ctx, query, args...) return &Row{rows: rows, err: err, unsafe: tx.unsafe, Mapper: tx.Mapper} } // NamedExecContext using this Tx. // Any named placeholder parameters are replaced with fields from arg. -func (tx *Tx) NamedExecContext(ctx context.Context, query string, arg interface{}) (sql.Result, error) { +func (tx *Tx) NamedExecContext(ctx context.Context, query safesql.TrustedSQLString, arg interface{}) (sql.Result, error) { return NamedExecContext(ctx, tx, query, arg) } // SelectContext using the prepared statement. // Any placeholder parameters are replaced with supplied args. func (s *Stmt) SelectContext(ctx context.Context, dest interface{}, args ...interface{}) error { - return SelectContext(ctx, &qStmt{s}, dest, "", args...) + return SelectContext(ctx, &qStmt{s}, dest, safesql.New(""), args...) } // GetContext using the prepared statement. // Any placeholder parameters are replaced with supplied args. // An error is returned if the result set is empty. func (s *Stmt) GetContext(ctx context.Context, dest interface{}, args ...interface{}) error { - return GetContext(ctx, &qStmt{s}, dest, "", args...) + return GetContext(ctx, &qStmt{s}, dest, safesql.New(""), args...) } // MustExecContext (panic) using this statement. Note that the query portion of // the error output will be blank, as Stmt does not expose its query. // Any placeholder parameters are replaced with supplied args. func (s *Stmt) MustExecContext(ctx context.Context, args ...interface{}) sql.Result { - return MustExecContext(ctx, &qStmt{s}, "", args...) + return MustExecContext(ctx, &qStmt{s}, safesql.New(""), args...) } // QueryRowxContext using this statement. // Any placeholder parameters are replaced with supplied args. func (s *Stmt) QueryRowxContext(ctx context.Context, args ...interface{}) *Row { qs := &qStmt{s} - return qs.QueryRowxContext(ctx, "", args...) + return qs.QueryRowxContext(ctx, safesql.New(""), args...) } // QueryxContext using this statement. // Any placeholder parameters are replaced with supplied args. func (s *Stmt) QueryxContext(ctx context.Context, args ...interface{}) (*Rows, error) { qs := &qStmt{s} - return qs.QueryxContext(ctx, "", args...) + return qs.QueryxContext(ctx, safesql.New(""), args...) } -func (q *qStmt) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { +func (q *qStmt) QueryContext(ctx context.Context, query safesql.TrustedSQLString, args ...interface{}) (*sql.Rows, error) { return q.Stmt.QueryContext(ctx, args...) } -func (q *qStmt) QueryxContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) { +func (q *qStmt) QueryxContext(ctx context.Context, query safesql.TrustedSQLString, args ...interface{}) (*Rows, error) { r, err := q.Stmt.QueryContext(ctx, args...) if err != nil { return nil, err @@ -407,11 +409,11 @@ func (q *qStmt) QueryxContext(ctx context.Context, query string, args ...interfa return &Rows{Rows: r, unsafe: q.Stmt.unsafe, Mapper: q.Stmt.Mapper}, err } -func (q *qStmt) QueryRowxContext(ctx context.Context, query string, args ...interface{}) *Row { +func (q *qStmt) QueryRowxContext(ctx context.Context, query safesql.TrustedSQLString, args ...interface{}) *Row { rows, err := q.Stmt.QueryContext(ctx, args...) return &Row{rows: rows, err: err, unsafe: q.Stmt.unsafe, Mapper: q.Stmt.Mapper} } -func (q *qStmt) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { +func (q *qStmt) ExecContext(ctx context.Context, query safesql.TrustedSQLString, args ...interface{}) (sql.Result, error) { return q.Stmt.ExecContext(ctx, args...) } diff --git a/sqlx_context_test.go b/sqlx_context_test.go index 91c5cba..19039db 100644 --- a/sqlx_context_test.go +++ b/sqlx_context_test.go @@ -23,6 +23,8 @@ import ( "time" _ "github.com/go-sql-driver/mysql" + "github.com/google/go-safeweb/safesql" + "github.com/google/go-safeweb/safesql/uncheckedconversions" _ "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" @@ -35,7 +37,10 @@ func MultiExecContext(ctx context.Context, e ExecerContext, query string) { stmts = stmts[:len(stmts)-1] } for _, s := range stmts { - _, err := e.ExecContext(ctx, s) + // UNSAFE: this is fine for tests when there's no easy alternative + // Never do this in production code! + ss := uncheckedconversions.TrustedSQLStringFromStringKnownToSatisfyTypeContract(s) + _, err := e.ExecContext(ctx, ss) if err != nil { fmt.Println(err, s) } @@ -68,19 +73,19 @@ func RunWithSchemaContext(ctx context.Context, schema Schema, t *testing.T, test func loadDefaultFixtureContext(ctx context.Context, db *DB, t *testing.T) { tx := db.MustBeginTx(ctx, nil) - tx.MustExecContext(ctx, tx.Rebind("INSERT INTO person (first_name, last_name, email) VALUES (?, ?, ?)"), "Jason", "Moiron", "jmoiron@jmoiron.net") - tx.MustExecContext(ctx, tx.Rebind("INSERT INTO person (first_name, last_name, email) VALUES (?, ?, ?)"), "John", "Doe", "johndoeDNE@gmail.net") - tx.MustExecContext(ctx, tx.Rebind("INSERT INTO place (country, city, telcode) VALUES (?, ?, ?)"), "United States", "New York", "1") - tx.MustExecContext(ctx, tx.Rebind("INSERT INTO place (country, telcode) VALUES (?, ?)"), "Hong Kong", "852") - tx.MustExecContext(ctx, tx.Rebind("INSERT INTO place (country, telcode) VALUES (?, ?)"), "Singapore", "65") + tx.MustExecContext(ctx, tx.Rebind(safesql.New("INSERT INTO person (first_name, last_name, email) VALUES (?, ?, ?)")), "Jason", "Moiron", "jmoiron@jmoiron.net") + tx.MustExecContext(ctx, tx.Rebind(safesql.New("INSERT INTO person (first_name, last_name, email) VALUES (?, ?, ?)")), "John", "Doe", "johndoeDNE@gmail.net") + tx.MustExecContext(ctx, tx.Rebind(safesql.New("INSERT INTO place (country, city, telcode) VALUES (?, ?, ?)")), "United States", "New York", "1") + tx.MustExecContext(ctx, tx.Rebind(safesql.New("INSERT INTO place (country, telcode) VALUES (?, ?)")), "Hong Kong", "852") + tx.MustExecContext(ctx, tx.Rebind(safesql.New("INSERT INTO place (country, telcode) VALUES (?, ?)")), "Singapore", "65") if db.DriverName() == "mysql" { - tx.MustExecContext(ctx, tx.Rebind("INSERT INTO capplace (`COUNTRY`, `TELCODE`) VALUES (?, ?)"), "Sarf Efrica", "27") + tx.MustExecContext(ctx, tx.Rebind(safesql.New("INSERT INTO capplace (`COUNTRY`, `TELCODE`) VALUES (?, ?)")), "Sarf Efrica", "27") } else { - tx.MustExecContext(ctx, tx.Rebind("INSERT INTO capplace (\"COUNTRY\", \"TELCODE\") VALUES (?, ?)"), "Sarf Efrica", "27") + tx.MustExecContext(ctx, tx.Rebind(safesql.New("INSERT INTO capplace (\"COUNTRY\", \"TELCODE\") VALUES (?, ?)")), "Sarf Efrica", "27") } - tx.MustExecContext(ctx, tx.Rebind("INSERT INTO employees (name, id) VALUES (?, ?)"), "Peter", "4444") - tx.MustExecContext(ctx, tx.Rebind("INSERT INTO employees (name, id, boss_id) VALUES (?, ?, ?)"), "Joe", "1", "4444") - tx.MustExecContext(ctx, tx.Rebind("INSERT INTO employees (name, id, boss_id) VALUES (?, ?, ?)"), "Martin", "2", "4444") + tx.MustExecContext(ctx, tx.Rebind(safesql.New("INSERT INTO employees (name, id) VALUES (?, ?)")), "Peter", "4444") + tx.MustExecContext(ctx, tx.Rebind(safesql.New("INSERT INTO employees (name, id, boss_id) VALUES (?, ?, ?)")), "Joe", "1", "4444") + tx.MustExecContext(ctx, tx.Rebind(safesql.New("INSERT INTO employees (name, id, boss_id) VALUES (?, ?, ?)")), "Martin", "2", "4444") tx.Commit() } @@ -99,21 +104,21 @@ func TestMissingNamesContextContext(t *testing.T) { // test Select first pps := []PersonPlus{} // pps lacks added_at destination - err := db.SelectContext(ctx, &pps, "SELECT * FROM person") + err := db.SelectContext(ctx, &pps, safesql.New("SELECT * FROM person")) if err == nil { t.Error("Expected missing name from Select to fail, but it did not.") } // test Get pp := PersonPlus{} - err = db.GetContext(ctx, &pp, "SELECT * FROM person LIMIT 1") + err = db.GetContext(ctx, &pp, safesql.New("SELECT * FROM person LIMIT 1")) if err == nil { t.Error("Expected missing name Get to fail, but it did not.") } // test naked StructScan pps = []PersonPlus{} - rows, err := db.QueryContext(ctx, "SELECT * FROM person LIMIT 1") + rows, err := db.QueryContext(ctx, safesql.New("SELECT * FROM person LIMIT 1")) if err != nil { t.Fatal(err) } @@ -127,21 +132,21 @@ func TestMissingNamesContextContext(t *testing.T) { // now try various things with unsafe set. db = db.Unsafe() pps = []PersonPlus{} - err = db.SelectContext(ctx, &pps, "SELECT * FROM person") + err = db.SelectContext(ctx, &pps, safesql.New("SELECT * FROM person")) if err != nil { t.Error(err) } // test Get pp = PersonPlus{} - err = db.GetContext(ctx, &pp, "SELECT * FROM person LIMIT 1") + err = db.GetContext(ctx, &pp, safesql.New("SELECT * FROM person LIMIT 1")) if err != nil { t.Error(err) } // test naked StructScan pps = []PersonPlus{} - rowsx, err := db.QueryxContext(ctx, "SELECT * FROM person LIMIT 1") + rowsx, err := db.QueryxContext(ctx, safesql.New("SELECT * FROM person LIMIT 1")) if err != nil { t.Fatal(err) } @@ -156,7 +161,7 @@ func TestMissingNamesContextContext(t *testing.T) { if !isUnsafe(db) { t.Error("Expected db to be unsafe, but it isn't") } - nstmt, err := db.PrepareNamedContext(ctx, `SELECT * FROM person WHERE first_name != :name`) + nstmt, err := db.PrepareNamedContext(ctx, safesql.New(`SELECT * FROM person WHERE first_name != :name`)) if err != nil { t.Fatal(err) } @@ -178,7 +183,7 @@ func TestMissingNamesContextContext(t *testing.T) { if isUnsafe(db) { t.Error("expected db to be safe but it isn't") } - nstmt, err = db.PrepareNamedContext(ctx, `SELECT * FROM person WHERE first_name != :name`) + nstmt, err = db.PrepareNamedContext(ctx, safesql.New(`SELECT * FROM person WHERE first_name != :name`)) if err != nil { t.Fatal(err) } @@ -213,8 +218,8 @@ func TestEmbeddedStructsContextContext(t *testing.T) { err := db.SelectContext( ctx, &peopleAndPlaces, - `SELECT person.*, place.* FROM - person natural join place`) + safesql.New(`SELECT person.*, place.* FROM + person natural join place`)) if err != nil { t.Fatal(err) } @@ -230,8 +235,8 @@ func TestEmbeddedStructsContextContext(t *testing.T) { // test embedded structs with StructScan rows, err := db.QueryxContext( ctx, - `SELECT person.*, place.* FROM - person natural join place`) + safesql.New(`SELECT person.*, place.* FROM + person natural join place`)) if err != nil { t.Error(err) } @@ -257,8 +262,8 @@ func TestEmbeddedStructsContextContext(t *testing.T) { err = db.SelectContext( ctx, &peopleAndPlacesPtrs, - `SELECT person.*, place.* FROM - person natural join place`) + safesql.New(`SELECT person.*, place.* FROM + person natural join place`)) if err != nil { t.Fatal(err) } @@ -273,7 +278,7 @@ func TestEmbeddedStructsContextContext(t *testing.T) { // test "deep nesting" l3s := []Loop3{} - err = db.SelectContext(ctx, &l3s, `select * from person`) + err = db.SelectContext(ctx, &l3s, safesql.New(`select * from person`)) if err != nil { t.Fatal(err) } @@ -285,7 +290,7 @@ func TestEmbeddedStructsContextContext(t *testing.T) { // test "embed conflicts" ec := []EmbedConflict{} - err = db.SelectContext(ctx, &ec, `select * from person`) + err = db.SelectContext(ctx, &ec, safesql.New(`select * from person`)) // I'm torn between erroring here or having some kind of working behavior // in order to allow for more flexibility in destination structs if err != nil { @@ -313,8 +318,8 @@ func TestJoinQueryContext(t *testing.T) { err := db.SelectContext(ctx, &employees, - `SELECT employees.*, boss.id "boss.id", boss.name "boss.name" FROM employees - JOIN employees AS boss ON employees.boss_id = boss.id`) + safesql.New(`SELECT employees.*, boss.id "boss.id", boss.name "boss.name" FROM employees + JOIN employees AS boss ON employees.boss_id = boss.id`)) if err != nil { t.Fatal(err) } @@ -350,11 +355,11 @@ func TestJoinQueryNamedPointerStructsContext(t *testing.T) { err := db.SelectContext(ctx, &employees, - `SELECT emp.name "emp1.name", emp.id "emp1.id", emp.boss_id "emp1.boss_id", + safesql.New(`SELECT emp.name "emp1.name", emp.id "emp1.id", emp.boss_id "emp1.boss_id", emp.name "emp2.name", emp.id "emp2.id", emp.boss_id "emp2.boss_id", boss.id "boss.id", boss.name "boss.name" FROM employees AS emp JOIN employees AS boss ON emp.boss_id = boss.id - `) + `)) if err != nil { t.Fatal(err) } @@ -373,7 +378,7 @@ func TestJoinQueryNamedPointerStructsContext(t *testing.T) { func TestSelectSliceMapTimeContext(t *testing.T) { RunWithSchemaContext(context.Background(), defaultSchema, t, func(ctx context.Context, db *DB, t *testing.T) { loadDefaultFixtureContext(ctx, db, t) - rows, err := db.QueryxContext(ctx, "SELECT * FROM person") + rows, err := db.QueryxContext(ctx, safesql.New("SELECT * FROM person")) if err != nil { t.Fatal(err) } @@ -384,7 +389,7 @@ func TestSelectSliceMapTimeContext(t *testing.T) { } } - rows, err = db.QueryxContext(ctx, "SELECT * FROM person") + rows, err = db.QueryxContext(ctx, safesql.New("SELECT * FROM person")) if err != nil { t.Fatal(err) } @@ -403,12 +408,12 @@ func TestNilReceiverContext(t *testing.T) { RunWithSchemaContext(context.Background(), defaultSchema, t, func(ctx context.Context, db *DB, t *testing.T) { loadDefaultFixtureContext(ctx, db, t) var p *Person - err := db.GetContext(ctx, p, "SELECT * FROM person LIMIT 1") + err := db.GetContext(ctx, p, safesql.New("SELECT * FROM person LIMIT 1")) if err == nil { t.Error("Expected error when getting into nil struct ptr.") } var pp *[]Person - err = db.SelectContext(ctx, pp, "SELECT * FROM person") + err = db.SelectContext(ctx, pp, safesql.New("SELECT * FROM person")) if err == nil { t.Error("Expected an error when selecting into nil slice ptr.") } @@ -459,14 +464,14 @@ func TestNamedQueryContext(t *testing.T) { Email: sql.NullString{String: "ben@doe.com", Valid: true}, } - q1 := `INSERT INTO person (first_name, last_name, email) VALUES (:first_name, :last_name, :email)` + q1 := safesql.New(`INSERT INTO person (first_name, last_name, email) VALUES (:first_name, :last_name, :email)`) _, err := db.NamedExecContext(ctx, q1, p) if err != nil { log.Fatal(err) } p2 := &Person{} - rows, err := db.NamedQueryContext(ctx, "SELECT * FROM person WHERE first_name=:first_name", p) + rows, err := db.NamedQueryContext(ctx, safesql.New("SELECT * FROM person WHERE first_name=:first_name"), p) if err != nil { log.Fatal(err) } @@ -505,14 +510,15 @@ func TestNamedQueryContext(t *testing.T) { // prepare queries for case sensitivity to test our ToUpper function. // postgres and sqlite accept "", but mysql uses ``; since Go's multi-line // strings are `` we use "" by default and swap out for MySQL - pdb := func(s string, db *DB) string { + pdb := func(s safesql.TrustedSQLString, db *DB) safesql.TrustedSQLString { if db.DriverName() == "mysql" { - return strings.Replace(s, `"`, "`", -1) + rawS := s.String() + return uncheckedconversions.TrustedSQLStringFromStringKnownToSatisfyTypeContract(strings.Replace(rawS, `"`, "`", -1)) } return s } - q1 = `INSERT INTO jsperson ("FIRST", last_name, "EMAIL") VALUES (:FIRST, :last_name, :EMAIL)` + q1 = safesql.New(`INSERT INTO jsperson ("FIRST", last_name, "EMAIL") VALUES (:FIRST, :last_name, :EMAIL)`) _, err = db.NamedExecContext(ctx, pdb(q1, db), jp) if err != nil { t.Fatal(err, db.DriverName()) @@ -538,13 +544,13 @@ func TestNamedQueryContext(t *testing.T) { } } - ns, err := db.PrepareNamed(pdb(` + ns, err := db.PrepareNamed(pdb(safesql.New(` SELECT * FROM jsperson WHERE "FIRST"=:FIRST AND last_name=:last_name AND "EMAIL"=:EMAIL - `, db)) + `), db)) if err != nil { t.Fatal(err) @@ -558,13 +564,13 @@ func TestNamedQueryContext(t *testing.T) { // Check exactly the same thing, but with db.NamedQuery, which does not go // through the PrepareNamed/NamedStmt path. - rows, err = db.NamedQueryContext(ctx, pdb(` + rows, err = db.NamedQueryContext(ctx, pdb(safesql.New(` SELECT * FROM jsperson WHERE "FIRST"=:FIRST AND last_name=:last_name AND "EMAIL"=:EMAIL - `, db), jp) + `), db), jp) if err != nil { t.Fatal(err) } @@ -595,7 +601,7 @@ func TestNamedQueryContext(t *testing.T) { Email: sql.NullString{String: "ben@doe.com", Valid: true}, } - q2 := `INSERT INTO place (id, name) VALUES (1, :name)` + q2 := safesql.New(`INSERT INTO place (id, name) VALUES (1, :name)`) _, err = db.NamedExecContext(ctx, q2, pl) if err != nil { log.Fatal(err) @@ -604,14 +610,14 @@ func TestNamedQueryContext(t *testing.T) { id := 1 pp.Place.ID = id - q3 := `INSERT INTO placeperson (first_name, last_name, email, place_id) VALUES (:first_name, :last_name, :email, :place.id)` + q3 := safesql.New(`INSERT INTO placeperson (first_name, last_name, email, place_id) VALUES (:first_name, :last_name, :email, :place.id)`) _, err = db.NamedExecContext(ctx, q3, pp) if err != nil { log.Fatal(err) } pp2 := &PlacePerson{} - rows, err = db.NamedQueryContext(ctx, ` + rows, err = db.NamedQueryContext(ctx, safesql.New(` SELECT first_name, last_name, @@ -621,7 +627,7 @@ func TestNamedQueryContext(t *testing.T) { FROM placeperson INNER JOIN place ON place.id = placeperson.place_id WHERE - place.id=:place.id`, pp) + place.id=:place.id`), pp) if err != nil { log.Fatal(err) } @@ -664,8 +670,8 @@ func TestNilInsertsContext(t *testing.T) { var v, v2 TT r := db.Rebind - db.MustExecContext(ctx, r(`INSERT INTO tt (id) VALUES (1)`)) - db.GetContext(ctx, &v, r(`SELECT * FROM tt`)) + db.MustExecContext(ctx, r(safesql.New(`INSERT INTO tt (id) VALUES (1)`))) + db.GetContext(ctx, &v, r(safesql.New(`SELECT * FROM tt`))) if v.ID != 1 { t.Errorf("Expecting id of 1, got %v", v.ID) } @@ -679,9 +685,9 @@ func TestNilInsertsContext(t *testing.T) { // as reflectx.FieldByIndexes attempts to allocate nil pointer receivers for // writing. This was fixed by creating & using the reflectx.FieldByIndexesReadOnly // function. This next line is important as it provides the only coverage for this. - db.NamedExecContext(ctx, `INSERT INTO tt (id, value) VALUES (:id, :value)`, v) + db.NamedExecContext(ctx, safesql.New(`INSERT INTO tt (id, value) VALUES (:id, :value)`), v) - db.GetContext(ctx, &v2, r(`SELECT * FROM tt WHERE id=2`)) + db.GetContext(ctx, &v2, r(safesql.New(`SELECT * FROM tt WHERE id=2`))) if v.ID != v2.ID { t.Errorf("%v != %v", v.ID, v2.ID) } @@ -706,12 +712,12 @@ func TestScanErrorContext(t *testing.T) { K int V string } - _, err := db.Exec(db.Rebind("INSERT INTO kv (k, v) VALUES (?, ?)"), "hi", 1) + _, err := db.Exec(db.Rebind(safesql.New("INSERT INTO kv (k, v) VALUES (?, ?)")), "hi", 1) if err != nil { t.Error(err) } - rows, err := db.QueryxContext(ctx, "SELECT * FROM kv") + rows, err := db.QueryxContext(ctx, safesql.New("SELECT * FROM kv")) if err != nil { t.Error(err) } @@ -732,14 +738,14 @@ func TestUsageContext(t *testing.T) { RunWithSchemaContext(context.Background(), defaultSchema, t, func(ctx context.Context, db *DB, t *testing.T) { loadDefaultFixtureContext(ctx, db, t) slicemembers := []SliceMember{} - err := db.SelectContext(ctx, &slicemembers, "SELECT * FROM place ORDER BY telcode ASC") + err := db.SelectContext(ctx, &slicemembers, safesql.New("SELECT * FROM place ORDER BY telcode ASC")) if err != nil { t.Fatal(err) } people := []Person{} - err = db.SelectContext(ctx, &people, "SELECT * FROM person ORDER BY first_name ASC") + err = db.SelectContext(ctx, &people, safesql.New("SELECT * FROM person ORDER BY first_name ASC")) if err != nil { t.Fatal(err) } @@ -759,7 +765,7 @@ func TestUsageContext(t *testing.T) { } jason = Person{} - err = db.GetContext(ctx, &jason, db.Rebind("SELECT * FROM person WHERE first_name=?"), "Jason") + err = db.GetContext(ctx, &jason, db.Rebind(safesql.New("SELECT * FROM person WHERE first_name=?")), "Jason") if err != nil { t.Fatal(err) @@ -768,7 +774,7 @@ func TestUsageContext(t *testing.T) { t.Errorf("Expecting to get back Jason, but got %v\n", jason.FirstName) } - err = db.GetContext(ctx, &jason, db.Rebind("SELECT * FROM person WHERE first_name=?"), "Foobar") + err = db.GetContext(ctx, &jason, db.Rebind(safesql.New("SELECT * FROM person WHERE first_name=?")), "Foobar") if err == nil { t.Errorf("Expecting an error, got nil\n") } @@ -778,7 +784,7 @@ func TestUsageContext(t *testing.T) { // The following tests check statement reuse, which was actually a problem // due to copying being done when creating Stmt's which was eventually removed - stmt1, err := db.PreparexContext(ctx, db.Rebind("SELECT * FROM person WHERE first_name=?")) + stmt1, err := db.PreparexContext(ctx, db.Rebind(safesql.New("SELECT * FROM person WHERE first_name=?"))) if err != nil { t.Fatal(err) } @@ -798,7 +804,7 @@ func TestUsageContext(t *testing.T) { t.Fatal(err) } - stmt2, err := db.PreparexContext(ctx, db.Rebind("SELECT * FROM person WHERE first_name=?")) + stmt2, err := db.PreparexContext(ctx, db.Rebind(safesql.New("SELECT * FROM person WHERE first_name=?"))) if err != nil { t.Fatal(err) } @@ -816,7 +822,7 @@ func TestUsageContext(t *testing.T) { tx.Commit() places := []*Place{} - err = db.SelectContext(ctx, &places, "SELECT telcode FROM place ORDER BY telcode ASC") + err = db.SelectContext(ctx, &places, safesql.New("SELECT telcode FROM place ORDER BY telcode ASC")) if err != nil { t.Fatal(err) } @@ -828,7 +834,7 @@ func TestUsageContext(t *testing.T) { } placesptr := []PlacePtr{} - err = db.SelectContext(ctx, &placesptr, "SELECT * FROM place ORDER BY telcode ASC") + err = db.SelectContext(ctx, &placesptr, safesql.New("SELECT * FROM place ORDER BY telcode ASC")) if err != nil { t.Error(err) } @@ -837,7 +843,7 @@ func TestUsageContext(t *testing.T) { // if you have null fields and use SELECT *, you must use sql.Null* in your struct // this test also verifies that you can use either a []Struct{} or a []*Struct{} places2 := []Place{} - err = db.SelectContext(ctx, &places2, "SELECT * FROM place ORDER BY telcode ASC") + err = db.SelectContext(ctx, &places2, safesql.New("SELECT * FROM place ORDER BY telcode ASC")) if err != nil { t.Fatal(err) } @@ -846,14 +852,14 @@ func TestUsageContext(t *testing.T) { // this should return a type error that &p is not a pointer to a struct slice p := Place{} - err = db.SelectContext(ctx, &p, "SELECT * FROM place ORDER BY telcode ASC") + err = db.SelectContext(ctx, &p, safesql.New("SELECT * FROM place ORDER BY telcode ASC")) if err == nil { t.Errorf("Expected an error, argument to select should be a pointer to a struct slice") } // this should be an error pl := []Place{} - err = db.SelectContext(ctx, pl, "SELECT * FROM place ORDER BY telcode ASC") + err = db.SelectContext(ctx, pl, safesql.New("SELECT * FROM place ORDER BY telcode ASC")) if err == nil { t.Errorf("Expected an error, argument to select should be a pointer to a struct slice, not a slice.") } @@ -862,7 +868,7 @@ func TestUsageContext(t *testing.T) { t.Errorf("Expected integer telcodes to work, got %#v", places) } - stmt, err := db.PreparexContext(ctx, db.Rebind("SELECT country, telcode FROM place WHERE telcode > ? ORDER BY telcode ASC")) + stmt, err := db.PreparexContext(ctx, db.Rebind(safesql.New("SELECT country, telcode FROM place WHERE telcode > ? ORDER BY telcode ASC"))) if err != nil { t.Error(err) } @@ -880,7 +886,7 @@ func TestUsageContext(t *testing.T) { t.Errorf("Expected the right telcodes, got %#v", places) } - rows, err := db.QueryxContext(ctx, "SELECT * FROM place") + rows, err := db.QueryxContext(ctx, safesql.New("SELECT * FROM place")) if err != nil { t.Fatal(err) } @@ -892,7 +898,7 @@ func TestUsageContext(t *testing.T) { } } - rows, err = db.QueryxContext(ctx, "SELECT * FROM place") + rows, err = db.QueryxContext(ctx, safesql.New("SELECT * FROM place")) if err != nil { t.Fatal(err) } @@ -908,7 +914,7 @@ func TestUsageContext(t *testing.T) { } } - rows, err = db.QueryxContext(ctx, "SELECT * FROM place") + rows, err = db.QueryxContext(ctx, safesql.New("SELECT * FROM place")) if err != nil { t.Fatal(err) } @@ -924,7 +930,7 @@ func TestUsageContext(t *testing.T) { // test advanced querying // test that NamedExec works with a map as well as a struct - _, err = db.NamedExecContext(ctx, "INSERT INTO person (first_name, last_name, email) VALUES (:first, :last, :email)", map[string]interface{}{ + _, err = db.NamedExecContext(ctx, safesql.New("INSERT INTO person (first_name, last_name, email) VALUES (:first, :last, :email)"), map[string]interface{}{ "first": "Bin", "last": "Smuth", "email": "bensmith@allblacks.nz", @@ -935,7 +941,7 @@ func TestUsageContext(t *testing.T) { // ensure that if the named param happens right at the end it still works // ensure that NamedQuery works with a map[string]interface{} - rows, err = db.NamedQueryContext(ctx, "SELECT * FROM person WHERE first_name=:first", map[string]interface{}{"first": "Bin"}) + rows, err = db.NamedQueryContext(ctx, safesql.New("SELECT * FROM person WHERE first_name=:first"), map[string]interface{}{"first": "Bin"}) if err != nil { t.Fatal(err) } @@ -959,13 +965,13 @@ func TestUsageContext(t *testing.T) { ben.Email = "binsmuth@allblacks.nz" // Insert via a named query using the struct - _, err = db.NamedExecContext(ctx, "INSERT INTO person (first_name, last_name, email) VALUES (:first_name, :last_name, :email)", ben) + _, err = db.NamedExecContext(ctx, safesql.New("INSERT INTO person (first_name, last_name, email) VALUES (:first_name, :last_name, :email)"), ben) if err != nil { t.Fatal(err) } - rows, err = db.NamedQueryContext(ctx, "SELECT * FROM person WHERE first_name=:first_name", ben) + rows, err = db.NamedQueryContext(ctx, safesql.New("SELECT * FROM person WHERE first_name=:first_name"), ben) if err != nil { t.Fatal(err) } @@ -983,14 +989,14 @@ func TestUsageContext(t *testing.T) { } // ensure that Get does not panic on emppty result set person := &Person{} - err = db.GetContext(ctx, person, "SELECT * FROM person WHERE first_name=$1", "does-not-exist") + err = db.GetContext(ctx, person, safesql.New("SELECT * FROM person WHERE first_name=$1"), "does-not-exist") if err == nil { t.Fatal("Should have got an error for Get on non-existent row.") } // lets test prepared statements some more - stmt, err = db.PreparexContext(ctx, db.Rebind("SELECT * FROM person WHERE first_name=?")) + stmt, err = db.PreparexContext(ctx, db.Rebind(safesql.New("SELECT * FROM person WHERE first_name=?"))) if err != nil { t.Fatal(err) } @@ -1012,7 +1018,7 @@ func TestUsageContext(t *testing.T) { } john = Person{} - stmt, err = db.PreparexContext(ctx, db.Rebind("SELECT * FROM person WHERE first_name=?")) + stmt, err = db.PreparexContext(ctx, db.Rebind(safesql.New("SELECT * FROM person WHERE first_name=?"))) if err != nil { t.Error(err) } @@ -1025,7 +1031,7 @@ func TestUsageContext(t *testing.T) { // THIS USED TO WORK BUT WILL NO LONGER WORK. db.MapperFunc(strings.ToUpper) rsa := CPlace{} - err = db.GetContext(ctx, &rsa, "SELECT * FROM capplace;") + err = db.GetContext(ctx, &rsa, safesql.New("SELECT * FROM capplace;")) if err != nil { t.Error(err, "in db:", db.DriverName()) } @@ -1035,20 +1041,20 @@ func TestUsageContext(t *testing.T) { // differently from the original. dbCopy := NewDb(db.DB, db.DriverName()) dbCopy.MapperFunc(strings.ToUpper) - err = dbCopy.GetContext(ctx, &rsa, "SELECT * FROM capplace;") + err = dbCopy.GetContext(ctx, &rsa, safesql.New("SELECT * FROM capplace;")) if err != nil { fmt.Println(db.DriverName()) t.Error(err) } - err = db.GetContext(ctx, &rsa, "SELECT * FROM cappplace;") + err = db.GetContext(ctx, &rsa, safesql.New("SELECT * FROM cappplace;")) if err == nil { t.Error("Expected no error, got ", err) } // test base type slices var sdest []string - rows, err = db.QueryxContext(ctx, "SELECT email FROM person ORDER BY email ASC;") + rows, err = db.QueryxContext(ctx, safesql.New("SELECT email FROM person ORDER BY email ASC;")) if err != nil { t.Error(err) } @@ -1059,7 +1065,7 @@ func TestUsageContext(t *testing.T) { // test Get with base types var count int - err = db.GetContext(ctx, &count, "SELECT count(*) FROM person;") + err = db.GetContext(ctx, &count, safesql.New("SELECT count(*) FROM person;")) if err != nil { t.Error(err) } @@ -1069,20 +1075,20 @@ func TestUsageContext(t *testing.T) { // test Get and Select with time.Time, #84 var addedAt time.Time - err = db.GetContext(ctx, &addedAt, "SELECT added_at FROM person LIMIT 1;") + err = db.GetContext(ctx, &addedAt, safesql.New("SELECT added_at FROM person LIMIT 1;")) if err != nil { t.Error(err) } var addedAts []time.Time - err = db.SelectContext(ctx, &addedAts, "SELECT added_at FROM person;") + err = db.SelectContext(ctx, &addedAts, safesql.New("SELECT added_at FROM person;")) if err != nil { t.Error(err) } // test it on a double pointer var pcount *int - err = db.GetContext(ctx, &pcount, "SELECT count(*) FROM person;") + err = db.GetContext(ctx, &pcount, safesql.New("SELECT count(*) FROM person;")) if err != nil { t.Error(err) } @@ -1092,7 +1098,7 @@ func TestUsageContext(t *testing.T) { // test Select... sdest = []string{} - err = db.SelectContext(ctx, &sdest, "SELECT first_name FROM person ORDER BY first_name ASC;") + err = db.SelectContext(ctx, &sdest, safesql.New("SELECT first_name FROM person ORDER BY first_name ASC;")) if err != nil { t.Error(err) } @@ -1104,7 +1110,7 @@ func TestUsageContext(t *testing.T) { } var nsdest []sql.NullString - err = db.SelectContext(ctx, &nsdest, "SELECT city FROM place ORDER BY city ASC") + err = db.SelectContext(ctx, &nsdest, safesql.New("SELECT city FROM place ORDER BY city ASC")) if err != nil { t.Error(err) } @@ -1140,7 +1146,7 @@ func TestEmbeddedMapsContext(t *testing.T) { {"Hello, World", PropertyMap{"one": "1", "two": "2"}}, {"Thanks, Joy", PropertyMap{"pull": "request"}}, } - q1 := `INSERT INTO message (string, properties) VALUES (:string, :properties);` + q1 := safesql.New(`INSERT INTO message (string, properties) VALUES (:string, :properties);`) for _, m := range messages { _, err := db.NamedExecContext(ctx, q1, m) if err != nil { @@ -1148,7 +1154,7 @@ func TestEmbeddedMapsContext(t *testing.T) { } } var count int - err := db.GetContext(ctx, &count, "SELECT count(*) FROM message") + err := db.GetContext(ctx, &count, safesql.New("SELECT count(*) FROM message")) if err != nil { t.Fatal(err) } @@ -1157,7 +1163,7 @@ func TestEmbeddedMapsContext(t *testing.T) { } var m Message - err = db.GetContext(ctx, &m, "SELECT * FROM message LIMIT 1;") + err = db.GetContext(ctx, &m, safesql.New("SELECT * FROM message LIMIT 1;")) if err != nil { t.Fatal(err) } @@ -1182,26 +1188,26 @@ func TestIssue197Context(t *testing.T) { RunWithSchemaContext(context.Background(), defaultSchema, t, func(ctx context.Context, db *DB, t *testing.T) { var err error var v, q Var - if err = db.GetContext(ctx, &v, `SELECT '{"a": "b"}' AS raw`); err != nil { + if err = db.GetContext(ctx, &v, safesql.New(`SELECT '{"a": "b"}' AS raw`)); err != nil { t.Fatal(err) } - if err = db.GetContext(ctx, &q, `SELECT 'null' AS raw`); err != nil { + if err = db.GetContext(ctx, &q, safesql.New(`SELECT 'null' AS raw`)); err != nil { t.Fatal(err) } var v2, q2 Var2 - if err = db.GetContext(ctx, &v2, `SELECT '{"a": "b"}' AS raw`); err != nil { + if err = db.GetContext(ctx, &v2, safesql.New(`SELECT '{"a": "b"}' AS raw`)); err != nil { t.Fatal(err) } - if err = db.GetContext(ctx, &q2, `SELECT 'null' AS raw`); err != nil { + if err = db.GetContext(ctx, &q2, safesql.New(`SELECT 'null' AS raw`)); err != nil { t.Fatal(err) } var v3, q3 Var3 - if err = db.QueryRowContext(ctx, `SELECT '{"a": "b"}' AS raw`).Scan(&v3.Raw); err != nil { + if err = db.QueryRowContext(ctx, safesql.New(`SELECT '{"a": "b"}' AS raw`)).Scan(&v3.Raw); err != nil { t.Fatal(err) } - if err = db.QueryRowContext(ctx, `SELECT '{"c": "d"}' AS raw`).Scan(&q3.Raw); err != nil { + if err = db.QueryRowContext(ctx, safesql.New(`SELECT '{"c": "d"}' AS raw`)).Scan(&q3.Raw); err != nil { t.Fatal(err) } t.Fail() @@ -1211,15 +1217,15 @@ func TestIssue197Context(t *testing.T) { func TestInContext(t *testing.T) { // some quite normal situations type tr struct { - q string + q safesql.TrustedSQLString args []interface{} c int } tests := []tr{ - {"SELECT * FROM foo WHERE x = ? AND v in (?) AND y = ?", + {safesql.New("SELECT * FROM foo WHERE x = ? AND v in (?) AND y = ?"), []interface{}{"foo", []int{0, 5, 7, 2, 9}, "bar"}, 7}, - {"SELECT * FROM foo WHERE x in (?)", + {safesql.New("SELECT * FROM foo WHERE x in (?)"), []interface{}{[]int{1, 2, 3, 4, 5, 6, 7, 8}}, 8}, } @@ -1231,8 +1237,8 @@ func TestInContext(t *testing.T) { if len(a) != test.c { t.Errorf("Expected %d args, but got %d (%+v)", test.c, len(a), a) } - if strings.Count(q, "?") != test.c { - t.Errorf("Expected %d bindVars, got %d", test.c, strings.Count(q, "?")) + if strings.Count(q.String(), "?") != test.c { + t.Errorf("Expected %d bindVars, got %d", test.c, strings.Count(q.String(), "?")) } } @@ -1240,7 +1246,7 @@ func TestInContext(t *testing.T) { // i'm not sure if this is the right behavior; this query/arg combo // might not work, but we shouldn't parse if we don't need to { - orig := "SELECT * FROM foo WHERE x = ? AND y = ?" + orig := safesql.New("SELECT * FROM foo WHERE x = ? AND y = ?") q, a, err := In(orig, "foo", "bar", "baz") if err != nil { t.Error(err) @@ -1255,15 +1261,15 @@ func TestInContext(t *testing.T) { tests = []tr{ // too many bindvars; slice present so should return error during parse - {"SELECT * FROM foo WHERE x = ? and y = ?", + {safesql.New("SELECT * FROM foo WHERE x = ? and y = ?"), []interface{}{"foo", []int{1, 2, 3}, "bar"}, 0}, // empty slice, should return error before parse - {"SELECT * FROM foo WHERE x = ?", + {safesql.New("SELECT * FROM foo WHERE x = ?"), []interface{}{[]int{}}, 0}, // too *few* bindvars, should return an error - {"SELECT * FROM foo WHERE x = ? AND y in (?)", + {safesql.New("SELECT * FROM foo WHERE x = ? AND y in (?)"), []interface{}{[]int{1, 2, 3}}, 0}, } @@ -1279,7 +1285,7 @@ func TestInContext(t *testing.T) { // tx.MustExecContext(ctx, tx.Rebind("INSERT INTO place (country, telcode) VALUES (?, ?)"), "Hong Kong", "852") // tx.MustExecContext(ctx, tx.Rebind("INSERT INTO place (country, telcode) VALUES (?, ?)"), "Singapore", "65") telcodes := []int{852, 65} - q := "SELECT * FROM place WHERE telcode IN(?) ORDER BY telcode" + q := safesql.New("SELECT * FROM place WHERE telcode IN(?) ORDER BY telcode") query, args, err := In(q, telcodes) if err != nil { t.Error(err) @@ -1322,10 +1328,10 @@ func TestEmbeddedLiteralsContext(t *testing.T) { K *string } - db.MustExecContext(ctx, db.Rebind("INSERT INTO x (k) VALUES (?), (?), (?);"), "one", "two", "three") + db.MustExecContext(ctx, db.Rebind(safesql.New("INSERT INTO x (k) VALUES (?), (?), (?);")), "one", "two", "three") target := t1{} - err := db.GetContext(ctx, &target, db.Rebind("SELECT * FROM x WHERE k=?"), "one") + err := db.GetContext(ctx, &target, db.Rebind(safesql.New("SELECT * FROM x WHERE k=?")), "one") if err != nil { t.Error(err) } @@ -1334,7 +1340,7 @@ func TestEmbeddedLiteralsContext(t *testing.T) { } target2 := t2{} - err = db.GetContext(ctx, &target2, db.Rebind("SELECT * FROM x WHERE k=?"), "one") + err = db.GetContext(ctx, &target2, db.Rebind(safesql.New("SELECT * FROM x WHERE k=?")), "one") if err != nil { t.Error(err) } @@ -1361,7 +1367,7 @@ func TestConn(t *testing.T) { t.Fatal(err) } - _, err = conn.ExecContext(ctx, conn.Rebind(`INSERT INTO tt_conn (id, value) VALUES (?, ?), (?, ?)`), 1, "a", 2, "b") + _, err = conn.ExecContext(ctx, conn.Rebind(safesql.New(`INSERT INTO tt_conn (id, value) VALUES (?, ?), (?, ?)`)), 1, "a", 2, "b") if err != nil { t.Fatal(err) } @@ -1373,7 +1379,7 @@ func TestConn(t *testing.T) { v := []s{} - err = conn.SelectContext(ctx, &v, "SELECT * FROM tt_conn ORDER BY id ASC") + err = conn.SelectContext(ctx, &v, safesql.New("SELECT * FROM tt_conn ORDER BY id ASC")) if err != nil { t.Fatal(err) } @@ -1383,7 +1389,7 @@ func TestConn(t *testing.T) { } v1 := s{} - err = conn.GetContext(ctx, &v1, conn.Rebind("SELECT * FROM tt_conn WHERE id=?"), 1) + err = conn.GetContext(ctx, &v1, conn.Rebind(safesql.New("SELECT * FROM tt_conn WHERE id=?")), 1) if err != nil { t.Fatal(err) @@ -1392,7 +1398,7 @@ func TestConn(t *testing.T) { t.Errorf("Expecting to get back 1, but got %v\n", v1.ID) } - stmt, err := conn.PreparexContext(ctx, conn.Rebind("SELECT * FROM tt_conn WHERE id=?")) + stmt, err := conn.PreparexContext(ctx, conn.Rebind(safesql.New("SELECT * FROM tt_conn WHERE id=?"))) if err != nil { t.Fatal(err) } @@ -1412,7 +1418,7 @@ func TestConn(t *testing.T) { t.Errorf("Expecting to get back 1, but got %v\n", v1.ID) } - rows, err := conn.QueryxContext(ctx, "SELECT * FROM tt_conn") + rows, err := conn.QueryxContext(ctx, safesql.New("SELECT * FROM tt_conn")) if err != nil { t.Fatal(err) } diff --git a/sqlx_test.go b/sqlx_test.go index 9fac2cd..0d05036 100644 --- a/sqlx_test.go +++ b/sqlx_test.go @@ -22,6 +22,8 @@ import ( "time" _ "github.com/go-sql-driver/mysql" + "github.com/google/go-safeweb/safesql" + "github.com/google/go-safeweb/safesql/uncheckedconversions" _ "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" @@ -210,7 +212,10 @@ func MultiExec(e Execer, query string) { stmts = stmts[:len(stmts)-1] } for _, s := range stmts { - _, err := e.Exec(s) + // UNSAFE: this is fine for tests when there's no easy alternative + // Never do this in production code! + ss := uncheckedconversions.TrustedSQLStringFromStringKnownToSatisfyTypeContract(s) + _, err := e.Exec(ss) if err != nil { fmt.Println(err, s) } @@ -243,19 +248,19 @@ func RunWithSchema(schema Schema, t *testing.T, test func(db *DB, t *testing.T, func loadDefaultFixture(db *DB, t *testing.T) { tx := db.MustBegin() - tx.MustExec(tx.Rebind("INSERT INTO person (first_name, last_name, email) VALUES (?, ?, ?)"), "Jason", "Moiron", "jmoiron@jmoiron.net") - tx.MustExec(tx.Rebind("INSERT INTO person (first_name, last_name, email) VALUES (?, ?, ?)"), "John", "Doe", "johndoeDNE@gmail.net") - tx.MustExec(tx.Rebind("INSERT INTO place (country, city, telcode) VALUES (?, ?, ?)"), "United States", "New York", "1") - tx.MustExec(tx.Rebind("INSERT INTO place (country, telcode) VALUES (?, ?)"), "Hong Kong", "852") - tx.MustExec(tx.Rebind("INSERT INTO place (country, telcode) VALUES (?, ?)"), "Singapore", "65") + tx.MustExec(tx.Rebind(safesql.New("INSERT INTO person (first_name, last_name, email) VALUES (?, ?, ?)")), "Jason", "Moiron", "jmoiron@jmoiron.net") + tx.MustExec(tx.Rebind(safesql.New("INSERT INTO person (first_name, last_name, email) VALUES (?, ?, ?)")), "John", "Doe", "johndoeDNE@gmail.net") + tx.MustExec(tx.Rebind(safesql.New("INSERT INTO place (country, city, telcode) VALUES (?, ?, ?)")), "United States", "New York", "1") + tx.MustExec(tx.Rebind(safesql.New("INSERT INTO place (country, telcode) VALUES (?, ?)")), "Hong Kong", "852") + tx.MustExec(tx.Rebind(safesql.New("INSERT INTO place (country, telcode) VALUES (?, ?)")), "Singapore", "65") if db.DriverName() == "mysql" { - tx.MustExec(tx.Rebind("INSERT INTO capplace (`COUNTRY`, `TELCODE`) VALUES (?, ?)"), "Sarf Efrica", "27") + tx.MustExec(tx.Rebind(safesql.New("INSERT INTO capplace (`COUNTRY`, `TELCODE`) VALUES (?, ?)")), "Sarf Efrica", "27") } else { - tx.MustExec(tx.Rebind("INSERT INTO capplace (\"COUNTRY\", \"TELCODE\") VALUES (?, ?)"), "Sarf Efrica", "27") + tx.MustExec(tx.Rebind(safesql.New("INSERT INTO capplace (\"COUNTRY\", \"TELCODE\") VALUES (?, ?)")), "Sarf Efrica", "27") } - tx.MustExec(tx.Rebind("INSERT INTO employees (name, id) VALUES (?, ?)"), "Peter", "4444") - tx.MustExec(tx.Rebind("INSERT INTO employees (name, id, boss_id) VALUES (?, ?, ?)"), "Joe", "1", "4444") - tx.MustExec(tx.Rebind("INSERT INTO employees (name, id, boss_id) VALUES (?, ?, ?)"), "Martin", "2", "4444") + tx.MustExec(tx.Rebind(safesql.New("INSERT INTO employees (name, id) VALUES (?, ?)")), "Peter", "4444") + tx.MustExec(tx.Rebind(safesql.New("INSERT INTO employees (name, id, boss_id) VALUES (?, ?, ?)")), "Joe", "1", "4444") + tx.MustExec(tx.Rebind(safesql.New("INSERT INTO employees (name, id, boss_id) VALUES (?, ?, ?)")), "Martin", "2", "4444") tx.Commit() } @@ -274,21 +279,21 @@ func TestMissingNames(t *testing.T) { // test Select first pps := []PersonPlus{} // pps lacks added_at destination - err := db.Select(&pps, "SELECT * FROM person") + err := db.Select(&pps, safesql.New("SELECT * FROM person")) if err == nil { t.Error("Expected missing name from Select to fail, but it did not.") } // test Get pp := PersonPlus{} - err = db.Get(&pp, "SELECT * FROM person LIMIT 1") + err = db.Get(&pp, safesql.New("SELECT * FROM person LIMIT 1")) if err == nil { t.Error("Expected missing name Get to fail, but it did not.") } // test naked StructScan pps = []PersonPlus{} - rows, err := db.Query("SELECT * FROM person LIMIT 1") + rows, err := db.Query(safesql.New("SELECT * FROM person LIMIT 1")) if err != nil { t.Fatal(err) } @@ -302,21 +307,21 @@ func TestMissingNames(t *testing.T) { // now try various things with unsafe set. db = db.Unsafe() pps = []PersonPlus{} - err = db.Select(&pps, "SELECT * FROM person") + err = db.Select(&pps, safesql.New("SELECT * FROM person")) if err != nil { t.Error(err) } // test Get pp = PersonPlus{} - err = db.Get(&pp, "SELECT * FROM person LIMIT 1") + err = db.Get(&pp, safesql.New("SELECT * FROM person LIMIT 1")) if err != nil { t.Error(err) } // test naked StructScan pps = []PersonPlus{} - rowsx, err := db.Queryx("SELECT * FROM person LIMIT 1") + rowsx, err := db.Queryx(safesql.New("SELECT * FROM person LIMIT 1")) if err != nil { t.Fatal(err) } @@ -331,7 +336,7 @@ func TestMissingNames(t *testing.T) { if !isUnsafe(db) { t.Error("Expected db to be unsafe, but it isn't") } - nstmt, err := db.PrepareNamed(`SELECT * FROM person WHERE first_name != :name`) + nstmt, err := db.PrepareNamed(safesql.New(`SELECT * FROM person WHERE first_name != :name`)) if err != nil { t.Fatal(err) } @@ -353,7 +358,7 @@ func TestMissingNames(t *testing.T) { if isUnsafe(db) { t.Error("expected db to be safe but it isn't") } - nstmt, err = db.PrepareNamed(`SELECT * FROM person WHERE first_name != :name`) + nstmt, err = db.PrepareNamed(safesql.New(`SELECT * FROM person WHERE first_name != :name`)) if err != nil { t.Fatal(err) } @@ -387,8 +392,8 @@ func TestEmbeddedStructs(t *testing.T) { peopleAndPlaces := []PersonPlace{} err := db.Select( &peopleAndPlaces, - `SELECT person.*, place.* FROM - person natural join place`) + safesql.New(`SELECT person.*, place.* FROM + person natural join place`)) if err != nil { t.Fatal(err) } @@ -403,8 +408,8 @@ func TestEmbeddedStructs(t *testing.T) { // test embedded structs with StructScan rows, err := db.Queryx( - `SELECT person.*, place.* FROM - person natural join place`) + safesql.New(`SELECT person.*, place.* FROM + person natural join place`)) if err != nil { t.Error(err) } @@ -429,8 +434,8 @@ func TestEmbeddedStructs(t *testing.T) { peopleAndPlacesPtrs := []PersonPlacePtr{} err = db.Select( &peopleAndPlacesPtrs, - `SELECT person.*, place.* FROM - person natural join place`) + safesql.New(`SELECT person.*, place.* FROM + person natural join place`)) if err != nil { t.Fatal(err) } @@ -445,7 +450,7 @@ func TestEmbeddedStructs(t *testing.T) { // test "deep nesting" l3s := []Loop3{} - err = db.Select(&l3s, `select * from person`) + err = db.Select(&l3s, safesql.New(`select * from person`)) if err != nil { t.Fatal(err) } @@ -457,7 +462,7 @@ func TestEmbeddedStructs(t *testing.T) { // test "embed conflicts" ec := []EmbedConflict{} - err = db.Select(&ec, `select * from person`) + err = db.Select(&ec, safesql.New(`select * from person`)) // I'm torn between erroring here or having some kind of working behavior // in order to allow for more flexibility in destination structs if err != nil { @@ -485,8 +490,8 @@ func TestJoinQuery(t *testing.T) { err := db.Select( &employees, - `SELECT employees.*, boss.id "boss.id", boss.name "boss.name" FROM employees - JOIN employees AS boss ON employees.boss_id = boss.id`) + safesql.New(`SELECT employees.*, boss.id "boss.id", boss.name "boss.name" FROM employees + JOIN employees AS boss ON employees.boss_id = boss.id`)) if err != nil { t.Fatal(err) } @@ -522,11 +527,11 @@ func TestJoinQueryNamedPointerStructs(t *testing.T) { err := db.Select( &employees, - `SELECT emp.name "emp1.name", emp.id "emp1.id", emp.boss_id "emp1.boss_id", + safesql.New(`SELECT emp.name "emp1.name", emp.id "emp1.id", emp.boss_id "emp1.boss_id", emp.name "emp2.name", emp.id "emp2.id", emp.boss_id "emp2.boss_id", boss.id "boss.id", boss.name "boss.name" FROM employees AS emp JOIN employees AS boss ON emp.boss_id = boss.id - `) + `)) if err != nil { t.Fatal(err) } @@ -545,7 +550,7 @@ func TestJoinQueryNamedPointerStructs(t *testing.T) { func TestSelectSliceMapTime(t *testing.T) { RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T, now string) { loadDefaultFixture(db, t) - rows, err := db.Queryx("SELECT * FROM person") + rows, err := db.Queryx(safesql.New("SELECT * FROM person")) if err != nil { t.Fatal(err) } @@ -556,7 +561,7 @@ func TestSelectSliceMapTime(t *testing.T) { } } - rows, err = db.Queryx("SELECT * FROM person") + rows, err = db.Queryx(safesql.New("SELECT * FROM person")) if err != nil { t.Fatal(err) } @@ -575,12 +580,12 @@ func TestNilReceiver(t *testing.T) { RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T, now string) { loadDefaultFixture(db, t) var p *Person - err := db.Get(p, "SELECT * FROM person LIMIT 1") + err := db.Get(p, safesql.New("SELECT * FROM person LIMIT 1")) if err == nil { t.Error("Expected error when getting into nil struct ptr.") } var pp *[]Person - err = db.Select(pp, "SELECT * FROM person") + err = db.Select(pp, safesql.New("SELECT * FROM person")) if err == nil { t.Error("Expected an error when selecting into nil slice ptr.") } @@ -631,14 +636,14 @@ func TestNamedQuery(t *testing.T) { Email: sql.NullString{String: "ben@doe.com", Valid: true}, } - q1 := `INSERT INTO person (first_name, last_name, email) VALUES (:first_name, :last_name, :email)` + q1 := safesql.New(`INSERT INTO person (first_name, last_name, email) VALUES (:first_name, :last_name, :email)`) _, err := db.NamedExec(q1, p) if err != nil { log.Fatal(err) } p2 := &Person{} - rows, err := db.NamedQuery("SELECT * FROM person WHERE first_name=:first_name", p) + rows, err := db.NamedQuery(safesql.New("SELECT * FROM person WHERE first_name=:first_name"), p) if err != nil { log.Fatal(err) } @@ -677,14 +682,15 @@ func TestNamedQuery(t *testing.T) { // prepare queries for case sensitivity to test our ToUpper function. // postgres and sqlite accept "", but mysql uses ``; since Go's multi-line // strings are `` we use "" by default and swap out for MySQL - pdb := func(s string, db *DB) string { + pdb := func(s safesql.TrustedSQLString, db *DB) safesql.TrustedSQLString { if db.DriverName() == "mysql" { - return strings.Replace(s, `"`, "`", -1) + rawS := s.String() + return uncheckedconversions.TrustedSQLStringFromStringKnownToSatisfyTypeContract(strings.Replace(rawS, `"`, "`", -1)) } return s } - q1 = `INSERT INTO jsperson ("FIRST", last_name, "EMAIL") VALUES (:FIRST, :last_name, :EMAIL)` + q1 = safesql.New(`INSERT INTO jsperson ("FIRST", last_name, "EMAIL") VALUES (:FIRST, :last_name, :EMAIL)`) _, err = db.NamedExec(pdb(q1, db), jp) if err != nil { t.Fatal(err, db.DriverName()) @@ -710,13 +716,13 @@ func TestNamedQuery(t *testing.T) { } } - ns, err := db.PrepareNamed(pdb(` + ns, err := db.PrepareNamed(pdb(safesql.New(` SELECT * FROM jsperson WHERE "FIRST"=:FIRST AND last_name=:last_name AND "EMAIL"=:EMAIL - `, db)) + `), db)) if err != nil { t.Fatal(err) @@ -730,13 +736,13 @@ func TestNamedQuery(t *testing.T) { // Check exactly the same thing, but with db.NamedQuery, which does not go // through the PrepareNamed/NamedStmt path. - rows, err = db.NamedQuery(pdb(` + rows, err = db.NamedQuery(pdb(safesql.New(` SELECT * FROM jsperson WHERE "FIRST"=:FIRST AND last_name=:last_name AND "EMAIL"=:EMAIL - `, db), jp) + `), db), jp) if err != nil { t.Fatal(err) } @@ -767,7 +773,7 @@ func TestNamedQuery(t *testing.T) { Email: sql.NullString{String: "ben@doe.com", Valid: true}, } - q2 := `INSERT INTO place (id, name) VALUES (1, :name)` + q2 := safesql.New(`INSERT INTO place (id, name) VALUES (1, :name)`) _, err = db.NamedExec(q2, pl) if err != nil { log.Fatal(err) @@ -776,14 +782,14 @@ func TestNamedQuery(t *testing.T) { id := 1 pp.Place.ID = id - q3 := `INSERT INTO placeperson (first_name, last_name, email, place_id) VALUES (:first_name, :last_name, :email, :place.id)` + q3 := safesql.New(`INSERT INTO placeperson (first_name, last_name, email, place_id) VALUES (:first_name, :last_name, :email, :place.id)`) _, err = db.NamedExec(q3, pp) if err != nil { log.Fatal(err) } pp2 := &PlacePerson{} - rows, err = db.NamedQuery(` + rows, err = db.NamedQuery(safesql.New(` SELECT first_name, last_name, @@ -793,7 +799,7 @@ func TestNamedQuery(t *testing.T) { FROM placeperson INNER JOIN place ON place.id = placeperson.place_id WHERE - place.id=:place.id`, pp) + place.id=:place.id`), pp) if err != nil { log.Fatal(err) } @@ -836,8 +842,8 @@ func TestNilInserts(t *testing.T) { var v, v2 TT r := db.Rebind - db.MustExec(r(`INSERT INTO tt (id) VALUES (1)`)) - db.Get(&v, r(`SELECT * FROM tt`)) + db.MustExec(r(safesql.New(`INSERT INTO tt (id) VALUES (1)`))) + db.Get(&v, r(safesql.New(`SELECT * FROM tt`))) if v.ID != 1 { t.Errorf("Expecting id of 1, got %v", v.ID) } @@ -851,9 +857,9 @@ func TestNilInserts(t *testing.T) { // as reflectx.FieldByIndexes attempts to allocate nil pointer receivers for // writing. This was fixed by creating & using the reflectx.FieldByIndexesReadOnly // function. This next line is important as it provides the only coverage for this. - db.NamedExec(`INSERT INTO tt (id, value) VALUES (:id, :value)`, v) + db.NamedExec(safesql.New(`INSERT INTO tt (id, value) VALUES (:id, :value)`), v) - db.Get(&v2, r(`SELECT * FROM tt WHERE id=2`)) + db.Get(&v2, r(safesql.New(`SELECT * FROM tt WHERE id=2`))) if v.ID != v2.ID { t.Errorf("%v != %v", v.ID, v2.ID) } @@ -878,12 +884,12 @@ func TestScanError(t *testing.T) { K int V string } - _, err := db.Exec(db.Rebind("INSERT INTO kv (k, v) VALUES (?, ?)"), "hi", 1) + _, err := db.Exec(db.Rebind(safesql.New("INSERT INTO kv (k, v) VALUES (?, ?)")), "hi", 1) if err != nil { t.Error(err) } - rows, err := db.Queryx("SELECT * FROM kv") + rows, err := db.Queryx(safesql.New("SELECT * FROM kv")) if err != nil { t.Error(err) } @@ -900,7 +906,7 @@ func TestScanError(t *testing.T) { func TestMultiInsert(t *testing.T) { RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T, now string) { loadDefaultFixture(db, t) - q := db.Rebind(`INSERT INTO employees (name, id) VALUES (?, ?), (?, ?);`) + q := db.Rebind(safesql.New(`INSERT INTO employees (name, id) VALUES (?, ?), (?, ?);`)) db.MustExec(q, "Name1", 400, "name2", 500, @@ -915,14 +921,14 @@ func TestUsage(t *testing.T) { RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T, now string) { loadDefaultFixture(db, t) slicemembers := []SliceMember{} - err := db.Select(&slicemembers, "SELECT * FROM place ORDER BY telcode ASC") + err := db.Select(&slicemembers, safesql.New("SELECT * FROM place ORDER BY telcode ASC")) if err != nil { t.Fatal(err) } people := []Person{} - err = db.Select(&people, "SELECT * FROM person ORDER BY first_name ASC") + err = db.Select(&people, safesql.New("SELECT * FROM person ORDER BY first_name ASC")) if err != nil { t.Fatal(err) } @@ -942,7 +948,7 @@ func TestUsage(t *testing.T) { } jason = Person{} - err = db.Get(&jason, db.Rebind("SELECT * FROM person WHERE first_name=?"), "Jason") + err = db.Get(&jason, db.Rebind(safesql.New("SELECT * FROM person WHERE first_name=?")), "Jason") if err != nil { t.Fatal(err) @@ -951,7 +957,7 @@ func TestUsage(t *testing.T) { t.Errorf("Expecting to get back Jason, but got %v\n", jason.FirstName) } - err = db.Get(&jason, db.Rebind("SELECT * FROM person WHERE first_name=?"), "Foobar") + err = db.Get(&jason, db.Rebind(safesql.New("SELECT * FROM person WHERE first_name=?")), "Foobar") if err == nil { t.Errorf("Expecting an error, got nil\n") } @@ -961,7 +967,7 @@ func TestUsage(t *testing.T) { // The following tests check statement reuse, which was actually a problem // due to copying being done when creating Stmt's which was eventually removed - stmt1, err := db.Preparex(db.Rebind("SELECT * FROM person WHERE first_name=?")) + stmt1, err := db.Preparex(db.Rebind(safesql.New("SELECT * FROM person WHERE first_name=?"))) if err != nil { t.Fatal(err) } @@ -981,7 +987,7 @@ func TestUsage(t *testing.T) { t.Fatal(err) } - stmt2, err := db.Preparex(db.Rebind("SELECT * FROM person WHERE first_name=?")) + stmt2, err := db.Preparex(db.Rebind(safesql.New("SELECT * FROM person WHERE first_name=?"))) if err != nil { t.Fatal(err) } @@ -999,7 +1005,7 @@ func TestUsage(t *testing.T) { tx.Commit() places := []*Place{} - err = db.Select(&places, "SELECT telcode FROM place ORDER BY telcode ASC") + err = db.Select(&places, safesql.New("SELECT telcode FROM place ORDER BY telcode ASC")) if err != nil { t.Fatal(err) } @@ -1011,7 +1017,7 @@ func TestUsage(t *testing.T) { } placesptr := []PlacePtr{} - err = db.Select(&placesptr, "SELECT * FROM place ORDER BY telcode ASC") + err = db.Select(&placesptr, safesql.New("SELECT * FROM place ORDER BY telcode ASC")) if err != nil { t.Error(err) } @@ -1020,7 +1026,7 @@ func TestUsage(t *testing.T) { // if you have null fields and use SELECT *, you must use sql.Null* in your struct // this test also verifies that you can use either a []Struct{} or a []*Struct{} places2 := []Place{} - err = db.Select(&places2, "SELECT * FROM place ORDER BY telcode ASC") + err = db.Select(&places2, safesql.New("SELECT * FROM place ORDER BY telcode ASC")) if err != nil { t.Fatal(err) } @@ -1029,14 +1035,14 @@ func TestUsage(t *testing.T) { // this should return a type error that &p is not a pointer to a struct slice p := Place{} - err = db.Select(&p, "SELECT * FROM place ORDER BY telcode ASC") + err = db.Select(&p, safesql.New("SELECT * FROM place ORDER BY telcode ASC")) if err == nil { t.Errorf("Expected an error, argument to select should be a pointer to a struct slice") } // this should be an error pl := []Place{} - err = db.Select(pl, "SELECT * FROM place ORDER BY telcode ASC") + err = db.Select(pl, safesql.New("SELECT * FROM place ORDER BY telcode ASC")) if err == nil { t.Errorf("Expected an error, argument to select should be a pointer to a struct slice, not a slice.") } @@ -1045,7 +1051,7 @@ func TestUsage(t *testing.T) { t.Errorf("Expected integer telcodes to work, got %#v", places) } - stmt, err := db.Preparex(db.Rebind("SELECT country, telcode FROM place WHERE telcode > ? ORDER BY telcode ASC")) + stmt, err := db.Preparex(db.Rebind(safesql.New("SELECT country, telcode FROM place WHERE telcode > ? ORDER BY telcode ASC"))) if err != nil { t.Error(err) } @@ -1063,7 +1069,7 @@ func TestUsage(t *testing.T) { t.Errorf("Expected the right telcodes, got %#v", places) } - rows, err := db.Queryx("SELECT * FROM place") + rows, err := db.Queryx(safesql.New("SELECT * FROM place")) if err != nil { t.Fatal(err) } @@ -1075,7 +1081,7 @@ func TestUsage(t *testing.T) { } } - rows, err = db.Queryx("SELECT * FROM place") + rows, err = db.Queryx(safesql.New("SELECT * FROM place")) if err != nil { t.Fatal(err) } @@ -1091,7 +1097,7 @@ func TestUsage(t *testing.T) { } } - rows, err = db.Queryx("SELECT * FROM place") + rows, err = db.Queryx(safesql.New("SELECT * FROM place")) if err != nil { t.Fatal(err) } @@ -1107,7 +1113,7 @@ func TestUsage(t *testing.T) { // test advanced querying // test that NamedExec works with a map as well as a struct - _, err = db.NamedExec("INSERT INTO person (first_name, last_name, email) VALUES (:first, :last, :email)", map[string]interface{}{ + _, err = db.NamedExec(safesql.New("INSERT INTO person (first_name, last_name, email) VALUES (:first, :last, :email)"), map[string]interface{}{ "first": "Bin", "last": "Smuth", "email": "bensmith@allblacks.nz", @@ -1118,7 +1124,7 @@ func TestUsage(t *testing.T) { // ensure that if the named param happens right at the end it still works // ensure that NamedQuery works with a map[string]interface{} - rows, err = db.NamedQuery("SELECT * FROM person WHERE first_name=:first", map[string]interface{}{"first": "Bin"}) + rows, err = db.NamedQuery(safesql.New("SELECT * FROM person WHERE first_name=:first"), map[string]interface{}{"first": "Bin"}) if err != nil { t.Fatal(err) } @@ -1142,13 +1148,13 @@ func TestUsage(t *testing.T) { ben.Email = "binsmuth@allblacks.nz" // Insert via a named query using the struct - _, err = db.NamedExec("INSERT INTO person (first_name, last_name, email) VALUES (:first_name, :last_name, :email)", ben) + _, err = db.NamedExec(safesql.New("INSERT INTO person (first_name, last_name, email) VALUES (:first_name, :last_name, :email)"), ben) if err != nil { t.Fatal(err) } - rows, err = db.NamedQuery("SELECT * FROM person WHERE first_name=:first_name", ben) + rows, err = db.NamedQuery(safesql.New("SELECT * FROM person WHERE first_name=:first_name"), ben) if err != nil { t.Fatal(err) } @@ -1166,14 +1172,14 @@ func TestUsage(t *testing.T) { } // ensure that Get does not panic on emppty result set person := &Person{} - err = db.Get(person, "SELECT * FROM person WHERE first_name=$1", "does-not-exist") + err = db.Get(person, safesql.New("SELECT * FROM person WHERE first_name=$1"), "does-not-exist") if err == nil { t.Fatal("Should have got an error for Get on non-existent row.") } // lets test prepared statements some more - stmt, err = db.Preparex(db.Rebind("SELECT * FROM person WHERE first_name=?")) + stmt, err = db.Preparex(db.Rebind(safesql.New("SELECT * FROM person WHERE first_name=?"))) if err != nil { t.Fatal(err) } @@ -1195,7 +1201,7 @@ func TestUsage(t *testing.T) { } john = Person{} - stmt, err = db.Preparex(db.Rebind("SELECT * FROM person WHERE first_name=?")) + stmt, err = db.Preparex(db.Rebind(safesql.New("SELECT * FROM person WHERE first_name=?"))) if err != nil { t.Error(err) } @@ -1208,7 +1214,7 @@ func TestUsage(t *testing.T) { // THIS USED TO WORK BUT WILL NO LONGER WORK. db.MapperFunc(strings.ToUpper) rsa := CPlace{} - err = db.Get(&rsa, "SELECT * FROM capplace;") + err = db.Get(&rsa, safesql.New("SELECT * FROM capplace;")) if err != nil { t.Error(err, "in db:", db.DriverName()) } @@ -1218,20 +1224,20 @@ func TestUsage(t *testing.T) { // differently from the original. dbCopy := NewDb(db.DB, db.DriverName()) dbCopy.MapperFunc(strings.ToUpper) - err = dbCopy.Get(&rsa, "SELECT * FROM capplace;") + err = dbCopy.Get(&rsa, safesql.New("SELECT * FROM capplace;")) if err != nil { fmt.Println(db.DriverName()) t.Error(err) } - err = db.Get(&rsa, "SELECT * FROM cappplace;") + err = db.Get(&rsa, safesql.New("SELECT * FROM cappplace;")) if err == nil { t.Error("Expected no error, got ", err) } // test base type slices var sdest []string - rows, err = db.Queryx("SELECT email FROM person ORDER BY email ASC;") + rows, err = db.Queryx(safesql.New("SELECT email FROM person ORDER BY email ASC;")) if err != nil { t.Error(err) } @@ -1242,7 +1248,7 @@ func TestUsage(t *testing.T) { // test Get with base types var count int - err = db.Get(&count, "SELECT count(*) FROM person;") + err = db.Get(&count, safesql.New("SELECT count(*) FROM person;")) if err != nil { t.Error(err) } @@ -1252,20 +1258,20 @@ func TestUsage(t *testing.T) { // test Get and Select with time.Time, #84 var addedAt time.Time - err = db.Get(&addedAt, "SELECT added_at FROM person LIMIT 1;") + err = db.Get(&addedAt, safesql.New("SELECT added_at FROM person LIMIT 1;")) if err != nil { t.Error(err) } var addedAts []time.Time - err = db.Select(&addedAts, "SELECT added_at FROM person;") + err = db.Select(&addedAts, safesql.New("SELECT added_at FROM person;")) if err != nil { t.Error(err) } // test it on a double pointer var pcount *int - err = db.Get(&pcount, "SELECT count(*) FROM person;") + err = db.Get(&pcount, safesql.New("SELECT count(*) FROM person;")) if err != nil { t.Error(err) } @@ -1275,7 +1281,7 @@ func TestUsage(t *testing.T) { // test Select... sdest = []string{} - err = db.Select(&sdest, "SELECT first_name FROM person ORDER BY first_name ASC;") + err = db.Select(&sdest, safesql.New("SELECT first_name FROM person ORDER BY first_name ASC;")) if err != nil { t.Error(err) } @@ -1287,7 +1293,7 @@ func TestUsage(t *testing.T) { } var nsdest []sql.NullString - err = db.Select(&nsdest, "SELECT city FROM place ORDER BY city ASC") + err = db.Select(&nsdest, safesql.New("SELECT city FROM place ORDER BY city ASC")) if err != nil { t.Error(err) } @@ -1316,41 +1322,41 @@ func TestDoNotPanicOnConnect(t *testing.T) { } func TestRebind(t *testing.T) { - q1 := `INSERT INTO foo (a, b, c, d, e, f, g, h, i) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)` - q2 := `INSERT INTO foo (a, b, c) VALUES (?, ?, "foo"), ("Hi", ?, ?)` + q1 := safesql.New(`INSERT INTO foo (a, b, c, d, e, f, g, h, i) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`) + q2 := safesql.New(`INSERT INTO foo (a, b, c) VALUES (?, ?, "foo"), ("Hi", ?, ?)`) s1 := Rebind(DOLLAR, q1) s2 := Rebind(DOLLAR, q2) - if s1 != `INSERT INTO foo (a, b, c, d, e, f, g, h, i) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)` { + if s1 != safesql.New(`INSERT INTO foo (a, b, c, d, e, f, g, h, i) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)`) { t.Errorf("q1 failed") } - if s2 != `INSERT INTO foo (a, b, c) VALUES ($1, $2, "foo"), ("Hi", $3, $4)` { + if s2 != safesql.New(`INSERT INTO foo (a, b, c) VALUES ($1, $2, "foo"), ("Hi", $3, $4)`) { t.Errorf("q2 failed") } s1 = Rebind(AT, q1) s2 = Rebind(AT, q2) - if s1 != `INSERT INTO foo (a, b, c, d, e, f, g, h, i) VALUES (@p1, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10)` { + if s1 != safesql.New(`INSERT INTO foo (a, b, c, d, e, f, g, h, i) VALUES (@p1, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10)`) { t.Errorf("q1 failed") } - if s2 != `INSERT INTO foo (a, b, c) VALUES (@p1, @p2, "foo"), ("Hi", @p3, @p4)` { + if s2 != safesql.New(`INSERT INTO foo (a, b, c) VALUES (@p1, @p2, "foo"), ("Hi", @p3, @p4)`) { t.Errorf("q2 failed") } s1 = Rebind(NAMED, q1) s2 = Rebind(NAMED, q2) - ex1 := `INSERT INTO foo (a, b, c, d, e, f, g, h, i) VALUES ` + - `(:arg1, :arg2, :arg3, :arg4, :arg5, :arg6, :arg7, :arg8, :arg9, :arg10)` + ex1 := safesql.New(`INSERT INTO foo (a, b, c, d, e, f, g, h, i) VALUES ` + + `(:arg1, :arg2, :arg3, :arg4, :arg5, :arg6, :arg7, :arg8, :arg9, :arg10)`) if s1 != ex1 { t.Error("q1 failed on Named params") } - ex2 := `INSERT INTO foo (a, b, c) VALUES (:arg1, :arg2, "foo"), ("Hi", :arg3, :arg4)` + ex2 := safesql.New(`INSERT INTO foo (a, b, c) VALUES (:arg1, :arg2, "foo"), ("Hi", :arg3, :arg4)`) if s2 != ex2 { t.Error("q2 failed on Named params") } @@ -1358,7 +1364,7 @@ func TestRebind(t *testing.T) { func TestBindMap(t *testing.T) { // Test that it works.. - q1 := `INSERT INTO foo (a, b, c, d) VALUES (:name, :age, :first, :last)` + q1 := safesql.New(`INSERT INTO foo (a, b, c, d) VALUES (:name, :age, :first, :last)`) am := map[string]interface{}{ "name": "Jason Moiron", "age": 30, @@ -1367,7 +1373,7 @@ func TestBindMap(t *testing.T) { } bq, args, _ := bindMap(QUESTION, q1, am) - expect := `INSERT INTO foo (a, b, c, d) VALUES (?, ?, ?, ?)` + expect := safesql.New(`INSERT INTO foo (a, b, c, d) VALUES (?, ?, ?, ?)`) if bq != expect { t.Errorf("Interpolation of query failed: got `%v`, expected `%v`\n", bq, expect) } @@ -1436,7 +1442,7 @@ func TestEmbeddedMaps(t *testing.T) { {"Hello, World", PropertyMap{"one": "1", "two": "2"}}, {"Thanks, Joy", PropertyMap{"pull": "request"}}, } - q1 := `INSERT INTO message (string, properties) VALUES (:string, :properties);` + q1 := safesql.New(`INSERT INTO message (string, properties) VALUES (:string, :properties);`) for _, m := range messages { _, err := db.NamedExec(q1, m) if err != nil { @@ -1444,7 +1450,7 @@ func TestEmbeddedMaps(t *testing.T) { } } var count int - err := db.Get(&count, "SELECT count(*) FROM message") + err := db.Get(&count, safesql.New("SELECT count(*) FROM message")) if err != nil { t.Fatal(err) } @@ -1453,7 +1459,7 @@ func TestEmbeddedMaps(t *testing.T) { } var m Message - err = db.Get(&m, "SELECT * FROM message LIMIT 1;") + err = db.Get(&m, safesql.New("SELECT * FROM message LIMIT 1;")) if err != nil { t.Fatal(err) } @@ -1478,26 +1484,26 @@ func TestIssue197(t *testing.T) { RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T, now string) { var err error var v, q Var - if err = db.Get(&v, `SELECT '{"a": "b"}' AS raw`); err != nil { + if err = db.Get(&v, safesql.New(`SELECT '{"a": "b"}' AS raw`)); err != nil { t.Fatal(err) } - if err = db.Get(&q, `SELECT 'null' AS raw`); err != nil { + if err = db.Get(&q, safesql.New(`SELECT 'null' AS raw`)); err != nil { t.Fatal(err) } var v2, q2 Var2 - if err = db.Get(&v2, `SELECT '{"a": "b"}' AS raw`); err != nil { + if err = db.Get(&v2, safesql.New(`SELECT '{"a": "b"}' AS raw`)); err != nil { t.Fatal(err) } - if err = db.Get(&q2, `SELECT 'null' AS raw`); err != nil { + if err = db.Get(&q2, safesql.New(`SELECT 'null' AS raw`)); err != nil { t.Fatal(err) } var v3, q3 Var3 - if err = db.QueryRow(`SELECT '{"a": "b"}' AS raw`).Scan(&v3.Raw); err != nil { + if err = db.QueryRow(safesql.New(`SELECT '{"a": "b"}' AS raw`)).Scan(&v3.Raw); err != nil { t.Fatal(err) } - if err = db.QueryRow(`SELECT '{"c": "d"}' AS raw`).Scan(&q3.Raw); err != nil { + if err = db.QueryRow(safesql.New(`SELECT '{"c": "d"}' AS raw`)).Scan(&q3.Raw); err != nil { t.Fatal(err) } t.Fail() @@ -1507,21 +1513,21 @@ func TestIssue197(t *testing.T) { func TestIn(t *testing.T) { // some quite normal situations type tr struct { - q string + q safesql.TrustedSQLString args []interface{} c int } tests := []tr{ - {"SELECT * FROM foo WHERE x = ? AND v in (?) AND y = ?", + {safesql.New("SELECT * FROM foo WHERE x = ? AND v in (?) AND y = ?"), []interface{}{"foo", []int{0, 5, 7, 2, 9}, "bar"}, 7}, - {"SELECT * FROM foo WHERE x in (?)", + {safesql.New("SELECT * FROM foo WHERE x in (?)"), []interface{}{[]int{1, 2, 3, 4, 5, 6, 7, 8}}, 8}, - {"SELECT * FROM foo WHERE x = ? AND y in (?)", + {safesql.New("SELECT * FROM foo WHERE x = ? AND y in (?)"), []interface{}{[]byte("foo"), []int{0, 5, 3}}, 4}, - {"SELECT * FROM foo WHERE x = ? AND y IN (?)", + {safesql.New("SELECT * FROM foo WHERE x = ? AND y IN (?)"), []interface{}{sql.NullString{Valid: false}, []string{"a", "b"}}, 3}, } @@ -1533,8 +1539,8 @@ func TestIn(t *testing.T) { if len(a) != test.c { t.Errorf("Expected %d args, but got %d (%+v)", test.c, len(a), a) } - if strings.Count(q, "?") != test.c { - t.Errorf("Expected %d bindVars, got %d", test.c, strings.Count(q, "?")) + if strings.Count(q.String(), "?") != test.c { + t.Errorf("Expected %d bindVars, got %d", test.c, strings.Count(q.String(), "?")) } } @@ -1542,7 +1548,7 @@ func TestIn(t *testing.T) { // i'm not sure if this is the right behavior; this query/arg combo // might not work, but we shouldn't parse if we don't need to { - orig := "SELECT * FROM foo WHERE x = ? AND y = ?" + orig := safesql.New("SELECT * FROM foo WHERE x = ? AND y = ?") q, a, err := In(orig, "foo", "bar", "baz") if err != nil { t.Error(err) @@ -1557,15 +1563,15 @@ func TestIn(t *testing.T) { tests = []tr{ // too many bindvars; slice present so should return error during parse - {"SELECT * FROM foo WHERE x = ? and y = ?", + {safesql.New("SELECT * FROM foo WHERE x = ? and y = ?"), []interface{}{"foo", []int{1, 2, 3}, "bar"}, 0}, // empty slice, should return error before parse - {"SELECT * FROM foo WHERE x = ?", + {safesql.New("SELECT * FROM foo WHERE x = ?"), []interface{}{[]int{}}, 0}, // too *few* bindvars, should return an error - {"SELECT * FROM foo WHERE x = ? AND y in (?)", + {safesql.New("SELECT * FROM foo WHERE x = ? AND y in (?)"), []interface{}{[]int{1, 2, 3}}, 0}, } @@ -1581,7 +1587,7 @@ func TestIn(t *testing.T) { // tx.MustExec(tx.Rebind("INSERT INTO place (country, telcode) VALUES (?, ?)"), "Hong Kong", "852") // tx.MustExec(tx.Rebind("INSERT INTO place (country, telcode) VALUES (?, ?)"), "Singapore", "65") telcodes := []int{852, 65} - q := "SELECT * FROM place WHERE telcode IN(?) ORDER BY telcode" + q := safesql.New("SELECT * FROM place WHERE telcode IN(?) ORDER BY telcode") query, args, err := In(q, telcodes) if err != nil { t.Error(err) @@ -1607,7 +1613,7 @@ func TestIn(t *testing.T) { func TestBindStruct(t *testing.T) { var err error - q1 := `INSERT INTO foo (a, b, c, d) VALUES (:name, :age, :first, :last)` + q1 := safesql.New(`INSERT INTO foo (a, b, c, d) VALUES (:name, :age, :first, :last)`) type tt struct { Name string @@ -1629,7 +1635,7 @@ func TestBindStruct(t *testing.T) { am := tt{"Jason Moiron", 30, "Jason", "Moiron"} bq, args, _ := bindStruct(QUESTION, q1, am, mapper()) - expect := `INSERT INTO foo (a, b, c, d) VALUES (?, ?, ?, ?)` + expect := safesql.New(`INSERT INTO foo (a, b, c, d) VALUES (?, ?, ?, ?)`) if bq != expect { t.Errorf("Interpolation of query failed: got `%v`, expected `%v`\n", bq, expect) } @@ -1651,8 +1657,8 @@ func TestBindStruct(t *testing.T) { } am2 := tt2{"Hello", "World"} - bq, args, _ = bindStruct(QUESTION, "INSERT INTO foo (a, b) VALUES (:field_2, :field_1)", am2, mapper()) - expect = `INSERT INTO foo (a, b) VALUES (?, ?)` + bq, args, _ = bindStruct(QUESTION, safesql.New("INSERT INTO foo (a, b) VALUES (:field_2, :field_1)"), am2, mapper()) + expect = safesql.New(`INSERT INTO foo (a, b) VALUES (?, ?)`) if bq != expect { t.Errorf("Interpolation of query failed: got `%v`, expected `%v`\n", bq, expect) } @@ -1668,13 +1674,13 @@ func TestBindStruct(t *testing.T) { am3.Field1 = "Hello" am3.Field2 = "World" - bq, args, err = bindStruct(QUESTION, "INSERT INTO foo (a, b, c) VALUES (:name, :field_1, :field_2)", am3, mapper()) + bq, args, err = bindStruct(QUESTION, safesql.New("INSERT INTO foo (a, b, c) VALUES (:name, :field_1, :field_2)"), am3, mapper()) if err != nil { t.Fatal(err) } - expect = `INSERT INTO foo (a, b, c) VALUES (?, ?, ?)` + expect = safesql.New(`INSERT INTO foo (a, b, c) VALUES (?, ?, ?)`) if bq != expect { t.Errorf("Interpolation of query failed: got `%v`, expected `%v`\n", bq, expect) } @@ -1710,10 +1716,10 @@ func TestEmbeddedLiterals(t *testing.T) { K *string } - db.MustExec(db.Rebind("INSERT INTO x (k) VALUES (?), (?), (?);"), "one", "two", "three") + db.MustExec(db.Rebind(safesql.New("INSERT INTO x (k) VALUES (?), (?), (?);")), "one", "two", "three") target := t1{} - err := db.Get(&target, db.Rebind("SELECT * FROM x WHERE k=?"), "one") + err := db.Get(&target, db.Rebind(safesql.New("SELECT * FROM x WHERE k=?")), "one") if err != nil { t.Error(err) } @@ -1722,7 +1728,7 @@ func TestEmbeddedLiterals(t *testing.T) { } target2 := t2{} - err = db.Get(&target2, db.Rebind("SELECT * FROM x WHERE k=?"), "one") + err = db.Get(&target2, db.Rebind(safesql.New("SELECT * FROM x WHERE k=?")), "one") if err != nil { t.Error(err) } @@ -1734,7 +1740,7 @@ func TestEmbeddedLiterals(t *testing.T) { func BenchmarkBindStruct(b *testing.B) { b.StopTimer() - q1 := `INSERT INTO foo (a, b, c, d) VALUES (:name, :age, :first, :last)` + q1 := safesql.New(`INSERT INTO foo (a, b, c, d) VALUES (:name, :age, :first, :last)`) type t struct { Name string Age int @@ -1751,7 +1757,7 @@ func BenchmarkBindStruct(b *testing.B) { func TestBindNamedMapper(t *testing.T) { type A map[string]interface{} m := reflectx.NewMapperFunc("db", NameMapper) - query, args, err := bindNamedMapper(DOLLAR, `select :x`, A{ + query, args, err := bindNamedMapper(DOLLAR, safesql.New(`select :x`), A{ "x": "X!", }, m) if err != nil { @@ -1764,7 +1770,7 @@ func TestBindNamedMapper(t *testing.T) { t.Errorf("\ngot: %q\nwant: %q", got, want) } - _, _, err = bindNamedMapper(DOLLAR, `select :x`, map[string]string{ + _, _, err = bindNamedMapper(DOLLAR, safesql.New(`select :x`), map[string]string{ "x": "X!", }, m) if err == nil { @@ -1777,7 +1783,7 @@ func TestBindNamedMapper(t *testing.T) { func BenchmarkBindMap(b *testing.B) { b.StopTimer() - q1 := `INSERT INTO foo (a, b, c, d) VALUES (:name, :age, :first, :last)` + q1 := safesql.New(`INSERT INTO foo (a, b, c, d) VALUES (:name, :age, :first, :last)`) am := map[string]interface{}{ "name": "Jason Moiron", "age": 30, @@ -1791,7 +1797,7 @@ func BenchmarkBindMap(b *testing.B) { } func BenchmarkIn(b *testing.B) { - q := `SELECT * FROM foo WHERE x = ? AND v in (?) AND y = ?` + q := safesql.New(`SELECT * FROM foo WHERE x = ? AND v in (?) AND y = ?`) for i := 0; i < b.N; i++ { _, _, _ = In(q, []interface{}{"foo", []int{0, 5, 7, 2, 9}, "bar"}...) @@ -1799,7 +1805,7 @@ func BenchmarkIn(b *testing.B) { } func BenchmarkIn1k(b *testing.B) { - q := `SELECT * FROM foo WHERE x = ? AND v in (?) AND y = ?` + q := safesql.New(`SELECT * FROM foo WHERE x = ? AND v in (?) AND y = ?`) var vals [1000]interface{} @@ -1809,7 +1815,7 @@ func BenchmarkIn1k(b *testing.B) { } func BenchmarkIn1kInt(b *testing.B) { - q := `SELECT * FROM foo WHERE x = ? AND v in (?) AND y = ?` + q := safesql.New(`SELECT * FROM foo WHERE x = ? AND v in (?) AND y = ?`) var vals [1000]int @@ -1819,7 +1825,7 @@ func BenchmarkIn1kInt(b *testing.B) { } func BenchmarkIn1kString(b *testing.B) { - q := `SELECT * FROM foo WHERE x = ? AND v in (?) AND y = ?` + q := safesql.New(`SELECT * FROM foo WHERE x = ? AND v in (?) AND y = ?`) var vals [1000]string @@ -1830,8 +1836,8 @@ func BenchmarkIn1kString(b *testing.B) { func BenchmarkRebind(b *testing.B) { b.StopTimer() - q1 := `INSERT INTO foo (a, b, c, d, e, f, g, h, i) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)` - q2 := `INSERT INTO foo (a, b, c) VALUES (?, ?, "foo"), ("Hi", ?, ?)` + q1 := safesql.New(`INSERT INTO foo (a, b, c, d, e, f, g, h, i) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`) + q2 := safesql.New(`INSERT INTO foo (a, b, c) VALUES (?, ?, "foo"), ("Hi", ?, ?)`) b.StartTimer() for i := 0; i < b.N; i++ { @@ -1842,8 +1848,8 @@ func BenchmarkRebind(b *testing.B) { func BenchmarkRebindBuffer(b *testing.B) { b.StopTimer() - q1 := `INSERT INTO foo (a, b, c, d, e, f, g, h, i) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)` - q2 := `INSERT INTO foo (a, b, c) VALUES (?, ?, "foo"), ("Hi", ?, ?)` + q1 := safesql.New(`INSERT INTO foo (a, b, c, d, e, f, g, h, i) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`) + q2 := safesql.New(`INSERT INTO foo (a, b, c) VALUES (?, ?, "foo"), ("Hi", ?, ?)`) b.StartTimer() for i := 0; i < b.N; i++ { @@ -1854,11 +1860,11 @@ func BenchmarkRebindBuffer(b *testing.B) { func TestIn130Regression(t *testing.T) { t.Run("[]interface{}{}", func(t *testing.T) { - q, args, err := In("SELECT * FROM people WHERE name IN (?)", []interface{}{[]string{"gopher"}}...) + q, args, err := In(safesql.New("SELECT * FROM people WHERE name IN (?)"), []interface{}{[]string{"gopher"}}...) if err != nil { t.Fatal(err) } - if q != "SELECT * FROM people WHERE name IN (?)" { + if q != safesql.New("SELECT * FROM people WHERE name IN (?)") { t.Errorf("got=%v", q) } t.Log(args) @@ -1873,11 +1879,11 @@ func TestIn130Regression(t *testing.T) { }) t.Run("[]string{}", func(t *testing.T) { - q, args, err := In("SELECT * FROM people WHERE name IN (?)", []string{"gopher"}) + q, args, err := In(safesql.New("SELECT * FROM people WHERE name IN (?)"), []string{"gopher"}) if err != nil { t.Fatal(err) } - if q != "SELECT * FROM people WHERE name IN (?)" { + if q != safesql.New("SELECT * FROM people WHERE name IN (?)") { t.Errorf("got=%v", q) } t.Log(args) @@ -1897,7 +1903,7 @@ func TestSelectReset(t *testing.T) { loadDefaultFixture(db, t) filledDest := []string{"a", "b", "c"} - err := db.Select(&filledDest, "SELECT first_name FROM person ORDER BY first_name ASC;") + err := db.Select(&filledDest, safesql.New("SELECT first_name FROM person ORDER BY first_name ASC;")) if err != nil { t.Fatal(err) } @@ -1912,7 +1918,7 @@ func TestSelectReset(t *testing.T) { } var emptyDest []string - err = db.Select(&emptyDest, "SELECT first_name FROM person WHERE first_name = 'Jack';") + err = db.Select(&emptyDest, safesql.New("SELECT first_name FROM person WHERE first_name = 'Jack';")) if err != nil { t.Fatal(err) } From 882c72c4432203e9af86697b958734590727d7c1 Mon Sep 17 00:00:00 2001 From: Kasvi Date: Wed, 27 Nov 2024 13:23:14 +1100 Subject: [PATCH 5/5] update package name --- go.mod | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 15ffd49..a602fdd 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module github.com/jmoiron/sqlx +module github.com/kasluthra-sec/sqlx go 1.10