From 87edeed4461d4e804455f34ed93cb7d01c8eaca9 Mon Sep 17 00:00:00 2001 From: Zhongyi Tong Date: Wed, 1 Nov 2023 19:45:02 +0000 Subject: [PATCH 1/8] Update reflectx to allow for optional nested structs --- reflectx/reflect.go | 188 +++++++++++++++++++++++++++++++++++++++++++ sqlx.go | 16 ++-- sqlx_context_test.go | 104 ++++++++++++++++++++++++ 3 files changed, 303 insertions(+), 5 deletions(-) diff --git a/reflectx/reflect.go b/reflectx/reflect.go index 8ec6a13..e8dc926 100644 --- a/reflectx/reflect.go +++ b/reflectx/reflect.go @@ -6,8 +6,11 @@ package reflectx import ( + "database/sql" + "fmt" "reflect" "runtime" + "strconv" "strings" "sync" ) @@ -200,6 +203,191 @@ func (m *Mapper) TraversalsByNameFunc(t reflect.Type, names []string, fn func(in return nil } +// ObjectContext provides a single layer to abstract away +// nested struct scanning functionality +type ObjectContext struct { + value reflect.Value +} + +func NewObjectContext() *ObjectContext { + return &ObjectContext{} +} + +// NewRow updates the object reference. +// This ensures all columns point to the same object +func (o *ObjectContext) NewRow(value reflect.Value) { + o.value = value +} + +// FieldForIndexes returns the value for address. If the address is a nested struct, +// a nestedFieldScanner is returned instead of the standard value reference +func (o *ObjectContext) FieldForIndexes(indexes []int) reflect.Value { + if len(indexes) == 1 { + val := FieldByIndexes(o.value, indexes) + return val + } + + obj := &nestedFieldScanner{ + parent: o, + indexes: indexes, + } + + v := reflect.ValueOf(obj).Elem() + return v +} + +// nestedFieldScanner will only forward the Scan to the nested value if +// the database value is not nil. +type nestedFieldScanner struct { + parent *ObjectContext + indexes []int +} + +// Scan implements sql.Scanner. +// This method largely mirrors the sql.convertAssign() method with some minor changes +func (o *nestedFieldScanner) Scan(src interface{}) error { + if src == nil { + return nil + } + + dv := FieldByIndexes(o.parent.value, o.indexes) + // Dereference pointer fields to avoid double pointers **T + if dv.Kind() == reflect.Pointer { + dv.Set(reflect.New(dv.Type().Elem())) + dv = dv.Elem() + } + iface := dv.Addr().Interface() + + if scan, ok := iface.(sql.Scanner); ok { + return scan.Scan(src) + } + + sv := reflect.ValueOf(src) + + // below is taken from https://cs.opensource.google/go/go/+/refs/tags/go1.19.5:src/database/sql/convert.go + // with a few minor edits + + if sv.IsValid() && sv.Type().AssignableTo(dv.Type()) { + switch b := src.(type) { + case []byte: + dv.Set(reflect.ValueOf(bytesClone(b))) + default: + dv.Set(sv) + } + + return nil + } + + if dv.Kind() == sv.Kind() && sv.Type().ConvertibleTo(dv.Type()) { + dv.Set(sv.Convert(dv.Type())) + return nil + } + + // The following conversions use a string value as an intermediate representation + // to convert between various numeric types. + // + // This also allows scanning into user defined types such as "type Int int64". + // For symmetry, also check for string destination types. + switch dv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + if src == nil { + return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind()) + } + s := asString(src) + i64, err := strconv.ParseInt(s, 10, dv.Type().Bits()) + if err != nil { + err = strconvErr(err) + return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) + } + dv.SetInt(i64) + return nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + if src == nil { + return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind()) + } + s := asString(src) + u64, err := strconv.ParseUint(s, 10, dv.Type().Bits()) + if err != nil { + err = strconvErr(err) + return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) + } + dv.SetUint(u64) + return nil + case reflect.Float32, reflect.Float64: + if src == nil { + return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind()) + } + s := asString(src) + f64, err := strconv.ParseFloat(s, dv.Type().Bits()) + if err != nil { + err = strconvErr(err) + return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) + } + dv.SetFloat(f64) + return nil + case reflect.String: + if src == nil { + return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind()) + } + switch v := src.(type) { + case string: + dv.SetString(v) + return nil + case []byte: + dv.SetString(string(v)) + return nil + } + } + + return fmt.Errorf("don't know how to parse type %T -> %T", src, iface) +} + +// returns internal conversion error if available +// taken from https://cs.opensource.google/go/go/+/refs/tags/go1.19.5:src/database/sql/convert.go +func strconvErr(err error) error { + if ne, ok := err.(*strconv.NumError); ok { + return ne.Err + } + return err +} + +// converts value to it's string value +// taken from https://cs.opensource.google/go/go/+/refs/tags/go1.19.5:src/database/sql/convert.go +func asString(src interface{}) string { + switch v := src.(type) { + case string: + return v + case []byte: + return string(v) + } + rv := reflect.ValueOf(src) + switch rv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return strconv.FormatInt(rv.Int(), 10) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return strconv.FormatUint(rv.Uint(), 10) + case reflect.Float64: + return strconv.FormatFloat(rv.Float(), 'g', -1, 64) + case reflect.Float32: + return strconv.FormatFloat(rv.Float(), 'g', -1, 32) + case reflect.Bool: + return strconv.FormatBool(rv.Bool()) + } + return fmt.Sprintf("%v", src) +} + +// bytesClone returns a copy of b[:len(b)]. +// The result may have additional unused capacity. +// Clone(nil) returns nil. +// +// bytesClone is a mirror of bytes.Clone while our go.mod is on an older version +func bytesClone(b []byte) []byte { + if b == nil { + return nil + } + return append([]byte{}, b...) +} + // FieldByIndexes returns a value for the field given by the struct traversal // for the given value. func FieldByIndexes(v reflect.Value, indexes []int) reflect.Value { diff --git a/sqlx.go b/sqlx.go index 8259a4f..e0ef63d 100644 --- a/sqlx.go +++ b/sqlx.go @@ -624,7 +624,8 @@ func (r *Rows) StructScan(dest interface{}) error { r.started = true } - err := fieldsByTraversal(v, r.fields, r.values, true) + octx := reflectx.NewObjectContext() + err := fieldsByTraversal(octx, v, r.fields, r.values, true) if err != nil { return err } @@ -784,7 +785,9 @@ func (r *Row) scanAny(dest interface{}, structOnly bool) error { } values := make([]interface{}, len(columns)) - err = fieldsByTraversal(v, fields, values, true) + octx := reflectx.NewObjectContext() + + err = fieldsByTraversal(octx, v, fields, values, true) if err != nil { return err } @@ -951,13 +954,14 @@ func scanAll(rows rowsi, dest interface{}, structOnly bool) error { return fmt.Errorf("missing destination name %s in %T", columns[f], dest) } values = make([]interface{}, len(columns)) + octx := reflectx.NewObjectContext() for rows.Next() { // create a new struct type (which returns PtrTo) and indirect it vp = reflect.New(base) v = reflect.Indirect(vp) - err = fieldsByTraversal(v, fields, values, true) + err = fieldsByTraversal(octx, v, fields, values, true) if err != nil { return err } @@ -1023,18 +1027,20 @@ func baseType(t reflect.Type, expected reflect.Kind) (reflect.Type, error) { // when iterating over many rows. Empty traversals will get an interface pointer. // Because of the necessity of requesting ptrs or values, it's considered a bit too // specialized for inclusion in reflectx itself. -func fieldsByTraversal(v reflect.Value, traversals [][]int, values []interface{}, ptrs bool) error { +func fieldsByTraversal(octx *reflectx.ObjectContext, v reflect.Value, traversals [][]int, values []interface{}, ptrs bool) error { v = reflect.Indirect(v) if v.Kind() != reflect.Struct { return errors.New("argument not a struct") } + octx.NewRow(v) + for i, traversal := range traversals { if len(traversal) == 0 { values[i] = new(interface{}) continue } - f := reflectx.FieldByIndexes(v, traversal) + f := octx.FieldForIndexes(traversal) if ptrs { values[i] = f.Addr().Interface() } else { diff --git a/sqlx_context_test.go b/sqlx_context_test.go index 91c5cba..73e4f5d 100644 --- a/sqlx_context_test.go +++ b/sqlx_context_test.go @@ -643,6 +643,110 @@ func TestNamedQueryContext(t *testing.T) { t.Errorf("Expected place name of %v, got %v", pp.Place.ID, pp2.Place.ID) } } + + rows.Close() + + type Owner struct { + Email *string `db:"email"` + FirstName string `db:"first_name"` + LastName string `db:"last_name"` + } + + // Test optional nested structs with left join + type PlaceOwner struct { + Place Place `db:"place"` + Owner *Owner `db:"owner"` + } + + pl = Place{ + Name: sql.NullString{String: "the-house", Valid: true}, + } + + q4 := `INSERT INTO place (id, name) VALUES (2, :name)` + _, err = db.NamedExecContext(ctx, q4, pl) + if err != nil { + log.Fatal(err) + } + + id = 2 + pp.Place.ID = id + + q5 := `INSERT INTO placeperson (first_name, last_name, email, place_id) VALUES (:first_name, :last_name, :email, :place.id)` + _, err = db.NamedExecContext(ctx, q5, pp) + if err != nil { + log.Fatal(err) + } + + pp3 := &PlaceOwner{} + rows, err = db.NamedQueryContext(ctx, ` + SELECT + placeperson.first_name "owner.first_name", + placeperson.last_name "owner.last_name", + placeperson.email "owner.email", + place.id AS "place.id", + place.name AS "place.name" + FROM place + LEFT JOIN placeperson ON false -- null left join + WHERE + place.id=:place.id`, pp) + if err != nil { + log.Fatal(err) + } + for rows.Next() { + err = rows.StructScan(pp3) + if err != nil { + t.Error(err) + } + if pp3.Owner != nil { + t.Error("Expected `Owner`, to be nil") + } + if pp3.Place.Name.String != "the-house" { + t.Error("Expected place name of `the-house`, got " + pp3.Place.Name.String) + } + if pp3.Place.ID != pp.Place.ID { + t.Errorf("Expected place name of %v, got %v", pp.Place.ID, pp3.Place.ID) + } + } + + rows.Close() + + pp3 = &PlaceOwner{} + rows, err = db.NamedQueryContext(ctx, ` + SELECT + placeperson.first_name "owner.first_name", + placeperson.last_name "owner.last_name", + placeperson.email "owner.email", + place.id AS "place.id", + place.name AS "place.name" + FROM place + left JOIN placeperson ON placeperson.place_id = place.id + WHERE + place.id=:place.id`, pp) + if err != nil { + log.Fatal(err) + } + for rows.Next() { + err = rows.StructScan(pp3) + if err != nil { + t.Error(err) + } + if pp3.Owner == nil { + t.Error("Expected `Owner`, to not be nil") + } + + if pp3.Owner.FirstName != "ben" { + t.Error("Expected first name of `ben`, got " + pp3.Owner.FirstName) + } + if pp3.Owner.LastName != "doe" { + t.Error("Expected first name of `doe`, got " + pp3.Owner.LastName) + } + if pp3.Place.Name.String != "the-house" { + t.Error("Expected place name of `the-house`, got " + pp3.Place.Name.String) + } + if pp3.Place.ID != pp.Place.ID { + t.Errorf("Expected place name of %v, got %v", pp.Place.ID, pp3.Place.ID) + } + } }) } From 48580808cda90b2b1b62f64afd09737f0ffdffa0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Krzysztof=20Szafra=C5=84ski?= Date: Wed, 16 Oct 2024 18:07:13 +0200 Subject: [PATCH 2/8] Use go:linkname to call convertAssign instead of copying it --- reflectx/reflect.go | 154 +++----------------------------------------- 1 file changed, 8 insertions(+), 146 deletions(-) diff --git a/reflectx/reflect.go b/reflectx/reflect.go index e8dc926..ded9838 100644 --- a/reflectx/reflect.go +++ b/reflectx/reflect.go @@ -6,13 +6,11 @@ package reflectx import ( - "database/sql" - "fmt" "reflect" "runtime" - "strconv" "strings" "sync" + _ "unsafe" ) // A FieldInfo is metadata for a struct field. @@ -223,8 +221,7 @@ func (o *ObjectContext) NewRow(value reflect.Value) { // a nestedFieldScanner is returned instead of the standard value reference func (o *ObjectContext) FieldForIndexes(indexes []int) reflect.Value { if len(indexes) == 1 { - val := FieldByIndexes(o.value, indexes) - return val + return FieldByIndexes(o.value, indexes) } obj := &nestedFieldScanner{ @@ -232,8 +229,7 @@ func (o *ObjectContext) FieldForIndexes(indexes []int) reflect.Value { indexes: indexes, } - v := reflect.ValueOf(obj).Elem() - return v + return reflect.ValueOf(obj).Elem() } // nestedFieldScanner will only forward the Scan to the nested value if @@ -244,149 +240,16 @@ type nestedFieldScanner struct { } // Scan implements sql.Scanner. -// This method largely mirrors the sql.convertAssign() method with some minor changes func (o *nestedFieldScanner) Scan(src interface{}) error { if src == nil { return nil } - - dv := FieldByIndexes(o.parent.value, o.indexes) - // Dereference pointer fields to avoid double pointers **T - if dv.Kind() == reflect.Pointer { - dv.Set(reflect.New(dv.Type().Elem())) - dv = dv.Elem() - } - iface := dv.Addr().Interface() - - if scan, ok := iface.(sql.Scanner); ok { - return scan.Scan(src) - } - - sv := reflect.ValueOf(src) - - // below is taken from https://cs.opensource.google/go/go/+/refs/tags/go1.19.5:src/database/sql/convert.go - // with a few minor edits - - if sv.IsValid() && sv.Type().AssignableTo(dv.Type()) { - switch b := src.(type) { - case []byte: - dv.Set(reflect.ValueOf(bytesClone(b))) - default: - dv.Set(sv) - } - - return nil - } - - if dv.Kind() == sv.Kind() && sv.Type().ConvertibleTo(dv.Type()) { - dv.Set(sv.Convert(dv.Type())) - return nil - } - - // The following conversions use a string value as an intermediate representation - // to convert between various numeric types. - // - // This also allows scanning into user defined types such as "type Int int64". - // For symmetry, also check for string destination types. - switch dv.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - if src == nil { - return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind()) - } - s := asString(src) - i64, err := strconv.ParseInt(s, 10, dv.Type().Bits()) - if err != nil { - err = strconvErr(err) - return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) - } - dv.SetInt(i64) - return nil - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - if src == nil { - return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind()) - } - s := asString(src) - u64, err := strconv.ParseUint(s, 10, dv.Type().Bits()) - if err != nil { - err = strconvErr(err) - return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) - } - dv.SetUint(u64) - return nil - case reflect.Float32, reflect.Float64: - if src == nil { - return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind()) - } - s := asString(src) - f64, err := strconv.ParseFloat(s, dv.Type().Bits()) - if err != nil { - err = strconvErr(err) - return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) - } - dv.SetFloat(f64) - return nil - case reflect.String: - if src == nil { - return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind()) - } - switch v := src.(type) { - case string: - dv.SetString(v) - return nil - case []byte: - dv.SetString(string(v)) - return nil - } - } - - return fmt.Errorf("don't know how to parse type %T -> %T", src, iface) + dest := FieldByIndexes(o.parent.value, o.indexes) + return convertAssign(dest.Addr().Interface(), src) } -// returns internal conversion error if available -// taken from https://cs.opensource.google/go/go/+/refs/tags/go1.19.5:src/database/sql/convert.go -func strconvErr(err error) error { - if ne, ok := err.(*strconv.NumError); ok { - return ne.Err - } - return err -} - -// converts value to it's string value -// taken from https://cs.opensource.google/go/go/+/refs/tags/go1.19.5:src/database/sql/convert.go -func asString(src interface{}) string { - switch v := src.(type) { - case string: - return v - case []byte: - return string(v) - } - rv := reflect.ValueOf(src) - switch rv.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return strconv.FormatInt(rv.Int(), 10) - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - return strconv.FormatUint(rv.Uint(), 10) - case reflect.Float64: - return strconv.FormatFloat(rv.Float(), 'g', -1, 64) - case reflect.Float32: - return strconv.FormatFloat(rv.Float(), 'g', -1, 32) - case reflect.Bool: - return strconv.FormatBool(rv.Bool()) - } - return fmt.Sprintf("%v", src) -} - -// bytesClone returns a copy of b[:len(b)]. -// The result may have additional unused capacity. -// Clone(nil) returns nil. -// -// bytesClone is a mirror of bytes.Clone while our go.mod is on an older version -func bytesClone(b []byte) []byte { - if b == nil { - return nil - } - return append([]byte{}, b...) -} +//go:linkname convertAssign database/sql.convertAssign +func convertAssign(dest, src interface{}) error // FieldByIndexes returns a value for the field given by the struct traversal // for the given value. @@ -395,8 +258,7 @@ func FieldByIndexes(v reflect.Value, indexes []int) reflect.Value { v = reflect.Indirect(v).Field(i) // if this is a pointer and it's nil, allocate a new value and set it if v.Kind() == reflect.Ptr && v.IsNil() { - alloc := reflect.New(Deref(v.Type())) - v.Set(alloc) + v.Set(reflect.New(v.Type().Elem())) } if v.Kind() == reflect.Map && v.IsNil() { v.Set(reflect.MakeMap(v.Type())) From f537847a9abeada8c5e4933756eda77c28a32c9b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Krzysztof=20Szafra=C5=84ski?= Date: Wed, 16 Oct 2024 18:14:46 +0200 Subject: [PATCH 3/8] Move ObjectContext out of reflectx where it doesn't belong --- convert.go | 8 +++++++ reflectx/reflect.go | 51 ----------------------------------------- sqlx.go | 55 +++++++++++++++++++++++++++++++++++++++++---- 3 files changed, 59 insertions(+), 55 deletions(-) create mode 100644 convert.go diff --git a/convert.go b/convert.go new file mode 100644 index 0000000..3964a91 --- /dev/null +++ b/convert.go @@ -0,0 +1,8 @@ +package sqlx + +import ( + _ "unsafe" +) + +//go:linkname convertAssign database/sql.convertAssign +func convertAssign(dest, src interface{}) error diff --git a/reflectx/reflect.go b/reflectx/reflect.go index ded9838..beaaa43 100644 --- a/reflectx/reflect.go +++ b/reflectx/reflect.go @@ -10,7 +10,6 @@ import ( "runtime" "strings" "sync" - _ "unsafe" ) // A FieldInfo is metadata for a struct field. @@ -201,56 +200,6 @@ func (m *Mapper) TraversalsByNameFunc(t reflect.Type, names []string, fn func(in return nil } -// ObjectContext provides a single layer to abstract away -// nested struct scanning functionality -type ObjectContext struct { - value reflect.Value -} - -func NewObjectContext() *ObjectContext { - return &ObjectContext{} -} - -// NewRow updates the object reference. -// This ensures all columns point to the same object -func (o *ObjectContext) NewRow(value reflect.Value) { - o.value = value -} - -// FieldForIndexes returns the value for address. If the address is a nested struct, -// a nestedFieldScanner is returned instead of the standard value reference -func (o *ObjectContext) FieldForIndexes(indexes []int) reflect.Value { - if len(indexes) == 1 { - return FieldByIndexes(o.value, indexes) - } - - obj := &nestedFieldScanner{ - parent: o, - indexes: indexes, - } - - return reflect.ValueOf(obj).Elem() -} - -// nestedFieldScanner will only forward the Scan to the nested value if -// the database value is not nil. -type nestedFieldScanner struct { - parent *ObjectContext - indexes []int -} - -// Scan implements sql.Scanner. -func (o *nestedFieldScanner) Scan(src interface{}) error { - if src == nil { - return nil - } - dest := FieldByIndexes(o.parent.value, o.indexes) - return convertAssign(dest.Addr().Interface(), src) -} - -//go:linkname convertAssign database/sql.convertAssign -func convertAssign(dest, src interface{}) error - // FieldByIndexes returns a value for the field given by the struct traversal // for the given value. func FieldByIndexes(v reflect.Value, indexes []int) reflect.Value { diff --git a/sqlx.go b/sqlx.go index e0ef63d..b0b1038 100644 --- a/sqlx.go +++ b/sqlx.go @@ -624,7 +624,7 @@ func (r *Rows) StructScan(dest interface{}) error { r.started = true } - octx := reflectx.NewObjectContext() + octx := newObjectContext() err := fieldsByTraversal(octx, v, r.fields, r.values, true) if err != nil { return err @@ -785,7 +785,7 @@ func (r *Row) scanAny(dest interface{}, structOnly bool) error { } values := make([]interface{}, len(columns)) - octx := reflectx.NewObjectContext() + octx := newObjectContext() err = fieldsByTraversal(octx, v, fields, values, true) if err != nil { @@ -954,7 +954,7 @@ func scanAll(rows rowsi, dest interface{}, structOnly bool) error { return fmt.Errorf("missing destination name %s in %T", columns[f], dest) } values = make([]interface{}, len(columns)) - octx := reflectx.NewObjectContext() + octx := newObjectContext() for rows.Next() { // create a new struct type (which returns PtrTo) and indirect it @@ -1027,7 +1027,7 @@ func baseType(t reflect.Type, expected reflect.Kind) (reflect.Type, error) { // when iterating over many rows. Empty traversals will get an interface pointer. // Because of the necessity of requesting ptrs or values, it's considered a bit too // specialized for inclusion in reflectx itself. -func fieldsByTraversal(octx *reflectx.ObjectContext, v reflect.Value, traversals [][]int, values []interface{}, ptrs bool) error { +func fieldsByTraversal(octx *objectContext, v reflect.Value, traversals [][]int, values []interface{}, ptrs bool) error { v = reflect.Indirect(v) if v.Kind() != reflect.Struct { return errors.New("argument not a struct") @@ -1058,3 +1058,50 @@ func missingFields(transversals [][]int) (field int, err error) { } return 0, nil } + +// objectContext provides a single layer to abstract away +// nested struct scanning functionality +type objectContext struct { + value reflect.Value +} + +func newObjectContext() *objectContext { + return &objectContext{} +} + +// NewRow updates the object reference. +// This ensures all columns point to the same object +func (o *objectContext) NewRow(value reflect.Value) { + o.value = value +} + +// FieldForIndexes returns the value for address. If the address is a nested struct, +// a nestedFieldScanner is returned instead of the standard value reference +func (o *objectContext) FieldForIndexes(indexes []int) reflect.Value { + if len(indexes) == 1 { + return reflectx.FieldByIndexes(o.value, indexes) + } + + obj := &nestedFieldScanner{ + parent: o, + indexes: indexes, + } + + return reflect.ValueOf(obj).Elem() +} + +// nestedFieldScanner will only forward the Scan to the nested value if +// the database value is not nil. +type nestedFieldScanner struct { + parent *objectContext + indexes []int +} + +// Scan implements sql.Scanner. +func (o *nestedFieldScanner) Scan(src interface{}) error { + if src == nil { + return nil + } + dest := reflectx.FieldByIndexes(o.parent.value, o.indexes) + return convertAssign(dest.Addr().Interface(), src) +} From f11fa570ece992a4356a7313f3d61d9dd7721d6d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Krzysztof=20Szafra=C5=84ski?= Date: Wed, 16 Oct 2024 18:16:15 +0200 Subject: [PATCH 4/8] Simplify fieldsByTraversal, ptrs is always true --- sqlx.go | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/sqlx.go b/sqlx.go index b0b1038..1f286c3 100644 --- a/sqlx.go +++ b/sqlx.go @@ -625,7 +625,7 @@ func (r *Rows) StructScan(dest interface{}) error { } octx := newObjectContext() - err := fieldsByTraversal(octx, v, r.fields, r.values, true) + err := fieldsByTraversal(octx, v, r.fields, r.values) if err != nil { return err } @@ -787,7 +787,7 @@ func (r *Row) scanAny(dest interface{}, structOnly bool) error { octx := newObjectContext() - err = fieldsByTraversal(octx, v, fields, values, true) + err = fieldsByTraversal(octx, v, fields, values) if err != nil { return err } @@ -961,7 +961,7 @@ func scanAll(rows rowsi, dest interface{}, structOnly bool) error { vp = reflect.New(base) v = reflect.Indirect(vp) - err = fieldsByTraversal(octx, v, fields, values, true) + err = fieldsByTraversal(octx, v, fields, values) if err != nil { return err } @@ -1027,7 +1027,7 @@ func baseType(t reflect.Type, expected reflect.Kind) (reflect.Type, error) { // when iterating over many rows. Empty traversals will get an interface pointer. // Because of the necessity of requesting ptrs or values, it's considered a bit too // specialized for inclusion in reflectx itself. -func fieldsByTraversal(octx *objectContext, v reflect.Value, traversals [][]int, values []interface{}, ptrs bool) error { +func fieldsByTraversal(octx *objectContext, v reflect.Value, traversals [][]int, values []interface{}) error { v = reflect.Indirect(v) if v.Kind() != reflect.Struct { return errors.New("argument not a struct") @@ -1041,11 +1041,7 @@ func fieldsByTraversal(octx *objectContext, v reflect.Value, traversals [][]int, continue } f := octx.FieldForIndexes(traversal) - if ptrs { - values[i] = f.Addr().Interface() - } else { - values[i] = f.Interface() - } + values[i] = f.Addr().Interface() } return nil } From cb724a28bc9425fb28ba8f6f6cbe28899dd5588b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Krzysztof=20Szafra=C5=84ski?= Date: Wed, 16 Oct 2024 18:16:47 +0200 Subject: [PATCH 5/8] Fix typo --- sqlx.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sqlx.go b/sqlx.go index 1f286c3..a656626 100644 --- a/sqlx.go +++ b/sqlx.go @@ -1046,8 +1046,8 @@ func fieldsByTraversal(octx *objectContext, v reflect.Value, traversals [][]int, return nil } -func missingFields(transversals [][]int) (field int, err error) { - for i, t := range transversals { +func missingFields(traversals [][]int) (field int, err error) { + for i, t := range traversals { if len(t) == 0 { return i, errors.New("missing field") } From a58a604a216833ea36792422b3214a622f8fae9d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Krzysztof=20Szafra=C5=84ski?= Date: Wed, 16 Oct 2024 18:27:53 +0200 Subject: [PATCH 6/8] Simplify the code by eliminating objectContext and using simple optDest --- sqlx.go | 67 +++++++++++++-------------------------------------------- 1 file changed, 15 insertions(+), 52 deletions(-) diff --git a/sqlx.go b/sqlx.go index a656626..c2f500c 100644 --- a/sqlx.go +++ b/sqlx.go @@ -624,8 +624,7 @@ func (r *Rows) StructScan(dest interface{}) error { r.started = true } - octx := newObjectContext() - err := fieldsByTraversal(octx, v, r.fields, r.values) + err := fieldsByTraversal(v, r.fields, r.values) if err != nil { return err } @@ -785,9 +784,7 @@ func (r *Row) scanAny(dest interface{}, structOnly bool) error { } values := make([]interface{}, len(columns)) - octx := newObjectContext() - - err = fieldsByTraversal(octx, v, fields, values) + err = fieldsByTraversal(v, fields, values) if err != nil { return err } @@ -954,14 +951,13 @@ func scanAll(rows rowsi, dest interface{}, structOnly bool) error { return fmt.Errorf("missing destination name %s in %T", columns[f], dest) } values = make([]interface{}, len(columns)) - octx := newObjectContext() for rows.Next() { // create a new struct type (which returns PtrTo) and indirect it vp = reflect.New(base) v = reflect.Indirect(vp) - err = fieldsByTraversal(octx, v, fields, values) + err = fieldsByTraversal(v, fields, values) if err != nil { return err } @@ -1027,21 +1023,23 @@ func baseType(t reflect.Type, expected reflect.Kind) (reflect.Type, error) { // when iterating over many rows. Empty traversals will get an interface pointer. // Because of the necessity of requesting ptrs or values, it's considered a bit too // specialized for inclusion in reflectx itself. -func fieldsByTraversal(octx *objectContext, v reflect.Value, traversals [][]int, values []interface{}) error { +func fieldsByTraversal(v reflect.Value, traversals [][]int, values []interface{}) error { v = reflect.Indirect(v) if v.Kind() != reflect.Struct { return errors.New("argument not a struct") } - octx.NewRow(v) - for i, traversal := range traversals { if len(traversal) == 0 { values[i] = new(interface{}) - continue + } else if len(traversal) == 1 { + values[i] = reflectx.FieldByIndexes(v, traversal).Addr().Interface() + } else { + traversal := traversal + values[i] = optDest(func() interface{} { + return reflectx.FieldByIndexes(v, traversal).Addr().Interface() + }) } - f := octx.FieldForIndexes(traversal) - values[i] = f.Addr().Interface() } return nil } @@ -1055,49 +1053,14 @@ func missingFields(traversals [][]int) (field int, err error) { return 0, nil } -// objectContext provides a single layer to abstract away -// nested struct scanning functionality -type objectContext struct { - value reflect.Value -} - -func newObjectContext() *objectContext { - return &objectContext{} -} - -// NewRow updates the object reference. -// This ensures all columns point to the same object -func (o *objectContext) NewRow(value reflect.Value) { - o.value = value -} - -// FieldForIndexes returns the value for address. If the address is a nested struct, -// a nestedFieldScanner is returned instead of the standard value reference -func (o *objectContext) FieldForIndexes(indexes []int) reflect.Value { - if len(indexes) == 1 { - return reflectx.FieldByIndexes(o.value, indexes) - } - - obj := &nestedFieldScanner{ - parent: o, - indexes: indexes, - } - - return reflect.ValueOf(obj).Elem() -} - -// nestedFieldScanner will only forward the Scan to the nested value if +// optDest will only forward the Scan to the nested value if // the database value is not nil. -type nestedFieldScanner struct { - parent *objectContext - indexes []int -} +type optDest func() interface{} // Scan implements sql.Scanner. -func (o *nestedFieldScanner) Scan(src interface{}) error { +func (dest optDest) Scan(src interface{}) error { if src == nil { return nil } - dest := reflectx.FieldByIndexes(o.parent.value, o.indexes) - return convertAssign(dest.Addr().Interface(), src) + return convertAssign(dest(), src) } From e4499162e4642b7d741f8bc45cfc11baa9fe149a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Krzysztof=20Szafra=C5=84ski?= Date: Wed, 16 Oct 2024 18:48:36 +0200 Subject: [PATCH 7/8] Add explanatory comment --- sqlx.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sqlx.go b/sqlx.go index c2f500c..dda1ce6 100644 --- a/sqlx.go +++ b/sqlx.go @@ -1035,6 +1035,9 @@ func fieldsByTraversal(v reflect.Value, traversals [][]int, values []interface{} } else if len(traversal) == 1 { values[i] = reflectx.FieldByIndexes(v, traversal).Addr().Interface() } else { + // reflectx.FieldByIndexes initializes pointer fields, including pointers to nested structs. + // Use optDest to delay it until the first non-NULL value is scanned into a field of a nested struct. + // That way we can support LEFT JOINs with optional nested structs. traversal := traversal values[i] = optDest(func() interface{} { return reflectx.FieldByIndexes(v, traversal).Addr().Interface() From 26b1bb14f4ed5ee7f216300909584575e8fc869b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Krzysztof=20Szafra=C5=84ski?= Date: Wed, 16 Oct 2024 19:18:06 +0200 Subject: [PATCH 8/8] Add test for an optional struct inside an optional struct --- sqlx_context_test.go | 164 +++++++++++++++++++++++++++++++++++++------ 1 file changed, 144 insertions(+), 20 deletions(-) diff --git a/sqlx_context_test.go b/sqlx_context_test.go index 73e4f5d..c5e81bc 100644 --- a/sqlx_context_test.go +++ b/sqlx_context_test.go @@ -437,12 +437,17 @@ func TestNamedQueryContext(t *testing.T) { "FIRST" text NULL, last_name text NULL, "EMAIL" text NULL + ); + CREATE TABLE persondetails ( + email text NULL, + notes text NULL );`, drop: ` drop table person; drop table jsperson; drop table place; drop table placeperson; + drop table persondetails; `, } @@ -648,8 +653,8 @@ func TestNamedQueryContext(t *testing.T) { type Owner struct { Email *string `db:"email"` - FirstName string `db:"first_name"` - LastName string `db:"last_name"` + FirstName string `db:"first_name"` + LastName string `db:"last_name"` } // Test optional nested structs with left join @@ -680,11 +685,11 @@ func TestNamedQueryContext(t *testing.T) { pp3 := &PlaceOwner{} rows, err = db.NamedQueryContext(ctx, ` SELECT + place.id AS "place.id", + place.name AS "place.name", placeperson.first_name "owner.first_name", placeperson.last_name "owner.last_name", - placeperson.email "owner.email", - place.id AS "place.id", - place.name AS "place.name" + placeperson.email "owner.email" FROM place LEFT JOIN placeperson ON false -- null left join WHERE @@ -698,7 +703,7 @@ func TestNamedQueryContext(t *testing.T) { t.Error(err) } if pp3.Owner != nil { - t.Error("Expected `Owner`, to be nil") + t.Error("Expected `Owner` to be nil") } if pp3.Place.Name.String != "the-house" { t.Error("Expected place name of `the-house`, got " + pp3.Place.Name.String) @@ -710,41 +715,160 @@ func TestNamedQueryContext(t *testing.T) { rows.Close() - pp3 = &PlaceOwner{} + pp4 := &PlaceOwner{} rows, err = db.NamedQueryContext(ctx, ` SELECT + place.id AS "place.id", + place.name AS "place.name", placeperson.first_name "owner.first_name", placeperson.last_name "owner.last_name", - placeperson.email "owner.email", + placeperson.email "owner.email" + FROM place + LEFT JOIN placeperson ON placeperson.place_id = place.id + WHERE + place.id=:place.id`, pp) + if err != nil { + log.Fatal(err) + } + for rows.Next() { + err = rows.StructScan(pp4) + if err != nil { + t.Error(err) + } + if pp4.Owner == nil { + t.Error("Expected `Owner` to not be nil") + } + if pp4.Owner.FirstName != "ben" { + t.Error("Expected first name of `ben`, got " + pp4.Owner.FirstName) + } + if pp4.Owner.LastName != "doe" { + t.Error("Expected first name of `doe`, got " + pp4.Owner.LastName) + } + if pp4.Place.Name.String != "the-house" { + t.Error("Expected place name of `the-house`, got " + pp4.Place.Name.String) + } + if pp4.Place.ID != pp.Place.ID { + t.Errorf("Expected place name of %v, got %v", pp.Place.ID, pp4.Place.ID) + } + } + + type Details struct { + Email string `db:"email"` + Notes string `db:"notes"` + } + + type OwnerDetails struct { + Email *string `db:"email"` + FirstName string `db:"first_name"` + LastName string `db:"last_name"` + Details *Details `db:"details"` + } + + type PlaceOwnerDetails struct { + Place Place `db:"place"` + Owner *OwnerDetails `db:"owner"` + } + + pp5 := &PlaceOwnerDetails{} + rows, err = db.NamedQueryContext(ctx, ` + SELECT place.id AS "place.id", - place.name AS "place.name" + place.name AS "place.name", + placeperson.first_name "owner.first_name", + placeperson.last_name "owner.last_name", + placeperson.email "owner.email", + persondetails.email "owner.details.email", + persondetails.notes "owner.details.notes" FROM place - left JOIN placeperson ON placeperson.place_id = place.id + LEFT JOIN placeperson ON placeperson.place_id = place.id + LEFT JOIN persondetails ON false WHERE place.id=:place.id`, pp) if err != nil { log.Fatal(err) } for rows.Next() { - err = rows.StructScan(pp3) + err = rows.StructScan(pp5) if err != nil { t.Error(err) } - if pp3.Owner == nil { + if pp5.Owner == nil { t.Error("Expected `Owner`, to not be nil") } + if pp5.Owner.FirstName != "ben" { + t.Error("Expected first name of `ben`, got " + pp5.Owner.FirstName) + } + if pp5.Owner.LastName != "doe" { + t.Error("Expected first name of `doe`, got " + pp5.Owner.LastName) + } + if pp5.Place.Name.String != "the-house" { + t.Error("Expected place name of `the-house`, got " + pp5.Place.Name.String) + } + if pp5.Place.ID != pp.Place.ID { + t.Errorf("Expected place name of %v, got %v", pp.Place.ID, pp5.Place.ID) + } + if pp5.Owner.Details != nil { + t.Error("Expected `Details` to be nil") + } + } + + details := Details{ + Email: pp.Email.String, + Notes: "this is a test person", + } - if pp3.Owner.FirstName != "ben" { - t.Error("Expected first name of `ben`, got " + pp3.Owner.FirstName) + q6 := `INSERT INTO persondetails (email, notes) VALUES (:email, :notes)` + _, err = db.NamedExecContext(ctx, q6, details) + if err != nil { + log.Fatal(err) + } + + pp6 := &PlaceOwnerDetails{} + rows, err = db.NamedQueryContext(ctx, ` + SELECT + place.id AS "place.id", + place.name AS "place.name", + placeperson.first_name "owner.first_name", + placeperson.last_name "owner.last_name", + placeperson.email "owner.email", + persondetails.email "owner.details.email", + persondetails.notes "owner.details.notes" + FROM place + LEFT JOIN placeperson ON placeperson.place_id = place.id + LEFT JOIN persondetails ON persondetails.email = placeperson.email + WHERE + place.id=:place.id`, pp) + if err != nil { + log.Fatal(err) + } + for rows.Next() { + err = rows.StructScan(pp6) + if err != nil { + t.Error(err) } - if pp3.Owner.LastName != "doe" { - t.Error("Expected first name of `doe`, got " + pp3.Owner.LastName) + if pp6.Owner == nil { + t.Error("Expected `Owner` to not be nil") } - if pp3.Place.Name.String != "the-house" { - t.Error("Expected place name of `the-house`, got " + pp3.Place.Name.String) + if pp6.Owner.FirstName != "ben" { + t.Error("Expected first name of `ben`, got " + pp6.Owner.FirstName) } - if pp3.Place.ID != pp.Place.ID { - t.Errorf("Expected place name of %v, got %v", pp.Place.ID, pp3.Place.ID) + if pp6.Owner.LastName != "doe" { + t.Error("Expected first name of `doe`, got " + pp6.Owner.LastName) + } + if pp6.Place.Name.String != "the-house" { + t.Error("Expected place name of `the-house`, got " + pp6.Place.Name.String) + } + if pp6.Place.ID != pp.Place.ID { + t.Errorf("Expected place name of %v, got %v", pp.Place.ID, pp6.Place.ID) + } + if pp6.Owner.Details == nil { + t.Error("Expected `Details` to not be nil") + } + if pp6.Owner.Details.Email != details.Email { + t.Errorf("Expected details email of %v, got %v", details.Email, pp6.Owner.Details.Email) + } + if pp6.Owner.Details.Notes != details.Notes { + t.Errorf("Expected details notes of %v, got %v", details.Notes, pp6.Owner.Details.Notes) } } })