Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions driver/connection.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
package driver

import (
"cloud.google.com/go/bigquery"
"context"
"database/sql/driver"
"fmt"

"cloud.google.com/go/bigquery"
)

type bigQueryConnection struct {
Expand Down Expand Up @@ -94,7 +95,6 @@ func (connection *bigQueryConnection) Exec(query string, args []driver.Value) (d
return statement.Exec(args)
}

func (bigQueryConnection) CheckNamedValue(*driver.NamedValue) error {
// TODO: Revise in the future
return nil
func (connection *bigQueryConnection) CheckNamedValue(namedValue *driver.NamedValue) error {
return unwrapValuer(namedValue)
}
4 changes: 2 additions & 2 deletions driver/statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ func (statement bigQueryStatement) NumInput() int {
return 0
}

func (bigQueryStatement) CheckNamedValue(*driver.NamedValue) error {
return nil
func (bigQueryStatement) CheckNamedValue(namedValue *driver.NamedValue) error {
return unwrapValuer(namedValue)
}

func (statement *bigQueryStatement) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
Expand Down
28 changes: 28 additions & 0 deletions driver/valuer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package driver

import (
"database/sql/driver"
"fmt"
)

const maxValuerUnwrapDepth = 100

func unwrapValuer(namedValue *driver.NamedValue) error {
value := namedValue.Value
for depth := 0; depth < maxValuerUnwrapDepth; depth++ {
valuer, ok := value.(driver.Valuer)
if !ok {
namedValue.Value = value
return nil
}

unwrapped, err := valuer.Value()
if err != nil {
return err
}

value = unwrapped
}

return fmt.Errorf("valuer unwrap exceeded max depth %d", maxValuerUnwrapDepth)
}
64 changes: 64 additions & 0 deletions driver/valuer_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package driver

import (
"database/sql/driver"
"errors"
"testing"

"github.com/stretchr/testify/require"
)

type staticValuer struct {
value driver.Value
}

func (valuer staticValuer) Value() (driver.Value, error) {
return valuer.value, nil
}

type errorValuer struct {
err error
}

func (valuer errorValuer) Value() (driver.Value, error) {
return nil, valuer.err
}

type recursiveValuer struct{}

func (recursiveValuer) Value() (driver.Value, error) {
return recursiveValuer{}, nil
}

func TestBigQueryConnectionCheckNamedValueUnwrapsNestedValuer(t *testing.T) {
namedValue := &driver.NamedValue{
Name: "value",
Value: staticValuer{value: staticValuer{value: "hello"}},
}

err := bigQueryConnection{}.CheckNamedValue(namedValue)

require.NoError(t, err)
require.Equal(t, "hello", namedValue.Value)
}

func TestBigQueryStatementCheckNamedValuePropagatesValuerError(t *testing.T) {
expectedErr := errors.New("fail to unwrap")
namedValue := &driver.NamedValue{
Value: errorValuer{err: expectedErr},
}

err := bigQueryStatement{}.CheckNamedValue(namedValue)

require.ErrorIs(t, err, expectedErr)
}

func TestBigQueryStatementCheckNamedValueGuardsAgainstInfiniteUnwrap(t *testing.T) {
namedValue := &driver.NamedValue{
Value: recursiveValuer{},
}

err := bigQueryStatement{}.CheckNamedValue(namedValue)

require.EqualError(t, err, "valuer unwrap exceeded max depth 100")
}