diff --git a/driver/connection.go b/driver/connection.go index 2d8303a..c3fde7a 100644 --- a/driver/connection.go +++ b/driver/connection.go @@ -1,10 +1,11 @@ package driver import ( - "cloud.google.com/go/bigquery" "context" "database/sql/driver" "fmt" + + "cloud.google.com/go/bigquery" ) type bigQueryConnection struct { @@ -94,7 +95,6 @@ func (connection *bigQueryConnection) Exec(query string, args []driver.Value) (d return statement.Exec(args) } -func (bigQueryConnection) CheckNamedValue(*driver.NamedValue) error { - // TODO: Revise in the future - return nil +func (connection *bigQueryConnection) CheckNamedValue(namedValue *driver.NamedValue) error { + return unwrapValuer(namedValue) } diff --git a/driver/statement.go b/driver/statement.go index dd28080..89ef479 100644 --- a/driver/statement.go +++ b/driver/statement.go @@ -22,8 +22,8 @@ func (statement bigQueryStatement) NumInput() int { return 0 } -func (bigQueryStatement) CheckNamedValue(*driver.NamedValue) error { - return nil +func (bigQueryStatement) CheckNamedValue(namedValue *driver.NamedValue) error { + return unwrapValuer(namedValue) } func (statement *bigQueryStatement) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { diff --git a/driver/valuer.go b/driver/valuer.go new file mode 100644 index 0000000..868bfdc --- /dev/null +++ b/driver/valuer.go @@ -0,0 +1,28 @@ +package driver + +import ( + "database/sql/driver" + "fmt" +) + +const maxValuerUnwrapDepth = 100 + +func unwrapValuer(namedValue *driver.NamedValue) error { + value := namedValue.Value + for depth := 0; depth < maxValuerUnwrapDepth; depth++ { + valuer, ok := value.(driver.Valuer) + if !ok { + namedValue.Value = value + return nil + } + + unwrapped, err := valuer.Value() + if err != nil { + return err + } + + value = unwrapped + } + + return fmt.Errorf("valuer unwrap exceeded max depth %d", maxValuerUnwrapDepth) +} diff --git a/driver/valuer_test.go b/driver/valuer_test.go new file mode 100644 index 0000000..7770f67 --- /dev/null +++ b/driver/valuer_test.go @@ -0,0 +1,64 @@ +package driver + +import ( + "database/sql/driver" + "errors" + "testing" + + "github.com/stretchr/testify/require" +) + +type staticValuer struct { + value driver.Value +} + +func (valuer staticValuer) Value() (driver.Value, error) { + return valuer.value, nil +} + +type errorValuer struct { + err error +} + +func (valuer errorValuer) Value() (driver.Value, error) { + return nil, valuer.err +} + +type recursiveValuer struct{} + +func (recursiveValuer) Value() (driver.Value, error) { + return recursiveValuer{}, nil +} + +func TestBigQueryConnectionCheckNamedValueUnwrapsNestedValuer(t *testing.T) { + namedValue := &driver.NamedValue{ + Name: "value", + Value: staticValuer{value: staticValuer{value: "hello"}}, + } + + err := bigQueryConnection{}.CheckNamedValue(namedValue) + + require.NoError(t, err) + require.Equal(t, "hello", namedValue.Value) +} + +func TestBigQueryStatementCheckNamedValuePropagatesValuerError(t *testing.T) { + expectedErr := errors.New("fail to unwrap") + namedValue := &driver.NamedValue{ + Value: errorValuer{err: expectedErr}, + } + + err := bigQueryStatement{}.CheckNamedValue(namedValue) + + require.ErrorIs(t, err, expectedErr) +} + +func TestBigQueryStatementCheckNamedValueGuardsAgainstInfiniteUnwrap(t *testing.T) { + namedValue := &driver.NamedValue{ + Value: recursiveValuer{}, + } + + err := bigQueryStatement{}.CheckNamedValue(namedValue) + + require.EqualError(t, err, "valuer unwrap exceeded max depth 100") +}