diff --git a/.gitignore b/.gitignore index 4fb2829fe..b7f34a72c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,22 +1,28 @@ -# Binaries for programs and plugins -*.exe -*.exe~ -*.dll -*.so -*.dylib - -# Test binary, built with `go test -c` -*.test - -# Output of the go coverage tool, specifically when used with LiteIDE -*.out - -# Dependency directories (remove the comment below to include it) -vendor/ -dist - -# Locally built emulator binary (produced by `make emulator/build`) -/bigquery-emulator - -# Local output of the e2e conformance suite -/test/e2e/.out/ +# Binaries for programs and plugins +*.exe +*.exe~ +*.dll +*.so +*.dylib + +# Test binary, built with `go test -c` +*.test + +# Output of the go coverage tool, specifically when used with LiteIDE +*.out + +# Dependency directories (remove the comment below to include it) +vendor/ +dist + +# Locally built emulator binary (produced by `make emulator/build`) +/bigquery-emulator + +# Local output of the e2e conformance suite +/test/e2e/.out/ + +# Local reproduction scripts and logs +/repro/ + +# Local development log +devlog.md diff --git a/internal/connection/manager.go b/internal/connection/manager.go index 53dbff844..713bb92a6 100644 --- a/internal/connection/manager.go +++ b/internal/connection/manager.go @@ -1,117 +1,121 @@ -package connection - -import ( - "context" - "database/sql" - "fmt" - - "github.com/goccy/googlesqlite" -) - -type Manager struct { - db *sql.DB -} - -func NewManager(db *sql.DB) *Manager { - return &Manager{db: db} -} - -func (m *Manager) Connection(ctx context.Context, projectID, datasetID string) (*Conn, error) { - if projectID == "" { - return nil, fmt.Errorf("invalid projectID. projectID is empty") - } - conn, err := m.db.Conn(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get connection: %w", err) - } - return &Conn{ - ProjectID: projectID, - DatasetID: datasetID, - Conn: conn, - }, nil -} - -type Tx struct { - tx *sql.Tx - conn *Conn - committed bool -} - -func (t *Tx) Tx() *sql.Tx { - return t.tx -} - -func (t *Tx) RollbackIfNotCommitted() error { - if t.committed { - return nil - } - defer t.conn.Conn.Close() - return t.tx.Rollback() -} - -func (t *Tx) Commit() error { - if err := t.tx.Commit(); err != nil { - return err - } - t.committed = true - t.conn.Conn.Close() - return nil -} - -func (t *Tx) SetProjectAndDataset(projectID, datasetID string) { - t.conn.ProjectID = projectID - t.conn.DatasetID = datasetID -} - -func (t *Tx) MetadataRepoMode() error { - if err := t.conn.Conn.Raw(func(c interface{}) error { - gsqlConn, ok := c.(*googlesqlite.Conn) - if !ok { - return fmt.Errorf("failed to get *googlesqlite.Conn from %T", c) - } - _ = gsqlConn.SetNamePath([]string{}) - return nil - }); err != nil { - return fmt.Errorf("failed to setup connection: %w", err) - } - return nil -} - -func (t *Tx) ContentRepoMode() error { - if err := t.conn.Conn.Raw(func(c interface{}) error { - gsqlConn, ok := c.(*googlesqlite.Conn) - if !ok { - return fmt.Errorf("failed to get *googlesqlite.Conn from %T", c) - } - if t.conn.DatasetID == "" { - _ = gsqlConn.SetNamePath([]string{t.conn.ProjectID}) - } else { - _ = gsqlConn.SetNamePath([]string{t.conn.ProjectID, t.conn.DatasetID}) - } - const maxNamePath = 3 // projectID and datasetID and tableID - gsqlConn.SetMaxNamePath(maxNamePath) - return nil - }); err != nil { - return fmt.Errorf("failed to setup connection: %w", err) - } - return nil -} - -type Conn struct { - ProjectID string - DatasetID string - Conn *sql.Conn -} - -func (c *Conn) Begin(ctx context.Context) (*Tx, error) { - tx, err := c.Conn.BeginTx(ctx, nil) - if err != nil { - // The pooled connection is owned by the Tx once BeginTx succeeds and - // is released by Commit/RollbackIfNotCommitted. When BeginTx fails no - // Tx is created, so the connection must be returned to the pool here - // or it leaks for the lifetime of the process. - _ = c.Conn.Close() - return nil, err - } - return &Tx{tx: tx, conn: c}, nil -} +package connection + +import ( + "context" + "database/sql" + "fmt" + + "github.com/goccy/googlesqlite" +) + +type Manager struct { + db *sql.DB +} + +func NewManager(db *sql.DB) *Manager { + return &Manager{db: db} +} + +func (m *Manager) Connection(ctx context.Context, projectID, datasetID string) (*Conn, error) { + if projectID == "" { + return nil, fmt.Errorf("invalid projectID. projectID is empty") + } + conn, err := m.db.Conn(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get connection: %w", err) + } + return &Conn{ + ProjectID: projectID, + DatasetID: datasetID, + Conn: conn, + }, nil +} + +type Tx struct { + tx *sql.Tx + conn *Conn + committed bool +} + +func (t *Tx) Tx() *sql.Tx { + return t.tx +} + +func (t *Tx) RollbackIfNotCommitted() error { + if t.committed { + return nil + } + defer t.conn.Conn.Close() + return t.tx.Rollback() +} + +func (t *Tx) Commit() error { + if err := t.tx.Commit(); err != nil { + return err + } + t.committed = true + t.conn.Conn.Close() + return nil +} + +func (t *Tx) SetProjectAndDataset(projectID, datasetID string) { + t.conn.ProjectID = projectID + t.conn.DatasetID = datasetID +} + +func (t *Tx) MetadataRepoMode() error { + if err := t.conn.Conn.Raw(func(c interface{}) error { + gsqlConn, ok := c.(*googlesqlite.Conn) + if !ok { + return fmt.Errorf("failed to get *googlesqlite.Conn from %T", c) + } + _ = gsqlConn.SetNamePath([]string{}) + return nil + }); err != nil { + return fmt.Errorf("failed to setup connection: %w", err) + } + return nil +} + +func (t *Tx) ContentRepoMode() error { + if err := t.conn.Conn.Raw(func(c interface{}) error { + gsqlConn, ok := c.(*googlesqlite.Conn) + if !ok { + return fmt.Errorf("failed to get *googlesqlite.Conn from %T", c) + } + if t.conn.ProjectID == "" { + return fmt.Errorf("invalid projectID. projectID is empty") + } + namePath := []string{t.conn.ProjectID} + if t.conn.DatasetID != "" { + namePath = append(namePath, t.conn.DatasetID) + } + _ = gsqlConn.SetNamePath(namePath) + + const maxNamePath = 3 // projectID and datasetID and tableID + gsqlConn.SetMaxNamePath(maxNamePath) + return nil + }); err != nil { + return fmt.Errorf("failed to setup connection: %w", err) + } + return nil +} + +type Conn struct { + ProjectID string + DatasetID string + Conn *sql.Conn +} + +func (c *Conn) Begin(ctx context.Context) (*Tx, error) { + tx, err := c.Conn.BeginTx(ctx, nil) + if err != nil { + // The pooled connection is owned by the Tx once BeginTx succeeds and + // is released by Commit/RollbackIfNotCommitted. When BeginTx fails no + // Tx is created, so the connection must be returned to the pool here + // or it leaks for the lifetime of the process. + _ = c.Conn.Close() + return nil, err + } + return &Tx{tx: tx, conn: c}, nil +} diff --git a/internal/contentdata/repository.go b/internal/contentdata/repository.go index 9458739ed..1470e302a 100644 --- a/internal/contentdata/repository.go +++ b/internal/contentdata/repository.go @@ -541,7 +541,7 @@ func (r *Repository) convertValueToCell(value interface{}, schema *bigqueryv2.Ta func (r *Repository) CreateOrReplaceTable(ctx context.Context, tx *connection.Tx, projectID, datasetID string, table *types.Table) error { tx.SetProjectAndDataset(projectID, datasetID) - if err := tx.ContentRepoMode(); err != nil { + if err := tx.MetadataRepoMode(); err != nil { return err } defer func() { @@ -577,7 +577,7 @@ func (r *Repository) AddTableData(ctx context.Context, tx *connection.Tx, projec return nil } tx.SetProjectAndDataset(projectID, datasetID) - if err := tx.ContentRepoMode(); err != nil { + if err := tx.MetadataRepoMode(); err != nil { return err } defer func() { diff --git a/internal/types/types.go b/internal/types/types.go index 8b7f5d86d..ef8faeb31 100644 --- a/internal/types/types.go +++ b/internal/types/types.go @@ -1,266 +1,329 @@ -package types - -import ( - "fmt" - "time" - - "github.com/apache/arrow-go/v18/arrow/array" - "github.com/goccy/bigquery-emulator/types" - "github.com/goccy/googlesqlite" - bigqueryv2 "google.golang.org/api/bigquery/v2" -) - -type ( - GetQueryResultsResponse struct { - JobReference *bigqueryv2.JobReference `json:"jobReference"` - Schema *bigqueryv2.TableSchema `json:"schema"` - Rows []*TableRow `json:"rows"` - TotalRows uint64 `json:"totalRows,string"` - JobComplete bool `json:"jobComplete"` - PageToken string `json:"pageToken,omitempty"` - TotalBytes uint64 `json:"-"` - } - - QueryResponse struct { - JobReference *bigqueryv2.JobReference `json:"jobReference"` - Schema *bigqueryv2.TableSchema `json:"schema"` - Rows []*TableRow `json:"rows"` - TotalRows uint64 `json:"totalRows,string"` - JobComplete bool `json:"jobComplete"` - TotalBytes int64 `json:"-"` - ChangedCatalog *googlesqlite.ChangedCatalog `json:"-"` - } - - TableDataList struct { - Rows []*TableRow `json:"rows"` - TotalRows uint64 `json:"totalRows,string"` - } - - TableRow struct { - F []*TableCell `json:"f,omitempty"` - } - - // Redefines the TableCell type to return null explicitly - // because TableCell for bigqueryv2 is omitted if V is nil, - TableCell struct { - V interface{} `json:"v"` - Bytes int64 `json:"-"` - Name string `json:"-"` - } -) - -func (r *TableRow) Data() (map[string]interface{}, error) { - rowMap := map[string]interface{}{} - for _, cell := range r.F { - v, err := cell.Data() - if err != nil { - return nil, err - } - rowMap[cell.Name] = v - } - return rowMap, nil -} - -func (r *TableRow) AVROValue(namespace string, fields []*types.AVROFieldSchema) (map[string]interface{}, error) { - rowMap := map[string]interface{}{} - for idx, cell := range r.F { - v, err := cell.AVROValue(namespace, fields[idx]) - if err != nil { - return nil, err - } - rowMap[cell.Name] = v - } - return rowMap, nil -} - -// avroRecordUnionKey returns the Avro union branch name for a nullable record -// field. goavro identifies a record branch within a union by the record's -// full name, which is the enclosing namespace joined with the record name. -func avroRecordUnionKey(namespace, name string) string { - if namespace == "" { - return name - } - return namespace + "." + name -} - -func (c *TableCell) Data() (interface{}, error) { - switch v := c.V.(type) { - case TableRow: - return v.Data() - case []*TableCell: - ret := make([]interface{}, 0, len(v)) - for _, vv := range v { - data, err := vv.Data() - if err != nil { - return nil, err - } - ret = append(ret, data) - } - return ret, nil - default: - if v == nil { - return nil, nil - } - text, ok := v.(string) - if !ok { - return nil, fmt.Errorf("failed to cast to string from %s", v) - } - return text, nil - } -} - -func (c *TableCell) AVROValue(namespace string, schema *types.AVROFieldSchema) (interface{}, error) { - switch v := c.V.(type) { - case TableRow: - fields := types.TableFieldSchemasToAVRO(schema.Type.TypeSchema.Fields) - recordValue, err := v.AVROValue(namespace, fields) - if err != nil { - return nil, err - } - // The schema for a record field mirrors AVROType.MarshalJSON: - // REQUIRED/REPEATED records are encoded bare (REPEATED records are - // the bare element type of an array), while a nullable record is an - // Avro union and goavro requires its value to be wrapped in a - // single-key map keyed by the record's full name. - switch types.Mode(schema.Type.TypeSchema.Mode) { - case types.RequiredMode, types.RepeatedMode: - return recordValue, nil - default: - return map[string]interface{}{ - avroRecordUnionKey(namespace, schema.Type.TypeSchema.Name): recordValue, - }, nil - } - case []*TableCell: - ret := make([]interface{}, 0, len(v)) - for _, vv := range v { - avrov, err := vv.AVROValue(namespace, schema) - if err != nil { - return nil, err - } - ret = append(ret, avrov) - } - return ret, nil - default: - if v == nil { - return map[string]interface{}{schema.Type.Key(): nil}, nil - } - text, ok := v.(string) - if !ok { - return nil, fmt.Errorf("failed to cast to string from %s", v) - } - value, err := schema.Type.CastValue(text) - if err != nil { - return nil, err - } - // A bare value for REQUIRED fields and for the element type of - // REPEATED arrays; a union-wrapped value for nullable fields. - switch types.Mode(schema.Type.TypeSchema.Mode) { - case types.RequiredMode, types.RepeatedMode: - return value, nil - default: - return map[string]interface{}{schema.Type.Key(): value}, nil - } - } -} - -func (r *TableRow) AppendValueToARROWBuilder(builder *array.RecordBuilder) error { - for idx, cell := range r.F { - if err := cell.AppendValueToARROWBuilder(builder.Field(idx)); err != nil { - return err - } - } - return nil -} - -func (r *TableRow) appendValueToARROWBuilder(builder *array.StructBuilder) error { - for idx, cell := range r.F { - if err := cell.AppendValueToARROWBuilder(builder.FieldBuilder(idx)); err != nil { - return err - } - } - return nil -} - -func (c *TableCell) AppendValueToARROWBuilder(builder array.Builder) error { - switch v := c.V.(type) { - case TableRow: - b, ok := builder.(*array.StructBuilder) - if !ok { - return fmt.Errorf("failed to convert to struct builder from %T", builder) - } - b.Append(true) - return v.appendValueToARROWBuilder(b) - case []*TableCell: - listBuilder, ok := builder.(*array.ListBuilder) - if !ok { - return fmt.Errorf("failed to convert to list builder from %T", builder) - } - b := listBuilder.ValueBuilder() - for _, vv := range v { - listBuilder.Append(true) - if err := vv.AppendValueToARROWBuilder(b); err != nil { - return err - } - } - return nil - default: - if v == nil { - return types.AppendValueToARROWBuilder(nil, builder) - } - text, ok := v.(string) - if !ok { - return fmt.Errorf("failed to cast to string from %s", v) - } - return types.AppendValueToARROWBuilder(&text, builder) - } -} - -// Format converts TIMESTAMP result cells from the raw canonical timestamp -// produced by the SQL backend into the representation the BigQuery REST API -// returns: int64 microseconds-since-epoch when useInt64Timestamp is set, and -// otherwise a float seconds-since-epoch string. Every official client decodes -// a TIMESTAMP value as one of those two numeric forms, never as a formatted -// datetime string. -func Format(schema *bigqueryv2.TableSchema, rows []*TableRow, useInt64Timestamp bool) []*TableRow { - formattedRows := make([]*TableRow, 0, len(rows)) - for _, row := range rows { - cells := make([]*TableCell, 0, len(row.F)) - for colIdx, cell := range row.F { - if schema.Fields[colIdx].Type == "TIMESTAMP" && cell.V != nil { - cells = append(cells, &TableCell{ - V: formatTimestampCell(cell.V, useInt64Timestamp), - }) - } else { - cells = append(cells, cell) - } - } - formattedRows = append(formattedRows, &TableRow{ - F: cells, - }) - } - return formattedRows -} - -// formatTimestampCell renders one TIMESTAMP cell value. A non-string value or -// an unparseable timestamp is passed through unchanged. -func formatTimestampCell(v interface{}, useInt64Timestamp bool) interface{} { - raw, ok := v.(string) - if !ok { - return v - } - t, err := googlesqlite.TimeFromTimestampValue(raw) - if err != nil { - return v - } - micros := t.UnixMicro() - if useInt64Timestamp { - return fmt.Sprint(micros) - } - sec := micros / int64(time.Second/time.Microsecond) - frac := micros % int64(time.Second/time.Microsecond) - if frac < 0 { - frac += int64(time.Second / time.Microsecond) - sec-- - } - return fmt.Sprintf("%d.%06d", sec, frac) -} +package types + +import ( + "fmt" + "time" + + "github.com/apache/arrow-go/v18/arrow/array" + "github.com/goccy/bigquery-emulator/types" + "github.com/goccy/googlesqlite" + bigqueryv2 "google.golang.org/api/bigquery/v2" +) + +type ( + GetQueryResultsResponse struct { + JobReference *bigqueryv2.JobReference `json:"jobReference"` + Schema *bigqueryv2.TableSchema `json:"schema"` + Rows []*TableRow `json:"rows"` + TotalRows uint64 `json:"totalRows,string"` + JobComplete bool `json:"jobComplete"` + PageToken string `json:"pageToken,omitempty"` + TotalBytes uint64 `json:"-"` + } + + QueryResponse struct { + JobReference *bigqueryv2.JobReference `json:"jobReference"` + Schema *bigqueryv2.TableSchema `json:"schema"` + Rows []*TableRow `json:"rows"` + TotalRows uint64 `json:"totalRows,string"` + JobComplete bool `json:"jobComplete"` + TotalBytes int64 `json:"-"` + ChangedCatalog *googlesqlite.ChangedCatalog `json:"-"` + } + + TableDataList struct { + Rows []*TableRow `json:"rows"` + TotalRows uint64 `json:"totalRows,string"` + } + + TableRow struct { + F []*TableCell `json:"f,omitempty"` + } + + // Redefines the TableCell type to return null explicitly + // because TableCell for bigqueryv2 is omitted if V is nil, + TableCell struct { + V interface{} `json:"v"` + Bytes int64 `json:"-"` + Name string `json:"-"` + } +) + +func (r *TableRow) Data() (map[string]interface{}, error) { + rowMap := map[string]interface{}{} + for _, cell := range r.F { + v, err := cell.Data() + if err != nil { + return nil, err + } + rowMap[cell.Name] = v + } + return rowMap, nil +} + +func (r *TableRow) AVROValue(namespace string, fields []*types.AVROFieldSchema) (map[string]interface{}, error) { + rowMap := map[string]interface{}{} + for idx, cell := range r.F { + v, err := cell.AVROValue(namespace, fields[idx]) + if err != nil { + return nil, err + } + rowMap[cell.Name] = v + } + return rowMap, nil +} + +// avroRecordUnionKey returns the Avro union branch name for a nullable record +// field. goavro identifies a record branch within a union by the record's +// full name, which is the enclosing namespace joined with the record name. +func avroRecordUnionKey(namespace, name string) string { + if namespace == "" { + return name + } + return namespace + "." + name +} + +func (c *TableCell) Data() (interface{}, error) { + switch v := c.V.(type) { + case TableRow: + return v.Data() + case []*TableCell: + ret := make([]interface{}, 0, len(v)) + for _, vv := range v { + data, err := vv.Data() + if err != nil { + return nil, err + } + ret = append(ret, data) + } + return ret, nil + default: + if v == nil { + return nil, nil + } + text, ok := v.(string) + if !ok { + return nil, fmt.Errorf("failed to cast to string from %s", v) + } + return text, nil + } +} + +func (c *TableCell) AVROValue(namespace string, schema *types.AVROFieldSchema) (interface{}, error) { + switch v := c.V.(type) { + case TableRow: + fields := types.TableFieldSchemasToAVRO(schema.Type.TypeSchema.Fields) + recordValue, err := v.AVROValue(namespace, fields) + if err != nil { + return nil, err + } + // The schema for a record field mirrors AVROType.MarshalJSON: + // REQUIRED/REPEATED records are encoded bare (REPEATED records are + // the bare element type of an array), while a nullable record is an + // Avro union and goavro requires its value to be wrapped in a + // single-key map keyed by the record's full name. + switch types.Mode(schema.Type.TypeSchema.Mode) { + case types.RequiredMode, types.RepeatedMode: + return recordValue, nil + default: + return map[string]interface{}{ + avroRecordUnionKey(namespace, schema.Type.TypeSchema.Name): recordValue, + }, nil + } + case []*TableCell: + ret := make([]interface{}, 0, len(v)) + for _, vv := range v { + avrov, err := vv.AVROValue(namespace, schema) + if err != nil { + return nil, err + } + ret = append(ret, avrov) + } + return ret, nil + default: + if v == nil { + return map[string]interface{}{schema.Type.Key(): nil}, nil + } + text, ok := v.(string) + if !ok { + return nil, fmt.Errorf("failed to cast to string from %s", v) + } + value, err := schema.Type.CastValue(text) + if err != nil { + return nil, err + } + // A bare value for REQUIRED fields and for the element type of + // REPEATED arrays; a union-wrapped value for nullable fields. + switch types.Mode(schema.Type.TypeSchema.Mode) { + case types.RequiredMode, types.RepeatedMode: + return value, nil + default: + return map[string]interface{}{schema.Type.Key(): value}, nil + } + } +} + +func (r *TableRow) AppendValueToARROWBuilder(builder *array.RecordBuilder) error { + for idx, cell := range r.F { + if err := cell.AppendValueToARROWBuilder(builder.Field(idx)); err != nil { + return err + } + } + return nil +} + +func (r *TableRow) appendValueToARROWBuilder(builder *array.StructBuilder) error { + for idx, cell := range r.F { + if err := cell.AppendValueToARROWBuilder(builder.FieldBuilder(idx)); err != nil { + return err + } + } + return nil +} + +func (c *TableCell) AppendValueToARROWBuilder(builder array.Builder) error { + switch v := c.V.(type) { + case TableRow: + b, ok := builder.(*array.StructBuilder) + if !ok { + return fmt.Errorf("failed to convert to struct builder from %T", builder) + } + b.Append(true) + return v.appendValueToARROWBuilder(b) + case []*TableCell: + listBuilder, ok := builder.(*array.ListBuilder) + if !ok { + return fmt.Errorf("failed to convert to list builder from %T", builder) + } + b := listBuilder.ValueBuilder() + for _, vv := range v { + listBuilder.Append(true) + if err := vv.AppendValueToARROWBuilder(b); err != nil { + return err + } + } + return nil + default: + if v == nil { + return types.AppendValueToARROWBuilder(nil, builder) + } + text, ok := v.(string) + if !ok { + return fmt.Errorf("failed to cast to string from %s", v) + } + return types.AppendValueToARROWBuilder(&text, builder) + } +} + +// Format converts TIMESTAMP result cells from the raw canonical timestamp +// produced by the SQL backend into the representation the BigQuery REST API +// returns: int64 microseconds-since-epoch when useInt64Timestamp is set, and +// otherwise a float seconds-since-epoch string. Every official client decodes +// a TIMESTAMP value as one of those two numeric forms, never as a formatted +// datetime string. +func Format(schema *bigqueryv2.TableSchema, rows []*TableRow, useInt64Timestamp bool) []*TableRow { + formattedRows := make([]*TableRow, 0, len(rows)) + for _, row := range rows { + cells := make([]*TableCell, 0, len(row.F)) + for colIdx, cell := range row.F { + cells = append(cells, formatCell(schema.Fields[colIdx], cell, useInt64Timestamp)) + } + formattedRows = append(formattedRows, &TableRow{ + F: cells, + }) + } + return formattedRows +} + +func formatCell(field *bigqueryv2.TableFieldSchema, cell *TableCell, useInt64Timestamp bool) *TableCell { + if cell.V == nil { + return cell + } + if field == nil { + return cell + } + + if field.Mode == "REPEATED" { + if cells, ok := cell.V.([]*TableCell); ok { + elemField := *field + elemField.Mode = "NULLABLE" + formattedCells := make([]*TableCell, 0, len(cells)) + for _, c := range cells { + formattedCells = append(formattedCells, formatCell(&elemField, c, useInt64Timestamp)) + } + return &TableCell{ + V: formattedCells, + Bytes: cell.Bytes, + Name: cell.Name, + } + } + } + + if field.Type == "RECORD" { + switch row := cell.V.(type) { + case TableRow: + formattedRow := formatRecordRow(field, &row, useInt64Timestamp) + return &TableCell{ + V: *formattedRow, + Bytes: cell.Bytes, + Name: cell.Name, + } + case *TableRow: + if row == nil { + return cell + } + formattedRow := formatRecordRow(field, row, useInt64Timestamp) + return &TableCell{ + V: formattedRow, + Bytes: cell.Bytes, + Name: cell.Name, + } + } + } + + if field.Type == "TIMESTAMP" { + return &TableCell{ + V: formatTimestampCell(cell.V, useInt64Timestamp), + Bytes: cell.Bytes, + Name: cell.Name, + } + } + + return cell +} + +func formatRecordRow(field *bigqueryv2.TableFieldSchema, row *TableRow, useInt64Timestamp bool) *TableRow { + formattedF := make([]*TableCell, 0, len(row.F)) + for i, c := range row.F { + var nestedField *bigqueryv2.TableFieldSchema + if i < len(field.Fields) { + nestedField = field.Fields[i] + } + formattedF = append(formattedF, formatCell(nestedField, c, useInt64Timestamp)) + } + return &TableRow{F: formattedF} +} + +// formatTimestampCell renders one TIMESTAMP cell value. A non-string value or +// an unparseable timestamp is passed through unchanged. +func formatTimestampCell(v interface{}, useInt64Timestamp bool) interface{} { + raw, ok := v.(string) + if !ok { + return v + } + t, err := googlesqlite.TimeFromTimestampValue(raw) + if err != nil { + return v + } + micros := t.UnixMicro() + if useInt64Timestamp { + return fmt.Sprint(micros) + } + sec := micros / int64(time.Second/time.Microsecond) + frac := micros % int64(time.Second/time.Microsecond) + if frac < 0 { + frac += int64(time.Second / time.Microsecond) + sec-- + } + return fmt.Sprintf("%d.%06d", sec, frac) +} diff --git a/internal/types/types_test.go b/internal/types/types_test.go new file mode 100644 index 000000000..20a8f7055 --- /dev/null +++ b/internal/types/types_test.go @@ -0,0 +1,47 @@ +package types + +import ( + "fmt" + "testing" + "time" + + bigqueryv2 "google.golang.org/api/bigquery/v2" +) + +func TestFormatCellHandlesNilField(t *testing.T) { + cell := &TableCell{V: "value"} + got := formatCell(nil, cell, true) + if got != cell { + t.Fatalf("formatCell(nil, cell) = %#v, want same cell pointer", got) + } +} + +func TestFormatCellRecordPointerFormatsNestedTimestamp(t *testing.T) { + field := &bigqueryv2.TableFieldSchema{ + Type: "RECORD", + Fields: []*bigqueryv2.TableFieldSchema{ + {Type: "TIMESTAMP"}, + }, + } + row := &TableRow{ + F: []*TableCell{{V: "2026-06-15T00:00:00Z"}}, + } + cell := &TableCell{V: row} + + got := formatCell(field, cell, true) + gotRow, ok := got.V.(*TableRow) + if !ok { + t.Fatalf("got.V type = %T, want *TableRow", got.V) + } + if len(gotRow.F) != 1 { + t.Fatalf("len(gotRow.F) = %d, want 1", len(gotRow.F)) + } + tm, err := time.Parse(time.RFC3339, "2026-06-15T00:00:00Z") + if err != nil { + t.Fatalf("time.Parse failed: %v", err) + } + want := fmt.Sprint(tm.UnixMicro()) + if gotRow.F[0].V != want { + t.Fatalf("nested timestamp = %v, want %v", gotRow.F[0].V, want) + } +} diff --git a/server/handler.go b/server/handler.go index 32a2c7d50..eed9dd177 100644 --- a/server/handler.go +++ b/server/handler.go @@ -1,3503 +1,3539 @@ -package server - -import ( - "bytes" - "context" - _ "embed" - "encoding/csv" - "errors" - "fmt" - "html" - "io" - "mime" - "mime/multipart" - "net/http" - "os" - "reflect" - "strconv" - "strings" - "sync" - "time" - - "cloud.google.com/go/storage" - "github.com/goccy/go-json" - "github.com/goccy/googlesqlite" - "go.uber.org/zap" - bigqueryv2 "google.golang.org/api/bigquery/v2" - "google.golang.org/api/iterator" - "google.golang.org/api/option" - - "github.com/goccy/bigquery-emulator/internal/connection" - "github.com/goccy/bigquery-emulator/internal/contentdata" - "github.com/goccy/bigquery-emulator/internal/logger" - "github.com/goccy/bigquery-emulator/internal/metadata" - internaltypes "github.com/goccy/bigquery-emulator/internal/types" - "github.com/goccy/bigquery-emulator/types" - "github.com/parquet-go/parquet-go" -) - -func errorResponse(ctx context.Context, w http.ResponseWriter, e *ServerError) { - logger.Logger(ctx).WithOptions(zap.AddCallerSkip(1)).Error(string(e.Reason), zap.Error(e)) - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(e.Status) - w.Write(e.Response()) -} - -// uploadErrorResponse renders an error returned by the upload content handler. -// When the handler produced a typed *ServerError (e.g. a missing dataset), its -// HTTP status is preserved; any other failure is reported as a job error. -func uploadErrorResponse(ctx context.Context, w http.ResponseWriter, err error) { - var serr *ServerError - if errors.As(err, &serr) { - errorResponse(ctx, w, serr) - return - } - errorResponse(ctx, w, errJobInternalError(err.Error())) -} - -func encodeResponse(ctx context.Context, w http.ResponseWriter, response interface{}) { - b, err := json.Marshal(response) - if err != nil { - errorResponse(ctx, w, errInternalError(fmt.Sprintf("failed to encode json: %s", err.Error()))) - return - } - w.Header().Set("Content-Type", "application/json") - w.Write(b) -} - -const ( - discoveryAPIEndpoint = "/discovery/v1/apis/bigquery/v2/rest" - newDiscoveryAPIEndpoint = "/$discovery/rest" - uploadAPIEndpoint = "/upload/bigquery/v2/projects/{projectId}/jobs" -) - -//go:embed resources/discovery.json -var bigqueryAPIJSON []byte - -var ( - discoveryAPIOnce sync.Once - discoveryAPIResponse map[string]interface{} -) - -type discoveryHandler struct { - server *Server -} - -func newDiscoveryHandler(server *Server) *discoveryHandler { - return &discoveryHandler{server: server} -} - -func (h *discoveryHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - var decodeJSONErr error - discoveryAPIOnce.Do(func() { - if err := json.Unmarshal(bigqueryAPIJSON, &discoveryAPIResponse); err != nil { - decodeJSONErr = err - return - } - addr := h.server.httpServer.Addr - if !strings.HasPrefix(addr, "http") { - addr = "http://" + addr - } - discoveryAPIResponse["mtlsRootUrl"] = addr - discoveryAPIResponse["rootUrl"] = addr - discoveryAPIResponse["baseUrl"] = addr - }) - if decodeJSONErr != nil { - errorResponse(ctx, w, errInternalError(decodeJSONErr.Error())) - return - } - encodeResponse(ctx, w, discoveryAPIResponse) -} - -type uploadHandler struct{} - -type UploadJobConfigurationLoad struct { - AllowJaggedRows bool `json:"allowJaggedRows,omitempty"` - AllowQuotedNewlines bool `json:"allowQuotedNewlines,omitempty"` - Autodetect bool `json:"autodetect,omitempty"` - Clustering *bigqueryv2.Clustering `json:"clustering,omitempty"` - CreateDisposition string `json:"createDisposition,omitempty"` - DecimalTargetTypes []string `json:"decimalTargetTypes,omitempty"` - DestinationEncryptionConfiguration *bigqueryv2.EncryptionConfiguration `json:"destinationEncryptionConfiguration,omitempty"` - DestinationTable *bigqueryv2.TableReference `json:"destinationTable,omitempty"` - DestinationTableProperties *bigqueryv2.DestinationTableProperties `json:"destinationTableProperties,omitempty"` - Encoding string `json:"encoding,omitempty"` - FieldDelimiter string `json:"fieldDelimiter,omitempty"` - HivePartitioningOptions *bigqueryv2.HivePartitioningOptions `json:"hivePartitioningOptions,omitempty"` - IgnoreUnknownValues bool `json:"ignoreUnknownValues,omitempty"` - JsonExtension string `json:"jsonExtension,omitempty"` - MaxBadRecords int64 `json:"maxBadRecords,omitempty"` - NullMarker string `json:"nullMarker,omitempty"` - ParquetOptions *bigqueryv2.ParquetOptions `json:"parquetOptions,omitempty"` - PreserveAsciiControlCharacters bool `json:"preserveAsciiControlCharacters,omitempty"` - ProjectionFields []string `json:"projectionFields,omitempty"` - Quote *string `json:"quote,omitempty"` - RangePartitioning *bigqueryv2.RangePartitioning `json:"rangePartitioning,omitempty"` - Schema *bigqueryv2.TableSchema `json:"schema,omitempty"` - SchemaInline string `json:"schemaInline,omitempty"` - SchemaInlineFormat string `json:"schemaInlineFormat,omitempty"` - SchemaUpdateOptions []string `json:"schemaUpdateOptions,omitempty"` - SkipLeadingRows json.Number `json:"skipLeadingRows,omitempty"` - SourceFormat string `json:"sourceFormat,omitempty"` - SourceUris []string `json:"sourceUris,omitempty"` - TimePartitioning *bigqueryv2.TimePartitioning `json:"timePartitioning,omitempty"` - UseAvroLogicalTypes bool `json:"useAvroLogicalTypes,omitempty"` - WriteDisposition string `json:"writeDisposition,omitempty"` -} - -type UploadJobConfiguration struct { - Load *UploadJobConfigurationLoad `json:"load"` -} - -type UploadJob struct { - JobReference *bigqueryv2.JobReference `json:"jobReference"` - Configuration *UploadJobConfiguration `json:"configuration"` -} - -// normalize fills in fields that some client libraries omit from the upload -// metadata. The Node.js client in particular sends no jobReference, which -// previously caused a nil pointer dereference when the handler read the job -// id. A missing job id is generated so the upload still gets a stable handle. -func (j *UploadJob) normalize(projectID string) *ServerError { - if j.Configuration == nil || j.Configuration.Load == nil { - return errInvalid("upload job is missing configuration.load") - } - if j.JobReference == nil { - j.JobReference = &bigqueryv2.JobReference{} - } - if j.JobReference.JobId == "" { - j.JobReference.JobId = randomID() - } - if j.JobReference.ProjectId == "" { - j.JobReference.ProjectId = projectID - } - return nil -} - -func (j *UploadJob) ToJob() *bigqueryv2.Job { - load := j.Configuration.Load - skipLeadingRows, _ := load.SkipLeadingRows.Int64() - return &bigqueryv2.Job{ - JobReference: j.JobReference, - Configuration: &bigqueryv2.JobConfiguration{ - Load: &bigqueryv2.JobConfigurationLoad{ - AllowJaggedRows: load.AllowJaggedRows, - AllowQuotedNewlines: load.AllowQuotedNewlines, - Autodetect: load.Autodetect, - Clustering: load.Clustering, - CreateDisposition: load.CreateDisposition, - DecimalTargetTypes: load.DecimalTargetTypes, - DestinationEncryptionConfiguration: load.DestinationEncryptionConfiguration, - DestinationTable: load.DestinationTable, - DestinationTableProperties: load.DestinationTableProperties, - Encoding: load.Encoding, - FieldDelimiter: load.FieldDelimiter, - HivePartitioningOptions: load.HivePartitioningOptions, - IgnoreUnknownValues: load.IgnoreUnknownValues, - JsonExtension: load.JsonExtension, - MaxBadRecords: load.MaxBadRecords, - NullMarker: load.NullMarker, - ParquetOptions: load.ParquetOptions, - PreserveAsciiControlCharacters: load.PreserveAsciiControlCharacters, - ProjectionFields: load.ProjectionFields, - Quote: load.Quote, - RangePartitioning: load.RangePartitioning, - Schema: load.Schema, - SchemaInline: load.SchemaInline, - SchemaInlineFormat: load.SchemaInlineFormat, - SchemaUpdateOptions: load.SchemaUpdateOptions, - SkipLeadingRows: skipLeadingRows, - SourceFormat: load.SourceFormat, - SourceUris: load.SourceUris, - TimePartitioning: load.TimePartitioning, - UseAvroLogicalTypes: load.UseAvroLogicalTypes, - WriteDisposition: load.WriteDisposition, - }, - }, - } -} - -func (h *uploadHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - switch r.URL.Query().Get("uploadType") { - case "multipart": - h.serveMultipart(w, r) - case "resumable": - h.serveResumable(w, r) - default: - errorResponse(r.Context(), w, errInvalid(`uploadType should be "multipart" or "resumable"`)) - } -} - -func (h *uploadHandler) serveMultipart(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - server := serverFromContext(ctx) - project := projectFromContext(ctx) - contentType, params, err := mime.ParseMediaType(r.Header.Get("Content-Type")) - if err != nil || !strings.HasPrefix(contentType, "multipart/") { - errorResponse(ctx, w, errInvalid("expecting a multipart message")) - return - } - mul := multipart.NewReader(r.Body, params["boundary"]) - p, err := mul.NextPart() - if err != nil { - errorResponse(ctx, w, errInvalid(fmt.Sprintf("failed to load metadata: %s", err.Error()))) - return - } - var job UploadJob - if err := json.NewDecoder(p).Decode(&job); err != nil { - errorResponse(ctx, w, errInvalid(fmt.Sprintf("failed to decode job: %s", err.Error()))) - return - } - if serr := job.normalize(project.ID); serr != nil { - errorResponse(ctx, w, serr) - return - } - uploadJob, err := h.Handle(ctx, &uploadRequest{ - server: server, - project: project, - job: &job, - }) - if err != nil { - errorResponse(ctx, w, errInternalError(err.Error())) - return - } - - p, err = mul.NextPart() - if err != nil { - errorResponse(ctx, w, errInvalid(fmt.Sprintf("multipart request is invalid: %s", err.Error()))) - return - } - u := &uploadContentHandler{} - err = u.Handle(ctx, &uploadContentRequest{ - server: server, - project: project, - job: uploadJob, - reader: p, - }) - if err != nil { - uploadErrorResponse(ctx, w, err) - return - } - encodeResponse(ctx, w, uploadJob.Content()) -} - -func (h *uploadHandler) serveResumable(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - server := serverFromContext(ctx) - project := projectFromContext(ctx) - var job UploadJob - if err := json.NewDecoder(r.Body).Decode(&job); err != nil { - errorResponse(ctx, w, errInvalid(fmt.Sprintf("failed to decode job: %s", err.Error()))) - return - } - if serr := job.normalize(project.ID); serr != nil { - errorResponse(ctx, w, serr) - return - } - res, err := h.Handle(ctx, &uploadRequest{ - server: server, - project: project, - job: &job, - }) - if err != nil { - errorResponse(ctx, w, errInternalError(err.Error())) - return - } - addr := server.httpServer.Addr - if !strings.HasPrefix(addr, "http") { - addr = "http://" + addr - } - addr = strings.TrimRight(addr, "/") - w.Header().Add( - "Location", - fmt.Sprintf( - "%s/upload/bigquery/v2/projects/%s/jobs?uploadType=resumable&upload_id=%s", - addr, - project.ID, - job.JobReference.JobId, - ), - ) - encodeResponse(ctx, w, res.Content()) -} - -type uploadRequest struct { - server *Server - project *metadata.Project - job *UploadJob -} - -func (h *uploadHandler) Handle(ctx context.Context, r *uploadRequest) (*metadata.Job, error) { - conn, err := r.server.connMgr.Connection(ctx, r.project.ID, "") - if err != nil { - return nil, fmt.Errorf("failed to get connection: %w", err) - } - tx, err := conn.Begin(ctx) - if err != nil { - return nil, fmt.Errorf("failed to start transaction: %w", err) - } - defer tx.RollbackIfNotCommitted() - job := metadata.NewJob(r.server.metaRepo, r.project.ID, r.job.JobReference.JobId, r.job.ToJob(), nil, nil) - if err := r.project.AddJob(ctx, tx.Tx(), job); err != nil { - return nil, fmt.Errorf("failed to add job: %w", err) - } - if err := tx.Commit(); err != nil { - return nil, fmt.Errorf("failed to commit job: %w", err) - } - return job, nil -} - -type uploadContentHandler struct{} - -func (h *uploadContentHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - server := serverFromContext(ctx) - project := projectFromContext(ctx) - query := r.URL.Query() - uploadType := query["uploadType"] - if len(uploadType) == 0 { - errorResponse(ctx, w, errInvalid("uploadType parameter is not found")) - return - } - if uploadType[0] != "resumable" { - errorResponse(ctx, w, errInvalid(fmt.Sprintf("uploadType parameter is not resumable %s", uploadType[0]))) - return - } - uploadID := query["upload_id"] - if len(uploadID) == 0 { - errorResponse(ctx, w, errInvalid("upload_id parameter is not found")) - return - } - jobID := uploadID[0] - job := project.Job(jobID) - if job == nil { - errorResponse(ctx, w, errNotFound(fmt.Sprintf("upload job %s is not found", jobID))) - return - } - if err := h.Handle(ctx, &uploadContentRequest{ - server: server, - project: project, - job: job, - reader: r.Body, - }); err != nil { - uploadErrorResponse(ctx, w, err) - return - } - content := job.Content() - content.Status = &bigqueryv2.JobStatus{State: "DONE"} - encodeResponse(ctx, w, content) -} - -type uploadContentRequest struct { - server *Server - project *metadata.Project - job *metadata.Job - reader io.Reader -} - -func (h *uploadContentHandler) getCandidateName(col string, columnNames []string) string { - var ( - foundName string - foundCount int - ) - for _, name := range columnNames { - if strings.Contains(name, col) { - foundName = name - foundCount++ - } - } - if foundCount == 1 { - return foundName - } - return "" -} - -func (h *uploadContentHandler) existsColumnNameInCSVHeader(col string, header []string) bool { - for _, h := range header { - if col == h { - return true - } - } - return false -} - -func (h *uploadContentHandler) normalizeColumnNameForJSONData(columnMap map[string]*types.Column, data map[string]interface{}) { - for k, v := range data { - if _, exists := columnMap[k]; exists { - continue - } - lowerKey := strings.ToLower(k) - var ( - foundCount int - columnName string - ) - for colName := range columnMap { - if lowerKey == strings.ToLower(colName) { - foundCount++ - columnName = colName - } - } - if foundCount == 1 { - delete(data, k) - data[columnName] = v - } - } -} - -func newLoadCSVReader(reader io.Reader, fieldDelimiter string) (*csv.Reader, error) { - csvReader := csv.NewReader(reader) - csvReader.FieldsPerRecord = -1 - if fieldDelimiter == "" { - return csvReader, nil - } - delimiters := []rune(fieldDelimiter) - if len(delimiters) != 1 { - return nil, fmt.Errorf("fieldDelimiter must be a single character") - } - csvReader.Comma = delimiters[0] - return csvReader, nil -} - -func schemaColumns(fields []*bigqueryv2.TableFieldSchema) []*types.Column { - columns := make([]*types.Column, 0, len(fields)) - for _, field := range fields { - columns = append(columns, types.NewColumnWithSchema(field)) - } - return columns -} - -func columnsFromCSVHeader(header []string, columnToType map[string]types.Type) ([]*types.Column, bool) { - columns := make([]*types.Column, 0, len(header)) - for _, col := range header { - columnType, exists := columnToType[col] - if !exists { - return nil, false - } - columns = append(columns, &types.Column{ - Name: col, - Type: columnType, - }) - } - return columns, true -} - -func csvLoadColumnsAndRows(records [][]string, schemaFields []*bigqueryv2.TableFieldSchema, columnToType map[string]types.Type, skipLeadingRows int64) ([]*types.Column, [][]string, error) { - window, err := csvRowWindowFor(len(records), skipLeadingRows) - if err != nil { - return nil, nil, err - } - if !window.hasHeader { - return schemaColumns(schemaFields), nil, nil - } - - dataStart := window.dataStart - if window.headerIndex < len(records) { - if columns, ok := columnsFromCSVHeader(records[window.headerIndex], columnToType); ok { - if dataStart <= window.headerIndex { - dataStart = window.headerIndex + 1 - } - return columns, records[dataStart:], nil - } - } - return schemaColumns(schemaFields), records[dataStart:], nil -} - -func csvRowsToTableData(records [][]string, schemaFields []*bigqueryv2.TableFieldSchema, skipLeadingRows int64, allowJaggedRows bool) ([]*types.Column, types.Data, error) { - columnToType := map[string]types.Type{} - for _, field := range schemaFields { - columnToType[field.Name] = types.Type(field.Type) - } - columns, dataRows, err := csvLoadColumnsAndRows(records, schemaFields, columnToType, skipLeadingRows) - if err != nil { - return nil, nil, err - } - data := types.Data{} - for _, record := range dataRows { - rowData := map[string]interface{}{} - if len(record) > len(columns) || (!allowJaggedRows && len(record) != len(columns)) { - return nil, nil, fmt.Errorf("invalid column number: found broken row data: %v", record) - } - for i := 0; i < len(columns); i++ { - var colData string - if i < len(record) { - colData = record[i] - } - if colData == "" { - rowData[columns[i].Name] = nil - } else { - rowData[columns[i].Name] = colData - } - } - data = append(data, rowData) - } - return columns, data, nil -} - -func (h *uploadContentHandler) Handle(ctx context.Context, r *uploadContentRequest) error { - load := r.job.Content().Configuration.Load - tableRef := load.DestinationTable - if tableRef == nil { - return errInvalid("load job is missing configuration.load.destinationTable") - } - dataset := r.project.Dataset(tableRef.DatasetId) - if dataset == nil { - return errNotFound(fmt.Sprintf("dataset %q is not found", tableRef.DatasetId)) - } - table := dataset.Table(tableRef.TableId) - // The write disposition only matters for a table that already exists; a - // freshly created one is empty regardless. - tableExisted := table != nil - - // Read CSV content up front so an autodetect load can infer the schema - // before the destination table is created. - var csvRecords [][]string - var csvColumns []*types.Column - var csvData types.Data - var csvDataReady bool - if load.SourceFormat == "CSV" { - csvReader, err := newLoadCSVReader(r.reader, load.FieldDelimiter) - if err != nil { - return err - } - records, err := csvReader.ReadAll() - if err != nil { - return fmt.Errorf("failed to read csv: %w", err) - } - csvRecords = records - if !tableExisted && load.Schema == nil && load.Autodetect { - schema, err := inferCSVSchema(csvRecords, load.SkipLeadingRows) - if err != nil { - return err - } - load.Schema = schema - } - } - if table == nil { - if load.CreateDisposition == "CREATE_NEVER" { - return fmt.Errorf("`%s` is not found", tableRef.TableId) - } - if load.SourceFormat == "CSV" && load.Schema != nil { - var err error - csvColumns, csvData, err = csvRowsToTableData(csvRecords, load.Schema.Fields, load.SkipLeadingRows, load.AllowJaggedRows) - if err != nil { - return err - } - csvDataReady = true - } - if _, err := (&tablesInsertHandler{}).Handle(ctx, &tablesInsertRequest{ - server: r.server, - project: r.project, - dataset: dataset, - table: &bigqueryv2.Table{ - Schema: load.Schema, - TableReference: tableRef, - }, - }); err != nil { - return err - } - table = dataset.Table(tableRef.TableId) - } - - tableContent, err := table.Content() - if err != nil { - return err - } - - sourceFormat := load.SourceFormat - columns := []*types.Column{} - data := types.Data{} - switch sourceFormat { - case "CSV": - if csvDataReady { - columns = csvColumns - data = csvData - break - } - var err error - columns, data, err = csvRowsToTableData(csvRecords, tableContent.Schema.Fields, load.SkipLeadingRows, load.AllowJaggedRows) - if err != nil { - return err - } - case "PARQUET": - b, err := io.ReadAll(r.reader) - if err != nil { - return err - } - reader := parquet.NewReader(bytes.NewReader(b)) - defer reader.Close() - - columns = schemaColumns(load.Schema.Fields) - - for i := 0; i < int(reader.NumRows()); i++ { - var rowData interface{} - err := reader.Read(&rowData) - if err != nil { - return err - } - - data = append(data, rowData.(map[string]interface{})) - } - case "NEWLINE_DELIMITED_JSON": - columns = schemaColumns(tableContent.Schema.Fields) - columnMap := map[string]*types.Column{} - for _, col := range columns { - columnMap[col.Name] = col - } - decoder := json.NewDecoder(r.reader) - decoder.UseNumber() - for decoder.More() { - d := make(map[string]interface{}) - if err := decoder.Decode(&d); err != nil { - return err - } - h.normalizeColumnNameForJSONData(columnMap, d) - data = append(data, d) - } - default: - return fmt.Errorf("not support sourceFormat: %s", sourceFormat) - } - tableDef := &types.Table{ - ID: tableRef.TableId, - Columns: columns, - Data: data, - } - conn, err := r.server.connMgr.Connection(ctx, tableRef.ProjectId, tableRef.DatasetId) - if err != nil { - return fmt.Errorf("failed to get connection: %w", err) - } - tx, err := conn.Begin(ctx) - if err != nil { - return err - } - defer tx.RollbackIfNotCommitted() - if tableExisted { - switch load.WriteDisposition { - case "WRITE_TRUNCATE": - if err := r.server.contentRepo.TruncateTable(ctx, tx, tableRef.ProjectId, tableRef.DatasetId, tableRef.TableId); err != nil { - return err - } - case "WRITE_EMPTY": - count, err := r.server.contentRepo.CountTableRows(ctx, tx, tableRef.ProjectId, tableRef.DatasetId, tableRef.TableId) - if err != nil { - return err - } - if count > 0 { - return fmt.Errorf("table %s already exists and contains data (WRITE_EMPTY)", tableRef.TableId) - } - } - } - if err := r.server.contentRepo.AddTableData(ctx, tx, tableRef.ProjectId, tableRef.DatasetId, tableDef); err != nil { - return err - } - if err := tx.Commit(); err != nil { - return err - } - return nil -} - -const ( - formatOptionsUseInt64TimestampParam = "formatOptions.useInt64Timestamp" - deleteContentsParam = "deleteContents" -) - -func isDeleteContents(r *http.Request) bool { - return parseQueryValueAsBool(r, deleteContentsParam) -} - -func isFormatOptionsUseInt64Timestamp(r *http.Request) bool { - return parseQueryValueAsBool(r, formatOptionsUseInt64TimestampParam) -} - -// parseQueryValueAsUint64 reads an unsigned integer query parameter, reporting -// whether it was present and valid. -func parseQueryValueAsUint64(r *http.Request, key string) (uint64, bool) { - values, exists := r.URL.Query()[key] - if !exists || len(values) != 1 { - return 0, false - } - v, err := strconv.ParseUint(values[0], 10, 64) - if err != nil { - return 0, false - } - return v, true -} - -// applyNullQueryParameters inspects the raw JSON of a query-parameters array -// and clears the ParameterValue of every scalar parameter whose value was JSON -// null or absent. The bigqueryv2 structs store a parameter value as a plain -// string, so they cannot otherwise distinguish a NULL scalar from an empty string. -// -// ARRAY and STRUCT parameters must never be cleared, even when their -// parameterValue is empty (e.g. an empty []string{} omits "arrayValues" -// entirely due to JSON omitempty). The parameter type is used as the -// authoritative signal: any parameter whose type is "ARRAY" or "STRUCT" is -// left untouched. -func applyNullQueryParameters(rawParams []json.RawMessage, params []*bigqueryv2.QueryParameter) { - for i, raw := range rawParams { - if i >= len(params) || params[i] == nil { - continue - } - var p struct { - ParameterType *struct { - Type string `json:"type"` - } `json:"parameterType"` - ParameterValue *json.RawMessage `json:"parameterValue"` - } - if err := json.Unmarshal(raw, &p); err != nil { - continue - } - // ARRAY and STRUCT parameters must not be cleared regardless of whether - // their parameterValue happens to be empty. - if p.ParameterType != nil { - switch p.ParameterType.Type { - case "ARRAY", "STRUCT": - continue - } - } - // No parameterValue key at all → treat as null scalar. - if p.ParameterValue == nil { - params[i].ParameterValue = nil - continue - } - // Decode parameterValue into a key-presence map. - // JSON null decodes to a nil map; map lookups on nil return false safely. - var pv map[string]json.RawMessage - if err := json.Unmarshal(*p.ParameterValue, &pv); err != nil { - continue - } - // Scalar: clear if "value" is absent or is the JSON literal null. - valueRaw, hasValue := pv["value"] - if !hasValue || string(valueRaw) == "null" { - params[i].ParameterValue = nil - } - } -} - -func parseQueryValueAsBool(r *http.Request, key string) bool { - queryValues := r.URL.Query() - values, exists := queryValues[key] - if !exists { - return false - } - if len(values) != 1 { - return false - } - b, err := strconv.ParseBool(values[0]) - if err != nil { - return false - } - return b -} - -func (h *datasetsDeleteHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - server := serverFromContext(ctx) - project := projectFromContext(ctx) - dataset := datasetFromContext(ctx) - if err := h.Handle(ctx, &datasetsDeleteRequest{ - server: server, - project: project, - dataset: dataset, - deleteContents: isDeleteContents(r), - }); err != nil { - // Preserve typed *ServerError (e.g. a 400 resourceInUse for a - // non-empty dataset) so the client sees the real HTTP status - // rather than the 500 retry-forever loop a blanket wrap would - // cause. - var serr *ServerError - if errors.As(err, &serr) { - errorResponse(ctx, w, serr) - return - } - errorResponse(ctx, w, errInternalError(err.Error())) - return - } -} - -type datasetsDeleteRequest struct { - server *Server - project *metadata.Project - dataset *metadata.Dataset - deleteContents bool -} - -func (h *datasetsDeleteHandler) Handle(ctx context.Context, r *datasetsDeleteRequest) error { - // BigQuery rejects deleting a non-empty dataset unless - // deleteContents=true. Reject up front so the dataset is never - // removed while its tables remain (which would orphan them and - // surface as a UNIQUE constraint violation on the next CREATE - // TABLE with the same name) and so the caller sees a 4xx rather - // than the 500 the Google SDKs retry indefinitely on. - if !r.deleteContents && len(r.dataset.Tables()) > 0 { - return errResourceInUse(fmt.Sprintf( - "Dataset %s:%s is still in use", - r.project.ID, r.dataset.ID, - )) - } - conn, err := r.server.connMgr.Connection(ctx, r.project.ID, r.dataset.ID) - if err != nil { - return fmt.Errorf("failed to get connection: %w", err) - } - tx, err := conn.Begin(ctx) - if err != nil { - return fmt.Errorf("failed to start transaction: %w", err) - } - defer tx.RollbackIfNotCommitted() - if err := r.project.DeleteDataset(ctx, tx.Tx(), r.dataset.ID); err != nil { - return fmt.Errorf("failed to delete dataset: %w", err) - } - if r.deleteContents { - tables := r.dataset.Tables() - deletions := make([]contentdata.TableDeletion, 0, len(tables)) - for _, table := range tables { - if err := table.Delete(ctx, tx.Tx()); err != nil { - return err - } - deletions = append(deletions, contentdata.TableDeletion{ - ID: table.ID, - IsView: table.IsView(), - }) - } - if err := r.server.contentRepo.DeleteTables(ctx, tx, r.project.ID, r.dataset.ID, deletions); err != nil { - return fmt.Errorf("failed to delete tables: %w", err) - } - } - if err := tx.Commit(); err != nil { - return fmt.Errorf("failed to commit delete dataset: %w", err) - } - return nil -} - -func (h *datasetsGetHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - server := serverFromContext(ctx) - project := projectFromContext(ctx) - dataset := datasetFromContext(ctx) - res, err := h.Handle(ctx, &datasetsGetRequest{ - server: server, - project: project, - dataset: dataset, - }) - if err != nil { - errorResponse(ctx, w, errInternalError(err.Error())) - return - } - encodeResponse(ctx, w, res) -} - -type datasetsGetRequest struct { - server *Server - project *metadata.Project - dataset *metadata.Dataset -} - -func (h *datasetsGetHandler) Handle(ctx context.Context, r *datasetsGetRequest) (*bigqueryv2.Dataset, error) { - newContent := *r.dataset.Content() - newContent.DatasetReference = &bigqueryv2.DatasetReference{ - ProjectId: r.project.ID, - DatasetId: r.dataset.ID, - } - return &newContent, nil -} - -func (h *datasetsInsertHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - server := serverFromContext(ctx) - project := projectFromContext(ctx) - var dataset bigqueryv2.Dataset - if err := json.NewDecoder(r.Body).Decode(&dataset); err != nil { - errorResponse(ctx, w, errInvalid(err.Error())) - return - } - res, err := h.Handle(ctx, &datasetsInsertRequest{ - server: server, - project: project, - dataset: &dataset, - }) - if err != nil { - var serr *ServerError - if errors.As(err, &serr) { - errorResponse(ctx, w, serr) - return - } - errorResponse(ctx, w, errInternalError(err.Error())) - return - } - encodeResponse(ctx, w, res) -} - -type datasetsInsertRequest struct { - server *Server - project *metadata.Project - dataset *bigqueryv2.Dataset -} - -func (h *datasetsInsertHandler) Handle(ctx context.Context, r *datasetsInsertRequest) (*bigqueryv2.DatasetListDatasets, error) { - if r.dataset.DatasetReference == nil { - return nil, fmt.Errorf("DatasetReference is nil") - } - datasetID := r.dataset.DatasetReference.DatasetId - if datasetID == "" { - return nil, fmt.Errorf("dataset id is empty") - } - conn, err := r.server.connMgr.Connection(ctx, r.project.ID, datasetID) - if err != nil { - return nil, fmt.Errorf("failed to get connection: %w", err) - } - tx, err := conn.Begin(ctx) - if err != nil { - return nil, fmt.Errorf("failed to start transaction: %w", err) - } - defer tx.RollbackIfNotCommitted() - - if err := r.project.AddDataset( - ctx, - tx.Tx(), - metadata.NewDataset( - r.server.metaRepo, - r.project.ID, - datasetID, - r.dataset, - nil, - nil, - nil, - ), - ); err != nil { - if errors.Is(err, metadata.ErrDuplicatedDataset) { - return nil, errDuplicate(err.Error()) - } - return nil, err - } - if err := tx.Commit(); err != nil { - return nil, err - } - return &bigqueryv2.DatasetListDatasets{ - DatasetReference: &bigqueryv2.DatasetReference{ - ProjectId: r.project.ID, - DatasetId: datasetID, - }, - Id: datasetID, - FriendlyName: r.dataset.FriendlyName, - Kind: r.dataset.Kind, - Labels: r.dataset.Labels, - Location: r.dataset.Location, - ForceSendFields: r.dataset.ForceSendFields, - NullFields: r.dataset.NullFields, - }, nil -} - -func (h *datasetsListHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - server := serverFromContext(ctx) - project := projectFromContext(ctx) - res, err := h.Handle(ctx, &datasetsListRequest{ - server: server, - project: project, - }) - if err != nil { - errorResponse(ctx, w, errInternalError(err.Error())) - return - } - encodeResponse(ctx, w, res) -} - -type datasetsListRequest struct { - server *Server - project *metadata.Project -} - -func (h *datasetsListHandler) Handle(ctx context.Context, r *datasetsListRequest) (*bigqueryv2.DatasetList, error) { - datasetsRes := []*bigqueryv2.DatasetListDatasets{} - for _, dataset := range r.project.Datasets() { - content := dataset.Content() - datasetsRes = append(datasetsRes, &bigqueryv2.DatasetListDatasets{ - DatasetReference: &bigqueryv2.DatasetReference{ - ProjectId: r.project.ID, - DatasetId: dataset.ID, - }, - FriendlyName: content.FriendlyName, - Id: dataset.ID, - Kind: content.Kind, - Labels: content.Labels, - Location: content.Location, - ForceSendFields: content.ForceSendFields, - NullFields: content.NullFields, - }) - } - return &bigqueryv2.DatasetList{ - Datasets: datasetsRes, - }, nil -} - -func (h *datasetsPatchHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - server := serverFromContext(ctx) - project := projectFromContext(ctx) - dataset := datasetFromContext(ctx) - var newDataset bigqueryv2.Dataset - if err := json.NewDecoder(r.Body).Decode(&newDataset); err != nil { - errorResponse(ctx, w, errInvalid(err.Error())) - return - } - res, err := h.Handle(ctx, &datasetsPatchRequest{ - server: server, - project: project, - dataset: dataset, - newDataset: &newDataset, - }) - if err != nil { - errorResponse(ctx, w, errInternalError(err.Error())) - return - } - encodeResponse(ctx, w, res) -} - -type datasetsPatchRequest struct { - server *Server - project *metadata.Project - dataset *metadata.Dataset - newDataset *bigqueryv2.Dataset -} - -func (h *datasetsPatchHandler) Handle(ctx context.Context, r *datasetsPatchRequest) (*bigqueryv2.Dataset, error) { - r.dataset.UpdateContentIfExists(r.newDataset) - newContent := *r.dataset.Content() - return &newContent, nil -} - -func (h *datasetsUpdateHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - server := serverFromContext(ctx) - project := projectFromContext(ctx) - dataset := datasetFromContext(ctx) - var newDataset bigqueryv2.Dataset - if err := json.NewDecoder(r.Body).Decode(&newDataset); err != nil { - errorResponse(ctx, w, errInvalid(err.Error())) - return - } - res, err := h.Handle(ctx, &datasetsUpdateRequest{ - server: server, - project: project, - dataset: dataset, - newDataset: &newDataset, - }) - if err != nil { - errorResponse(ctx, w, errInternalError(err.Error())) - return - } - encodeResponse(ctx, w, res) -} - -type datasetsUpdateRequest struct { - server *Server - project *metadata.Project - dataset *metadata.Dataset - newDataset *bigqueryv2.Dataset -} - -func (h *datasetsUpdateHandler) Handle(ctx context.Context, r *datasetsUpdateRequest) (*bigqueryv2.Dataset, error) { - r.dataset.UpdateContent(r.newDataset) - newContent := *r.dataset.Content() - return &newContent, nil -} - -func (h *jobsCancelHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - server := serverFromContext(ctx) - project := projectFromContext(ctx) - job := jobFromContext(ctx) - res, err := h.Handle(ctx, &jobsCancelRequest{ - server: server, - project: project, - job: job, - }) - if err != nil { - errorResponse(ctx, w, errJobInternalError(err.Error())) - return - } - encodeResponse(ctx, w, res) -} - -type jobsCancelRequest struct { - server *Server - project *metadata.Project - job *metadata.Job -} - -func (h *jobsCancelHandler) Handle(ctx context.Context, r *jobsCancelRequest) (*bigqueryv2.JobCancelResponse, error) { - if err := r.job.Cancel(ctx); err != nil { - return nil, err - } - return &bigqueryv2.JobCancelResponse{Job: r.job.Content()}, nil -} - -func (h *jobsDeleteHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - server := serverFromContext(ctx) - project := projectFromContext(ctx) - job := jobFromContext(ctx) - if err := h.Handle(ctx, &jobsDeleteRequest{ - server: server, - project: project, - job: job, - }); err != nil { - errorResponse(ctx, w, errJobInternalError(err.Error())) - return - } -} - -type jobsDeleteRequest struct { - server *Server - project *metadata.Project - job *metadata.Job -} - -func (h *jobsDeleteHandler) Handle(ctx context.Context, r *jobsDeleteRequest) error { - conn, err := r.server.connMgr.Connection(ctx, r.project.ID, "") - if err != nil { - return fmt.Errorf("failed to get connection: %w", err) - } - tx, err := conn.Begin(ctx) - if err != nil { - return fmt.Errorf("failed to start transaction: %w", err) - } - defer tx.RollbackIfNotCommitted() - if err := r.project.DeleteJob(ctx, tx.Tx(), r.job.ID); err != nil { - return fmt.Errorf("failed to delete job: %w", err) - } - if err := tx.Commit(); err != nil { - return err - } - return nil -} - -func (h *jobsGetHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - server := serverFromContext(ctx) - project := projectFromContext(ctx) - job := jobFromContext(ctx) - res, err := h.Handle(ctx, &jobsGetRequest{ - server: server, - project: project, - job: job, - }) - if err != nil { - errorResponse(ctx, w, errJobInternalError(err.Error())) - return - } - encodeResponse(ctx, w, res) -} - -type jobsGetRequest struct { - server *Server - project *metadata.Project - job *metadata.Job -} - -func (h *jobsGetHandler) Handle(ctx context.Context, r *jobsGetRequest) (*bigqueryv2.Job, error) { - content := *r.job.Content() - content.Status = &bigqueryv2.JobStatus{State: "DONE"} - return &content, nil -} - -func (h *jobsGetQueryResultsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - server := serverFromContext(ctx) - project := projectFromContext(ctx) - job := jobFromContext(ctx) - maxResults, hasMaxResults := parseQueryValueAsUint64(r, "maxResults") - startIndex, _ := parseQueryValueAsUint64(r, "startIndex") - // The page token returned by this handler is simply the next row index. - if token, ok := parseQueryValueAsUint64(r, "pageToken"); ok { - startIndex = token - } - res, err := h.Handle(ctx, &jobsGetQueryResultsRequest{ - server: server, - project: project, - job: job, - useInt64Timestamp: isFormatOptionsUseInt64Timestamp(r), - maxResults: maxResults, - hasMaxResults: hasMaxResults, - startIndex: startIndex, - }) - if err != nil { - errorResponse(ctx, w, errJobInternalError(err.Error())) - return - } - encodeResponse(ctx, w, res) -} - -type jobsGetQueryResultsRequest struct { - server *Server - project *metadata.Project - job *metadata.Job - useInt64Timestamp bool - maxResults uint64 - hasMaxResults bool - startIndex uint64 -} - -func (h *jobsGetQueryResultsHandler) Handle(ctx context.Context, r *jobsGetQueryResultsRequest) (*internaltypes.GetQueryResultsResponse, error) { - response, err := r.job.Wait(ctx) - if err != nil { - return nil, err - } - rows := internaltypes.Format(response.Schema, response.Rows, r.useInt64Timestamp) - - // Honor maxResults/startIndex paging. Clients (notably the Python one) - // poll getQueryResults with maxResults=0 purely to await completion and - // then fetch the rows through a separate, paged request; returning every - // row on the completion poll hands them rows in a format they did not ask - // for. - total := uint64(len(rows)) - start := r.startIndex - if start > total { - start = total - } - end := total - pageToken := "" - if r.hasMaxResults { - end = start + r.maxResults - if end > total { - end = total - } - if end < total { - pageToken = strconv.FormatUint(end, 10) - } - } - return &internaltypes.GetQueryResultsResponse{ - JobReference: &bigqueryv2.JobReference{ - ProjectId: r.project.ID, - JobId: r.job.ID, - }, - Schema: response.Schema, - TotalRows: response.TotalRows, - JobComplete: true, - PageToken: pageToken, - Rows: rows[start:end], - }, nil -} - -func (h *jobsInsertHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - server := serverFromContext(ctx) - project := projectFromContext(ctx) - body, err := io.ReadAll(r.Body) - // A gzip body flushed but not closed delivers all its content yet ends - // with ErrUnexpectedEOF; the json.Unmarshal below is the real validator. - if err != nil && !errors.Is(err, io.ErrUnexpectedEOF) { - errorResponse(ctx, w, errInvalid(err.Error())) - return - } - var job bigqueryv2.Job - if err := json.Unmarshal(body, &job); err != nil { - errorResponse(ctx, w, errInvalid(err.Error())) - return - } - if job.Configuration != nil && job.Configuration.Query != nil { - var rawReq struct { - Configuration struct { - Query struct { - QueryParameters []json.RawMessage `json:"queryParameters"` - } `json:"query"` - } `json:"configuration"` - } - if err := json.Unmarshal(body, &rawReq); err == nil { - applyNullQueryParameters(rawReq.Configuration.Query.QueryParameters, job.Configuration.Query.QueryParameters) - } - } - res, err := h.Handle(ctx, &jobsInsertRequest{ - server: server, - project: project, - job: &job, - }) - if err != nil { - // Preserve typed *ServerError (e.g. a 404 notFound for a missing - // destination table under CREATE_NEVER) so the client sees the - // real HTTP status rather than a blanket 400 jobInternalError. - var serr *ServerError - if errors.As(err, &serr) { - errorResponse(ctx, w, serr) - return - } - errorResponse(ctx, w, errJobInternalError(err.Error())) - return - } - encodeResponse(ctx, w, res) -} - -type jobsInsertRequest struct { - server *Server - project *metadata.Project - job *bigqueryv2.Job -} - -func (h *jobsInsertHandler) tableDefFromQueryResponse(tableID string, response *internaltypes.QueryResponse) (*types.Table, error) { - columns := []*types.Column{} - for _, field := range response.Schema.Fields { - columns = append(columns, types.NewColumnWithSchema(field)) - } - data := types.Data{} - for _, row := range response.Rows { - rowData, err := row.Data() - if err != nil { - return nil, err - } - data = append(data, rowData) - } - return types.NewTableWithSchema( - &bigqueryv2.Table{ - TableReference: &bigqueryv2.TableReference{ - TableId: tableID, - }, - Schema: response.Schema, - }, - data, - ) -} - -func (h *jobsInsertHandler) destinationProjectAndDataset(ctx context.Context, r *jobsInsertRequest, tableRef *bigqueryv2.TableReference) (*metadata.Project, *metadata.Dataset, error) { - projectID := tableRef.ProjectId - if projectID == "" { - projectID = r.project.ID - tableRef.ProjectId = projectID - } - - project := r.project - if projectID != r.project.ID { - var err error - project, err = r.server.metaRepo.FindProject(ctx, projectID) - if err != nil { - return nil, nil, err - } - if project == nil { - return nil, nil, errNotFound(fmt.Sprintf("project %q is not found", projectID)) - } - } - - dataset := project.Dataset(tableRef.DatasetId) - if dataset == nil { - return nil, nil, errNotFound(fmt.Sprintf("dataset %q is not found in project %q", tableRef.DatasetId, projectID)) - } - return project, dataset, nil -} - -const ( - gcsEmulatorHostEnvName = "STORAGE_EMULATOR_HOST" - gcsURIPrefix = "gs://" -) - -// gcsClientOptions builds the option set for a Cloud Storage client. -// -// - When STORAGE_EMULATOR_HOST is set, the client targets that GCS emulator -// with authentication disabled. -// - Otherwise, if no Application Default Credentials are configured, the -// client falls back to anonymous access so that loads from public buckets -// succeed instead of failing with "could not find default credentials". -// - When GOOGLE_APPLICATION_CREDENTIALS is set, those credentials are used. -func gcsClientOptions(jsonReads bool) []option.ClientOption { - if host := os.Getenv(gcsEmulatorHostEnvName); host != "" { - opts := []option.ClientOption{ - option.WithEndpoint(fmt.Sprintf("%s/storage/v1/", host)), - option.WithoutAuthentication(), - } - if jsonReads { - opts = append(opts, storage.WithJSONReads()) - } - return opts - } - if os.Getenv("GOOGLE_APPLICATION_CREDENTIALS") == "" { - return []option.ClientOption{option.WithoutAuthentication()} - } - return nil -} - -func (h *jobsInsertHandler) importFromGCS(ctx context.Context, r *jobsInsertRequest) (*bigqueryv2.Job, error) { - client, err := storage.NewClient(ctx, gcsClientOptions(true)...) - if err != nil { - return nil, err - } - startTime := time.Now() - // The write disposition applies to the load job as a whole, so it is - // honored only for the first object imported; every later object (e.g. - // the matches of a wildcard URI) appends to what the first one wrote. - importObject := func(reader *storage.Reader) error { - if err := h.importFromGCSObject(ctx, r, reader); err != nil { - return err - } - r.job.Configuration.Load.WriteDisposition = "WRITE_APPEND" - return nil - } - for _, uri := range r.job.Configuration.Load.SourceUris { - if !strings.HasPrefix(uri, gcsURIPrefix) { - return nil, fmt.Errorf("load source uri must start with gs://") - } - uri = strings.TrimPrefix(uri, gcsURIPrefix) - paths := strings.Split(uri, "/") - if len(paths) < 2 { - return nil, fmt.Errorf("unexpected gcs uri format %s", uri) - } - bucketName := paths[0] - objectPath := strings.Join(paths[1:], "/") - switch strings.Count(objectPath, "*") { - case 0: - reader, err := client.Bucket(bucketName).Object(objectPath).NewReader(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get gcs object reader for %s: %w", uri, err) - } - if err := importObject(reader); err != nil { - return nil, err - } - case 1: - splitPath := strings.Split(objectPath, "*") - prefix := splitPath[0] - suffix := splitPath[1] - query := &storage.Query{ - Prefix: prefix, - } - query.SetAttrSelection([]string{"Name"}) - it := client.Bucket(bucketName).Objects(ctx, query) - for { - attrs, err := it.Next() - if err == iterator.Done { - break - } - if err != nil { - return nil, fmt.Errorf("failed to list gcs object for %s: %w", uri, err) - } - if strings.HasSuffix(attrs.Name, suffix) { - reader, err := client.Bucket(bucketName).Object(attrs.Name).NewReader(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get gcs object reader for %s: %w", uri, err) - } - if err := importObject(reader); err != nil { - return nil, err - } - } - } - default: - return nil, fmt.Errorf("the number of wildcards in gcs uri must be 0 or 1") - } - } - endTime := time.Now() - job := r.job - job.Kind = "bigquery#job" - job.Configuration.JobType = "LOAD" - job.SelfLink = fmt.Sprintf( - "http://%s/bigquery/v2/projects/%s/jobs/%s", - r.server.httpServer.Addr, - r.project.ID, - job.JobReference.JobId, - ) - job.Status = &bigqueryv2.JobStatus{State: "DONE"} - job.Statistics = &bigqueryv2.JobStatistics{ - CreationTime: startTime.Unix(), - StartTime: startTime.Unix(), - EndTime: endTime.Unix(), - } - conn, err := r.server.connMgr.Connection(ctx, r.project.ID, "") - if err != nil { - return nil, fmt.Errorf("failed to get connection: %w", err) - } - tx, err := conn.Begin(ctx) - if err != nil { - return nil, fmt.Errorf("failed to start transaction: %w", err) - } - defer tx.RollbackIfNotCommitted() - if err := r.project.AddJob( - ctx, - tx.Tx(), - metadata.NewJob( - r.server.metaRepo, - r.project.ID, - job.JobReference.JobId, - job, - nil, - nil, - ), - ); err != nil { - return nil, fmt.Errorf("failed to add job: %w", err) - } - if !job.Configuration.DryRun { - if err := tx.Commit(); err != nil { - return nil, fmt.Errorf("failed to commit job: %w", err) - } - } - return job, nil -} - -func (h *jobsInsertHandler) importFromGCSObject(ctx context.Context, r *jobsInsertRequest, reader *storage.Reader) error { - defer func() { - _ = reader.Close() - }() - job := metadata.NewJob( - r.server.metaRepo, - r.project.ID, - r.job.JobReference.JobId, - r.job, - nil, - nil, - ) - if err := new(uploadContentHandler).Handle(ctx, &uploadContentRequest{ - server: r.server, - project: r.project, - job: job, - reader: reader, - }); err != nil { - return err - } - return nil -} - -func (h *jobsInsertHandler) exportToGCS(ctx context.Context, r *jobsInsertRequest) (*bigqueryv2.Job, error) { - client, err := storage.NewClient(ctx, gcsClientOptions(false)...) - if err != nil { - return nil, err - } - startTime := time.Now() - conn, err := r.server.connMgr.Connection(ctx, r.project.ID, "") - if err != nil { - return nil, fmt.Errorf("failed to get connection: %w", err) - } - tx, err := conn.Begin(ctx) - if err != nil { - return nil, fmt.Errorf("failed to start transaction: %w", err) - } - defer tx.RollbackIfNotCommitted() - extract := r.job.Configuration.Extract - sourceTable := extract.SourceTable - response, err := r.server.contentRepo.Query( - ctx, - tx, - sourceTable.ProjectId, - sourceTable.DatasetId, - fmt.Sprintf("SELECT * FROM `%s`", sourceTable.TableId), - nil, - ) - if err != nil { - return nil, err - } - for _, uri := range extract.DestinationUris { - if !strings.HasPrefix(uri, gcsURIPrefix) { - return nil, fmt.Errorf("destination uri must start with gs://") - } - uri = strings.TrimPrefix(uri, gcsURIPrefix) - paths := strings.Split(uri, "/") - if len(paths) < 2 { - return nil, fmt.Errorf("unexpected gcs uri format %s", uri) - } - bucketName := paths[0] - objectPath := strings.Join(paths[1:], "/") - bucket := client.Bucket(bucketName) - _ = bucket.Create(ctx, r.project.ID, nil) // ignore "already exists" error. - writer := bucket.Object(objectPath).NewWriter(ctx) - if err := h.exportToGCSWithObject(ctx, response, extract, writer); err != nil { - return nil, err - } - } - endTime := time.Now() - job := r.job - job.Kind = "bigquery#job" - job.Configuration.JobType = "EXTRACT" - job.SelfLink = fmt.Sprintf( - "http://%s/bigquery/v2/projects/%s/jobs/%s", - r.server.httpServer.Addr, - r.project.ID, - job.JobReference.JobId, - ) - job.Status = &bigqueryv2.JobStatus{State: "DONE"} - job.Statistics = &bigqueryv2.JobStatistics{ - CreationTime: startTime.Unix(), - StartTime: startTime.Unix(), - EndTime: endTime.Unix(), - } - if err := r.project.AddJob( - ctx, - tx.Tx(), - metadata.NewJob( - r.server.metaRepo, - r.project.ID, - job.JobReference.JobId, - job, - nil, - nil, - ), - ); err != nil { - return nil, fmt.Errorf("failed to add job: %w", err) - } - if !job.Configuration.DryRun { - if err := tx.Commit(); err != nil { - return nil, fmt.Errorf("failed to commit job: %w", err) - } - } - return job, nil -} - -func (h *jobsInsertHandler) exportToGCSWithObject(ctx context.Context, response *internaltypes.QueryResponse, extract *bigqueryv2.JobConfigurationExtract, writer *storage.Writer) (e error) { - defer func() { - if err := writer.Close(); err != nil { - e = err - } - }() - switch extract.DestinationFormat { - case "CSV": - if len(response.Rows) == 0 { - if _, err := writer.Write(nil); err != nil { - return fmt.Errorf("failed to empty table data to gcs object: %w", err) - } - return nil - } - csvWriter := csv.NewWriter(writer) - var columns []string - for _, cell := range response.Rows[0].F { - columns = append(columns, cell.Name) - } - if extract.PrintHeader == nil { - if err := csvWriter.Write(columns); err != nil { - return fmt.Errorf("failed to encode csv columns: %w", err) - } - } - for _, row := range response.Rows { - data, err := row.Data() - if err != nil { - return fmt.Errorf("failed to get data from table row: %w", err) - } - var records []string - for _, col := range columns { - value := data[col] - if value == nil { - records = append(records, "") - continue - } - if v, ok := value.(string); ok { - records = append(records, v) - continue - } - jsonValue, err := json.Marshal(value) - if err != nil { - return fmt.Errorf("failed to encode row value: %w", err) - } - records = append(records, string(jsonValue)) - } - if err := csvWriter.Write(records); err != nil { - return fmt.Errorf("failed to encode csv data: %w", err) - } - } - csvWriter.Flush() - if err := csvWriter.Error(); err != nil { - return fmt.Errorf("failed to encode csv data: %w", err) - } - case "NEWLINE_DELIMITED_JSON": - writer.ContentType = "application/json" - enc := json.NewEncoder(writer) - for _, row := range response.Rows { - data, err := row.Data() - if err != nil { - return fmt.Errorf("failed to get data from table row: %w", err) - } - if err := enc.Encode(data); err != nil { - return fmt.Errorf("failed to encode table data: %w", err) - } - } - case "PARQUET": - var opts []parquet.WriterOption - switch extract.Compression { - case "GZIP": - opts = append(opts, parquet.Compression(&parquet.Gzip)) - case "SNAPPY": - opts = append(opts, parquet.Compression(&parquet.Snappy)) - case "DEFLATE": - opts = append(opts, parquet.Compression(&parquet.Gzip)) - } - _ = opts - fallthrough - default: - return fmt.Errorf("failed to export to gcs: unsupported destination format %s", extract.DestinationFormat) - } - return nil -} - -func (h *jobsInsertHandler) Handle(ctx context.Context, r *jobsInsertRequest) (*bigqueryv2.Job, error) { - job := r.job - if job.Configuration == nil { - return nil, fmt.Errorf("unspecified job configuration") - } - if job.Configuration.Query == nil { - if job.Configuration.Load != nil && len(job.Configuration.Load.SourceUris) != 0 { - // load from google cloud storage - job, err := h.importFromGCS(ctx, r) - if err != nil { - return nil, fmt.Errorf("failed to import from gcs: %w", err) - } - return job, nil - } else if job.Configuration.Extract != nil && len(job.Configuration.Extract.DestinationUris) != 0 { - job, err := h.exportToGCS(ctx, r) - if err != nil { - return nil, fmt.Errorf("failed to export to gcs: %w", err) - } - return job, nil - } - return nil, fmt.Errorf("unspecified job configuration query") - } - queryProjectID, datasetID := queryProjectAndDataset(job.Configuration.Query.DefaultDataset, r.project.ID) - conn, err := r.server.connMgr.Connection(ctx, queryProjectID, datasetID) - if err != nil { - return nil, fmt.Errorf("failed to get connection: %w", err) - } - tx, err := conn.Begin(ctx) - if err != nil { - return nil, fmt.Errorf("failed to start transaction: %w", err) - } - defer tx.RollbackIfNotCommitted() - hasDestinationTable := job.Configuration.Query.DestinationTable != nil - startTime := time.Now() - response, jobErr := r.server.contentRepo.Query( - ctx, - tx, - queryProjectID, - datasetID, - job.Configuration.Query.Query, - job.Configuration.Query.QueryParameters, - ) - endTime := time.Now() - if job.JobReference.JobId == "" { - job.JobReference.JobId = randomID() // generate job id - } - if jobErr == nil { - if hasDestinationTable { - // insert results to destination table - tableRef := job.Configuration.Query.DestinationTable - destinationProject, destinationDataset, err := h.destinationProjectAndDataset(ctx, r, tableRef) - if err != nil { - return nil, err - } - tableDef, err := h.tableDefFromQueryResponse(tableRef.TableId, response) - if err != nil { - return nil, err - } - destinationTable := destinationDataset.Table(tableRef.TableId) - destinationTableExists := destinationTable != nil - if !destinationTableExists { - // CreateDisposition controls whether a missing destination - // table is materialized on the fly. CREATE_NEVER must - // surface the missing table as a 404 (matching real - // BigQuery and load-job behaviour); CREATE_IF_NEEDED (the - // default) and an empty value create it from the query's - // inferred schema. - if job.Configuration.Query.CreateDisposition == "CREATE_NEVER" { - return nil, errNotFound(fmt.Sprintf( - "Not found: Table %s:%s.%s", - tableRef.ProjectId, tableRef.DatasetId, tableRef.TableId, - )) - } - table := tableDef.ToBigqueryV2(destinationProject.ID, tableRef.DatasetId) - _, err := createTableMetadata(ctx, tx, r.server, destinationProject, destinationDataset, table) - if err != nil { - return nil, fmt.Errorf("failed to create table: %w", err) - } - tx.SetProjectAndDataset(destinationProject.ID, tableRef.DatasetId) - serverErr := r.server.contentRepo.CreateTable(ctx, tx, tableDef.ToBigqueryV2(destinationProject.ID, tableRef.DatasetId)) - if serverErr != nil { - return nil, fmt.Errorf("failed to create table: %w", serverErr) - } - } - if err := r.server.contentRepo.AddTableData(ctx, tx, destinationProject.ID, tableRef.DatasetId, tableDef); err != nil { - return nil, fmt.Errorf("failed to add table data: %w", err) - } - } else if response != nil && response.Schema != nil && len(response.Schema.Fields) > 0 { - // A query that produces a result set (a SELECT, as opposed to a - // DDL/DML statement that has no result schema) still gets an - // anonymous results table. The job must advertise it through - // configuration.query.destinationTable: clients such as the Ruby - // one read query results through that reference. - destRef, err := h.addQueryResultToDynamicDestinationTable(ctx, tx, r, response) - if err != nil { - return nil, fmt.Errorf("failed to add query result to dynamic destination table: %w", err) - } - job.Configuration.Query.DestinationTable = destRef - } - } - job.Kind = "bigquery#job" - job.Configuration.JobType = "QUERY" - job.Configuration.Query.Priority = "INTERACTIVE" - job.SelfLink = fmt.Sprintf( - "http://%s/bigquery/v2/projects/%s/jobs/%s", - r.server.httpServer.Addr, - r.project.ID, - job.JobReference.JobId, - ) - status := &bigqueryv2.JobStatus{State: "DONE"} - if jobErr != nil { - internalErr := errJobInternalError(jobErr.Error()) - status.ErrorResult = internalErr.ErrorProto() - status.Errors = []*bigqueryv2.ErrorProto{internalErr.ErrorProto()} - } - job.Status = status - var totalBytes int64 - if response != nil { - totalBytes = response.TotalBytes - } - job.Statistics = &bigqueryv2.JobStatistics{ - Query: &bigqueryv2.JobStatistics2{ - CacheHit: false, - StatementType: "SELECT", - TotalBytesBilled: totalBytes, - TotalBytesProcessed: totalBytes, - }, - CreationTime: startTime.Unix(), - StartTime: startTime.Unix(), - EndTime: endTime.Unix(), - TotalBytesProcessed: totalBytes, - } - if err := r.project.AddJob( - ctx, - tx.Tx(), - metadata.NewJob( - r.server.metaRepo, - r.project.ID, - job.JobReference.JobId, - job, - response, - jobErr, - ), - ); err != nil { - return nil, fmt.Errorf("failed to add job: %w", err) - } - if !job.Configuration.DryRun { - if err := tx.Commit(); err != nil { - return nil, fmt.Errorf("failed to commit job: %w", err) - } - if response != nil && response.ChangedCatalog.Changed() { - if err := syncCatalog(ctx, r.server, response.ChangedCatalog); err != nil { - return nil, err - } - } - } - - return job, nil -} - -func queryProjectAndDataset(defaultDataset *bigqueryv2.DatasetReference, fallbackProjectID string) (string, string) { - projectID := fallbackProjectID - var datasetID string - if defaultDataset != nil { - if defaultDataset.ProjectId != "" { - projectID = defaultDataset.ProjectId - } - datasetID = defaultDataset.DatasetId - } - return projectID, datasetID -} - -func syncCatalog(ctx context.Context, server *Server, cat *googlesqlite.ChangedCatalog) error { - for _, table := range cat.Table.Added { - if err := addTableMetadata(ctx, server, table); err != nil { - return err - } - } - for _, table := range cat.Table.Deleted { - if err := deleteTableMetadata(ctx, server, table); err != nil { - return err - } - } - return nil -} - -func addTableMetadata(ctx context.Context, server *Server, spec *googlesqlite.TableSpec) error { - if len(spec.NamePath) != 3 { - return fmt.Errorf("unexpected table name path: %v", spec.NamePath) - } - projectID := spec.NamePath[0] - datasetID := spec.NamePath[1] - tableID := spec.NamePath[2] - project, err := server.metaRepo.FindProject(ctx, projectID) - if err != nil { - return err - } - dataset := project.Dataset(datasetID) - if dataset == nil { - return fmt.Errorf("dataset %s is not found", datasetID) - } - fields := make([]*bigqueryv2.TableFieldSchema, 0, len(spec.Columns)) - for _, column := range spec.Columns { - fields = append(fields, types.TableFieldSchemaFromColumnType(column.Name, column.Type)) - } - conn, err := server.connMgr.Connection(ctx, projectID, datasetID) - if err != nil { - return err - } - tx, err := conn.Begin(ctx) - if err != nil { - return err - } - defer tx.RollbackIfNotCommitted() - table := &bigqueryv2.Table{ - TableReference: &bigqueryv2.TableReference{ - ProjectId: projectID, - DatasetId: datasetID, - TableId: tableID, - }, - Schema: &bigqueryv2.TableSchema{Fields: fields}, - } - // A view created by a CREATE VIEW DDL statement must be recorded as a - // view so it is typed correctly and dropped with DROP VIEW. - if spec.IsView { - table.View = &bigqueryv2.ViewDefinition{Query: spec.Query} - } - if _, err := createTableMetadata(ctx, tx, server, project, dataset, table); err != nil { - return err - } - if err := tx.Commit(); err != nil { - return err - } - return nil -} - -func deleteTableMetadata(ctx context.Context, server *Server, spec *googlesqlite.TableSpec) error { - if len(spec.NamePath) != 3 { - return fmt.Errorf("unexpected table name path: %v", spec.NamePath) - } - projectID := spec.NamePath[0] - datasetID := spec.NamePath[1] - tableID := spec.NamePath[2] - project, err := server.metaRepo.FindProject(ctx, projectID) - if err != nil { - return err - } - dataset := project.Dataset(datasetID) - if dataset == nil { - return fmt.Errorf("dataset %s is not found", datasetID) - } - table := dataset.Table(tableID) - conn, err := server.connMgr.Connection(ctx, projectID, datasetID) - if err != nil { - return err - } - tx, err := conn.Begin(ctx) - if err != nil { - return err - } - defer tx.RollbackIfNotCommitted() - if err := table.Delete(ctx, tx.Tx()); err != nil { - return err - } - if err := tx.Commit(); err != nil { - return err - } - return nil -} - -// addQueryResultToDynamicDestinationTable materializes the result of a query -// that had no explicit destination into an anonymous results table (named -// after the job id) and returns a reference to it. -func (h *jobsInsertHandler) addQueryResultToDynamicDestinationTable(ctx context.Context, tx *connection.Tx, r *jobsInsertRequest, response *internaltypes.QueryResponse) (*bigqueryv2.TableReference, error) { - projectID := r.project.ID - jobID := r.job.JobReference.JobId - datasetID := jobID - tableID := jobID - - tableDef, err := h.tableDefFromQueryResponse(tableID, response) - if err != nil { - return nil, err - } - tableDef.SetupMetadata(projectID, datasetID) - table := metadata.NewTable(r.server.metaRepo, projectID, datasetID, tableID, tableDef.Metadata) - dataset := metadata.NewDataset( - r.server.metaRepo, - projectID, - datasetID, - &bigqueryv2.Dataset{ - Id: fmt.Sprintf("%s:%s", projectID, datasetID), - DatasetReference: &bigqueryv2.DatasetReference{ - ProjectId: projectID, - DatasetId: datasetID, - }, - }, - []*metadata.Table{table}, - nil, - nil, - ) - if err := r.project.AddDataset(ctx, tx.Tx(), dataset); err != nil { - return nil, err - } - if err := r.server.metaRepo.AddTable(ctx, tx.Tx(), table); err != nil { - return nil, err - } - tx.SetProjectAndDataset(projectID, datasetID) - if err := r.server.contentRepo.CreateTable(ctx, tx, tableDef.ToBigqueryV2(projectID, datasetID)); err != nil { - return nil, err - } - if err := r.server.contentRepo.AddTableData(ctx, tx, projectID, datasetID, tableDef); err != nil { - return nil, fmt.Errorf("failed to add table data: %w", err) - } - return &bigqueryv2.TableReference{ - ProjectId: projectID, - DatasetId: datasetID, - TableId: tableID, - }, nil -} - -func (h *jobsListHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - server := serverFromContext(ctx) - project := projectFromContext(ctx) - res, err := h.Handle(ctx, &jobsListRequest{ - server: server, - project: project, - }) - if err != nil { - errorResponse(ctx, w, errJobInternalError(err.Error())) - return - } - encodeResponse(ctx, w, res) -} - -type jobsListRequest struct { - server *Server - project *metadata.Project -} - -func (h *jobsListHandler) Handle(ctx context.Context, r *jobsListRequest) (*bigqueryv2.JobList, error) { - jobs := []*bigqueryv2.JobListJobs{} - for _, job := range r.project.Jobs() { - content := job.Content() - jobs = append(jobs, &bigqueryv2.JobListJobs{ - Id: content.Id, - JobReference: content.JobReference, - Kind: content.Kind, - Statistics: content.Statistics, - Status: content.Status, - UserEmail: content.UserEmail, - }) - } - return &bigqueryv2.JobList{Jobs: jobs}, nil -} - -func (h *jobsQueryHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - server := serverFromContext(ctx) - project := projectFromContext(ctx) - body, err := io.ReadAll(r.Body) - // A gzip body flushed but not closed delivers all its content yet ends - // with ErrUnexpectedEOF; the json.Unmarshal below is the real validator. - if err != nil && !errors.Is(err, io.ErrUnexpectedEOF) { - errorResponse(ctx, w, errInvalid(err.Error())) - return - } - var req bigqueryv2.QueryRequest - if err := json.Unmarshal(body, &req); err != nil { - errorResponse(ctx, w, errInvalid(err.Error())) - return - } - var rawReq struct { - QueryParameters []json.RawMessage `json:"queryParameters"` - } - if err := json.Unmarshal(body, &rawReq); err == nil { - applyNullQueryParameters(rawReq.QueryParameters, req.QueryParameters) - } - useInt64Timestamp := false - if options := req.FormatOptions; options != nil { - useInt64Timestamp = options.UseInt64Timestamp - } - useInt64Timestamp = useInt64Timestamp || isFormatOptionsUseInt64Timestamp(r) - res, err := h.Handle(ctx, &jobsQueryRequest{ - server: server, - project: project, - queryRequest: &req, - useInt64Timestamp: useInt64Timestamp, - }) - if err != nil { - errorResponse(ctx, w, errJobInternalError(err.Error())) - return - } - encodeResponse(ctx, w, res) -} - -type jobsQueryRequest struct { - server *Server - project *metadata.Project - queryRequest *bigqueryv2.QueryRequest - useInt64Timestamp bool -} - -func (h *jobsQueryHandler) Handle(ctx context.Context, r *jobsQueryRequest) (*internaltypes.QueryResponse, error) { - queryProjectID, datasetID := queryProjectAndDataset(r.queryRequest.DefaultDataset, r.project.ID) - conn, err := r.server.connMgr.Connection(ctx, queryProjectID, datasetID) - if err != nil { - return nil, err - } - tx, err := conn.Begin(ctx) - if err != nil { - return nil, err - } - defer tx.RollbackIfNotCommitted() - startTime := time.Now() - response, queryErr := r.server.contentRepo.Query( - ctx, - tx, - queryProjectID, - datasetID, - r.queryRequest.Query, - r.queryRequest.QueryParameters, - ) - if queryErr != nil { - return nil, queryErr - } - endTime := time.Now() - // jobs.query allocates jobIDs server-side (real BigQuery does the - // same). queryRequest.RequestId is the *idempotency* key — same - // RequestId on retry should return the cached result — and is - // deliberately not the jobID. The Go BigQuery client in particular - // instantiates a fresh `uid.NewSpace("request", …)` per call, so its - // per-Space atomic counter always emits `-0001`; concurrent - // in-process callers therefore submit identical RequestIds and would - // collide on AddJob if we routed them through as the jobID. Idempotency - // (RequestId → cached response) is a TODO; for now every call gets a - // fresh jobID. - jobID := randomID() - - // Persist the job in the metadata store, mirroring what - // jobsInsertHandler.Handle does at line ~1758. The previous behaviour - // returned a JobReference whose ID was never recorded, so any client - // that then issued `jobs.get(jobID)` — e.g. Go's - // `RowIterator.SourceJob().Status(ctx)`, or the Java/Node clients - // that poll the job for status after a synchronous query — got a 404 - // and (for clients that treat 404 as transient) hung re-polling. The - // synthetic Job below carries the minimum fields the GET handler - // returns to the caller; DryRun queries skip both the AddJob and the - // Commit so they remain side-effect-free. - var totalBytes int64 - if response != nil { - totalBytes = response.TotalBytes - } - job := &bigqueryv2.Job{ - Kind: "bigquery#job", - JobReference: &bigqueryv2.JobReference{ - ProjectId: r.project.ID, - JobId: jobID, - Location: r.queryRequest.Location, - }, - Configuration: &bigqueryv2.JobConfiguration{ - JobType: "QUERY", - DryRun: r.queryRequest.DryRun, - Query: &bigqueryv2.JobConfigurationQuery{ - Query: r.queryRequest.Query, - DefaultDataset: r.queryRequest.DefaultDataset, - QueryParameters: r.queryRequest.QueryParameters, - Priority: "INTERACTIVE", - }, - }, - Status: &bigqueryv2.JobStatus{State: "DONE"}, - Statistics: &bigqueryv2.JobStatistics{ - Query: &bigqueryv2.JobStatistics2{ - CacheHit: false, - StatementType: "SELECT", - TotalBytesBilled: totalBytes, - TotalBytesProcessed: totalBytes, - }, - CreationTime: startTime.Unix(), - StartTime: startTime.Unix(), - EndTime: endTime.Unix(), - TotalBytesProcessed: totalBytes, - }, - SelfLink: fmt.Sprintf( - "http://%s/bigquery/v2/projects/%s/jobs/%s", - r.server.httpServer.Addr, - r.project.ID, - jobID, - ), - } - if !r.queryRequest.DryRun { - if err := r.project.AddJob( - ctx, - tx.Tx(), - metadata.NewJob( - r.server.metaRepo, - r.project.ID, - jobID, - job, - response, - nil, - ), - ); err != nil { - return nil, fmt.Errorf("failed to add job: %w", err) - } - if err := tx.Commit(); err != nil { - return nil, err - } - if response.ChangedCatalog.Changed() { - if err := syncCatalog(ctx, r.server, response.ChangedCatalog); err != nil { - return nil, err - } - } - } - response.Rows = internaltypes.Format(response.Schema, response.Rows, r.useInt64Timestamp) - response.JobReference = job.JobReference - return response, nil -} - -func (h *modelsDeleteHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - server := serverFromContext(ctx) - project := projectFromContext(ctx) - dataset := datasetFromContext(ctx) - model := modelFromContext(ctx) - if err := h.Handle(ctx, &modelsDeleteRequest{ - server: server, - project: project, - dataset: dataset, - model: model, - }); err != nil { - errorResponse(ctx, w, errInternalError(err.Error())) - } -} - -type modelsDeleteRequest struct { - server *Server - project *metadata.Project - dataset *metadata.Dataset - model *metadata.Model -} - -func (h *modelsDeleteHandler) Handle(ctx context.Context, r *modelsDeleteRequest) error { - conn, err := r.server.connMgr.Connection(ctx, r.project.ID, r.dataset.ID) - if err != nil { - return err - } - tx, err := conn.Begin(ctx) - if err != nil { - return fmt.Errorf("failed to start transaction: %w", err) - } - defer tx.RollbackIfNotCommitted() - if err := r.dataset.DeleteModel(ctx, tx.Tx(), r.model.ID); err != nil { - return fmt.Errorf("failed to delete model: %w", err) - } - if err := tx.Commit(); err != nil { - return err - } - return nil -} - -func (h *modelsGetHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - server := serverFromContext(ctx) - project := projectFromContext(ctx) - dataset := datasetFromContext(ctx) - model := modelFromContext(ctx) - res, err := h.Handle(ctx, &modelsGetRequest{ - server: server, - project: project, - dataset: dataset, - model: model, - }) - if err != nil { - errorResponse(ctx, w, errInternalError(err.Error())) - return - } - encodeResponse(ctx, w, res) -} - -type modelsGetRequest struct { - server *Server - project *metadata.Project - dataset *metadata.Dataset - model *metadata.Model -} - -func (h *modelsGetHandler) Handle(ctx context.Context, r *modelsGetRequest) (*bigqueryv2.Model, error) { - return &bigqueryv2.Model{ - BestTrialId: 0, - CreationTime: 0, - DefaultTrialId: 0, - Description: "", - EncryptionConfiguration: nil, - Etag: "", - ExpirationTime: 0, - FeatureColumns: nil, - FriendlyName: "", - HparamSearchSpaces: nil, - HparamTrials: nil, - LabelColumns: nil, - Labels: nil, - LastModifiedTime: 0, - Location: "", - ModelReference: nil, - ModelType: "", - OptimalTrialIds: nil, - TrainingRuns: nil, - }, nil -} - -func (h *modelsListHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - server := serverFromContext(ctx) - project := projectFromContext(ctx) - dataset := datasetFromContext(ctx) - res, err := h.Handle(ctx, &modelsListRequest{ - server: server, - project: project, - dataset: dataset, - }) - if err != nil { - errorResponse(ctx, w, errInternalError(err.Error())) - return - } - encodeResponse(ctx, w, res) -} - -type modelsListRequest struct { - server *Server - project *metadata.Project - dataset *metadata.Dataset -} - -func (h *modelsListHandler) Handle(ctx context.Context, r *modelsListRequest) (*bigqueryv2.ListModelsResponse, error) { - models := []*bigqueryv2.Model{} - for _, m := range r.dataset.Models() { - _ = m - models = append(models, &bigqueryv2.Model{}) - } - return &bigqueryv2.ListModelsResponse{ - Models: models, - }, nil -} - -func (h *modelsPatchHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - server := serverFromContext(ctx) - project := projectFromContext(ctx) - dataset := datasetFromContext(ctx) - model := modelFromContext(ctx) - res, err := h.Handle(ctx, &modelsPatchRequest{ - server: server, - project: project, - dataset: dataset, - model: model, - }) - if err != nil { - errorResponse(ctx, w, errInternalError(err.Error())) - return - } - encodeResponse(ctx, w, res) -} - -type modelsPatchRequest struct { - server *Server - project *metadata.Project - dataset *metadata.Dataset - model *metadata.Model -} - -func (h *modelsPatchHandler) Handle(ctx context.Context, r *modelsPatchRequest) (*bigqueryv2.Model, error) { - return &bigqueryv2.Model{}, nil -} - -func (h *projectsGetServiceAccountHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - server := serverFromContext(ctx) - project := projectFromContext(ctx) - res, err := h.Handle(ctx, &projectsGetServiceAccountRequest{ - server: server, - project: project, - }) - if err != nil { - errorResponse(ctx, w, errInternalError(err.Error())) - return - } - encodeResponse(ctx, w, res) -} - -type projectsGetServiceAccountRequest struct { - server *Server - project *metadata.Project -} - -func (h *projectsGetServiceAccountHandler) Handle(ctx context.Context, r *projectsGetServiceAccountRequest) (*bigqueryv2.GetServiceAccountResponse, error) { - return &bigqueryv2.GetServiceAccountResponse{}, nil -} - -func (h *projectsListHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - server := serverFromContext(ctx) - res, err := h.Handle(ctx, &projectsListRequest{ - server: server, - }) - if err != nil { - errorResponse(ctx, w, errInternalError(err.Error())) - return - } - encodeResponse(ctx, w, res) -} - -type projectsListRequest struct { - server *Server -} - -func (h *projectsListHandler) Handle(ctx context.Context, r *projectsListRequest) (*bigqueryv2.ProjectList, error) { - projects, err := r.server.metaRepo.FindAllProjects(ctx) - if err != nil { - return nil, err - } - - projectList := []*bigqueryv2.ProjectListProjects{} - for i, p := range projects { - projectList = append(projectList, &bigqueryv2.ProjectListProjects{ - Id: p.ID, - NumericId: uint64(i + 1), - FriendlyName: p.ID, - }) - } - return &bigqueryv2.ProjectList{ - Projects: projectList, - }, nil -} - -func (h *routinesDeleteHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - server := serverFromContext(ctx) - project := projectFromContext(ctx) - dataset := datasetFromContext(ctx) - routine := routineFromContext(ctx) - if err := h.Handle(ctx, &routinesDeleteRequest{ - server: server, - project: project, - dataset: dataset, - routine: routine, - }); err != nil { - errorResponse(ctx, w, errInternalError(err.Error())) - return - } -} - -type routinesDeleteRequest struct { - server *Server - project *metadata.Project - dataset *metadata.Dataset - routine *metadata.Routine -} - -func (h *routinesDeleteHandler) Handle(ctx context.Context, r *routinesDeleteRequest) error { - return fmt.Errorf("unsupported bigquery.routines.delete") -} - -func (h *routinesGetHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - server := serverFromContext(ctx) - project := projectFromContext(ctx) - dataset := datasetFromContext(ctx) - routine := routineFromContext(ctx) - res, err := h.Handle(ctx, &routinesGetRequest{ - server: server, - project: project, - dataset: dataset, - routine: routine, - }) - if err != nil { - errorResponse(ctx, w, errInternalError(err.Error())) - return - } - encodeResponse(ctx, w, res) -} - -type routinesGetRequest struct { - server *Server - project *metadata.Project - dataset *metadata.Dataset - routine *metadata.Routine -} - -func (h *routinesGetHandler) Handle(ctx context.Context, r *routinesGetRequest) (*bigqueryv2.Routine, error) { - return nil, fmt.Errorf("unsupported bigquery.routines.get") -} - -func (h *routinesInsertHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - server := serverFromContext(ctx) - project := projectFromContext(ctx) - dataset := datasetFromContext(ctx) - var routine bigqueryv2.Routine - if err := json.NewDecoder(r.Body).Decode(&routine); err != nil { - errorResponse(ctx, w, errInvalid(err.Error())) - return - } - res, err := h.Handle(ctx, &routinesInsertRequest{ - server: server, - project: project, - dataset: dataset, - routine: &routine, - }) - if err != nil { - errorResponse(ctx, w, errInternalError(err.Error())) - return - } - encodeResponse(ctx, w, res) -} - -type routinesInsertRequest struct { - server *Server - project *metadata.Project - dataset *metadata.Dataset - routine *bigqueryv2.Routine -} - -func (h *routinesInsertHandler) Handle(ctx context.Context, r *routinesInsertRequest) (*bigqueryv2.Routine, error) { - conn, err := r.server.connMgr.Connection(ctx, r.project.ID, r.dataset.ID) - if err != nil { - return nil, fmt.Errorf("failed to get connection: %w", err) - } - tx, err := conn.Begin(ctx) - if err != nil { - return nil, err - } - defer tx.RollbackIfNotCommitted() - if err := r.server.contentRepo.AddRoutineByMetaData(ctx, tx, r.routine); err != nil { - return nil, err - } - if err := tx.Commit(); err != nil { - return nil, err - } - return r.routine, nil -} - -func (h *routinesListHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - server := serverFromContext(ctx) - project := projectFromContext(ctx) - dataset := datasetFromContext(ctx) - res, err := h.Handle(ctx, &routinesListRequest{ - server: server, - project: project, - dataset: dataset, - }) - if err != nil { - errorResponse(ctx, w, errInternalError(err.Error())) - return - } - encodeResponse(ctx, w, res) -} - -type routinesListRequest struct { - server *Server - project *metadata.Project - dataset *metadata.Dataset -} - -func (h *routinesListHandler) Handle(ctx context.Context, r *routinesListRequest) (*bigqueryv2.ListRoutinesResponse, error) { - var routineList []*bigqueryv2.Routine - for _, routine := range r.dataset.Routines() { - _ = routine - routineList = append(routineList, &bigqueryv2.Routine{}) - } - return &bigqueryv2.ListRoutinesResponse{ - Routines: routineList, - }, fmt.Errorf("unsupported bigquery.routines.list") -} - -func (h *routinesUpdateHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - server := serverFromContext(ctx) - project := projectFromContext(ctx) - dataset := datasetFromContext(ctx) - routine := routineFromContext(ctx) - res, err := h.Handle(ctx, &routinesUpdateRequest{ - server: server, - project: project, - dataset: dataset, - routine: routine, - }) - if err != nil { - errorResponse(ctx, w, errInternalError(err.Error())) - return - } - encodeResponse(ctx, w, res) -} - -type routinesUpdateRequest struct { - server *Server - project *metadata.Project - dataset *metadata.Dataset - routine *metadata.Routine -} - -func (h *routinesUpdateHandler) Handle(ctx context.Context, r *routinesUpdateRequest) (*bigqueryv2.Routine, error) { - return nil, fmt.Errorf("unsupported bigquery.routines.update") -} - -func (h *rowAccessPoliciesGetIamPolicyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - server := serverFromContext(ctx) - res, err := h.Handle(ctx, &rowAccessPoliciesGetIamPolicyRequest{ - server: server, - }) - if err != nil { - errorResponse(ctx, w, errInternalError(err.Error())) - return - } - encodeResponse(ctx, w, res) -} - -type rowAccessPoliciesGetIamPolicyRequest struct { - server *Server -} - -func (h *rowAccessPoliciesGetIamPolicyHandler) Handle(ctx context.Context, r *rowAccessPoliciesGetIamPolicyRequest) (*bigqueryv2.Policy, error) { - return nil, fmt.Errorf("unsupported bigquery.rowAccessPolicies.getIamPolicy") -} - -func (h *rowAccessPoliciesListHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - server := serverFromContext(ctx) - project := projectFromContext(ctx) - dataset := datasetFromContext(ctx) - table := tableFromContext(ctx) - res, err := h.Handle(ctx, &rowAccessPoliciesListRequest{ - server: server, - project: project, - dataset: dataset, - table: table, - }) - if err != nil { - errorResponse(ctx, w, errInternalError(err.Error())) - return - } - encodeResponse(ctx, w, res) -} - -type rowAccessPoliciesListRequest struct { - server *Server - project *metadata.Project - dataset *metadata.Dataset - table *metadata.Table -} - -func (h *rowAccessPoliciesListHandler) Handle(ctx context.Context, r *rowAccessPoliciesListRequest) (*bigqueryv2.ListRowAccessPoliciesResponse, error) { - return nil, fmt.Errorf("unsupported bigquery.rowAccessPolicies.list") -} - -func (h *rowAccessPoliciesSetIamPolicyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - server := serverFromContext(ctx) - res, err := h.Handle(ctx, &rowAccessPoliciesSetIamPolicyRequest{ - server: server, - }) - if err != nil { - errorResponse(ctx, w, errInternalError(err.Error())) - return - } - encodeResponse(ctx, w, res) -} - -type rowAccessPoliciesSetIamPolicyRequest struct { - server *Server -} - -func (h *rowAccessPoliciesSetIamPolicyHandler) Handle(ctx context.Context, r *rowAccessPoliciesSetIamPolicyRequest) (*bigqueryv2.Policy, error) { - return nil, fmt.Errorf("unsupported bigquery.rowAccessPolicies.setIamPolicy") -} - -func (h *rowAccessPoliciesTestIamPermissionsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - server := serverFromContext(ctx) - res, err := h.Handle(ctx, &rowAccessPoliciesTestIamPermissionsRequest{ - server: server, - }) - if err != nil { - errorResponse(ctx, w, errInternalError(err.Error())) - return - } - encodeResponse(ctx, w, res) -} - -type rowAccessPoliciesTestIamPermissionsRequest struct { - server *Server -} - -func (h *rowAccessPoliciesTestIamPermissionsHandler) Handle(ctx context.Context, r *rowAccessPoliciesTestIamPermissionsRequest) (*bigqueryv2.TestIamPermissionsResponse, error) { - return nil, fmt.Errorf("unsupported bigquery.rowAccessPolicies.testIamPermissions") -} - -func (h *tabledataInsertAllHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - server := serverFromContext(ctx) - project := projectFromContext(ctx) - dataset := datasetFromContext(ctx) - table := tableFromContext(ctx) - var req bigqueryv2.TableDataInsertAllRequest - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - errorResponse(ctx, w, errInvalid(err.Error())) - return - } - res, err := h.Handle(ctx, &tabledataInsertAllRequest{ - server: server, - project: project, - dataset: dataset, - table: table, - req: &req, - }) - if err != nil { - errorResponse(ctx, w, errInternalError(err.Error())) - return - } - encodeResponse(ctx, w, res) -} - -type tabledataInsertAllRequest struct { - server *Server - project *metadata.Project - dataset *metadata.Dataset - table *metadata.Table - req *bigqueryv2.TableDataInsertAllRequest -} - -func normalizeInsertValue(v interface{}, field *bigqueryv2.TableFieldSchema) (interface{}, error) { - rv := reflect.ValueOf(v) - kind := rv.Kind() - if field.Mode == "REPEATED" { - if kind != reflect.Slice && kind != reflect.Array { - return nil, fmt.Errorf("invalid value type %T for ARRAY column", v) - } - values := make([]interface{}, 0, rv.Len()) - for i := 0; i < rv.Len(); i++ { - value, err := normalizeInsertValue(rv.Index(i).Interface(), &bigqueryv2.TableFieldSchema{ - Fields: field.Fields, - }) - if err != nil { - return nil, err - } - values = append(values, value) - } - return values, nil - } - if kind == reflect.Map { - fieldMap := map[string]*bigqueryv2.TableFieldSchema{} - for _, f := range field.Fields { - fieldMap[f.Name] = f - } - columnNameToValueMap := map[string]interface{}{} - for _, key := range rv.MapKeys() { - if key.Kind() != reflect.String { - return nil, fmt.Errorf("invalid value type %s for STRUCT column", key.Kind()) - } - columnName := key.Interface().(string) - value, err := normalizeInsertValue(rv.MapIndex(key).Interface(), fieldMap[columnName]) - if err != nil { - return nil, err - } - columnNameToValueMap[columnName] = value - } - fields := make([]map[string]interface{}, 0, len(fieldMap)) - for _, f := range field.Fields { - value, exists := columnNameToValueMap[f.Name] - if !exists { - return nil, fmt.Errorf("failed to find value from %s", f.Name) - } - fields = append(fields, map[string]interface{}{f.Name: value}) - } - return fields, nil - } - return v, nil -} - -func (h *tabledataInsertAllHandler) Handle(ctx context.Context, r *tabledataInsertAllRequest) (*bigqueryv2.TableDataInsertAllResponse, error) { - content, err := r.table.Content() - if err != nil { - return nil, err - } - var insertErrors []*bigqueryv2.TableDataInsertAllResponseInsertErrors - data := types.Data{} - for i, row := range r.req.Rows { - // A row that carries fields absent from the table schema is rejected - // unless ignoreUnknownValues is set; it is reported per-row and not - // inserted, while the remaining rows still go in. - if !r.req.IgnoreUnknownValues { - if unknown := types.ValidateRowFields(content.Schema, row.Json); len(unknown) > 0 { - errs := make([]*bigqueryv2.ErrorProto, 0, len(unknown)) - for _, name := range unknown { - errs = append(errs, &bigqueryv2.ErrorProto{ - Reason: "invalid", - Location: name, - Message: fmt.Sprintf("no such field: %s.", name), - }) - } - insertErrors = append(insertErrors, &bigqueryv2.TableDataInsertAllResponseInsertErrors{ - Index: int64(i), - Errors: errs, - }) - continue - } - } - rowData := map[string]interface{}{} - for k, v := range row.Json { - rowData[k] = v - } - data = append(data, rowData) - } - if len(data) > 0 { - tableDef, err := types.NewTableWithSchema(content, data) - if err != nil { - return nil, err - } - conn, err := r.server.connMgr.Connection(ctx, r.project.ID, r.dataset.ID) - if err != nil { - return nil, fmt.Errorf("failed to get connection: %w", err) - } - tx, err := conn.Begin(ctx) - if err != nil { - return nil, err - } - defer tx.RollbackIfNotCommitted() - if err := r.server.contentRepo.AddTableData(ctx, tx, r.project.ID, r.dataset.ID, tableDef); err != nil { - return nil, err - } - if err := tx.Commit(); err != nil { - return nil, err - } - } - return &bigqueryv2.TableDataInsertAllResponse{InsertErrors: insertErrors}, nil -} - -func (h *tabledataListHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - server := serverFromContext(ctx) - project := projectFromContext(ctx) - dataset := datasetFromContext(ctx) - table := tableFromContext(ctx) - res, err := h.Handle(ctx, &tabledataListRequest{ - server: server, - project: project, - dataset: dataset, - table: table, - useInt64Timestamp: isFormatOptionsUseInt64Timestamp(r), - }) - if err != nil { - errorResponse(ctx, w, errInternalError(err.Error())) - return - } - encodeResponse(ctx, w, res) -} - -type tabledataListRequest struct { - server *Server - project *metadata.Project - dataset *metadata.Dataset - table *metadata.Table - useInt64Timestamp bool -} - -func (h *tabledataListHandler) Handle(ctx context.Context, r *tabledataListRequest) (*internaltypes.TableDataList, error) { - conn, err := r.server.connMgr.Connection(ctx, r.project.ID, r.dataset.ID) - if err != nil { - return nil, fmt.Errorf("failed to get connection: %w", err) - } - tx, err := conn.Begin(ctx) - if err != nil { - return nil, fmt.Errorf("failed to start transaction: %w", err) - } - defer tx.RollbackIfNotCommitted() - response, err := r.server.contentRepo.Query( - ctx, - tx, - r.project.ID, - r.dataset.ID, - fmt.Sprintf("SELECT * FROM `%s`", r.table.ID), - nil, - ) - if err != nil { - return nil, err - } - - return &internaltypes.TableDataList{ - Rows: internaltypes.Format(response.Schema, response.Rows, r.useInt64Timestamp), - TotalRows: response.TotalRows, - }, nil -} - -func (h *tablesDeleteHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - server := serverFromContext(ctx) - project := projectFromContext(ctx) - dataset := datasetFromContext(ctx) - table := tableFromContext(ctx) - if err := h.Handle(ctx, &tablesDeleteRequest{ - server: server, - project: project, - dataset: dataset, - table: table, - }); err != nil { - errorResponse(ctx, w, errInternalError(err.Error())) - } -} - -type tablesDeleteRequest struct { - server *Server - project *metadata.Project - dataset *metadata.Dataset - table *metadata.Table -} - -func (h *tablesDeleteHandler) Handle(ctx context.Context, r *tablesDeleteRequest) error { - conn, err := r.server.connMgr.Connection(ctx, r.project.ID, r.dataset.ID) - if err != nil { - return err - } - tx, err := conn.Begin(ctx) - if err != nil { - return err - } - defer tx.RollbackIfNotCommitted() - // delete table metadata - if err := r.table.Delete(ctx, tx.Tx()); err != nil { - return err - } - // delete table - if err := r.server.contentRepo.DeleteTables( - ctx, - tx, - r.project.ID, - r.dataset.ID, - []contentdata.TableDeletion{{ID: r.table.ID, IsView: r.table.IsView()}}, - ); err != nil { - return fmt.Errorf("failed to delete table %s: %w", r.table.ID, err) - } - if err := tx.Commit(); err != nil { - return err - } - return nil -} - -func (h *tablesGetHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - server := serverFromContext(ctx) - project := projectFromContext(ctx) - dataset := datasetFromContext(ctx) - table := tableFromContext(ctx) - res, err := h.Handle(ctx, &tablesGetRequest{ - server: server, - project: project, - dataset: dataset, - table: table, - }) - if err != nil { - errorResponse(ctx, w, errInternalError(err.Error())) - return - } - encodeResponse(ctx, w, res) -} - -type tablesGetRequest struct { - server *Server - project *metadata.Project - dataset *metadata.Dataset - table *metadata.Table -} - -func (h *tablesGetHandler) Handle(ctx context.Context, r *tablesGetRequest) (*bigqueryv2.Table, error) { - table, err := r.table.Content() - if err != nil { - return nil, fmt.Errorf("failed to get table content: %w", err) - } - // Populate NumRows from the backing table so clients that depend on it - // (e.g. Table.getNumRows) observe an accurate count. Views and external - // tables have no backing row store and are left untouched. - if table.Type == "" || table.Type == "TABLE" { - if numRows, err := h.countRows(ctx, r); err == nil { - table.NumRows = uint64(numRows) - } - } - return table, nil -} - -func (h *tablesGetHandler) countRows(ctx context.Context, r *tablesGetRequest) (int64, error) { - conn, err := r.server.connMgr.Connection(ctx, r.project.ID, r.dataset.ID) - if err != nil { - return 0, err - } - tx, err := conn.Begin(ctx) - if err != nil { - return 0, err - } - defer tx.RollbackIfNotCommitted() - count, err := r.server.contentRepo.CountTableRows(ctx, tx, r.project.ID, r.dataset.ID, r.table.ID) - if err != nil { - return 0, err - } - if err := tx.Commit(); err != nil { - return 0, err - } - return count, nil -} - -func (h *tablesGetIamPolicyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - server := serverFromContext(ctx) - var req bigqueryv2.GetIamPolicyRequest - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - errorResponse(ctx, w, errInvalid(err.Error())) - return - } - res, err := h.Handle(ctx, &tablesGetIamPolicyRequest{ - server: server, - req: &req, - }) - if err != nil { - errorResponse(ctx, w, errInternalError(err.Error())) - return - } - encodeResponse(ctx, w, res) -} - -type tablesGetIamPolicyRequest struct { - server *Server - req *bigqueryv2.GetIamPolicyRequest -} - -func (h *tablesGetIamPolicyHandler) Handle(ctx context.Context, r *tablesGetIamPolicyRequest) (*bigqueryv2.Policy, error) { - return nil, fmt.Errorf("bigquery.tables.getIamPolicy") -} - -func (h *tablesInsertHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - server := serverFromContext(ctx) - project := projectFromContext(ctx) - dataset := datasetFromContext(ctx) - var table bigqueryv2.Table - if err := json.NewDecoder(r.Body).Decode(&table); err != nil { - errorResponse(ctx, w, errInvalid(err.Error())) - return - } - res, err := h.Handle(ctx, &tablesInsertRequest{ - server: server, - project: project, - dataset: dataset, - table: &table, - }) - if err != nil { - errorResponse(ctx, w, err) - return - } - encodeResponse(ctx, w, res) -} - -type tablesInsertRequest struct { - server *Server - project *metadata.Project - dataset *metadata.Dataset - table *bigqueryv2.Table -} - -type TableType string - -const ( - DefaultTableType TableType = "TABLE" - ViewTableType TableType = "VIEW" - ExternalTableType TableType = "EXTERNAL" - MaterializedViewTableType TableType = "MATERIALIZED_VIEW" - SnapshotTableType TableType = "SNAPSHOT" -) - -func createTableMetadata(ctx context.Context, tx *connection.Tx, server *Server, project *metadata.Project, dataset *metadata.Dataset, table *bigqueryv2.Table) (*bigqueryv2.Table, *ServerError) { - now := time.Now().Unix() - table.Id = fmt.Sprintf("%s:%s.%s", project.ID, dataset.ID, table.TableReference.TableId) - table.CreationTime = now - table.LastModifiedTime = uint64(now) - table.Type = string(DefaultTableType) // TODO: need to handle other table types - if table.View != nil { - table.Type = string(ViewTableType) - } - if table.MaterializedView != nil { - table.Type = string(MaterializedViewTableType) - } - table.Kind = "bigquery#table" - table.SelfLink = fmt.Sprintf( - "http://%s/bigquery/v2/projects/%s/datasets/%s/tables/%s", - server.httpServer.Addr, - project.ID, - dataset.ID, - table.TableReference.TableId, - ) - encodedTableData, err := json.Marshal(table) - if err != nil { - return nil, errInternalError(err.Error()) - } - var tableMetadata map[string]interface{} - if err := json.Unmarshal(encodedTableData, &tableMetadata); err != nil { - return nil, errInternalError(err.Error()) - } - if err := dataset.AddTable( - ctx, - tx.Tx(), - metadata.NewTable( - server.metaRepo, - project.ID, - dataset.ID, - table.TableReference.TableId, - tableMetadata, - ), - ); err != nil { - if errors.Is(err, metadata.ErrDuplicatedTable) { - return nil, errDuplicate(err.Error()) - } - return nil, errInternalError(err.Error()) - } - return table, nil -} - -func (h *tablesInsertHandler) Handle(ctx context.Context, r *tablesInsertRequest) (*bigqueryv2.Table, *ServerError) { - if r.table.ExternalDataConfiguration != nil { - return h.handleExternalTable(ctx, r) - } - conn, err := r.server.connMgr.Connection(ctx, r.project.ID, r.dataset.ID) - if err != nil { - return nil, errInternalError(err.Error()) - } - tx, err := conn.Begin(ctx) - if err != nil { - return nil, errInternalError(err.Error()) - } - defer tx.RollbackIfNotCommitted() - - isView := r.table.View != nil || r.table.MaterializedView != nil - if isView { - // Create the view first so its resolved column schema can be read - // back and recorded in the metadata, as real BigQuery does. - if err := r.server.contentRepo.CreateView(ctx, tx, r.table); err != nil { - return nil, errInvalid(err.Error()) - } - schema, err := r.server.contentRepo.ViewSchema( - ctx, tx, r.project.ID, r.dataset.ID, r.table.TableReference.TableId, - ) - if err != nil { - return nil, errInvalid(err.Error()) - } - r.table.Schema = schema - } - table, serverErr := createTableMetadata(ctx, tx, r.server, r.project, r.dataset, r.table) - if serverErr != nil { - return nil, serverErr - } - if !isView && r.table.Schema != nil { - if err := r.server.contentRepo.CreateTable(ctx, tx, r.table); err != nil { - return nil, errInternalError(err.Error()) - } - } - if err := tx.Commit(); err != nil { - return nil, errInternalError(fmt.Errorf("failed to commit table: %w", err).Error()) - } - return table, nil -} - -// handleExternalTable materializes an external table's source data into a -// backing table so the table is registered with the query engine and can be -// queried. Real BigQuery reads an external table live on every query; the -// emulator snapshots the source at creation time, which is enough to make -// the table queryable. -func (h *tablesInsertHandler) handleExternalTable(ctx context.Context, r *tablesInsertRequest) (*bigqueryv2.Table, *ServerError) { - edc := r.table.ExternalDataConfiguration - tableRef := r.table.TableReference - if tableRef == nil { - return nil, errInvalid("external table is missing tableReference") - } - if tableRef.ProjectId == "" { - tableRef.ProjectId = r.project.ID - } - if tableRef.DatasetId == "" { - tableRef.DatasetId = r.dataset.ID - } - if len(edc.SourceUris) == 0 { - return nil, errInvalid("external table is missing sourceUris") - } - - // Translate the external data configuration into a load job and route it - // through the existing load pipeline, which creates the backing table and - // loads the rows (inferring the schema when autodetect is requested). - load := &bigqueryv2.JobConfigurationLoad{ - DestinationTable: tableRef, - SourceUris: edc.SourceUris, - SourceFormat: edc.SourceFormat, - Autodetect: edc.Autodetect, - Schema: edc.Schema, - CreateDisposition: "CREATE_IF_NEEDED", - } - if load.Schema == nil { - load.Schema = r.table.Schema - } - if csv := edc.CsvOptions; csv != nil { - load.Quote = csv.Quote - load.FieldDelimiter = csv.FieldDelimiter - load.SkipLeadingRows = csv.SkipLeadingRows - load.AllowJaggedRows = csv.AllowJaggedRows - load.AllowQuotedNewlines = csv.AllowQuotedNewlines - load.Encoding = csv.Encoding - } - job := &bigqueryv2.Job{ - JobReference: &bigqueryv2.JobReference{JobId: randomID(), ProjectId: r.project.ID}, - Configuration: &bigqueryv2.JobConfiguration{Load: load}, - } - if _, err := (&jobsInsertHandler{}).importFromGCS(ctx, &jobsInsertRequest{ - server: r.server, - project: r.project, - job: job, - }); err != nil { - return nil, errInvalid(fmt.Sprintf("failed to load external table data: %s", err)) - } - - // The load created a plain table; record the external configuration on - // its metadata so it round-trips on tables.get and tables.list. - table := r.dataset.Table(tableRef.TableId) - if table == nil { - return nil, errInternalError("external table backing data was not created") - } - content, err := table.Content() - if err != nil { - return nil, errInternalError(err.Error()) - } - content.ExternalDataConfiguration = edc - content.Type = string(ExternalTableType) - encoded, err := json.Marshal(content) - if err != nil { - return nil, errInternalError(err.Error()) - } - var newMetadata map[string]interface{} - if err := json.Unmarshal(encoded, &newMetadata); err != nil { - return nil, errInternalError(err.Error()) - } - conn, err := r.server.connMgr.Connection(ctx, r.project.ID, r.dataset.ID) - if err != nil { - return nil, errInternalError(err.Error()) - } - tx, err := conn.Begin(ctx) - if err != nil { - return nil, errInternalError(err.Error()) - } - defer tx.RollbackIfNotCommitted() - if err := table.Replace(ctx, tx.Tx(), newMetadata); err != nil { - return nil, errInternalError(err.Error()) - } - if err := tx.Commit(); err != nil { - return nil, errInternalError(err.Error()) - } - return content, nil -} - -func (h *tablesListHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - server := serverFromContext(ctx) - project := projectFromContext(ctx) - dataset := datasetFromContext(ctx) - res, err := h.Handle(ctx, &tablesListRequest{ - server: server, - project: project, - dataset: dataset, - }) - if err != nil { - errorResponse(ctx, w, errInternalError(err.Error())) - return - } - encodeResponse(ctx, w, res) -} - -type tablesListRequest struct { - server *Server - project *metadata.Project - dataset *metadata.Dataset -} - -func (h *tablesListHandler) Handle(ctx context.Context, r *tablesListRequest) (*bigqueryv2.TableList, error) { - var tables []*bigqueryv2.TableListTables - for _, tableID := range r.dataset.TableIDs() { - table, err := r.dataset.Table(tableID).Content() - if err != nil { - return nil, fmt.Errorf("failed to get table metadata from %s: %w", tableID, err) - } - tables = append(tables, &bigqueryv2.TableListTables{ - Clustering: table.Clustering, - CreationTime: table.CreationTime, - ExpirationTime: table.ExpirationTime, - FriendlyName: table.FriendlyName, - Id: table.Id, - Kind: table.Kind, - Labels: table.Labels, - RangePartitioning: table.RangePartitioning, - TableReference: table.TableReference, - TimePartitioning: table.TimePartitioning, - Type: table.Type, - }) - } - return &bigqueryv2.TableList{ - Tables: tables, - TotalItems: int64(len(tables)), - }, nil -} - -func (h *tablesPatchHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - server := serverFromContext(ctx) - project := projectFromContext(ctx) - dataset := datasetFromContext(ctx) - table := tableFromContext(ctx) - var newTable bigqueryv2.Table - if err := json.NewDecoder(r.Body).Decode(&newTable); err != nil { - errorResponse(ctx, w, errInvalid(err.Error())) - return - } - res, err := h.Handle(ctx, &tablesPatchRequest{ - server: server, - project: project, - dataset: dataset, - table: table, - newTable: &newTable, - }) - if err != nil { - errorResponse(ctx, w, errInternalError(err.Error())) - return - } - encodeResponse(ctx, w, res) -} - -type tablesPatchRequest struct { - server *Server - project *metadata.Project - dataset *metadata.Dataset - table *metadata.Table - newTable *bigqueryv2.Table -} - -func (h *tablesPatchHandler) Handle(ctx context.Context, r *tablesPatchRequest) (*bigqueryv2.Table, error) { - encodedTableData, err := json.Marshal(r.newTable) - if err != nil { - return nil, err - } - var tableMetadata map[string]interface{} - if err := json.Unmarshal(encodedTableData, &tableMetadata); err != nil { - return nil, err - } - - conn, err := r.server.connMgr.Connection(ctx, r.project.ID, r.dataset.ID) - if err != nil { - return nil, err - } - tx, err := conn.Begin(ctx) - if err != nil { - return nil, err - } - defer tx.RollbackIfNotCommitted() - if err := r.table.Patch(ctx, tx.Tx(), tableMetadata); err != nil { - return nil, err - } - if err := tx.Commit(); err != nil { - return nil, err - } - // Return the full, merged table resource (kind/type/id/creationTime - // included) rather than echoing the request body. - return r.table.Content() -} - -func (h *tablesSetIamPolicyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - server := serverFromContext(ctx) - res, err := h.Handle(ctx, &tablesSetIamPolicyRequest{ - server: server, - }) - if err != nil { - errorResponse(ctx, w, errInternalError(err.Error())) - return - } - encodeResponse(ctx, w, res) -} - -type tablesSetIamPolicyRequest struct { - server *Server -} - -func (h *tablesSetIamPolicyHandler) Handle(ctx context.Context, r *tablesSetIamPolicyRequest) (*bigqueryv2.Policy, error) { - return nil, fmt.Errorf("unsupported bigquery.tables.setIamPolicy") -} - -func (h *tablesTestIamPermissionsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - server := serverFromContext(ctx) - res, err := h.Handle(ctx, &tablesTestIamPermissionsRequest{ - server: server, - }) - if err != nil { - errorResponse(ctx, w, errInternalError(err.Error())) - return - } - encodeResponse(ctx, w, res) -} - -type tablesTestIamPermissionsRequest struct { - server *Server -} - -func (h *tablesTestIamPermissionsHandler) Handle(ctx context.Context, r *tablesTestIamPermissionsRequest) (*bigqueryv2.TestIamPermissionsResponse, error) { - return nil, fmt.Errorf("unsupported bigquery.tables.testIamPermissions") -} - -func (h *tablesUpdateHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - server := serverFromContext(ctx) - project := projectFromContext(ctx) - dataset := datasetFromContext(ctx) - table := tableFromContext(ctx) - var newTable bigqueryv2.Table - if err := json.NewDecoder(r.Body).Decode(&newTable); err != nil { - errorResponse(ctx, w, errInvalid(err.Error())) - return - } - res, err := h.Handle(ctx, &tablesUpdateRequest{ - server: server, - project: project, - dataset: dataset, - table: table, - newTable: &newTable, - }) - if err != nil { - errorResponse(ctx, w, errInternalError(err.Error())) - return - } - encodeResponse(ctx, w, res) -} - -type tablesUpdateRequest struct { - server *Server - project *metadata.Project - dataset *metadata.Dataset - table *metadata.Table - newTable *bigqueryv2.Table -} - -func (h *tablesUpdateHandler) Handle(ctx context.Context, r *tablesUpdateRequest) (*bigqueryv2.Table, error) { - encodedTableData, err := json.Marshal(r.newTable) - if err != nil { - return nil, err - } - var tableMetadata map[string]interface{} - if err := json.Unmarshal(encodedTableData, &tableMetadata); err != nil { - return nil, err - } - - conn, err := r.server.connMgr.Connection(ctx, r.project.ID, r.dataset.ID) - if err != nil { - return nil, err - } - tx, err := conn.Begin(ctx) - if err != nil { - return nil, err - } - defer tx.RollbackIfNotCommitted() - if err := r.table.Replace(ctx, tx.Tx(), tableMetadata); err != nil { - return nil, err - } - if err := tx.Commit(); err != nil { - return nil, err - } - return r.table.Content() -} - -type defaultHandler struct{} - -func (h *defaultHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - errorResponse(ctx, w, errInternalError(fmt.Sprintf("unexpected request path: %s", html.EscapeString(r.URL.Path)))) -} +package server + +import ( + "bytes" + "context" + _ "embed" + "encoding/csv" + "errors" + "fmt" + "html" + "io" + "mime" + "mime/multipart" + "net/http" + "os" + "reflect" + "strconv" + "strings" + "sync" + "time" + + "cloud.google.com/go/storage" + "github.com/goccy/go-json" + "github.com/goccy/googlesqlite" + "go.uber.org/zap" + bigqueryv2 "google.golang.org/api/bigquery/v2" + "google.golang.org/api/iterator" + "google.golang.org/api/option" + + "github.com/goccy/bigquery-emulator/internal/connection" + "github.com/goccy/bigquery-emulator/internal/contentdata" + "github.com/goccy/bigquery-emulator/internal/logger" + "github.com/goccy/bigquery-emulator/internal/metadata" + internaltypes "github.com/goccy/bigquery-emulator/internal/types" + "github.com/goccy/bigquery-emulator/types" + "github.com/parquet-go/parquet-go" +) + +func errorResponse(ctx context.Context, w http.ResponseWriter, e *ServerError) { + logger.Logger(ctx).WithOptions(zap.AddCallerSkip(1)).Error(string(e.Reason), zap.Error(e)) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(e.Status) + w.Write(e.Response()) +} + +// uploadErrorResponse renders an error returned by the upload content handler. +// When the handler produced a typed *ServerError (e.g. a missing dataset), its +// HTTP status is preserved; any other failure is reported as a job error. +func uploadErrorResponse(ctx context.Context, w http.ResponseWriter, err error) { + var serr *ServerError + if errors.As(err, &serr) { + errorResponse(ctx, w, serr) + return + } + errorResponse(ctx, w, errJobInternalError(err.Error())) +} + +func encodeResponse(ctx context.Context, w http.ResponseWriter, response interface{}) { + b, err := json.Marshal(response) + if err != nil { + errorResponse(ctx, w, errInternalError(fmt.Sprintf("failed to encode json: %s", err.Error()))) + return + } + w.Header().Set("Content-Type", "application/json") + w.Write(b) +} + +const ( + discoveryAPIEndpoint = "/discovery/v1/apis/bigquery/v2/rest" + newDiscoveryAPIEndpoint = "/$discovery/rest" + uploadAPIEndpoint = "/upload/bigquery/v2/projects/{projectId}/jobs" +) + +//go:embed resources/discovery.json +var bigqueryAPIJSON []byte + +var ( + discoveryAPIOnce sync.Once + discoveryAPIResponse map[string]interface{} +) + +type discoveryHandler struct { + server *Server +} + +func newDiscoveryHandler(server *Server) *discoveryHandler { + return &discoveryHandler{server: server} +} + +func (h *discoveryHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + var decodeJSONErr error + discoveryAPIOnce.Do(func() { + if err := json.Unmarshal(bigqueryAPIJSON, &discoveryAPIResponse); err != nil { + decodeJSONErr = err + return + } + addr := h.server.httpServer.Addr + if !strings.HasPrefix(addr, "http") { + addr = "http://" + addr + } + discoveryAPIResponse["mtlsRootUrl"] = addr + discoveryAPIResponse["rootUrl"] = addr + discoveryAPIResponse["baseUrl"] = addr + }) + if decodeJSONErr != nil { + errorResponse(ctx, w, errInternalError(decodeJSONErr.Error())) + return + } + encodeResponse(ctx, w, discoveryAPIResponse) +} + +type uploadHandler struct{} + +type UploadJobConfigurationLoad struct { + AllowJaggedRows bool `json:"allowJaggedRows,omitempty"` + AllowQuotedNewlines bool `json:"allowQuotedNewlines,omitempty"` + Autodetect bool `json:"autodetect,omitempty"` + Clustering *bigqueryv2.Clustering `json:"clustering,omitempty"` + CreateDisposition string `json:"createDisposition,omitempty"` + DecimalTargetTypes []string `json:"decimalTargetTypes,omitempty"` + DestinationEncryptionConfiguration *bigqueryv2.EncryptionConfiguration `json:"destinationEncryptionConfiguration,omitempty"` + DestinationTable *bigqueryv2.TableReference `json:"destinationTable,omitempty"` + DestinationTableProperties *bigqueryv2.DestinationTableProperties `json:"destinationTableProperties,omitempty"` + Encoding string `json:"encoding,omitempty"` + FieldDelimiter string `json:"fieldDelimiter,omitempty"` + HivePartitioningOptions *bigqueryv2.HivePartitioningOptions `json:"hivePartitioningOptions,omitempty"` + IgnoreUnknownValues bool `json:"ignoreUnknownValues,omitempty"` + JsonExtension string `json:"jsonExtension,omitempty"` + MaxBadRecords int64 `json:"maxBadRecords,omitempty"` + NullMarker string `json:"nullMarker,omitempty"` + ParquetOptions *bigqueryv2.ParquetOptions `json:"parquetOptions,omitempty"` + PreserveAsciiControlCharacters bool `json:"preserveAsciiControlCharacters,omitempty"` + ProjectionFields []string `json:"projectionFields,omitempty"` + Quote *string `json:"quote,omitempty"` + RangePartitioning *bigqueryv2.RangePartitioning `json:"rangePartitioning,omitempty"` + Schema *bigqueryv2.TableSchema `json:"schema,omitempty"` + SchemaInline string `json:"schemaInline,omitempty"` + SchemaInlineFormat string `json:"schemaInlineFormat,omitempty"` + SchemaUpdateOptions []string `json:"schemaUpdateOptions,omitempty"` + SkipLeadingRows json.Number `json:"skipLeadingRows,omitempty"` + SourceFormat string `json:"sourceFormat,omitempty"` + SourceUris []string `json:"sourceUris,omitempty"` + TimePartitioning *bigqueryv2.TimePartitioning `json:"timePartitioning,omitempty"` + UseAvroLogicalTypes bool `json:"useAvroLogicalTypes,omitempty"` + WriteDisposition string `json:"writeDisposition,omitempty"` +} + +type UploadJobConfiguration struct { + Load *UploadJobConfigurationLoad `json:"load"` +} + +type UploadJob struct { + JobReference *bigqueryv2.JobReference `json:"jobReference"` + Configuration *UploadJobConfiguration `json:"configuration"` +} + +// normalize fills in fields that some client libraries omit from the upload +// metadata. The Node.js client in particular sends no jobReference, which +// previously caused a nil pointer dereference when the handler read the job +// id. A missing job id is generated so the upload still gets a stable handle. +func (j *UploadJob) normalize(projectID string) *ServerError { + if j.Configuration == nil || j.Configuration.Load == nil { + return errInvalid("upload job is missing configuration.load") + } + if j.JobReference == nil { + j.JobReference = &bigqueryv2.JobReference{} + } + if j.JobReference.JobId == "" { + j.JobReference.JobId = randomID() + } + if j.JobReference.ProjectId == "" { + j.JobReference.ProjectId = projectID + } + return nil +} + +func (j *UploadJob) ToJob() *bigqueryv2.Job { + load := j.Configuration.Load + skipLeadingRows, _ := load.SkipLeadingRows.Int64() + return &bigqueryv2.Job{ + JobReference: j.JobReference, + Configuration: &bigqueryv2.JobConfiguration{ + Load: &bigqueryv2.JobConfigurationLoad{ + AllowJaggedRows: load.AllowJaggedRows, + AllowQuotedNewlines: load.AllowQuotedNewlines, + Autodetect: load.Autodetect, + Clustering: load.Clustering, + CreateDisposition: load.CreateDisposition, + DecimalTargetTypes: load.DecimalTargetTypes, + DestinationEncryptionConfiguration: load.DestinationEncryptionConfiguration, + DestinationTable: load.DestinationTable, + DestinationTableProperties: load.DestinationTableProperties, + Encoding: load.Encoding, + FieldDelimiter: load.FieldDelimiter, + HivePartitioningOptions: load.HivePartitioningOptions, + IgnoreUnknownValues: load.IgnoreUnknownValues, + JsonExtension: load.JsonExtension, + MaxBadRecords: load.MaxBadRecords, + NullMarker: load.NullMarker, + ParquetOptions: load.ParquetOptions, + PreserveAsciiControlCharacters: load.PreserveAsciiControlCharacters, + ProjectionFields: load.ProjectionFields, + Quote: load.Quote, + RangePartitioning: load.RangePartitioning, + Schema: load.Schema, + SchemaInline: load.SchemaInline, + SchemaInlineFormat: load.SchemaInlineFormat, + SchemaUpdateOptions: load.SchemaUpdateOptions, + SkipLeadingRows: skipLeadingRows, + SourceFormat: load.SourceFormat, + SourceUris: load.SourceUris, + TimePartitioning: load.TimePartitioning, + UseAvroLogicalTypes: load.UseAvroLogicalTypes, + WriteDisposition: load.WriteDisposition, + }, + }, + } +} + +func (h *uploadHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + switch r.URL.Query().Get("uploadType") { + case "multipart": + h.serveMultipart(w, r) + case "resumable": + h.serveResumable(w, r) + default: + errorResponse(r.Context(), w, errInvalid(`uploadType should be "multipart" or "resumable"`)) + } +} + +func (h *uploadHandler) serveMultipart(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + server := serverFromContext(ctx) + project := projectFromContext(ctx) + contentType, params, err := mime.ParseMediaType(r.Header.Get("Content-Type")) + if err != nil || !strings.HasPrefix(contentType, "multipart/") { + errorResponse(ctx, w, errInvalid("expecting a multipart message")) + return + } + mul := multipart.NewReader(r.Body, params["boundary"]) + p, err := mul.NextPart() + if err != nil { + errorResponse(ctx, w, errInvalid(fmt.Sprintf("failed to load metadata: %s", err.Error()))) + return + } + var job UploadJob + if err := json.NewDecoder(p).Decode(&job); err != nil { + errorResponse(ctx, w, errInvalid(fmt.Sprintf("failed to decode job: %s", err.Error()))) + return + } + if serr := job.normalize(project.ID); serr != nil { + errorResponse(ctx, w, serr) + return + } + uploadJob, err := h.Handle(ctx, &uploadRequest{ + server: server, + project: project, + job: &job, + }) + if err != nil { + errorResponse(ctx, w, errInternalError(err.Error())) + return + } + + p, err = mul.NextPart() + if err != nil { + errorResponse(ctx, w, errInvalid(fmt.Sprintf("multipart request is invalid: %s", err.Error()))) + return + } + u := &uploadContentHandler{} + err = u.Handle(ctx, &uploadContentRequest{ + server: server, + project: project, + job: uploadJob, + reader: p, + }) + if err != nil { + uploadErrorResponse(ctx, w, err) + return + } + encodeResponse(ctx, w, uploadJob.Content()) +} + +func (h *uploadHandler) serveResumable(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + server := serverFromContext(ctx) + project := projectFromContext(ctx) + var job UploadJob + if err := json.NewDecoder(r.Body).Decode(&job); err != nil { + errorResponse(ctx, w, errInvalid(fmt.Sprintf("failed to decode job: %s", err.Error()))) + return + } + if serr := job.normalize(project.ID); serr != nil { + errorResponse(ctx, w, serr) + return + } + res, err := h.Handle(ctx, &uploadRequest{ + server: server, + project: project, + job: &job, + }) + if err != nil { + errorResponse(ctx, w, errInternalError(err.Error())) + return + } + addr := server.httpServer.Addr + if !strings.HasPrefix(addr, "http") { + addr = "http://" + addr + } + addr = strings.TrimRight(addr, "/") + w.Header().Add( + "Location", + fmt.Sprintf( + "%s/upload/bigquery/v2/projects/%s/jobs?uploadType=resumable&upload_id=%s", + addr, + project.ID, + job.JobReference.JobId, + ), + ) + encodeResponse(ctx, w, res.Content()) +} + +type uploadRequest struct { + server *Server + project *metadata.Project + job *UploadJob +} + +func (h *uploadHandler) Handle(ctx context.Context, r *uploadRequest) (*metadata.Job, error) { + conn, err := r.server.connMgr.Connection(ctx, r.project.ID, "") + if err != nil { + return nil, fmt.Errorf("failed to get connection: %w", err) + } + tx, err := conn.Begin(ctx) + if err != nil { + return nil, fmt.Errorf("failed to start transaction: %w", err) + } + defer tx.RollbackIfNotCommitted() + job := metadata.NewJob(r.server.metaRepo, r.project.ID, r.job.JobReference.JobId, r.job.ToJob(), nil, nil) + if err := r.project.AddJob(ctx, tx.Tx(), job); err != nil { + return nil, fmt.Errorf("failed to add job: %w", err) + } + if err := tx.Commit(); err != nil { + return nil, fmt.Errorf("failed to commit job: %w", err) + } + return job, nil +} + +type uploadContentHandler struct{} + +func (h *uploadContentHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + server := serverFromContext(ctx) + project := projectFromContext(ctx) + query := r.URL.Query() + uploadType := query["uploadType"] + if len(uploadType) == 0 { + errorResponse(ctx, w, errInvalid("uploadType parameter is not found")) + return + } + if uploadType[0] != "resumable" { + errorResponse(ctx, w, errInvalid(fmt.Sprintf("uploadType parameter is not resumable %s", uploadType[0]))) + return + } + uploadID := query["upload_id"] + if len(uploadID) == 0 { + errorResponse(ctx, w, errInvalid("upload_id parameter is not found")) + return + } + jobID := uploadID[0] + job := project.Job(jobID) + if job == nil { + errorResponse(ctx, w, errNotFound(fmt.Sprintf("upload job %s is not found", jobID))) + return + } + if err := h.Handle(ctx, &uploadContentRequest{ + server: server, + project: project, + job: job, + reader: r.Body, + }); err != nil { + uploadErrorResponse(ctx, w, err) + return + } + content := job.Content() + content.Status = &bigqueryv2.JobStatus{State: "DONE"} + encodeResponse(ctx, w, content) +} + +type uploadContentRequest struct { + server *Server + project *metadata.Project + job *metadata.Job + reader io.Reader +} + +func (h *uploadContentHandler) getCandidateName(col string, columnNames []string) string { + var ( + foundName string + foundCount int + ) + for _, name := range columnNames { + if strings.Contains(name, col) { + foundName = name + foundCount++ + } + } + if foundCount == 1 { + return foundName + } + return "" +} + +func (h *uploadContentHandler) existsColumnNameInCSVHeader(col string, header []string) bool { + for _, h := range header { + if col == h { + return true + } + } + return false +} + +func (h *uploadContentHandler) normalizeColumnNameForJSONData(columnMap map[string]*types.Column, data map[string]interface{}) { + for k, v := range data { + if _, exists := columnMap[k]; exists { + continue + } + lowerKey := strings.ToLower(k) + var ( + foundCount int + columnName string + ) + for colName := range columnMap { + if lowerKey == strings.ToLower(colName) { + foundCount++ + columnName = colName + } + } + if foundCount == 1 { + delete(data, k) + data[columnName] = v + } + } +} + +func newLoadCSVReader(reader io.Reader, fieldDelimiter string) (*csv.Reader, error) { + csvReader := csv.NewReader(reader) + csvReader.FieldsPerRecord = -1 + if fieldDelimiter == "" { + return csvReader, nil + } + delimiters := []rune(fieldDelimiter) + if len(delimiters) != 1 { + return nil, fmt.Errorf("fieldDelimiter must be a single character") + } + csvReader.Comma = delimiters[0] + return csvReader, nil +} + +func schemaColumns(fields []*bigqueryv2.TableFieldSchema) []*types.Column { + columns := make([]*types.Column, 0, len(fields)) + for _, field := range fields { + columns = append(columns, types.NewColumnWithSchema(field)) + } + return columns +} + +func columnsFromCSVHeader(header []string, columnToType map[string]types.Type) ([]*types.Column, bool) { + columns := make([]*types.Column, 0, len(header)) + for _, col := range header { + columnType, exists := columnToType[col] + if !exists { + return nil, false + } + columns = append(columns, &types.Column{ + Name: col, + Type: columnType, + }) + } + return columns, true +} + +func csvLoadColumnsAndRows(records [][]string, schemaFields []*bigqueryv2.TableFieldSchema, columnToType map[string]types.Type, skipLeadingRows int64) ([]*types.Column, [][]string, error) { + window, err := csvRowWindowFor(len(records), skipLeadingRows) + if err != nil { + return nil, nil, err + } + if !window.hasHeader { + return schemaColumns(schemaFields), nil, nil + } + + dataStart := window.dataStart + if window.headerIndex < len(records) { + if columns, ok := columnsFromCSVHeader(records[window.headerIndex], columnToType); ok { + if dataStart <= window.headerIndex { + dataStart = window.headerIndex + 1 + } + return columns, records[dataStart:], nil + } + } + return schemaColumns(schemaFields), records[dataStart:], nil +} + +func csvRowsToTableData(records [][]string, schemaFields []*bigqueryv2.TableFieldSchema, skipLeadingRows int64, allowJaggedRows bool) ([]*types.Column, types.Data, error) { + columnToType := map[string]types.Type{} + for _, field := range schemaFields { + columnToType[field.Name] = types.Type(field.Type) + } + columns, dataRows, err := csvLoadColumnsAndRows(records, schemaFields, columnToType, skipLeadingRows) + if err != nil { + return nil, nil, err + } + data := types.Data{} + for _, record := range dataRows { + rowData := map[string]interface{}{} + if len(record) > len(columns) || (!allowJaggedRows && len(record) != len(columns)) { + return nil, nil, fmt.Errorf("invalid column number: found broken row data: %v", record) + } + for i := 0; i < len(columns); i++ { + var colData string + if i < len(record) { + colData = record[i] + } + if colData == "" { + rowData[columns[i].Name] = nil + } else { + rowData[columns[i].Name] = colData + } + } + data = append(data, rowData) + } + return columns, data, nil +} + +func (h *uploadContentHandler) Handle(ctx context.Context, r *uploadContentRequest) error { + load := r.job.Content().Configuration.Load + tableRef := load.DestinationTable + if tableRef == nil { + return errInvalid("load job is missing configuration.load.destinationTable") + } + dataset := r.project.Dataset(tableRef.DatasetId) + if dataset == nil { + return errNotFound(fmt.Sprintf("dataset %q is not found", tableRef.DatasetId)) + } + table := dataset.Table(tableRef.TableId) + // The write disposition only matters for a table that already exists; a + // freshly created one is empty regardless. + tableExisted := table != nil + + // Read CSV content up front so an autodetect load can infer the schema + // before the destination table is created. + var csvRecords [][]string + var csvColumns []*types.Column + var csvData types.Data + var csvDataReady bool + if load.SourceFormat == "CSV" { + csvReader, err := newLoadCSVReader(r.reader, load.FieldDelimiter) + if err != nil { + return err + } + records, err := csvReader.ReadAll() + if err != nil { + return fmt.Errorf("failed to read csv: %w", err) + } + csvRecords = records + if !tableExisted && load.Schema == nil && load.Autodetect { + schema, err := inferCSVSchema(csvRecords, load.SkipLeadingRows) + if err != nil { + return err + } + load.Schema = schema + } + } + if table == nil { + if load.CreateDisposition == "CREATE_NEVER" { + return fmt.Errorf("`%s` is not found", tableRef.TableId) + } + if load.SourceFormat == "CSV" && load.Schema != nil { + var err error + csvColumns, csvData, err = csvRowsToTableData(csvRecords, load.Schema.Fields, load.SkipLeadingRows, load.AllowJaggedRows) + if err != nil { + return err + } + csvDataReady = true + } + if _, err := (&tablesInsertHandler{}).Handle(ctx, &tablesInsertRequest{ + server: r.server, + project: r.project, + dataset: dataset, + table: &bigqueryv2.Table{ + Schema: load.Schema, + TableReference: tableRef, + }, + }); err != nil { + return err + } + table = dataset.Table(tableRef.TableId) + } + + tableContent, err := table.Content() + if err != nil { + return err + } + + sourceFormat := load.SourceFormat + columns := []*types.Column{} + data := types.Data{} + switch sourceFormat { + case "CSV": + if csvDataReady { + columns = csvColumns + data = csvData + break + } + var err error + columns, data, err = csvRowsToTableData(csvRecords, tableContent.Schema.Fields, load.SkipLeadingRows, load.AllowJaggedRows) + if err != nil { + return err + } + case "PARQUET": + b, err := io.ReadAll(r.reader) + if err != nil { + return err + } + reader := parquet.NewReader(bytes.NewReader(b)) + defer reader.Close() + + columns = schemaColumns(load.Schema.Fields) + + for i := 0; i < int(reader.NumRows()); i++ { + var rowData interface{} + err := reader.Read(&rowData) + if err != nil { + return err + } + + data = append(data, rowData.(map[string]interface{})) + } + case "NEWLINE_DELIMITED_JSON": + columns = schemaColumns(tableContent.Schema.Fields) + columnMap := map[string]*types.Column{} + for _, col := range columns { + columnMap[col.Name] = col + } + decoder := json.NewDecoder(r.reader) + decoder.UseNumber() + for decoder.More() { + d := make(map[string]interface{}) + if err := decoder.Decode(&d); err != nil { + return err + } + h.normalizeColumnNameForJSONData(columnMap, d) + data = append(data, d) + } + default: + return fmt.Errorf("not support sourceFormat: %s", sourceFormat) + } + tableDef := &types.Table{ + ID: tableRef.TableId, + Columns: columns, + Data: data, + } + conn, err := r.server.connMgr.Connection(ctx, tableRef.ProjectId, tableRef.DatasetId) + if err != nil { + return fmt.Errorf("failed to get connection: %w", err) + } + tx, err := conn.Begin(ctx) + if err != nil { + return err + } + defer tx.RollbackIfNotCommitted() + if tableExisted { + switch load.WriteDisposition { + case "WRITE_TRUNCATE": + if err := r.server.contentRepo.TruncateTable(ctx, tx, tableRef.ProjectId, tableRef.DatasetId, tableRef.TableId); err != nil { + return err + } + case "WRITE_EMPTY": + count, err := r.server.contentRepo.CountTableRows(ctx, tx, tableRef.ProjectId, tableRef.DatasetId, tableRef.TableId) + if err != nil { + return err + } + if count > 0 { + return fmt.Errorf("table %s already exists and contains data (WRITE_EMPTY)", tableRef.TableId) + } + } + } + if err := r.server.contentRepo.AddTableData(ctx, tx, tableRef.ProjectId, tableRef.DatasetId, tableDef); err != nil { + return err + } + if err := tx.Commit(); err != nil { + return err + } + return nil +} + +const ( + formatOptionsUseInt64TimestampParam = "formatOptions.useInt64Timestamp" + deleteContentsParam = "deleteContents" +) + +func isDeleteContents(r *http.Request) bool { + return parseQueryValueAsBool(r, deleteContentsParam) +} + +func isFormatOptionsUseInt64Timestamp(r *http.Request) bool { + return parseQueryValueAsBool(r, formatOptionsUseInt64TimestampParam) +} + +// parseQueryValueAsUint64 reads an unsigned integer query parameter, reporting +// whether it was present and valid. +func parseQueryValueAsUint64(r *http.Request, key string) (uint64, bool) { + values, exists := r.URL.Query()[key] + if !exists || len(values) != 1 { + return 0, false + } + v, err := strconv.ParseUint(values[0], 10, 64) + if err != nil { + return 0, false + } + return v, true +} + +// applyNullQueryParameters inspects the raw JSON of a query-parameters array +// and clears the ParameterValue of every scalar parameter whose value was JSON +// null or absent. The bigqueryv2 structs store a parameter value as a plain +// string, so they cannot otherwise distinguish a NULL scalar from an empty string. +// +// ARRAY and STRUCT parameters must never be cleared, even when their +// parameterValue is empty (e.g. an empty []string{} omits "arrayValues" +// entirely due to JSON omitempty). The parameter type is used as the +// authoritative signal: any parameter whose type is "ARRAY" or "STRUCT" is +// left untouched. +func applyNullQueryParameters(rawParams []json.RawMessage, params []*bigqueryv2.QueryParameter) { + for i, raw := range rawParams { + if i >= len(params) || params[i] == nil { + continue + } + var p struct { + ParameterType *struct { + Type string `json:"type"` + } `json:"parameterType"` + ParameterValue *json.RawMessage `json:"parameterValue"` + } + if err := json.Unmarshal(raw, &p); err != nil { + continue + } + // ARRAY and STRUCT parameters must not be cleared regardless of whether + // their parameterValue happens to be empty. + if p.ParameterType != nil { + switch p.ParameterType.Type { + case "ARRAY", "STRUCT": + continue + } + } + // No parameterValue key at all → treat as null scalar. + if p.ParameterValue == nil { + params[i].ParameterValue = nil + continue + } + // Decode parameterValue into a key-presence map. + // JSON null decodes to a nil map; map lookups on nil return false safely. + var pv map[string]json.RawMessage + if err := json.Unmarshal(*p.ParameterValue, &pv); err != nil { + continue + } + // Scalar: clear if "value" is absent or is the JSON literal null. + valueRaw, hasValue := pv["value"] + if !hasValue || string(valueRaw) == "null" { + params[i].ParameterValue = nil + } + } +} + +func parseQueryValueAsBool(r *http.Request, key string) bool { + queryValues := r.URL.Query() + values, exists := queryValues[key] + if !exists { + return false + } + if len(values) != 1 { + return false + } + b, err := strconv.ParseBool(values[0]) + if err != nil { + return false + } + return b +} + +func (h *datasetsDeleteHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + server := serverFromContext(ctx) + project := projectFromContext(ctx) + dataset := datasetFromContext(ctx) + if err := h.Handle(ctx, &datasetsDeleteRequest{ + server: server, + project: project, + dataset: dataset, + deleteContents: isDeleteContents(r), + }); err != nil { + // Preserve typed *ServerError (e.g. a 400 resourceInUse for a + // non-empty dataset) so the client sees the real HTTP status + // rather than the 500 retry-forever loop a blanket wrap would + // cause. + var serr *ServerError + if errors.As(err, &serr) { + errorResponse(ctx, w, serr) + return + } + errorResponse(ctx, w, errInternalError(err.Error())) + return + } +} + +type datasetsDeleteRequest struct { + server *Server + project *metadata.Project + dataset *metadata.Dataset + deleteContents bool +} + +func (h *datasetsDeleteHandler) Handle(ctx context.Context, r *datasetsDeleteRequest) error { + // BigQuery rejects deleting a non-empty dataset unless + // deleteContents=true. Reject up front so the dataset is never + // removed while its tables remain (which would orphan them and + // surface as a UNIQUE constraint violation on the next CREATE + // TABLE with the same name) and so the caller sees a 4xx rather + // than the 500 the Google SDKs retry indefinitely on. + if !r.deleteContents && len(r.dataset.Tables()) > 0 { + return errResourceInUse(fmt.Sprintf( + "Dataset %s:%s is still in use", + r.project.ID, r.dataset.ID, + )) + } + conn, err := r.server.connMgr.Connection(ctx, r.project.ID, r.dataset.ID) + if err != nil { + return fmt.Errorf("failed to get connection: %w", err) + } + tx, err := conn.Begin(ctx) + if err != nil { + return fmt.Errorf("failed to start transaction: %w", err) + } + defer tx.RollbackIfNotCommitted() + if err := r.project.DeleteDataset(ctx, tx.Tx(), r.dataset.ID); err != nil { + return fmt.Errorf("failed to delete dataset: %w", err) + } + if r.deleteContents { + tables := r.dataset.Tables() + deletions := make([]contentdata.TableDeletion, 0, len(tables)) + for _, table := range tables { + if err := table.Delete(ctx, tx.Tx()); err != nil { + return err + } + deletions = append(deletions, contentdata.TableDeletion{ + ID: table.ID, + IsView: table.IsView(), + }) + } + if err := r.server.contentRepo.DeleteTables(ctx, tx, r.project.ID, r.dataset.ID, deletions); err != nil { + return fmt.Errorf("failed to delete tables: %w", err) + } + } + if err := tx.Commit(); err != nil { + return fmt.Errorf("failed to commit delete dataset: %w", err) + } + return nil +} + +func (h *datasetsGetHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + server := serverFromContext(ctx) + project := projectFromContext(ctx) + dataset := datasetFromContext(ctx) + res, err := h.Handle(ctx, &datasetsGetRequest{ + server: server, + project: project, + dataset: dataset, + }) + if err != nil { + errorResponse(ctx, w, errInternalError(err.Error())) + return + } + encodeResponse(ctx, w, res) +} + +type datasetsGetRequest struct { + server *Server + project *metadata.Project + dataset *metadata.Dataset +} + +func (h *datasetsGetHandler) Handle(ctx context.Context, r *datasetsGetRequest) (*bigqueryv2.Dataset, error) { + newContent := *r.dataset.Content() + newContent.DatasetReference = &bigqueryv2.DatasetReference{ + ProjectId: r.project.ID, + DatasetId: r.dataset.ID, + } + return &newContent, nil +} + +func (h *datasetsInsertHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + server := serverFromContext(ctx) + project := projectFromContext(ctx) + var dataset bigqueryv2.Dataset + if err := json.NewDecoder(r.Body).Decode(&dataset); err != nil { + errorResponse(ctx, w, errInvalid(err.Error())) + return + } + res, err := h.Handle(ctx, &datasetsInsertRequest{ + server: server, + project: project, + dataset: &dataset, + }) + if err != nil { + var serr *ServerError + if errors.As(err, &serr) { + errorResponse(ctx, w, serr) + return + } + errorResponse(ctx, w, errInternalError(err.Error())) + return + } + encodeResponse(ctx, w, res) +} + +type datasetsInsertRequest struct { + server *Server + project *metadata.Project + dataset *bigqueryv2.Dataset +} + +func (h *datasetsInsertHandler) Handle(ctx context.Context, r *datasetsInsertRequest) (*bigqueryv2.DatasetListDatasets, error) { + if r.dataset.DatasetReference == nil { + return nil, fmt.Errorf("DatasetReference is nil") + } + datasetID := r.dataset.DatasetReference.DatasetId + if datasetID == "" { + return nil, fmt.Errorf("dataset id is empty") + } + conn, err := r.server.connMgr.Connection(ctx, r.project.ID, datasetID) + if err != nil { + return nil, fmt.Errorf("failed to get connection: %w", err) + } + tx, err := conn.Begin(ctx) + if err != nil { + return nil, fmt.Errorf("failed to start transaction: %w", err) + } + defer tx.RollbackIfNotCommitted() + + if err := r.project.AddDataset( + ctx, + tx.Tx(), + metadata.NewDataset( + r.server.metaRepo, + r.project.ID, + datasetID, + r.dataset, + nil, + nil, + nil, + ), + ); err != nil { + if errors.Is(err, metadata.ErrDuplicatedDataset) { + return nil, errDuplicate(err.Error()) + } + return nil, err + } + if err := tx.Commit(); err != nil { + return nil, err + } + return &bigqueryv2.DatasetListDatasets{ + DatasetReference: &bigqueryv2.DatasetReference{ + ProjectId: r.project.ID, + DatasetId: datasetID, + }, + Id: datasetID, + FriendlyName: r.dataset.FriendlyName, + Kind: r.dataset.Kind, + Labels: r.dataset.Labels, + Location: r.dataset.Location, + ForceSendFields: r.dataset.ForceSendFields, + NullFields: r.dataset.NullFields, + }, nil +} + +func (h *datasetsListHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + server := serverFromContext(ctx) + project := projectFromContext(ctx) + res, err := h.Handle(ctx, &datasetsListRequest{ + server: server, + project: project, + }) + if err != nil { + errorResponse(ctx, w, errInternalError(err.Error())) + return + } + encodeResponse(ctx, w, res) +} + +type datasetsListRequest struct { + server *Server + project *metadata.Project +} + +func (h *datasetsListHandler) Handle(ctx context.Context, r *datasetsListRequest) (*bigqueryv2.DatasetList, error) { + datasetsRes := []*bigqueryv2.DatasetListDatasets{} + for _, dataset := range r.project.Datasets() { + content := dataset.Content() + datasetsRes = append(datasetsRes, &bigqueryv2.DatasetListDatasets{ + DatasetReference: &bigqueryv2.DatasetReference{ + ProjectId: r.project.ID, + DatasetId: dataset.ID, + }, + FriendlyName: content.FriendlyName, + Id: dataset.ID, + Kind: content.Kind, + Labels: content.Labels, + Location: content.Location, + ForceSendFields: content.ForceSendFields, + NullFields: content.NullFields, + }) + } + return &bigqueryv2.DatasetList{ + Datasets: datasetsRes, + }, nil +} + +func (h *datasetsPatchHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + server := serverFromContext(ctx) + project := projectFromContext(ctx) + dataset := datasetFromContext(ctx) + var newDataset bigqueryv2.Dataset + if err := json.NewDecoder(r.Body).Decode(&newDataset); err != nil { + errorResponse(ctx, w, errInvalid(err.Error())) + return + } + res, err := h.Handle(ctx, &datasetsPatchRequest{ + server: server, + project: project, + dataset: dataset, + newDataset: &newDataset, + }) + if err != nil { + errorResponse(ctx, w, errInternalError(err.Error())) + return + } + encodeResponse(ctx, w, res) +} + +type datasetsPatchRequest struct { + server *Server + project *metadata.Project + dataset *metadata.Dataset + newDataset *bigqueryv2.Dataset +} + +func (h *datasetsPatchHandler) Handle(ctx context.Context, r *datasetsPatchRequest) (*bigqueryv2.Dataset, error) { + r.dataset.UpdateContentIfExists(r.newDataset) + newContent := *r.dataset.Content() + return &newContent, nil +} + +func (h *datasetsUpdateHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + server := serverFromContext(ctx) + project := projectFromContext(ctx) + dataset := datasetFromContext(ctx) + var newDataset bigqueryv2.Dataset + if err := json.NewDecoder(r.Body).Decode(&newDataset); err != nil { + errorResponse(ctx, w, errInvalid(err.Error())) + return + } + res, err := h.Handle(ctx, &datasetsUpdateRequest{ + server: server, + project: project, + dataset: dataset, + newDataset: &newDataset, + }) + if err != nil { + errorResponse(ctx, w, errInternalError(err.Error())) + return + } + encodeResponse(ctx, w, res) +} + +type datasetsUpdateRequest struct { + server *Server + project *metadata.Project + dataset *metadata.Dataset + newDataset *bigqueryv2.Dataset +} + +func (h *datasetsUpdateHandler) Handle(ctx context.Context, r *datasetsUpdateRequest) (*bigqueryv2.Dataset, error) { + r.dataset.UpdateContent(r.newDataset) + newContent := *r.dataset.Content() + return &newContent, nil +} + +func (h *jobsCancelHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + server := serverFromContext(ctx) + project := projectFromContext(ctx) + job := jobFromContext(ctx) + res, err := h.Handle(ctx, &jobsCancelRequest{ + server: server, + project: project, + job: job, + }) + if err != nil { + errorResponse(ctx, w, errJobInternalError(err.Error())) + return + } + encodeResponse(ctx, w, res) +} + +type jobsCancelRequest struct { + server *Server + project *metadata.Project + job *metadata.Job +} + +func (h *jobsCancelHandler) Handle(ctx context.Context, r *jobsCancelRequest) (*bigqueryv2.JobCancelResponse, error) { + if err := r.job.Cancel(ctx); err != nil { + return nil, err + } + return &bigqueryv2.JobCancelResponse{Job: r.job.Content()}, nil +} + +func (h *jobsDeleteHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + server := serverFromContext(ctx) + project := projectFromContext(ctx) + job := jobFromContext(ctx) + if err := h.Handle(ctx, &jobsDeleteRequest{ + server: server, + project: project, + job: job, + }); err != nil { + errorResponse(ctx, w, errJobInternalError(err.Error())) + return + } +} + +type jobsDeleteRequest struct { + server *Server + project *metadata.Project + job *metadata.Job +} + +func (h *jobsDeleteHandler) Handle(ctx context.Context, r *jobsDeleteRequest) error { + conn, err := r.server.connMgr.Connection(ctx, r.project.ID, "") + if err != nil { + return fmt.Errorf("failed to get connection: %w", err) + } + tx, err := conn.Begin(ctx) + if err != nil { + return fmt.Errorf("failed to start transaction: %w", err) + } + defer tx.RollbackIfNotCommitted() + if err := r.project.DeleteJob(ctx, tx.Tx(), r.job.ID); err != nil { + return fmt.Errorf("failed to delete job: %w", err) + } + if err := tx.Commit(); err != nil { + return err + } + return nil +} + +func (h *jobsGetHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + server := serverFromContext(ctx) + project := projectFromContext(ctx) + job := jobFromContext(ctx) + res, err := h.Handle(ctx, &jobsGetRequest{ + server: server, + project: project, + job: job, + }) + if err != nil { + errorResponse(ctx, w, errJobInternalError(err.Error())) + return + } + encodeResponse(ctx, w, res) +} + +type jobsGetRequest struct { + server *Server + project *metadata.Project + job *metadata.Job +} + +func (h *jobsGetHandler) Handle(ctx context.Context, r *jobsGetRequest) (*bigqueryv2.Job, error) { + content := *r.job.Content() + content.Status = &bigqueryv2.JobStatus{State: "DONE"} + return &content, nil +} + +func (h *jobsGetQueryResultsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + server := serverFromContext(ctx) + project := projectFromContext(ctx) + job := jobFromContext(ctx) + maxResults, hasMaxResults := parseQueryValueAsUint64(r, "maxResults") + startIndex, _ := parseQueryValueAsUint64(r, "startIndex") + // The page token returned by this handler is simply the next row index. + if token, ok := parseQueryValueAsUint64(r, "pageToken"); ok { + startIndex = token + } + res, err := h.Handle(ctx, &jobsGetQueryResultsRequest{ + server: server, + project: project, + job: job, + useInt64Timestamp: isFormatOptionsUseInt64Timestamp(r), + maxResults: maxResults, + hasMaxResults: hasMaxResults, + startIndex: startIndex, + }) + if err != nil { + errorResponse(ctx, w, errJobInternalError(err.Error())) + return + } + encodeResponse(ctx, w, res) +} + +type jobsGetQueryResultsRequest struct { + server *Server + project *metadata.Project + job *metadata.Job + useInt64Timestamp bool + maxResults uint64 + hasMaxResults bool + startIndex uint64 +} + +func (h *jobsGetQueryResultsHandler) Handle(ctx context.Context, r *jobsGetQueryResultsRequest) (*internaltypes.GetQueryResultsResponse, error) { + response, err := r.job.Wait(ctx) + if err != nil { + return nil, err + } + rows := internaltypes.Format(response.Schema, response.Rows, r.useInt64Timestamp) + + // Honor maxResults/startIndex paging. Clients (notably the Python one) + // poll getQueryResults with maxResults=0 purely to await completion and + // then fetch the rows through a separate, paged request; returning every + // row on the completion poll hands them rows in a format they did not ask + // for. + total := uint64(len(rows)) + start := r.startIndex + if start > total { + start = total + } + end := total + pageToken := "" + if r.hasMaxResults { + end = start + r.maxResults + if end > total { + end = total + } + if end < total { + pageToken = strconv.FormatUint(end, 10) + } + } + return &internaltypes.GetQueryResultsResponse{ + JobReference: &bigqueryv2.JobReference{ + ProjectId: r.project.ID, + JobId: r.job.ID, + }, + Schema: response.Schema, + TotalRows: response.TotalRows, + JobComplete: true, + PageToken: pageToken, + Rows: rows[start:end], + }, nil +} + +func (h *jobsInsertHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + server := serverFromContext(ctx) + project := projectFromContext(ctx) + body, err := io.ReadAll(r.Body) + // A gzip body flushed but not closed delivers all its content yet ends + // with ErrUnexpectedEOF; the json.Unmarshal below is the real validator. + if err != nil && !errors.Is(err, io.ErrUnexpectedEOF) { + errorResponse(ctx, w, errInvalid(err.Error())) + return + } + var job bigqueryv2.Job + if err := json.Unmarshal(body, &job); err != nil { + errorResponse(ctx, w, errInvalid(err.Error())) + return + } + if job.Configuration != nil && job.Configuration.Query != nil { + var rawReq struct { + Configuration struct { + Query struct { + QueryParameters []json.RawMessage `json:"queryParameters"` + } `json:"query"` + } `json:"configuration"` + } + if err := json.Unmarshal(body, &rawReq); err == nil { + applyNullQueryParameters(rawReq.Configuration.Query.QueryParameters, job.Configuration.Query.QueryParameters) + } + } + res, err := h.Handle(ctx, &jobsInsertRequest{ + server: server, + project: project, + job: &job, + }) + if err != nil { + // Preserve typed *ServerError (e.g. a 404 notFound for a missing + // destination table under CREATE_NEVER) so the client sees the + // real HTTP status rather than a blanket 400 jobInternalError. + var serr *ServerError + if errors.As(err, &serr) { + errorResponse(ctx, w, serr) + return + } + errorResponse(ctx, w, errJobInternalError(err.Error())) + return + } + encodeResponse(ctx, w, res) +} + +type jobsInsertRequest struct { + server *Server + project *metadata.Project + job *bigqueryv2.Job +} + +func (h *jobsInsertHandler) tableDefFromQueryResponse(tableID string, response *internaltypes.QueryResponse) (*types.Table, error) { + columns := []*types.Column{} + for _, field := range response.Schema.Fields { + columns = append(columns, types.NewColumnWithSchema(field)) + } + data := types.Data{} + for _, row := range response.Rows { + rowData, err := row.Data() + if err != nil { + return nil, err + } + data = append(data, rowData) + } + return types.NewTableWithSchema( + &bigqueryv2.Table{ + TableReference: &bigqueryv2.TableReference{ + TableId: tableID, + }, + Schema: response.Schema, + }, + data, + ) +} + +func (h *jobsInsertHandler) destinationProjectAndDataset(ctx context.Context, r *jobsInsertRequest, tableRef *bigqueryv2.TableReference) (*metadata.Project, *metadata.Dataset, error) { + projectID := tableRef.ProjectId + if projectID == "" { + projectID = r.project.ID + tableRef.ProjectId = projectID + } + + project := r.project + if projectID != r.project.ID { + var err error + project, err = r.server.metaRepo.FindProject(ctx, projectID) + if err != nil { + return nil, nil, err + } + if project == nil { + return nil, nil, errNotFound(fmt.Sprintf("project %q is not found", projectID)) + } + } + + dataset := project.Dataset(tableRef.DatasetId) + if dataset == nil { + return nil, nil, errNotFound(fmt.Sprintf("dataset %q is not found in project %q", tableRef.DatasetId, projectID)) + } + return project, dataset, nil +} + +const ( + gcsEmulatorHostEnvName = "STORAGE_EMULATOR_HOST" + gcsURIPrefix = "gs://" +) + +// gcsClientOptions builds the option set for a Cloud Storage client. +// +// - When STORAGE_EMULATOR_HOST is set, the client targets that GCS emulator +// with authentication disabled. +// - Otherwise, if no Application Default Credentials are configured, the +// client falls back to anonymous access so that loads from public buckets +// succeed instead of failing with "could not find default credentials". +// - When GOOGLE_APPLICATION_CREDENTIALS is set, those credentials are used. +func gcsClientOptions(jsonReads bool) []option.ClientOption { + if host := os.Getenv(gcsEmulatorHostEnvName); host != "" { + opts := []option.ClientOption{ + option.WithEndpoint(fmt.Sprintf("%s/storage/v1/", host)), + option.WithoutAuthentication(), + } + if jsonReads { + opts = append(opts, storage.WithJSONReads()) + } + return opts + } + if os.Getenv("GOOGLE_APPLICATION_CREDENTIALS") == "" { + return []option.ClientOption{option.WithoutAuthentication()} + } + return nil +} + +func (h *jobsInsertHandler) importFromGCS(ctx context.Context, r *jobsInsertRequest) (*bigqueryv2.Job, error) { + client, err := storage.NewClient(ctx, gcsClientOptions(true)...) + if err != nil { + return nil, err + } + startTime := time.Now() + // The write disposition applies to the load job as a whole, so it is + // honored only for the first object imported; every later object (e.g. + // the matches of a wildcard URI) appends to what the first one wrote. + importObject := func(reader *storage.Reader) error { + if err := h.importFromGCSObject(ctx, r, reader); err != nil { + return err + } + r.job.Configuration.Load.WriteDisposition = "WRITE_APPEND" + return nil + } + for _, uri := range r.job.Configuration.Load.SourceUris { + if !strings.HasPrefix(uri, gcsURIPrefix) { + return nil, fmt.Errorf("load source uri must start with gs://") + } + uri = strings.TrimPrefix(uri, gcsURIPrefix) + paths := strings.Split(uri, "/") + if len(paths) < 2 { + return nil, fmt.Errorf("unexpected gcs uri format %s", uri) + } + bucketName := paths[0] + objectPath := strings.Join(paths[1:], "/") + switch strings.Count(objectPath, "*") { + case 0: + reader, err := client.Bucket(bucketName).Object(objectPath).NewReader(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get gcs object reader for %s: %w", uri, err) + } + if err := importObject(reader); err != nil { + return nil, err + } + case 1: + splitPath := strings.Split(objectPath, "*") + prefix := splitPath[0] + suffix := splitPath[1] + query := &storage.Query{ + Prefix: prefix, + } + query.SetAttrSelection([]string{"Name"}) + it := client.Bucket(bucketName).Objects(ctx, query) + for { + attrs, err := it.Next() + if err == iterator.Done { + break + } + if err != nil { + return nil, fmt.Errorf("failed to list gcs object for %s: %w", uri, err) + } + if strings.HasSuffix(attrs.Name, suffix) { + reader, err := client.Bucket(bucketName).Object(attrs.Name).NewReader(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get gcs object reader for %s: %w", uri, err) + } + if err := importObject(reader); err != nil { + return nil, err + } + } + } + default: + return nil, fmt.Errorf("the number of wildcards in gcs uri must be 0 or 1") + } + } + endTime := time.Now() + job := r.job + job.Kind = "bigquery#job" + job.Configuration.JobType = "LOAD" + job.SelfLink = fmt.Sprintf( + "http://%s/bigquery/v2/projects/%s/jobs/%s", + r.server.httpServer.Addr, + r.project.ID, + job.JobReference.JobId, + ) + job.Status = &bigqueryv2.JobStatus{State: "DONE"} + job.Statistics = &bigqueryv2.JobStatistics{ + CreationTime: startTime.Unix(), + StartTime: startTime.Unix(), + EndTime: endTime.Unix(), + } + conn, err := r.server.connMgr.Connection(ctx, r.project.ID, "") + if err != nil { + return nil, fmt.Errorf("failed to get connection: %w", err) + } + tx, err := conn.Begin(ctx) + if err != nil { + return nil, fmt.Errorf("failed to start transaction: %w", err) + } + defer tx.RollbackIfNotCommitted() + if err := r.project.AddJob( + ctx, + tx.Tx(), + metadata.NewJob( + r.server.metaRepo, + r.project.ID, + job.JobReference.JobId, + job, + nil, + nil, + ), + ); err != nil { + return nil, fmt.Errorf("failed to add job: %w", err) + } + if !job.Configuration.DryRun { + if err := tx.Commit(); err != nil { + return nil, fmt.Errorf("failed to commit job: %w", err) + } + } + return job, nil +} + +func (h *jobsInsertHandler) importFromGCSObject(ctx context.Context, r *jobsInsertRequest, reader *storage.Reader) error { + defer func() { + _ = reader.Close() + }() + job := metadata.NewJob( + r.server.metaRepo, + r.project.ID, + r.job.JobReference.JobId, + r.job, + nil, + nil, + ) + if err := new(uploadContentHandler).Handle(ctx, &uploadContentRequest{ + server: r.server, + project: r.project, + job: job, + reader: reader, + }); err != nil { + return err + } + return nil +} + +func (h *jobsInsertHandler) exportToGCS(ctx context.Context, r *jobsInsertRequest) (*bigqueryv2.Job, error) { + client, err := storage.NewClient(ctx, gcsClientOptions(false)...) + if err != nil { + return nil, err + } + startTime := time.Now() + conn, err := r.server.connMgr.Connection(ctx, r.project.ID, "") + if err != nil { + return nil, fmt.Errorf("failed to get connection: %w", err) + } + tx, err := conn.Begin(ctx) + if err != nil { + return nil, fmt.Errorf("failed to start transaction: %w", err) + } + defer tx.RollbackIfNotCommitted() + extract := r.job.Configuration.Extract + sourceTable := extract.SourceTable + response, err := r.server.contentRepo.Query( + ctx, + tx, + sourceTable.ProjectId, + sourceTable.DatasetId, + fmt.Sprintf("SELECT * FROM `%s`", sourceTable.TableId), + nil, + ) + if err != nil { + return nil, err + } + for _, uri := range extract.DestinationUris { + if !strings.HasPrefix(uri, gcsURIPrefix) { + return nil, fmt.Errorf("destination uri must start with gs://") + } + uri = strings.TrimPrefix(uri, gcsURIPrefix) + paths := strings.Split(uri, "/") + if len(paths) < 2 { + return nil, fmt.Errorf("unexpected gcs uri format %s", uri) + } + bucketName := paths[0] + objectPath := strings.Join(paths[1:], "/") + bucket := client.Bucket(bucketName) + _ = bucket.Create(ctx, r.project.ID, nil) // ignore "already exists" error. + writer := bucket.Object(objectPath).NewWriter(ctx) + if err := h.exportToGCSWithObject(ctx, response, extract, writer); err != nil { + return nil, err + } + } + endTime := time.Now() + job := r.job + job.Kind = "bigquery#job" + job.Configuration.JobType = "EXTRACT" + job.SelfLink = fmt.Sprintf( + "http://%s/bigquery/v2/projects/%s/jobs/%s", + r.server.httpServer.Addr, + r.project.ID, + job.JobReference.JobId, + ) + job.Status = &bigqueryv2.JobStatus{State: "DONE"} + job.Statistics = &bigqueryv2.JobStatistics{ + CreationTime: startTime.Unix(), + StartTime: startTime.Unix(), + EndTime: endTime.Unix(), + } + if err := r.project.AddJob( + ctx, + tx.Tx(), + metadata.NewJob( + r.server.metaRepo, + r.project.ID, + job.JobReference.JobId, + job, + nil, + nil, + ), + ); err != nil { + return nil, fmt.Errorf("failed to add job: %w", err) + } + if !job.Configuration.DryRun { + if err := tx.Commit(); err != nil { + return nil, fmt.Errorf("failed to commit job: %w", err) + } + } + return job, nil +} + +func (h *jobsInsertHandler) exportToGCSWithObject(ctx context.Context, response *internaltypes.QueryResponse, extract *bigqueryv2.JobConfigurationExtract, writer *storage.Writer) (e error) { + defer func() { + if err := writer.Close(); err != nil { + e = err + } + }() + switch extract.DestinationFormat { + case "CSV": + if len(response.Rows) == 0 { + if _, err := writer.Write(nil); err != nil { + return fmt.Errorf("failed to empty table data to gcs object: %w", err) + } + return nil + } + csvWriter := csv.NewWriter(writer) + var columns []string + for _, cell := range response.Rows[0].F { + columns = append(columns, cell.Name) + } + if extract.PrintHeader == nil { + if err := csvWriter.Write(columns); err != nil { + return fmt.Errorf("failed to encode csv columns: %w", err) + } + } + for _, row := range response.Rows { + data, err := row.Data() + if err != nil { + return fmt.Errorf("failed to get data from table row: %w", err) + } + var records []string + for _, col := range columns { + value := data[col] + if value == nil { + records = append(records, "") + continue + } + if v, ok := value.(string); ok { + records = append(records, v) + continue + } + jsonValue, err := json.Marshal(value) + if err != nil { + return fmt.Errorf("failed to encode row value: %w", err) + } + records = append(records, string(jsonValue)) + } + if err := csvWriter.Write(records); err != nil { + return fmt.Errorf("failed to encode csv data: %w", err) + } + } + csvWriter.Flush() + if err := csvWriter.Error(); err != nil { + return fmt.Errorf("failed to encode csv data: %w", err) + } + case "NEWLINE_DELIMITED_JSON": + writer.ContentType = "application/json" + enc := json.NewEncoder(writer) + for _, row := range response.Rows { + data, err := row.Data() + if err != nil { + return fmt.Errorf("failed to get data from table row: %w", err) + } + if err := enc.Encode(data); err != nil { + return fmt.Errorf("failed to encode table data: %w", err) + } + } + case "PARQUET": + var opts []parquet.WriterOption + switch extract.Compression { + case "GZIP": + opts = append(opts, parquet.Compression(&parquet.Gzip)) + case "SNAPPY": + opts = append(opts, parquet.Compression(&parquet.Snappy)) + case "DEFLATE": + opts = append(opts, parquet.Compression(&parquet.Gzip)) + } + _ = opts + fallthrough + default: + return fmt.Errorf("failed to export to gcs: unsupported destination format %s", extract.DestinationFormat) + } + return nil +} + +func (h *jobsInsertHandler) Handle(ctx context.Context, r *jobsInsertRequest) (*bigqueryv2.Job, error) { + job := r.job + if job.Configuration == nil { + return nil, fmt.Errorf("unspecified job configuration") + } + if job.Configuration.Query == nil { + if job.Configuration.Load != nil && len(job.Configuration.Load.SourceUris) != 0 { + // load from google cloud storage + job, err := h.importFromGCS(ctx, r) + if err != nil { + return nil, fmt.Errorf("failed to import from gcs: %w", err) + } + return job, nil + } else if job.Configuration.Extract != nil && len(job.Configuration.Extract.DestinationUris) != 0 { + job, err := h.exportToGCS(ctx, r) + if err != nil { + return nil, fmt.Errorf("failed to export to gcs: %w", err) + } + return job, nil + } + return nil, fmt.Errorf("unspecified job configuration query") + } + queryProjectID, datasetID := queryProjectAndDataset(job.Configuration.Query.DefaultDataset, r.project.ID) + conn, err := r.server.connMgr.Connection(ctx, queryProjectID, datasetID) + if err != nil { + return nil, fmt.Errorf("failed to get connection: %w", err) + } + tx, err := conn.Begin(ctx) + if err != nil { + return nil, fmt.Errorf("failed to start transaction: %w", err) + } + defer tx.RollbackIfNotCommitted() + hasDestinationTable := job.Configuration.Query.DestinationTable != nil + startTime := time.Now() + response, jobErr := r.server.contentRepo.Query( + ctx, + tx, + queryProjectID, + datasetID, + job.Configuration.Query.Query, + job.Configuration.Query.QueryParameters, + ) + endTime := time.Now() + if job.JobReference.JobId == "" { + job.JobReference.JobId = randomID() // generate job id + } + if jobErr == nil { + if hasDestinationTable { + // insert results to destination table + tableRef := job.Configuration.Query.DestinationTable + destinationProject, destinationDataset, err := h.destinationProjectAndDataset(ctx, r, tableRef) + if err != nil { + return nil, err + } + tableDef, err := h.tableDefFromQueryResponse(tableRef.TableId, response) + if err != nil { + return nil, err + } + destinationTable := destinationDataset.Table(tableRef.TableId) + destinationTableExists := destinationTable != nil + if !destinationTableExists { + // CreateDisposition controls whether a missing destination + // table is materialized on the fly. CREATE_NEVER must + // surface the missing table as a 404 (matching real + // BigQuery and load-job behaviour); CREATE_IF_NEEDED (the + // default) and an empty value create it from the query's + // inferred schema. + if job.Configuration.Query.CreateDisposition == "CREATE_NEVER" { + return nil, errNotFound(fmt.Sprintf( + "Not found: Table %s:%s.%s", + tableRef.ProjectId, tableRef.DatasetId, tableRef.TableId, + )) + } + table := tableDef.ToBigqueryV2(destinationProject.ID, tableRef.DatasetId) + _, err := createTableMetadata(ctx, tx, r.server, destinationProject, destinationDataset, table) + if err != nil { + return nil, fmt.Errorf("failed to create table: %w", err) + } + tx.SetProjectAndDataset(destinationProject.ID, tableRef.DatasetId) + serverErr := r.server.contentRepo.CreateTable(ctx, tx, tableDef.ToBigqueryV2(destinationProject.ID, tableRef.DatasetId)) + if serverErr != nil { + return nil, fmt.Errorf("failed to create table: %w", serverErr) + } + } + if err := r.server.contentRepo.AddTableData(ctx, tx, destinationProject.ID, tableRef.DatasetId, tableDef); err != nil { + return nil, fmt.Errorf("failed to add table data: %w", err) + } + } else if response != nil && response.Schema != nil && len(response.Schema.Fields) > 0 { + // A query that produces a result set (a SELECT, as opposed to a + // DDL/DML statement that has no result schema) still gets an + // anonymous results table. The job must advertise it through + // configuration.query.destinationTable: clients such as the Ruby + // one read query results through that reference. + destRef, err := h.addQueryResultToDynamicDestinationTable(ctx, tx, r, response) + if err != nil { + return nil, fmt.Errorf("failed to add query result to dynamic destination table: %w", err) + } + job.Configuration.Query.DestinationTable = destRef + } + } + job.Kind = "bigquery#job" + job.Configuration.JobType = "QUERY" + job.Configuration.Query.Priority = "INTERACTIVE" + job.SelfLink = fmt.Sprintf( + "http://%s/bigquery/v2/projects/%s/jobs/%s", + r.server.httpServer.Addr, + r.project.ID, + job.JobReference.JobId, + ) + status := &bigqueryv2.JobStatus{State: "DONE"} + if jobErr != nil { + internalErr := errJobInternalError(jobErr.Error()) + status.ErrorResult = internalErr.ErrorProto() + status.Errors = []*bigqueryv2.ErrorProto{internalErr.ErrorProto()} + } + job.Status = status + var totalBytes int64 + if response != nil { + totalBytes = response.TotalBytes + } + job.Statistics = &bigqueryv2.JobStatistics{ + Query: &bigqueryv2.JobStatistics2{ + CacheHit: false, + StatementType: "SELECT", + TotalBytesBilled: totalBytes, + TotalBytesProcessed: totalBytes, + }, + CreationTime: startTime.Unix(), + StartTime: startTime.Unix(), + EndTime: endTime.Unix(), + TotalBytesProcessed: totalBytes, + } + if err := r.project.AddJob( + ctx, + tx.Tx(), + metadata.NewJob( + r.server.metaRepo, + r.project.ID, + job.JobReference.JobId, + job, + response, + jobErr, + ), + ); err != nil { + return nil, fmt.Errorf("failed to add job: %w", err) + } + if !job.Configuration.DryRun { + if err := tx.Commit(); err != nil { + return nil, fmt.Errorf("failed to commit job: %w", err) + } + if response != nil && response.ChangedCatalog.Changed() { + if err := syncCatalog(ctx, r.server, response.ChangedCatalog); err != nil { + return nil, err + } + } + } + + return job, nil +} + +func queryProjectAndDataset(defaultDataset *bigqueryv2.DatasetReference, fallbackProjectID string) (string, string) { + projectID := fallbackProjectID + var datasetID string + if defaultDataset != nil { + if defaultDataset.ProjectId != "" { + projectID = defaultDataset.ProjectId + } + datasetID = defaultDataset.DatasetId + } + return projectID, datasetID +} + +func syncCatalog(ctx context.Context, server *Server, cat *googlesqlite.ChangedCatalog) error { + for _, table := range cat.Table.Added { + if err := addTableMetadata(ctx, server, table); err != nil { + return err + } + } + for _, table := range cat.Table.Deleted { + if err := deleteTableMetadata(ctx, server, table); err != nil { + return err + } + } + return nil +} + +func addTableMetadata(ctx context.Context, server *Server, spec *googlesqlite.TableSpec) error { + if len(spec.NamePath) != 3 { + return fmt.Errorf("unexpected table name path: %v", spec.NamePath) + } + projectID := spec.NamePath[0] + datasetID := spec.NamePath[1] + tableID := spec.NamePath[2] + project, err := server.metaRepo.FindProject(ctx, projectID) + if err != nil { + return err + } + dataset := project.Dataset(datasetID) + if dataset == nil { + return fmt.Errorf("dataset %s is not found", datasetID) + } + fields := make([]*bigqueryv2.TableFieldSchema, 0, len(spec.Columns)) + for _, column := range spec.Columns { + fields = append(fields, types.TableFieldSchemaFromColumnType(column.Name, column.Type)) + } + conn, err := server.connMgr.Connection(ctx, projectID, datasetID) + if err != nil { + return err + } + tx, err := conn.Begin(ctx) + if err != nil { + return err + } + defer tx.RollbackIfNotCommitted() + table := &bigqueryv2.Table{ + TableReference: &bigqueryv2.TableReference{ + ProjectId: projectID, + DatasetId: datasetID, + TableId: tableID, + }, + Schema: &bigqueryv2.TableSchema{Fields: fields}, + } + // A view created by a CREATE VIEW DDL statement must be recorded as a + // view so it is typed correctly and dropped with DROP VIEW. + if spec.IsView { + table.View = &bigqueryv2.ViewDefinition{Query: spec.Query} + } + if _, err := createTableMetadata(ctx, tx, server, project, dataset, table); err != nil { + return err + } + if err := tx.Commit(); err != nil { + return err + } + return nil +} + +func deleteTableMetadata(ctx context.Context, server *Server, spec *googlesqlite.TableSpec) error { + if len(spec.NamePath) != 3 { + return fmt.Errorf("unexpected table name path: %v", spec.NamePath) + } + projectID := spec.NamePath[0] + datasetID := spec.NamePath[1] + tableID := spec.NamePath[2] + project, err := server.metaRepo.FindProject(ctx, projectID) + if err != nil { + return err + } + dataset := project.Dataset(datasetID) + if dataset == nil { + return fmt.Errorf("dataset %s is not found", datasetID) + } + table := dataset.Table(tableID) + conn, err := server.connMgr.Connection(ctx, projectID, datasetID) + if err != nil { + return err + } + tx, err := conn.Begin(ctx) + if err != nil { + return err + } + defer tx.RollbackIfNotCommitted() + if err := table.Delete(ctx, tx.Tx()); err != nil { + return err + } + if err := tx.Commit(); err != nil { + return err + } + return nil +} + +// addQueryResultToDynamicDestinationTable materializes the result of a query +// that had no explicit destination into an anonymous results table (named +// after the job id) and returns a reference to it. +func (h *jobsInsertHandler) addQueryResultToDynamicDestinationTable(ctx context.Context, tx *connection.Tx, r *jobsInsertRequest, response *internaltypes.QueryResponse) (*bigqueryv2.TableReference, error) { + projectID := r.project.ID + jobID := r.job.JobReference.JobId + datasetID := jobID + tableID := jobID + + tableDef, err := h.tableDefFromQueryResponse(tableID, response) + if err != nil { + return nil, err + } + tableDef.SetupMetadata(projectID, datasetID) + table := metadata.NewTable(r.server.metaRepo, projectID, datasetID, tableID, tableDef.Metadata) + dataset := metadata.NewDataset( + r.server.metaRepo, + projectID, + datasetID, + &bigqueryv2.Dataset{ + Id: fmt.Sprintf("%s:%s", projectID, datasetID), + DatasetReference: &bigqueryv2.DatasetReference{ + ProjectId: projectID, + DatasetId: datasetID, + }, + }, + []*metadata.Table{table}, + nil, + nil, + ) + if err := r.project.AddDataset(ctx, tx.Tx(), dataset); err != nil { + return nil, err + } + if err := r.server.metaRepo.AddTable(ctx, tx.Tx(), table); err != nil { + return nil, err + } + tx.SetProjectAndDataset(projectID, datasetID) + if err := r.server.contentRepo.CreateTable(ctx, tx, tableDef.ToBigqueryV2(projectID, datasetID)); err != nil { + return nil, err + } + if err := r.server.contentRepo.AddTableData(ctx, tx, projectID, datasetID, tableDef); err != nil { + return nil, fmt.Errorf("failed to add table data: %w", err) + } + return &bigqueryv2.TableReference{ + ProjectId: projectID, + DatasetId: datasetID, + TableId: tableID, + }, nil +} + +func (h *jobsListHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + server := serverFromContext(ctx) + project := projectFromContext(ctx) + res, err := h.Handle(ctx, &jobsListRequest{ + server: server, + project: project, + }) + if err != nil { + errorResponse(ctx, w, errJobInternalError(err.Error())) + return + } + encodeResponse(ctx, w, res) +} + +type jobsListRequest struct { + server *Server + project *metadata.Project +} + +func (h *jobsListHandler) Handle(ctx context.Context, r *jobsListRequest) (*bigqueryv2.JobList, error) { + jobs := []*bigqueryv2.JobListJobs{} + for _, job := range r.project.Jobs() { + content := job.Content() + jobs = append(jobs, &bigqueryv2.JobListJobs{ + Id: content.Id, + JobReference: content.JobReference, + Kind: content.Kind, + Statistics: content.Statistics, + Status: content.Status, + UserEmail: content.UserEmail, + }) + } + return &bigqueryv2.JobList{Jobs: jobs}, nil +} + +func (h *jobsQueryHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + server := serverFromContext(ctx) + project := projectFromContext(ctx) + body, err := io.ReadAll(r.Body) + // A gzip body flushed but not closed delivers all its content yet ends + // with ErrUnexpectedEOF; the json.Unmarshal below is the real validator. + if err != nil && !errors.Is(err, io.ErrUnexpectedEOF) { + errorResponse(ctx, w, errInvalid(err.Error())) + return + } + var req bigqueryv2.QueryRequest + if err := json.Unmarshal(body, &req); err != nil { + errorResponse(ctx, w, errInvalid(err.Error())) + return + } + var rawReq struct { + QueryParameters []json.RawMessage `json:"queryParameters"` + } + if err := json.Unmarshal(body, &rawReq); err == nil { + applyNullQueryParameters(rawReq.QueryParameters, req.QueryParameters) + } + useInt64Timestamp := false + if options := req.FormatOptions; options != nil { + useInt64Timestamp = options.UseInt64Timestamp + } + useInt64Timestamp = useInt64Timestamp || isFormatOptionsUseInt64Timestamp(r) + res, err := h.Handle(ctx, &jobsQueryRequest{ + server: server, + project: project, + queryRequest: &req, + useInt64Timestamp: useInt64Timestamp, + }) + if err != nil { + errorResponse(ctx, w, errJobInternalError(err.Error())) + return + } + encodeResponse(ctx, w, res) +} + +type jobsQueryRequest struct { + server *Server + project *metadata.Project + queryRequest *bigqueryv2.QueryRequest + useInt64Timestamp bool +} + +func (h *jobsQueryHandler) Handle(ctx context.Context, r *jobsQueryRequest) (*internaltypes.QueryResponse, error) { + queryProjectID, datasetID := queryProjectAndDataset(r.queryRequest.DefaultDataset, r.project.ID) + conn, err := r.server.connMgr.Connection(ctx, queryProjectID, datasetID) + if err != nil { + return nil, err + } + tx, err := conn.Begin(ctx) + if err != nil { + return nil, err + } + defer tx.RollbackIfNotCommitted() + startTime := time.Now() + response, queryErr := r.server.contentRepo.Query( + ctx, + tx, + queryProjectID, + datasetID, + r.queryRequest.Query, + r.queryRequest.QueryParameters, + ) + if queryErr != nil { + return nil, queryErr + } + endTime := time.Now() + // jobs.query allocates jobIDs server-side (real BigQuery does the + // same). queryRequest.RequestId is the *idempotency* key — same + // RequestId on retry should return the cached result — and is + // deliberately not the jobID. The Go BigQuery client in particular + // instantiates a fresh `uid.NewSpace("request", …)` per call, so its + // per-Space atomic counter always emits `-0001`; concurrent + // in-process callers therefore submit identical RequestIds and would + // collide on AddJob if we routed them through as the jobID. Idempotency + // (RequestId → cached response) is a TODO; for now every call gets a + // fresh jobID. + jobID := randomID() + + // Persist the job in the metadata store, mirroring what + // jobsInsertHandler.Handle does at line ~1758. The previous behaviour + // returned a JobReference whose ID was never recorded, so any client + // that then issued `jobs.get(jobID)` — e.g. Go's + // `RowIterator.SourceJob().Status(ctx)`, or the Java/Node clients + // that poll the job for status after a synchronous query — got a 404 + // and (for clients that treat 404 as transient) hung re-polling. The + // synthetic Job below carries the minimum fields the GET handler + // returns to the caller; DryRun queries skip both the AddJob and the + // Commit so they remain side-effect-free. + var totalBytes int64 + if response != nil { + totalBytes = response.TotalBytes + } + job := &bigqueryv2.Job{ + Kind: "bigquery#job", + JobReference: &bigqueryv2.JobReference{ + ProjectId: r.project.ID, + JobId: jobID, + Location: r.queryRequest.Location, + }, + Configuration: &bigqueryv2.JobConfiguration{ + JobType: "QUERY", + DryRun: r.queryRequest.DryRun, + Query: &bigqueryv2.JobConfigurationQuery{ + Query: r.queryRequest.Query, + DefaultDataset: r.queryRequest.DefaultDataset, + QueryParameters: r.queryRequest.QueryParameters, + Priority: "INTERACTIVE", + }, + }, + Status: &bigqueryv2.JobStatus{State: "DONE"}, + Statistics: &bigqueryv2.JobStatistics{ + Query: &bigqueryv2.JobStatistics2{ + CacheHit: false, + StatementType: "SELECT", + TotalBytesBilled: totalBytes, + TotalBytesProcessed: totalBytes, + }, + CreationTime: startTime.Unix(), + StartTime: startTime.Unix(), + EndTime: endTime.Unix(), + TotalBytesProcessed: totalBytes, + }, + SelfLink: fmt.Sprintf( + "http://%s/bigquery/v2/projects/%s/jobs/%s", + r.server.httpServer.Addr, + r.project.ID, + jobID, + ), + } + if !r.queryRequest.DryRun { + if err := r.project.AddJob( + ctx, + tx.Tx(), + metadata.NewJob( + r.server.metaRepo, + r.project.ID, + jobID, + job, + response, + nil, + ), + ); err != nil { + return nil, fmt.Errorf("failed to add job: %w", err) + } + if err := tx.Commit(); err != nil { + return nil, err + } + if response.ChangedCatalog.Changed() { + if err := syncCatalog(ctx, r.server, response.ChangedCatalog); err != nil { + return nil, err + } + } + } + response.Rows = internaltypes.Format(response.Schema, response.Rows, r.useInt64Timestamp) + response.JobReference = job.JobReference + return response, nil +} + +func (h *modelsDeleteHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + server := serverFromContext(ctx) + project := projectFromContext(ctx) + dataset := datasetFromContext(ctx) + model := modelFromContext(ctx) + if err := h.Handle(ctx, &modelsDeleteRequest{ + server: server, + project: project, + dataset: dataset, + model: model, + }); err != nil { + errorResponse(ctx, w, errInternalError(err.Error())) + } +} + +type modelsDeleteRequest struct { + server *Server + project *metadata.Project + dataset *metadata.Dataset + model *metadata.Model +} + +func (h *modelsDeleteHandler) Handle(ctx context.Context, r *modelsDeleteRequest) error { + conn, err := r.server.connMgr.Connection(ctx, r.project.ID, r.dataset.ID) + if err != nil { + return err + } + tx, err := conn.Begin(ctx) + if err != nil { + return fmt.Errorf("failed to start transaction: %w", err) + } + defer tx.RollbackIfNotCommitted() + if err := r.dataset.DeleteModel(ctx, tx.Tx(), r.model.ID); err != nil { + return fmt.Errorf("failed to delete model: %w", err) + } + if err := tx.Commit(); err != nil { + return err + } + return nil +} + +func (h *modelsGetHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + server := serverFromContext(ctx) + project := projectFromContext(ctx) + dataset := datasetFromContext(ctx) + model := modelFromContext(ctx) + res, err := h.Handle(ctx, &modelsGetRequest{ + server: server, + project: project, + dataset: dataset, + model: model, + }) + if err != nil { + errorResponse(ctx, w, errInternalError(err.Error())) + return + } + encodeResponse(ctx, w, res) +} + +type modelsGetRequest struct { + server *Server + project *metadata.Project + dataset *metadata.Dataset + model *metadata.Model +} + +func (h *modelsGetHandler) Handle(ctx context.Context, r *modelsGetRequest) (*bigqueryv2.Model, error) { + return &bigqueryv2.Model{ + BestTrialId: 0, + CreationTime: 0, + DefaultTrialId: 0, + Description: "", + EncryptionConfiguration: nil, + Etag: "", + ExpirationTime: 0, + FeatureColumns: nil, + FriendlyName: "", + HparamSearchSpaces: nil, + HparamTrials: nil, + LabelColumns: nil, + Labels: nil, + LastModifiedTime: 0, + Location: "", + ModelReference: nil, + ModelType: "", + OptimalTrialIds: nil, + TrainingRuns: nil, + }, nil +} + +func (h *modelsListHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + server := serverFromContext(ctx) + project := projectFromContext(ctx) + dataset := datasetFromContext(ctx) + res, err := h.Handle(ctx, &modelsListRequest{ + server: server, + project: project, + dataset: dataset, + }) + if err != nil { + errorResponse(ctx, w, errInternalError(err.Error())) + return + } + encodeResponse(ctx, w, res) +} + +type modelsListRequest struct { + server *Server + project *metadata.Project + dataset *metadata.Dataset +} + +func (h *modelsListHandler) Handle(ctx context.Context, r *modelsListRequest) (*bigqueryv2.ListModelsResponse, error) { + models := []*bigqueryv2.Model{} + for _, m := range r.dataset.Models() { + _ = m + models = append(models, &bigqueryv2.Model{}) + } + return &bigqueryv2.ListModelsResponse{ + Models: models, + }, nil +} + +func (h *modelsPatchHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + server := serverFromContext(ctx) + project := projectFromContext(ctx) + dataset := datasetFromContext(ctx) + model := modelFromContext(ctx) + res, err := h.Handle(ctx, &modelsPatchRequest{ + server: server, + project: project, + dataset: dataset, + model: model, + }) + if err != nil { + errorResponse(ctx, w, errInternalError(err.Error())) + return + } + encodeResponse(ctx, w, res) +} + +type modelsPatchRequest struct { + server *Server + project *metadata.Project + dataset *metadata.Dataset + model *metadata.Model +} + +func (h *modelsPatchHandler) Handle(ctx context.Context, r *modelsPatchRequest) (*bigqueryv2.Model, error) { + return &bigqueryv2.Model{}, nil +} + +func (h *projectsGetServiceAccountHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + server := serverFromContext(ctx) + project := projectFromContext(ctx) + res, err := h.Handle(ctx, &projectsGetServiceAccountRequest{ + server: server, + project: project, + }) + if err != nil { + errorResponse(ctx, w, errInternalError(err.Error())) + return + } + encodeResponse(ctx, w, res) +} + +type projectsGetServiceAccountRequest struct { + server *Server + project *metadata.Project +} + +func (h *projectsGetServiceAccountHandler) Handle(ctx context.Context, r *projectsGetServiceAccountRequest) (*bigqueryv2.GetServiceAccountResponse, error) { + return &bigqueryv2.GetServiceAccountResponse{}, nil +} + +func (h *projectsListHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + server := serverFromContext(ctx) + res, err := h.Handle(ctx, &projectsListRequest{ + server: server, + }) + if err != nil { + errorResponse(ctx, w, errInternalError(err.Error())) + return + } + encodeResponse(ctx, w, res) +} + +type projectsListRequest struct { + server *Server +} + +func (h *projectsListHandler) Handle(ctx context.Context, r *projectsListRequest) (*bigqueryv2.ProjectList, error) { + projects, err := r.server.metaRepo.FindAllProjects(ctx) + if err != nil { + return nil, err + } + + projectList := []*bigqueryv2.ProjectListProjects{} + for i, p := range projects { + projectList = append(projectList, &bigqueryv2.ProjectListProjects{ + Id: p.ID, + NumericId: uint64(i + 1), + FriendlyName: p.ID, + }) + } + return &bigqueryv2.ProjectList{ + Projects: projectList, + }, nil +} + +func (h *routinesDeleteHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + server := serverFromContext(ctx) + project := projectFromContext(ctx) + dataset := datasetFromContext(ctx) + routine := routineFromContext(ctx) + if err := h.Handle(ctx, &routinesDeleteRequest{ + server: server, + project: project, + dataset: dataset, + routine: routine, + }); err != nil { + errorResponse(ctx, w, errInternalError(err.Error())) + return + } +} + +type routinesDeleteRequest struct { + server *Server + project *metadata.Project + dataset *metadata.Dataset + routine *metadata.Routine +} + +func (h *routinesDeleteHandler) Handle(ctx context.Context, r *routinesDeleteRequest) error { + return fmt.Errorf("unsupported bigquery.routines.delete") +} + +func (h *routinesGetHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + server := serverFromContext(ctx) + project := projectFromContext(ctx) + dataset := datasetFromContext(ctx) + routine := routineFromContext(ctx) + res, err := h.Handle(ctx, &routinesGetRequest{ + server: server, + project: project, + dataset: dataset, + routine: routine, + }) + if err != nil { + errorResponse(ctx, w, errInternalError(err.Error())) + return + } + encodeResponse(ctx, w, res) +} + +type routinesGetRequest struct { + server *Server + project *metadata.Project + dataset *metadata.Dataset + routine *metadata.Routine +} + +func (h *routinesGetHandler) Handle(ctx context.Context, r *routinesGetRequest) (*bigqueryv2.Routine, error) { + return nil, fmt.Errorf("unsupported bigquery.routines.get") +} + +func (h *routinesInsertHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + server := serverFromContext(ctx) + project := projectFromContext(ctx) + dataset := datasetFromContext(ctx) + var routine bigqueryv2.Routine + if err := json.NewDecoder(r.Body).Decode(&routine); err != nil { + errorResponse(ctx, w, errInvalid(err.Error())) + return + } + res, err := h.Handle(ctx, &routinesInsertRequest{ + server: server, + project: project, + dataset: dataset, + routine: &routine, + }) + if err != nil { + errorResponse(ctx, w, errInternalError(err.Error())) + return + } + encodeResponse(ctx, w, res) +} + +type routinesInsertRequest struct { + server *Server + project *metadata.Project + dataset *metadata.Dataset + routine *bigqueryv2.Routine +} + +func (h *routinesInsertHandler) Handle(ctx context.Context, r *routinesInsertRequest) (*bigqueryv2.Routine, error) { + conn, err := r.server.connMgr.Connection(ctx, r.project.ID, r.dataset.ID) + if err != nil { + return nil, fmt.Errorf("failed to get connection: %w", err) + } + tx, err := conn.Begin(ctx) + if err != nil { + return nil, err + } + defer tx.RollbackIfNotCommitted() + if err := r.server.contentRepo.AddRoutineByMetaData(ctx, tx, r.routine); err != nil { + return nil, err + } + if err := tx.Commit(); err != nil { + return nil, err + } + return r.routine, nil +} + +func (h *routinesListHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + server := serverFromContext(ctx) + project := projectFromContext(ctx) + dataset := datasetFromContext(ctx) + res, err := h.Handle(ctx, &routinesListRequest{ + server: server, + project: project, + dataset: dataset, + }) + if err != nil { + errorResponse(ctx, w, errInternalError(err.Error())) + return + } + encodeResponse(ctx, w, res) +} + +type routinesListRequest struct { + server *Server + project *metadata.Project + dataset *metadata.Dataset +} + +func (h *routinesListHandler) Handle(ctx context.Context, r *routinesListRequest) (*bigqueryv2.ListRoutinesResponse, error) { + var routineList []*bigqueryv2.Routine + for _, routine := range r.dataset.Routines() { + _ = routine + routineList = append(routineList, &bigqueryv2.Routine{}) + } + return &bigqueryv2.ListRoutinesResponse{ + Routines: routineList, + }, fmt.Errorf("unsupported bigquery.routines.list") +} + +func (h *routinesUpdateHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + server := serverFromContext(ctx) + project := projectFromContext(ctx) + dataset := datasetFromContext(ctx) + routine := routineFromContext(ctx) + res, err := h.Handle(ctx, &routinesUpdateRequest{ + server: server, + project: project, + dataset: dataset, + routine: routine, + }) + if err != nil { + errorResponse(ctx, w, errInternalError(err.Error())) + return + } + encodeResponse(ctx, w, res) +} + +type routinesUpdateRequest struct { + server *Server + project *metadata.Project + dataset *metadata.Dataset + routine *metadata.Routine +} + +func (h *routinesUpdateHandler) Handle(ctx context.Context, r *routinesUpdateRequest) (*bigqueryv2.Routine, error) { + return nil, fmt.Errorf("unsupported bigquery.routines.update") +} + +func (h *rowAccessPoliciesGetIamPolicyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + server := serverFromContext(ctx) + res, err := h.Handle(ctx, &rowAccessPoliciesGetIamPolicyRequest{ + server: server, + }) + if err != nil { + errorResponse(ctx, w, errInternalError(err.Error())) + return + } + encodeResponse(ctx, w, res) +} + +type rowAccessPoliciesGetIamPolicyRequest struct { + server *Server +} + +func (h *rowAccessPoliciesGetIamPolicyHandler) Handle(ctx context.Context, r *rowAccessPoliciesGetIamPolicyRequest) (*bigqueryv2.Policy, error) { + return nil, fmt.Errorf("unsupported bigquery.rowAccessPolicies.getIamPolicy") +} + +func (h *rowAccessPoliciesListHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + server := serverFromContext(ctx) + project := projectFromContext(ctx) + dataset := datasetFromContext(ctx) + table := tableFromContext(ctx) + res, err := h.Handle(ctx, &rowAccessPoliciesListRequest{ + server: server, + project: project, + dataset: dataset, + table: table, + }) + if err != nil { + errorResponse(ctx, w, errInternalError(err.Error())) + return + } + encodeResponse(ctx, w, res) +} + +type rowAccessPoliciesListRequest struct { + server *Server + project *metadata.Project + dataset *metadata.Dataset + table *metadata.Table +} + +func (h *rowAccessPoliciesListHandler) Handle(ctx context.Context, r *rowAccessPoliciesListRequest) (*bigqueryv2.ListRowAccessPoliciesResponse, error) { + return nil, fmt.Errorf("unsupported bigquery.rowAccessPolicies.list") +} + +func (h *rowAccessPoliciesSetIamPolicyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + server := serverFromContext(ctx) + res, err := h.Handle(ctx, &rowAccessPoliciesSetIamPolicyRequest{ + server: server, + }) + if err != nil { + errorResponse(ctx, w, errInternalError(err.Error())) + return + } + encodeResponse(ctx, w, res) +} + +type rowAccessPoliciesSetIamPolicyRequest struct { + server *Server +} + +func (h *rowAccessPoliciesSetIamPolicyHandler) Handle(ctx context.Context, r *rowAccessPoliciesSetIamPolicyRequest) (*bigqueryv2.Policy, error) { + return nil, fmt.Errorf("unsupported bigquery.rowAccessPolicies.setIamPolicy") +} + +func (h *rowAccessPoliciesTestIamPermissionsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + server := serverFromContext(ctx) + res, err := h.Handle(ctx, &rowAccessPoliciesTestIamPermissionsRequest{ + server: server, + }) + if err != nil { + errorResponse(ctx, w, errInternalError(err.Error())) + return + } + encodeResponse(ctx, w, res) +} + +type rowAccessPoliciesTestIamPermissionsRequest struct { + server *Server +} + +func (h *rowAccessPoliciesTestIamPermissionsHandler) Handle(ctx context.Context, r *rowAccessPoliciesTestIamPermissionsRequest) (*bigqueryv2.TestIamPermissionsResponse, error) { + return nil, fmt.Errorf("unsupported bigquery.rowAccessPolicies.testIamPermissions") +} + +func (h *tabledataInsertAllHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + server := serverFromContext(ctx) + project := projectFromContext(ctx) + dataset := datasetFromContext(ctx) + table := tableFromContext(ctx) + var req bigqueryv2.TableDataInsertAllRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + errorResponse(ctx, w, errInvalid(err.Error())) + return + } + res, err := h.Handle(ctx, &tabledataInsertAllRequest{ + server: server, + project: project, + dataset: dataset, + table: table, + req: &req, + }) + if err != nil { + errorResponse(ctx, w, errInternalError(err.Error())) + return + } + encodeResponse(ctx, w, res) +} + +type tabledataInsertAllRequest struct { + server *Server + project *metadata.Project + dataset *metadata.Dataset + table *metadata.Table + req *bigqueryv2.TableDataInsertAllRequest +} + +func normalizeInsertValue(v interface{}, field *bigqueryv2.TableFieldSchema) (interface{}, error) { + rv := reflect.ValueOf(v) + kind := rv.Kind() + if field.Mode == "REPEATED" { + if kind != reflect.Slice && kind != reflect.Array { + return nil, fmt.Errorf("invalid value type %T for ARRAY column", v) + } + values := make([]interface{}, 0, rv.Len()) + for i := 0; i < rv.Len(); i++ { + value, err := normalizeInsertValue(rv.Index(i).Interface(), &bigqueryv2.TableFieldSchema{ + Fields: field.Fields, + }) + if err != nil { + return nil, err + } + values = append(values, value) + } + return values, nil + } + if kind == reflect.Map { + fieldMap := map[string]*bigqueryv2.TableFieldSchema{} + for _, f := range field.Fields { + fieldMap[f.Name] = f + } + columnNameToValueMap := map[string]interface{}{} + for _, key := range rv.MapKeys() { + if key.Kind() != reflect.String { + return nil, fmt.Errorf("invalid value type %s for STRUCT column", key.Kind()) + } + columnName := key.Interface().(string) + value, err := normalizeInsertValue(rv.MapIndex(key).Interface(), fieldMap[columnName]) + if err != nil { + return nil, err + } + columnNameToValueMap[columnName] = value + } + fields := make([]map[string]interface{}, 0, len(fieldMap)) + for _, f := range field.Fields { + value, exists := columnNameToValueMap[f.Name] + if !exists { + return nil, fmt.Errorf("failed to find value from %s", f.Name) + } + fields = append(fields, map[string]interface{}{f.Name: value}) + } + return fields, nil + } + return v, nil +} + +func (h *tabledataInsertAllHandler) Handle(ctx context.Context, r *tabledataInsertAllRequest) (*bigqueryv2.TableDataInsertAllResponse, error) { + content, err := r.table.Content() + if err != nil { + return nil, err + } + var insertErrors []*bigqueryv2.TableDataInsertAllResponseInsertErrors + data := types.Data{} + for i, row := range r.req.Rows { + // A row that carries fields absent from the table schema is rejected + // unless ignoreUnknownValues is set; it is reported per-row and not + // inserted, while the remaining rows still go in. + if !r.req.IgnoreUnknownValues { + if unknown := types.ValidateRowFields(content.Schema, row.Json); len(unknown) > 0 { + errs := make([]*bigqueryv2.ErrorProto, 0, len(unknown)) + for _, name := range unknown { + errs = append(errs, &bigqueryv2.ErrorProto{ + Reason: "invalid", + Location: name, + Message: fmt.Sprintf("no such field: %s.", name), + }) + } + insertErrors = append(insertErrors, &bigqueryv2.TableDataInsertAllResponseInsertErrors{ + Index: int64(i), + Errors: errs, + }) + continue + } + } + rowData := map[string]interface{}{} + for k, v := range row.Json { + rowData[k] = v + } + data = append(data, rowData) + } + if len(data) > 0 { + tableDef, err := types.NewTableWithSchema(content, data) + if err != nil { + return nil, err + } + conn, err := r.server.connMgr.Connection(ctx, r.project.ID, r.dataset.ID) + if err != nil { + return nil, fmt.Errorf("failed to get connection: %w", err) + } + tx, err := conn.Begin(ctx) + if err != nil { + return nil, err + } + defer tx.RollbackIfNotCommitted() + if err := r.server.contentRepo.AddTableData(ctx, tx, r.project.ID, r.dataset.ID, tableDef); err != nil { + return nil, err + } + if err := tx.Commit(); err != nil { + return nil, err + } + } + return &bigqueryv2.TableDataInsertAllResponse{InsertErrors: insertErrors}, nil +} + +func (h *tabledataListHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + server := serverFromContext(ctx) + project := projectFromContext(ctx) + dataset := datasetFromContext(ctx) + table := tableFromContext(ctx) + res, err := h.Handle(ctx, &tabledataListRequest{ + server: server, + project: project, + dataset: dataset, + table: table, + useInt64Timestamp: isFormatOptionsUseInt64Timestamp(r), + }) + if err != nil { + errorResponse(ctx, w, errInternalError(err.Error())) + return + } + encodeResponse(ctx, w, res) +} + +type tabledataListRequest struct { + server *Server + project *metadata.Project + dataset *metadata.Dataset + table *metadata.Table + useInt64Timestamp bool +} + +func (h *tabledataListHandler) Handle(ctx context.Context, r *tabledataListRequest) (*internaltypes.TableDataList, error) { + conn, err := r.server.connMgr.Connection(ctx, r.project.ID, r.dataset.ID) + if err != nil { + return nil, fmt.Errorf("failed to get connection: %w", err) + } + tx, err := conn.Begin(ctx) + if err != nil { + return nil, fmt.Errorf("failed to start transaction: %w", err) + } + defer tx.RollbackIfNotCommitted() + response, err := r.server.contentRepo.Query( + ctx, + tx, + r.project.ID, + r.dataset.ID, + fmt.Sprintf("SELECT * FROM `%s`", r.table.ID), + nil, + ) + if err != nil { + return nil, err + } + + return &internaltypes.TableDataList{ + Rows: internaltypes.Format(response.Schema, response.Rows, r.useInt64Timestamp), + TotalRows: response.TotalRows, + }, nil +} + +func (h *tablesDeleteHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + server := serverFromContext(ctx) + project := projectFromContext(ctx) + dataset := datasetFromContext(ctx) + table := tableFromContext(ctx) + if err := h.Handle(ctx, &tablesDeleteRequest{ + server: server, + project: project, + dataset: dataset, + table: table, + }); err != nil { + errorResponse(ctx, w, errInternalError(err.Error())) + } +} + +type tablesDeleteRequest struct { + server *Server + project *metadata.Project + dataset *metadata.Dataset + table *metadata.Table +} + +func (h *tablesDeleteHandler) Handle(ctx context.Context, r *tablesDeleteRequest) error { + conn, err := r.server.connMgr.Connection(ctx, r.project.ID, r.dataset.ID) + if err != nil { + return err + } + tx, err := conn.Begin(ctx) + if err != nil { + return err + } + defer tx.RollbackIfNotCommitted() + // delete table metadata + if err := r.table.Delete(ctx, tx.Tx()); err != nil { + return err + } + // delete table + if err := r.server.contentRepo.DeleteTables( + ctx, + tx, + r.project.ID, + r.dataset.ID, + []contentdata.TableDeletion{{ID: r.table.ID, IsView: r.table.IsView()}}, + ); err != nil { + return fmt.Errorf("failed to delete table %s: %w", r.table.ID, err) + } + if err := tx.Commit(); err != nil { + return err + } + return nil +} + +func (h *tablesGetHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + server := serverFromContext(ctx) + project := projectFromContext(ctx) + dataset := datasetFromContext(ctx) + table := tableFromContext(ctx) + res, err := h.Handle(ctx, &tablesGetRequest{ + server: server, + project: project, + dataset: dataset, + table: table, + }) + if err != nil { + errorResponse(ctx, w, errInternalError(err.Error())) + return + } + encodeResponse(ctx, w, res) +} + +type tablesGetRequest struct { + server *Server + project *metadata.Project + dataset *metadata.Dataset + table *metadata.Table +} + +func (h *tablesGetHandler) Handle(ctx context.Context, r *tablesGetRequest) (*bigqueryv2.Table, error) { + table, err := r.table.Content() + if err != nil { + return nil, fmt.Errorf("failed to get table content: %w", err) + } + // Populate NumRows from the backing table so clients that depend on it + // (e.g. Table.getNumRows) observe an accurate count. Views and external + // tables have no backing row store and are left untouched. + if table.Type == "" || table.Type == "TABLE" { + if numRows, err := h.countRows(ctx, r); err == nil { + table.NumRows = uint64(numRows) + } + } + return table, nil +} + +func (h *tablesGetHandler) countRows(ctx context.Context, r *tablesGetRequest) (int64, error) { + conn, err := r.server.connMgr.Connection(ctx, r.project.ID, r.dataset.ID) + if err != nil { + return 0, err + } + tx, err := conn.Begin(ctx) + if err != nil { + return 0, err + } + defer tx.RollbackIfNotCommitted() + count, err := r.server.contentRepo.CountTableRows(ctx, tx, r.project.ID, r.dataset.ID, r.table.ID) + if err != nil { + return 0, err + } + if err := tx.Commit(); err != nil { + return 0, err + } + return count, nil +} + +func (h *tablesGetIamPolicyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + server := serverFromContext(ctx) + var req bigqueryv2.GetIamPolicyRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + errorResponse(ctx, w, errInvalid(err.Error())) + return + } + res, err := h.Handle(ctx, &tablesGetIamPolicyRequest{ + server: server, + req: &req, + }) + if err != nil { + errorResponse(ctx, w, errInternalError(err.Error())) + return + } + encodeResponse(ctx, w, res) +} + +type tablesGetIamPolicyRequest struct { + server *Server + req *bigqueryv2.GetIamPolicyRequest +} + +func (h *tablesGetIamPolicyHandler) Handle(ctx context.Context, r *tablesGetIamPolicyRequest) (*bigqueryv2.Policy, error) { + return nil, fmt.Errorf("bigquery.tables.getIamPolicy") +} + +func (h *tablesInsertHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + server := serverFromContext(ctx) + project := projectFromContext(ctx) + dataset := datasetFromContext(ctx) + var table bigqueryv2.Table + if err := json.NewDecoder(r.Body).Decode(&table); err != nil { + errorResponse(ctx, w, errInvalid(err.Error())) + return + } + res, err := h.Handle(ctx, &tablesInsertRequest{ + server: server, + project: project, + dataset: dataset, + table: &table, + }) + if err != nil { + errorResponse(ctx, w, err) + return + } + encodeResponse(ctx, w, res) +} + +type tablesInsertRequest struct { + server *Server + project *metadata.Project + dataset *metadata.Dataset + table *bigqueryv2.Table +} + +type TableType string + +const ( + DefaultTableType TableType = "TABLE" + ViewTableType TableType = "VIEW" + ExternalTableType TableType = "EXTERNAL" + MaterializedViewTableType TableType = "MATERIALIZED_VIEW" + SnapshotTableType TableType = "SNAPSHOT" +) + +func createTableMetadata(ctx context.Context, tx *connection.Tx, server *Server, project *metadata.Project, dataset *metadata.Dataset, table *bigqueryv2.Table) (*bigqueryv2.Table, *ServerError) { + now := time.Now().Unix() + if table.TableReference == nil { + table.TableReference = &bigqueryv2.TableReference{} + } + if table.TableReference.ProjectId == "" { + table.TableReference.ProjectId = project.ID + } + if table.TableReference.DatasetId == "" { + table.TableReference.DatasetId = dataset.ID + } + table.Id = fmt.Sprintf("%s:%s.%s", project.ID, dataset.ID, table.TableReference.TableId) + table.CreationTime = now + table.LastModifiedTime = uint64(now) + table.Type = string(DefaultTableType) // TODO: need to handle other table types + if table.View != nil { + table.Type = string(ViewTableType) + } + if table.MaterializedView != nil { + table.Type = string(MaterializedViewTableType) + } + table.Kind = "bigquery#table" + table.SelfLink = fmt.Sprintf( + "http://%s/bigquery/v2/projects/%s/datasets/%s/tables/%s", + server.httpServer.Addr, + project.ID, + dataset.ID, + table.TableReference.TableId, + ) + encodedTableData, err := json.Marshal(table) + if err != nil { + return nil, errInternalError(err.Error()) + } + var tableMetadata map[string]interface{} + if err := json.Unmarshal(encodedTableData, &tableMetadata); err != nil { + return nil, errInternalError(err.Error()) + } + if err := dataset.AddTable( + ctx, + tx.Tx(), + metadata.NewTable( + server.metaRepo, + project.ID, + dataset.ID, + table.TableReference.TableId, + tableMetadata, + ), + ); err != nil { + if errors.Is(err, metadata.ErrDuplicatedTable) { + return nil, errDuplicate(err.Error()) + } + return nil, errInternalError(err.Error()) + } + return table, nil +} + +func (h *tablesInsertHandler) Handle(ctx context.Context, r *tablesInsertRequest) (*bigqueryv2.Table, *ServerError) { + if r.table.ExternalDataConfiguration != nil { + return h.handleExternalTable(ctx, r) + } + if r.table.TableReference == nil { + r.table.TableReference = &bigqueryv2.TableReference{} + } + if r.table.TableReference.ProjectId == "" { + r.table.TableReference.ProjectId = r.project.ID + } + if r.table.TableReference.DatasetId == "" { + r.table.TableReference.DatasetId = r.dataset.ID + } + if r.table.TableReference.TableId == "" { + r.table.TableReference.TableId = terminalTableID(r.table.Id) + } + + conn, err := r.server.connMgr.Connection(ctx, r.project.ID, r.dataset.ID) + if err != nil { + return nil, errInternalError(err.Error()) + } + tx, err := conn.Begin(ctx) + if err != nil { + return nil, errInternalError(err.Error()) + } + defer tx.RollbackIfNotCommitted() + + isView := r.table.View != nil || r.table.MaterializedView != nil + if isView { + // Create the view first so its resolved column schema can be read + // back and recorded in the metadata, as real BigQuery does. + if err := r.server.contentRepo.CreateView(ctx, tx, r.table); err != nil { + return nil, errInvalid(err.Error()) + } + schema, err := r.server.contentRepo.ViewSchema( + ctx, tx, r.project.ID, r.dataset.ID, r.table.TableReference.TableId, + ) + if err != nil { + return nil, errInvalid(err.Error()) + } + r.table.Schema = schema + } + table, serverErr := createTableMetadata(ctx, tx, r.server, r.project, r.dataset, r.table) + if serverErr != nil { + return nil, serverErr + } + if !isView && r.table.Schema != nil { + if err := r.server.contentRepo.CreateTable(ctx, tx, r.table); err != nil { + return nil, errInternalError(err.Error()) + } + } + if err := tx.Commit(); err != nil { + return nil, errInternalError(fmt.Errorf("failed to commit table: %w", err).Error()) + } + return table, nil +} + +func terminalTableID(id string) string { + if id == "" { + return "" + } + trimmed := id + if colon := strings.LastIndex(trimmed, ":"); colon >= 0 { + trimmed = trimmed[colon+1:] + } + if dot := strings.LastIndex(trimmed, "."); dot >= 0 { + trimmed = trimmed[dot+1:] + } + return trimmed +} + +// handleExternalTable materializes an external table's source data into a +// backing table so the table is registered with the query engine and can be +// queried. Real BigQuery reads an external table live on every query; the +// emulator snapshots the source at creation time, which is enough to make +// the table queryable. +func (h *tablesInsertHandler) handleExternalTable(ctx context.Context, r *tablesInsertRequest) (*bigqueryv2.Table, *ServerError) { + edc := r.table.ExternalDataConfiguration + tableRef := r.table.TableReference + if tableRef == nil { + return nil, errInvalid("external table is missing tableReference") + } + if tableRef.ProjectId == "" { + tableRef.ProjectId = r.project.ID + } + if tableRef.DatasetId == "" { + tableRef.DatasetId = r.dataset.ID + } + if len(edc.SourceUris) == 0 { + return nil, errInvalid("external table is missing sourceUris") + } + + // Translate the external data configuration into a load job and route it + // through the existing load pipeline, which creates the backing table and + // loads the rows (inferring the schema when autodetect is requested). + load := &bigqueryv2.JobConfigurationLoad{ + DestinationTable: tableRef, + SourceUris: edc.SourceUris, + SourceFormat: edc.SourceFormat, + Autodetect: edc.Autodetect, + Schema: edc.Schema, + CreateDisposition: "CREATE_IF_NEEDED", + } + if load.Schema == nil { + load.Schema = r.table.Schema + } + if csv := edc.CsvOptions; csv != nil { + load.Quote = csv.Quote + load.FieldDelimiter = csv.FieldDelimiter + load.SkipLeadingRows = csv.SkipLeadingRows + load.AllowJaggedRows = csv.AllowJaggedRows + load.AllowQuotedNewlines = csv.AllowQuotedNewlines + load.Encoding = csv.Encoding + } + job := &bigqueryv2.Job{ + JobReference: &bigqueryv2.JobReference{JobId: randomID(), ProjectId: r.project.ID}, + Configuration: &bigqueryv2.JobConfiguration{Load: load}, + } + if _, err := (&jobsInsertHandler{}).importFromGCS(ctx, &jobsInsertRequest{ + server: r.server, + project: r.project, + job: job, + }); err != nil { + return nil, errInvalid(fmt.Sprintf("failed to load external table data: %s", err)) + } + + // The load created a plain table; record the external configuration on + // its metadata so it round-trips on tables.get and tables.list. + table := r.dataset.Table(tableRef.TableId) + if table == nil { + return nil, errInternalError("external table backing data was not created") + } + content, err := table.Content() + if err != nil { + return nil, errInternalError(err.Error()) + } + content.ExternalDataConfiguration = edc + content.Type = string(ExternalTableType) + encoded, err := json.Marshal(content) + if err != nil { + return nil, errInternalError(err.Error()) + } + var newMetadata map[string]interface{} + if err := json.Unmarshal(encoded, &newMetadata); err != nil { + return nil, errInternalError(err.Error()) + } + conn, err := r.server.connMgr.Connection(ctx, r.project.ID, r.dataset.ID) + if err != nil { + return nil, errInternalError(err.Error()) + } + tx, err := conn.Begin(ctx) + if err != nil { + return nil, errInternalError(err.Error()) + } + defer tx.RollbackIfNotCommitted() + if err := table.Replace(ctx, tx.Tx(), newMetadata); err != nil { + return nil, errInternalError(err.Error()) + } + if err := tx.Commit(); err != nil { + return nil, errInternalError(err.Error()) + } + return content, nil +} + +func (h *tablesListHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + server := serverFromContext(ctx) + project := projectFromContext(ctx) + dataset := datasetFromContext(ctx) + res, err := h.Handle(ctx, &tablesListRequest{ + server: server, + project: project, + dataset: dataset, + }) + if err != nil { + errorResponse(ctx, w, errInternalError(err.Error())) + return + } + encodeResponse(ctx, w, res) +} + +type tablesListRequest struct { + server *Server + project *metadata.Project + dataset *metadata.Dataset +} + +func (h *tablesListHandler) Handle(ctx context.Context, r *tablesListRequest) (*bigqueryv2.TableList, error) { + var tables []*bigqueryv2.TableListTables + for _, tableID := range r.dataset.TableIDs() { + table, err := r.dataset.Table(tableID).Content() + if err != nil { + return nil, fmt.Errorf("failed to get table metadata from %s: %w", tableID, err) + } + tables = append(tables, &bigqueryv2.TableListTables{ + Clustering: table.Clustering, + CreationTime: table.CreationTime, + ExpirationTime: table.ExpirationTime, + FriendlyName: table.FriendlyName, + Id: table.Id, + Kind: table.Kind, + Labels: table.Labels, + RangePartitioning: table.RangePartitioning, + TableReference: table.TableReference, + TimePartitioning: table.TimePartitioning, + Type: table.Type, + }) + } + return &bigqueryv2.TableList{ + Tables: tables, + TotalItems: int64(len(tables)), + }, nil +} + +func (h *tablesPatchHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + server := serverFromContext(ctx) + project := projectFromContext(ctx) + dataset := datasetFromContext(ctx) + table := tableFromContext(ctx) + var newTable bigqueryv2.Table + if err := json.NewDecoder(r.Body).Decode(&newTable); err != nil { + errorResponse(ctx, w, errInvalid(err.Error())) + return + } + res, err := h.Handle(ctx, &tablesPatchRequest{ + server: server, + project: project, + dataset: dataset, + table: table, + newTable: &newTable, + }) + if err != nil { + errorResponse(ctx, w, errInternalError(err.Error())) + return + } + encodeResponse(ctx, w, res) +} + +type tablesPatchRequest struct { + server *Server + project *metadata.Project + dataset *metadata.Dataset + table *metadata.Table + newTable *bigqueryv2.Table +} + +func (h *tablesPatchHandler) Handle(ctx context.Context, r *tablesPatchRequest) (*bigqueryv2.Table, error) { + encodedTableData, err := json.Marshal(r.newTable) + if err != nil { + return nil, err + } + var tableMetadata map[string]interface{} + if err := json.Unmarshal(encodedTableData, &tableMetadata); err != nil { + return nil, err + } + + conn, err := r.server.connMgr.Connection(ctx, r.project.ID, r.dataset.ID) + if err != nil { + return nil, err + } + tx, err := conn.Begin(ctx) + if err != nil { + return nil, err + } + defer tx.RollbackIfNotCommitted() + if err := r.table.Patch(ctx, tx.Tx(), tableMetadata); err != nil { + return nil, err + } + if err := tx.Commit(); err != nil { + return nil, err + } + // Return the full, merged table resource (kind/type/id/creationTime + // included) rather than echoing the request body. + return r.table.Content() +} + +func (h *tablesSetIamPolicyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + server := serverFromContext(ctx) + res, err := h.Handle(ctx, &tablesSetIamPolicyRequest{ + server: server, + }) + if err != nil { + errorResponse(ctx, w, errInternalError(err.Error())) + return + } + encodeResponse(ctx, w, res) +} + +type tablesSetIamPolicyRequest struct { + server *Server +} + +func (h *tablesSetIamPolicyHandler) Handle(ctx context.Context, r *tablesSetIamPolicyRequest) (*bigqueryv2.Policy, error) { + return nil, fmt.Errorf("unsupported bigquery.tables.setIamPolicy") +} + +func (h *tablesTestIamPermissionsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + server := serverFromContext(ctx) + res, err := h.Handle(ctx, &tablesTestIamPermissionsRequest{ + server: server, + }) + if err != nil { + errorResponse(ctx, w, errInternalError(err.Error())) + return + } + encodeResponse(ctx, w, res) +} + +type tablesTestIamPermissionsRequest struct { + server *Server +} + +func (h *tablesTestIamPermissionsHandler) Handle(ctx context.Context, r *tablesTestIamPermissionsRequest) (*bigqueryv2.TestIamPermissionsResponse, error) { + return nil, fmt.Errorf("unsupported bigquery.tables.testIamPermissions") +} + +func (h *tablesUpdateHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + server := serverFromContext(ctx) + project := projectFromContext(ctx) + dataset := datasetFromContext(ctx) + table := tableFromContext(ctx) + var newTable bigqueryv2.Table + if err := json.NewDecoder(r.Body).Decode(&newTable); err != nil { + errorResponse(ctx, w, errInvalid(err.Error())) + return + } + res, err := h.Handle(ctx, &tablesUpdateRequest{ + server: server, + project: project, + dataset: dataset, + table: table, + newTable: &newTable, + }) + if err != nil { + errorResponse(ctx, w, errInternalError(err.Error())) + return + } + encodeResponse(ctx, w, res) +} + +type tablesUpdateRequest struct { + server *Server + project *metadata.Project + dataset *metadata.Dataset + table *metadata.Table + newTable *bigqueryv2.Table +} + +func (h *tablesUpdateHandler) Handle(ctx context.Context, r *tablesUpdateRequest) (*bigqueryv2.Table, error) { + encodedTableData, err := json.Marshal(r.newTable) + if err != nil { + return nil, err + } + var tableMetadata map[string]interface{} + if err := json.Unmarshal(encodedTableData, &tableMetadata); err != nil { + return nil, err + } + + conn, err := r.server.connMgr.Connection(ctx, r.project.ID, r.dataset.ID) + if err != nil { + return nil, err + } + tx, err := conn.Begin(ctx) + if err != nil { + return nil, err + } + defer tx.RollbackIfNotCommitted() + if err := r.table.Replace(ctx, tx.Tx(), tableMetadata); err != nil { + return nil, err + } + if err := tx.Commit(); err != nil { + return nil, err + } + return r.table.Content() +} + +type defaultHandler struct{} + +func (h *defaultHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + errorResponse(ctx, w, errInternalError(fmt.Sprintf("unexpected request path: %s", html.EscapeString(r.URL.Path)))) +} diff --git a/server/handler_tableid_test.go b/server/handler_tableid_test.go new file mode 100644 index 000000000..84d71bd21 --- /dev/null +++ b/server/handler_tableid_test.go @@ -0,0 +1,25 @@ +package server + +import "testing" + +func TestTerminalTableID(t *testing.T) { + tests := []struct { + name string + in string + want string + }{ + {name: "empty", in: "", want: ""}, + {name: "table only", in: "users", want: "users"}, + {name: "dataset table", in: "sales.users", want: "users"}, + {name: "project dataset table", in: "myproj:sales.users", want: "users"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := terminalTableID(tt.in) + if got != tt.want { + t.Fatalf("terminalTableID(%q) = %q, want %q", tt.in, got, tt.want) + } + }) + } +} diff --git a/server/server_test.go b/server/server_test.go index 92e8c5615..c246e001a 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -4797,3 +4797,128 @@ func TestExportDataStatement(t *testing.T) { } }) } + +// TestRepeatedRecordFieldOrder verifies that REPEATED RECORD (array-of-struct) +// columns preserve field ordering when data is inserted via the streaming-insert +// API and then read back via a SQL query. This is a regression guard for a bug +// where field values were assigned to the wrong sub-field non-deterministically. +func TestRepeatedRecordFieldOrder(t *testing.T) { + const ( + projectID = "test" + datasetID = "dataset1" + tableID = "repeated_rec" + ) + ctx := context.Background() + bqServer, err := server.New(server.TempStorage) + if err != nil { + t.Fatal(err) + } + if err := bqServer.Load(server.StructSource( + types.NewProject(projectID, + types.NewDataset(datasetID, + types.NewTable(tableID, + []*types.Column{ + types.NewColumn("id", types.INT64), + types.NewColumn( + "tags", + types.STRUCT, + types.ColumnFields( + types.NewColumn("label", types.STRING), + types.NewColumn("score", types.INT64), + ), + types.ColumnMode(types.RepeatedMode), + ), + }, + nil, + ), + ), + ), + )); err != nil { + t.Fatal(err) + } + testServer := bqServer.TestServer() + defer func() { + testServer.Close() + bqServer.Close() + }() + + client, err := bigquery.NewClient(ctx, projectID, + option.WithEndpoint(testServer.URL), + option.WithoutAuthentication(), + ) + if err != nil { + t.Fatal(err) + } + defer client.Close() + + type tagRow struct { + Label string + Score int64 + } + type row struct { + ID int64 + Tags []tagRow + } + inserter := client.Dataset(datasetID).Table(tableID).Inserter() + saver := &bigquery.StructSaver{ + Schema: bigquery.Schema{ + {Name: "id", Type: bigquery.IntegerFieldType}, + {Name: "tags", Type: bigquery.RecordFieldType, Repeated: true, Schema: bigquery.Schema{ + {Name: "label", Type: bigquery.StringFieldType}, + {Name: "score", Type: bigquery.IntegerFieldType}, + }}, + }, + InsertID: "", + Struct: &row{ + ID: 1, + Tags: []tagRow{ + {Label: "alpha", Score: 10}, + {Label: "beta", Score: 20}, + }, + }, + } + if err := inserter.Put(ctx, saver); err != nil { + t.Fatalf("insert: %v", err) + } + + it, err := client.Query( + "SELECT id, tags FROM `" + projectID + "." + datasetID + "." + tableID + "` WHERE id = 1", + ).Read(ctx) + if err != nil { + t.Fatalf("query: %v", err) + } + var row0 []bigquery.Value + if err := it.Next(&row0); err != nil { + t.Fatalf("next: %v", err) + } + if len(row0) != 2 { + t.Fatalf("expected 2 columns, got %d", len(row0)) + } + tags, ok := row0[1].([]bigquery.Value) + if !ok { + t.Fatalf("tags field type %T, want []bigquery.Value", row0[1]) + } + if len(tags) != 2 { + t.Fatalf("expected 2 tags, got %d", len(tags)) + } + for i, want := range []struct { + label string + score int64 + }{{"alpha", 10}, {"beta", 20}} { + fields, ok := tags[i].([]bigquery.Value) + if !ok { + t.Fatalf("tags[%d] type %T, want []bigquery.Value", i, tags[i]) + } + if len(fields) != 2 { + t.Fatalf("tags[%d] length %d, want 2", i, len(fields)) + } + label, _ := fields[0].(string) + score, _ := fields[1].(int64) + if label != want.label { + t.Errorf("tags[%d].label = %q, want %q", i, label, want.label) + } + if score != want.score { + t.Errorf("tags[%d].score = %d, want %d", i, score, want.score) + } + } +} diff --git a/server/storage_handler.go b/server/storage_handler.go index 42a44e2e0..e893a9ad3 100644 --- a/server/storage_handler.go +++ b/server/storage_handler.go @@ -1,1019 +1,1052 @@ -package server - -import ( - "bytes" - "context" - "fmt" - "io" - "strings" - "sync" - "time" - - storagepb "cloud.google.com/go/bigquery/storage/apiv1/storagepb" - "github.com/apache/arrow-go/v18/arrow" - "github.com/apache/arrow-go/v18/arrow/array" - "github.com/apache/arrow-go/v18/arrow/ipc" - "github.com/apache/arrow-go/v18/arrow/memory" - "github.com/goccy/go-json" - goavro "github.com/linkedin/goavro/v2" - bigqueryv2 "google.golang.org/api/bigquery/v2" - "google.golang.org/genproto/googleapis/rpc/status" - "google.golang.org/grpc" - "google.golang.org/grpc/codes" - "google.golang.org/protobuf/proto" - "google.golang.org/protobuf/reflect/protodesc" - "google.golang.org/protobuf/reflect/protoreflect" - "google.golang.org/protobuf/types/descriptorpb" - "google.golang.org/protobuf/types/dynamicpb" - "google.golang.org/protobuf/types/known/timestamppb" - "google.golang.org/protobuf/types/known/wrapperspb" - - "github.com/goccy/bigquery-emulator/internal/connection" - "github.com/goccy/bigquery-emulator/internal/logger" - internaltypes "github.com/goccy/bigquery-emulator/internal/types" - "github.com/goccy/bigquery-emulator/types" -) - -var ( - _ storagepb.BigQueryReadServer = &storageReadServer{} - _ storagepb.BigQueryWriteServer = &storageWriteServer{} -) - -type storageReadServer struct { - server *Server - streamMap map[string]*readStreamStatus - mu sync.RWMutex -} - -type readStreamStatus struct { - projectID string - datasetID string - tableID string - outputColumns []string - condition string - dataFormat storagepb.DataFormat - avroSchema *types.AVROSchema - arrowSchema *arrow.Schema - schemaText string -} - -type AVROSchema struct { - ReadSessionSchema *storagepb.ReadSession_AvroSchema - Schema *types.AVROSchema - Text string -} - -type ARROWSchema struct { - ReadSessionSchema *storagepb.ReadSession_ArrowSchema - Schema *arrow.Schema - Text string -} - -func (s *storageReadServer) CreateReadSession(ctx context.Context, req *storagepb.CreateReadSessionRequest) (*storagepb.ReadSession, error) { - sessionID := randomID() - sessionName := fmt.Sprintf("%s/locations/%s/sessions/%s", req.Parent, "location", sessionID) - projectID, datasetID, tableID, err := getIDsFromPath(req.ReadSession.Table) - if err != nil { - return nil, err - } - tableMetadata, err := getTableMetadata(ctx, s.server, projectID, datasetID, tableID) - if err != nil { - return nil, fmt.Errorf("failed to get table metadata: %w", err) - } - // MaxStreamCount is documented as a request for an upper bound, with 0 - // asking the server to choose a sensible default. Real BigQuery always - // returns at least one stream when data is available; the emulator only - // supports a single stream, so treat any non-1 request (including the - // unset 0) as 1 rather than producing a session with zero streams. - if req.MaxStreamCount > 1 { - return nil, fmt.Errorf("currently supported only one stream") - } - streamID := randomID() - streamName := fmt.Sprintf("%s/streams/%s", sessionName, streamID) - streams := []*storagepb.ReadStream{{Name: streamName}} - readSession := &storagepb.ReadSession{ - Name: sessionName, - ExpireTime: timestamppb.New(time.Now().Add(1 * time.Hour)), - Streams: streams, - EstimatedTotalBytesScanned: 0, - DataFormat: req.ReadSession.DataFormat, - Table: req.ReadSession.Table, - ReadOptions: req.ReadSession.ReadOptions, - TableModifiers: req.ReadSession.TableModifiers, - TraceId: req.ReadSession.TraceId, - } - // ReadOptions is optional: a client may create a read session - // without selected fields or a row restriction to read the whole - // table. Guard the dereference so an absent ReadOptions reads - // every column with no filter instead of panicking. - var ( - outputColumns []string - condition string - ) - if readOptions := req.ReadSession.ReadOptions; readOptions != nil { - outputColumns = readOptions.SelectedFields - condition = readOptions.RowRestriction - } - outputColumnMap := map[string]struct{}{} - for _, outputColumn := range outputColumns { - outputColumnMap[outputColumn] = struct{}{} - } - status := &readStreamStatus{ - projectID: projectID, - datasetID: datasetID, - tableID: tableID, - outputColumns: outputColumns, - condition: condition, - dataFormat: readSession.DataFormat, - } - switch readSession.DataFormat { - case storagepb.DataFormat_AVRO: - schema, err := s.getAVROSchema(tableMetadata, outputColumnMap) - if err != nil { - return nil, err - } - readSession.Schema = schema.ReadSessionSchema - status.avroSchema = schema.Schema - status.schemaText = schema.Text - case storagepb.DataFormat_ARROW: - schema, err := s.getARROWSchema(tableMetadata, outputColumnMap) - if err != nil { - return nil, err - } - readSession.Schema = schema.ReadSessionSchema - status.arrowSchema = schema.Schema - status.schemaText = schema.Text - default: - return nil, fmt.Errorf("unexpected data format %s", readSession.DataFormat) - } - s.mu.Lock() - s.streamMap[streamName] = status - s.mu.Unlock() - return readSession, nil -} - -func (s *storageReadServer) ReadRows(req *storagepb.ReadRowsRequest, stream storagepb.BigQueryRead_ReadRowsServer) error { - s.mu.RLock() - status := s.streamMap[req.ReadStream] - s.mu.RUnlock() - - if status == nil { - return fmt.Errorf("failed to find stream status from %s", req.ReadStream) - } - ctx := context.Background() - ctx = logger.WithLogger(ctx, s.server.logger) - - response, err := s.query(ctx, status) - if err != nil { - return err - } - switch status.dataFormat { - case storagepb.DataFormat_AVRO: - if err := s.sendAVRORows(status, response, stream); err != nil { - return err - } - case storagepb.DataFormat_ARROW: - if err := s.sendARROWRows(status, response, stream); err != nil { - return err - } - } - return nil -} - -func (s *storageReadServer) SplitReadStream(ctx context.Context, req *storagepb.SplitReadStreamRequest) (*storagepb.SplitReadStreamResponse, error) { - return nil, fmt.Errorf("unimplemented split read stream") -} - -func (s *storageReadServer) buildQuery(status *readStreamStatus) string { - var columns string - if len(status.outputColumns) != 0 { - outputColumns := make([]string, len(status.outputColumns)) - for idx, outputColumn := range status.outputColumns { - outputColumns[idx] = fmt.Sprintf("`%s`", outputColumn) - } - columns = strings.Join(outputColumns, ",") - } else { - columns = "*" - } - var condition string - if status.condition != "" { - condition = fmt.Sprintf("WHERE %s", status.condition) - } - return fmt.Sprintf("SELECT %s FROM `%s` %s", columns, status.tableID, condition) -} - -func (s *storageReadServer) query(ctx context.Context, status *readStreamStatus) (*internaltypes.QueryResponse, error) { - conn, err := s.server.connMgr.Connection(ctx, status.projectID, status.datasetID) - if err != nil { - return nil, fmt.Errorf("failed to get connection: %w", err) - } - tx, err := conn.Begin(ctx) - if err != nil { - return nil, fmt.Errorf("failed to start transaction: %w", err) - } - defer tx.RollbackIfNotCommitted() - - query := s.buildQuery(status) - return s.server.contentRepo.Query( - ctx, - tx, - status.projectID, - status.datasetID, - query, - nil, - ) -} - -func (s *storageReadServer) getAVROSchema(tableMetadata *bigqueryv2.Table, outputColumnMap map[string]struct{}) (*AVROSchema, error) { - avroSchema := types.TableToAVRO(tableMetadata) - if len(outputColumnMap) != 0 { - filteredFields := make([]*types.AVROFieldSchema, 0, len(avroSchema.Fields)) - for _, field := range avroSchema.Fields { - if _, exists := outputColumnMap[field.Name]; exists { - filteredFields = append(filteredFields, field) - } - } - avroSchema.Fields = filteredFields - } - schema, err := json.Marshal(avroSchema) - if err != nil { - return nil, err - } - schemaText := string(schema) - return &AVROSchema{ - ReadSessionSchema: &storagepb.ReadSession_AvroSchema{ - AvroSchema: &storagepb.AvroSchema{Schema: schemaText}, - }, - Schema: avroSchema, - Text: schemaText, - }, nil -} - -func (s *storageReadServer) sendAVRORows(status *readStreamStatus, response *internaltypes.QueryResponse, stream storagepb.BigQueryRead_ReadRowsServer) error { - codec, err := goavro.NewCodec(status.schemaText) - if err != nil { - return fmt.Errorf("failed to create avro codec from schema %s: %w", status.schemaText, err) - } - var buf []byte - for _, row := range response.Rows { - value, err := row.AVROValue(status.avroSchema.Namespace, status.avroSchema.Fields) - if err != nil { - return fmt.Errorf("failed to convert response fields to avro value: %w", err) - } - b, err := codec.BinaryFromNative(buf, value) - if err != nil { - return fmt.Errorf("failed to encode binary from go value: %w", err) - } - buf = b - } - rows := &storagepb.ReadRowsResponse_AvroRows{ - AvroRows: &storagepb.AvroRows{ - SerializedBinaryRows: buf, - RowCount: int64(response.TotalRows), - }, - } - if err := stream.Send(&storagepb.ReadRowsResponse{ - Rows: rows, - RowCount: int64(response.TotalRows), - Schema: &storagepb.ReadRowsResponse_AvroSchema{ - AvroSchema: &storagepb.AvroSchema{ - Schema: status.schemaText, - }, - }, - }); err != nil { - return fmt.Errorf("failed to send read rows response for avro format: %w", err) - } - return nil -} - -func (s *storageReadServer) getARROWSchema(tableMetadata *bigqueryv2.Table, outputColumnMap map[string]struct{}) (*ARROWSchema, error) { - arrowSchema, err := types.TableToARROW(tableMetadata) - if err != nil { - return nil, err - } - if len(outputColumnMap) != 0 { - filteredFields := make([]arrow.Field, 0, len(arrowSchema.Fields())) - for _, field := range arrowSchema.Fields() { - if _, exists := outputColumnMap[field.Name]; exists { - filteredFields = append(filteredFields, field) - } - } - arrowSchema = arrow.NewSchema(filteredFields, nil) - } - schemaText := arrowSchema.String() - schema, err := s.getSerializedARROWSchema(arrowSchema) - if err != nil { - return nil, err - } - return &ARROWSchema{ - ReadSessionSchema: &storagepb.ReadSession_ArrowSchema{ - ArrowSchema: &storagepb.ArrowSchema{ - SerializedSchema: schema, - }, - }, - Schema: arrowSchema, - Text: schemaText, - }, nil -} - -func (s *storageReadServer) getSerializedARROWSchema(schema *arrow.Schema) ([]byte, error) { - mem := memory.NewGoAllocator() - buf := new(bytes.Buffer) - writer := ipc.NewWriter(buf, ipc.WithAllocator(mem), ipc.WithSchema(schema)) - builder := array.NewRecordBuilder(mem, schema) - defer builder.Release() - record := builder.NewRecord() - if err := writer.Write(record); err != nil { - return nil, err - } - record.Release() - if err := writer.Close(); err != nil { - return nil, err - } - return buf.Bytes(), nil -} - -func (s *storageReadServer) sendARROWRows(status *readStreamStatus, response *internaltypes.QueryResponse, stream storagepb.BigQueryRead_ReadRowsServer) error { - schema, err := s.getSerializedARROWSchema(status.arrowSchema) - if err != nil { - return err - } - mem := memory.NewGoAllocator() - builder := array.NewRecordBuilder(mem, status.arrowSchema) - defer builder.Release() - for _, row := range response.Rows { - if err := row.AppendValueToARROWBuilder(builder); err != nil { - return err - } - } - record := builder.NewRecord() - buf := new(bytes.Buffer) - writer := ipc.NewWriter(buf, ipc.WithAllocator(mem), ipc.WithSchema(status.arrowSchema)) - if err := writer.Write(record); err != nil { - return err - } - record.Release() - if err := writer.Close(); err != nil { - return err - } - rows := &storagepb.ReadRowsResponse_ArrowRecordBatch{ - ArrowRecordBatch: &storagepb.ArrowRecordBatch{ - SerializedRecordBatch: buf.Bytes(), - RowCount: int64(response.TotalRows), - }, - } - if err := stream.Send(&storagepb.ReadRowsResponse{ - Rows: rows, - RowCount: int64(response.TotalRows), - Schema: &storagepb.ReadRowsResponse_ArrowSchema{ - ArrowSchema: &storagepb.ArrowSchema{ - SerializedSchema: schema, - }, - }, - }); err != nil { - return fmt.Errorf("failed to send read rows response for arrow format: %w", err) - } - return nil -} - -type storageWriteServer struct { - server *Server - streamMap map[string]*writeStreamStatus - mu sync.RWMutex -} - -type writeStreamStatus struct { - mu sync.Mutex - streamType storagepb.WriteStream_Type - stream *storagepb.WriteStream - projectID string - datasetID string - tableID string - tableMetadata *bigqueryv2.Table - rows types.Data - finalized bool -} - -func newWriteStreamStatus(streamName string, streamType storagepb.WriteStream_Type, projectID, datasetID, tableID string, tableMetadata *bigqueryv2.Table) *writeStreamStatus { - createTime := timestamppb.New(time.Now()) - var commitTime *timestamppb.Timestamp - if streamType == storagepb.WriteStream_COMMITTED { - commitTime = createTime - } - schema := types.TableToProto(tableMetadata) - stream := &storagepb.WriteStream{ - Name: streamName, - Type: streamType, - CreateTime: createTime, - CommitTime: commitTime, - TableSchema: schema, - WriteMode: storagepb.WriteStream_INSERT, - } - return &writeStreamStatus{ - streamType: streamType, - stream: stream, - projectID: projectID, - datasetID: datasetID, - tableID: tableID, - tableMetadata: tableMetadata, - } -} - -func (s *writeStreamStatus) ensureAppendable() error { - s.mu.Lock() - defer s.mu.Unlock() - if s.finalized { - return fmt.Errorf("stream is already finalized") - } - return nil -} - -func (s *writeStreamStatus) appendBufferedRows(data types.Data) error { - s.mu.Lock() - defer s.mu.Unlock() - if s.finalized { - return fmt.Errorf("stream is already finalized") - } - s.rows = append(s.rows, data...) - return nil -} - -func (s *storageWriteServer) CreateWriteStream(ctx context.Context, req *storagepb.CreateWriteStreamRequest) (*storagepb.WriteStream, error) { - projectID, datasetID, tableID, err := getIDsFromPath(req.Parent) - if err != nil { - return nil, err - } - tableMetadata, err := getTableMetadata(ctx, s.server, projectID, datasetID, tableID) - if err != nil { - return nil, fmt.Errorf("failed to get table metadata: %w", err) - } - streamID := randomID() - streamName := fmt.Sprintf("%s/streams/%s", req.Parent, streamID) - streamType := req.GetWriteStream().GetType() - streamStatus := newWriteStreamStatus(streamName, streamType, projectID, datasetID, tableID, tableMetadata) - s.mu.Lock() - s.streamMap[streamName] = streamStatus - s.mu.Unlock() - return streamStatus.stream, nil -} - -func (s *storageWriteServer) AppendRows(stream storagepb.BigQueryWrite_AppendRowsServer) error { - req, err := stream.Recv() - if err == io.EOF { - return nil - } - if err != nil { - return err - } - msgDesc, err := s.getMessageDescriptor(req) - if err != nil { - return err - } - status, streamName, err := s.appendRows(req, msgDesc, stream, nil, "") - if err != nil { - return fmt.Errorf("failed to append rows: %w", err) - } - for { - req, err := stream.Recv() - if err == io.EOF { - break - } - if err != nil { - return err - } - status, streamName, err = s.appendRows(req, msgDesc, stream, status, streamName) - if err != nil { - return fmt.Errorf("failed to append rows: %w", err) - } - } - return nil -} - -func (s *storageWriteServer) getMessageDescriptor(req *storagepb.AppendRowsRequest) (protoreflect.MessageDescriptor, error) { - descProto := req.GetProtoRows().GetWriterSchema().GetProtoDescriptor() - fdProto := &descriptorpb.FileDescriptorProto{ - Name: proto.String("proto"), - MessageType: []*descriptorpb.DescriptorProto{ - descProto, - }, - } - fd, err := protodesc.NewFile(fdProto, nil) - if err != nil { - return nil, fmt.Errorf("failed to create file descriptor: %w", err) - } - return fd.Messages().ByName(protoreflect.Name(descProto.GetName())), nil -} - -func (s *storageWriteServer) appendRows(req *storagepb.AppendRowsRequest, msgDesc protoreflect.MessageDescriptor, stream storagepb.BigQueryWrite_AppendRowsServer, fallbackStatus *writeStreamStatus, fallbackStreamName string) (*writeStreamStatus, string, error) { - streamName := req.GetWriteStream() - var status *writeStreamStatus - if streamName == "" { - status = fallbackStatus - streamName = fallbackStreamName - } else { - var err error - status, streamName, err = s.getOrCreateWriteStreamStatus(stream.Context(), streamName) - if err != nil { - return nil, "", err - } - } - if status == nil { - return nil, "", fmt.Errorf("write stream is not specified") - } - offset := int64(0) - if req.GetOffset() != nil { - offset = req.GetOffset().Value - } - rows := req.GetProtoRows().GetRows().GetSerializedRows() - data, err := s.decodeData(msgDesc, rows) - if err != nil { - s.sendErrorMessage(stream, streamName, err) - return nil, "", err - } - if status.streamType == storagepb.WriteStream_COMMITTED { - if err := status.ensureAppendable(); err != nil { - return nil, "", err - } - ctx := context.Background() - ctx = logger.WithLogger(ctx, s.server.logger) - - conn, err := s.server.connMgr.Connection(ctx, status.projectID, status.datasetID) - if err != nil { - s.sendErrorMessage(stream, streamName, err) - return nil, "", err - } - tx, err := conn.Begin(ctx) - if err != nil { - s.sendErrorMessage(stream, streamName, err) - return nil, "", err - } - defer tx.RollbackIfNotCommitted() - if err := s.insertTableData(ctx, tx, status, data); err != nil { - s.sendErrorMessage(stream, streamName, err) - return nil, "", err - } - if err := tx.Commit(); err != nil { - s.sendErrorMessage(stream, streamName, err) - return nil, "", err - } - } else { - if err := status.appendBufferedRows(data); err != nil { - return nil, "", err - } - } - if err := s.sendResult(stream, streamName, offset+int64(len(rows))); err != nil { - return nil, "", err - } - return status, streamName, nil - -} - -func (s *storageWriteServer) lookupWriteStreamStatus(streamName string) (*writeStreamStatus, string, bool) { - canonicalName := canonicalWriteStreamName(streamName) - s.mu.RLock() - status, exists := s.streamMap[canonicalName] - s.mu.RUnlock() - return status, canonicalName, exists -} - -func (s *storageWriteServer) getOrCreateWriteStreamStatus(ctx context.Context, streamName string) (*writeStreamStatus, string, error) { - status, canonicalName, exists := s.lookupWriteStreamStatus(streamName) - if exists { - return status, canonicalName, nil - } - if !isDefaultWriteStreamName(streamName) { - return nil, "", fmt.Errorf("failed to get stream from %s", streamName) - } - status, canonicalName, err := s.createDefaultStreamStatus(ctx, streamName) - if err != nil { - return nil, "", fmt.Errorf("failed to get stream from %s", streamName) - } - return status, canonicalName, nil -} - -func (s *storageWriteServer) sendResult(stream storagepb.BigQueryWrite_AppendRowsServer, streamName string, offset int64) error { - return stream.Send(&storagepb.AppendRowsResponse{ - WriteStream: streamName, - Response: &storagepb.AppendRowsResponse_AppendResult_{ - AppendResult: &storagepb.AppendRowsResponse_AppendResult{ - Offset: wrapperspb.Int64(offset), - }, - }, - }) -} - -func (s *storageWriteServer) sendErrorMessage(stream storagepb.BigQueryWrite_AppendRowsServer, streamName string, err error) error { - return stream.Send(&storagepb.AppendRowsResponse{ - WriteStream: streamName, - Response: &storagepb.AppendRowsResponse_Error{ - Error: &status.Status{ - Code: int32(codes.Internal), - Message: err.Error(), - }, - }, - }) -} - -func (s *storageWriteServer) decodeData(msgDesc protoreflect.MessageDescriptor, rows [][]byte) (types.Data, error) { - data := types.Data{} - for _, row := range rows { - msg := dynamicpb.NewMessage(msgDesc) - rowData, err := s.decodeRowData(row, msg) - if err != nil { - return nil, err - } - data = append(data, rowData) - } - return data, nil -} - -func (s *storageWriteServer) decodeRowData(data []byte, msg *dynamicpb.Message) (map[string]interface{}, error) { - if err := proto.Unmarshal(data, msg); err != nil { - return nil, fmt.Errorf("failed to decode message: %w", err) - } - ret := map[string]interface{}{} - var decodeErr error - msg.Range(func(f protoreflect.FieldDescriptor, val protoreflect.Value) bool { - v, err := s.decodeProtoReflectValue(f, val) - if err != nil { - decodeErr = err - return false - } - ret[f.TextName()] = v - return true - }) - return ret, decodeErr -} - -func (s *storageWriteServer) decodeProtoReflectValue(f protoreflect.FieldDescriptor, v protoreflect.Value) (interface{}, error) { - if f.IsList() { - list := v.List() - ret := make([]interface{}, 0, list.Len()) - if !list.IsValid() { - return ret, nil - } - for i := 0; i < list.Len(); i++ { - vv := list.Get(i) - elem, err := s.decodeProtoReflectValueFromKind(f.Kind(), vv) - if err != nil { - return nil, err - } - ret = append(ret, elem) - } - return ret, nil - } - return s.decodeProtoReflectValueFromKind(f.Kind(), v) -} - -func (s *storageWriteServer) decodeProtoReflectValueFromKind(kind protoreflect.Kind, v protoreflect.Value) (interface{}, error) { - if !v.IsValid() { - return nil, nil - } - switch kind { - case protoreflect.BoolKind: - return v.Bool(), nil - case protoreflect.EnumKind: - return v.Enum(), nil - case protoreflect.Int32Kind: - return v.Int(), nil - case protoreflect.Sint32Kind: - return v.Int(), nil - case protoreflect.Uint32Kind: - return v.Uint(), nil - case protoreflect.Int64Kind: - return v.Int(), nil - case protoreflect.Sint64Kind: - return v.Int(), nil - case protoreflect.Uint64Kind: - return v.Uint(), nil - case protoreflect.Sfixed32Kind: - return v.Int(), nil - case protoreflect.Fixed32Kind: - return v.Int(), nil - case protoreflect.FloatKind: - return v.Float(), nil - case protoreflect.Sfixed64Kind: - return v.Int(), nil - case protoreflect.Fixed64Kind: - return v.Float(), nil - case protoreflect.DoubleKind: - return v.Float(), nil - case protoreflect.StringKind: - return v.String(), nil - case protoreflect.BytesKind: - return v.Bytes(), nil - case protoreflect.MessageKind: - msg := v.Message() - // A google.protobuf scalar wrapper (wrappers.proto) carries a - // single `value` field; BigQuery maps a wrapper-typed proto - // field to its underlying scalar column, so unwrap it rather - // than rendering a one-field STRUCT. - if scalar, ok, err := s.unwrapWrapperValue(msg); ok || err != nil { - return scalar, err - } - structV := map[string]interface{}{} - var decodeErr error - msg.Range(func(f protoreflect.FieldDescriptor, val protoreflect.Value) bool { - v, err := s.decodeProtoReflectValue(f, val) - if err != nil { - decodeErr = err - return false - } - structV[f.TextName()] = v - return true - }) - return structV, decodeErr - case protoreflect.GroupKind: - return nil, fmt.Errorf("unsupported group kind for storage api") - } - return nil, fmt.Errorf("specified unknown kind") -} - -// wrapperMessageNames is the set of google.protobuf scalar wrapper messages -// (wrappers.proto). Each carries a single scalar `value` field. -var wrapperMessageNames = map[protoreflect.FullName]struct{}{ - "google.protobuf.DoubleValue": {}, - "google.protobuf.FloatValue": {}, - "google.protobuf.Int64Value": {}, - "google.protobuf.UInt64Value": {}, - "google.protobuf.Int32Value": {}, - "google.protobuf.UInt32Value": {}, - "google.protobuf.BoolValue": {}, - "google.protobuf.StringValue": {}, - "google.protobuf.BytesValue": {}, -} - -// unwrapWrapperValue reports whether msg is a google.protobuf scalar wrapper -// and, if so, returns its underlying `value`. A wrapper-typed proto field -// maps to its scalar BigQuery column, so the Storage Write API decoder must -// hand back the bare scalar instead of a single-field STRUCT, which the -// schema-driven normalizer cannot reconcile with a scalar column. -func (s *storageWriteServer) unwrapWrapperValue(msg protoreflect.Message) (interface{}, bool, error) { - if _, ok := wrapperMessageNames[msg.Descriptor().FullName()]; !ok { - return nil, false, nil - } - valueField := msg.Descriptor().Fields().ByName("value") - if valueField == nil { - return nil, false, nil - } - v, err := s.decodeProtoReflectValueFromKind(valueField.Kind(), msg.Get(valueField)) - return v, true, err -} - -func (s *storageWriteServer) insertTableData(ctx context.Context, tx *connection.Tx, status *writeStreamStatus, data types.Data) error { - tableDef, err := types.NewTableWithSchema(status.tableMetadata, data) - if err != nil { - return err - } - if err := s.server.contentRepo.AddTableData( - ctx, - tx, - status.projectID, - status.datasetID, - tableDef, - ); err != nil { - return fmt.Errorf("failed to add table data: %w", err) - } - return nil -} - -func (s *storageWriteServer) GetWriteStream(ctx context.Context, req *storagepb.GetWriteStreamRequest) (*storagepb.WriteStream, error) { - status, _, err := s.getOrCreateWriteStreamStatus(ctx, req.Name) - if err != nil { - return nil, fmt.Errorf("failed to find stream from %s", req.Name) - } - return status.stream, nil -} - -func (s *storageWriteServer) FinalizeWriteStream(ctx context.Context, req *storagepb.FinalizeWriteStreamRequest) (*storagepb.FinalizeWriteStreamResponse, error) { - status, _, exists := s.lookupWriteStreamStatus(req.GetName()) - if !exists { - return nil, fmt.Errorf("failed to get stream from %s", req.GetName()) - } - status.mu.Lock() - status.finalized = true - rowCount := int64(len(status.rows)) - status.mu.Unlock() - return &storagepb.FinalizeWriteStreamResponse{ - RowCount: rowCount, - }, nil -} - -func (s *storageWriteServer) BatchCommitWriteStreams(ctx context.Context, req *storagepb.BatchCommitWriteStreamsRequest) (*storagepb.BatchCommitWriteStreamsResponse, error) { - var streamErrors []*storagepb.StorageError - for _, streamName := range req.GetWriteStreams() { - status, _, exists := s.lookupWriteStreamStatus(streamName) - if !exists { - streamErrors = append(streamErrors, &storagepb.StorageError{ - Code: storagepb.StorageError_STREAM_NOT_FOUND, - Entity: streamName, - ErrorMessage: fmt.Sprintf("failed to find stream from %s", streamName), - }) - continue - } - status.mu.Lock() - rows := append(types.Data(nil), status.rows...) - status.mu.Unlock() - conn, err := s.server.connMgr.Connection(ctx, status.projectID, status.datasetID) - if err != nil { - streamErrors = append(streamErrors, s.createUnspecifiedStorageError(streamName, err)) - continue - } - tx, err := conn.Begin(ctx) - if err != nil { - streamErrors = append(streamErrors, s.createUnspecifiedStorageError(streamName, err)) - continue - } - defer tx.RollbackIfNotCommitted() - if err := s.insertTableData(ctx, tx, status, rows); err != nil { - streamErrors = append(streamErrors, s.createUnspecifiedStorageError(streamName, err)) - continue - } - if err := tx.Commit(); err != nil { - streamErrors = append(streamErrors, s.createUnspecifiedStorageError(streamName, err)) - } - } - return &storagepb.BatchCommitWriteStreamsResponse{ - CommitTime: timestamppb.New(time.Now()), - StreamErrors: streamErrors, - }, nil -} - -func (s *storageWriteServer) createUnspecifiedStorageError(streamName string, err error) *storagepb.StorageError { - return &storagepb.StorageError{ - Code: storagepb.StorageError_STORAGE_ERROR_CODE_UNSPECIFIED, - Entity: streamName, - ErrorMessage: err.Error(), - } -} - -func (s *storageWriteServer) FlushRows(ctx context.Context, req *storagepb.FlushRowsRequest) (*storagepb.FlushRowsResponse, error) { - streamName := req.GetWriteStream() - status, _, exists := s.lookupWriteStreamStatus(streamName) - if !exists { - return nil, fmt.Errorf("failed to find stream from %s", streamName) - } - if req.GetOffset() == nil { - return nil, fmt.Errorf("offset is required") - } - offset := req.GetOffset().Value - status.mu.Lock() - if offset < 0 || offset >= int64(len(status.rows)) { - status.mu.Unlock() - return nil, fmt.Errorf("offset %d is out of range", offset) - } - rows := append(types.Data(nil), status.rows[:offset+1]...) - status.mu.Unlock() - conn, err := s.server.connMgr.Connection(ctx, status.projectID, status.datasetID) - if err != nil { - return nil, err - } - tx, err := conn.Begin(ctx) - if err != nil { - return nil, err - } - defer tx.RollbackIfNotCommitted() - if err := s.insertTableData(ctx, tx, status, rows); err != nil { - return nil, err - } - if err := tx.Commit(); err != nil { - return nil, err - } - return &storagepb.FlushRowsResponse{ - Offset: offset, - }, nil -} - -// According to google documentation every table has a special stream named -// `_default` to which data can be written. This stream doesn't need to be -// created using CreateWriteStream. Client libraries use both -// projects//datasets//tables//_default and -// projects//datasets//tables//streams/_default, -// so accept both forms. -func (s *storageWriteServer) createDefaultStream(ctx context.Context, req *storagepb.GetWriteStreamRequest) (*storagepb.WriteStream, error) { - status, _, err := s.createDefaultStreamStatus(ctx, req.Name) - if err != nil { - return nil, err - } - return status.stream, nil -} - -func (s *storageWriteServer) createDefaultStreamStatus(ctx context.Context, streamName string) (*writeStreamStatus, string, error) { - tablePath, err := defaultWriteStreamTablePath(streamName) - if err != nil { - return nil, "", err - } - canonicalName := canonicalDefaultWriteStreamName(tablePath) - projectID, datasetID, tableID, err := getIDsFromPath(tablePath) - if err != nil { - return nil, "", err - } - tableMetadata, err := getTableMetadata(ctx, s.server, projectID, datasetID, tableID) - if err != nil { - return nil, "", err - } - streamStatus := newWriteStreamStatus(canonicalName, storagepb.WriteStream_COMMITTED, projectID, datasetID, tableID, tableMetadata) - s.mu.Lock() - defer s.mu.Unlock() - if status, exists := s.streamMap[canonicalName]; exists { - return status, canonicalName, nil - } - s.streamMap[canonicalName] = streamStatus - return streamStatus, canonicalName, nil -} - -func isDefaultWriteStreamName(name string) bool { - _, err := defaultWriteStreamTablePath(name) - return err == nil -} - -func canonicalWriteStreamName(name string) string { - tablePath, err := defaultWriteStreamTablePath(name) - if err != nil { - return name - } - return canonicalDefaultWriteStreamName(tablePath) -} - -func canonicalDefaultWriteStreamName(tablePath string) string { - return tablePath + "/_default" -} - -func defaultWriteStreamTablePath(name string) (string, error) { - switch { - case strings.HasSuffix(name, "/streams/_default"): - return strings.TrimSuffix(name, "/streams/_default"), nil - case strings.HasSuffix(name, "/_default"): - return strings.TrimSuffix(name, "/_default"), nil - default: - return "", fmt.Errorf("unexpected default stream name: %s", name) - } -} - -func getIDsFromPath(path string) (string, string, string, error) { - paths := strings.Split(path, "/") - if len(paths)%2 != 0 { - return "", "", "", fmt.Errorf("unexpected table path: %s", path) - } - var ( - projectID string - datasetID string - tableID string - ) - for i := 0; i < len(paths); i += 2 { - switch paths[i] { - case "projects": - projectID = paths[i+1] - case "datasets": - datasetID = paths[i+1] - case "tables": - tableID = paths[i+1] - } - } - if projectID == "" { - return "", "", "", fmt.Errorf("unspecified project id") - } - if datasetID == "" { - return "", "", "", fmt.Errorf("unspecified dataset id") - } - if tableID == "" { - return "", "", "", fmt.Errorf("unspecified table id") - } - return projectID, datasetID, tableID, nil -} - -func getTableMetadata(ctx context.Context, server *Server, projectID, datasetID, tableID string) (*bigqueryv2.Table, error) { - project, err := server.metaRepo.FindProject(ctx, projectID) - if err != nil { - return nil, err - } - if project == nil { - return nil, fmt.Errorf("project %s is not found", projectID) - } - dataset := project.Dataset(datasetID) - if dataset == nil { - return nil, fmt.Errorf("dataset %s is not found in project %s", datasetID, projectID) - } - table := dataset.Table(tableID) - if table == nil { - return nil, fmt.Errorf("table %s is not found in dataset %s", tableID, datasetID) - } - return table.Content() -} - -func registerStorageServer(grpcServer *grpc.Server, srv *Server) { - storagepb.RegisterBigQueryReadServer( - grpcServer, - &storageReadServer{ - server: srv, - streamMap: map[string]*readStreamStatus{}, - }, - ) - storagepb.RegisterBigQueryWriteServer( - grpcServer, - &storageWriteServer{ - server: srv, - streamMap: map[string]*writeStreamStatus{}, - }, - ) -} +package server + +import ( + "bytes" + "context" + "encoding/binary" + "fmt" + "io" + "strings" + "sync" + "time" + + storagepb "cloud.google.com/go/bigquery/storage/apiv1/storagepb" + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/array" + "github.com/apache/arrow-go/v18/arrow/ipc" + "github.com/apache/arrow-go/v18/arrow/memory" + "github.com/goccy/go-json" + goavro "github.com/linkedin/goavro/v2" + bigqueryv2 "google.golang.org/api/bigquery/v2" + "google.golang.org/genproto/googleapis/rpc/status" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protodesc" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/descriptorpb" + "google.golang.org/protobuf/types/dynamicpb" + "google.golang.org/protobuf/types/known/timestamppb" + "google.golang.org/protobuf/types/known/wrapperspb" + + "github.com/goccy/bigquery-emulator/internal/connection" + "github.com/goccy/bigquery-emulator/internal/logger" + internaltypes "github.com/goccy/bigquery-emulator/internal/types" + "github.com/goccy/bigquery-emulator/types" +) + +var ( + _ storagepb.BigQueryReadServer = &storageReadServer{} + _ storagepb.BigQueryWriteServer = &storageWriteServer{} +) + +type storageReadServer struct { + server *Server + streamMap map[string]*readStreamStatus + mu sync.RWMutex +} + +type readStreamStatus struct { + projectID string + datasetID string + tableID string + outputColumns []string + condition string + dataFormat storagepb.DataFormat + avroSchema *types.AVROSchema + arrowSchema *arrow.Schema + schemaText string +} + +type AVROSchema struct { + ReadSessionSchema *storagepb.ReadSession_AvroSchema + Schema *types.AVROSchema + Text string +} + +type ARROWSchema struct { + ReadSessionSchema *storagepb.ReadSession_ArrowSchema + Schema *arrow.Schema + Text string +} + +func (s *storageReadServer) CreateReadSession(ctx context.Context, req *storagepb.CreateReadSessionRequest) (*storagepb.ReadSession, error) { + sessionID := randomID() + sessionName := fmt.Sprintf("%s/locations/%s/sessions/%s", req.Parent, "location", sessionID) + projectID, datasetID, tableID, err := getIDsFromPath(req.ReadSession.Table) + if err != nil { + return nil, err + } + tableMetadata, err := getTableMetadata(ctx, s.server, projectID, datasetID, tableID) + if err != nil { + return nil, fmt.Errorf("failed to get table metadata: %w", err) + } + // MaxStreamCount is documented as a request for an upper bound, with 0 + // asking the server to choose a sensible default. Real BigQuery always + // returns at least one stream when data is available; the emulator only + // supports a single stream, so treat any non-1 request (including the + // unset 0) as 1 rather than producing a session with zero streams. + if req.MaxStreamCount > 1 { + return nil, fmt.Errorf("currently supported only one stream") + } + streamID := randomID() + streamName := fmt.Sprintf("%s/streams/%s", sessionName, streamID) + streams := []*storagepb.ReadStream{{Name: streamName}} + readSession := &storagepb.ReadSession{ + Name: sessionName, + ExpireTime: timestamppb.New(time.Now().Add(1 * time.Hour)), + Streams: streams, + EstimatedTotalBytesScanned: 0, + DataFormat: req.ReadSession.DataFormat, + Table: req.ReadSession.Table, + ReadOptions: req.ReadSession.ReadOptions, + TableModifiers: req.ReadSession.TableModifiers, + TraceId: req.ReadSession.TraceId, + } + // ReadOptions is optional: a client may create a read session + // without selected fields or a row restriction to read the whole + // table. Guard the dereference so an absent ReadOptions reads + // every column with no filter instead of panicking. + var ( + outputColumns []string + condition string + ) + if readOptions := req.ReadSession.ReadOptions; readOptions != nil { + outputColumns = readOptions.SelectedFields + condition = readOptions.RowRestriction + } + outputColumnMap := map[string]struct{}{} + for _, outputColumn := range outputColumns { + outputColumnMap[outputColumn] = struct{}{} + } + status := &readStreamStatus{ + projectID: projectID, + datasetID: datasetID, + tableID: tableID, + outputColumns: outputColumns, + condition: condition, + dataFormat: readSession.DataFormat, + } + switch readSession.DataFormat { + case storagepb.DataFormat_AVRO: + schema, err := s.getAVROSchema(tableMetadata, outputColumnMap) + if err != nil { + return nil, err + } + readSession.Schema = schema.ReadSessionSchema + status.avroSchema = schema.Schema + status.schemaText = schema.Text + case storagepb.DataFormat_ARROW: + schema, err := s.getARROWSchema(tableMetadata, outputColumnMap) + if err != nil { + return nil, err + } + readSession.Schema = schema.ReadSessionSchema + status.arrowSchema = schema.Schema + status.schemaText = schema.Text + default: + return nil, fmt.Errorf("unexpected data format %s", readSession.DataFormat) + } + s.mu.Lock() + s.streamMap[streamName] = status + s.mu.Unlock() + return readSession, nil +} + +func (s *storageReadServer) ReadRows(req *storagepb.ReadRowsRequest, stream storagepb.BigQueryRead_ReadRowsServer) error { + s.mu.RLock() + status := s.streamMap[req.ReadStream] + s.mu.RUnlock() + + if status == nil { + return fmt.Errorf("failed to find stream status from %s", req.ReadStream) + } + ctx := context.Background() + ctx = logger.WithLogger(ctx, s.server.logger) + + response, err := s.query(ctx, status) + if err != nil { + return err + } + switch status.dataFormat { + case storagepb.DataFormat_AVRO: + if err := s.sendAVRORows(status, response, stream); err != nil { + return err + } + case storagepb.DataFormat_ARROW: + if err := s.sendARROWRows(status, response, stream); err != nil { + return err + } + } + return nil +} + +func (s *storageReadServer) SplitReadStream(ctx context.Context, req *storagepb.SplitReadStreamRequest) (*storagepb.SplitReadStreamResponse, error) { + return nil, fmt.Errorf("unimplemented split read stream") +} + +func (s *storageReadServer) buildQuery(status *readStreamStatus) string { + var columns string + if len(status.outputColumns) != 0 { + outputColumns := make([]string, len(status.outputColumns)) + for idx, outputColumn := range status.outputColumns { + outputColumns[idx] = fmt.Sprintf("`%s`", outputColumn) + } + columns = strings.Join(outputColumns, ",") + } else { + columns = "*" + } + var condition string + if status.condition != "" { + condition = fmt.Sprintf("WHERE %s", status.condition) + } + return fmt.Sprintf("SELECT %s FROM `%s` %s", columns, status.tableID, condition) +} + +func (s *storageReadServer) query(ctx context.Context, status *readStreamStatus) (*internaltypes.QueryResponse, error) { + conn, err := s.server.connMgr.Connection(ctx, status.projectID, status.datasetID) + if err != nil { + return nil, fmt.Errorf("failed to get connection: %w", err) + } + tx, err := conn.Begin(ctx) + if err != nil { + return nil, fmt.Errorf("failed to start transaction: %w", err) + } + defer tx.RollbackIfNotCommitted() + + query := s.buildQuery(status) + return s.server.contentRepo.Query( + ctx, + tx, + status.projectID, + status.datasetID, + query, + nil, + ) +} + +func (s *storageReadServer) getAVROSchema(tableMetadata *bigqueryv2.Table, outputColumnMap map[string]struct{}) (*AVROSchema, error) { + avroSchema := types.TableToAVRO(tableMetadata) + if len(outputColumnMap) != 0 { + filteredFields := make([]*types.AVROFieldSchema, 0, len(avroSchema.Fields)) + for _, field := range avroSchema.Fields { + if _, exists := outputColumnMap[field.Name]; exists { + filteredFields = append(filteredFields, field) + } + } + avroSchema.Fields = filteredFields + } + schema, err := json.Marshal(avroSchema) + if err != nil { + return nil, err + } + schemaText := string(schema) + return &AVROSchema{ + ReadSessionSchema: &storagepb.ReadSession_AvroSchema{ + AvroSchema: &storagepb.AvroSchema{Schema: schemaText}, + }, + Schema: avroSchema, + Text: schemaText, + }, nil +} + +func (s *storageReadServer) sendAVRORows(status *readStreamStatus, response *internaltypes.QueryResponse, stream storagepb.BigQueryRead_ReadRowsServer) error { + codec, err := goavro.NewCodec(status.schemaText) + if err != nil { + return fmt.Errorf("failed to create avro codec from schema %s: %w", status.schemaText, err) + } + var buf []byte + for _, row := range response.Rows { + value, err := row.AVROValue(status.avroSchema.Namespace, status.avroSchema.Fields) + if err != nil { + return fmt.Errorf("failed to convert response fields to avro value: %w", err) + } + b, err := codec.BinaryFromNative(buf, value) + if err != nil { + return fmt.Errorf("failed to encode binary from go value: %w", err) + } + buf = b + } + rows := &storagepb.ReadRowsResponse_AvroRows{ + AvroRows: &storagepb.AvroRows{ + SerializedBinaryRows: buf, + RowCount: int64(response.TotalRows), + }, + } + if err := stream.Send(&storagepb.ReadRowsResponse{ + Rows: rows, + RowCount: int64(response.TotalRows), + Schema: &storagepb.ReadRowsResponse_AvroSchema{ + AvroSchema: &storagepb.AvroSchema{ + Schema: status.schemaText, + }, + }, + }); err != nil { + return fmt.Errorf("failed to send read rows response for avro format: %w", err) + } + return nil +} + +func (s *storageReadServer) getARROWSchema(tableMetadata *bigqueryv2.Table, outputColumnMap map[string]struct{}) (*ARROWSchema, error) { + arrowSchema, err := types.TableToARROW(tableMetadata) + if err != nil { + return nil, err + } + if len(outputColumnMap) != 0 { + filteredFields := make([]arrow.Field, 0, len(arrowSchema.Fields())) + for _, field := range arrowSchema.Fields() { + if _, exists := outputColumnMap[field.Name]; exists { + filteredFields = append(filteredFields, field) + } + } + arrowSchema = arrow.NewSchema(filteredFields, nil) + } + schemaText := arrowSchema.String() + schema, err := s.getSerializedARROWSchema(arrowSchema) + if err != nil { + return nil, err + } + return &ARROWSchema{ + ReadSessionSchema: &storagepb.ReadSession_ArrowSchema{ + ArrowSchema: &storagepb.ArrowSchema{ + SerializedSchema: schema, + }, + }, + Schema: arrowSchema, + Text: schemaText, + }, nil +} + +// splitIPCStream splits an Arrow IPC stream that contains exactly one schema +// message followed by one record-batch message (no EOS marker) into the two +// bare IPC-encapsulated messages required by the BigQuery Storage Read API. +// +// Each IPC message on the wire is framed as: +// +// [continuation: 4 bytes][metaLen: 4 bytes LE][metadata: metaLen bytes][padding to 8-byte][body] +// +// The schema message body is always empty, so its total byte length is +// 4+4+metaLen rounded up to the next 8-byte boundary. Everything that follows +// is the record-batch message. +func splitIPCStream(data []byte) (schemaMsg, recordBatchMsg []byte, err error) { + if len(data) < 8 { + return nil, nil, fmt.Errorf("arrow IPC data too short (%d bytes)", len(data)) + } + metaLen := int(binary.LittleEndian.Uint32(data[4:8])) + schemaEnd := 8 + metaLen + if pad := schemaEnd % 8; pad != 0 { + schemaEnd += 8 - pad + } + if schemaEnd > len(data) { + return nil, nil, fmt.Errorf("schema IPC message overruns buffer (need %d, have %d)", schemaEnd, len(data)) + } + return data[:schemaEnd], data[schemaEnd:], nil +} + +func (s *storageReadServer) getSerializedARROWSchema(schema *arrow.Schema) ([]byte, error) { + mem := memory.NewGoAllocator() + var buf bytes.Buffer + writer := ipc.NewWriter(&buf, ipc.WithAllocator(mem), ipc.WithSchema(schema)) + builder := array.NewRecordBuilder(mem, schema) + defer builder.Release() + record := builder.NewRecord() + if err := writer.Write(record); err != nil { + return nil, err + } + record.Release() + // Do not call writer.Close(): we do not want the EOS marker. + // buf now holds [schema_message][empty_record_batch_message]. + schemaBytes, _, err := splitIPCStream(buf.Bytes()) + if err != nil { + return nil, fmt.Errorf("failed to extract schema IPC message: %w", err) + } + return schemaBytes, nil +} + +func (s *storageReadServer) sendARROWRows(status *readStreamStatus, response *internaltypes.QueryResponse, stream storagepb.BigQueryRead_ReadRowsServer) error { + schema, err := s.getSerializedARROWSchema(status.arrowSchema) + if err != nil { + return err + } + mem := memory.NewGoAllocator() + builder := array.NewRecordBuilder(mem, status.arrowSchema) + defer builder.Release() + for _, row := range response.Rows { + if err := row.AppendValueToARROWBuilder(builder); err != nil { + return err + } + } + record := builder.NewRecord() + var buf bytes.Buffer + writer := ipc.NewWriter(&buf, ipc.WithAllocator(mem), ipc.WithSchema(status.arrowSchema)) + if err := writer.Write(record); err != nil { + return err + } + record.Release() + // Do not call writer.Close(): we do not want the EOS marker. + // buf now holds [schema_message][record_batch_message]; extract only the record batch. + _, recordBatchBytes, err := splitIPCStream(buf.Bytes()) + if err != nil { + return fmt.Errorf("failed to extract record batch IPC message: %w", err) + } + rows := &storagepb.ReadRowsResponse_ArrowRecordBatch{ + ArrowRecordBatch: &storagepb.ArrowRecordBatch{ + SerializedRecordBatch: recordBatchBytes, + RowCount: int64(response.TotalRows), + }, + } + if err := stream.Send(&storagepb.ReadRowsResponse{ + Rows: rows, + RowCount: int64(response.TotalRows), + Schema: &storagepb.ReadRowsResponse_ArrowSchema{ + ArrowSchema: &storagepb.ArrowSchema{ + SerializedSchema: schema, + }, + }, + }); err != nil { + return fmt.Errorf("failed to send read rows response for arrow format: %w", err) + } + return nil +} + +type storageWriteServer struct { + server *Server + streamMap map[string]*writeStreamStatus + mu sync.RWMutex +} + +type writeStreamStatus struct { + mu sync.Mutex + streamType storagepb.WriteStream_Type + stream *storagepb.WriteStream + projectID string + datasetID string + tableID string + tableMetadata *bigqueryv2.Table + rows types.Data + finalized bool +} + +func newWriteStreamStatus(streamName string, streamType storagepb.WriteStream_Type, projectID, datasetID, tableID string, tableMetadata *bigqueryv2.Table) *writeStreamStatus { + createTime := timestamppb.New(time.Now()) + var commitTime *timestamppb.Timestamp + if streamType == storagepb.WriteStream_COMMITTED { + commitTime = createTime + } + schema := types.TableToProto(tableMetadata) + stream := &storagepb.WriteStream{ + Name: streamName, + Type: streamType, + CreateTime: createTime, + CommitTime: commitTime, + TableSchema: schema, + WriteMode: storagepb.WriteStream_INSERT, + } + return &writeStreamStatus{ + streamType: streamType, + stream: stream, + projectID: projectID, + datasetID: datasetID, + tableID: tableID, + tableMetadata: tableMetadata, + } +} + +func (s *writeStreamStatus) ensureAppendable() error { + s.mu.Lock() + defer s.mu.Unlock() + if s.finalized { + return fmt.Errorf("stream is already finalized") + } + return nil +} + +func (s *writeStreamStatus) appendBufferedRows(data types.Data) error { + s.mu.Lock() + defer s.mu.Unlock() + if s.finalized { + return fmt.Errorf("stream is already finalized") + } + s.rows = append(s.rows, data...) + return nil +} + +func (s *storageWriteServer) CreateWriteStream(ctx context.Context, req *storagepb.CreateWriteStreamRequest) (*storagepb.WriteStream, error) { + projectID, datasetID, tableID, err := getIDsFromPath(req.Parent) + if err != nil { + return nil, err + } + tableMetadata, err := getTableMetadata(ctx, s.server, projectID, datasetID, tableID) + if err != nil { + return nil, fmt.Errorf("failed to get table metadata: %w", err) + } + streamID := randomID() + streamName := fmt.Sprintf("%s/streams/%s", req.Parent, streamID) + streamType := req.GetWriteStream().GetType() + streamStatus := newWriteStreamStatus(streamName, streamType, projectID, datasetID, tableID, tableMetadata) + s.mu.Lock() + s.streamMap[streamName] = streamStatus + s.mu.Unlock() + return streamStatus.stream, nil +} + +func (s *storageWriteServer) AppendRows(stream storagepb.BigQueryWrite_AppendRowsServer) error { + req, err := stream.Recv() + if err == io.EOF { + return nil + } + if err != nil { + return err + } + msgDesc, err := s.getMessageDescriptor(req) + if err != nil { + return err + } + status, streamName, err := s.appendRows(req, msgDesc, stream, nil, "") + if err != nil { + return fmt.Errorf("failed to append rows: %w", err) + } + for { + req, err := stream.Recv() + if err == io.EOF { + break + } + if err != nil { + return err + } + status, streamName, err = s.appendRows(req, msgDesc, stream, status, streamName) + if err != nil { + return fmt.Errorf("failed to append rows: %w", err) + } + } + return nil +} + +func (s *storageWriteServer) getMessageDescriptor(req *storagepb.AppendRowsRequest) (protoreflect.MessageDescriptor, error) { + descProto := req.GetProtoRows().GetWriterSchema().GetProtoDescriptor() + fdProto := &descriptorpb.FileDescriptorProto{ + Name: proto.String("proto"), + MessageType: []*descriptorpb.DescriptorProto{ + descProto, + }, + } + fd, err := protodesc.NewFile(fdProto, nil) + if err != nil { + return nil, fmt.Errorf("failed to create file descriptor: %w", err) + } + return fd.Messages().ByName(protoreflect.Name(descProto.GetName())), nil +} + +func (s *storageWriteServer) appendRows(req *storagepb.AppendRowsRequest, msgDesc protoreflect.MessageDescriptor, stream storagepb.BigQueryWrite_AppendRowsServer, fallbackStatus *writeStreamStatus, fallbackStreamName string) (*writeStreamStatus, string, error) { + streamName := req.GetWriteStream() + var status *writeStreamStatus + if streamName == "" { + status = fallbackStatus + streamName = fallbackStreamName + } else { + var err error + status, streamName, err = s.getOrCreateWriteStreamStatus(stream.Context(), streamName) + if err != nil { + return nil, "", err + } + } + if status == nil { + return nil, "", fmt.Errorf("write stream is not specified") + } + offset := int64(0) + if req.GetOffset() != nil { + offset = req.GetOffset().Value + } + rows := req.GetProtoRows().GetRows().GetSerializedRows() + data, err := s.decodeData(msgDesc, rows) + if err != nil { + s.sendErrorMessage(stream, streamName, err) + return nil, "", err + } + if status.streamType == storagepb.WriteStream_COMMITTED { + if err := status.ensureAppendable(); err != nil { + return nil, "", err + } + ctx := context.Background() + ctx = logger.WithLogger(ctx, s.server.logger) + + conn, err := s.server.connMgr.Connection(ctx, status.projectID, status.datasetID) + if err != nil { + s.sendErrorMessage(stream, streamName, err) + return nil, "", err + } + tx, err := conn.Begin(ctx) + if err != nil { + s.sendErrorMessage(stream, streamName, err) + return nil, "", err + } + defer tx.RollbackIfNotCommitted() + if err := s.insertTableData(ctx, tx, status, data); err != nil { + s.sendErrorMessage(stream, streamName, err) + return nil, "", err + } + if err := tx.Commit(); err != nil { + s.sendErrorMessage(stream, streamName, err) + return nil, "", err + } + } else { + if err := status.appendBufferedRows(data); err != nil { + return nil, "", err + } + } + if err := s.sendResult(stream, streamName, offset+int64(len(rows))); err != nil { + return nil, "", err + } + return status, streamName, nil + +} + +func (s *storageWriteServer) lookupWriteStreamStatus(streamName string) (*writeStreamStatus, string, bool) { + canonicalName := canonicalWriteStreamName(streamName) + s.mu.RLock() + status, exists := s.streamMap[canonicalName] + s.mu.RUnlock() + return status, canonicalName, exists +} + +func (s *storageWriteServer) getOrCreateWriteStreamStatus(ctx context.Context, streamName string) (*writeStreamStatus, string, error) { + status, canonicalName, exists := s.lookupWriteStreamStatus(streamName) + if exists { + return status, canonicalName, nil + } + if !isDefaultWriteStreamName(streamName) { + return nil, "", fmt.Errorf("failed to get stream from %s", streamName) + } + status, canonicalName, err := s.createDefaultStreamStatus(ctx, streamName) + if err != nil { + return nil, "", fmt.Errorf("failed to get stream from %s", streamName) + } + return status, canonicalName, nil +} + +func (s *storageWriteServer) sendResult(stream storagepb.BigQueryWrite_AppendRowsServer, streamName string, offset int64) error { + return stream.Send(&storagepb.AppendRowsResponse{ + WriteStream: streamName, + Response: &storagepb.AppendRowsResponse_AppendResult_{ + AppendResult: &storagepb.AppendRowsResponse_AppendResult{ + Offset: wrapperspb.Int64(offset), + }, + }, + }) +} + +func (s *storageWriteServer) sendErrorMessage(stream storagepb.BigQueryWrite_AppendRowsServer, streamName string, err error) error { + return stream.Send(&storagepb.AppendRowsResponse{ + WriteStream: streamName, + Response: &storagepb.AppendRowsResponse_Error{ + Error: &status.Status{ + Code: int32(codes.Internal), + Message: err.Error(), + }, + }, + }) +} + +func (s *storageWriteServer) decodeData(msgDesc protoreflect.MessageDescriptor, rows [][]byte) (types.Data, error) { + data := types.Data{} + for _, row := range rows { + msg := dynamicpb.NewMessage(msgDesc) + rowData, err := s.decodeRowData(row, msg) + if err != nil { + return nil, err + } + data = append(data, rowData) + } + return data, nil +} + +func (s *storageWriteServer) decodeRowData(data []byte, msg *dynamicpb.Message) (map[string]interface{}, error) { + if err := proto.Unmarshal(data, msg); err != nil { + return nil, fmt.Errorf("failed to decode message: %w", err) + } + ret := map[string]interface{}{} + var decodeErr error + msg.Range(func(f protoreflect.FieldDescriptor, val protoreflect.Value) bool { + v, err := s.decodeProtoReflectValue(f, val) + if err != nil { + decodeErr = err + return false + } + ret[f.TextName()] = v + return true + }) + return ret, decodeErr +} + +func (s *storageWriteServer) decodeProtoReflectValue(f protoreflect.FieldDescriptor, v protoreflect.Value) (interface{}, error) { + if f.IsList() { + list := v.List() + ret := make([]interface{}, 0, list.Len()) + if !list.IsValid() { + return ret, nil + } + for i := 0; i < list.Len(); i++ { + vv := list.Get(i) + elem, err := s.decodeProtoReflectValueFromKind(f.Kind(), vv) + if err != nil { + return nil, err + } + ret = append(ret, elem) + } + return ret, nil + } + return s.decodeProtoReflectValueFromKind(f.Kind(), v) +} + +func (s *storageWriteServer) decodeProtoReflectValueFromKind(kind protoreflect.Kind, v protoreflect.Value) (interface{}, error) { + if !v.IsValid() { + return nil, nil + } + switch kind { + case protoreflect.BoolKind: + return v.Bool(), nil + case protoreflect.EnumKind: + return v.Enum(), nil + case protoreflect.Int32Kind: + return v.Int(), nil + case protoreflect.Sint32Kind: + return v.Int(), nil + case protoreflect.Uint32Kind: + return v.Uint(), nil + case protoreflect.Int64Kind: + return v.Int(), nil + case protoreflect.Sint64Kind: + return v.Int(), nil + case protoreflect.Uint64Kind: + return v.Uint(), nil + case protoreflect.Sfixed32Kind: + return v.Int(), nil + case protoreflect.Fixed32Kind: + return v.Int(), nil + case protoreflect.FloatKind: + return v.Float(), nil + case protoreflect.Sfixed64Kind: + return v.Int(), nil + case protoreflect.Fixed64Kind: + return v.Float(), nil + case protoreflect.DoubleKind: + return v.Float(), nil + case protoreflect.StringKind: + return v.String(), nil + case protoreflect.BytesKind: + return v.Bytes(), nil + case protoreflect.MessageKind: + msg := v.Message() + // A google.protobuf scalar wrapper (wrappers.proto) carries a + // single `value` field; BigQuery maps a wrapper-typed proto + // field to its underlying scalar column, so unwrap it rather + // than rendering a one-field STRUCT. + if scalar, ok, err := s.unwrapWrapperValue(msg); ok || err != nil { + return scalar, err + } + structV := map[string]interface{}{} + var decodeErr error + msg.Range(func(f protoreflect.FieldDescriptor, val protoreflect.Value) bool { + v, err := s.decodeProtoReflectValue(f, val) + if err != nil { + decodeErr = err + return false + } + structV[f.TextName()] = v + return true + }) + return structV, decodeErr + case protoreflect.GroupKind: + return nil, fmt.Errorf("unsupported group kind for storage api") + } + return nil, fmt.Errorf("specified unknown kind") +} + +// wrapperMessageNames is the set of google.protobuf scalar wrapper messages +// (wrappers.proto). Each carries a single scalar `value` field. +var wrapperMessageNames = map[protoreflect.FullName]struct{}{ + "google.protobuf.DoubleValue": {}, + "google.protobuf.FloatValue": {}, + "google.protobuf.Int64Value": {}, + "google.protobuf.UInt64Value": {}, + "google.protobuf.Int32Value": {}, + "google.protobuf.UInt32Value": {}, + "google.protobuf.BoolValue": {}, + "google.protobuf.StringValue": {}, + "google.protobuf.BytesValue": {}, +} + +// unwrapWrapperValue reports whether msg is a google.protobuf scalar wrapper +// and, if so, returns its underlying `value`. A wrapper-typed proto field +// maps to its scalar BigQuery column, so the Storage Write API decoder must +// hand back the bare scalar instead of a single-field STRUCT, which the +// schema-driven normalizer cannot reconcile with a scalar column. +func (s *storageWriteServer) unwrapWrapperValue(msg protoreflect.Message) (interface{}, bool, error) { + if _, ok := wrapperMessageNames[msg.Descriptor().FullName()]; !ok { + return nil, false, nil + } + valueField := msg.Descriptor().Fields().ByName("value") + if valueField == nil { + return nil, false, nil + } + v, err := s.decodeProtoReflectValueFromKind(valueField.Kind(), msg.Get(valueField)) + return v, true, err +} + +func (s *storageWriteServer) insertTableData(ctx context.Context, tx *connection.Tx, status *writeStreamStatus, data types.Data) error { + tableDef, err := types.NewTableWithSchema(status.tableMetadata, data) + if err != nil { + return err + } + if err := s.server.contentRepo.AddTableData( + ctx, + tx, + status.projectID, + status.datasetID, + tableDef, + ); err != nil { + return fmt.Errorf("failed to add table data: %w", err) + } + return nil +} + +func (s *storageWriteServer) GetWriteStream(ctx context.Context, req *storagepb.GetWriteStreamRequest) (*storagepb.WriteStream, error) { + status, _, err := s.getOrCreateWriteStreamStatus(ctx, req.Name) + if err != nil { + return nil, fmt.Errorf("failed to find stream from %s", req.Name) + } + return status.stream, nil +} + +func (s *storageWriteServer) FinalizeWriteStream(ctx context.Context, req *storagepb.FinalizeWriteStreamRequest) (*storagepb.FinalizeWriteStreamResponse, error) { + status, _, exists := s.lookupWriteStreamStatus(req.GetName()) + if !exists { + return nil, fmt.Errorf("failed to get stream from %s", req.GetName()) + } + status.mu.Lock() + status.finalized = true + rowCount := int64(len(status.rows)) + status.mu.Unlock() + return &storagepb.FinalizeWriteStreamResponse{ + RowCount: rowCount, + }, nil +} + +func (s *storageWriteServer) BatchCommitWriteStreams(ctx context.Context, req *storagepb.BatchCommitWriteStreamsRequest) (*storagepb.BatchCommitWriteStreamsResponse, error) { + var streamErrors []*storagepb.StorageError + for _, streamName := range req.GetWriteStreams() { + status, _, exists := s.lookupWriteStreamStatus(streamName) + if !exists { + streamErrors = append(streamErrors, &storagepb.StorageError{ + Code: storagepb.StorageError_STREAM_NOT_FOUND, + Entity: streamName, + ErrorMessage: fmt.Sprintf("failed to find stream from %s", streamName), + }) + continue + } + status.mu.Lock() + rows := append(types.Data(nil), status.rows...) + status.mu.Unlock() + conn, err := s.server.connMgr.Connection(ctx, status.projectID, status.datasetID) + if err != nil { + streamErrors = append(streamErrors, s.createUnspecifiedStorageError(streamName, err)) + continue + } + tx, err := conn.Begin(ctx) + if err != nil { + streamErrors = append(streamErrors, s.createUnspecifiedStorageError(streamName, err)) + continue + } + defer tx.RollbackIfNotCommitted() + if err := s.insertTableData(ctx, tx, status, rows); err != nil { + streamErrors = append(streamErrors, s.createUnspecifiedStorageError(streamName, err)) + continue + } + if err := tx.Commit(); err != nil { + streamErrors = append(streamErrors, s.createUnspecifiedStorageError(streamName, err)) + } + } + return &storagepb.BatchCommitWriteStreamsResponse{ + CommitTime: timestamppb.New(time.Now()), + StreamErrors: streamErrors, + }, nil +} + +func (s *storageWriteServer) createUnspecifiedStorageError(streamName string, err error) *storagepb.StorageError { + return &storagepb.StorageError{ + Code: storagepb.StorageError_STORAGE_ERROR_CODE_UNSPECIFIED, + Entity: streamName, + ErrorMessage: err.Error(), + } +} + +func (s *storageWriteServer) FlushRows(ctx context.Context, req *storagepb.FlushRowsRequest) (*storagepb.FlushRowsResponse, error) { + streamName := req.GetWriteStream() + status, _, exists := s.lookupWriteStreamStatus(streamName) + if !exists { + return nil, fmt.Errorf("failed to find stream from %s", streamName) + } + if req.GetOffset() == nil { + return nil, fmt.Errorf("offset is required") + } + offset := req.GetOffset().Value + status.mu.Lock() + if offset < 0 || offset >= int64(len(status.rows)) { + status.mu.Unlock() + return nil, fmt.Errorf("offset %d is out of range", offset) + } + rows := append(types.Data(nil), status.rows[:offset+1]...) + status.mu.Unlock() + conn, err := s.server.connMgr.Connection(ctx, status.projectID, status.datasetID) + if err != nil { + return nil, err + } + tx, err := conn.Begin(ctx) + if err != nil { + return nil, err + } + defer tx.RollbackIfNotCommitted() + if err := s.insertTableData(ctx, tx, status, rows); err != nil { + return nil, err + } + if err := tx.Commit(); err != nil { + return nil, err + } + return &storagepb.FlushRowsResponse{ + Offset: offset, + }, nil +} + +// According to google documentation every table has a special stream named +// `_default` to which data can be written. This stream doesn't need to be +// created using CreateWriteStream. Client libraries use both +// projects//datasets//tables//_default and +// projects//datasets//tables//streams/_default, +// so accept both forms. +func (s *storageWriteServer) createDefaultStream(ctx context.Context, req *storagepb.GetWriteStreamRequest) (*storagepb.WriteStream, error) { + status, _, err := s.createDefaultStreamStatus(ctx, req.Name) + if err != nil { + return nil, err + } + return status.stream, nil +} + +func (s *storageWriteServer) createDefaultStreamStatus(ctx context.Context, streamName string) (*writeStreamStatus, string, error) { + tablePath, err := defaultWriteStreamTablePath(streamName) + if err != nil { + return nil, "", err + } + canonicalName := canonicalDefaultWriteStreamName(tablePath) + projectID, datasetID, tableID, err := getIDsFromPath(tablePath) + if err != nil { + return nil, "", err + } + tableMetadata, err := getTableMetadata(ctx, s.server, projectID, datasetID, tableID) + if err != nil { + return nil, "", err + } + streamStatus := newWriteStreamStatus(canonicalName, storagepb.WriteStream_COMMITTED, projectID, datasetID, tableID, tableMetadata) + s.mu.Lock() + defer s.mu.Unlock() + if status, exists := s.streamMap[canonicalName]; exists { + return status, canonicalName, nil + } + s.streamMap[canonicalName] = streamStatus + return streamStatus, canonicalName, nil +} + +func isDefaultWriteStreamName(name string) bool { + _, err := defaultWriteStreamTablePath(name) + return err == nil +} + +func canonicalWriteStreamName(name string) string { + tablePath, err := defaultWriteStreamTablePath(name) + if err != nil { + return name + } + return canonicalDefaultWriteStreamName(tablePath) +} + +func canonicalDefaultWriteStreamName(tablePath string) string { + return tablePath + "/_default" +} + +func defaultWriteStreamTablePath(name string) (string, error) { + switch { + case strings.HasSuffix(name, "/streams/_default"): + return strings.TrimSuffix(name, "/streams/_default"), nil + case strings.HasSuffix(name, "/_default"): + return strings.TrimSuffix(name, "/_default"), nil + default: + return "", fmt.Errorf("unexpected default stream name: %s", name) + } +} + +func getIDsFromPath(path string) (string, string, string, error) { + paths := strings.Split(path, "/") + if len(paths)%2 != 0 { + return "", "", "", fmt.Errorf("unexpected table path: %s", path) + } + var ( + projectID string + datasetID string + tableID string + ) + for i := 0; i < len(paths); i += 2 { + switch paths[i] { + case "projects": + projectID = paths[i+1] + case "datasets": + datasetID = paths[i+1] + case "tables": + tableID = paths[i+1] + } + } + if projectID == "" { + return "", "", "", fmt.Errorf("unspecified project id") + } + if datasetID == "" { + return "", "", "", fmt.Errorf("unspecified dataset id") + } + if tableID == "" { + return "", "", "", fmt.Errorf("unspecified table id") + } + return projectID, datasetID, tableID, nil +} + +func getTableMetadata(ctx context.Context, server *Server, projectID, datasetID, tableID string) (*bigqueryv2.Table, error) { + project, err := server.metaRepo.FindProject(ctx, projectID) + if err != nil { + return nil, err + } + if project == nil { + return nil, fmt.Errorf("project %s is not found", projectID) + } + dataset := project.Dataset(datasetID) + if dataset == nil { + return nil, fmt.Errorf("dataset %s is not found in project %s", datasetID, projectID) + } + table := dataset.Table(tableID) + if table == nil { + return nil, fmt.Errorf("table %s is not found in dataset %s", tableID, datasetID) + } + return table.Content() +} + +func registerStorageServer(grpcServer *grpc.Server, srv *Server) { + storagepb.RegisterBigQueryReadServer( + grpcServer, + &storageReadServer{ + server: srv, + streamMap: map[string]*readStreamStatus{}, + }, + ) + storagepb.RegisterBigQueryWriteServer( + grpcServer, + &storageWriteServer{ + server: srv, + streamMap: map[string]*writeStreamStatus{}, + }, + ) +} diff --git a/server/storage_test.go b/server/storage_test.go index d1cfb7a73..3c91ba620 100644 --- a/server/storage_test.go +++ b/server/storage_test.go @@ -1,896 +1,893 @@ -// FYI: https://cloud.google.com/bigquery/docs/reference/storage/libraries?hl=ja#client-libraries-usage-go -package server_test - -import ( - "bytes" - "context" - "fmt" - "io" - "math/rand" - "path/filepath" - "sync" - "testing" - "time" - - "cloud.google.com/go/bigquery" - bqStorage "cloud.google.com/go/bigquery/storage/apiv1" - storagepb "cloud.google.com/go/bigquery/storage/apiv1/storagepb" - "cloud.google.com/go/bigquery/storage/managedwriter" - "cloud.google.com/go/bigquery/storage/managedwriter/adapt" - "github.com/GoogleCloudPlatform/golang-samples/bigquery/snippets/managedwriter/exampleproto" - "github.com/apache/arrow-go/v18/arrow" - "github.com/apache/arrow-go/v18/arrow/ipc" - "github.com/apache/arrow-go/v18/arrow/memory" - "github.com/goccy/bigquery-emulator/server" - "github.com/goccy/go-json" - gax "github.com/googleapis/gax-go/v2" - goavro "github.com/linkedin/goavro/v2" - "google.golang.org/api/iterator" - "google.golang.org/api/option" - "google.golang.org/genproto/googleapis/rpc/errdetails" - "google.golang.org/grpc" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" - "google.golang.org/protobuf/proto" - - "github.com/goccy/bigquery-emulator/types" -) - -var ( - rpcOpts = gax.WithGRPCOptions( - grpc.MaxCallRecvMsgSize(1024 * 1024 * 129), - ) - outputColumns = []string{"id", "name", "structarr", "birthday", "skillNum", "created_at"} -) - -func TestStorageReadAVRO(t *testing.T) { - const ( - project = "test" - dataset = "dataset1" - table = "table_a" - ) - ctx := context.Background() - bqServer, err := server.New(server.TempStorage) - if err != nil { - t.Fatal(err) - } - if err := bqServer.Load(server.YAMLSource(filepath.Join("testdata", "data.yaml"))); err != nil { - t.Fatal(err) - } - testServer := bqServer.TestServer() - defer func() { - testServer.Close() - bqServer.Close() - }() - opts, err := testServer.GRPCClientOptions(ctx) - if err != nil { - t.Fatal(err) - } - bqReadClient, err := bqStorage.NewBigQueryReadClient(ctx, opts...) - if err != nil { - t.Fatal(err) - } - defer bqReadClient.Close() - - readTable := fmt.Sprintf("projects/%s/datasets/%s/tables/%s", project, dataset, table) - - tableReadOptions := &storagepb.ReadSession_TableReadOptions{ - SelectedFields: outputColumns, - RowRestriction: `id = 1`, - } - - createReadSessionRequest := &storagepb.CreateReadSessionRequest{ - Parent: fmt.Sprintf("projects/%s", project), - ReadSession: &storagepb.ReadSession{ - Table: readTable, - DataFormat: storagepb.DataFormat_AVRO, - ReadOptions: tableReadOptions, - }, - MaxStreamCount: 1, - } - - // Create the session from the request. - session, err := bqReadClient.CreateReadSession(ctx, createReadSessionRequest, rpcOpts) - if err != nil { - t.Fatalf("CreateReadSession: %v", err) - } - if len(session.GetStreams()) == 0 { - t.Fatalf("no streams in session. if this was a small query result, consider writing to output to a named table.") - } - - // We'll use only a single stream for reading data from the table. Because - // of dynamic sharding, this will yield all the rows in the table. However, - // if you wanted to fan out multiple readers you could do so by having a - // increasing the MaxStreamCount. - readStream := session.GetStreams()[0].Name - - // streamCtx lets either goroutine unblock the other on failure: the - // decoder cancels it when it stops, which aborts the reader's stream. - streamCtx, cancel := context.WithCancel(ctx) - defer cancel() - - ch := make(chan *storagepb.ReadRowsResponse) - - // Use a waitgroup to coordinate the reading and decoding goroutines. - var wg sync.WaitGroup - - // Start the reading in one goroutine. - wg.Add(1) - go func() { - defer wg.Done() - // Always close ch so the decoder goroutine cannot block forever - // waiting for rows that will never arrive. - defer close(ch) - if err := processStream(t, streamCtx, bqReadClient, readStream, ch); err != nil { - t.Errorf("processStream failure: %v", err) - } - }() - - // Start Avro processing and decoding in another goroutine. - wg.Add(1) - go func() { - defer wg.Done() - // Cancel the read stream when decoding stops so the reader - // goroutine cannot block forever sending on ch. - defer cancel() - if err := processAvro(t, streamCtx, session.GetAvroSchema().GetSchema(), ch); err != nil { - t.Errorf("error processing %s: %v", storagepb.DataFormat_AVRO, err) - } - }() - - // Wait until both the reading and decoding goroutines complete. - wg.Wait() -} - -func TestStorageReadARROW(t *testing.T) { - const ( - project = "test" - dataset = "dataset1" - table = "table_a" - ) - ctx := context.Background() - bqServer, err := server.New(server.TempStorage) - if err != nil { - t.Fatal(err) - } - if err := bqServer.Load(server.YAMLSource(filepath.Join("testdata", "data.yaml"))); err != nil { - t.Fatal(err) - } - testServer := bqServer.TestServer() - defer func() { - testServer.Close() - bqServer.Close() - }() - opts, err := testServer.GRPCClientOptions(ctx) - if err != nil { - t.Fatal(err) - } - bqReadClient, err := bqStorage.NewBigQueryReadClient(ctx, opts...) - if err != nil { - t.Fatal(err) - } - defer bqReadClient.Close() - - readTable := fmt.Sprintf("projects/%s/datasets/%s/tables/%s", project, dataset, table) - - tableReadOptions := &storagepb.ReadSession_TableReadOptions{ - SelectedFields: outputColumns, - RowRestriction: `id = 1`, - } - - createReadSessionRequest := &storagepb.CreateReadSessionRequest{ - Parent: fmt.Sprintf("projects/%s", project), - ReadSession: &storagepb.ReadSession{ - Table: readTable, - DataFormat: storagepb.DataFormat_ARROW, - ReadOptions: tableReadOptions, - }, - MaxStreamCount: 1, - } - - // Create the session from the request. - session, err := bqReadClient.CreateReadSession(ctx, createReadSessionRequest, rpcOpts) - if err != nil { - t.Fatalf("CreateReadSession: %v", err) - } - if len(session.GetStreams()) == 0 { - t.Fatalf("no streams in session. if this was a small query result, consider writing to output to a named table.") - } - - // We'll use only a single stream for reading data from the table. Because - // of dynamic sharding, this will yield all the rows in the table. However, - // if you wanted to fan out multiple readers you could do so by having a - // increasing the MaxStreamCount. - readStream := session.GetStreams()[0].Name - - // streamCtx lets either goroutine unblock the other on failure: the - // decoder cancels it when it stops, which aborts the reader's stream. - streamCtx, cancel := context.WithCancel(ctx) - defer cancel() - - ch := make(chan *storagepb.ReadRowsResponse) - - // Use a waitgroup to coordinate the reading and decoding goroutines. - var wg sync.WaitGroup - - // Start the reading in one goroutine. - wg.Add(1) - go func() { - defer wg.Done() - // Always close ch so the decoder goroutine cannot block forever - // waiting for rows that will never arrive. - defer close(ch) - if err := processStream(t, streamCtx, bqReadClient, readStream, ch); err != nil { - t.Errorf("processStream failure: %v", err) - } - }() - - // Start Arrow processing and decoding in another goroutine. - wg.Add(1) - go func() { - defer wg.Done() - // Cancel the read stream when decoding stops so the reader - // goroutine cannot block forever sending on ch. - defer cancel() - if err := processArrow(t, streamCtx, session.GetArrowSchema().GetSerializedSchema(), ch); err != nil { - t.Errorf("error processing %s: %v", storagepb.DataFormat_ARROW, err) - } - }() - - // Wait until both the reading and decoding goroutines complete. - wg.Wait() -} - -func processStream(t *testing.T, ctx context.Context, client *bqStorage.BigQueryReadClient, st string, ch chan<- *storagepb.ReadRowsResponse) error { - var offset int64 - - // Streams may be long-running. Rather than using a global retry for the - // stream, implement a retry that resets once progress is made. - retryLimit := 3 - retries := 0 - for { - // Send the initiating request to start streaming row blocks. - rowStream, err := client.ReadRows(ctx, &storagepb.ReadRowsRequest{ - ReadStream: st, - Offset: offset, - }, rpcOpts) - if err != nil { - return fmt.Errorf("couldn't invoke ReadRows: %v", err) - } - - // Process the streamed responses. - for { - r, err := rowStream.Recv() - if err == io.EOF { - return nil - } - if err != nil { - // If there is an error, check whether it is a retryable - // error with a retry delay and sleep instead of increasing - // retries count. - var retryDelayDuration time.Duration - if errorStatus, ok := status.FromError(err); ok && errorStatus.Code() == codes.ResourceExhausted { - for _, detail := range errorStatus.Details() { - retryInfo, ok := detail.(*errdetails.RetryInfo) - if !ok { - continue - } - retryDelay := retryInfo.GetRetryDelay() - retryDelayDuration = time.Duration(retryDelay.Seconds)*time.Second + time.Duration(retryDelay.Nanos)*time.Nanosecond - break - } - } - if retryDelayDuration != 0 { - t.Logf("processStream failed with a retryable error, retrying in %v", retryDelayDuration) - time.Sleep(retryDelayDuration) - } else { - retries++ - if retries >= retryLimit { - return fmt.Errorf("processStream retries exhausted: %v", err) - } - } - // break the inner loop, and try to recover by starting a new streaming - // ReadRows call at the last known good offset. - break - } else { - // Reset retries after a successful response. - retries = 0 - } - - rc := r.GetRowCount() - if rc > 0 { - // Bookmark our progress in case of retries and send the rowblock on the channel. - offset = offset + rc - // We're making progress, reset retries. - retries = 0 - select { - case ch <- r: - case <-ctx.Done(): - return ctx.Err() - } - } - } - } -} - -// processAvro receives row blocks from a channel, and uses the provided Avro -// schema to decode the blocks into individual row messages for printing. Will -// continue to run until the channel is closed or the provided context is -// cancelled. -func processAvro(t *testing.T, ctx context.Context, schema string, ch <-chan *storagepb.ReadRowsResponse) error { - // Establish a decoder that can process blocks of messages using the - // reference schema. All blocks share the same schema, so the decoder - // can be long-lived. - codec, err := goavro.NewCodec(schema) - if err != nil { - return fmt.Errorf("couldn't create codec: %v", err) - } - - for { - select { - case <-ctx.Done(): - // Context was cancelled. Stop. - return ctx.Err() - case rows, ok := <-ch: - if !ok { - // Channel closed, no further avro messages. Stop. - return nil - } - undecoded := rows.GetAvroRows().GetSerializedBinaryRows() - for len(undecoded) > 0 { - datum, remainingBytes, err := codec.NativeFromBinary(undecoded) - - if err != nil { - if err == io.EOF { - break - } - return fmt.Errorf("decoding error with %d bytes remaining: %v", len(undecoded), err) - } - validateDatum(t, datum) - undecoded = remainingBytes - } - } - } -} - -func processArrow(t *testing.T, ctx context.Context, schema []byte, ch <-chan *storagepb.ReadRowsResponse) error { - mem := memory.NewGoAllocator() - buf := bytes.NewBuffer(schema) - r, err := ipc.NewReader(buf, ipc.WithAllocator(mem)) - if err != nil { - return err - } - aschema := r.Schema() - for { - select { - case <-ctx.Done(): - // Context was cancelled. Stop. - return ctx.Err() - case rows, ok := <-ch: - if !ok { - // Channel closed, no further arrow messages. Stop. - return nil - } - undecoded := rows.GetArrowRecordBatch().GetSerializedRecordBatch() - if len(undecoded) > 0 { - buf = bytes.NewBuffer(undecoded) - r, err = ipc.NewReader(buf, ipc.WithAllocator(mem), ipc.WithSchema(aschema)) - if err != nil { - return err - } - for r.Next() { - rec := r.Record() - validateArrowRecord(t, rec) - } - } - } - } -} - -func validateArrowRecord(t *testing.T, record arrow.Record) { - out, err := record.MarshalJSON() - if err != nil { - t.Fatal(err) - } - list := []map[string]interface{}{} - if err := json.Unmarshal(out, &list); err != nil { - t.Fatal(err) - } - if len(list) == 0 { - t.Fatal("failed to get arrow record") - } - first := list[0] - if len(first) != len(outputColumns) { - t.Fatalf("failed to get arrow record %+v", first) - } -} - -func validateDatum(t *testing.T, d interface{}) { - m, ok := d.(map[string]interface{}) - if !ok { - t.Logf("failed type assertion: %v", d) - } - if len(m) != len(outputColumns) { - t.Fatalf("failed to receive table data. expected columns %v but got %v", outputColumns, m) - } -} - -func TestStorageWrite(t *testing.T) { - for _, test := range []struct { - name string - streamType storagepb.WriteStream_Type - isDefaultStream bool - expectedRowsAfterFirstWrite int - expectedRowsAfterSecondWrite int - expectedRowsAfterThirdWrite int - expectedRowsAfterExplicitCommit int - }{ - { - name: "pending", - streamType: storagepb.WriteStream_PENDING, - expectedRowsAfterFirstWrite: 0, - expectedRowsAfterSecondWrite: 0, - expectedRowsAfterThirdWrite: 0, - expectedRowsAfterExplicitCommit: 6, - }, - { - name: "committed", - streamType: storagepb.WriteStream_COMMITTED, - expectedRowsAfterFirstWrite: 1, - expectedRowsAfterSecondWrite: 4, - expectedRowsAfterThirdWrite: 6, - expectedRowsAfterExplicitCommit: 6, - }, - { - name: "default", - streamType: storagepb.WriteStream_COMMITTED, - isDefaultStream: true, - expectedRowsAfterFirstWrite: 1, - expectedRowsAfterSecondWrite: 4, - expectedRowsAfterThirdWrite: 6, - expectedRowsAfterExplicitCommit: 6, - }, - } { - const ( - projectID = "test" - datasetID = "test" - tableID = "sample" - ) - - ctx := context.Background() - bqServer, err := server.New(server.TempStorage) - if err != nil { - t.Fatal(err) - } - if err := bqServer.Load( - server.StructSource( - types.NewProject( - projectID, - types.NewDataset( - datasetID, - types.NewTable( - tableID, - []*types.Column{ - types.NewColumn("bool_col", types.BOOL), - types.NewColumn("bytes_col", types.BYTES), - types.NewColumn("float64_col", types.FLOAT64), - types.NewColumn("int64_col", types.INT64), - types.NewColumn("string_col", types.STRING), - types.NewColumn("date_col", types.DATE), - types.NewColumn("datetime_col", types.DATETIME), - types.NewColumn("geography_col", types.GEOGRAPHY), - types.NewColumn("numeric_col", types.NUMERIC), - types.NewColumn("bignumeric_col", types.BIGNUMERIC), - types.NewColumn("time_col", types.TIME), - types.NewColumn("timestamp_col", types.TIMESTAMP), - types.NewColumn("int64_list", types.INT64, types.ColumnMode(types.RepeatedMode)), - types.NewColumn( - "struct_col", - types.STRUCT, - types.ColumnFields( - types.NewColumn("sub_int_col", types.INT64), - ), - ), - types.NewColumn( - "struct_list", - types.STRUCT, - types.ColumnFields( - types.NewColumn("sub_int_col", types.INT64), - ), - types.ColumnMode(types.RepeatedMode), - ), - }, - nil, - ), - ), - ), - ), - ); err != nil { - t.Fatal(err) - } - testServer := bqServer.TestServer() - defer func() { - testServer.Close() - bqServer.Close() - }() - opts, err := testServer.GRPCClientOptions(ctx) - if err != nil { - t.Fatal(err) - } - - client, err := managedwriter.NewClient(ctx, projectID, opts...) - if err != nil { - t.Fatal(err) - } - defer client.Close() - t.Run(test.name, func(t *testing.T) { - var writeStreamName string - fullTableName := fmt.Sprintf("projects/%s/datasets/%s/tables/%s", projectID, datasetID, tableID) - if !test.isDefaultStream { - writeStream, err := client.CreateWriteStream(ctx, &storagepb.CreateWriteStreamRequest{ - Parent: fullTableName, - WriteStream: &storagepb.WriteStream{ - Type: test.streamType, - }, - }) - if err != nil { - t.Fatalf("CreateWriteStream: %v", err) - } - writeStreamName = writeStream.GetName() - } - m := &exampleproto.SampleData{} - descriptorProto, err := adapt.NormalizeDescriptor(m.ProtoReflect().Descriptor()) - if err != nil { - t.Fatalf("NormalizeDescriptor: %v", err) - } - var writerOptions []managedwriter.WriterOption - if test.isDefaultStream { - writerOptions = append(writerOptions, managedwriter.WithType(managedwriter.DefaultStream)) - writerOptions = append(writerOptions, managedwriter.WithDestinationTable(fullTableName)) - } else { - writerOptions = append(writerOptions, managedwriter.WithStreamName(writeStreamName)) - } - writerOptions = append(writerOptions, managedwriter.WithSchemaDescriptor(descriptorProto)) - managedStream, err := client.NewManagedStream( - ctx, - writerOptions..., - ) - if err != nil { - t.Fatalf("NewManagedStream: %v", err) - } - - bqClient, err := bigquery.NewClient( - ctx, - projectID, - option.WithEndpoint(testServer.URL), - option.WithoutAuthentication(), - ) - if err != nil { - t.Fatal(err) - } - defer bqClient.Close() - - rows, err := generateExampleMessages(1) - if err != nil { - t.Fatalf("generateExampleMessages: %v", err) - } - - var ( - curOffset int64 - results []*managedwriter.AppendResult - ) - result, err := managedStream.AppendRows(ctx, rows, managedwriter.WithOffset(0)) - if err != nil { - t.Fatalf("AppendRows first call error: %v", err) - } - // managedwriter's AppendRows returns a future; the gRPC - // request itself is sent (and acknowledged) asynchronously. - // Wait for the ACK before issuing the Read below — otherwise - // the Read can race the server's COMMITTED-stream insert and - // observe an empty (or partial) table. The wazero-backed - // SQL engine was slow enough to hide the race; a native-Go - // engine (e.g. wasm2go) surfaces it as a flake on the first - // / second AppendRows iteration. - if _, err := result.GetResult(ctx); err != nil { - t.Fatalf("first AppendRows result error: %v", err) - } - - iter := bqClient.Dataset(datasetID).Table(tableID).Read(ctx) - resultRowCount := countRows(t, iter) - if resultRowCount != test.expectedRowsAfterFirstWrite { - t.Fatalf("expected the number of rows after first AppendRows %d but got %d", test.expectedRowsAfterFirstWrite, resultRowCount) - } - - results = append(results, result) - curOffset = curOffset + 1 - rows, err = generateExampleMessages(3) - if err != nil { - t.Fatalf("generateExampleMessages: %v", err) - } - result, err = managedStream.AppendRows(ctx, rows, managedwriter.WithOffset(curOffset)) - if err != nil { - t.Fatalf("AppendRows second call error: %v", err) - } - if _, err := result.GetResult(ctx); err != nil { - t.Fatalf("second AppendRows result error: %v", err) - } - - iter = bqClient.Dataset(datasetID).Table(tableID).Read(ctx) - resultRowCount = countRows(t, iter) - if resultRowCount != test.expectedRowsAfterSecondWrite { - t.Fatalf("expected the number of rows after second AppendRows %d but got %d", test.expectedRowsAfterSecondWrite, resultRowCount) - } - - results = append(results, result) - curOffset = curOffset + 3 - rows, err = generateExampleMessages(2) - if err != nil { - t.Fatalf("generateExampleMessages: %v", err) - } - result, err = managedStream.AppendRows(ctx, rows, managedwriter.WithOffset(curOffset)) - if err != nil { - t.Fatalf("AppendRows third call error: %v", err) - } - results = append(results, result) - - for k, v := range results { - recvOffset, err := v.GetResult(ctx) - if err != nil { - t.Fatalf("append %d returned error: %v", k, err) - } - t.Logf("Successfully appended data at offset %d", recvOffset) - } - - iter = bqClient.Dataset(datasetID).Table(tableID).Read(ctx) - resultRowCount = countRows(t, iter) - if resultRowCount != test.expectedRowsAfterThirdWrite { - t.Fatalf("expected the number of rows after third AppendRows %d but got %d", test.expectedRowsAfterThirdWrite, resultRowCount) - } - - rowCount, err := managedStream.Finalize(ctx) - if err != nil { - t.Fatalf("error during Finalize: %v", err) - } - - t.Logf("Stream %s finalized with %d rows", managedStream.StreamName(), rowCount) - - req := &storagepb.BatchCommitWriteStreamsRequest{ - Parent: managedwriter.TableParentFromStreamName(managedStream.StreamName()), - WriteStreams: []string{managedStream.StreamName()}, - } - - resp, err := client.BatchCommitWriteStreams(ctx, req) - if err != nil { - t.Fatalf("client.BatchCommit: %v", err) - } - if len(resp.GetStreamErrors()) > 0 { - t.Fatalf("stream errors present: %v", resp.GetStreamErrors()) - } - - iter = bqClient.Dataset(datasetID).Table(tableID).Read(ctx) - resultRowCount = countRows(t, iter) - if resultRowCount != test.expectedRowsAfterExplicitCommit { - t.Fatalf("expected the number of rows after Finalize %d but got %d", test.expectedRowsAfterExplicitCommit, resultRowCount) - } - - t.Logf("Table data committed at %s", resp.GetCommitTime().AsTime().Format(time.RFC3339Nano)) - }) - } -} - -func countRows(t *testing.T, iter *bigquery.RowIterator) int { - var resultRowCount int - for { - v := map[string]bigquery.Value{} - if err := iter.Next(&v); err != nil { - if err == iterator.Done { - break - } - t.Fatal(err) - } - resultRowCount++ - } - return resultRowCount -} - -func generateExampleMessages(numMessages int) ([][]byte, error) { - msgs := make([][]byte, numMessages) - for i := 0; i < numMessages; i++ { - - random := rand.New(rand.NewSource(time.Now().UnixNano())) - - // Our example data embeds an array of structs, so we'll construct that first. - sList := make([]*exampleproto.SampleStruct, 5) - for i := 0; i < int(random.Int63n(5)+1); i++ { - sList[i] = &exampleproto.SampleStruct{ - SubIntCol: proto.Int64(random.Int63()), - } - } - - m := &exampleproto.SampleData{ - BoolCol: proto.Bool(true), - BytesCol: []byte("some bytes"), - Float64Col: proto.Float64(3.14), - Int64Col: proto.Int64(123), - StringCol: proto.String("example string value"), - - // These types require special encoding/formatting to transmit. - - // DATE values are number of days since the Unix epoch. - - DateCol: proto.Int32(int32(time.Now().UnixNano() / 86400000000000)), - - // DATETIME uses the literal format. - DatetimeCol: proto.String("2022-01-01 12:13:14.000000"), - - // GEOGRAPHY uses Well-Known-Text (WKT) format. - GeographyCol: proto.String("POINT(-122.350220 47.649154)"), - - // NUMERIC and BIGNUMERIC can be passed as string, or more efficiently - // using a packed byte representation. - NumericCol: proto.String("99999999999999999999999999999.999999999"), - BignumericCol: proto.String("578960446186580977117854925043439539266.34992332820282019728792003956564819967"), - - // TIME also uses literal format. - TimeCol: proto.String("12:13:14.000000"), - - // TIMESTAMP uses microseconds since Unix epoch. - TimestampCol: proto.Int64(time.Now().UnixNano() / 1000), - - // Int64List is an array of INT64 types. - Int64List: []int64{2, 4, 6, 8}, - - // This is a required field, and thus must be present. - RowNum: proto.Int64(23), - - // StructCol is a single nested message. - StructCol: &exampleproto.SampleStruct{ - SubIntCol: proto.Int64(random.Int63()), - }, - - // StructList is a repeated array of a nested message. - StructList: sList, - } - b, err := proto.Marshal(m) - if err != nil { - return nil, fmt.Errorf("error generating message %d: %w", i, err) - } - msgs[i] = b - } - return msgs, nil -} - -// TestIssue382CreateReadSessionWithoutReadOptions is a regression test for -// https://github.com/goccy/bigquery-emulator/issues/382: CreateReadSession -// panicked with a nil pointer dereference when the request's ReadSession -// carried no ReadOptions (a valid request shape — read the whole table with -// no column projection or row restriction). -func TestIssue382CreateReadSessionWithoutReadOptions(t *testing.T) { - const ( - project = "test" - dataset = "dataset1" - table = "table_a" - ) - ctx := context.Background() - bqServer, err := server.New(server.TempStorage) - if err != nil { - t.Fatal(err) - } - if err := bqServer.Load(server.YAMLSource(filepath.Join("testdata", "data.yaml"))); err != nil { - t.Fatal(err) - } - testServer := bqServer.TestServer() - defer func() { - testServer.Close() - bqServer.Close() - }() - opts, err := testServer.GRPCClientOptions(ctx) - if err != nil { - t.Fatal(err) - } - bqReadClient, err := bqStorage.NewBigQueryReadClient(ctx, opts...) - if err != nil { - t.Fatal(err) - } - defer bqReadClient.Close() - - // ReadSession deliberately carries no ReadOptions. - req := &storagepb.CreateReadSessionRequest{ - Parent: fmt.Sprintf("projects/%s", project), - ReadSession: &storagepb.ReadSession{ - Table: fmt.Sprintf("projects/%s/datasets/%s/tables/%s", - project, dataset, table), - DataFormat: storagepb.DataFormat_AVRO, - }, - MaxStreamCount: 1, - } - session, err := bqReadClient.CreateReadSession(ctx, req, rpcOpts) - if err != nil { - t.Fatalf("CreateReadSession without ReadOptions: %v", err) - } - if len(session.GetStreams()) == 0 { - t.Fatal("expected at least one stream in the session") - } - if session.GetAvroSchema().GetSchema() == "" { - t.Fatal("expected an AVRO schema on the session") - } - - // The session must still be readable end to end: with no selected - // fields the whole row is streamed back. - streamCtx, cancel := context.WithCancel(ctx) - defer cancel() - ch := make(chan *storagepb.ReadRowsResponse) - var wg sync.WaitGroup - wg.Add(2) - go func() { - defer wg.Done() - defer close(ch) - if err := processStream(t, streamCtx, bqReadClient, session.GetStreams()[0].Name, ch); err != nil { - t.Errorf("processStream failure: %v", err) - } - }() - go func() { - defer wg.Done() - defer cancel() - if err := processAvro(t, streamCtx, session.GetAvroSchema().GetSchema(), ch); err != nil { - t.Errorf("error processing AVRO: %v", err) - } - }() - wg.Wait() -} - -// TestCreateReadSessionDefaultsToOneStream is a regression test for -// https://github.com/goccy/bigquery-emulator/issues/409: a -// CreateReadSession request with MaxStreamCount left at the zero default -// produced a session carrying zero streams, so the client could never read -// any rows. Real BigQuery picks a sensible default; the emulator caps at one -// stream, so an unset MaxStreamCount must yield one stream too. -func TestCreateReadSessionDefaultsToOneStream(t *testing.T) { - const ( - project = "test" - dataset = "dataset1" - table = "table_a" - ) - ctx := context.Background() - bqServer, err := server.New(server.TempStorage) - if err != nil { - t.Fatal(err) - } - if err := bqServer.Load(server.YAMLSource(filepath.Join("testdata", "data.yaml"))); err != nil { - t.Fatal(err) - } - testServer := bqServer.TestServer() - defer func() { - testServer.Close() - bqServer.Close() - }() - opts, err := testServer.GRPCClientOptions(ctx) - if err != nil { - t.Fatal(err) - } - bqReadClient, err := bqStorage.NewBigQueryReadClient(ctx, opts...) - if err != nil { - t.Fatal(err) - } - defer bqReadClient.Close() - - // MaxStreamCount intentionally omitted (zero value). - req := &storagepb.CreateReadSessionRequest{ - Parent: fmt.Sprintf("projects/%s", project), - ReadSession: &storagepb.ReadSession{ - Table: fmt.Sprintf("projects/%s/datasets/%s/tables/%s", - project, dataset, table), - DataFormat: storagepb.DataFormat_AVRO, - ReadOptions: &storagepb.ReadSession_TableReadOptions{ - SelectedFields: outputColumns, - }, - }, - } - session, err := bqReadClient.CreateReadSession(ctx, req, rpcOpts) - if err != nil { - t.Fatalf("CreateReadSession: %v", err) - } - if got := len(session.GetStreams()); got != 1 { - t.Fatalf("session has %d streams; want 1 (unset MaxStreamCount must default to one)", got) - } -} +// FYI: https://cloud.google.com/bigquery/docs/reference/storage/libraries?hl=ja#client-libraries-usage-go +package server_test + +import ( + "bytes" + "context" + "fmt" + "io" + "math/rand" + "path/filepath" + "sync" + "testing" + "time" + + "cloud.google.com/go/bigquery" + bqStorage "cloud.google.com/go/bigquery/storage/apiv1" + storagepb "cloud.google.com/go/bigquery/storage/apiv1/storagepb" + "cloud.google.com/go/bigquery/storage/managedwriter" + "cloud.google.com/go/bigquery/storage/managedwriter/adapt" + "github.com/GoogleCloudPlatform/golang-samples/bigquery/snippets/managedwriter/exampleproto" + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/ipc" + "github.com/apache/arrow-go/v18/arrow/memory" + "github.com/goccy/bigquery-emulator/server" + "github.com/goccy/go-json" + gax "github.com/googleapis/gax-go/v2" + goavro "github.com/linkedin/goavro/v2" + "google.golang.org/api/iterator" + "google.golang.org/api/option" + "google.golang.org/genproto/googleapis/rpc/errdetails" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" + + "github.com/goccy/bigquery-emulator/types" +) + +var ( + rpcOpts = gax.WithGRPCOptions( + grpc.MaxCallRecvMsgSize(1024 * 1024 * 129), + ) + outputColumns = []string{"id", "name", "structarr", "birthday", "skillNum", "created_at"} +) + +func TestStorageReadAVRO(t *testing.T) { + const ( + project = "test" + dataset = "dataset1" + table = "table_a" + ) + ctx := context.Background() + bqServer, err := server.New(server.TempStorage) + if err != nil { + t.Fatal(err) + } + if err := bqServer.Load(server.YAMLSource(filepath.Join("testdata", "data.yaml"))); err != nil { + t.Fatal(err) + } + testServer := bqServer.TestServer() + defer func() { + testServer.Close() + bqServer.Close() + }() + opts, err := testServer.GRPCClientOptions(ctx) + if err != nil { + t.Fatal(err) + } + bqReadClient, err := bqStorage.NewBigQueryReadClient(ctx, opts...) + if err != nil { + t.Fatal(err) + } + defer bqReadClient.Close() + + readTable := fmt.Sprintf("projects/%s/datasets/%s/tables/%s", project, dataset, table) + + tableReadOptions := &storagepb.ReadSession_TableReadOptions{ + SelectedFields: outputColumns, + RowRestriction: `id = 1`, + } + + createReadSessionRequest := &storagepb.CreateReadSessionRequest{ + Parent: fmt.Sprintf("projects/%s", project), + ReadSession: &storagepb.ReadSession{ + Table: readTable, + DataFormat: storagepb.DataFormat_AVRO, + ReadOptions: tableReadOptions, + }, + MaxStreamCount: 1, + } + + // Create the session from the request. + session, err := bqReadClient.CreateReadSession(ctx, createReadSessionRequest, rpcOpts) + if err != nil { + t.Fatalf("CreateReadSession: %v", err) + } + if len(session.GetStreams()) == 0 { + t.Fatalf("no streams in session. if this was a small query result, consider writing to output to a named table.") + } + + // We'll use only a single stream for reading data from the table. Because + // of dynamic sharding, this will yield all the rows in the table. However, + // if you wanted to fan out multiple readers you could do so by having a + // increasing the MaxStreamCount. + readStream := session.GetStreams()[0].Name + + // streamCtx lets either goroutine unblock the other on failure: the + // decoder cancels it when it stops, which aborts the reader's stream. + streamCtx, cancel := context.WithCancel(ctx) + defer cancel() + + ch := make(chan *storagepb.ReadRowsResponse) + + // Use a waitgroup to coordinate the reading and decoding goroutines. + var wg sync.WaitGroup + + // Start the reading in one goroutine. + wg.Add(1) + go func() { + defer wg.Done() + // Always close ch so the decoder goroutine cannot block forever + // waiting for rows that will never arrive. + defer close(ch) + if err := processStream(t, streamCtx, bqReadClient, readStream, ch); err != nil { + t.Errorf("processStream failure: %v", err) + } + }() + + // Start Avro processing and decoding in another goroutine. + wg.Add(1) + go func() { + defer wg.Done() + // Cancel the read stream when decoding stops so the reader + // goroutine cannot block forever sending on ch. + defer cancel() + if err := processAvro(t, streamCtx, session.GetAvroSchema().GetSchema(), ch); err != nil { + t.Errorf("error processing %s: %v", storagepb.DataFormat_AVRO, err) + } + }() + + // Wait until both the reading and decoding goroutines complete. + wg.Wait() +} + +func TestStorageReadARROW(t *testing.T) { + const ( + project = "test" + dataset = "dataset1" + table = "table_a" + ) + ctx := context.Background() + bqServer, err := server.New(server.TempStorage) + if err != nil { + t.Fatal(err) + } + if err := bqServer.Load(server.YAMLSource(filepath.Join("testdata", "data.yaml"))); err != nil { + t.Fatal(err) + } + testServer := bqServer.TestServer() + defer func() { + testServer.Close() + bqServer.Close() + }() + opts, err := testServer.GRPCClientOptions(ctx) + if err != nil { + t.Fatal(err) + } + bqReadClient, err := bqStorage.NewBigQueryReadClient(ctx, opts...) + if err != nil { + t.Fatal(err) + } + defer bqReadClient.Close() + + readTable := fmt.Sprintf("projects/%s/datasets/%s/tables/%s", project, dataset, table) + + tableReadOptions := &storagepb.ReadSession_TableReadOptions{ + SelectedFields: outputColumns, + RowRestriction: `id = 1`, + } + + createReadSessionRequest := &storagepb.CreateReadSessionRequest{ + Parent: fmt.Sprintf("projects/%s", project), + ReadSession: &storagepb.ReadSession{ + Table: readTable, + DataFormat: storagepb.DataFormat_ARROW, + ReadOptions: tableReadOptions, + }, + MaxStreamCount: 1, + } + + // Create the session from the request. + session, err := bqReadClient.CreateReadSession(ctx, createReadSessionRequest, rpcOpts) + if err != nil { + t.Fatalf("CreateReadSession: %v", err) + } + if len(session.GetStreams()) == 0 { + t.Fatalf("no streams in session. if this was a small query result, consider writing to output to a named table.") + } + + // We'll use only a single stream for reading data from the table. Because + // of dynamic sharding, this will yield all the rows in the table. However, + // if you wanted to fan out multiple readers you could do so by having a + // increasing the MaxStreamCount. + readStream := session.GetStreams()[0].Name + + // streamCtx lets either goroutine unblock the other on failure: the + // decoder cancels it when it stops, which aborts the reader's stream. + streamCtx, cancel := context.WithCancel(ctx) + defer cancel() + + ch := make(chan *storagepb.ReadRowsResponse) + + // Use a waitgroup to coordinate the reading and decoding goroutines. + var wg sync.WaitGroup + + // Start the reading in one goroutine. + wg.Add(1) + go func() { + defer wg.Done() + // Always close ch so the decoder goroutine cannot block forever + // waiting for rows that will never arrive. + defer close(ch) + if err := processStream(t, streamCtx, bqReadClient, readStream, ch); err != nil { + t.Errorf("processStream failure: %v", err) + } + }() + + // Start Arrow processing and decoding in another goroutine. + wg.Add(1) + go func() { + defer wg.Done() + // Cancel the read stream when decoding stops so the reader + // goroutine cannot block forever sending on ch. + defer cancel() + if err := processArrow(t, streamCtx, session.GetArrowSchema().GetSerializedSchema(), ch); err != nil { + t.Errorf("error processing %s: %v", storagepb.DataFormat_ARROW, err) + } + }() + + // Wait until both the reading and decoding goroutines complete. + wg.Wait() +} + +func processStream(t *testing.T, ctx context.Context, client *bqStorage.BigQueryReadClient, st string, ch chan<- *storagepb.ReadRowsResponse) error { + var offset int64 + + // Streams may be long-running. Rather than using a global retry for the + // stream, implement a retry that resets once progress is made. + retryLimit := 3 + retries := 0 + for { + // Send the initiating request to start streaming row blocks. + rowStream, err := client.ReadRows(ctx, &storagepb.ReadRowsRequest{ + ReadStream: st, + Offset: offset, + }, rpcOpts) + if err != nil { + return fmt.Errorf("couldn't invoke ReadRows: %v", err) + } + + // Process the streamed responses. + for { + r, err := rowStream.Recv() + if err == io.EOF { + return nil + } + if err != nil { + // If there is an error, check whether it is a retryable + // error with a retry delay and sleep instead of increasing + // retries count. + var retryDelayDuration time.Duration + if errorStatus, ok := status.FromError(err); ok && errorStatus.Code() == codes.ResourceExhausted { + for _, detail := range errorStatus.Details() { + retryInfo, ok := detail.(*errdetails.RetryInfo) + if !ok { + continue + } + retryDelay := retryInfo.GetRetryDelay() + retryDelayDuration = time.Duration(retryDelay.Seconds)*time.Second + time.Duration(retryDelay.Nanos)*time.Nanosecond + break + } + } + if retryDelayDuration != 0 { + t.Logf("processStream failed with a retryable error, retrying in %v", retryDelayDuration) + time.Sleep(retryDelayDuration) + } else { + retries++ + if retries >= retryLimit { + return fmt.Errorf("processStream retries exhausted: %v", err) + } + } + // break the inner loop, and try to recover by starting a new streaming + // ReadRows call at the last known good offset. + break + } else { + // Reset retries after a successful response. + retries = 0 + } + + rc := r.GetRowCount() + if rc > 0 { + // Bookmark our progress in case of retries and send the rowblock on the channel. + offset = offset + rc + // We're making progress, reset retries. + retries = 0 + select { + case ch <- r: + case <-ctx.Done(): + return ctx.Err() + } + } + } + } +} + +// processAvro receives row blocks from a channel, and uses the provided Avro +// schema to decode the blocks into individual row messages for printing. Will +// continue to run until the channel is closed or the provided context is +// cancelled. +func processAvro(t *testing.T, ctx context.Context, schema string, ch <-chan *storagepb.ReadRowsResponse) error { + // Establish a decoder that can process blocks of messages using the + // reference schema. All blocks share the same schema, so the decoder + // can be long-lived. + codec, err := goavro.NewCodec(schema) + if err != nil { + return fmt.Errorf("couldn't create codec: %v", err) + } + + for { + select { + case <-ctx.Done(): + // Context was cancelled. Stop. + return ctx.Err() + case rows, ok := <-ch: + if !ok { + // Channel closed, no further avro messages. Stop. + return nil + } + undecoded := rows.GetAvroRows().GetSerializedBinaryRows() + for len(undecoded) > 0 { + datum, remainingBytes, err := codec.NativeFromBinary(undecoded) + + if err != nil { + if err == io.EOF { + break + } + return fmt.Errorf("decoding error with %d bytes remaining: %v", len(undecoded), err) + } + validateDatum(t, datum) + undecoded = remainingBytes + } + } + } +} + +func processArrow(t *testing.T, ctx context.Context, schema []byte, ch <-chan *storagepb.ReadRowsResponse) error { + mem := memory.NewGoAllocator() + for { + select { + case <-ctx.Done(): + // Context was cancelled. Stop. + return ctx.Err() + case rows, ok := <-ch: + if !ok { + // Channel closed, no further arrow messages. Stop. + return nil + } + undecoded := rows.GetArrowRecordBatch().GetSerializedRecordBatch() + if len(undecoded) > 0 { + // Reconstruct a valid IPC stream: bare schema message + bare record batch message. + // The BigQuery Storage API sends them as separate bare messages; ipc.NewReader + // requires a stream that starts with a schema message. + stream := append(schema, undecoded...) + r, err := ipc.NewReader(bytes.NewBuffer(stream), ipc.WithAllocator(mem)) + if err != nil { + return err + } + for r.Next() { + rec := r.Record() + validateArrowRecord(t, rec) + } + } + } + } +} + +func validateArrowRecord(t *testing.T, record arrow.Record) { + out, err := record.MarshalJSON() + if err != nil { + t.Fatal(err) + } + list := []map[string]interface{}{} + if err := json.Unmarshal(out, &list); err != nil { + t.Fatal(err) + } + if len(list) == 0 { + t.Fatal("failed to get arrow record") + } + first := list[0] + if len(first) != len(outputColumns) { + t.Fatalf("failed to get arrow record %+v", first) + } +} + +func validateDatum(t *testing.T, d interface{}) { + m, ok := d.(map[string]interface{}) + if !ok { + t.Logf("failed type assertion: %v", d) + } + if len(m) != len(outputColumns) { + t.Fatalf("failed to receive table data. expected columns %v but got %v", outputColumns, m) + } +} + +func TestStorageWrite(t *testing.T) { + for _, test := range []struct { + name string + streamType storagepb.WriteStream_Type + isDefaultStream bool + expectedRowsAfterFirstWrite int + expectedRowsAfterSecondWrite int + expectedRowsAfterThirdWrite int + expectedRowsAfterExplicitCommit int + }{ + { + name: "pending", + streamType: storagepb.WriteStream_PENDING, + expectedRowsAfterFirstWrite: 0, + expectedRowsAfterSecondWrite: 0, + expectedRowsAfterThirdWrite: 0, + expectedRowsAfterExplicitCommit: 6, + }, + { + name: "committed", + streamType: storagepb.WriteStream_COMMITTED, + expectedRowsAfterFirstWrite: 1, + expectedRowsAfterSecondWrite: 4, + expectedRowsAfterThirdWrite: 6, + expectedRowsAfterExplicitCommit: 6, + }, + { + name: "default", + streamType: storagepb.WriteStream_COMMITTED, + isDefaultStream: true, + expectedRowsAfterFirstWrite: 1, + expectedRowsAfterSecondWrite: 4, + expectedRowsAfterThirdWrite: 6, + expectedRowsAfterExplicitCommit: 6, + }, + } { + const ( + projectID = "test" + datasetID = "test" + tableID = "sample" + ) + + ctx := context.Background() + bqServer, err := server.New(server.TempStorage) + if err != nil { + t.Fatal(err) + } + if err := bqServer.Load( + server.StructSource( + types.NewProject( + projectID, + types.NewDataset( + datasetID, + types.NewTable( + tableID, + []*types.Column{ + types.NewColumn("bool_col", types.BOOL), + types.NewColumn("bytes_col", types.BYTES), + types.NewColumn("float64_col", types.FLOAT64), + types.NewColumn("int64_col", types.INT64), + types.NewColumn("string_col", types.STRING), + types.NewColumn("date_col", types.DATE), + types.NewColumn("datetime_col", types.DATETIME), + types.NewColumn("geography_col", types.GEOGRAPHY), + types.NewColumn("numeric_col", types.NUMERIC), + types.NewColumn("bignumeric_col", types.BIGNUMERIC), + types.NewColumn("time_col", types.TIME), + types.NewColumn("timestamp_col", types.TIMESTAMP), + types.NewColumn("int64_list", types.INT64, types.ColumnMode(types.RepeatedMode)), + types.NewColumn( + "struct_col", + types.STRUCT, + types.ColumnFields( + types.NewColumn("sub_int_col", types.INT64), + ), + ), + types.NewColumn( + "struct_list", + types.STRUCT, + types.ColumnFields( + types.NewColumn("sub_int_col", types.INT64), + ), + types.ColumnMode(types.RepeatedMode), + ), + }, + nil, + ), + ), + ), + ), + ); err != nil { + t.Fatal(err) + } + testServer := bqServer.TestServer() + defer func() { + testServer.Close() + bqServer.Close() + }() + opts, err := testServer.GRPCClientOptions(ctx) + if err != nil { + t.Fatal(err) + } + + client, err := managedwriter.NewClient(ctx, projectID, opts...) + if err != nil { + t.Fatal(err) + } + defer client.Close() + t.Run(test.name, func(t *testing.T) { + var writeStreamName string + fullTableName := fmt.Sprintf("projects/%s/datasets/%s/tables/%s", projectID, datasetID, tableID) + if !test.isDefaultStream { + writeStream, err := client.CreateWriteStream(ctx, &storagepb.CreateWriteStreamRequest{ + Parent: fullTableName, + WriteStream: &storagepb.WriteStream{ + Type: test.streamType, + }, + }) + if err != nil { + t.Fatalf("CreateWriteStream: %v", err) + } + writeStreamName = writeStream.GetName() + } + m := &exampleproto.SampleData{} + descriptorProto, err := adapt.NormalizeDescriptor(m.ProtoReflect().Descriptor()) + if err != nil { + t.Fatalf("NormalizeDescriptor: %v", err) + } + var writerOptions []managedwriter.WriterOption + if test.isDefaultStream { + writerOptions = append(writerOptions, managedwriter.WithType(managedwriter.DefaultStream)) + writerOptions = append(writerOptions, managedwriter.WithDestinationTable(fullTableName)) + } else { + writerOptions = append(writerOptions, managedwriter.WithStreamName(writeStreamName)) + } + writerOptions = append(writerOptions, managedwriter.WithSchemaDescriptor(descriptorProto)) + managedStream, err := client.NewManagedStream( + ctx, + writerOptions..., + ) + if err != nil { + t.Fatalf("NewManagedStream: %v", err) + } + + bqClient, err := bigquery.NewClient( + ctx, + projectID, + option.WithEndpoint(testServer.URL), + option.WithoutAuthentication(), + ) + if err != nil { + t.Fatal(err) + } + defer bqClient.Close() + + rows, err := generateExampleMessages(1) + if err != nil { + t.Fatalf("generateExampleMessages: %v", err) + } + + var ( + curOffset int64 + results []*managedwriter.AppendResult + ) + result, err := managedStream.AppendRows(ctx, rows, managedwriter.WithOffset(0)) + if err != nil { + t.Fatalf("AppendRows first call error: %v", err) + } + // managedwriter's AppendRows returns a future; the gRPC + // request itself is sent (and acknowledged) asynchronously. + // Wait for the ACK before issuing the Read below — otherwise + // the Read can race the server's COMMITTED-stream insert and + // observe an empty (or partial) table. The wazero-backed + // SQL engine was slow enough to hide the race; a native-Go + // engine (e.g. wasm2go) surfaces it as a flake on the first + // / second AppendRows iteration. + if _, err := result.GetResult(ctx); err != nil { + t.Fatalf("first AppendRows result error: %v", err) + } + + iter := bqClient.Dataset(datasetID).Table(tableID).Read(ctx) + resultRowCount := countRows(t, iter) + if resultRowCount != test.expectedRowsAfterFirstWrite { + t.Fatalf("expected the number of rows after first AppendRows %d but got %d", test.expectedRowsAfterFirstWrite, resultRowCount) + } + + results = append(results, result) + curOffset = curOffset + 1 + rows, err = generateExampleMessages(3) + if err != nil { + t.Fatalf("generateExampleMessages: %v", err) + } + result, err = managedStream.AppendRows(ctx, rows, managedwriter.WithOffset(curOffset)) + if err != nil { + t.Fatalf("AppendRows second call error: %v", err) + } + if _, err := result.GetResult(ctx); err != nil { + t.Fatalf("second AppendRows result error: %v", err) + } + + iter = bqClient.Dataset(datasetID).Table(tableID).Read(ctx) + resultRowCount = countRows(t, iter) + if resultRowCount != test.expectedRowsAfterSecondWrite { + t.Fatalf("expected the number of rows after second AppendRows %d but got %d", test.expectedRowsAfterSecondWrite, resultRowCount) + } + + results = append(results, result) + curOffset = curOffset + 3 + rows, err = generateExampleMessages(2) + if err != nil { + t.Fatalf("generateExampleMessages: %v", err) + } + result, err = managedStream.AppendRows(ctx, rows, managedwriter.WithOffset(curOffset)) + if err != nil { + t.Fatalf("AppendRows third call error: %v", err) + } + results = append(results, result) + + for k, v := range results { + recvOffset, err := v.GetResult(ctx) + if err != nil { + t.Fatalf("append %d returned error: %v", k, err) + } + t.Logf("Successfully appended data at offset %d", recvOffset) + } + + iter = bqClient.Dataset(datasetID).Table(tableID).Read(ctx) + resultRowCount = countRows(t, iter) + if resultRowCount != test.expectedRowsAfterThirdWrite { + t.Fatalf("expected the number of rows after third AppendRows %d but got %d", test.expectedRowsAfterThirdWrite, resultRowCount) + } + + rowCount, err := managedStream.Finalize(ctx) + if err != nil { + t.Fatalf("error during Finalize: %v", err) + } + + t.Logf("Stream %s finalized with %d rows", managedStream.StreamName(), rowCount) + + req := &storagepb.BatchCommitWriteStreamsRequest{ + Parent: managedwriter.TableParentFromStreamName(managedStream.StreamName()), + WriteStreams: []string{managedStream.StreamName()}, + } + + resp, err := client.BatchCommitWriteStreams(ctx, req) + if err != nil { + t.Fatalf("client.BatchCommit: %v", err) + } + if len(resp.GetStreamErrors()) > 0 { + t.Fatalf("stream errors present: %v", resp.GetStreamErrors()) + } + + iter = bqClient.Dataset(datasetID).Table(tableID).Read(ctx) + resultRowCount = countRows(t, iter) + if resultRowCount != test.expectedRowsAfterExplicitCommit { + t.Fatalf("expected the number of rows after Finalize %d but got %d", test.expectedRowsAfterExplicitCommit, resultRowCount) + } + + t.Logf("Table data committed at %s", resp.GetCommitTime().AsTime().Format(time.RFC3339Nano)) + }) + } +} + +func countRows(t *testing.T, iter *bigquery.RowIterator) int { + var resultRowCount int + for { + v := map[string]bigquery.Value{} + if err := iter.Next(&v); err != nil { + if err == iterator.Done { + break + } + t.Fatal(err) + } + resultRowCount++ + } + return resultRowCount +} + +func generateExampleMessages(numMessages int) ([][]byte, error) { + msgs := make([][]byte, numMessages) + for i := 0; i < numMessages; i++ { + + random := rand.New(rand.NewSource(time.Now().UnixNano())) + + // Our example data embeds an array of structs, so we'll construct that first. + sList := make([]*exampleproto.SampleStruct, 5) + for i := 0; i < int(random.Int63n(5)+1); i++ { + sList[i] = &exampleproto.SampleStruct{ + SubIntCol: proto.Int64(random.Int63()), + } + } + + m := &exampleproto.SampleData{ + BoolCol: proto.Bool(true), + BytesCol: []byte("some bytes"), + Float64Col: proto.Float64(3.14), + Int64Col: proto.Int64(123), + StringCol: proto.String("example string value"), + + // These types require special encoding/formatting to transmit. + + // DATE values are number of days since the Unix epoch. + + DateCol: proto.Int32(int32(time.Now().UnixNano() / 86400000000000)), + + // DATETIME uses the literal format. + DatetimeCol: proto.String("2022-01-01 12:13:14.000000"), + + // GEOGRAPHY uses Well-Known-Text (WKT) format. + GeographyCol: proto.String("POINT(-122.350220 47.649154)"), + + // NUMERIC and BIGNUMERIC can be passed as string, or more efficiently + // using a packed byte representation. + NumericCol: proto.String("99999999999999999999999999999.999999999"), + BignumericCol: proto.String("578960446186580977117854925043439539266.34992332820282019728792003956564819967"), + + // TIME also uses literal format. + TimeCol: proto.String("12:13:14.000000"), + + // TIMESTAMP uses microseconds since Unix epoch. + TimestampCol: proto.Int64(time.Now().UnixNano() / 1000), + + // Int64List is an array of INT64 types. + Int64List: []int64{2, 4, 6, 8}, + + // This is a required field, and thus must be present. + RowNum: proto.Int64(23), + + // StructCol is a single nested message. + StructCol: &exampleproto.SampleStruct{ + SubIntCol: proto.Int64(random.Int63()), + }, + + // StructList is a repeated array of a nested message. + StructList: sList, + } + b, err := proto.Marshal(m) + if err != nil { + return nil, fmt.Errorf("error generating message %d: %w", i, err) + } + msgs[i] = b + } + return msgs, nil +} + +// TestIssue382CreateReadSessionWithoutReadOptions is a regression test for +// https://github.com/goccy/bigquery-emulator/issues/382: CreateReadSession +// panicked with a nil pointer dereference when the request's ReadSession +// carried no ReadOptions (a valid request shape — read the whole table with +// no column projection or row restriction). +func TestIssue382CreateReadSessionWithoutReadOptions(t *testing.T) { + const ( + project = "test" + dataset = "dataset1" + table = "table_a" + ) + ctx := context.Background() + bqServer, err := server.New(server.TempStorage) + if err != nil { + t.Fatal(err) + } + if err := bqServer.Load(server.YAMLSource(filepath.Join("testdata", "data.yaml"))); err != nil { + t.Fatal(err) + } + testServer := bqServer.TestServer() + defer func() { + testServer.Close() + bqServer.Close() + }() + opts, err := testServer.GRPCClientOptions(ctx) + if err != nil { + t.Fatal(err) + } + bqReadClient, err := bqStorage.NewBigQueryReadClient(ctx, opts...) + if err != nil { + t.Fatal(err) + } + defer bqReadClient.Close() + + // ReadSession deliberately carries no ReadOptions. + req := &storagepb.CreateReadSessionRequest{ + Parent: fmt.Sprintf("projects/%s", project), + ReadSession: &storagepb.ReadSession{ + Table: fmt.Sprintf("projects/%s/datasets/%s/tables/%s", + project, dataset, table), + DataFormat: storagepb.DataFormat_AVRO, + }, + MaxStreamCount: 1, + } + session, err := bqReadClient.CreateReadSession(ctx, req, rpcOpts) + if err != nil { + t.Fatalf("CreateReadSession without ReadOptions: %v", err) + } + if len(session.GetStreams()) == 0 { + t.Fatal("expected at least one stream in the session") + } + if session.GetAvroSchema().GetSchema() == "" { + t.Fatal("expected an AVRO schema on the session") + } + + // The session must still be readable end to end: with no selected + // fields the whole row is streamed back. + streamCtx, cancel := context.WithCancel(ctx) + defer cancel() + ch := make(chan *storagepb.ReadRowsResponse) + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + defer close(ch) + if err := processStream(t, streamCtx, bqReadClient, session.GetStreams()[0].Name, ch); err != nil { + t.Errorf("processStream failure: %v", err) + } + }() + go func() { + defer wg.Done() + defer cancel() + if err := processAvro(t, streamCtx, session.GetAvroSchema().GetSchema(), ch); err != nil { + t.Errorf("error processing AVRO: %v", err) + } + }() + wg.Wait() +} + +// TestCreateReadSessionDefaultsToOneStream is a regression test for +// https://github.com/goccy/bigquery-emulator/issues/409: a +// CreateReadSession request with MaxStreamCount left at the zero default +// produced a session carrying zero streams, so the client could never read +// any rows. Real BigQuery picks a sensible default; the emulator caps at one +// stream, so an unset MaxStreamCount must yield one stream too. +func TestCreateReadSessionDefaultsToOneStream(t *testing.T) { + const ( + project = "test" + dataset = "dataset1" + table = "table_a" + ) + ctx := context.Background() + bqServer, err := server.New(server.TempStorage) + if err != nil { + t.Fatal(err) + } + if err := bqServer.Load(server.YAMLSource(filepath.Join("testdata", "data.yaml"))); err != nil { + t.Fatal(err) + } + testServer := bqServer.TestServer() + defer func() { + testServer.Close() + bqServer.Close() + }() + opts, err := testServer.GRPCClientOptions(ctx) + if err != nil { + t.Fatal(err) + } + bqReadClient, err := bqStorage.NewBigQueryReadClient(ctx, opts...) + if err != nil { + t.Fatal(err) + } + defer bqReadClient.Close() + + // MaxStreamCount intentionally omitted (zero value). + req := &storagepb.CreateReadSessionRequest{ + Parent: fmt.Sprintf("projects/%s", project), + ReadSession: &storagepb.ReadSession{ + Table: fmt.Sprintf("projects/%s/datasets/%s/tables/%s", + project, dataset, table), + DataFormat: storagepb.DataFormat_AVRO, + ReadOptions: &storagepb.ReadSession_TableReadOptions{ + SelectedFields: outputColumns, + }, + }, + } + session, err := bqReadClient.CreateReadSession(ctx, req, rpcOpts) + if err != nil { + t.Fatalf("CreateReadSession: %v", err) + } + if got := len(session.GetStreams()); got != 1 { + t.Fatalf("session has %d streams; want 1 (unset MaxStreamCount must default to one)", got) + } +}