Skip to content
Merged
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ explains how to use `database/sql` along with sqlx.

## Changes compared to the original sqlx

* Better scanning in the case of outer joins. If a struct contains a nested
struct pointer, it will no longer be a scan error.

* Made complex joins easier to scan by using the position of the field
to help map duplicate column names into structs. See the [joins
example](./examples/joins/main.go).
Expand Down
8 changes: 8 additions & 0 deletions convert.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package sqlx

import (
_ "unsafe"
)

//go:linkname convertAssign database/sql.convertAssign
func convertAssign(dest, src interface{}) error
6 changes: 4 additions & 2 deletions examples/generics/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,17 @@ import (
// docker run --name sqlxpg -p 5444:5432 -e POSTGRES_PASSWORD=password -d docker.io/postgres:17.4

const schema = `
CREATE TABLE IF NOT EXISTS person (
DROP TABLE IF EXISTS person;
CREATE TABLE person (
id SERIAL PRIMARY KEY,
first_name text,
last_name text,
email text
);
TRUNCATE TABLE person;

CREATE TABLE IF NOT EXISTS place (
DROP TABLE IF EXISTS place;
CREATE TABLE place (
country text,
city text NULL,
telcode integer
Expand Down
3 changes: 1 addition & 2 deletions reflectx/reflect.go
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,7 @@ func FieldByIndexes(v reflect.Value, indexes []int) reflect.Value {
v = reflect.Indirect(v).Field(i)
// if this is a pointer and it's nil, allocate a new value and set it
if v.Kind() == reflect.Ptr && v.IsNil() {
alloc := reflect.New(Deref(v.Type()))
v.Set(alloc)
v.Set(reflect.New(v.Type().Elem()))
}
if v.Kind() == reflect.Map && v.IsNil() {
v.Set(reflect.MakeMap(v.Type()))
Expand Down
38 changes: 26 additions & 12 deletions sqlx.go
Original file line number Diff line number Diff line change
Expand Up @@ -773,7 +773,7 @@ func (r *Rows) StructScan(dest interface{}) error {
r.started = true
}

err := fieldsByTraversal(v, r.fields, r.values, true)
err := fieldsByTraversal(v, r.fields, r.values)
if err != nil {
return err
}
Expand Down Expand Up @@ -990,7 +990,7 @@ func (r *Row) scanAny(dest interface{}, structOnly bool) error {
}
values := make([]interface{}, len(columns))

err = fieldsByTraversal(v, fields, values, true)
err = fieldsByTraversal(v, fields, values)
if err != nil {
return err
}
Expand Down Expand Up @@ -1165,7 +1165,7 @@ func scanAll(rows rowsi, dest interface{}, structOnly bool) error {
vp = reflect.New(base)
v = reflect.Indirect(vp)

err = fieldsByTraversal(v, fields, values, true)
err = fieldsByTraversal(v, fields, values)
if err != nil {
return err
}
Expand Down Expand Up @@ -1231,7 +1231,7 @@ func baseType(t reflect.Type, expected reflect.Kind) (reflect.Type, error) {
// when iterating over many rows. Empty traversals will get an interface pointer.
// Because of the necessity of requesting ptrs or values, it's considered a bit too
// specialized for inclusion in reflectx itself.
func fieldsByTraversal(v reflect.Value, traversals [][]int, values []interface{}, ptrs bool) error {
func fieldsByTraversal(v reflect.Value, traversals [][]int, values []interface{}) error {
v = reflect.Indirect(v)
if v.Kind() != reflect.Struct {
return errors.New("argument not a struct")
Expand All @@ -1240,23 +1240,37 @@ func fieldsByTraversal(v reflect.Value, traversals [][]int, values []interface{}
for i, traversal := range traversals {
if len(traversal) == 0 {
values[i] = new(interface{})
continue
}
f := reflectx.FieldByIndexes(v, traversal)
if ptrs {
values[i] = f.Addr().Interface()
} else if len(traversal) == 1 {
values[i] = reflectx.FieldByIndexes(v, traversal).Addr().Interface()
} else {
values[i] = f.Interface()
// reflectx.FieldByIndexes initializes pointer fields, including pointers to nested structs.
// Use optDest to delay it until the first non-NULL value is scanned into a field of a nested struct.
// That way we can support LEFT JOINs with optional nested structs.
values[i] = optDest(func() interface{} {
return reflectx.FieldByIndexes(v, traversal).Addr().Interface()
})
}
}
return nil
}

func missingFields(transversals [][]int) (field int, err error) {
for i, t := range transversals {
func missingFields(traversals [][]int) (field int, err error) {
for i, t := range traversals {
if len(t) == 0 {
return i, errors.New("missing field")
}
}
return 0, nil
}

// optDest will only forward the Scan to the nested value if
// the database value is not nil.
type optDest func() interface{}

// Scan implements sql.Scanner.
func (dest optDest) Scan(src interface{}) error {
if src == nil {
return nil
}
return convertAssign(dest(), src)
}
Loading