From 87edeed4461d4e804455f34ed93cb7d01c8eaca9 Mon Sep 17 00:00:00 2001 From: Zhongyi Tong Date: Wed, 1 Nov 2023 19:45:02 +0000 Subject: [PATCH 01/10] Update reflectx to allow for optional nested structs --- reflectx/reflect.go | 188 +++++++++++++++++++++++++++++++++++++++++++ sqlx.go | 16 ++-- sqlx_context_test.go | 104 ++++++++++++++++++++++++ 3 files changed, 303 insertions(+), 5 deletions(-) diff --git a/reflectx/reflect.go b/reflectx/reflect.go index 8ec6a13..e8dc926 100644 --- a/reflectx/reflect.go +++ b/reflectx/reflect.go @@ -6,8 +6,11 @@ package reflectx import ( + "database/sql" + "fmt" "reflect" "runtime" + "strconv" "strings" "sync" ) @@ -200,6 +203,191 @@ func (m *Mapper) TraversalsByNameFunc(t reflect.Type, names []string, fn func(in return nil } +// ObjectContext provides a single layer to abstract away +// nested struct scanning functionality +type ObjectContext struct { + value reflect.Value +} + +func NewObjectContext() *ObjectContext { + return &ObjectContext{} +} + +// NewRow updates the object reference. +// This ensures all columns point to the same object +func (o *ObjectContext) NewRow(value reflect.Value) { + o.value = value +} + +// FieldForIndexes returns the value for address. If the address is a nested struct, +// a nestedFieldScanner is returned instead of the standard value reference +func (o *ObjectContext) FieldForIndexes(indexes []int) reflect.Value { + if len(indexes) == 1 { + val := FieldByIndexes(o.value, indexes) + return val + } + + obj := &nestedFieldScanner{ + parent: o, + indexes: indexes, + } + + v := reflect.ValueOf(obj).Elem() + return v +} + +// nestedFieldScanner will only forward the Scan to the nested value if +// the database value is not nil. +type nestedFieldScanner struct { + parent *ObjectContext + indexes []int +} + +// Scan implements sql.Scanner. +// This method largely mirrors the sql.convertAssign() method with some minor changes +func (o *nestedFieldScanner) Scan(src interface{}) error { + if src == nil { + return nil + } + + dv := FieldByIndexes(o.parent.value, o.indexes) + // Dereference pointer fields to avoid double pointers **T + if dv.Kind() == reflect.Pointer { + dv.Set(reflect.New(dv.Type().Elem())) + dv = dv.Elem() + } + iface := dv.Addr().Interface() + + if scan, ok := iface.(sql.Scanner); ok { + return scan.Scan(src) + } + + sv := reflect.ValueOf(src) + + // below is taken from https://cs.opensource.google/go/go/+/refs/tags/go1.19.5:src/database/sql/convert.go + // with a few minor edits + + if sv.IsValid() && sv.Type().AssignableTo(dv.Type()) { + switch b := src.(type) { + case []byte: + dv.Set(reflect.ValueOf(bytesClone(b))) + default: + dv.Set(sv) + } + + return nil + } + + if dv.Kind() == sv.Kind() && sv.Type().ConvertibleTo(dv.Type()) { + dv.Set(sv.Convert(dv.Type())) + return nil + } + + // The following conversions use a string value as an intermediate representation + // to convert between various numeric types. + // + // This also allows scanning into user defined types such as "type Int int64". + // For symmetry, also check for string destination types. + switch dv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + if src == nil { + return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind()) + } + s := asString(src) + i64, err := strconv.ParseInt(s, 10, dv.Type().Bits()) + if err != nil { + err = strconvErr(err) + return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) + } + dv.SetInt(i64) + return nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + if src == nil { + return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind()) + } + s := asString(src) + u64, err := strconv.ParseUint(s, 10, dv.Type().Bits()) + if err != nil { + err = strconvErr(err) + return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) + } + dv.SetUint(u64) + return nil + case reflect.Float32, reflect.Float64: + if src == nil { + return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind()) + } + s := asString(src) + f64, err := strconv.ParseFloat(s, dv.Type().Bits()) + if err != nil { + err = strconvErr(err) + return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) + } + dv.SetFloat(f64) + return nil + case reflect.String: + if src == nil { + return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind()) + } + switch v := src.(type) { + case string: + dv.SetString(v) + return nil + case []byte: + dv.SetString(string(v)) + return nil + } + } + + return fmt.Errorf("don't know how to parse type %T -> %T", src, iface) +} + +// returns internal conversion error if available +// taken from https://cs.opensource.google/go/go/+/refs/tags/go1.19.5:src/database/sql/convert.go +func strconvErr(err error) error { + if ne, ok := err.(*strconv.NumError); ok { + return ne.Err + } + return err +} + +// converts value to it's string value +// taken from https://cs.opensource.google/go/go/+/refs/tags/go1.19.5:src/database/sql/convert.go +func asString(src interface{}) string { + switch v := src.(type) { + case string: + return v + case []byte: + return string(v) + } + rv := reflect.ValueOf(src) + switch rv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return strconv.FormatInt(rv.Int(), 10) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return strconv.FormatUint(rv.Uint(), 10) + case reflect.Float64: + return strconv.FormatFloat(rv.Float(), 'g', -1, 64) + case reflect.Float32: + return strconv.FormatFloat(rv.Float(), 'g', -1, 32) + case reflect.Bool: + return strconv.FormatBool(rv.Bool()) + } + return fmt.Sprintf("%v", src) +} + +// bytesClone returns a copy of b[:len(b)]. +// The result may have additional unused capacity. +// Clone(nil) returns nil. +// +// bytesClone is a mirror of bytes.Clone while our go.mod is on an older version +func bytesClone(b []byte) []byte { + if b == nil { + return nil + } + return append([]byte{}, b...) +} + // FieldByIndexes returns a value for the field given by the struct traversal // for the given value. func FieldByIndexes(v reflect.Value, indexes []int) reflect.Value { diff --git a/sqlx.go b/sqlx.go index 8259a4f..e0ef63d 100644 --- a/sqlx.go +++ b/sqlx.go @@ -624,7 +624,8 @@ func (r *Rows) StructScan(dest interface{}) error { r.started = true } - err := fieldsByTraversal(v, r.fields, r.values, true) + octx := reflectx.NewObjectContext() + err := fieldsByTraversal(octx, v, r.fields, r.values, true) if err != nil { return err } @@ -784,7 +785,9 @@ func (r *Row) scanAny(dest interface{}, structOnly bool) error { } values := make([]interface{}, len(columns)) - err = fieldsByTraversal(v, fields, values, true) + octx := reflectx.NewObjectContext() + + err = fieldsByTraversal(octx, v, fields, values, true) if err != nil { return err } @@ -951,13 +954,14 @@ func scanAll(rows rowsi, dest interface{}, structOnly bool) error { return fmt.Errorf("missing destination name %s in %T", columns[f], dest) } values = make([]interface{}, len(columns)) + octx := reflectx.NewObjectContext() for rows.Next() { // create a new struct type (which returns PtrTo) and indirect it vp = reflect.New(base) v = reflect.Indirect(vp) - err = fieldsByTraversal(v, fields, values, true) + err = fieldsByTraversal(octx, v, fields, values, true) if err != nil { return err } @@ -1023,18 +1027,20 @@ func baseType(t reflect.Type, expected reflect.Kind) (reflect.Type, error) { // when iterating over many rows. Empty traversals will get an interface pointer. // Because of the necessity of requesting ptrs or values, it's considered a bit too // specialized for inclusion in reflectx itself. -func fieldsByTraversal(v reflect.Value, traversals [][]int, values []interface{}, ptrs bool) error { +func fieldsByTraversal(octx *reflectx.ObjectContext, v reflect.Value, traversals [][]int, values []interface{}, ptrs bool) error { v = reflect.Indirect(v) if v.Kind() != reflect.Struct { return errors.New("argument not a struct") } + octx.NewRow(v) + for i, traversal := range traversals { if len(traversal) == 0 { values[i] = new(interface{}) continue } - f := reflectx.FieldByIndexes(v, traversal) + f := octx.FieldForIndexes(traversal) if ptrs { values[i] = f.Addr().Interface() } else { diff --git a/sqlx_context_test.go b/sqlx_context_test.go index 91c5cba..73e4f5d 100644 --- a/sqlx_context_test.go +++ b/sqlx_context_test.go @@ -643,6 +643,110 @@ func TestNamedQueryContext(t *testing.T) { t.Errorf("Expected place name of %v, got %v", pp.Place.ID, pp2.Place.ID) } } + + rows.Close() + + type Owner struct { + Email *string `db:"email"` + FirstName string `db:"first_name"` + LastName string `db:"last_name"` + } + + // Test optional nested structs with left join + type PlaceOwner struct { + Place Place `db:"place"` + Owner *Owner `db:"owner"` + } + + pl = Place{ + Name: sql.NullString{String: "the-house", Valid: true}, + } + + q4 := `INSERT INTO place (id, name) VALUES (2, :name)` + _, err = db.NamedExecContext(ctx, q4, pl) + if err != nil { + log.Fatal(err) + } + + id = 2 + pp.Place.ID = id + + q5 := `INSERT INTO placeperson (first_name, last_name, email, place_id) VALUES (:first_name, :last_name, :email, :place.id)` + _, err = db.NamedExecContext(ctx, q5, pp) + if err != nil { + log.Fatal(err) + } + + pp3 := &PlaceOwner{} + rows, err = db.NamedQueryContext(ctx, ` + SELECT + placeperson.first_name "owner.first_name", + placeperson.last_name "owner.last_name", + placeperson.email "owner.email", + place.id AS "place.id", + place.name AS "place.name" + FROM place + LEFT JOIN placeperson ON false -- null left join + WHERE + place.id=:place.id`, pp) + if err != nil { + log.Fatal(err) + } + for rows.Next() { + err = rows.StructScan(pp3) + if err != nil { + t.Error(err) + } + if pp3.Owner != nil { + t.Error("Expected `Owner`, to be nil") + } + if pp3.Place.Name.String != "the-house" { + t.Error("Expected place name of `the-house`, got " + pp3.Place.Name.String) + } + if pp3.Place.ID != pp.Place.ID { + t.Errorf("Expected place name of %v, got %v", pp.Place.ID, pp3.Place.ID) + } + } + + rows.Close() + + pp3 = &PlaceOwner{} + rows, err = db.NamedQueryContext(ctx, ` + SELECT + placeperson.first_name "owner.first_name", + placeperson.last_name "owner.last_name", + placeperson.email "owner.email", + place.id AS "place.id", + place.name AS "place.name" + FROM place + left JOIN placeperson ON placeperson.place_id = place.id + WHERE + place.id=:place.id`, pp) + if err != nil { + log.Fatal(err) + } + for rows.Next() { + err = rows.StructScan(pp3) + if err != nil { + t.Error(err) + } + if pp3.Owner == nil { + t.Error("Expected `Owner`, to not be nil") + } + + if pp3.Owner.FirstName != "ben" { + t.Error("Expected first name of `ben`, got " + pp3.Owner.FirstName) + } + if pp3.Owner.LastName != "doe" { + t.Error("Expected first name of `doe`, got " + pp3.Owner.LastName) + } + if pp3.Place.Name.String != "the-house" { + t.Error("Expected place name of `the-house`, got " + pp3.Place.Name.String) + } + if pp3.Place.ID != pp.Place.ID { + t.Errorf("Expected place name of %v, got %v", pp.Place.ID, pp3.Place.ID) + } + } }) } From 48580808cda90b2b1b62f64afd09737f0ffdffa0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Krzysztof=20Szafra=C5=84ski?= Date: Wed, 16 Oct 2024 18:07:13 +0200 Subject: [PATCH 02/10] Use go:linkname to call convertAssign instead of copying it --- reflectx/reflect.go | 154 +++----------------------------------------- 1 file changed, 8 insertions(+), 146 deletions(-) diff --git a/reflectx/reflect.go b/reflectx/reflect.go index e8dc926..ded9838 100644 --- a/reflectx/reflect.go +++ b/reflectx/reflect.go @@ -6,13 +6,11 @@ package reflectx import ( - "database/sql" - "fmt" "reflect" "runtime" - "strconv" "strings" "sync" + _ "unsafe" ) // A FieldInfo is metadata for a struct field. @@ -223,8 +221,7 @@ func (o *ObjectContext) NewRow(value reflect.Value) { // a nestedFieldScanner is returned instead of the standard value reference func (o *ObjectContext) FieldForIndexes(indexes []int) reflect.Value { if len(indexes) == 1 { - val := FieldByIndexes(o.value, indexes) - return val + return FieldByIndexes(o.value, indexes) } obj := &nestedFieldScanner{ @@ -232,8 +229,7 @@ func (o *ObjectContext) FieldForIndexes(indexes []int) reflect.Value { indexes: indexes, } - v := reflect.ValueOf(obj).Elem() - return v + return reflect.ValueOf(obj).Elem() } // nestedFieldScanner will only forward the Scan to the nested value if @@ -244,149 +240,16 @@ type nestedFieldScanner struct { } // Scan implements sql.Scanner. -// This method largely mirrors the sql.convertAssign() method with some minor changes func (o *nestedFieldScanner) Scan(src interface{}) error { if src == nil { return nil } - - dv := FieldByIndexes(o.parent.value, o.indexes) - // Dereference pointer fields to avoid double pointers **T - if dv.Kind() == reflect.Pointer { - dv.Set(reflect.New(dv.Type().Elem())) - dv = dv.Elem() - } - iface := dv.Addr().Interface() - - if scan, ok := iface.(sql.Scanner); ok { - return scan.Scan(src) - } - - sv := reflect.ValueOf(src) - - // below is taken from https://cs.opensource.google/go/go/+/refs/tags/go1.19.5:src/database/sql/convert.go - // with a few minor edits - - if sv.IsValid() && sv.Type().AssignableTo(dv.Type()) { - switch b := src.(type) { - case []byte: - dv.Set(reflect.ValueOf(bytesClone(b))) - default: - dv.Set(sv) - } - - return nil - } - - if dv.Kind() == sv.Kind() && sv.Type().ConvertibleTo(dv.Type()) { - dv.Set(sv.Convert(dv.Type())) - return nil - } - - // The following conversions use a string value as an intermediate representation - // to convert between various numeric types. - // - // This also allows scanning into user defined types such as "type Int int64". - // For symmetry, also check for string destination types. - switch dv.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - if src == nil { - return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind()) - } - s := asString(src) - i64, err := strconv.ParseInt(s, 10, dv.Type().Bits()) - if err != nil { - err = strconvErr(err) - return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) - } - dv.SetInt(i64) - return nil - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - if src == nil { - return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind()) - } - s := asString(src) - u64, err := strconv.ParseUint(s, 10, dv.Type().Bits()) - if err != nil { - err = strconvErr(err) - return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) - } - dv.SetUint(u64) - return nil - case reflect.Float32, reflect.Float64: - if src == nil { - return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind()) - } - s := asString(src) - f64, err := strconv.ParseFloat(s, dv.Type().Bits()) - if err != nil { - err = strconvErr(err) - return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) - } - dv.SetFloat(f64) - return nil - case reflect.String: - if src == nil { - return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind()) - } - switch v := src.(type) { - case string: - dv.SetString(v) - return nil - case []byte: - dv.SetString(string(v)) - return nil - } - } - - return fmt.Errorf("don't know how to parse type %T -> %T", src, iface) + dest := FieldByIndexes(o.parent.value, o.indexes) + return convertAssign(dest.Addr().Interface(), src) } -// returns internal conversion error if available -// taken from https://cs.opensource.google/go/go/+/refs/tags/go1.19.5:src/database/sql/convert.go -func strconvErr(err error) error { - if ne, ok := err.(*strconv.NumError); ok { - return ne.Err - } - return err -} - -// converts value to it's string value -// taken from https://cs.opensource.google/go/go/+/refs/tags/go1.19.5:src/database/sql/convert.go -func asString(src interface{}) string { - switch v := src.(type) { - case string: - return v - case []byte: - return string(v) - } - rv := reflect.ValueOf(src) - switch rv.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return strconv.FormatInt(rv.Int(), 10) - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - return strconv.FormatUint(rv.Uint(), 10) - case reflect.Float64: - return strconv.FormatFloat(rv.Float(), 'g', -1, 64) - case reflect.Float32: - return strconv.FormatFloat(rv.Float(), 'g', -1, 32) - case reflect.Bool: - return strconv.FormatBool(rv.Bool()) - } - return fmt.Sprintf("%v", src) -} - -// bytesClone returns a copy of b[:len(b)]. -// The result may have additional unused capacity. -// Clone(nil) returns nil. -// -// bytesClone is a mirror of bytes.Clone while our go.mod is on an older version -func bytesClone(b []byte) []byte { - if b == nil { - return nil - } - return append([]byte{}, b...) -} +//go:linkname convertAssign database/sql.convertAssign +func convertAssign(dest, src interface{}) error // FieldByIndexes returns a value for the field given by the struct traversal // for the given value. @@ -395,8 +258,7 @@ func FieldByIndexes(v reflect.Value, indexes []int) reflect.Value { v = reflect.Indirect(v).Field(i) // if this is a pointer and it's nil, allocate a new value and set it if v.Kind() == reflect.Ptr && v.IsNil() { - alloc := reflect.New(Deref(v.Type())) - v.Set(alloc) + v.Set(reflect.New(v.Type().Elem())) } if v.Kind() == reflect.Map && v.IsNil() { v.Set(reflect.MakeMap(v.Type())) From f537847a9abeada8c5e4933756eda77c28a32c9b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Krzysztof=20Szafra=C5=84ski?= Date: Wed, 16 Oct 2024 18:14:46 +0200 Subject: [PATCH 03/10] Move ObjectContext out of reflectx where it doesn't belong --- convert.go | 8 +++++++ reflectx/reflect.go | 51 ----------------------------------------- sqlx.go | 55 +++++++++++++++++++++++++++++++++++++++++---- 3 files changed, 59 insertions(+), 55 deletions(-) create mode 100644 convert.go diff --git a/convert.go b/convert.go new file mode 100644 index 0000000..3964a91 --- /dev/null +++ b/convert.go @@ -0,0 +1,8 @@ +package sqlx + +import ( + _ "unsafe" +) + +//go:linkname convertAssign database/sql.convertAssign +func convertAssign(dest, src interface{}) error diff --git a/reflectx/reflect.go b/reflectx/reflect.go index ded9838..beaaa43 100644 --- a/reflectx/reflect.go +++ b/reflectx/reflect.go @@ -10,7 +10,6 @@ import ( "runtime" "strings" "sync" - _ "unsafe" ) // A FieldInfo is metadata for a struct field. @@ -201,56 +200,6 @@ func (m *Mapper) TraversalsByNameFunc(t reflect.Type, names []string, fn func(in return nil } -// ObjectContext provides a single layer to abstract away -// nested struct scanning functionality -type ObjectContext struct { - value reflect.Value -} - -func NewObjectContext() *ObjectContext { - return &ObjectContext{} -} - -// NewRow updates the object reference. -// This ensures all columns point to the same object -func (o *ObjectContext) NewRow(value reflect.Value) { - o.value = value -} - -// FieldForIndexes returns the value for address. If the address is a nested struct, -// a nestedFieldScanner is returned instead of the standard value reference -func (o *ObjectContext) FieldForIndexes(indexes []int) reflect.Value { - if len(indexes) == 1 { - return FieldByIndexes(o.value, indexes) - } - - obj := &nestedFieldScanner{ - parent: o, - indexes: indexes, - } - - return reflect.ValueOf(obj).Elem() -} - -// nestedFieldScanner will only forward the Scan to the nested value if -// the database value is not nil. -type nestedFieldScanner struct { - parent *ObjectContext - indexes []int -} - -// Scan implements sql.Scanner. -func (o *nestedFieldScanner) Scan(src interface{}) error { - if src == nil { - return nil - } - dest := FieldByIndexes(o.parent.value, o.indexes) - return convertAssign(dest.Addr().Interface(), src) -} - -//go:linkname convertAssign database/sql.convertAssign -func convertAssign(dest, src interface{}) error - // FieldByIndexes returns a value for the field given by the struct traversal // for the given value. func FieldByIndexes(v reflect.Value, indexes []int) reflect.Value { diff --git a/sqlx.go b/sqlx.go index e0ef63d..b0b1038 100644 --- a/sqlx.go +++ b/sqlx.go @@ -624,7 +624,7 @@ func (r *Rows) StructScan(dest interface{}) error { r.started = true } - octx := reflectx.NewObjectContext() + octx := newObjectContext() err := fieldsByTraversal(octx, v, r.fields, r.values, true) if err != nil { return err @@ -785,7 +785,7 @@ func (r *Row) scanAny(dest interface{}, structOnly bool) error { } values := make([]interface{}, len(columns)) - octx := reflectx.NewObjectContext() + octx := newObjectContext() err = fieldsByTraversal(octx, v, fields, values, true) if err != nil { @@ -954,7 +954,7 @@ func scanAll(rows rowsi, dest interface{}, structOnly bool) error { return fmt.Errorf("missing destination name %s in %T", columns[f], dest) } values = make([]interface{}, len(columns)) - octx := reflectx.NewObjectContext() + octx := newObjectContext() for rows.Next() { // create a new struct type (which returns PtrTo) and indirect it @@ -1027,7 +1027,7 @@ func baseType(t reflect.Type, expected reflect.Kind) (reflect.Type, error) { // when iterating over many rows. Empty traversals will get an interface pointer. // Because of the necessity of requesting ptrs or values, it's considered a bit too // specialized for inclusion in reflectx itself. -func fieldsByTraversal(octx *reflectx.ObjectContext, v reflect.Value, traversals [][]int, values []interface{}, ptrs bool) error { +func fieldsByTraversal(octx *objectContext, v reflect.Value, traversals [][]int, values []interface{}, ptrs bool) error { v = reflect.Indirect(v) if v.Kind() != reflect.Struct { return errors.New("argument not a struct") @@ -1058,3 +1058,50 @@ func missingFields(transversals [][]int) (field int, err error) { } return 0, nil } + +// objectContext provides a single layer to abstract away +// nested struct scanning functionality +type objectContext struct { + value reflect.Value +} + +func newObjectContext() *objectContext { + return &objectContext{} +} + +// NewRow updates the object reference. +// This ensures all columns point to the same object +func (o *objectContext) NewRow(value reflect.Value) { + o.value = value +} + +// FieldForIndexes returns the value for address. If the address is a nested struct, +// a nestedFieldScanner is returned instead of the standard value reference +func (o *objectContext) FieldForIndexes(indexes []int) reflect.Value { + if len(indexes) == 1 { + return reflectx.FieldByIndexes(o.value, indexes) + } + + obj := &nestedFieldScanner{ + parent: o, + indexes: indexes, + } + + return reflect.ValueOf(obj).Elem() +} + +// nestedFieldScanner will only forward the Scan to the nested value if +// the database value is not nil. +type nestedFieldScanner struct { + parent *objectContext + indexes []int +} + +// Scan implements sql.Scanner. +func (o *nestedFieldScanner) Scan(src interface{}) error { + if src == nil { + return nil + } + dest := reflectx.FieldByIndexes(o.parent.value, o.indexes) + return convertAssign(dest.Addr().Interface(), src) +} From f11fa570ece992a4356a7313f3d61d9dd7721d6d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Krzysztof=20Szafra=C5=84ski?= Date: Wed, 16 Oct 2024 18:16:15 +0200 Subject: [PATCH 04/10] Simplify fieldsByTraversal, ptrs is always true --- sqlx.go | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/sqlx.go b/sqlx.go index b0b1038..1f286c3 100644 --- a/sqlx.go +++ b/sqlx.go @@ -625,7 +625,7 @@ func (r *Rows) StructScan(dest interface{}) error { } octx := newObjectContext() - err := fieldsByTraversal(octx, v, r.fields, r.values, true) + err := fieldsByTraversal(octx, v, r.fields, r.values) if err != nil { return err } @@ -787,7 +787,7 @@ func (r *Row) scanAny(dest interface{}, structOnly bool) error { octx := newObjectContext() - err = fieldsByTraversal(octx, v, fields, values, true) + err = fieldsByTraversal(octx, v, fields, values) if err != nil { return err } @@ -961,7 +961,7 @@ func scanAll(rows rowsi, dest interface{}, structOnly bool) error { vp = reflect.New(base) v = reflect.Indirect(vp) - err = fieldsByTraversal(octx, v, fields, values, true) + err = fieldsByTraversal(octx, v, fields, values) if err != nil { return err } @@ -1027,7 +1027,7 @@ func baseType(t reflect.Type, expected reflect.Kind) (reflect.Type, error) { // when iterating over many rows. Empty traversals will get an interface pointer. // Because of the necessity of requesting ptrs or values, it's considered a bit too // specialized for inclusion in reflectx itself. -func fieldsByTraversal(octx *objectContext, v reflect.Value, traversals [][]int, values []interface{}, ptrs bool) error { +func fieldsByTraversal(octx *objectContext, v reflect.Value, traversals [][]int, values []interface{}) error { v = reflect.Indirect(v) if v.Kind() != reflect.Struct { return errors.New("argument not a struct") @@ -1041,11 +1041,7 @@ func fieldsByTraversal(octx *objectContext, v reflect.Value, traversals [][]int, continue } f := octx.FieldForIndexes(traversal) - if ptrs { - values[i] = f.Addr().Interface() - } else { - values[i] = f.Interface() - } + values[i] = f.Addr().Interface() } return nil } From cb724a28bc9425fb28ba8f6f6cbe28899dd5588b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Krzysztof=20Szafra=C5=84ski?= Date: Wed, 16 Oct 2024 18:16:47 +0200 Subject: [PATCH 05/10] Fix typo --- sqlx.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sqlx.go b/sqlx.go index 1f286c3..a656626 100644 --- a/sqlx.go +++ b/sqlx.go @@ -1046,8 +1046,8 @@ func fieldsByTraversal(octx *objectContext, v reflect.Value, traversals [][]int, return nil } -func missingFields(transversals [][]int) (field int, err error) { - for i, t := range transversals { +func missingFields(traversals [][]int) (field int, err error) { + for i, t := range traversals { if len(t) == 0 { return i, errors.New("missing field") } From a58a604a216833ea36792422b3214a622f8fae9d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Krzysztof=20Szafra=C5=84ski?= Date: Wed, 16 Oct 2024 18:27:53 +0200 Subject: [PATCH 06/10] Simplify the code by eliminating objectContext and using simple optDest --- sqlx.go | 67 +++++++++++++-------------------------------------------- 1 file changed, 15 insertions(+), 52 deletions(-) diff --git a/sqlx.go b/sqlx.go index a656626..c2f500c 100644 --- a/sqlx.go +++ b/sqlx.go @@ -624,8 +624,7 @@ func (r *Rows) StructScan(dest interface{}) error { r.started = true } - octx := newObjectContext() - err := fieldsByTraversal(octx, v, r.fields, r.values) + err := fieldsByTraversal(v, r.fields, r.values) if err != nil { return err } @@ -785,9 +784,7 @@ func (r *Row) scanAny(dest interface{}, structOnly bool) error { } values := make([]interface{}, len(columns)) - octx := newObjectContext() - - err = fieldsByTraversal(octx, v, fields, values) + err = fieldsByTraversal(v, fields, values) if err != nil { return err } @@ -954,14 +951,13 @@ func scanAll(rows rowsi, dest interface{}, structOnly bool) error { return fmt.Errorf("missing destination name %s in %T", columns[f], dest) } values = make([]interface{}, len(columns)) - octx := newObjectContext() for rows.Next() { // create a new struct type (which returns PtrTo) and indirect it vp = reflect.New(base) v = reflect.Indirect(vp) - err = fieldsByTraversal(octx, v, fields, values) + err = fieldsByTraversal(v, fields, values) if err != nil { return err } @@ -1027,21 +1023,23 @@ func baseType(t reflect.Type, expected reflect.Kind) (reflect.Type, error) { // when iterating over many rows. Empty traversals will get an interface pointer. // Because of the necessity of requesting ptrs or values, it's considered a bit too // specialized for inclusion in reflectx itself. -func fieldsByTraversal(octx *objectContext, v reflect.Value, traversals [][]int, values []interface{}) error { +func fieldsByTraversal(v reflect.Value, traversals [][]int, values []interface{}) error { v = reflect.Indirect(v) if v.Kind() != reflect.Struct { return errors.New("argument not a struct") } - octx.NewRow(v) - for i, traversal := range traversals { if len(traversal) == 0 { values[i] = new(interface{}) - continue + } else if len(traversal) == 1 { + values[i] = reflectx.FieldByIndexes(v, traversal).Addr().Interface() + } else { + traversal := traversal + values[i] = optDest(func() interface{} { + return reflectx.FieldByIndexes(v, traversal).Addr().Interface() + }) } - f := octx.FieldForIndexes(traversal) - values[i] = f.Addr().Interface() } return nil } @@ -1055,49 +1053,14 @@ func missingFields(traversals [][]int) (field int, err error) { return 0, nil } -// objectContext provides a single layer to abstract away -// nested struct scanning functionality -type objectContext struct { - value reflect.Value -} - -func newObjectContext() *objectContext { - return &objectContext{} -} - -// NewRow updates the object reference. -// This ensures all columns point to the same object -func (o *objectContext) NewRow(value reflect.Value) { - o.value = value -} - -// FieldForIndexes returns the value for address. If the address is a nested struct, -// a nestedFieldScanner is returned instead of the standard value reference -func (o *objectContext) FieldForIndexes(indexes []int) reflect.Value { - if len(indexes) == 1 { - return reflectx.FieldByIndexes(o.value, indexes) - } - - obj := &nestedFieldScanner{ - parent: o, - indexes: indexes, - } - - return reflect.ValueOf(obj).Elem() -} - -// nestedFieldScanner will only forward the Scan to the nested value if +// optDest will only forward the Scan to the nested value if // the database value is not nil. -type nestedFieldScanner struct { - parent *objectContext - indexes []int -} +type optDest func() interface{} // Scan implements sql.Scanner. -func (o *nestedFieldScanner) Scan(src interface{}) error { +func (dest optDest) Scan(src interface{}) error { if src == nil { return nil } - dest := reflectx.FieldByIndexes(o.parent.value, o.indexes) - return convertAssign(dest.Addr().Interface(), src) + return convertAssign(dest(), src) } From e4499162e4642b7d741f8bc45cfc11baa9fe149a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Krzysztof=20Szafra=C5=84ski?= Date: Wed, 16 Oct 2024 18:48:36 +0200 Subject: [PATCH 07/10] Add explanatory comment --- sqlx.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sqlx.go b/sqlx.go index c2f500c..dda1ce6 100644 --- a/sqlx.go +++ b/sqlx.go @@ -1035,6 +1035,9 @@ func fieldsByTraversal(v reflect.Value, traversals [][]int, values []interface{} } else if len(traversal) == 1 { values[i] = reflectx.FieldByIndexes(v, traversal).Addr().Interface() } else { + // reflectx.FieldByIndexes initializes pointer fields, including pointers to nested structs. + // Use optDest to delay it until the first non-NULL value is scanned into a field of a nested struct. + // That way we can support LEFT JOINs with optional nested structs. traversal := traversal values[i] = optDest(func() interface{} { return reflectx.FieldByIndexes(v, traversal).Addr().Interface() From 26b1bb14f4ed5ee7f216300909584575e8fc869b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Krzysztof=20Szafra=C5=84ski?= Date: Wed, 16 Oct 2024 19:18:06 +0200 Subject: [PATCH 08/10] Add test for an optional struct inside an optional struct --- sqlx_context_test.go | 164 +++++++++++++++++++++++++++++++++++++------ 1 file changed, 144 insertions(+), 20 deletions(-) diff --git a/sqlx_context_test.go b/sqlx_context_test.go index 73e4f5d..c5e81bc 100644 --- a/sqlx_context_test.go +++ b/sqlx_context_test.go @@ -437,12 +437,17 @@ func TestNamedQueryContext(t *testing.T) { "FIRST" text NULL, last_name text NULL, "EMAIL" text NULL + ); + CREATE TABLE persondetails ( + email text NULL, + notes text NULL );`, drop: ` drop table person; drop table jsperson; drop table place; drop table placeperson; + drop table persondetails; `, } @@ -648,8 +653,8 @@ func TestNamedQueryContext(t *testing.T) { type Owner struct { Email *string `db:"email"` - FirstName string `db:"first_name"` - LastName string `db:"last_name"` + FirstName string `db:"first_name"` + LastName string `db:"last_name"` } // Test optional nested structs with left join @@ -680,11 +685,11 @@ func TestNamedQueryContext(t *testing.T) { pp3 := &PlaceOwner{} rows, err = db.NamedQueryContext(ctx, ` SELECT + place.id AS "place.id", + place.name AS "place.name", placeperson.first_name "owner.first_name", placeperson.last_name "owner.last_name", - placeperson.email "owner.email", - place.id AS "place.id", - place.name AS "place.name" + placeperson.email "owner.email" FROM place LEFT JOIN placeperson ON false -- null left join WHERE @@ -698,7 +703,7 @@ func TestNamedQueryContext(t *testing.T) { t.Error(err) } if pp3.Owner != nil { - t.Error("Expected `Owner`, to be nil") + t.Error("Expected `Owner` to be nil") } if pp3.Place.Name.String != "the-house" { t.Error("Expected place name of `the-house`, got " + pp3.Place.Name.String) @@ -710,41 +715,160 @@ func TestNamedQueryContext(t *testing.T) { rows.Close() - pp3 = &PlaceOwner{} + pp4 := &PlaceOwner{} rows, err = db.NamedQueryContext(ctx, ` SELECT + place.id AS "place.id", + place.name AS "place.name", placeperson.first_name "owner.first_name", placeperson.last_name "owner.last_name", - placeperson.email "owner.email", + placeperson.email "owner.email" + FROM place + LEFT JOIN placeperson ON placeperson.place_id = place.id + WHERE + place.id=:place.id`, pp) + if err != nil { + log.Fatal(err) + } + for rows.Next() { + err = rows.StructScan(pp4) + if err != nil { + t.Error(err) + } + if pp4.Owner == nil { + t.Error("Expected `Owner` to not be nil") + } + if pp4.Owner.FirstName != "ben" { + t.Error("Expected first name of `ben`, got " + pp4.Owner.FirstName) + } + if pp4.Owner.LastName != "doe" { + t.Error("Expected first name of `doe`, got " + pp4.Owner.LastName) + } + if pp4.Place.Name.String != "the-house" { + t.Error("Expected place name of `the-house`, got " + pp4.Place.Name.String) + } + if pp4.Place.ID != pp.Place.ID { + t.Errorf("Expected place name of %v, got %v", pp.Place.ID, pp4.Place.ID) + } + } + + type Details struct { + Email string `db:"email"` + Notes string `db:"notes"` + } + + type OwnerDetails struct { + Email *string `db:"email"` + FirstName string `db:"first_name"` + LastName string `db:"last_name"` + Details *Details `db:"details"` + } + + type PlaceOwnerDetails struct { + Place Place `db:"place"` + Owner *OwnerDetails `db:"owner"` + } + + pp5 := &PlaceOwnerDetails{} + rows, err = db.NamedQueryContext(ctx, ` + SELECT place.id AS "place.id", - place.name AS "place.name" + place.name AS "place.name", + placeperson.first_name "owner.first_name", + placeperson.last_name "owner.last_name", + placeperson.email "owner.email", + persondetails.email "owner.details.email", + persondetails.notes "owner.details.notes" FROM place - left JOIN placeperson ON placeperson.place_id = place.id + LEFT JOIN placeperson ON placeperson.place_id = place.id + LEFT JOIN persondetails ON false WHERE place.id=:place.id`, pp) if err != nil { log.Fatal(err) } for rows.Next() { - err = rows.StructScan(pp3) + err = rows.StructScan(pp5) if err != nil { t.Error(err) } - if pp3.Owner == nil { + if pp5.Owner == nil { t.Error("Expected `Owner`, to not be nil") } + if pp5.Owner.FirstName != "ben" { + t.Error("Expected first name of `ben`, got " + pp5.Owner.FirstName) + } + if pp5.Owner.LastName != "doe" { + t.Error("Expected first name of `doe`, got " + pp5.Owner.LastName) + } + if pp5.Place.Name.String != "the-house" { + t.Error("Expected place name of `the-house`, got " + pp5.Place.Name.String) + } + if pp5.Place.ID != pp.Place.ID { + t.Errorf("Expected place name of %v, got %v", pp.Place.ID, pp5.Place.ID) + } + if pp5.Owner.Details != nil { + t.Error("Expected `Details` to be nil") + } + } + + details := Details{ + Email: pp.Email.String, + Notes: "this is a test person", + } - if pp3.Owner.FirstName != "ben" { - t.Error("Expected first name of `ben`, got " + pp3.Owner.FirstName) + q6 := `INSERT INTO persondetails (email, notes) VALUES (:email, :notes)` + _, err = db.NamedExecContext(ctx, q6, details) + if err != nil { + log.Fatal(err) + } + + pp6 := &PlaceOwnerDetails{} + rows, err = db.NamedQueryContext(ctx, ` + SELECT + place.id AS "place.id", + place.name AS "place.name", + placeperson.first_name "owner.first_name", + placeperson.last_name "owner.last_name", + placeperson.email "owner.email", + persondetails.email "owner.details.email", + persondetails.notes "owner.details.notes" + FROM place + LEFT JOIN placeperson ON placeperson.place_id = place.id + LEFT JOIN persondetails ON persondetails.email = placeperson.email + WHERE + place.id=:place.id`, pp) + if err != nil { + log.Fatal(err) + } + for rows.Next() { + err = rows.StructScan(pp6) + if err != nil { + t.Error(err) } - if pp3.Owner.LastName != "doe" { - t.Error("Expected first name of `doe`, got " + pp3.Owner.LastName) + if pp6.Owner == nil { + t.Error("Expected `Owner` to not be nil") } - if pp3.Place.Name.String != "the-house" { - t.Error("Expected place name of `the-house`, got " + pp3.Place.Name.String) + if pp6.Owner.FirstName != "ben" { + t.Error("Expected first name of `ben`, got " + pp6.Owner.FirstName) } - if pp3.Place.ID != pp.Place.ID { - t.Errorf("Expected place name of %v, got %v", pp.Place.ID, pp3.Place.ID) + if pp6.Owner.LastName != "doe" { + t.Error("Expected first name of `doe`, got " + pp6.Owner.LastName) + } + if pp6.Place.Name.String != "the-house" { + t.Error("Expected place name of `the-house`, got " + pp6.Place.Name.String) + } + if pp6.Place.ID != pp.Place.ID { + t.Errorf("Expected place name of %v, got %v", pp.Place.ID, pp6.Place.ID) + } + if pp6.Owner.Details == nil { + t.Error("Expected `Details` to not be nil") + } + if pp6.Owner.Details.Email != details.Email { + t.Errorf("Expected details email of %v, got %v", details.Email, pp6.Owner.Details.Email) + } + if pp6.Owner.Details.Notes != details.Notes { + t.Errorf("Expected details notes of %v, got %v", details.Notes, pp6.Owner.Details.Notes) } } }) From 54f4c074857cd693d550f060aa1425eecc3e9296 Mon Sep 17 00:00:00 2001 From: Mike Johnson Date: Sat, 26 Jul 2025 14:36:05 -0700 Subject: [PATCH 09/10] update readme --- README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/README.md b/README.md index 3de3a89..eb9662d 100644 --- a/README.md +++ b/README.md @@ -29,6 +29,9 @@ explains how to use `database/sql` along with sqlx. ## Changes compared to the original sqlx +* Better scanning in the case of outer joins. If a struct contains a nested + struct pointer, it will no longer be a scan error. + * Made complex joins easier to scan by using the position of the field to help map duplicate column names into structs. See the [joins example](./examples/joins/main.go). From 1f123ff78070bc71b25928c44d782f224858151d Mon Sep 17 00:00:00 2001 From: Mike Johnson Date: Sat, 26 Jul 2025 14:29:42 -0700 Subject: [PATCH 10/10] update tests with assert, range func and style --- examples/generics/main.go | 6 +- sqlx.go | 1 - sqlx_context_test.go | 346 ++++++++++++++------------------------ 3 files changed, 132 insertions(+), 221 deletions(-) diff --git a/examples/generics/main.go b/examples/generics/main.go index ed82b82..1837488 100644 --- a/examples/generics/main.go +++ b/examples/generics/main.go @@ -15,7 +15,8 @@ import ( // docker run --name sqlxpg -p 5444:5432 -e POSTGRES_PASSWORD=password -d docker.io/postgres:17.4 const schema = ` - CREATE TABLE IF NOT EXISTS person ( + DROP TABLE IF EXISTS person; + CREATE TABLE person ( id SERIAL PRIMARY KEY, first_name text, last_name text, @@ -23,7 +24,8 @@ const schema = ` ); TRUNCATE TABLE person; - CREATE TABLE IF NOT EXISTS place ( + DROP TABLE IF EXISTS place; + CREATE TABLE place ( country text, city text NULL, telcode integer diff --git a/sqlx.go b/sqlx.go index d598f75..341bccc 100644 --- a/sqlx.go +++ b/sqlx.go @@ -1246,7 +1246,6 @@ func fieldsByTraversal(v reflect.Value, traversals [][]int, values []interface{} // reflectx.FieldByIndexes initializes pointer fields, including pointers to nested structs. // Use optDest to delay it until the first non-NULL value is scanned into a field of a nested struct. // That way we can support LEFT JOINs with optional nested structs. - traversal := traversal values[i] = optDest(func() interface{} { return reflectx.FieldByIndexes(v, traversal).Addr().Interface() }) diff --git a/sqlx_context_test.go b/sqlx_context_test.go index 1eaae9e..c0762cc 100644 --- a/sqlx_context_test.go +++ b/sqlx_context_test.go @@ -500,28 +500,28 @@ func TestNamedQueryContext(t *testing.T) { Email: sql.NullString{String: "ben@doe.com", Valid: true}, } - q1 := `INSERT INTO person (first_name, last_name, email) VALUES (:first_name, :last_name, :email)` - _, err := db.NamedExecContext(ctx, q1, p) - if err != nil { - log.Fatal(err) - } + _, err := db.NamedExecContext(ctx, `INSERT INTO person (first_name, last_name, email) VALUES (:first_name, :last_name, :email)`, p) + require.NoError(t, err) - p2 := &Person{} - rows, err := db.NamedQueryContext(ctx, "SELECT * FROM person WHERE first_name=:first_name", p) - if err != nil { - log.Fatal(err) - } - for rows.Next() { - err = rows.StructScan(p2) + { + p2 := &Person{} + rows, err := db.NamedQueryContext(ctx, "SELECT * FROM person WHERE first_name=:first_name", p) if err != nil { - t.Error(err) - } - if p2.FirstName.String != "ben" { - t.Error("Expected first name of `ben`, got " + p2.FirstName.String) + log.Fatal(err) } - if p2.LastName.String != "doe" { - t.Error("Expected first name of `doe`, got " + p2.LastName.String) + for rows.Next() { + err = rows.StructScan(p2) + if err != nil { + t.Error(err) + } + if p2.FirstName.String != "ben" { + t.Error("Expected first name of `ben`, got " + p2.FirstName.String) + } + if p2.LastName.String != "doe" { + t.Error("Expected first name of `doe`, got " + p2.LastName.String) + } } + rows.Close() } // these are tests for #73; they verify that named queries work if you've @@ -553,8 +553,7 @@ func TestNamedQueryContext(t *testing.T) { return s } - q1 = `INSERT INTO jsperson ("FIRST", last_name, "EMAIL") VALUES (:FIRST, :last_name, :EMAIL)` - _, err = db.NamedExecContext(ctx, pdb(q1, db), jp) + _, err = db.NamedExecContext(ctx, pdb(`INSERT INTO jsperson ("FIRST", last_name, "EMAIL") VALUES (:FIRST, :last_name, :EMAIL)`, db), jp) if err != nil { t.Fatal(err, db.DriverName()) } @@ -586,16 +585,13 @@ func TestNamedQueryContext(t *testing.T) { last_name=:last_name AND "EMAIL"=:EMAIL `, db)) + require.NoError(t, err) - if err != nil { - t.Fatal(err) - } - rows, err = ns.QueryxContext(ctx, jp) - if err != nil { - t.Fatal(err) - } + rows, err := ns.QueryxContext(ctx, jp) + require.NoError(t, err) check(t, rows) + rows.Close() // Check exactly the same thing, but with db.NamedQuery, which does not go // through the PrepareNamed/NamedStmt path. @@ -606,11 +602,10 @@ func TestNamedQueryContext(t *testing.T) { last_name=:last_name AND "EMAIL"=:EMAIL `, db), jp) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) check(t, rows) + rows.Close() db.Mapper = old @@ -630,29 +625,23 @@ func TestNamedQueryContext(t *testing.T) { Name: sql.NullString{String: "myplace", Valid: true}, } - pp := PlacePerson{ + benDoe := PlacePerson{ FirstName: sql.NullString{String: "ben", Valid: true}, LastName: sql.NullString{String: "doe", Valid: true}, Email: sql.NullString{String: "ben@doe.com", Valid: true}, } - q2 := `INSERT INTO place (id, name) VALUES (1, :name)` - _, err = db.NamedExecContext(ctx, q2, pl) - if err != nil { - log.Fatal(err) - } + _, err = db.NamedExecContext(ctx, `INSERT INTO place (id, name) VALUES (1, :name)`, pl) + require.NoError(t, err) id := 1 - pp.Place.ID = id + benDoe.Place.ID = id - q3 := `INSERT INTO placeperson (first_name, last_name, email, place_id) VALUES (:first_name, :last_name, :email, :place.id)` - _, err = db.NamedExecContext(ctx, q3, pp) - if err != nil { - log.Fatal(err) - } + _, err = db.NamedExecContext(ctx, `INSERT INTO placeperson (first_name, last_name, email, place_id) VALUES (:first_name, :last_name, :email, :place.id)`, benDoe) + require.NoError(t, err) - pp2 := &PlacePerson{} - rows, err = db.NamedQueryContext(ctx, ` + { + rows, err = db.NamedQueryContext(ctx, ` SELECT first_name, last_name, @@ -662,31 +651,18 @@ func TestNamedQueryContext(t *testing.T) { FROM placeperson INNER JOIN place ON place.id = placeperson.place_id WHERE - place.id=:place.id`, pp) - if err != nil { - log.Fatal(err) - } - for rows.Next() { - err = rows.StructScan(pp2) - if err != nil { - t.Error(err) - } - if pp2.FirstName.String != "ben" { - t.Error("Expected first name of `ben`, got " + pp2.FirstName.String) - } - if pp2.LastName.String != "doe" { - t.Error("Expected first name of `doe`, got " + pp2.LastName.String) - } - if pp2.Place.Name.String != "myplace" { - t.Error("Expected place name of `myplace`, got " + pp2.Place.Name.String) - } - if pp2.Place.ID != pp.Place.ID { - t.Errorf("Expected place name of %v, got %v", pp.Place.ID, pp2.Place.ID) + place.id = :place.id`, benDoe) + require.NoError(t, err) + + for pp2, err := range AllRows[PlacePerson](rows) { + require.NoError(t, err) + assert.Equal(t, benDoe.FirstName.String, pp2.FirstName.String) + assert.Equal(t, benDoe.LastName.String, pp2.LastName.String) + assert.Equal(t, benDoe.Email.String, pp2.Email.String) + assert.Equal(t, benDoe.Place.ID, pp2.Place.ID) } } - rows.Close() - type Owner struct { Email *string `db:"email"` FirstName string `db:"first_name"` @@ -703,88 +679,58 @@ func TestNamedQueryContext(t *testing.T) { Name: sql.NullString{String: "the-house", Valid: true}, } - q4 := `INSERT INTO place (id, name) VALUES (2, :name)` - _, err = db.NamedExecContext(ctx, q4, pl) - if err != nil { - log.Fatal(err) - } + _, err = db.NamedExecContext(ctx, `INSERT INTO place (id, name) VALUES (2, :name)`, pl) + require.NoError(t, err) id = 2 - pp.Place.ID = id + benDoe.Place.ID = id - q5 := `INSERT INTO placeperson (first_name, last_name, email, place_id) VALUES (:first_name, :last_name, :email, :place.id)` - _, err = db.NamedExecContext(ctx, q5, pp) - if err != nil { - log.Fatal(err) - } + _, err = db.NamedExecContext(ctx, `INSERT INTO placeperson (first_name, last_name, email, place_id) VALUES (:first_name, :last_name, :email, :place.id)`, benDoe) + require.NoError(t, err) - pp3 := &PlaceOwner{} - rows, err = db.NamedQueryContext(ctx, ` + { + rows, err = db.NamedQueryContext(ctx, ` SELECT - place.id AS "place.id", - place.name AS "place.name", - placeperson.first_name "owner.first_name", - placeperson.last_name "owner.last_name", - placeperson.email "owner.email" + place.id, + place.name, + placeperson.first_name, + placeperson.last_name, + placeperson.email FROM place LEFT JOIN placeperson ON false -- null left join WHERE - place.id=:place.id`, pp) - if err != nil { - log.Fatal(err) - } - for rows.Next() { - err = rows.StructScan(pp3) - if err != nil { - t.Error(err) - } - if pp3.Owner != nil { - t.Error("Expected `Owner` to be nil") - } - if pp3.Place.Name.String != "the-house" { - t.Error("Expected place name of `the-house`, got " + pp3.Place.Name.String) - } - if pp3.Place.ID != pp.Place.ID { - t.Errorf("Expected place name of %v, got %v", pp.Place.ID, pp3.Place.ID) + place.id = :place.id`, benDoe) + require.NoError(t, err) + + for pp3, err := range AllRows[PlaceOwner](rows) { + require.NoError(t, err) + assert.Nil(t, pp3.Owner, "Expected `Owner` to be nil") + assert.Equal(t, "the-house", pp3.Place.Name.String) + assert.Equal(t, benDoe.Place.ID, pp3.Place.ID) } } - rows.Close() - - pp4 := &PlaceOwner{} - rows, err = db.NamedQueryContext(ctx, ` + { + rows, err = db.NamedQueryContext(ctx, ` SELECT - place.id AS "place.id", - place.name AS "place.name", - placeperson.first_name "owner.first_name", - placeperson.last_name "owner.last_name", - placeperson.email "owner.email" + place.id, + place.name, + placeperson.first_name, + placeperson.last_name, + placeperson.email FROM place LEFT JOIN placeperson ON placeperson.place_id = place.id WHERE - place.id=:place.id`, pp) - if err != nil { - log.Fatal(err) - } - for rows.Next() { - err = rows.StructScan(pp4) - if err != nil { - t.Error(err) - } - if pp4.Owner == nil { - t.Error("Expected `Owner` to not be nil") - } - if pp4.Owner.FirstName != "ben" { - t.Error("Expected first name of `ben`, got " + pp4.Owner.FirstName) - } - if pp4.Owner.LastName != "doe" { - t.Error("Expected first name of `doe`, got " + pp4.Owner.LastName) - } - if pp4.Place.Name.String != "the-house" { - t.Error("Expected place name of `the-house`, got " + pp4.Place.Name.String) - } - if pp4.Place.ID != pp.Place.ID { - t.Errorf("Expected place name of %v, got %v", pp.Place.ID, pp4.Place.ID) + place.id = :place.id`, benDoe) + require.NoError(t, err) + + for pp4, err := range AllRows[PlaceOwner](rows) { + require.NoError(t, err) + assert.NotNil(t, pp4.Owner, "Expected `Owner` to not be nil") + assert.Equal(t, "ben", pp4.Owner.FirstName) + assert.Equal(t, "doe", pp4.Owner.LastName) + assert.Equal(t, "the-house", pp4.Place.Name.String) + assert.Equal(t, benDoe.Place.ID, pp4.Place.ID) } } @@ -805,106 +751,70 @@ func TestNamedQueryContext(t *testing.T) { Owner *OwnerDetails `db:"owner"` } - pp5 := &PlaceOwnerDetails{} - rows, err = db.NamedQueryContext(ctx, ` + { + rows, err = db.NamedQueryContext(ctx, ` SELECT - place.id AS "place.id", - place.name AS "place.name", - placeperson.first_name "owner.first_name", - placeperson.last_name "owner.last_name", - placeperson.email "owner.email", - persondetails.email "owner.details.email", - persondetails.notes "owner.details.notes" + place.id, + place.name, + placeperson.first_name, + placeperson.last_name, + placeperson.email, + persondetails.email, + persondetails.notes FROM place LEFT JOIN placeperson ON placeperson.place_id = place.id LEFT JOIN persondetails ON false WHERE - place.id=:place.id`, pp) - if err != nil { - log.Fatal(err) - } - for rows.Next() { - err = rows.StructScan(pp5) - if err != nil { - t.Error(err) - } - if pp5.Owner == nil { - t.Error("Expected `Owner`, to not be nil") - } - if pp5.Owner.FirstName != "ben" { - t.Error("Expected first name of `ben`, got " + pp5.Owner.FirstName) - } - if pp5.Owner.LastName != "doe" { - t.Error("Expected first name of `doe`, got " + pp5.Owner.LastName) - } - if pp5.Place.Name.String != "the-house" { - t.Error("Expected place name of `the-house`, got " + pp5.Place.Name.String) - } - if pp5.Place.ID != pp.Place.ID { - t.Errorf("Expected place name of %v, got %v", pp.Place.ID, pp5.Place.ID) - } - if pp5.Owner.Details != nil { - t.Error("Expected `Details` to be nil") + place.id = :place.id`, benDoe) + require.NoError(t, err) + + for pp5, err := range AllRows[PlaceOwnerDetails](rows) { + require.NoError(t, err) + assert.NotNil(t, pp5.Owner, "Expected `Owner`, to not be nil") + assert.Equal(t, "ben", pp5.Owner.FirstName) + assert.Equal(t, "doe", pp5.Owner.LastName) + assert.Equal(t, benDoe.Email.String, *pp5.Owner.Email) + assert.Equal(t, "the-house", pp5.Place.Name.String) + assert.Equal(t, pp5.Place.ID, benDoe.Place.ID) + assert.Nil(t, pp5.Owner.Details) } } - details := Details{ - Email: pp.Email.String, - Notes: "this is a test person", - } + { + details := Details{ + Email: benDoe.Email.String, + Notes: "this is a test person", + } - q6 := `INSERT INTO persondetails (email, notes) VALUES (:email, :notes)` - _, err = db.NamedExecContext(ctx, q6, details) - if err != nil { - log.Fatal(err) - } + _, err = db.NamedExecContext(ctx, `INSERT INTO persondetails (email, notes) VALUES (:email, :notes)`, details) + require.NoError(t, err) - pp6 := &PlaceOwnerDetails{} - rows, err = db.NamedQueryContext(ctx, ` + rows, err = db.NamedQueryContext(ctx, ` SELECT - place.id AS "place.id", - place.name AS "place.name", - placeperson.first_name "owner.first_name", - placeperson.last_name "owner.last_name", - placeperson.email "owner.email", - persondetails.email "owner.details.email", - persondetails.notes "owner.details.notes" + place.id, + place.name, + placeperson.first_name, + placeperson.last_name, + placeperson.email, + persondetails.email, + persondetails.notes FROM place LEFT JOIN placeperson ON placeperson.place_id = place.id LEFT JOIN persondetails ON persondetails.email = placeperson.email WHERE - place.id=:place.id`, pp) - if err != nil { - log.Fatal(err) - } - for rows.Next() { - err = rows.StructScan(pp6) - if err != nil { - t.Error(err) - } - if pp6.Owner == nil { - t.Error("Expected `Owner` to not be nil") - } - if pp6.Owner.FirstName != "ben" { - t.Error("Expected first name of `ben`, got " + pp6.Owner.FirstName) - } - if pp6.Owner.LastName != "doe" { - t.Error("Expected first name of `doe`, got " + pp6.Owner.LastName) - } - if pp6.Place.Name.String != "the-house" { - t.Error("Expected place name of `the-house`, got " + pp6.Place.Name.String) - } - if pp6.Place.ID != pp.Place.ID { - t.Errorf("Expected place name of %v, got %v", pp.Place.ID, pp6.Place.ID) - } - if pp6.Owner.Details == nil { - t.Error("Expected `Details` to not be nil") - } - if pp6.Owner.Details.Email != details.Email { - t.Errorf("Expected details email of %v, got %v", details.Email, pp6.Owner.Details.Email) - } - if pp6.Owner.Details.Notes != details.Notes { - t.Errorf("Expected details notes of %v, got %v", details.Notes, pp6.Owner.Details.Notes) + place.id = :place.id`, benDoe) + require.NoError(t, err) + + for pp6, err := range AllRows[PlaceOwnerDetails](rows) { + require.NoError(t, err) + assert.NotNil(t, pp6.Owner, "Expected `Owner`, to not be nil") + assert.Equal(t, "ben", pp6.Owner.FirstName) + assert.Equal(t, "doe", pp6.Owner.LastName) + assert.Equal(t, "the-house", pp6.Place.Name.String) + assert.Equal(t, pp6.Place.ID, pp6.Place.ID) + assert.NotNil(t, pp6.Owner.Details, "Expected `Details` to not be nil") + assert.Equal(t, details.Email, pp6.Owner.Details.Email) + assert.Equal(t, details.Notes, pp6.Owner.Details.Notes) } } })