diff --git a/cmd/substreams/sink_clickhouse.go b/cmd/substreams/sink_clickhouse.go new file mode 100644 index 000000000..cebc38dc9 --- /dev/null +++ b/cmd/substreams/sink_clickhouse.go @@ -0,0 +1,194 @@ +package main + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/jhump/protoreflect/desc" + "github.com/spf13/cobra" + "github.com/streamingfast/cli/sflags" + "github.com/streamingfast/derr" + sinksqlbytes "github.com/streamingfast/substreams/sink/sql/bytes" + dbchangesdb "github.com/streamingfast/substreams/sink/sql/db_changes/db" + dbchangessinker "github.com/streamingfast/substreams/sink/sql/db_changes/sinker" + dbproto "github.com/streamingfast/substreams/sink/sql/db_proto" + dbprotoproto "github.com/streamingfast/substreams/sink/sql/db_proto/proto" + "github.com/streamingfast/substreams/sink" + "google.golang.org/protobuf/types/descriptorpb" +) + +func init() { + sink.AddFlagsToSet(sinkClickhouseCmd.Flags(), + sink.FlagExcludeDefault(sink.FlagUndoBufferSize)) + + sinkClickhouseCmd.Flags().String("on-module-hash-mismatch", "error", "What to do when the module hash in the manifest does not match the one in the database, can be 'error', 'warn' or 'ignore'") + sinkClickhouseCmd.Flags().String("cursors-table", "cursors", "Name of the table to use for storing cursors") + sinkClickhouseCmd.Flags().String("history-table", "substreams_history", "Name of the table to use for storing block history") + sinkClickhouseCmd.Flags().String("bytes-encoding", "raw", "Encoding for protobuf bytes fields: raw, hex, 0xhex, base64, base58") + sinkClickhouseCmd.Flags().Int("batch-block-flush-interval", 1000, "When in catch up mode, flush every N blocks") + sinkClickhouseCmd.Flags().Int("batch-row-flush-interval", 100000, "When in catch up mode, flush every N rows") + sinkClickhouseCmd.Flags().Int("live-block-flush-interval", 1, "When processing in live mode, flush every N blocks") + sinkClickhouseCmd.Flags().Int("flush-retry-count", 3, "Number of retry attempts for flush operations") + sinkClickhouseCmd.Flags().Duration("flush-retry-delay", 1*time.Second, "Base delay for retry backoff on flush failures") + sinkClickhouseCmd.Flags().Bool("no-constraints", false, "Do not add constraints to the database (proto-based mode only)") + sinkClickhouseCmd.Flags().Int("block-batch-size", 25, "Number of blocks to process at a time (proto-based mode only)") + sinkClickhouseCmd.Flags().String("clickhouse-cluster", "", "If non-empty, a 'ON CLUSTER ' clause will be applied when setting up tables in ClickHouse") + sinkClickhouseCmd.Flags().String("clickhouse-sink-info-folder", "", "Folder where to store the ClickHouse sink info (proto-based mode only)") + sinkClickhouseCmd.Flags().String("clickhouse-cursor-file-path", "cursor.txt", "File path where to store the ClickHouse cursor (proto-based mode only)") + sinkClickhouseCmd.Flags().Int("clickhouse-query-retry-count", 3, "Number of retries for ClickHouse queries when an error occurs") + sinkClickhouseCmd.Flags().Duration("clickhouse-query-retry-sleep", time.Second, "Sleep duration between ClickHouse query retries") + + SinkCmd.AddCommand(sinkClickhouseCmd) +} + +var sinkClickhouseCmd = &cobra.Command{ + Use: "clickhouse [ []]", + Short: "Run a ClickHouse sink for Substreams", + RunE: sinkClickhouseE, + Args: cobra.RangeArgs(1, 3), +} + +func sinkClickhouseE(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + cmd.SilenceUsage = true + + dsnString := args[0] + var manifestPath, outputModule string + if len(args) > 1 { + manifestPath = args[1] + } + if len(args) > 2 { + outputModule = args[2] + } + + sink.LoadSubstreamsAuthEnvFile(manifestPath) + + sinkerConfig, err := sink.ConfigFromViper(cmd, sink.IgnoreOutputModuleType, manifestPath, outputModule, "sink_clickhouse", zlog, tracer) + if err != nil { + return err + } + + outputType := strings.TrimPrefix(sinkerConfig.OutputModule.Output.Type, "proto:") + + ctx, cancel := context.WithCancel(ctx) + go func() { + <-derr.SetupSignalHandler(0) + cancel() + }() + + if strings.Contains(outputType, "DatabaseChanges") { + return sinkClickhouseDatabaseChanges(ctx, cmd, dsnString, sinkerConfig) + } + return sinkClickhouseProto(ctx, cmd, dsnString, sinkerConfig, outputType) +} + +func sinkClickhouseDatabaseChanges(ctx context.Context, cmd *cobra.Command, dsnString string, sinkerConfig *sink.SinkerConfig) error { + dbchangessinker.RegisterMetrics() + + baseSink, err := sink.NewFromConfig(sinkerConfig) + if err != nil { + return fmt.Errorf("creating base sinker: %w", err) + } + + sinkerFactory := dbchangessinker.SinkerFactory(baseSink, dbchangessinker.SinkerFactoryOptions{ + CursorTableName: sflags.MustGetString(cmd, "cursors-table"), + HistoryTableName: sflags.MustGetString(cmd, "history-table"), + ClickhouseCluster: sflags.MustGetString(cmd, "clickhouse-cluster"), + BatchBlockFlushInterval: sflags.MustGetInt(cmd, "batch-block-flush-interval"), + BatchRowFlushInterval: sflags.MustGetInt(cmd, "batch-row-flush-interval"), + LiveBlockFlushInterval: sflags.MustGetInt(cmd, "live-block-flush-interval"), + OnModuleHashMismatch: sflags.MustGetString(cmd, "on-module-hash-mismatch"), + HandleReorgs: false, + FlushRetryCount: sflags.MustGetInt(cmd, "flush-retry-count"), + FlushRetryDelay: sflags.MustGetDuration(cmd, "flush-retry-delay"), + }) + + sqlSinker, err := sinkerFactory(ctx, dsnString, zlog, tracer) + if err != nil { + return fmt.Errorf("unable to setup sql sinker: %w", err) + } + + sqlSinker.Run(ctx) + return sqlSinker.Err() +} + +func sinkClickhouseProto(ctx context.Context, cmd *cobra.Command, dsnString string, sinkerConfig *sink.SinkerConfig, outputType string) error { + dsn, err := dbchangesdb.ParseDSN(dsnString) + if err != nil { + return fmt.Errorf("parsing dsn: %w", err) + } + + spkg := sinkerConfig.Pkg + protoFiles := make(map[string]*descriptorpb.FileDescriptorProto, len(spkg.ProtoFiles)) + for _, file := range spkg.ProtoFiles { + protoFiles[file.GetName()] = file + } + + deps, err := dbprotoproto.ResolveDependencies(protoFiles) + if err != nil { + return fmt.Errorf("resolving dependencies: %w", err) + } + + fileDescriptor, err := dbprotoproto.FileDescriptorForOutputType(spkg, nil, deps, outputType) + if err != nil { + return fmt.Errorf("finding file descriptor for output type %q: %w", outputType, err) + } + + var rootMessageDescriptor *desc.MessageDescriptor + for _, md := range fileDescriptor.GetMessageTypes() { + if md.GetFullyQualifiedName() == outputType { + rootMessageDescriptor = md + break + } + } + if rootMessageDescriptor == nil { + return fmt.Errorf("message descriptor not found for output type %q, ensure your substreams bundles its protobuf definitions", outputType) + } + + useConstraints := !sflags.MustGetBool(cmd, "no-constraints") + useProtoOption := false + for _, dep := range fileDescriptor.GetDependencies() { + if dep.GetName() == "sf/substreams/sink/sql/schema/v1/schema.proto" { + useProtoOption = true + } + } + if !useProtoOption { + useConstraints = false + } + + encodingStr := sflags.MustGetString(cmd, "bytes-encoding") + encoding, err := sinksqlbytes.ParseEncoding(encodingStr) + if err != nil { + return fmt.Errorf("invalid bytes encoding %q: %w", encodingStr, err) + } + + baseSink, err := sink.NewFromConfig(sinkerConfig) + if err != nil { + return fmt.Errorf("creating base sinker: %w", err) + } + + outputModuleName := sinkerConfig.OutputModule.Name + factory := dbproto.SinkerFactory(baseSink, outputModuleName, rootMessageDescriptor.UnwrapMessage(), dbproto.SinkerFactoryOptions{ + UseProtoOption: useProtoOption, + UseConstraints: useConstraints, + UseTransactions: true, + BlockBatchSize: sflags.MustGetInt(cmd, "block-batch-size"), + Parallel: false, + Encoding: encoding, + Clickhouse: dbproto.SinkerFactoryClickhouse{ + SinkInfoFolder: sflags.MustGetString(cmd, "clickhouse-sink-info-folder"), + CursorFilePath: sflags.MustGetString(cmd, "clickhouse-cursor-file-path"), + QueryRetryCount: sflags.MustGetInt(cmd, "clickhouse-query-retry-count"), + QueryRetrySleep: sflags.MustGetDuration(cmd, "clickhouse-query-retry-sleep"), + }, + }) + + dbProtoSinker, err := factory(ctx, dsnString, dsn.Schema(), zlog, tracer) + if err != nil { + return fmt.Errorf("creating sinker: %w", err) + } + + return dbProtoSinker.Run(ctx) +} diff --git a/cmd/substreams/sink_postgres.go b/cmd/substreams/sink_postgres.go new file mode 100644 index 000000000..55d59332a --- /dev/null +++ b/cmd/substreams/sink_postgres.go @@ -0,0 +1,182 @@ +package main + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/jhump/protoreflect/desc" + "github.com/spf13/cobra" + "github.com/streamingfast/cli/sflags" + "github.com/streamingfast/derr" + sinksqlbytes "github.com/streamingfast/substreams/sink/sql/bytes" + dbchangesdb "github.com/streamingfast/substreams/sink/sql/db_changes/db" + dbchangessinker "github.com/streamingfast/substreams/sink/sql/db_changes/sinker" + dbproto "github.com/streamingfast/substreams/sink/sql/db_proto" + dbprotoproto "github.com/streamingfast/substreams/sink/sql/db_proto/proto" + "github.com/streamingfast/substreams/sink" + "google.golang.org/protobuf/types/descriptorpb" +) + +func init() { + sink.AddFlagsToSet(sinkPostgresCmd.Flags(), + sink.FlagExcludeDefault(sink.FlagUndoBufferSize)) + + sinkPostgresCmd.Flags().String("on-module-hash-mismatch", "error", "What to do when the module hash in the manifest does not match the one in the database, can be 'error', 'warn' or 'ignore'") + sinkPostgresCmd.Flags().String("cursors-table", "cursors", "Name of the table to use for storing cursors") + sinkPostgresCmd.Flags().String("history-table", "substreams_history", "Name of the table to use for storing block history, used to handle reorgs") + sinkPostgresCmd.Flags().String("bytes-encoding", "raw", "Encoding for protobuf bytes fields: raw, hex, 0xhex, base64, base58") + sinkPostgresCmd.Flags().Int("batch-block-flush-interval", 1000, "When in catch up mode, flush every N blocks") + sinkPostgresCmd.Flags().Int("batch-row-flush-interval", 100000, "When in catch up mode, flush every N rows") + sinkPostgresCmd.Flags().Int("live-block-flush-interval", 1, "When processing in live mode, flush every N blocks") + sinkPostgresCmd.Flags().Int("flush-retry-count", 3, "Number of retry attempts for flush operations") + sinkPostgresCmd.Flags().Duration("flush-retry-delay", 1*time.Second, "Base delay for retry backoff on flush failures") + sinkPostgresCmd.Flags().Bool("no-constraints", false, "Do not add constraints to the database (proto-based mode only)") + sinkPostgresCmd.Flags().Int("block-batch-size", 25, "Number of blocks to process at a time (proto-based mode only)") + + SinkCmd.AddCommand(sinkPostgresCmd) +} + +var sinkPostgresCmd = &cobra.Command{ + Use: "postgres [ []]", + Short: "Run a PostgreSQL sink for Substreams", + RunE: sinkPostgresE, + Args: cobra.RangeArgs(1, 3), +} + +func sinkPostgresE(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + cmd.SilenceUsage = true + + dsnString := args[0] + var manifestPath, outputModule string + if len(args) > 1 { + manifestPath = args[1] + } + if len(args) > 2 { + outputModule = args[2] + } + + sink.LoadSubstreamsAuthEnvFile(manifestPath) + + sinkerConfig, err := sink.ConfigFromViper(cmd, sink.IgnoreOutputModuleType, manifestPath, outputModule, "sink_postgres", zlog, tracer) + if err != nil { + return err + } + + outputType := strings.TrimPrefix(sinkerConfig.OutputModule.Output.Type, "proto:") + + ctx, cancel := context.WithCancel(ctx) + go func() { + <-derr.SetupSignalHandler(0) + cancel() + }() + + if strings.Contains(outputType, "DatabaseChanges") { + return sinkPostgresDatabaseChanges(ctx, cmd, dsnString, sinkerConfig) + } + return sinkPostgresProto(ctx, cmd, dsnString, sinkerConfig, outputType) +} + +func sinkPostgresDatabaseChanges(ctx context.Context, cmd *cobra.Command, dsnString string, sinkerConfig *sink.SinkerConfig) error { + dbchangessinker.RegisterMetrics() + + baseSink, err := sink.NewFromConfig(sinkerConfig) + if err != nil { + return fmt.Errorf("creating base sinker: %w", err) + } + + sinkerFactory := dbchangessinker.SinkerFactory(baseSink, dbchangessinker.SinkerFactoryOptions{ + CursorTableName: sflags.MustGetString(cmd, "cursors-table"), + HistoryTableName: sflags.MustGetString(cmd, "history-table"), + BatchBlockFlushInterval: sflags.MustGetInt(cmd, "batch-block-flush-interval"), + BatchRowFlushInterval: sflags.MustGetInt(cmd, "batch-row-flush-interval"), + LiveBlockFlushInterval: sflags.MustGetInt(cmd, "live-block-flush-interval"), + OnModuleHashMismatch: sflags.MustGetString(cmd, "on-module-hash-mismatch"), + HandleReorgs: true, + FlushRetryCount: sflags.MustGetInt(cmd, "flush-retry-count"), + FlushRetryDelay: sflags.MustGetDuration(cmd, "flush-retry-delay"), + }) + + sqlSinker, err := sinkerFactory(ctx, dsnString, zlog, tracer) + if err != nil { + return fmt.Errorf("unable to setup sql sinker: %w", err) + } + + sqlSinker.Run(ctx) + return sqlSinker.Err() +} + +func sinkPostgresProto(ctx context.Context, cmd *cobra.Command, dsnString string, sinkerConfig *sink.SinkerConfig, outputType string) error { + dsn, err := dbchangesdb.ParseDSN(dsnString) + if err != nil { + return fmt.Errorf("parsing dsn: %w", err) + } + + spkg := sinkerConfig.Pkg + protoFiles := make(map[string]*descriptorpb.FileDescriptorProto, len(spkg.ProtoFiles)) + for _, file := range spkg.ProtoFiles { + protoFiles[file.GetName()] = file + } + + deps, err := dbprotoproto.ResolveDependencies(protoFiles) + if err != nil { + return fmt.Errorf("resolving dependencies: %w", err) + } + + fileDescriptor, err := dbprotoproto.FileDescriptorForOutputType(spkg, nil, deps, outputType) + if err != nil { + return fmt.Errorf("finding file descriptor for output type %q: %w", outputType, err) + } + + var rootMessageDescriptor *desc.MessageDescriptor + for _, md := range fileDescriptor.GetMessageTypes() { + if md.GetFullyQualifiedName() == outputType { + rootMessageDescriptor = md + break + } + } + if rootMessageDescriptor == nil { + return fmt.Errorf("message descriptor not found for output type %q, ensure your substreams bundles its protobuf definitions", outputType) + } + + useConstraints := !sflags.MustGetBool(cmd, "no-constraints") + useProtoOption := false + for _, dep := range fileDescriptor.GetDependencies() { + if dep.GetName() == "sf/substreams/sink/sql/schema/v1/schema.proto" { + useProtoOption = true + } + } + if !useProtoOption { + useConstraints = false + } + + encodingStr := sflags.MustGetString(cmd, "bytes-encoding") + encoding, err := sinksqlbytes.ParseEncoding(encodingStr) + if err != nil { + return fmt.Errorf("invalid bytes encoding %q: %w", encodingStr, err) + } + + baseSink, err := sink.NewFromConfig(sinkerConfig) + if err != nil { + return fmt.Errorf("creating base sinker: %w", err) + } + + outputModuleName := sinkerConfig.OutputModule.Name + factory := dbproto.SinkerFactory(baseSink, outputModuleName, rootMessageDescriptor.UnwrapMessage(), dbproto.SinkerFactoryOptions{ + UseProtoOption: useProtoOption, + UseConstraints: useConstraints, + UseTransactions: true, + BlockBatchSize: sflags.MustGetInt(cmd, "block-batch-size"), + Parallel: false, + Encoding: encoding, + }) + + dbProtoSinker, err := factory(ctx, dsnString, dsn.Schema(), zlog, tracer) + if err != nil { + return fmt.Errorf("creating sinker: %w", err) + } + + return dbProtoSinker.Run(ctx) +} diff --git a/go.mod b/go.mod index 974bc20bf..fcb8f7892 100644 --- a/go.mod +++ b/go.mod @@ -5,8 +5,15 @@ go 1.25.0 toolchain go1.25.4 require ( + github.com/AfterShip/clickhouse-sql-parser v0.4.9 + github.com/ClickHouse/ch-go v0.68.0 + github.com/ClickHouse/clickhouse-go/v2 v2.40.3 + github.com/drone/envsubst v1.0.3 github.com/golang/protobuf v1.5.4 + github.com/jackc/pgx/v4 v4.18.1 github.com/jhump/protoreflect v1.14.0 + github.com/jmoiron/sqlx v1.4.0 + github.com/lib/pq v1.10.9 github.com/spf13/cobra v1.9.1 github.com/spf13/pflag v1.0.6 github.com/streamingfast/bstream v0.0.2-0.20260402095814-607e840ece3d @@ -20,6 +27,7 @@ require ( github.com/streamingfast/logging v0.0.0-20260108192805-38f96de0a641 github.com/streamingfast/pbgo v0.0.6-0.20250120164644-a58d8066ab4b github.com/stretchr/testify v1.11.1 + github.com/wk8/go-ordered-map/v2 v2.1.7 github.com/yourbasic/graph v0.0.0-20210606180040-8ecfec1c2869 go.uber.org/zap v1.27.1 google.golang.org/protobuf v1.36.11 diff --git a/pb/sf/substreams/sink/database/v1/database.pb.go b/pb/sf/substreams/sink/database/v1/database.pb.go new file mode 100644 index 000000000..92eaa3ef5 --- /dev/null +++ b/pb/sf/substreams/sink/database/v1/database.pb.go @@ -0,0 +1,81 @@ +// Code generated stub for compilation purposes only. +package pbdatabase + +import "google.golang.org/protobuf/reflect/protoreflect" + +type TableChange_Operation int32 + +const ( + TableChange_OPERATION_UNSET TableChange_Operation = 0 + TableChange_OPERATION_CREATE TableChange_Operation = 1 + TableChange_OPERATION_UPDATE TableChange_Operation = 2 + TableChange_OPERATION_DELETE TableChange_Operation = 3 + TableChange_OPERATION_UPSERT TableChange_Operation = 4 +) + +type Field_UpdateOp int32 + +const ( + Field_UPDATE_OP_UNSET Field_UpdateOp = 0 + Field_UPDATE_OP_ADD Field_UpdateOp = 1 + Field_UPDATE_OP_MAX Field_UpdateOp = 2 + Field_UPDATE_OP_MIN Field_UpdateOp = 3 + Field_UPDATE_OP_SET_IF_NULL Field_UpdateOp = 4 +) + +type DatabaseChanges struct { + TableChanges []*TableChange `protobuf:"bytes,1,rep,name=table_changes,json=tableChanges,proto3" json:"table_changes,omitempty"` +} + +func (x *DatabaseChanges) Reset() {} +func (x *DatabaseChanges) String() string { return "" } +func (x *DatabaseChanges) ProtoMessage() {} +func (x *DatabaseChanges) ProtoReflect() protoreflect.Message { return nil } + +type isTableChange_PrimaryKey interface { + isTableChange_PrimaryKey() +} + +type TableChange_Pk struct { + Pk string +} + +func (*TableChange_Pk) isTableChange_PrimaryKey() {} + +type TableChange_CompositePk struct { + CompositePk *CompositePrimaryKey +} + +func (*TableChange_CompositePk) isTableChange_PrimaryKey() {} + +type TableChange struct { + Table string `protobuf:"bytes,1,opt,name=table,proto3" json:"table,omitempty"` + Operation TableChange_Operation `protobuf:"varint,2,opt,name=operation,proto3,enum=sf.substreams.sink.database.v1.TableChange_Operation" json:"operation,omitempty"` + Fields []*Field `protobuf:"bytes,4,rep,name=fields,proto3" json:"fields,omitempty"` + PrimaryKey isTableChange_PrimaryKey +} + +func (x *TableChange) Reset() {} +func (x *TableChange) String() string { return x.Table } +func (x *TableChange) ProtoMessage() {} +func (x *TableChange) ProtoReflect() protoreflect.Message { return nil } + +type CompositePrimaryKey struct { + Keys map[string]string `protobuf:"bytes,1,rep,name=keys,proto3" json:"keys,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"` +} + +func (x *CompositePrimaryKey) Reset() {} +func (x *CompositePrimaryKey) String() string { return "" } +func (x *CompositePrimaryKey) ProtoMessage() {} +func (x *CompositePrimaryKey) ProtoReflect() protoreflect.Message { return nil } + +type Field struct { + Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` + Value string `protobuf:"bytes,3,opt,name=new_value,json=newValue,proto3" json:"new_value,omitempty"` + UpdateOp Field_UpdateOp `protobuf:"varint,5,opt,name=update_op,json=updateOp,proto3,enum=sf.substreams.sink.database.v1.Field_UpdateOp" json:"update_op,omitempty"` +} + +func (x *Field) Reset() {} +func (x *Field) String() string { return x.Name } +func (x *Field) ProtoMessage() {} +func (x *Field) ProtoReflect() protoreflect.Message { return nil } diff --git a/pb/sf/substreams/sink/sql/schema/v1/schema.pb.go b/pb/sf/substreams/sink/sql/schema/v1/schema.pb.go new file mode 100644 index 000000000..223b77780 --- /dev/null +++ b/pb/sf/substreams/sink/sql/schema/v1/schema.pb.go @@ -0,0 +1,174 @@ +// Code generated stub for compilation purposes only. +package pbschema + +import ( + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/runtime/protoimpl" +) + +var E_Table *protoimpl.ExtensionInfo +var E_Field *protoimpl.ExtensionInfo + +type Function int32 + +const ( + Function_unset Function = 0 + Function_toMonth Function = 1 + Function_toDate Function = 2 + Function_toStartOfMonth Function = 3 + Function_toYear Function = 4 + Function_toYYYYDD Function = 5 + Function_toYYYYMM Function = 6 +) + +func (f Function) String() string { + switch f { + case Function_toMonth: + return "toMonth" + case Function_toDate: + return "toDate" + case Function_toStartOfMonth: + return "toStartOfMonth" + case Function_toYear: + return "toYear" + case Function_toYYYYDD: + return "toYYYYDD" + case Function_toYYYYMM: + return "toYYYYMM" + default: + return "unknown" + } +} + +type IndexType string + +func (t IndexType) String() string { return string(t) } + +type Table struct { + Name string + ChildOf *string + ClickhouseTableOptions *ClickhouseTableOptions +} + +func (t *Table) Reset() {} +func (t *Table) String() string { return t.Name } +func (t *Table) ProtoMessage() {} +func (t *Table) ProtoReflect() protoreflect.Message { return nil } + +type Column struct { + Name *string + ForeignKey *string + PrimaryKey bool + Unique bool + ConvertTo *StringConvertion + Inline bool +} + +func (c *Column) Reset() {} +func (c *Column) String() string { return "" } +func (c *Column) ProtoMessage() {} +func (c *Column) ProtoReflect() protoreflect.Message { return nil } + +type StringConvertion struct { + Convertion isStringConvertion_Convertion +} + +func (s *StringConvertion) Reset() {} +func (s *StringConvertion) String() string { return "" } +func (s *StringConvertion) ProtoMessage() {} +func (s *StringConvertion) ProtoReflect() protoreflect.Message { return nil } + +type isStringConvertion_Convertion interface { + isStringConvertion_Convertion() +} + +type StringConvertion_Int128 struct{} + +func (*StringConvertion_Int128) isStringConvertion_Convertion() {} + +type StringConvertion_Uint128 struct{} + +func (*StringConvertion_Uint128) isStringConvertion_Convertion() {} + +type StringConvertion_Int256 struct{} + +func (*StringConvertion_Int256) isStringConvertion_Convertion() {} + +type StringConvertion_Uint256 struct{} + +func (*StringConvertion_Uint256) isStringConvertion_Convertion() {} + +type StringConvertion_Decimal128 struct { + Decimal128 *Decimal128Precision +} + +func (*StringConvertion_Decimal128) isStringConvertion_Convertion() {} + +type StringConvertion_Decimal256 struct { + Decimal256 *Decimal256Precision +} + +func (*StringConvertion_Decimal256) isStringConvertion_Convertion() {} + +type Decimal128Precision struct { + Scale int32 +} + +func (d *Decimal128Precision) Reset() {} +func (d *Decimal128Precision) String() string { return "" } +func (d *Decimal128Precision) ProtoMessage() {} +func (d *Decimal128Precision) ProtoReflect() protoreflect.Message { return nil } + +type Decimal256Precision struct { + Scale int32 +} + +func (d *Decimal256Precision) Reset() {} +func (d *Decimal256Precision) String() string { return "" } +func (d *Decimal256Precision) ProtoMessage() {} +func (d *Decimal256Precision) ProtoReflect() protoreflect.Message { return nil } + +type ClickhouseTableOptions struct { + OrderByFields []*ClickhouseOrderByField + PartitionFields []*ClickhousePartitionField + IndexFields []*ClickhouseIndexField +} + +func (c *ClickhouseTableOptions) Reset() {} +func (c *ClickhouseTableOptions) String() string { return "" } +func (c *ClickhouseTableOptions) ProtoMessage() {} +func (c *ClickhouseTableOptions) ProtoReflect() protoreflect.Message { return nil } + +type ClickhouseOrderByField struct { + Name string + Descending bool + Function Function +} + +func (c *ClickhouseOrderByField) Reset() {} +func (c *ClickhouseOrderByField) String() string { return c.Name } +func (c *ClickhouseOrderByField) ProtoMessage() {} +func (c *ClickhouseOrderByField) ProtoReflect() protoreflect.Message { return nil } + +type ClickhousePartitionField struct { + Name string + Function Function +} + +func (c *ClickhousePartitionField) Reset() {} +func (c *ClickhousePartitionField) String() string { return c.Name } +func (c *ClickhousePartitionField) ProtoMessage() {} +func (c *ClickhousePartitionField) ProtoReflect() protoreflect.Message { return nil } + +type ClickhouseIndexField struct { + Name string + FieldName string + Function Function + Type IndexType + Granularity uint64 +} + +func (c *ClickhouseIndexField) Reset() {} +func (c *ClickhouseIndexField) String() string { return c.Name } +func (c *ClickhouseIndexField) ProtoMessage() {} +func (c *ClickhouseIndexField) ProtoReflect() protoreflect.Message { return nil } diff --git a/pb/sf/substreams/sink/sql/services/v1/services.pb.go b/pb/sf/substreams/sink/sql/services/v1/services.pb.go new file mode 100644 index 000000000..d600907d5 --- /dev/null +++ b/pb/sf/substreams/sink/sql/services/v1/services.pb.go @@ -0,0 +1,13 @@ +// Code generated stub for compilation purposes only. +package pbsql + +import "google.golang.org/protobuf/reflect/protoreflect" + +type Service struct { + Schema string `protobuf:"bytes,1,opt,name=schema,proto3" json:"schema,omitempty"` +} + +func (x *Service) Reset() {} +func (x *Service) String() string { return x.Schema } +func (x *Service) ProtoMessage() {} +func (x *Service) ProtoReflect() protoreflect.Message { return nil } diff --git a/sink/sql/bytes/encoding.go b/sink/sql/bytes/encoding.go new file mode 100644 index 000000000..ce285c919 --- /dev/null +++ b/sink/sql/bytes/encoding.go @@ -0,0 +1,125 @@ +package bytes + +import ( + "encoding/base64" + "encoding/hex" + "fmt" + "strings" + + "github.com/mr-tron/base58/base58" +) + +// Encoding represents the different encoding types for protobuf bytes fields +type Encoding int + +const ( + // EncodingRaw keeps bytes as raw binary data (default) + EncodingRaw Encoding = iota + // EncodingHex encodes bytes as hexadecimal string + EncodingHex + // EncodingHexWith0x encodes bytes as hexadecimal string with 0x prefix + EncodingHexWith0x + // EncodingBase64 encodes bytes as base64 string + EncodingBase64 + // EncodingBase58 encodes bytes as base58 string + EncodingBase58 +) + +// String returns the string representation of the encoding +func (e Encoding) String() string { + switch e { + case EncodingRaw: + return "raw" + case EncodingHex: + return "hex" + case EncodingHexWith0x: + return "0xhex" + case EncodingBase64: + return "base64" + case EncodingBase58: + return "base58" + default: + return "unknown" + } +} + +// ParseEncoding parses a string into an Encoding type +func ParseEncoding(s string) (Encoding, error) { + switch strings.ToLower(s) { + case "raw": + return EncodingRaw, nil + case "hex": + return EncodingHex, nil + case "0xhex": + return EncodingHexWith0x, nil + case "base64": + return EncodingBase64, nil + case "base58": + return EncodingBase58, nil + default: + return EncodingRaw, fmt.Errorf("invalid encoding: %s", s) + } +} + +// IsStringType returns true if the encoding converts bytes to string database type +func (e Encoding) IsStringType() bool { + return e != EncodingRaw +} + +// EncodeBytes encodes the given bytes using the specified encoding +func (e Encoding) EncodeBytes(data []byte) (interface{}, error) { + switch e { + case EncodingRaw: + return data, nil + case EncodingHex: + return hex.EncodeToString(data), nil + case EncodingHexWith0x: + return "0x" + hex.EncodeToString(data), nil + case EncodingBase64: + return base64.StdEncoding.EncodeToString(data), nil + case EncodingBase58: + return base58.Encode(data), nil + default: + return nil, fmt.Errorf("unsupported encoding: %s", e) + } +} + +// DecodeBytes decodes the given string back to bytes using the specified encoding +func (e Encoding) DecodeBytes(encoded interface{}) ([]byte, error) { + switch e { + case EncodingRaw: + if data, ok := encoded.([]byte); ok { + return data, nil + } + return nil, fmt.Errorf("expected []byte for raw encoding, got %T", encoded) + case EncodingHex: + if str, ok := encoded.(string); ok { + return hex.DecodeString(str) + } + return nil, fmt.Errorf("expected string for hex encoding, got %T", encoded) + case EncodingHexWith0x: + if str, ok := encoded.(string); ok { + if strings.HasPrefix(str, "0x") || strings.HasPrefix(str, "0X") { + return hex.DecodeString(str[2:]) + } + return hex.DecodeString(str) + } + return nil, fmt.Errorf("expected string for 0xhex encoding, got %T", encoded) + case EncodingBase64: + if str, ok := encoded.(string); ok { + return base64.StdEncoding.DecodeString(str) + } + return nil, fmt.Errorf("expected string for base64 encoding, got %T", encoded) + case EncodingBase58: + if str, ok := encoded.(string); ok { + decoded, err := base58.Decode(str) + if err != nil { + return nil, fmt.Errorf("base58 decode: %w", err) + } + return decoded, nil + } + return nil, fmt.Errorf("expected string for base58 encoding, got %T", encoded) + default: + return nil, fmt.Errorf("unsupported encoding: %s", e) + } +} diff --git a/sink/sql/db_changes/bundler/bundler.go b/sink/sql/db_changes/bundler/bundler.go new file mode 100644 index 000000000..4e6583980 --- /dev/null +++ b/sink/sql/db_changes/bundler/bundler.go @@ -0,0 +1,224 @@ +package bundler + +import ( + "context" + "errors" + "fmt" + "path" + "time" + + "github.com/streamingfast/bstream" + "github.com/streamingfast/dhammer" + "github.com/streamingfast/dstore" + "github.com/streamingfast/shutter" + "github.com/streamingfast/substreams/sink/sql/db_changes/bundler/writer" + "go.uber.org/zap" +) + +type Bundler struct { + *shutter.Shutter + + blockCount uint64 + stats *boundaryStats + boundaryWriter writer.Writer + outputStore dstore.Store + Header []byte + HeaderWritten bool + + activeBoundary *bstream.Range + stopBlock uint64 + uploadQueue *dhammer.Nailer + zlogger *zap.Logger +} + +var ErrStopBlockReached = errors.New("stop block reached") + +func New( + size uint64, + stopBlock uint64, + boundaryWriter writer.Writer, + outputStore dstore.Store, + zlogger *zap.Logger, + header []byte, +) (*Bundler, error) { + + b := &Bundler{ + Shutter: shutter.New(), + boundaryWriter: boundaryWriter, + outputStore: outputStore, + blockCount: size, + stopBlock: stopBlock, + stats: newStats(), + zlogger: zlogger, + Header: header, + HeaderWritten: false, + } + + b.uploadQueue = dhammer.NewNailer(5, b.uploadBoundary, dhammer.NailerLogger(zlogger)) + + return b, nil +} + +func (b *Bundler) name() string { + return path.Base(b.outputStore.BaseURL().Path) +} + +func (b *Bundler) Launch(ctx context.Context) { + b.OnTerminating(func(err error) { + b.zlogger.Info("shutting down bundler", zap.String("store", b.name()), zap.Error(err)) + b.Close() + }) + b.uploadQueue.Start(ctx) + + go func() { + for v := range b.uploadQueue.Out { + bf := v.(*boundaryFile) + b.zlogger.Debug("uploaded file", zap.String("filename", bf.name)) + } + if b.uploadQueue.Err() != nil { + b.Shutdown(fmt.Errorf("upload queue failed: %w", b.uploadQueue.Err())) + } + }() + + b.uploadQueue.OnTerminating(func(_ error) { + b.Shutdown(fmt.Errorf("upload queue failed: %w", b.uploadQueue.Err())) + }) +} + +func (b *Bundler) Close() { + b.zlogger.Debug("closing upload queue") + b.uploadQueue.Close() + b.zlogger.Debug("waiting till queue is drained") + b.uploadQueue.WaitUntilEmpty(context.Background()) + b.zlogger.Debug("boundary upload completed") +} + +func (b *Bundler) Roll(ctx context.Context, blockNum uint64) (rolled bool, err error) { + if b.activeBoundary.Contains(blockNum) { + return false, nil + } + + boundaries := boundariesToSkip(b.activeBoundary, blockNum, b.blockCount) + + b.zlogger.Info("block_num is not in active boundary", + zap.Stringer("active_boundary", b.activeBoundary), + zap.Int("boundaries_to_skip", len(boundaries)), + zap.Uint64("block_num", blockNum), + ) + + if err := b.stop(ctx); err != nil { + return false, fmt.Errorf("stop active boundary: %w", err) + } + + for _, boundary := range boundaries { + if err := b.Start(boundary.StartBlock()); err != nil { + return false, fmt.Errorf("start skipping boundary: %w", err) + } + if err := b.stop(ctx); err != nil { + return false, fmt.Errorf("stop skipping boundary: %w", err) + } + } + + if blockNum >= b.stopBlock { + return false, ErrStopBlockReached + } + + if err := b.Start(blockNum); err != nil { + return false, fmt.Errorf("start active boundary: %w", err) + } + + return true, nil +} + +func (b *Bundler) TrackBlockProcessDuration(elapsed time.Duration) { + b.stats.addProcessingDataDur(elapsed) +} + +func (b *Bundler) Writer() writer.Writer { + return b.boundaryWriter +} + +func (b *Bundler) Start(blockNum uint64) error { + boundaryRange := b.newBoundary(blockNum) + b.activeBoundary = boundaryRange + + b.zlogger.Debug("starting new file boundary", zap.Stringer("boundary", boundaryRange)) + if err := b.boundaryWriter.StartBoundary(boundaryRange); err != nil { + return fmt.Errorf("start file: %w", err) + } + + b.stats.startBoundary(boundaryRange) + b.zlogger.Debug("boundary started", zap.Stringer("boundary", boundaryRange)) + return nil +} + +func (b *Bundler) stop(ctx context.Context) error { + b.zlogger.Debug("stopping file boundary") + + file, err := b.boundaryWriter.CloseBoundary(ctx) + if err != nil { + return fmt.Errorf("closing file: %w", err) + } + + if b.boundaryWriter.IsWritten() { + b.zlogger.Debug("queuing boundary upload", zap.Stringer("boundary", b.activeBoundary)) + + b.uploadQueue.In <- &boundaryFile{ + name: b.activeBoundary.String(), + file: file, + } + } else { + b.zlogger.Debug("boundary not written, skipping upload of files", zap.Stringer("boundary", b.activeBoundary)) + } + + b.HeaderWritten = false + b.activeBoundary = nil + b.stats.endBoundary() + + b.zlogger.Info("bundler stats", b.stats.Log()...) + return nil +} + +func (b *Bundler) newBoundary(containingBlockNum uint64) *bstream.Range { + startBlock := containingBlockNum - (containingBlockNum % b.blockCount) + endBlock := startBlock + b.blockCount + if b.stopBlock < endBlock { + endBlock = b.stopBlock + } + return bstream.NewRangeExcludingEnd(startBlock, endBlock) +} + +func boundariesToSkip(lastBoundary *bstream.Range, blockNum uint64, size uint64) (out []*bstream.Range) { + iter := *lastBoundary.EndBlock() + endBlock := computeEndBlock(iter, size) + for blockNum >= endBlock { + out = append(out, bstream.NewRangeExcludingEnd(iter, endBlock)) + iter = endBlock + endBlock = computeEndBlock(iter, size) + } + return out +} + +func computeEndBlock(startBlockNum, size uint64) uint64 { + return (startBlockNum + size) - (startBlockNum+size)%size +} + +type boundaryFile struct { + name string + file writer.Uploadeable +} + +func (b *Bundler) uploadBoundary(ctx context.Context, v interface{}) (interface{}, error) { + bf := v.(*boundaryFile) + + outputPath, err := bf.file.Upload(ctx, b.outputStore) + if err != nil { + return nil, fmt.Errorf("unable to upload: %w", err) + } + b.zlogger.Debug("boundary file uploaded", + zap.String("boundary", bf.name), + zap.String("output_path", outputPath), + ) + + return bf, nil +} diff --git a/sink/sql/db_changes/bundler/encoder.go b/sink/sql/db_changes/bundler/encoder.go new file mode 100644 index 000000000..be422cda6 --- /dev/null +++ b/sink/sql/db_changes/bundler/encoder.go @@ -0,0 +1,49 @@ +package bundler + +import ( + "bytes" + "encoding/csv" + "encoding/json" + "fmt" + "sort" + + "github.com/golang/protobuf/proto" +) + +type Encoder func(proto.Message) ([]byte, error) + +func JSONLEncode(message proto.Message) ([]byte, error) { + buf := []byte{} + data, err := json.Marshal(message) + if err != nil { + return nil, fmt.Errorf("json marshal: %w", err) + } + buf = append(buf, data...) + buf = append(buf, byte('\n')) + return buf, nil +} + +func CSVEncode(message map[string]string) ([]byte, error) { + keys := make([]string, 0, len(message)) + for k := range message { + keys = append(keys, k) + } + sort.Strings(keys) + + row := make([]string, 0, len(keys)) + for _, key := range keys { + row = append(row, message[key]) + } + + var buf bytes.Buffer + writer := csv.NewWriter(&buf) + if err := writer.Write(row); err != nil { + return nil, err + } + writer.Flush() + if err := writer.Error(); err != nil { + return nil, err + } + + return buf.Bytes(), nil +} diff --git a/sink/sql/db_changes/bundler/stats.go b/sink/sql/db_changes/bundler/stats.go new file mode 100644 index 000000000..34bf5b593 --- /dev/null +++ b/sink/sql/db_changes/bundler/stats.go @@ -0,0 +1,74 @@ +package bundler + +import ( + "time" + + "github.com/streamingfast/bstream" + "github.com/streamingfast/dmetrics" + "github.com/streamingfast/logging/zapx" + "go.uber.org/zap" +) + +type boundaryStats struct { + creationStart time.Time + + boundaryProcessTime time.Duration + procesingDataTime time.Duration + uploadedDuration time.Duration + + totalBoundaryCount uint64 + boundary *bstream.Range + + avgUploadDuration *dmetrics.AvgDurationCounter + avgBoundaryProcessDuration *dmetrics.AvgDurationCounter + avgDataProcessDuration *dmetrics.AvgDurationCounter +} + +func newStats() *boundaryStats { + return &boundaryStats{ + avgUploadDuration: dmetrics.NewAvgDurationCounter(30*time.Second, time.Second, "upload dur"), + avgBoundaryProcessDuration: dmetrics.NewAvgDurationCounter(30*time.Second, time.Second, "boundary process dur"), + avgDataProcessDuration: dmetrics.NewAvgDurationCounter(30*time.Second, time.Second, "data process dur"), + } +} + +func (s *boundaryStats) startBoundary(b *bstream.Range) { + s.creationStart = time.Now() + s.boundary = b + s.totalBoundaryCount++ + s.boundaryProcessTime = 0 + s.procesingDataTime = 0 + s.uploadedDuration = 0 +} + +func (s *boundaryStats) addUploadedDuration(dur time.Duration) { + s.avgUploadDuration.AddDuration(dur) + s.uploadedDuration = dur +} + +func (s *boundaryStats) endBoundary() { + dur := time.Since(s.creationStart) + s.avgBoundaryProcessDuration.AddDuration(dur) + s.boundaryProcessTime = dur + s.avgDataProcessDuration.AddDuration(s.procesingDataTime) +} + +func (s *boundaryStats) addProcessingDataDur(dur time.Duration) { + s.procesingDataTime += dur +} + +func (s *boundaryStats) Log() []zap.Field { + return []zap.Field{ + zap.Uint64("file_count", s.totalBoundaryCount), + zap.Stringer("boundary", s.boundary), + zapx.HumanDuration("boundary_process_duration", s.boundaryProcessTime), + zapx.HumanDuration("upload_duration", s.uploadedDuration), + zapx.HumanDuration("data_process_duration", s.procesingDataTime), + zapx.HumanDuration("avg_upload_duration", s.avgUploadDuration.Average()), + zapx.HumanDuration("total_upload_duration", s.avgUploadDuration.Total()), + zapx.HumanDuration("avg_boundary_process_duration", s.avgBoundaryProcessDuration.Average()), + zapx.HumanDuration("total_boundary_process_duration", s.avgBoundaryProcessDuration.Total()), + zapx.HumanDuration("avg_data_process_duration", s.avgDataProcessDuration.Average()), + zapx.HumanDuration("total_data_process_duration", s.avgDataProcessDuration.Total()), + } +} diff --git a/sink/sql/db_changes/bundler/writer/buffered.go b/sink/sql/db_changes/bundler/writer/buffered.go new file mode 100644 index 000000000..6285f5e4c --- /dev/null +++ b/sink/sql/db_changes/bundler/writer/buffered.go @@ -0,0 +1,248 @@ +package writer + +import ( + "bufio" + "bytes" + "context" + "fmt" + "io" + "os" + "path/filepath" + + "github.com/streamingfast/bstream" + "go.uber.org/zap" +) + +var _ Writer = (*BufferedIO)(nil) + +type BufferedIO struct { + baseWriter + + written bool + + bufferMazSize uint64 + workingDir string + activeFile *bufferedActiveFile +} + +func NewBufferedIO( + bufferMaxSize uint64, + workingDir string, + fileType FileType, + zlogger *zap.Logger, +) *BufferedIO { + if bufferMaxSize == 0 { + bufferMaxSize = DefaultBufSize + } + + return &BufferedIO{ + bufferMazSize: bufferMaxSize, + baseWriter: newBaseWriter(fileType, zlogger), + workingDir: workingDir, + written: false, + } +} + +func (s *BufferedIO) workingFilename(blockRange *bstream.Range) string { + return fmt.Sprintf("%010d-%010d.tmp.%s", blockRange.StartBlock(), (*blockRange.EndBlock()), s.fileType) +} + +func (s *BufferedIO) StartBoundary(blockRange *bstream.Range) error { + if s.activeFile != nil { + return fmt.Errorf("unable to start a file while one (backed by %q) is already open", s.activeFile.Path()) + } + + lazyFile := LazyOpen(filepath.Join(s.workingDir, s.workingFilename(blockRange))) + + a := &bufferedActiveFile{ + lazyFile: lazyFile, + writer: NewIntelligentWriterSize(lazyFile, int(s.bufferMazSize)), + blockRange: blockRange, + outputFilename: s.filename(blockRange), + } + + s.written = false + s.activeFile = a + return nil +} + +func (s *BufferedIO) CloseBoundary(ctx context.Context) (Uploadeable, error) { + defer func() { + s.activeFile = nil + }() + + if s.activeFile == nil { + return nil, fmt.Errorf("no active file") + } + + if s.activeFile.writer.AllDataFitInMemory() { + s.zlogger.Debug("all data from range is in memory, no need to flush") + return &dataFile{ + reader: bytes.NewReader(s.activeFile.writer.MemoryData()), + outputFilename: s.activeFile.outputFilename, + }, nil + } + + s.zlogger.Debug("flushing buffered writter") + if err := s.activeFile.writer.Flush(); err != nil { + return nil, fmt.Errorf("flushing buffered active writer: %w", err) + } + + if err := s.activeFile.lazyFile.Close(); err != nil { + return nil, fmt.Errorf("closing file: %w", err) + } + + workingPath := s.activeFile.Path() + return &localFile{ + localFilePath: workingPath, + outputFilename: s.activeFile.outputFilename, + }, nil +} + +func (s *BufferedIO) Write(data []byte) (n int, err error) { + if s.activeFile == nil { + return 0, fmt.Errorf("failed to write to active file") + } + + if !s.written { + s.written = true + } + + return s.activeFile.writer.Write(data) +} + +func (s *BufferedIO) IsWritten() bool { + return s.written +} + +var _ io.WriteCloser = (*LazyFile)(nil) + +// LazyFile only creates and writes to file if `Write` is called at least one. +// +// **Important** Not safe for concurrent access, you need to gate yourself if +// you need that. +type LazyFile struct { + *os.File + + path string +} + +func LazyOpen(path string) *LazyFile { + return &LazyFile{ + File: nil, + path: path, + } +} + +func (f *LazyFile) Path() string { + return f.path +} + +func (f *LazyFile) Write(p []byte) (n int, err error) { + if f.File == nil { + if err := os.MkdirAll(filepath.Dir(f.path), os.ModePerm); err != nil { + return 0, fmt.Errorf("mkdir dirs: %w", err) + } + + file, err := os.Create(f.path) + if err != nil { + return 0, fmt.Errorf("open file: %w", err) + } + + f.File = file + } + + return f.File.Write(p) +} + +func (f *LazyFile) Close() error { + if f.File != nil { + return f.File.Close() + } + + return nil +} + +type memoryBufferedWriter struct { + io.Writer + + MemoryBuffer []byte + NextWritesToMemory bool + WrittenToWrapped bool +} + +func newMemoryBufferedWriter(w io.Writer) *memoryBufferedWriter { + return &memoryBufferedWriter{Writer: w} +} + +func (f *memoryBufferedWriter) Write(p []byte) (n int, err error) { + if f.NextWritesToMemory { + if f.MemoryBuffer == nil { + f.MemoryBuffer = p + return len(p), nil + } + + f.MemoryBuffer = append(f.MemoryBuffer, p...) + return len(p), nil + } + + f.WrittenToWrapped = true + return f.Writer.Write(p) +} + +func (f *memoryBufferedWriter) Close() error { + if v, ok := f.Writer.(io.Closer); ok { + return v.Close() + } + + return nil +} + +func NewIntelligentWriterSize(w io.Writer, size int) *IntelligentWriter { + underlyingWritter := newMemoryBufferedWriter(w) + + return &IntelligentWriter{Writer: bufio.NewWriterSize(underlyingWritter, size), underlyingWritter: underlyingWritter} +} + +func NewIntelligentWriter(w io.Writer) *IntelligentWriter { + return NewIntelligentWriterSize(w, DefaultBufSize) +} + +func (w *IntelligentWriter) AllDataFitInMemory() bool { + return !w.underlyingWritter.WrittenToWrapped +} + +func (w *IntelligentWriter) MemoryData() []byte { + if !w.AllDataFitInMemory() { + panic(fmt.Errorf("it's invalid to call MemoryData without checking if all data is held in memory, check AllDataFitInMemory prior calling this method")) + } + + w.underlyingWritter.NextWritesToMemory = true + + if err := w.Writer.Flush(); err != nil { + panic(fmt.Errorf("this should have been infallible because we write directly received 'b.buf[0:n]', there is a flaw in our logic: %w", err)) + } + + return w.underlyingWritter.MemoryBuffer +} + +type IntelligentWriter struct { + *bufio.Writer + + underlyingWritter *memoryBufferedWriter +} + +const ( + DefaultBufSize = 16 * 1024 * 1024 // 16 MiB +) + +type bufferedActiveFile struct { + lazyFile *LazyFile + writer *IntelligentWriter + blockRange *bstream.Range + outputFilename string +} + +func (f *bufferedActiveFile) Path() string { + return f.lazyFile.path +} diff --git a/sink/sql/db_changes/bundler/writer/common.go b/sink/sql/db_changes/bundler/writer/common.go new file mode 100644 index 000000000..9f5900ee3 --- /dev/null +++ b/sink/sql/db_changes/bundler/writer/common.go @@ -0,0 +1,36 @@ +package writer + +import ( + "fmt" + + "github.com/streamingfast/bstream" + "go.uber.org/zap" +) + +type FileType string + +const ( + FileTypeJSONL FileType = "jsonl" + FileTypeCSV FileType = "csv" +) + +type baseWriter struct { + fileType FileType + zlogger *zap.Logger +} + +func newBaseWriter(fileType FileType, zlogger *zap.Logger) baseWriter { + return baseWriter{ + fileType: fileType, + zlogger: zlogger, + } + +} + +func (b baseWriter) filename(blockRange *bstream.Range) string { + return fmt.Sprintf("%010d-%010d", blockRange.StartBlock(), *blockRange.EndBlock()-1) +} + +func (b baseWriter) Type() FileType { + return b.fileType +} diff --git a/sink/sql/db_changes/bundler/writer/interface.go b/sink/sql/db_changes/bundler/writer/interface.go new file mode 100644 index 000000000..6c74a9c81 --- /dev/null +++ b/sink/sql/db_changes/bundler/writer/interface.go @@ -0,0 +1,23 @@ +package writer + +import ( + "context" + "io" + + "github.com/streamingfast/dstore" + + "github.com/streamingfast/bstream" +) + +type Writer interface { + io.Writer + + IsWritten() bool + StartBoundary(*bstream.Range) error + CloseBoundary(ctx context.Context) (Uploadeable, error) + Type() FileType +} + +type Uploadeable interface { + Upload(ctx context.Context, store dstore.Store) (string, error) +} diff --git a/sink/sql/db_changes/bundler/writer/types.go b/sink/sql/db_changes/bundler/writer/types.go new file mode 100644 index 000000000..5dd58959b --- /dev/null +++ b/sink/sql/db_changes/bundler/writer/types.go @@ -0,0 +1,33 @@ +package writer + +import ( + "context" + "fmt" + "io" + + "github.com/streamingfast/dstore" +) + +type dataFile struct { + reader io.Reader + outputFilename string +} + +func (d *dataFile) Upload(ctx context.Context, store dstore.Store) (string, error) { + if err := store.WriteObject(ctx, d.outputFilename, d.reader); err != nil { + return "", fmt.Errorf("write object: %w", err) + } + return store.ObjectPath(d.outputFilename), nil +} + +type localFile struct { + localFilePath string + outputFilename string +} + +func (l *localFile) Upload(ctx context.Context, store dstore.Store) (string, error) { + if err := store.PushLocalFile(ctx, l.localFilePath, l.outputFilename); err != nil { + return "", fmt.Errorf("pushing object: %w", err) + } + return store.ObjectPath(l.outputFilename), nil +} diff --git a/sink/sql/db_changes/db/cursor.go b/sink/sql/db_changes/db/cursor.go new file mode 100644 index 000000000..6d3212c2a --- /dev/null +++ b/sink/sql/db_changes/db/cursor.go @@ -0,0 +1,177 @@ +package db + +import ( + "context" + "database/sql" + "errors" + "fmt" + "strings" + + "github.com/lithammer/dedent" + sink "github.com/streamingfast/substreams/sink" + "go.uber.org/zap" +) + +var ErrCursorNotFound = errors.New("cursor not found") + +type cursorRow struct { + ID string + Cursor string + BlockNum uint64 + BlockID string +} + +// GetAllCursors returns an unordered map given for each module's hash recorded +// the active cursor for it. +func (l *Loader) GetAllCursors(ctx context.Context) (out map[string]*sink.Cursor, err error) { + query := l.dialect.GetAllCursorsQuery(l.cursorTable.identifier) + rows, err := l.DB.QueryContext(ctx, query) + if err != nil { + return nil, fmt.Errorf("query all cursors: %w", err) + } + + out = make(map[string]*sink.Cursor) + for rows.Next() { + c := &cursorRow{} + if err := rows.Scan(&c.ID, &c.Cursor, &c.BlockNum, &c.BlockID); err != nil { + return nil, fmt.Errorf("getting all cursors: %w", err) + } + + out[c.ID], err = sink.NewCursor(c.Cursor) + if err != nil { + return nil, fmt.Errorf("database corrupted: stored cursor %q is not a valid cursor", c.Cursor) + } + } + + return out, nil +} + +func (l *Loader) GetCursor(ctx context.Context, outputModuleHash string) (cursor *sink.Cursor, mismatchDetected bool, err error) { + cursors, err := l.GetAllCursors(ctx) + if err != nil { + return nil, false, fmt.Errorf("get cursor: %w", err) + } + + if len(cursors) == 0 { + return sink.NewBlankCursor(), false, ErrCursorNotFound + } + + activeCursor, found := cursors[outputModuleHash] + if found { + return activeCursor, false, err + } + + // It's not found at this point, look for one with highest block, we will report + // (maybe) a warning if the module hash is different, which is the case here. + actualOutputModuleHash, activeCursor := cursorAtHighestBlock(cursors) + + switch l.moduleMismatchMode { + case OnModuleHashMismatchIgnore: + return activeCursor, true, err + + case OnModuleHashMismatchWarn: + l.logger.Warn( + fmt.Sprintf("cursor module hash mismatch, continuing using cursor at highest block %s, this warning can be made silent by using '--on-module-hash-mismatch=ignore'", activeCursor.Block()), + zap.String("expected_module_hash", outputModuleHash), + zap.String("actual_module_hash", actualOutputModuleHash), + ) + + return activeCursor, true, err + + case OnModuleHashMismatchError: + return nil, true, fmt.Errorf("cursor module hash mismatch, refusing to continue because flag '--on-module-hash-mismatch=error' (defaults) is set, you can change to 'warn' or 'ignore': your module's hash is %q but cursor with highest block (%d) module hash is actually %q in the database", + outputModuleHash, + activeCursor.Block().Num(), + actualOutputModuleHash, + ) + + default: + panic(fmt.Errorf("unknown module mismatch mode %q", l.moduleMismatchMode)) + } +} + +func cursorAtHighestBlock(in map[string]*sink.Cursor) (hash string, highest *sink.Cursor) { + for moduleHash, cursor := range in { + if highest == nil || cursor.Block().Num() > highest.Block().Num() { + highest = cursor + hash = moduleHash + } + } + + return +} + +func (l *Loader) InsertCursor(ctx context.Context, moduleHash string, c *sink.Cursor) error { + query := fmt.Sprintf("INSERT INTO %s (id, cursor, block_num, block_id) values ('%s', '%s', %d, '%s')", + l.cursorTable.identifier, + moduleHash, + c, + c.Block().Num(), + c.Block().ID(), + ) + if _, err := l.DB.ExecContext(ctx, query); err != nil { + return fmt.Errorf("insert cursor: %w", err) + } + + return nil +} + +// UpdateCursor updates the active cursor. If no cursor is active and no update occurred, returns +// ErrCursorNotFound. If the update was not successful on the database, returns an error. +// You can use tx=nil to run the query outside of a transaction. +func (l *Loader) UpdateCursor(ctx context.Context, tx Tx, moduleHash string, c *sink.Cursor) error { + l.logger.Debug("updating cursor", zap.String("module_hash", moduleHash), zap.Stringer("cursor", c)) + _, err := l.runModifiyQuery(ctx, tx, "update", l.dialect.GetUpdateCursorQuery( + l.cursorTable.identifier, moduleHash, c, c.Block().Num(), c.Block().ID(), + )) + return err +} + +// DeleteCursor deletes the active cursor for the given 'moduleHash'. +func (l *Loader) DeleteCursor(ctx context.Context, moduleHash string) error { + _, err := l.runModifiyQuery(ctx, nil, "delete", fmt.Sprintf("DELETE FROM %s WHERE id = '%s'", l.cursorTable.identifier, moduleHash)) + return err +} + +// DeleteAllCursors deletes all cursors. +func (l *Loader) DeleteAllCursors(ctx context.Context) (deletedCount int64, err error) { + deletedCount, err = l.runModifiyQuery(ctx, nil, "delete", fmt.Sprintf("DELETE FROM %s", l.cursorTable.identifier)) + if err != nil && errors.Is(err, ErrCursorNotFound) { + return 0, nil + } + + return deletedCount, nil +} + +type sqlExecutor interface { + ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) +} + +// runModifiyQuery runs the logic to execute a query that is supposed to modify the database in some form affecting +// at least 1 row. +func (l *Loader) runModifiyQuery(ctx context.Context, tx Tx, action string, query string) (rowsAffected int64, err error) { + var executor sqlExecutor = l.DB + if tx != nil { + executor = tx + } + + result, err := executor.ExecContext(ctx, query) + if err != nil { + return 0, fmt.Errorf("%s cursor: %w", action, err) + } + + rowsAffected, err = result.RowsAffected() + if err != nil { + return 0, fmt.Errorf("rows affected: %w", err) + } + + if l.dialect.DriverSupportRowsAffected() && rowsAffected <= 0 { + return 0, ErrCursorNotFound + } + + return rowsAffected, nil +} + +func query(in string, args ...any) string { + return fmt.Sprintf(strings.TrimSpace(dedent.Dedent(in)), args...) +} diff --git a/sink/sql/db_changes/db/db.go b/sink/sql/db_changes/db/db.go new file mode 100644 index 000000000..dea982436 --- /dev/null +++ b/sink/sql/db_changes/db/db.go @@ -0,0 +1,395 @@ +package db + +import ( + "context" + "database/sql" + "database/sql/driver" + "fmt" + + "github.com/streamingfast/logging" + orderedmap "github.com/wk8/go-ordered-map/v2" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +// Make the typing a bit easier +type OrderedMap[K comparable, V any] struct { + *orderedmap.OrderedMap[K, V] +} + +func NewOrderedMap[K comparable, V any]() *OrderedMap[K, V] { + return &OrderedMap[K, V]{OrderedMap: orderedmap.New[K, V]()} +} + +type SystemTableError struct { + error +} + +type Loader struct { + *sql.DB + + entries *OrderedMap[string, *OrderedMap[string, *Operation]] + entriesCount uint64 + tables map[string]*TableInfo + cursorTable *TableInfo + + handleReorgs bool + batchBlockFlushInterval int + batchRowFlushInterval int + liveBlockFlushInterval int + moduleMismatchMode OnModuleHashMismatch + + dialect Dialect + + logger *zap.Logger + tracer logging.Tracer + + testTx *TestTx // used for testing: if non-nil, 'loader.BeginTx()' will return this object instead of a real *sql.Tx + dsn *DSN + + batchOrdinal uint64 // Counter for ordinals within the current batch, resets on flush +} + +func NewLoader( + dsn *DSN, + cursorTableName string, historyTableName string, clickhouseCluster string, + batchBlockFlushInterval int, + batchRowFlushInterval int, + liveBlockFlushInterval int, + OnModuleHashMismatch string, + handleReorgs *bool, + logger *zap.Logger, + tracer logging.Tracer, +) (*Loader, error) { + + // Validate ClickHouse is not using HTTP protocol ports + if dsn.Driver() == "clickhouse" { + if dsn.Port == 8123 || dsn.Port == 8443 { + return nil, fmt.Errorf("ClickHouse HTTP protocol (port %d) is not supported. Please use the native TCP protocol on port 9000 or 9440", dsn.Port) + } + } + + sqlDB, err := sql.Open(dsn.Driver(), dsn.ConnString()) + if err != nil { + return nil, fmt.Errorf("open db connection: %w", err) + } + + dialect, err := newDialect(sqlDB.Driver(), dsn, cursorTableName, historyTableName, clickhouseCluster) + if err != nil { + return nil, fmt.Errorf("get dialect: %w", err) + } + + moduleMismatchMode, err := ParseOnModuleHashMismatch(OnModuleHashMismatch) + if err != nil { + return nil, fmt.Errorf("parse on module hash mismatch: %w", err) + } + + l := &Loader{ + DB: sqlDB, + dsn: dsn, + entries: NewOrderedMap[string, *OrderedMap[string, *Operation]](), + tables: map[string]*TableInfo{}, + batchBlockFlushInterval: batchBlockFlushInterval, + batchRowFlushInterval: batchRowFlushInterval, + liveBlockFlushInterval: liveBlockFlushInterval, + moduleMismatchMode: moduleMismatchMode, + dialect: dialect, + logger: logger, + tracer: tracer, + } + + if handleReorgs == nil { + // automatic detection + l.handleReorgs = !l.dialect.OnlyInserts() + } else { + l.handleReorgs = *handleReorgs + } + + if l.handleReorgs && l.dialect.OnlyInserts() { + return nil, fmt.Errorf("driver %s does not support reorg handling. You must use set a non-zero undo-buffer-size", sqlDB.Driver()) + } + + logger.Info("created new DB loader", + zap.Int("batch_block_flush_interval", batchBlockFlushInterval), + zap.Int("batch_row_flush_interval", batchRowFlushInterval), + zap.Int("live_block_flush_interval", liveBlockFlushInterval), + zap.Stringer("on_module_hash_mismatch", moduleMismatchMode), + zap.Bool("handle_reorgs", l.handleReorgs), + zap.String("dialect", fmt.Sprintf("%T", l.dialect)), + ) + + return l, nil +} + +func newDialect(driver driver.Driver, dsn *DSN, cursorTableName string, historyTableName string, clickHouseClusterName string) (Dialect, error) { + driverType := fmt.Sprintf("%T", driver) + switch driverType { + case "*pq.Driver": + return NewPostgresDialect(dsn.Schema(), cursorTableName, historyTableName), nil + case "*clickhouse.stdDriver": + return NewClickhouseDialect(dsn.Schema(), cursorTableName, clickHouseClusterName), nil + default: + return nil, fmt.Errorf("unsupported driver: %s", driverType) + } +} + +type Tx interface { + Rollback() error + Commit() error + ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) + QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) +} + +func (l *Loader) Begin() (Tx, error) { + return l.BeginTx(context.Background(), nil) +} + +func (l *Loader) BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) { + if l.testTx != nil { + return l.testTx, nil + } + return l.DB.BeginTx(ctx, opts) +} + +func (l *Loader) BatchBlockFlushInterval() int { + return l.batchBlockFlushInterval +} + +func (l *Loader) LiveBlockFlushInterval() int { + return l.liveBlockFlushInterval +} + +func (l *Loader) FlushNeeded() bool { + totalRows := 0 + for pair := l.entries.Oldest(); pair != nil; pair = pair.Next() { + totalRows += pair.Value.Len() + } + return totalRows > l.batchRowFlushInterval +} + +// getTablesFromSchema returns table information similar to schema.Tables() +// but only inspects tables in the specified schema to avoid issues with database extensions +func (l *Loader) getTablesFromSchema(schemaName string) (map[[2]string][]*sql.ColumnType, error) { + tables, err := l.dialect.GetTablesInSchema(l.DB, schemaName) + if err != nil { + return nil, fmt.Errorf("getting tables from schema: %w", err) + } + + result := make(map[[2]string][]*sql.ColumnType) + + for _, table := range tables { + schemaName, tableName := table[0], table[1] + + columns, err := l.dialect.GetTableColumns(l.DB, schemaName, tableName) + if err != nil { + l.logger.Warn("failed to get columns for table, skipping", + zap.String("schema", schemaName), + zap.String("table", tableName), + zap.Error(err), + ) + continue + } + + key := [2]string{schemaName, tableName} + result[key] = columns + } + + return result, nil +} + +func (l *Loader) LoadTables(schemaName string, cursorTableName string, historyTableName string) error { + schemaTables, err := l.getTablesFromSchema(schemaName) + if err != nil { + return fmt.Errorf("retrieving table and schemaName: %w", err) + } + + seenCursorTable := false + seenHistoryTable := false + for schemaTableName, columns := range schemaTables { + tableName := schemaTableName[1] + l.logger.Debug("processing schemaName's table", + zap.String("schema_name", schemaName), + zap.String("table_name", tableName), + ) + + if schemaTableName[0] != schemaName { + continue + } + + if tableName == cursorTableName { + if err := l.validateCursorTables(columns, schemaName, cursorTableName); err != nil { + return fmt.Errorf("invalid cursors table: %w", err) + } + + seenCursorTable = true + } + if tableName == historyTableName { + seenHistoryTable = true + } + + columnByName := make(map[string]*ColumnInfo, len(columns)) + for _, f := range columns { + columnByName[f.Name()] = &ColumnInfo{ + name: f.Name(), + escapedName: EscapeIdentifier(f.Name()), + databaseTypeName: f.DatabaseTypeName(), + scanType: f.ScanType(), + } + } + + key, err := l.dialect.GetPrimaryKey(l.DB, schemaName, tableName) + if err != nil { + return fmt.Errorf("get primary key: %w", err) + } + + l.tables[tableName], err = NewTableInfo(schemaName, tableName, key, columnByName) + if err != nil { + return fmt.Errorf("invalid table: %w", err) + } + } + + if !seenCursorTable { + return &SystemTableError{fmt.Errorf(`%s.%s table is not found`, EscapeIdentifier(schemaName), cursorTableName)} + } + if l.handleReorgs && !seenHistoryTable { + return &SystemTableError{fmt.Errorf("%s.%s table is not found and reorgs handling is enabled", EscapeIdentifier(schemaName), historyTableName)} + } + + l.cursorTable = l.tables[cursorTableName] + + return nil +} + +func (l *Loader) validateCursorTables(columns []*sql.ColumnType, schemaName string, cursorTableName string) (err error) { + if len(columns) != 4 { + return &SystemTableError{fmt.Errorf("table requires 4 columns ('id', 'cursor', 'block_num', 'block_id')")} + } + columnsCheck := map[string]string{ + "block_num": "int64", + "block_id": "string", + "cursor": "string", + "id": "string", + } + for _, f := range columns { + columnName := f.Name() + if _, found := columnsCheck[columnName]; !found { + return &SystemTableError{fmt.Errorf("unexpected column %q in cursors table", columnName)} + } + expectedType := columnsCheck[columnName] + actualType := f.ScanType().Kind().String() + if expectedType != actualType { + return &SystemTableError{fmt.Errorf("column %q has invalid type, expected %q has %q", columnName, expectedType, actualType)} + } + delete(columnsCheck, columnName) + } + if len(columnsCheck) != 0 { + for k := range columnsCheck { + return &SystemTableError{fmt.Errorf("missing column %q from cursors", k)} + } + } + key, err := l.dialect.GetPrimaryKey(l.DB, schemaName, cursorTableName) + if err != nil { + return &SystemTableError{fmt.Errorf("failed getting primary key: %w", err)} + } + if len(key) == 0 { + return &SystemTableError{fmt.Errorf("primary key not found: %w", err)} + } + if key[0] != "id" { + return &SystemTableError{fmt.Errorf("column 'id' should be primary key not %q", key[0])} + } + return nil +} + +func (l *Loader) GetColumnsForTable(name string) []string { + columns := make([]string, 0, len(l.tables[name].columnsByName)) + for column := range l.tables[name].columnsByName { + if len(column) > 0 { + columns = append(columns, column) + } + } + return columns +} + +func (l *Loader) GetAvailableTablesInSchema() []string { + tables := make([]string, len(l.tables)) + i := 0 + for table := range l.tables { + tables[i] = table + i++ + } + return tables +} + +func (l *Loader) HasTable(tableName string) bool { + if _, found := l.tables[tableName]; found { + return true + } + return false +} + +func (l *Loader) MarshalLogObject(encoder zapcore.ObjectEncoder) error { + encoder.AddUint64("entries_count", l.entriesCount) + return nil +} + +// Setup creates the schemaName, cursors and history table where the is a byte array +// taken from somewhere. +func (l *Loader) Setup(ctx context.Context, schemaName string, userSql string, withPostgraphile bool) error { + if userSql != "" { + if err := l.dialect.ExecuteSetupScript(ctx, l, userSql); err != nil { + return fmt.Errorf("exec userSql: %w", err) + } + } + + if err := l.setupCursorTable(ctx, schemaName, withPostgraphile); err != nil { + return fmt.Errorf("setup cursor table: %w", err) + } + + if err := l.setupHistoryTable(ctx, schemaName, withPostgraphile); err != nil { + return fmt.Errorf("setup history table: %w", err) + } + + return nil +} + +func (l *Loader) setupCursorTable(ctx context.Context, schemaName string, withPostgraphile bool) error { + query := l.dialect.GetCreateCursorQuery(schemaName, withPostgraphile) + _, err := l.ExecContext(ctx, query) + return err +} + +func (l *Loader) setupHistoryTable(ctx context.Context, schemaName string, withPostgraphile bool) error { + if l.dialect.OnlyInserts() { + return nil + } + query := l.dialect.GetCreateHistoryQuery(schemaName, withPostgraphile) + _, err := l.ExecContext(ctx, query) + return err +} + +// GetIdentifier returns / suitable for user presentation +func (l *Loader) GetIdentifier() string { + return fmt.Sprintf("%s/%s", l.dsn.schema, l.dsn.schema) +} + +// GetDSN returns the DSN for the loader +func (l *Loader) GetDSN() *DSN { + return l.dsn +} + +// NextBatchOrdinal returns the next ordinal for the current batch and increments the counter +func (l *Loader) NextBatchOrdinal() uint64 { + ordinal := l.batchOrdinal + l.batchOrdinal++ + return ordinal +} + +type obfuscatedString string + +func (s obfuscatedString) String() string { + if len(s) == 0 { + return "" + } + + return "********" +} diff --git a/sink/sql/db_changes/db/dialect.go b/sink/sql/db_changes/db/dialect.go new file mode 100644 index 000000000..8ea60cb84 --- /dev/null +++ b/sink/sql/db_changes/db/dialect.go @@ -0,0 +1,36 @@ +package db + +import ( + "context" + "database/sql" + "fmt" + + sink "github.com/streamingfast/substreams/sink" +) + +type UnknownDriverError struct { + Driver string +} + +// Error returns a formatted string description. +func (e UnknownDriverError) Error() string { + return fmt.Sprintf("unknown database driver: %s", e.Driver) +} + +type Dialect interface { + GetCreateCursorQuery(schema string, withPostgraphile bool) string + GetCreateHistoryQuery(schema string, withPostgraphile bool) string + ExecuteSetupScript(ctx context.Context, l *Loader, schemaSql string) error + DriverSupportRowsAffected() bool + GetUpdateCursorQuery(table, moduleHash string, cursor *sink.Cursor, block_num uint64, block_id string) string + GetAllCursorsQuery(table string) string + ParseDatetimeNormalization(value string) string + Flush(tx Tx, ctx context.Context, l *Loader, outputModuleHash string, lastFinalBlock uint64) (int, error) + Revert(tx Tx, ctx context.Context, l *Loader, lastValidFinalBlock uint64) error + OnlyInserts() bool + AllowPkDuplicates() bool + CreateUser(tx Tx, ctx context.Context, l *Loader, username string, password string, database string, readOnly bool) error + GetTableColumns(db *sql.DB, schemaName, tableName string) ([]*sql.ColumnType, error) + GetPrimaryKey(db *sql.DB, schemaName, tableName string) ([]string, error) + GetTablesInSchema(db *sql.DB, schemaName string) ([][2]string, error) +} diff --git a/sink/sql/db_changes/db/dialect_clickhouse.go b/sink/sql/db_changes/db/dialect_clickhouse.go new file mode 100644 index 000000000..3894af6ee --- /dev/null +++ b/sink/sql/db_changes/db/dialect_clickhouse.go @@ -0,0 +1,533 @@ +package db + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "math/big" + "reflect" + "sort" + "strconv" + "strings" + "time" + + clickhouse "github.com/AfterShip/clickhouse-sql-parser/parser" + _ "github.com/ClickHouse/clickhouse-go/v2" + "github.com/streamingfast/cli" + sink "github.com/streamingfast/substreams/sink" + "go.uber.org/zap" + "golang.org/x/exp/maps" +) + +type ClickhouseDialect struct { + cursorTableName string + cluster string + schemaName string +} + +func NewClickhouseDialect(schemaName string, cursorTableName string, cluster string) *ClickhouseDialect { + return &ClickhouseDialect{ + cursorTableName: cursorTableName, + cluster: cluster, + schemaName: schemaName, + } +} + +// Clickhouse should be used to insert a lot of data in batches. The current official clickhouse +// driver doesn't support Transactions for multiple tables. The only way to add in batches is +// creating a transaction for a table, adding all rows and commiting it. +func (d ClickhouseDialect) Flush(tx Tx, ctx context.Context, l *Loader, outputModuleHash string, lastFinalBlock uint64) (int, error) { + var entryCount int + for entriesPair := l.entries.Oldest(); entriesPair != nil; entriesPair = entriesPair.Next() { + tableName := entriesPair.Key + entries := entriesPair.Value + sqlTx, err := l.DB.BeginTx(ctx, nil) + if err != nil { + return entryCount, fmt.Errorf("failed to begin db transaction") + } + + if l.tracer.Enabled() { + l.logger.Debug("flushing table entries", zap.String("table_name", tableName), zap.Int("entry_count", entries.Len())) + } + info := l.tables[tableName] + columns := make([]string, 0, len(info.columnsByName)) + for column := range info.columnsByName { + columns = append(columns, column) + } + sort.Strings(columns) + insertQuery := fmt.Sprintf( + "INSERT INTO %s.%s (%s)", + EscapeIdentifier(d.schemaName), + EscapeIdentifier(tableName), + strings.Join(columns, ",")) + batch, err := sqlTx.Prepare(insertQuery) + if err != nil { + return entryCount, fmt.Errorf("failed to prepare insert into %q: %w", tableName, err) + } + for entryPair := entries.Oldest(); entryPair != nil; entryPair = entryPair.Next() { + entry := entryPair.Value + + if l.tracer.Enabled() { + l.logger.Debug("adding query from operation to transaction", zap.Stringer("op", entry), zap.String("query", insertQuery)) + } + + values, err := convertOpToClickhouseValues(entry) + if err != nil { + return entryCount, fmt.Errorf("failed to get values: %w", err) + } + + if _, err := batch.ExecContext(ctx, values...); err != nil { + return entryCount, fmt.Errorf("executing for entry %q: %w", values, err) + } + } + + if err := sqlTx.Commit(); err != nil { + return entryCount, fmt.Errorf("failed to commit db transaction: %w", err) + } + entryCount += entries.Len() + } + + return entryCount, nil +} + +func (d ClickhouseDialect) Revert(tx Tx, ctx context.Context, l *Loader, lastValidFinalBlock uint64) error { + return fmt.Errorf("clickhouse driver does not support reorg management.") +} + +func (d ClickhouseDialect) GetCreateCursorQuery(schema string, withPostgraphile bool) string { + _ = withPostgraphile // TODO: see if this can work + + clusterClause := "" + engine := "ReplacingMergeTree()" + if d.cluster != "" { + clusterClause = fmt.Sprintf("ON CLUSTER %s", EscapeIdentifier(d.cluster)) + engine = "ReplicatedReplacingMergeTree()" + } + + return fmt.Sprintf(cli.Dedent(` + CREATE TABLE IF NOT EXISTS %s.%s %s + ( + id String, + cursor String, + block_num Int64, + block_id String + ) Engine = %s ORDER BY id; + `), EscapeIdentifier(schema), EscapeIdentifier(d.cursorTableName), clusterClause, engine) +} + +func (d ClickhouseDialect) GetCreateHistoryQuery(schema string, withPostgraphile bool) string { + panic("clickhouse does not support reorg management") +} + +func (d ClickhouseDialect) ExecuteSetupScript(ctx context.Context, l *Loader, schemaSql string) error { + if d.schemaName != "default" { + useDbQuery := fmt.Sprintf("USE %s", EscapeIdentifier(d.schemaName)) + if _, err := l.ExecContext(ctx, useDbQuery); err != nil { + l.logger.Error("failed to switch to database", zap.String("database", d.schemaName), zap.Error(err)) + return fmt.Errorf("use database %s: %w", d.schemaName, err) + } + } + + if d.cluster != "" { + stmts, err := clickhouse.NewParser(schemaSql).ParseStmts() + if err != nil { + return fmt.Errorf("parsing schemaName: %w", err) + } + + for _, stmt := range stmts { + if createDatabase, ok := stmt.(*clickhouse.CreateDatabase); ok { + l.logger.Debug("appending 'ON CLUSTER' clause to 'CREATE DATABASE'", zap.String("cluster", d.cluster), zap.Stringer("database", createDatabase.Name)) + createDatabase.OnCluster = &clickhouse.ClusterClause{Expr: &clickhouse.StringLiteral{Literal: d.cluster}} + } + if createTable, ok := stmt.(*clickhouse.CreateTable); ok { + l.logger.Debug("appending 'ON CLUSTER' clause to 'CREATE TABLE'", zap.String("cluster", d.cluster), zap.String("table", createTable.Name.String())) + createTable.OnCluster = &clickhouse.ClusterClause{Expr: &clickhouse.StringLiteral{Literal: d.cluster}} + + if !strings.HasPrefix(createTable.Engine.Name, "Replicated") && + strings.HasSuffix(createTable.Engine.Name, "MergeTree") { + newEngine := "Replicated" + createTable.Engine.Name + l.logger.Debug("replacing table engine with replicated one", zap.String("table", createTable.Name.String()), zap.String("engine", createTable.Engine.Name), zap.String("new_engine", newEngine)) + createTable.Engine.Name = newEngine + } + } + if createMaterializedView, ok := stmt.(*clickhouse.CreateMaterializedView); ok { + l.logger.Debug("appending 'ON CLUSTER' clause to 'CREATE MATERIALIZED VIEW'", zap.String("cluster", d.cluster), zap.Stringer("materialized_view", createMaterializedView.Name)) + createMaterializedView.OnCluster = &clickhouse.ClusterClause{Expr: &clickhouse.StringLiteral{Literal: d.cluster}} + + if createMaterializedView.Engine != nil && !strings.HasPrefix(createMaterializedView.Engine.Name, "Replicated") && + strings.HasSuffix(createMaterializedView.Engine.Name, "MergeTree") { + newEngine := "Replicated" + createMaterializedView.Engine.Name + l.logger.Debug("replacing table engine with replicated one", zap.Stringer("materialized_view", createMaterializedView.Name), zap.String("engine", createMaterializedView.Engine.Name), zap.String("new_engine", newEngine)) + createMaterializedView.Engine.Name = newEngine + } + } + if createView, ok := stmt.(*clickhouse.CreateView); ok { + l.logger.Debug("appending 'ON CLUSTER' clause to 'CREATE VIEW'", zap.String("cluster", d.cluster), zap.Stringer("view", createView.Name)) + createView.OnCluster = &clickhouse.ClusterClause{Expr: &clickhouse.StringLiteral{Literal: d.cluster}} + } + if createFunction, ok := stmt.(*clickhouse.CreateFunction); ok { + l.logger.Debug("appending 'ON CLUSTER' clause to 'CREATE FUNCTION'", zap.String("cluster", d.cluster), zap.Stringer("function", createFunction.FunctionName)) + createFunction.OnCluster = &clickhouse.ClusterClause{Expr: &clickhouse.StringLiteral{Literal: d.cluster}} + } + + if _, err := l.ExecContext(ctx, stmt.String()); err != nil { + l.logger.Error("failed to execute schema statement", zap.String("statement", stmt.String()), zap.Error(err)) + return fmt.Errorf("exec clickhouse cluster statements: %w", err) + } + } + } else { + // Splitting statements by ';' is not perfect but should be enough for now, + // it will fail for example if user enter a string that contains a ;! + for query := range strings.SplitSeq(schemaSql, ";") { + if len(strings.TrimSpace(query)) == 0 { + continue + } + if _, err := l.ExecContext(ctx, query); err != nil { + return fmt.Errorf("exec clickhouse statements: %w", err) + } + } + } + + return nil +} + +func (d ClickhouseDialect) GetUpdateCursorQuery(table, moduleHash string, cursor *sink.Cursor, block_num uint64, block_id string) string { + return query(` + INSERT INTO %s (id, cursor, block_num, block_id) values ('%s', '%s', %d, '%s') + `, table, moduleHash, cursor, block_num, block_id) +} + +func (d ClickhouseDialect) GetAllCursorsQuery(table string) string { + return fmt.Sprintf("SELECT id, cursor, block_num, block_id FROM %s FINAL", table) +} + +func (d ClickhouseDialect) ParseDatetimeNormalization(value string) string { + return fmt.Sprintf("parseDateTimeBestEffort(%s)", escapeStringValue(value)) +} + +func (d ClickhouseDialect) DriverSupportRowsAffected() bool { + return false +} + +func (d ClickhouseDialect) OnlyInserts() bool { + return true +} + +func (d ClickhouseDialect) AllowPkDuplicates() bool { + return true +} + +func (d ClickhouseDialect) CreateUser(tx Tx, ctx context.Context, l *Loader, username string, password string, _database string, readOnly bool) error { + user, pass := EscapeIdentifier(username), escapeStringValue(password) + + onClusterClause := "" + if d.cluster != "" { + onClusterClause = fmt.Sprintf("ON CLUSTER %s", EscapeIdentifier(d.cluster)) + } + + createUserQ := fmt.Sprintf("CREATE USER IF NOT EXISTS %s %s IDENTIFIED WITH plaintext_password BY %s;", user, onClusterClause, pass) + _, err := tx.ExecContext(ctx, createUserQ) + if err != nil { + return fmt.Errorf("executing create user query %q: %w", createUserQ, err) + } + + var grantQ string + if readOnly { + grantQ = fmt.Sprintf(` + GRANT %s SELECT ON *.* TO %s; + `, onClusterClause, user) + } else { + grantQ = fmt.Sprintf(` + GRANT %s ALL ON *.* TO %s; + `, onClusterClause, user) + } + + _, err = tx.ExecContext(ctx, grantQ) + if err != nil { + return fmt.Errorf("executing grant query %q: %w", grantQ, err) + } + + return nil +} + +func convertOpToClickhouseValues(o *Operation) ([]any, error) { + columns := make([]string, len(o.data)) + i := 0 + for column := range o.data { + columns[i] = column + i++ + } + sort.Strings(columns) + values := make([]any, len(o.data)) + for i, v := range columns { + if col, exists := o.table.columnsByName[v]; exists { + fieldData := o.data[v] + convertedType, err := convertToType(fieldData.Value, col.scanType) + if err != nil { + return nil, fmt.Errorf("converting value %q to type %q in column %q: %w", fieldData.Value, col.scanType, v, err) + } + values[i] = convertedType + } else { + return nil, fmt.Errorf("cannot find column %q for table %q (valid columns are %q)", v, o.table.identifier, strings.Join(maps.Keys(o.table.columnsByName), ", ")) + } + } + return values, nil +} + +func convertToType(value string, valueType reflect.Type) (any, error) { + switch valueType.Kind() { + case reflect.String: + return value, nil + case reflect.Slice: + if valueType.Elem().Kind() == reflect.Struct || valueType.Elem().Kind() == reflect.Ptr { + return nil, fmt.Errorf("%q is not supported as Clickhouse Array type", valueType.Elem().Name()) + } + + res := reflect.New(reflect.SliceOf(valueType.Elem())) + if err := json.Unmarshal([]byte(value), res.Interface()); err != nil { + return "", fmt.Errorf("could not JSON unmarshal slice value %q: %w", value, err) + } + + return res.Elem().Interface(), nil + case reflect.Bool: + return strconv.ParseBool(value) + case reflect.Int: + v, err := strconv.ParseInt(value, 10, 0) + return int(v), err + case reflect.Int8: + v, err := strconv.ParseInt(value, 10, 8) + return int8(v), err + case reflect.Int16: + v, err := strconv.ParseInt(value, 10, 16) + return int16(v), err + case reflect.Int32: + v, err := strconv.ParseInt(value, 10, 32) + return int32(v), err + case reflect.Int64: + return strconv.ParseInt(value, 10, 64) + case reflect.Uint: + v, err := strconv.ParseUint(value, 10, 0) + return uint(v), err + case reflect.Uint8: + v, err := strconv.ParseUint(value, 10, 8) + return uint8(v), err + case reflect.Uint16: + v, err := strconv.ParseUint(value, 10, 16) + return uint16(v), err + case reflect.Uint32: + v, err := strconv.ParseUint(value, 10, 32) + return uint32(v), err + case reflect.Uint64: + return strconv.ParseUint(value, 10, 0) + case reflect.Float32, reflect.Float64: + return strconv.ParseFloat(value, 10) + case reflect.Struct: + if valueType == reflectTypeTime { + if integerRegex.MatchString(value) { + i, err := strconv.Atoi(value) + if err != nil { + return "", fmt.Errorf("could not convert %s to int: %w", value, err) + } + + return int64(i), nil + } + + var v time.Time + var err error + if strings.Contains(value, "T") && strings.HasSuffix(value, "Z") { + v, err = time.Parse("2006-01-02T15:04:05Z", value) + } else if dateRegex.MatchString(value) { + // This is a Clickhouse Date field. The Clickhouse Go client doesn't convert unix timestamp into Date, + // so we just validate the format here and return a string. + _, err = time.Parse("2006-01-02", value) + if err != nil { + return "", fmt.Errorf("could not convert %s to date: %w", value, err) + } + return value, nil + } else { + v, err = time.Parse("2006-01-02 15:04:05", value) + } + if err != nil { + return "", fmt.Errorf("could not convert %s to time: %w", value, err) + } + return v.Unix(), nil + } + return "", fmt.Errorf("unsupported struct type %s", valueType) + + case reflect.Ptr: + if valueType.String() == "*big.Int" { + newInt := new(big.Int) + newInt.SetString(value, 10) + return newInt, nil + } + + elemType := valueType.Elem() + val, err := convertToType(value, elemType) + if err != nil { + return nil, fmt.Errorf("invalid pointer type: %w", err) + } + + // We cannot just return &val here as this will return an *interface{} that the Clickhouse Go client won't be + // able to convert on inserting. Instead, we create a new variable using the type that valueType has been + // pointing to, assign the converted value from convertToType to that and then return a pointer to the new variable. + result := reflect.New(elemType).Elem() + result.Set(reflect.ValueOf(val)) + return result.Addr().Interface(), nil + + default: + return value, nil + } +} + +func (d ClickhouseDialect) GetTableColumns(db *sql.DB, schemaName, tableName string) ([]*sql.ColumnType, error) { + // For TCP, use DESCRIBE TABLE to filter out AggregateFunction columns + describeQuery := fmt.Sprintf("DESCRIBE TABLE %s.%s", + EscapeIdentifier(schemaName), + EscapeIdentifier(tableName)) + + describeRows, err := db.Query(describeQuery) + if err != nil { + return nil, fmt.Errorf("describing table structure: %w", err) + } + defer describeRows.Close() + + var nonAggregateColumns []string + + // Get the column types to know how many columns DESCRIBE returns + describeColumnTypes, err := describeRows.ColumnTypes() + if err != nil { + return nil, fmt.Errorf("getting describe column types: %w", err) + } + + // Parse DESCRIBE results to filter out AggregateFunction columns + for describeRows.Next() { + // Create slice to hold all column values dynamically + values := make([]interface{}, len(describeColumnTypes)) + valuePtrs := make([]interface{}, len(describeColumnTypes)) + for i := range values { + valuePtrs[i] = &values[i] + } + + err := describeRows.Scan(valuePtrs...) + if err != nil { + return nil, fmt.Errorf("scanning describe results: %w", err) + } + + // First column is always the column name, second is the data type, + // third is the default_type (MATERIALIZED, ALIAS, DEFAULT, or empty) + name := fmt.Sprintf("%v", values[0]) + dataType := fmt.Sprintf("%v", values[1]) + defaultType := "" + if len(values) > 2 && values[2] != nil { + defaultType = fmt.Sprintf("%v", values[2]) + } + + // Skip AggregateFunction columns and MATERIALIZED columns + // MATERIALIZED columns are auto-computed and cannot be inserted into + if !strings.Contains(dataType, "AggregateFunction") && defaultType != "MATERIALIZED" { + nonAggregateColumns = append(nonAggregateColumns, EscapeIdentifier(name)) + } + } + + if err := describeRows.Err(); err != nil { + return nil, fmt.Errorf("iterating describe results: %w", err) + } + + if len(nonAggregateColumns) == 0 { + return nil, fmt.Errorf("no non-aggregate columns found in table %s.%s", schemaName, tableName) + } + + // TCP protocol works well with WHERE 1=0 + columnList := strings.Join(nonAggregateColumns, ", ") + selectQuery := fmt.Sprintf("SELECT %s FROM %s.%s WHERE 1=0", + columnList, + EscapeIdentifier(schemaName), + EscapeIdentifier(tableName)) + + rows, err := db.Query(selectQuery) + if err != nil { + return nil, fmt.Errorf("querying filtered table structure: %w", err) + } + defer rows.Close() + + return rows.ColumnTypes() +} + +func (d ClickhouseDialect) GetTablesInSchema(db *sql.DB, schemaName string) ([][2]string, error) { + // Use system.tables to query for tables in the schema + // Filter out MaterializedView as they are not regular tables and should not receive direct inserts + q := fmt.Sprintf(` + SELECT database AS table_schema, name AS table_name + FROM system.tables + WHERE database = '%s' + AND NOT is_temporary + AND engine NOT LIKE '%%View' + AND engine NOT LIKE 'System%%' + AND has_own_data != 0 + ORDER BY database, name + `, schemaName) + + rows, err := db.Query(q) + if err != nil { + return nil, fmt.Errorf("querying tables from system.tables: %w", err) + } + defer rows.Close() + + var result [][2]string + for rows.Next() { + var schema, table string + if err := rows.Scan(&schema, &table); err != nil { + return nil, fmt.Errorf("scanning table row: %w", err) + } + result = append(result, [2]string{schema, table}) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterating table rows: %w", err) + } + + return result, nil +} + +const clickhousePrimaryKeyQuery = ` + SELECT name + FROM system.columns + WHERE database = %s + AND table = %s + AND is_in_primary_key + ORDER BY position DESC` + +func (d ClickhouseDialect) GetPrimaryKey(db *sql.DB, schemaName, tableName string) ([]string, error) { + var q string + var args []interface{} + + if schemaName == "" { + q = fmt.Sprintf(clickhousePrimaryKeyQuery, "currentDatabase()", "?") + args = []interface{}{tableName} + } else { + q = fmt.Sprintf(clickhousePrimaryKeyQuery, "?", "?") + args = []interface{}{schemaName, tableName} + } + + rows, err := db.Query(q, args...) + if err != nil { + return nil, fmt.Errorf("querying primary key: %w", err) + } + defer rows.Close() + + var columns []string + for rows.Next() { + var column string + if err := rows.Scan(&column); err != nil { + return nil, fmt.Errorf("scanning primary key column: %w", err) + } + columns = append(columns, column) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterating primary key rows: %w", err) + } + + return columns, nil +} diff --git a/sink/sql/db_changes/db/dialect_postgres.go b/sink/sql/db_changes/db/dialect_postgres.go new file mode 100644 index 000000000..f2cde6d0b --- /dev/null +++ b/sink/sql/db_changes/db/dialect_postgres.go @@ -0,0 +1,757 @@ +package db + +import ( + "cmp" + "context" + "database/sql" + "encoding/json" + "fmt" + "reflect" + "slices" + "sort" + "strconv" + "strings" + "time" + + "github.com/streamingfast/cli" + sink "github.com/streamingfast/substreams/sink" + "go.uber.org/zap" + "golang.org/x/exp/maps" +) + +type PostgresDialect struct { + cursorTableName string + historyTableName string + schemaName string +} + +func NewPostgresDialect(schemaName string, cursorTableName string, historyTableName string) *PostgresDialect { + return &PostgresDialect{ + cursorTableName: cursorTableName, + historyTableName: historyTableName, + schemaName: schemaName, + } +} + +func (d PostgresDialect) Revert(tx Tx, ctx context.Context, l *Loader, lastValidFinalBlock uint64) error { + q := fmt.Sprintf(`SELECT op,table_name,pk,prev_value,block_num FROM %s WHERE "block_num" > %d ORDER BY "block_num" DESC`, + d.historyTable(d.schemaName), + lastValidFinalBlock, + ) + + rows, err := tx.QueryContext(ctx, q) + if err != nil { + return err + } + + var reversions []func() error + l.logger.Info("reverting forked block block(s)", zap.Uint64("last_valid_final_block", lastValidFinalBlock)) + if rows != nil { // rows will be nil with no error only in testing scenarios + defer rows.Close() + for rows.Next() { + var op string + var table_name string + var pk string + var prev_value_nullable sql.NullString + var block_num uint64 + if err := rows.Scan(&op, &table_name, &pk, &prev_value_nullable, &block_num); err != nil { + return fmt.Errorf("scanning row: %w", err) + } + l.logger.Debug("reverting", zap.String("operation", op), zap.String("table_name", table_name), zap.String("pk", pk), zap.Uint64("block_num", block_num)) + prev_value := prev_value_nullable.String + + // we can't call revertOp inside this loop, because it calls tx.ExecContext, + // which can't run while this query is "active" or it will silently discard the remaining rows! + reversions = append(reversions, func() error { + if err := d.revertOp(tx, ctx, op, table_name, pk, prev_value, block_num); err != nil { + return fmt.Errorf("revertOp: %w", err) + } + return nil + }) + } + if err := rows.Err(); err != nil { + return fmt.Errorf("iterating on rows from query %q: %w", q, err) + } + for _, reversion := range reversions { + if err := reversion(); err != nil { + return fmt.Errorf("execution revert operation: %w", err) + } + } + } + pruneHistory := fmt.Sprintf(`DELETE FROM %s WHERE "block_num" > %d;`, + d.historyTable(d.schemaName), + lastValidFinalBlock, + ) + + _, err = tx.ExecContext(ctx, pruneHistory) + if err != nil { + return fmt.Errorf("executing pruneHistory: %w", err) + } + return nil +} + +func (d PostgresDialect) Flush(tx Tx, ctx context.Context, l *Loader, outputModuleHash string, lastFinalBlock uint64) (int, error) { + var totalRows int + for entriesPair := l.entries.Oldest(); entriesPair != nil; entriesPair = entriesPair.Next() { + entries := entriesPair.Value + totalRows += entries.Len() + + if l.tracer.Enabled() { + l.logger.Debug("flushing table rows", zap.String("table_name", entriesPair.Key), zap.Int("row_count", entries.Len())) + } + } + + allOperations := make([]*Operation, 0, totalRows) + for entriesPair := l.entries.Oldest(); entriesPair != nil; entriesPair = entriesPair.Next() { + entries := entriesPair.Value + for entryPair := entries.Oldest(); entryPair != nil; entryPair = entryPair.Next() { + allOperations = append(allOperations, entryPair.Value) + } + } + + slices.SortFunc(allOperations, func(a, b *Operation) int { + return cmp.Compare(a.ordinal, b.ordinal) + }) + + var rowCount int + for _, entry := range allOperations { + normalQuery, undoQuery, err := d.prepareStatement(d.schemaName, entry) + if err != nil { + return 0, fmt.Errorf("failed to prepare statement: %w", err) + } + + // Execute undo query first (if present) to save state before modifying + if undoQuery != "" { + if l.tracer.Enabled() { + l.logger.Debug("adding undo query from operation to transaction", zap.Stringer("op", entry), zap.String("query", undoQuery), zap.Uint64("ordinal", entry.ordinal)) + } + + undoStart := time.Now() + if _, err := tx.ExecContext(ctx, undoQuery); err != nil { + return 0, fmt.Errorf("executing undo query %q: %w", undoQuery, err) + } + undoDuration := time.Since(undoStart) + QueryExecutionDuration.AddInt64(undoDuration.Nanoseconds(), "undo") + } + + // Execute normal query + if l.tracer.Enabled() { + l.logger.Debug("adding normal query from operation to transaction", zap.Stringer("op", entry), zap.String("query", normalQuery), zap.Uint64("ordinal", entry.ordinal)) + } + + normalStart := time.Now() + if _, err := tx.ExecContext(ctx, normalQuery); err != nil { + return 0, fmt.Errorf("executing normal query %q: %w", normalQuery, err) + } + normalDuration := time.Since(normalStart) + QueryExecutionDuration.AddInt64(normalDuration.Nanoseconds(), "normal") + + rowCount++ + } + + pruneStart := time.Now() + if err := d.pruneReversibleSegment(tx, ctx, d.schemaName, lastFinalBlock); err != nil { + return 0, err + } + pruneDuration := time.Since(pruneStart) + PruneReversibleSegmentDuration.AddInt64(pruneDuration.Nanoseconds()) + + return rowCount, nil +} + +func (d PostgresDialect) revertOp(tx Tx, ctx context.Context, op, escaped_table_name, pk, prev_value string, block_num uint64) error { + pkmap := make(map[string]string) + if err := json.Unmarshal([]byte(pk), &pkmap); err != nil { + return fmt.Errorf("revertOp: unmarshalling %q: %w", pk, err) + } + switch op { + case "I": + q := fmt.Sprintf(`DELETE FROM %s WHERE %s;`, + escaped_table_name, + getPrimaryKeyWhereClause(pkmap, ""), + ) + if _, err := tx.ExecContext(ctx, q); err != nil { + return fmt.Errorf("executing revert query %q: %w", q, err) + } + case "D": + q := fmt.Sprintf(`INSERT INTO %s SELECT * FROM json_populate_record(null::%s,%s);`, + escaped_table_name, + escaped_table_name, + escapeStringValue(prev_value), + ) + if _, err := tx.ExecContext(ctx, q); err != nil { + return fmt.Errorf("executing revert query %q: %w", q, err) + } + + case "U": + columns, err := sqlColumnNamesFromJSON(prev_value) + if err != nil { + return err + } + + q := fmt.Sprintf(`UPDATE %s SET(%s)=((SELECT %s FROM json_populate_record(null::%s,%s))) WHERE %s;`, + escaped_table_name, + columns, + columns, + escaped_table_name, + escapeStringValue(prev_value), + getPrimaryKeyWhereClause(pkmap, ""), + ) + if _, err := tx.ExecContext(ctx, q); err != nil { + return fmt.Errorf("executing revert query %q: %w", q, err) + } + default: + panic("invalid op in revert command") + } + return nil +} + +func sqlColumnNamesFromJSON(in string) (string, error) { + valueMap := make(map[string]interface{}) + if err := json.Unmarshal([]byte(in), &valueMap); err != nil { + return "", fmt.Errorf("unmarshalling %q into valueMap: %w", in, err) + } + escapedNames := make([]string, len(valueMap)) + i := 0 + for k := range valueMap { + escapedNames[i] = EscapeIdentifier(k) + i++ + } + sort.Strings(escapedNames) + + return strings.Join(escapedNames, ","), nil +} + +func (d PostgresDialect) pruneReversibleSegment(tx Tx, ctx context.Context, schema string, highestFinalBlock uint64) error { + q := fmt.Sprintf(`DELETE FROM %s WHERE block_num <= %d;`, d.historyTable(schema), highestFinalBlock) + if _, err := tx.ExecContext(ctx, q); err != nil { + return fmt.Errorf("executing prune query %q: %w", q, err) + } + return nil +} + +func (d PostgresDialect) GetCreateCursorQuery(schema string, withPostgraphile bool) string { + out := fmt.Sprintf(cli.Dedent(` + create table if not exists %s.%s + ( + id text not null constraint %s primary key, + cursor text, + block_num bigint, + block_id text + ); + `), EscapeIdentifier(schema), EscapeIdentifier(d.cursorTableName), EscapeIdentifier(d.cursorTableName+"_pk")) + if withPostgraphile { + out += fmt.Sprintf("COMMENT ON TABLE %s.%s IS E'@omit';", + EscapeIdentifier(schema), EscapeIdentifier(d.cursorTableName)) + } + return out +} + +func (d PostgresDialect) GetCreateHistoryQuery(schema string, withPostgraphile bool) string { + out := fmt.Sprintf(cli.Dedent(` + create table if not exists %s + ( + id SERIAL PRIMARY KEY, + op char, + table_name text, + pk text, + prev_value text, + block_num bigint + ); + `), + d.historyTable(schema), + ) + if withPostgraphile { + out += fmt.Sprintf("COMMENT ON TABLE %s.%s IS E'@omit';", + EscapeIdentifier(schema), EscapeIdentifier(d.historyTableName)) + } + return out +} + +func (d PostgresDialect) ExecuteSetupScript(ctx context.Context, l *Loader, schemaSql string) error { + // Prepend search_path directive to ensure user SQL runs in the correct schema context + fullSql := fmt.Sprintf(`SET search_path TO %s;`+"\n\n%s", EscapeIdentifier(d.schemaName), schemaSql) + + if _, err := l.ExecContext(ctx, fullSql); err != nil { + return fmt.Errorf("exec postgres statements: %w", err) + } + return nil +} + +func (d PostgresDialect) GetUpdateCursorQuery(table, moduleHash string, cursor *sink.Cursor, block_num uint64, block_id string) string { + return query(` + UPDATE %s set cursor = '%s', block_num = %d, block_id = '%s' WHERE id = '%s'; + `, table, cursor, block_num, block_id, moduleHash) +} + +func (d PostgresDialect) GetAllCursorsQuery(table string) string { + return fmt.Sprintf("SELECT id, cursor, block_num, block_id FROM %s", table) +} + +func (d PostgresDialect) ParseDatetimeNormalization(value string) string { + return escapeStringValue(value) +} + +func (d PostgresDialect) DriverSupportRowsAffected() bool { + return true +} + +func (d PostgresDialect) OnlyInserts() bool { + return false +} + +func (d PostgresDialect) AllowPkDuplicates() bool { + return false +} + +func (d PostgresDialect) CreateUser(tx Tx, ctx context.Context, l *Loader, username string, password string, database string, readOnly bool) error { + user, pass, db := EscapeIdentifier(username), password, EscapeIdentifier(database) + var q string + if readOnly { + q = fmt.Sprintf(` + CREATE ROLE %s LOGIN PASSWORD '%s'; + GRANT CONNECT ON DATABASE %s TO %s; + GRANT USAGE ON SCHEMA public TO %s; + ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT SELECT ON TABLES TO %s; + GRANT SELECT ON ALL TABLES IN SCHEMA public TO %s; + `, user, pass, db, user, user, user, user) + } else { + q = fmt.Sprintf("CREATE USER %s WITH PASSWORD '%s'; GRANT ALL PRIVILEGES ON DATABASE %s TO %s;", user, pass, db, user) + } + + _, err := tx.ExecContext(ctx, q) + if err != nil { + return fmt.Errorf("executing create user query %q: %w", q, err) + } + + return nil +} + +func (d PostgresDialect) historyTable(schema string) string { + return fmt.Sprintf("%s.%s", EscapeIdentifier(schema), EscapeIdentifier(d.historyTableName)) +} + +func (d PostgresDialect) saveInsert(schema string, table string, primaryKey map[string]string, blockNum uint64) string { + return fmt.Sprintf(`INSERT INTO %s (op,table_name,pk,block_num) values (%s,%s,%s,%d);`, + d.historyTable(schema), + escapeStringValue("I"), + escapeStringValue(table), + escapeStringValue(primaryKeyToJSON(primaryKey)), + blockNum, + ) +} + +func (d PostgresDialect) saveUpsert(schema string, escapedTableName string, primaryKey map[string]string, blockNum uint64) string { + schemaAndTable := fmt.Sprintf("%s.%s", EscapeIdentifier(schema), escapedTableName) + + return fmt.Sprintf(` + WITH t as (select %s) + INSERT INTO %s (op,table_name,pk,prev_value,block_num) + SELECT CASE WHEN %s THEN 'I' ELSE 'U' END AS op, %s, %s, row_to_json(%s),%d from t left join %s.%s on %s;`, + + getPrimaryKeyFakeEmptyValues(primaryKey), + d.historyTable(schema), + + getPrimaryKeyFakeEmptyValuesAssertion(primaryKey, escapedTableName), + + escapeStringValue(schemaAndTable), escapeStringValue(primaryKeyToJSON(primaryKey)), escapedTableName, blockNum, + EscapeIdentifier(schema), escapedTableName, + getPrimaryKeyWhereClause(primaryKey, escapedTableName), + ) +} + +func (d PostgresDialect) saveUpdate(schema string, escapedTableName string, primaryKey map[string]string, blockNum uint64) string { + return d.saveRow("U", schema, escapedTableName, primaryKey, blockNum) +} + +func (d PostgresDialect) saveDelete(schema string, escapedTableName string, primaryKey map[string]string, blockNum uint64) string { + return d.saveRow("D", schema, escapedTableName, primaryKey, blockNum) +} + +func (d PostgresDialect) saveRow(op, schema, escapedTableName string, primaryKey map[string]string, blockNum uint64) string { + schemaAndTable := fmt.Sprintf("%s.%s", EscapeIdentifier(schema), escapedTableName) + return fmt.Sprintf(`INSERT INTO %s (op,table_name,pk,prev_value,block_num) SELECT %s,%s,%s,row_to_json(%s),%d FROM %s.%s WHERE %s;`, + d.historyTable(schema), + escapeStringValue(op), escapeStringValue(schemaAndTable), escapeStringValue(primaryKeyToJSON(primaryKey)), escapedTableName, blockNum, + EscapeIdentifier(schema), escapedTableName, + getPrimaryKeyWhereClause(primaryKey, ""), + ) +} + +// getResultCast returns the appropriate cast suffix for the result of arithmetic operations +// based on the column's scan type. +func getResultCast(scanType reflect.Type) string { + if scanType == nil { + return "" + } + switch scanType.Kind() { + case reflect.String: + return "::text" + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Float32, reflect.Float64: + return "" + default: + return "" + } +} + +func (d *PostgresDialect) prepareStatement(schema string, o *Operation) (normalQuery string, undoQuery string, err error) { + var columns, values []string + var updateOps []UpdateOp + var scanTypes []reflect.Type + if o.opType == OperationTypeInsert || o.opType == OperationTypeUpsert || o.opType == OperationTypeUpdate { + columns, values, updateOps, scanTypes, err = d.prepareColValues(o.table, o.data) + if err != nil { + return "", "", fmt.Errorf("preparing column & values: %w", err) + } + } + + if o.opType == OperationTypeUpsert || o.opType == OperationTypeUpdate || o.opType == OperationTypeDelete { + // A table without a primary key set yield a `primaryKey` map with a single entry where the key is an empty string + if _, found := o.primaryKey[""]; found { + return "", "", fmt.Errorf("trying to perform %s operation but table %q don't have a primary key set, this is not accepted", o.opType, o.table.name) + } + } + + switch o.opType { + case OperationTypeInsert: + insertQuery := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s);", + o.table.identifier, + strings.Join(columns, ","), + strings.Join(values, ","), + ) + + if o.reversibleBlockNum != nil { + return insertQuery, d.saveInsert(schema, o.table.identifier, o.primaryKey, *o.reversibleBlockNum), nil + } + return insertQuery, "", nil + + case OperationTypeUpsert: + // Build per-field update expressions based on UpdateOp + updates := make([]string, len(columns)) + for i := range columns { + col := columns[i] + resultCast := getResultCast(scanTypes[i]) + switch updateOps[i] { + case UpdateOpSet: + updates[i] = fmt.Sprintf("%s=EXCLUDED.%s", col, col) + case UpdateOpAdd: + updates[i] = fmt.Sprintf("%s=(COALESCE(%s.%s::numeric, 0) + EXCLUDED.%s::numeric)%s", col, o.table.nameEscaped, col, col, resultCast) + case UpdateOpMax: + updates[i] = fmt.Sprintf("%s=GREATEST(COALESCE(%s.%s::numeric, 0), EXCLUDED.%s::numeric)%s", col, o.table.nameEscaped, col, col, resultCast) + case UpdateOpMin: + updates[i] = fmt.Sprintf("%s=LEAST(COALESCE(%s.%s::numeric, 0), EXCLUDED.%s::numeric)%s", col, o.table.nameEscaped, col, col, resultCast) + case UpdateOpSetIfNull: + updates[i] = fmt.Sprintf("%s=COALESCE(%s.%s, EXCLUDED.%s)", col, o.table.nameEscaped, col, col) + default: + updates[i] = fmt.Sprintf("%s=EXCLUDED.%s", col, col) + } + } + + // Escape primary key column names to preserve case sensitivity (e.g., camelCase) + escapedPKColumns := make([]string, 0, len(o.primaryKey)) + for pkColumn := range o.primaryKey { + escapedPKColumns = append(escapedPKColumns, EscapeIdentifier(pkColumn)) + } + sort.Strings(escapedPKColumns) + + insertQuery := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO UPDATE SET %s;", + o.table.identifier, + strings.Join(columns, ","), + strings.Join(values, ","), + strings.Join(escapedPKColumns, ","), + strings.Join(updates, ", "), + ) + + if o.reversibleBlockNum != nil { + return insertQuery, d.saveUpsert(schema, o.table.nameEscaped, o.primaryKey, *o.reversibleBlockNum), nil + } + return insertQuery, "", nil + + case OperationTypeUpdate: + // Build per-field update expressions based on UpdateOp + updates := make([]string, len(columns)) + for i := range columns { + col := columns[i] + val := values[i] + resultCast := getResultCast(scanTypes[i]) + switch updateOps[i] { + case UpdateOpSet: + updates[i] = fmt.Sprintf("%s=%s", col, val) + case UpdateOpAdd: + updates[i] = fmt.Sprintf("%s=(COALESCE(%s::numeric, 0) + %s::numeric)%s", col, col, val, resultCast) + case UpdateOpMax: + updates[i] = fmt.Sprintf("%s=GREATEST(COALESCE(%s::numeric, 0), %s::numeric)%s", col, col, val, resultCast) + case UpdateOpMin: + updates[i] = fmt.Sprintf("%s=LEAST(COALESCE(%s::numeric, 0), %s::numeric)%s", col, col, val, resultCast) + case UpdateOpSetIfNull: + updates[i] = fmt.Sprintf("%s=COALESCE(%s, %s)", col, col, val) + default: + updates[i] = fmt.Sprintf("%s=%s", col, val) + } + } + + primaryKeySelector := getPrimaryKeyWhereClause(o.primaryKey, "") + + updateQuery := fmt.Sprintf("UPDATE %s SET %s WHERE %s", + o.table.identifier, + strings.Join(updates, ", "), + primaryKeySelector, + ) + + if o.reversibleBlockNum != nil { + return updateQuery, d.saveUpdate(schema, o.table.nameEscaped, o.primaryKey, *o.reversibleBlockNum), nil + } + return updateQuery, "", nil + + case OperationTypeDelete: + primaryKeyWhereClause := getPrimaryKeyWhereClause(o.primaryKey, "") + deleteQuery := fmt.Sprintf("DELETE FROM %s WHERE %s", + o.table.identifier, + primaryKeyWhereClause, + ) + if o.reversibleBlockNum != nil { + return deleteQuery, d.saveDelete(schema, o.table.nameEscaped, o.primaryKey, *o.reversibleBlockNum), nil + } + return deleteQuery, "", nil + + default: + panic(fmt.Errorf("unknown operation type %q", o.opType)) + } +} + +func (d *PostgresDialect) prepareColValues(table *TableInfo, colValues map[string]FieldData) (columns []string, values []string, updateOps []UpdateOp, scanTypes []reflect.Type, err error) { + if len(colValues) == 0 { + return + } + + columns = make([]string, len(colValues)) + values = make([]string, len(colValues)) + updateOps = make([]UpdateOp, len(colValues)) + scanTypes = make([]reflect.Type, len(colValues)) + + i := 0 + for colName := range colValues { + columns[i] = colName + i++ + } + sort.Strings(columns) // sorted for determinism in tests + + for i, columnName := range columns { + fieldData := colValues[columnName] + columnInfo, found := table.columnsByName[columnName] + if !found { + return nil, nil, nil, nil, fmt.Errorf("cannot find column %q for table %q (valid columns are %q)", columnName, table.identifier, strings.Join(maps.Keys(table.columnsByName), ", ")) + } + + normalizedValue, err := d.normalizeValueType(fieldData.Value, columnInfo.scanType) + if err != nil { + return nil, nil, nil, nil, fmt.Errorf("getting sql value from table %s for column %q raw value %q: %w", table.identifier, columnName, fieldData.Value, err) + } + + values[i] = normalizedValue + columns[i] = columnInfo.escapedName // escape the column name + updateOps[i] = fieldData.UpdateOp + scanTypes[i] = columnInfo.scanType + } + return +} + +func getPrimaryKeyFakeEmptyValues(primaryKey map[string]string) string { + if len(primaryKey) == 1 { + for key := range primaryKey { + return "'' " + EscapeIdentifier(key) + } + } + + reg := make([]string, 0, len(primaryKey)) + for key := range primaryKey { + reg = append(reg, "'' "+EscapeIdentifier(key)) + } + sort.Strings(reg) + + return strings.Join(reg, ",") +} + +func getPrimaryKeyFakeEmptyValuesAssertion(primaryKey map[string]string, escapedTableName string) string { + if len(primaryKey) == 1 { + for key := range primaryKey { + return escapedTableName + "." + EscapeIdentifier(key) + " IS NULL" + } + } + + reg := make([]string, 0, len(primaryKey)) + for key := range primaryKey { + reg = append(reg, escapedTableName+"."+EscapeIdentifier(key)+" IS NULL") + } + sort.Strings(reg) + + return strings.Join(reg, " AND ") +} + +func getPrimaryKeyWhereClause(primaryKey map[string]string, escapedTableName string) string { + // Avoid any allocation if there is a single primary key + if len(primaryKey) == 1 { + for key, value := range primaryKey { + if escapedTableName == "" { + return EscapeIdentifier(key) + " = " + escapeStringValue(value) + } + + return escapedTableName + "." + EscapeIdentifier(key) + " = " + escapeStringValue(value) + } + } + + reg := make([]string, 0, len(primaryKey)) + for key, value := range primaryKey { + + if escapedTableName == "" { + reg = append(reg, EscapeIdentifier(key)+" = "+escapeStringValue(value)) + } else { + reg = append(reg, escapedTableName+"."+EscapeIdentifier(key)+" = "+escapeStringValue(value)) + } + } + sort.Strings(reg) + + return strings.Join(reg[:], " AND ") +} + +// normalizeValueType formats a value based on its type +func (d *PostgresDialect) normalizeValueType(value string, valueType reflect.Type) (string, error) { + switch valueType.Kind() { + case reflect.String: + // replace unicode null character with empty string + value = strings.ReplaceAll(value, "\u0000", "") + return escapeStringValue(value), nil + + // BYTES in Postgres must be escaped, we receive a Vec from substreams + case reflect.Slice: + return escapeStringValue(value), nil + + case reflect.Bool: + return fmt.Sprintf("'%s'", value), nil + + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return value, nil + + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return value, nil + + case reflect.Float32, reflect.Float64: + return value, nil + + case reflect.Struct: + if valueType == reflectTypeTime { + if integerRegex.MatchString(value) { + i, err := strconv.Atoi(value) + if err != nil { + return "", fmt.Errorf("could not convert %s to int: %w", value, err) + } + + return escapeStringValue(time.Unix(int64(i), 0).Format(time.RFC3339)), nil + } + + // It's a plain string, parse by dialect it and pass it to the database + return d.ParseDatetimeNormalization(value), nil + } + + return "", fmt.Errorf("unsupported struct type %s", valueType) + default: + // It's a column's type the schemaName parsing don't know how to represents as + // a Go type. In that case, we pass it unmodified to the database engine. It + // will be the responsibility of the one sending the data to correctly represent + // it in the way accepted by the database. + return value, nil + } +} + +func (d PostgresDialect) GetTableColumns(db *sql.DB, schemaName, tableName string) ([]*sql.ColumnType, error) { + q := fmt.Sprintf("SELECT * FROM %s.%s WHERE 1=0", + EscapeIdentifier(schemaName), + EscapeIdentifier(tableName)) + + rows, err := db.Query(q) + if err != nil { + return nil, fmt.Errorf("querying table structure: %w", err) + } + defer rows.Close() + + return rows.ColumnTypes() +} + +func (d PostgresDialect) GetTablesInSchema(db *sql.DB, schemaName string) ([][2]string, error) { + q := ` + SELECT table_schema, table_name + FROM information_schema.tables + WHERE table_type = 'BASE TABLE' + AND table_schema = $1 + ORDER BY table_schema, table_name + ` + + rows, err := db.Query(q, schemaName) + if err != nil { + return nil, fmt.Errorf("querying tables: %w", err) + } + defer rows.Close() + + var result [][2]string + for rows.Next() { + var schema, table string + if err := rows.Scan(&schema, &table); err != nil { + return nil, fmt.Errorf("scanning table row: %w", err) + } + result = append(result, [2]string{schema, table}) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterating table rows: %w", err) + } + + return result, nil +} + +const postgresPrimaryKeyQuery = ` + SELECT kcu.column_name + FROM information_schema.table_constraints tco + JOIN information_schema.key_column_usage kcu + ON kcu.constraint_name = tco.constraint_name + AND kcu.constraint_schema = tco.constraint_schema + AND kcu.table_name = tco.table_name + WHERE tco.constraint_type = 'PRIMARY KEY' + AND kcu.table_schema = %s + AND kcu.table_name = %s + ORDER BY kcu.ordinal_position` + +func (d PostgresDialect) GetPrimaryKey(db *sql.DB, schemaName, tableName string) ([]string, error) { + var q string + var args []any + + if schemaName == "" { + q = fmt.Sprintf(postgresPrimaryKeyQuery, "current_schema()", "$1") + args = []any{tableName} + } else { + q = fmt.Sprintf(postgresPrimaryKeyQuery, "$1", "$2") + args = []any{schemaName, tableName} + } + + rows, err := db.Query(q, args...) + if err != nil { + return nil, fmt.Errorf("querying primary key: %w", err) + } + defer rows.Close() + + var columns []string + for rows.Next() { + var column string + if err := rows.Scan(&column); err != nil { + return nil, fmt.Errorf("scanning primary key column: %w", err) + } + columns = append(columns, column) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterating primary key rows: %w", err) + } + + return columns, nil +} diff --git a/sink/sql/db_changes/db/driver.go b/sink/sql/db_changes/db/driver.go new file mode 100644 index 000000000..9ba2cc2dc --- /dev/null +++ b/sink/sql/db_changes/db/driver.go @@ -0,0 +1,6 @@ +package db + +import ( + // Register the PostgreSQL driver for use with database/sql. + _ "github.com/lib/pq" +) diff --git a/sink/sql/db_changes/db/dsn.go b/sink/sql/db_changes/db/dsn.go new file mode 100644 index 000000000..081b472bb --- /dev/null +++ b/sink/sql/db_changes/db/dsn.go @@ -0,0 +1,201 @@ +package db + +import ( + "fmt" + "iter" + "net/url" + "os" + "strconv" + "strings" + + "github.com/drone/envsubst" +) + +type DSN struct { + driver string + original string + scheme string + + Host string + Port int64 + Username string + Password string + Database string + Options DSNOptions + + // schema is the extracted schema from the DSN schemaName option (if present) + schema string +} + +var driverMap = map[string]string{ + "psql": "postgres", + "postgres": "postgres", + "clickhouse": "clickhouse", + "parquet": "parquet", +} + +func ParseDSN(dsn string) (*DSN, error) { + expanded, err := envsubst.Eval(dsn, os.Getenv) + if err != nil { + return nil, fmt.Errorf("variables expansion failed: %w", err) + } + + dsnURL, err := url.Parse(expanded) + if err != nil { + return nil, fmt.Errorf("invalid url: %w", err) + } + + driver, ok := driverMap[dsnURL.Scheme] + if !ok { + keys := make([]string, len(driverMap)) + i := 0 + for k := range driverMap { + keys[i] = k + i++ + } + + return nil, fmt.Errorf("invalid scheme %s, allowed schemes: [%s]", dsnURL.Scheme, strings.Join(keys, ",")) + } + + host := dsnURL.Hostname() + + port := int64(5432) + if strings.Contains(dsnURL.Host, ":") { + port, _ = strconv.ParseInt(dsnURL.Port(), 10, 32) + } + + username := dsnURL.User.Username() + password, _ := dsnURL.User.Password() + database := dsnURL.EscapedPath() + if database != "parquet" { + database = strings.TrimPrefix(database, "/") + } + + d := &DSN{ + original: dsn, + driver: driver, + scheme: dsnURL.Scheme, + Host: host, + Port: port, + Username: username, + Password: password, + Database: database, + Options: DSNOptions(dsnURL.Query()), + } + + schemaName := d.Options.RemoveOr("schemaName", "") + + if driver == "clickhouse" { + if schemaName != "" { + d.schema = schemaName + } else { + d.schema = database + } + } else { + if schemaName == "" { + schemaName = "public" + } + + d.schema = schemaName + } + + return d, nil +} + +func (c *DSN) Driver() string { + return c.driver +} + +func (c *DSN) ConnString() string { + if c.driver == "clickhouse" { + scheme := c.driver + host := c.Host + + baseURL := fmt.Sprintf("%s://%s:%s@%s:%d/%s", scheme, c.Username, c.Password, host, c.Port, c.Database) + if len(c.Options) > 0 { + baseURL += "?" + c.Options.Encode() + } + + return baseURL + } + options := c.Options.EncodeWithSeparator(" ") + out := fmt.Sprintf("host=%s port=%d dbname=%s %s", c.Host, c.Port, c.Database, options) + if c.Username != "" { + out = out + " user=" + c.Username + } + if c.Password != "" { + out = out + " password=" + c.Password + } + return out +} + +func (c *DSN) Schema() string { + return c.schema +} + +func (c *DSN) Clone() *DSN { + return &DSN{ + driver: c.driver, + original: c.original, + scheme: c.scheme, + Host: c.Host, + Port: c.Port, + Username: c.Username, + Password: c.Password, + Database: c.Database, + Options: c.Options, + schema: c.schema, + } +} + +// DSNOptions is a thin wrapper around url.Values to provide helper methods and +// better names. +type DSNOptions url.Values + +// Iter iterates over the first value of each key. +func (v DSNOptions) Iter() iter.Seq2[string, string] { + return func(yield func(k string, v string) bool) { + for k, vs := range v { + if len(vs) > 0 { + if !yield(k, vs[0]) { + return + } + } + } + } +} + +// Encode encodes the values into "URL encoded" form ("bar=baz&foo=quux") sorted by key. +func (v DSNOptions) Encode() string { + return (url.Values(v)).Encode() +} + +// EncodeWithSeparator encodes the values into "URL encoded" like form ("bar=baz foo=quux") sorted by key +// where essentially the separator is used instead of '&'. +func (v DSNOptions) EncodeWithSeparator(sep string) string { + return strings.ReplaceAll((url.Values(v)).Encode(), "&", sep) +} + +// Get returns the value associated with the key. +func (v DSNOptions) Get(key string) string { + return (url.Values(v)).Get(key) +} + +// GetOr returns the value associated with the key or defaultValue if not found. +func (v DSNOptions) GetOr(key, defaultValue string) string { + if val := (url.Values(v)).Get(key); val != "" { + return val + } + + return defaultValue +} + +// RemoveOr removes the key from the options and returns its value or defaultValue if not found. +func (v DSNOptions) RemoveOr(key, defaultValue string) string { + val := (url.Values(v)).Get(key) + (url.Values(v)).Del(key) + if val != "" { + return val + } + return defaultValue +} diff --git a/sink/sql/db_changes/db/flush.go b/sink/sql/db_changes/db/flush.go new file mode 100644 index 000000000..86eb738a1 --- /dev/null +++ b/sink/sql/db_changes/db/flush.go @@ -0,0 +1,83 @@ +package db + +import ( + "context" + "fmt" + "time" + + "github.com/ClickHouse/clickhouse-go/v2" + "github.com/streamingfast/logging/zapx" + sink "github.com/streamingfast/substreams/sink" + "go.uber.org/zap" +) + +func (l *Loader) Flush(ctx context.Context, outputModuleHash string, cursor *sink.Cursor, lastFinalBlock uint64) (rowFlushedCount int, err error) { + ctx = clickhouse.Context(context.Background(), clickhouse.WithStdAsync(false)) + + startAt := time.Now() + tx, err := l.BeginTx(ctx, nil) + if err != nil { + return 0, fmt.Errorf("failed to being db transaction: %w", err) + } + defer func() { + if err != nil { + if err := tx.Rollback(); err != nil { + l.logger.Warn("failed to rollback transaction", zap.Error(err)) + } + } + }() + + rowFlushedCount, err = l.dialect.Flush(tx, ctx, l, outputModuleHash, lastFinalBlock) + if err != nil { + return 0, fmt.Errorf("dialect flush: %w", err) + } + + rowFlushedCount += 1 + if err := l.UpdateCursor(ctx, tx, outputModuleHash, cursor); err != nil { + return 0, fmt.Errorf("update cursor: %w", err) + } + + if err := tx.Commit(); err != nil { + return 0, fmt.Errorf("failed to commit db transaction: %w", err) + } + l.reset() + + l.logger.Debug("flushed table(s) rows to database", zap.Int("table_count", l.entries.Len()+1), zap.Int("row_count", rowFlushedCount), zapx.HumanDuration("took", time.Since(startAt))) + return rowFlushedCount, nil +} + +func (l *Loader) Revert(ctx context.Context, outputModuleHash string, cursor *sink.Cursor, lastValidBlock uint64) error { + tx, err := l.BeginTx(ctx, nil) + if err != nil { + return fmt.Errorf("failed to being db transaction: %w", err) + } + defer func() { + if err != nil { + if err := tx.Rollback(); err != nil { + l.logger.Warn("failed to rollback transaction", zap.Error(err)) + } + } + }() + + if err := l.dialect.Revert(tx, ctx, l, lastValidBlock); err != nil { + return err + } + + if err := l.UpdateCursor(ctx, tx, outputModuleHash, cursor); err != nil { + return fmt.Errorf("update cursor after revert: %w", err) + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf("failed to commit db transaction: %w", err) + } + + l.logger.Debug("reverted changes to database", zap.Uint64("last_valid_block", lastValidBlock)) + return nil +} + +func (l *Loader) reset() { + for entriesPair := l.entries.Oldest(); entriesPair != nil; entriesPair = entriesPair.Next() { + l.entries.Set(entriesPair.Key, NewOrderedMap[string, *Operation]()) + } + l.batchOrdinal = 0 +} diff --git a/sink/sql/db_changes/db/metrics.go b/sink/sql/db_changes/db/metrics.go new file mode 100644 index 000000000..44160d074 --- /dev/null +++ b/sink/sql/db_changes/db/metrics.go @@ -0,0 +1,14 @@ +package db + +import ( + "github.com/streamingfast/dmetrics" +) + +var metrics = dmetrics.NewSet(dmetrics.PrefixNameWith("substreams_sink_sql")) + +var QueryExecutionDuration = metrics.NewCounterVec("tx_query_execution_duration", []string{"query_type"}, "The amount of time spent executing queries by type (normal/undo) in nanoseconds") +var PruneReversibleSegmentDuration = metrics.NewCounter("prune_reversible_segment_duration", "The amount of time spent pruning reversible segment in nanoseconds") + +func RegisterMetrics() { + metrics.Register() +} diff --git a/sink/sql/db_changes/db/operations.go b/sink/sql/db_changes/db/operations.go new file mode 100644 index 000000000..1ff3a6402 --- /dev/null +++ b/sink/sql/db_changes/db/operations.go @@ -0,0 +1,336 @@ +package db + +import ( + "encoding/json" + "fmt" + "math/big" + "reflect" + "regexp" + "strings" + "time" +) + +type TypeGetter func(tableName string, columnName string) (reflect.Type, error) + +type Queryable interface { + query(d Dialect) (string, error) +} + +type OperationType string + +const ( + OperationTypeInsert OperationType = "INSERT" + OperationTypeUpsert OperationType = "UPSERT" + OperationTypeUpdate OperationType = "UPDATE" + OperationTypeDelete OperationType = "DELETE" +) + +// UpdateOp defines the operation to apply when updating a field on conflict +type UpdateOp int32 + +const ( + UpdateOpSet UpdateOp = 0 // Direct assignment: col = value + UpdateOpAdd UpdateOp = 1 // Accumulate: col = COALESCE(col, 0) + value + UpdateOpMax UpdateOp = 2 // Maximum: col = GREATEST(COALESCE(col, 0), value) + UpdateOpMin UpdateOp = 3 // Minimum: col = LEAST(COALESCE(col, 0), value) + UpdateOpSetIfNull UpdateOp = 4 // Set only if NULL: col = COALESCE(col, value) +) + +// FieldData holds a field's value and its update operation +type FieldData struct { + Value string + UpdateOp UpdateOp +} + +type Operation struct { + table *TableInfo + opType OperationType + primaryKey map[string]string + data map[string]FieldData + ordinal uint64 + reversibleBlockNum *uint64 // nil if that block is known to be irreversible +} + +func (o *Operation) String() string { + return fmt.Sprintf("%s/%s (%s)", o.table.identifier, createRowUniqueID(o.primaryKey), strings.ToLower(string(o.opType))) +} + +func (l *Loader) newInsertOperation(table *TableInfo, primaryKey map[string]string, data map[string]FieldData, ordinal uint64, reversibleBlockNum *uint64) *Operation { + return &Operation{ + table: table, + opType: OperationTypeInsert, + primaryKey: primaryKey, + data: data, + ordinal: ordinal, + reversibleBlockNum: reversibleBlockNum, + } +} + +func (l *Loader) newUpsertOperation(table *TableInfo, primaryKey map[string]string, data map[string]FieldData, ordinal uint64, reversibleBlockNum *uint64) *Operation { + return &Operation{ + table: table, + opType: OperationTypeUpsert, + primaryKey: primaryKey, + data: data, + ordinal: ordinal, + reversibleBlockNum: reversibleBlockNum, + } +} + +func (l *Loader) newUpdateOperation(table *TableInfo, primaryKey map[string]string, data map[string]FieldData, ordinal uint64, reversibleBlockNum *uint64) *Operation { + return &Operation{ + table: table, + opType: OperationTypeUpdate, + primaryKey: primaryKey, + data: data, + ordinal: ordinal, + reversibleBlockNum: reversibleBlockNum, + } +} + +func (l *Loader) newDeleteOperation(table *TableInfo, primaryKey map[string]string, ordinal uint64, reversibleBlockNum *uint64) *Operation { + return &Operation{ + table: table, + opType: OperationTypeDelete, + primaryKey: primaryKey, + ordinal: ordinal, + reversibleBlockNum: reversibleBlockNum, + } +} + +func (o *Operation) mergeData(newData map[string]FieldData) error { + if o.opType == OperationTypeDelete { + return fmt.Errorf("unable to merge data for a delete operation") + } + + for k, fd := range newData { + existing, exists := o.data[k] + if !exists { + o.data[k] = fd + continue + } + + // Validate transition based on strict rules (consistent with Rust library) + // SET can be followed by any op, but non-SET ops can only be followed by same type + if err := validateOpTransition(k, existing.UpdateOp, fd.UpdateOp); err != nil { + return err + } + + // Handle each incoming operation type + switch fd.UpdateOp { + case UpdateOpSet: + // SET: latest value wins, overwrites any previous operation + o.data[k] = fd + + case UpdateOpAdd: + // ADD: accumulate values (valid after SET or ADD) + existingDec, err1 := parseDecimal(existing.Value) + newDec, err2 := parseDecimal(fd.Value) + if err1 == nil && err2 == nil { + o.data[k] = FieldData{ + Value: existingDec.Add(newDec).String(), + UpdateOp: existing.UpdateOp, // Keep existing op: SET stays SET, ADD stays ADD + } + } else { + // Non-numeric: latest value wins + o.data[k] = fd + } + + case UpdateOpMax: + // MAX: compute maximum (valid after SET or MAX) + existingDec, err1 := parseDecimal(existing.Value) + newDec, err2 := parseDecimal(fd.Value) + if err1 == nil && err2 == nil { + maxVal := existingDec + if newDec.Cmp(existingDec.Rat) > 0 { + maxVal = newDec + } + o.data[k] = FieldData{ + Value: maxVal.String(), + UpdateOp: existing.UpdateOp, // Keep existing op: SET stays SET, MAX stays MAX + } + } else { + // Non-numeric: latest value wins + o.data[k] = fd + } + + case UpdateOpMin: + // MIN: compute minimum (valid after SET or MIN) + existingDec, err1 := parseDecimal(existing.Value) + newDec, err2 := parseDecimal(fd.Value) + if err1 == nil && err2 == nil { + minVal := existingDec + if newDec.Cmp(existingDec.Rat) < 0 { + minVal = newDec + } + o.data[k] = FieldData{ + Value: minVal.String(), + UpdateOp: existing.UpdateOp, // Keep existing op: SET stays SET, MIN stays MIN + } + } else { + // Non-numeric: latest value wins + o.data[k] = fd + } + + case UpdateOpSetIfNull: + // SET_IF_NULL: keep existing value (first value wins) + // Field already exists, so keep it and don't overwrite + continue + } + } + return nil +} + +// validateOpTransition checks if the transition from existing to incoming op is valid. +// Returns an error for invalid transitions. +// +// Valid transitions: +// - SET → any op: OK +// - any op → SET: OK (SET always overwrites) +// - ADD → ADD: OK (accumulates) +// - MAX → MAX: OK (computes max) +// - MIN → MIN: OK (computes min) +// - SET_IF_NULL → SET_IF_NULL: OK (first value wins) +// +// All other transitions are invalid. +func validateOpTransition(fieldName string, existing, incoming UpdateOp) error { + // SET can be followed by any operation + if existing == UpdateOpSet { + return nil + } + + // Any operation can be followed by SET (SET overwrites) + if incoming == UpdateOpSet { + return nil + } + + // Non-SET ops can only be followed by the same op type + if existing == incoming { + return nil + } + + // Invalid transition + return fmt.Errorf( + "invalid UpdateOp transition for field %q: cannot apply %s after %s (only %s \u2192 %s or SET \u2192 %s is allowed)", + fieldName, + updateOpName(incoming), + updateOpName(existing), + updateOpName(existing), + updateOpName(existing), + updateOpName(incoming), + ) +} + +func updateOpName(op UpdateOp) string { + switch op { + case UpdateOpSet: + return "SET" + case UpdateOpAdd: + return "ADD" + case UpdateOpMax: + return "MAX" + case UpdateOpMin: + return "MIN" + case UpdateOpSetIfNull: + return "SET_IF_NULL" + default: + return fmt.Sprintf("UNKNOWN(%d)", op) + } +} + +func parseDecimal(s string) (decimal, error) { + // Simple decimal parsing - just use big.Rat for precision + var d decimal + _, ok := d.SetString(s) + if !ok { + return decimal{}, fmt.Errorf("invalid decimal: %s", s) + } + return d, nil +} + +// decimal is a simple wrapper around big.Rat for delta accumulation +type decimal struct { + *big.Rat +} + +func (d *decimal) SetString(s string) (*decimal, bool) { + if d.Rat == nil { + d.Rat = new(big.Rat) + } + _, ok := d.Rat.SetString(s) + return d, ok +} + +func (d decimal) Add(other decimal) decimal { + result := new(big.Rat) + result.Add(d.Rat, other.Rat) + return decimal{result} +} + +func (d decimal) Sub(other decimal) decimal { + result := new(big.Rat) + result.Sub(d.Rat, other.Rat) + return decimal{result} +} + +func (d decimal) Neg() decimal { + result := new(big.Rat) + result.Neg(d.Rat) + return decimal{result} +} + +func (d decimal) Sign() int { + return d.Rat.Sign() +} + +func (d decimal) String() string { + return d.Rat.FloatString(18) +} + +// mergeOperation merges another operation into this one +func (o *Operation) mergeOperation(otherData map[string]FieldData) error { + if o.opType == OperationTypeDelete { + return fmt.Errorf("unable to merge operation for a delete operation") + } + + return o.mergeData(otherData) +} + +var integerRegex = regexp.MustCompile(`^\d+$`) +var dateRegex = regexp.MustCompile(`^\d{4}-\d{2}-\d{2}$`) +var reflectTypeTime = reflect.TypeOf(time.Time{}) + +func EscapeIdentifier(valueToEscape string) string { + if strings.Contains(valueToEscape, `"`) { + valueToEscape = strings.ReplaceAll(valueToEscape, `"`, `""`) + } + + return `"` + valueToEscape + `"` +} + +func escapeStringValue(valueToEscape string) string { + if strings.Contains(valueToEscape, `'`) { + valueToEscape = strings.ReplaceAll(valueToEscape, `'`, `''`) + } + + return `'` + valueToEscape + `'` +} + +// primaryKeyToJSON serializes primary key to JSON for history storage +func primaryKeyToJSON(primaryKey map[string]string) string { + m, err := json.Marshal(primaryKey) + if err != nil { + panic(err) // should never happen with map[string]string + } + return string(m) +} + +// jsonToPrimaryKey deserializes primary key from JSON +func jsonToPrimaryKey(in string) (map[string]string, error) { + out := make(map[string]string) + err := json.Unmarshal([]byte(in), &out) + if err != nil { + return nil, err + } + return out, nil +} diff --git a/sink/sql/db_changes/db/ops.go b/sink/sql/db_changes/db/ops.go new file mode 100644 index 000000000..2ede4261c --- /dev/null +++ b/sink/sql/db_changes/db/ops.go @@ -0,0 +1,279 @@ +package db + +import ( + "fmt" + "maps" + "slices" + "strings" + + "go.uber.org/zap" +) + +// Insert a row in the DB, it is assumed the table exists, you can do a +// check before with HasTable() +func (l *Loader) Insert(tableName string, primaryKey map[string]string, data map[string]FieldData, reversibleBlockNum *uint64) error { + uniqueID := createRowUniqueID(primaryKey) + + if l.tracer.Enabled() { + l.logger.Debug("processing insert operation", zap.String("table_name", tableName), zap.String("primary_key", uniqueID), zap.Int("field_count", len(data))) + } + + table, found := l.tables[tableName] + if !found { + return fmt.Errorf("unknown table %q", tableName) + } + + entry, found := l.entries.Get(tableName) + if !found { + if l.tracer.Enabled() { + l.logger.Debug("adding tracking of table never seen before", zap.String("table_name", tableName)) + } + + entry = NewOrderedMap[string, *Operation]() + l.entries.Set(tableName, entry) + } + + if operation, found := entry.Get(uniqueID); found { + switch operation.opType { + case OperationTypeInsert: + if !l.dialect.AllowPkDuplicates() { + return fmt.Errorf("attempting to insert in table %q a primary key %q, that is already scheduled for insertion, insert should only be called once for a given primary key", tableName, primaryKey) + } + case OperationTypeDelete: + return fmt.Errorf("attempting to insert an object with primary key %q, that is scheduled to be deleted", primaryKey) + case OperationTypeUpdate: + return fmt.Errorf("attempting to insert an object with primary key %q, that is scheduled to be updated", primaryKey) + case OperationTypeUpsert: + return fmt.Errorf("attempting to insert an object with primary key %q, that is scheduled to be upserted", primaryKey) + } + } + + if l.tracer.Enabled() { + l.logger.Debug("primary key entry never existed for table, adding insert operation", zap.String("primary_key", uniqueID), zap.String("table_name", tableName)) + } + + // We need to make sure to add the primary key(s) in the data so that those column get created correctly, but only if there is data + for _, primary := range l.tables[tableName].primaryColumns { + if dataFromPrimaryKey, ok := primaryKey[primary.name]; ok { + data[primary.name] = FieldData{Value: dataFromPrimaryKey, UpdateOp: UpdateOpSet} + } + } + + entry.Set(uniqueID, l.newInsertOperation(table, primaryKey, data, l.NextBatchOrdinal(), reversibleBlockNum)) + l.entriesCount++ + return nil +} + +func createRowUniqueID(m map[string]string) string { + if len(m) == 1 { + for _, v := range m { + return v + } + } + + keys := slices.Collect(maps.Keys(m)) + slices.Sort(keys) + + values := make([]string, len(keys)) + for i, key := range keys { + values[i] = m[key] + } + + return strings.Join(values, "/") +} + +func (l *Loader) GetPrimaryKey(tableName string, pk string) (map[string]string, error) { + primaryKeyColumns := l.tables[tableName].primaryColumns + + switch len(primaryKeyColumns) { + case 0: + return nil, fmt.Errorf("substreams sent a single primary key, but our sql table has none, this is unsupported") + case 1: + return map[string]string{primaryKeyColumns[0].name: pk}, nil + } + + cols := make([]string, len(primaryKeyColumns)) + for i := range primaryKeyColumns { + cols[i] = primaryKeyColumns[i].name + } + return nil, fmt.Errorf("substreams sent a single primary key, but our sql table has a composite primary key (columns: %s), this is unsupported", strings.Join(cols, ",")) +} + +// Upsert a row in the DB, it is assumed the table exists, you can do a +// check before with HasTable(). +func (l *Loader) Upsert(tableName string, primaryKey map[string]string, data map[string]FieldData, reversibleBlockNum *uint64) error { + if l.dialect.OnlyInserts() { + return fmt.Errorf("update operation is not supported by the current database") + } + + uniqueID := createRowUniqueID(primaryKey) + if l.tracer.Enabled() { + l.logger.Debug("processing update operation", zap.String("table_name", tableName), zap.String("primary_key", uniqueID), zap.Int("field_count", len(data))) + } + + table, found := l.tables[tableName] + if !found { + return fmt.Errorf("unknown table %q", tableName) + } + + if len(table.primaryColumns) == 0 { + return fmt.Errorf("trying to perform an UPSERT operation but table %q don't have a primary key(s) set, this is not accepted", tableName) + } + + entry, found := l.entries.Get(tableName) + if !found { + if l.tracer.Enabled() { + l.logger.Debug("adding tracking of table never seen before", zap.String("table_name", tableName)) + } + + entry = NewOrderedMap[string, *Operation]() + l.entries.Set(tableName, entry) + } + + if op, found := entry.Get(uniqueID); found { + switch op.opType { + case OperationTypeInsert: + return fmt.Errorf("attempting to upsert an object with primary key %q, that is scheduled to be inserted, insert and upsert are exclusive", primaryKey) + case OperationTypeDelete: + return fmt.Errorf("attempting to upsert an object with primary key %q, that is scheduled to be deleted", primaryKey) + case OperationTypeUpdate: + // Accept existing update operation but change it to upsert, merge columns together + op.opType = OperationTypeUpsert + case OperationTypeUpsert: + // Fine, merge columns together + } + + if l.tracer.Enabled() { + l.logger.Debug("primary key entry already exist for table, merging columns together", zap.String("primary_key", uniqueID), zap.String("table_name", tableName)) + } + + op.mergeOperation(data) + entry.Set(uniqueID, op) + return nil + } else { + l.entriesCount++ + } + + if l.tracer.Enabled() { + l.logger.Debug("primary key entry never existed for table, adding upsert operation", zap.String("primary_key", uniqueID), zap.String("table_name", tableName)) + } + + // We need to make sure to add the primary key(s) in the data so that those column get created correctly, but only if there is data + for _, primary := range l.tables[tableName].primaryColumns { + if dataFromPrimaryKey, ok := primaryKey[primary.name]; ok { + data[primary.name] = FieldData{Value: dataFromPrimaryKey, UpdateOp: UpdateOpSet} + } + } + + entry.Set(uniqueID, l.newUpsertOperation(table, primaryKey, data, l.NextBatchOrdinal(), reversibleBlockNum)) + return nil +} + +// Update a row in the DB, it is assumed the table exists, you can do a +// check before with HasTable() +func (l *Loader) Update(tableName string, primaryKey map[string]string, data map[string]FieldData, reversibleBlockNum *uint64) error { + if l.dialect.OnlyInserts() { + return fmt.Errorf("update operation is not supported by the current database") + } + + uniqueID := createRowUniqueID(primaryKey) + if l.tracer.Enabled() { + l.logger.Debug("processing update operation", zap.String("table_name", tableName), zap.String("primary_key", uniqueID), zap.Int("field_count", len(data))) + } + + table, found := l.tables[tableName] + if !found { + return fmt.Errorf("unknown table %q", tableName) + } + + if len(table.primaryColumns) == 0 { + return fmt.Errorf("trying to perform an UPDATE operation but table %q don't have a primary key(s) set, this is not accepted", tableName) + } + + entry, found := l.entries.Get(tableName) + if !found { + if l.tracer.Enabled() { + l.logger.Debug("adding tracking of table never seen before", zap.String("table_name", tableName)) + } + + entry = NewOrderedMap[string, *Operation]() + l.entries.Set(tableName, entry) + } + + if op, found := entry.Get(uniqueID); found { + switch op.opType { + case OperationTypeInsert: + // Column is scheduled to be inserted, simply add our fields to the insert without changing its Insert type + case OperationTypeDelete: + return fmt.Errorf("attempting to update an object with primary key %q, that is scheduled to be deleted", primaryKey) + case OperationTypeUpdate: + // Fine, merge columns together + case OperationTypeUpsert: + // Accept existing upsert and our columns to it, but not change its type to keep it as an upsert + } + + if l.tracer.Enabled() { + l.logger.Debug("primary key entry already exist for table, merging fields together", zap.String("primary_key", uniqueID), zap.String("table_name", tableName)) + } + + op.mergeOperation(data) + entry.Set(uniqueID, op) + return nil + } else { + l.entriesCount++ + } + + if l.tracer.Enabled() { + l.logger.Debug("primary key entry never existed for table, adding update operation", zap.String("primary_key", uniqueID), zap.String("table_name", tableName)) + } + + entry.Set(uniqueID, l.newUpdateOperation(table, primaryKey, data, l.NextBatchOrdinal(), reversibleBlockNum)) + return nil +} + +// Delete a row in the DB, it is assumed the table exists, you can do a +// check before with HasTable() +func (l *Loader) Delete(tableName string, primaryKey map[string]string, reversibleBlockNum *uint64) error { + if l.dialect.OnlyInserts() { + return fmt.Errorf("delete operation is not supported by the current database") + } + + uniqueID := createRowUniqueID(primaryKey) + if l.tracer.Enabled() { + l.logger.Debug("processing delete operation", zap.String("table_name", tableName), zap.String("primary_key", uniqueID)) + } + + table, found := l.tables[tableName] + if !found { + return fmt.Errorf("unknown table %q", tableName) + } + + if len(table.primaryColumns) == 0 { + return fmt.Errorf("trying to perform a DELETE operation but table %q don't have a primary key(s) set, this is not accepted", tableName) + } + + entry, found := l.entries.Get(tableName) + if !found { + if l.tracer.Enabled() { + l.logger.Debug("adding tracking of table never seen before", zap.String("table_name", tableName)) + } + + entry = NewOrderedMap[string, *Operation]() + l.entries.Set(tableName, entry) + } + + if _, found := entry.Get(uniqueID); !found { + if l.tracer.Enabled() { + l.logger.Debug("primary key entry never existed for table", zap.String("primary_key", uniqueID), zap.String("table_name", tableName)) + } + + l.entriesCount++ + } + + if l.tracer.Enabled() { + l.logger.Debug("adding deleting operation", zap.String("primary_key", uniqueID), zap.String("table_name", tableName)) + } + + entry.Set(uniqueID, l.newDeleteOperation(table, primaryKey, l.NextBatchOrdinal(), reversibleBlockNum)) + return nil +} diff --git a/sink/sql/db_changes/db/testtx.go b/sink/sql/db_changes/db/testtx.go new file mode 100644 index 000000000..96009bdab --- /dev/null +++ b/sink/sql/db_changes/db/testtx.go @@ -0,0 +1,18 @@ +package db + +import ( + "context" + "database/sql" +) + +// TestTx is a stub Tx implementation used in testing. +type TestTx struct{} + +func (t *TestTx) Rollback() error { return nil } +func (t *TestTx) Commit() error { return nil } +func (t *TestTx) ExecContext(_ context.Context, _ string, _ ...any) (sql.Result, error) { + return nil, nil +} +func (t *TestTx) QueryContext(_ context.Context, _ string, _ ...any) (*sql.Rows, error) { + return nil, nil +} diff --git a/sink/sql/db_changes/db/types.go b/sink/sql/db_changes/db/types.go new file mode 100644 index 000000000..1caf0e95d --- /dev/null +++ b/sink/sql/db_changes/db/types.go @@ -0,0 +1,74 @@ +package db + +import ( + "fmt" + "reflect" +) + +//go:generate go-enum -f=$GOFILE --marshal --names -nocase + +// ENUM( +// +// Ignore +// Warn +// Error +// +// ) +type OnModuleHashMismatch uint + +type TableInfo struct { + schema string + schemaEscaped string + name string + nameEscaped string + columnsByName map[string]*ColumnInfo + primaryColumns []*ColumnInfo + + // Identifier is equivalent to 'escape().escape()' but pre-computed + // for usage when computing queries. + identifier string +} + +func NewTableInfo(schema, name string, pkList []string, columnsByName map[string]*ColumnInfo) (*TableInfo, error) { + schemaEscaped := EscapeIdentifier(schema) + nameEscaped := EscapeIdentifier(name) + primaryColumns := make([]*ColumnInfo, len(pkList)) + + for i, primaryKeyColumnName := range pkList { + primaryColumn, found := columnsByName[primaryKeyColumnName] + if !found { + return nil, fmt.Errorf("primary key column %q not found", primaryKeyColumnName) + } + primaryColumns[i] = primaryColumn + + } + if len(primaryColumns) == 0 { + return nil, fmt.Errorf("sql sink requires a primary key in every table, none was found in table %s.%s", schema, name) + } + + return &TableInfo{ + schema: schema, + schemaEscaped: schemaEscaped, + name: name, + nameEscaped: nameEscaped, + identifier: schemaEscaped + "." + nameEscaped, + primaryColumns: primaryColumns, + columnsByName: columnsByName, + }, nil +} + +type ColumnInfo struct { + name string + escapedName string + databaseTypeName string + scanType reflect.Type +} + +func NewColumnInfo(name string, databaseTypeName string, scanType any) *ColumnInfo { + return &ColumnInfo{ + name: name, + escapedName: EscapeIdentifier(name), + databaseTypeName: databaseTypeName, + scanType: reflect.TypeOf(scanType), + } +} diff --git a/sink/sql/db_changes/db/types_enum.go b/sink/sql/db_changes/db/types_enum.go new file mode 100644 index 000000000..8ba13a75a --- /dev/null +++ b/sink/sql/db_changes/db/types_enum.go @@ -0,0 +1,96 @@ +// Code generated by go-enum DO NOT EDIT. +// Version: +// Revision: +// Build Date: +// Built By: + +package db + +import ( + "fmt" + "strings" +) + +const ( + // OnModuleHashMismatchIgnore is a OnModuleHashMismatch of type Ignore. + OnModuleHashMismatchIgnore OnModuleHashMismatch = iota + // OnModuleHashMismatchWarn is a OnModuleHashMismatch of type Warn. + OnModuleHashMismatchWarn + // OnModuleHashMismatchError is a OnModuleHashMismatch of type Error. + OnModuleHashMismatchError +) + +var ErrInvalidOnModuleHashMismatch = fmt.Errorf("not a valid OnModuleHashMismatch, try [%s]", strings.Join(_OnModuleHashMismatchNames, ", ")) + +const _OnModuleHashMismatchName = "IgnoreWarnError" + +var _OnModuleHashMismatchNames = []string{ + _OnModuleHashMismatchName[0:6], + _OnModuleHashMismatchName[6:10], + _OnModuleHashMismatchName[10:15], +} + +// OnModuleHashMismatchNames returns a list of possible string values of OnModuleHashMismatch. +func OnModuleHashMismatchNames() []string { + tmp := make([]string, len(_OnModuleHashMismatchNames)) + copy(tmp, _OnModuleHashMismatchNames) + return tmp +} + +var _OnModuleHashMismatchMap = map[OnModuleHashMismatch]string{ + OnModuleHashMismatchIgnore: _OnModuleHashMismatchName[0:6], + OnModuleHashMismatchWarn: _OnModuleHashMismatchName[6:10], + OnModuleHashMismatchError: _OnModuleHashMismatchName[10:15], +} + +// String implements the Stringer interface. +func (x OnModuleHashMismatch) String() string { + if str, ok := _OnModuleHashMismatchMap[x]; ok { + return str + } + return fmt.Sprintf("OnModuleHashMismatch(%d)", x) +} + +// IsValid provides a quick way to determine if the typed value is +// part of the allowed enumerated values +func (x OnModuleHashMismatch) IsValid() bool { + _, ok := _OnModuleHashMismatchMap[x] + return ok +} + +var _OnModuleHashMismatchValue = map[string]OnModuleHashMismatch{ + _OnModuleHashMismatchName[0:6]: OnModuleHashMismatchIgnore, + strings.ToLower(_OnModuleHashMismatchName[0:6]): OnModuleHashMismatchIgnore, + _OnModuleHashMismatchName[6:10]: OnModuleHashMismatchWarn, + strings.ToLower(_OnModuleHashMismatchName[6:10]): OnModuleHashMismatchWarn, + _OnModuleHashMismatchName[10:15]: OnModuleHashMismatchError, + strings.ToLower(_OnModuleHashMismatchName[10:15]): OnModuleHashMismatchError, +} + +// ParseOnModuleHashMismatch attempts to convert a string to a OnModuleHashMismatch. +func ParseOnModuleHashMismatch(name string) (OnModuleHashMismatch, error) { + if x, ok := _OnModuleHashMismatchValue[name]; ok { + return x, nil + } + // Case insensitive parse, do a separate lookup to prevent unnecessary cost of lowercasing a string if we don't need to. + if x, ok := _OnModuleHashMismatchValue[strings.ToLower(name)]; ok { + return x, nil + } + return OnModuleHashMismatch(0), fmt.Errorf("%s is %w", name, ErrInvalidOnModuleHashMismatch) +} + +// MarshalText implements the text marshaller method. +func (x OnModuleHashMismatch) MarshalText() ([]byte, error) { + return []byte(x.String()), nil +} + +// UnmarshalText implements the text unmarshaller method. +func (x *OnModuleHashMismatch) UnmarshalText(text []byte) error { + name := string(text) + tmp, err := ParseOnModuleHashMismatch(name) + if err != nil { + return err + } + *x = tmp + return nil +} diff --git a/sink/sql/db_changes/db/user.go b/sink/sql/db_changes/db/user.go new file mode 100644 index 000000000..0d26d174b --- /dev/null +++ b/sink/sql/db_changes/db/user.go @@ -0,0 +1,34 @@ +package db + +import ( + "context" + "fmt" + + "go.uber.org/zap" +) + +func (l *Loader) CreateUser(ctx context.Context, username string, password string, database string, readOnly bool) (err error) { + tx, err := l.BeginTx(ctx, nil) + if err != nil { + return fmt.Errorf("failed to being db transaction: %w", err) + } + defer func() { + if err != nil { + if err := tx.Rollback(); err != nil { + l.logger.Warn("failed to rollback transaction", zap.Error(err)) + } + } + }() + + err = l.dialect.CreateUser(tx, ctx, l, username, password, database, readOnly) + if err != nil { + return fmt.Errorf("create user: %w", err) + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf("failed to commit db transaction: %w", err) + } + l.reset() + + return nil +} diff --git a/sink/sql/db_changes/sinker/factory.go b/sink/sql/db_changes/sinker/factory.go new file mode 100644 index 000000000..b011573e7 --- /dev/null +++ b/sink/sql/db_changes/sinker/factory.go @@ -0,0 +1,72 @@ +package sinker + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/streamingfast/logging" + sink "github.com/streamingfast/substreams/sink" + db "github.com/streamingfast/substreams/sink/sql/db_changes/db" + "go.uber.org/zap" +) + +type SinkerFactoryFunc func(ctx context.Context, dsnString string, logger *zap.Logger, tracer logging.Tracer) (*SQLSinker, error) + +type SinkerFactoryOptions struct { + CursorTableName string + HistoryTableName string + ClickhouseCluster string + BatchBlockFlushInterval int + BatchRowFlushInterval int + LiveBlockFlushInterval int + OnModuleHashMismatch string + HandleReorgs bool + FlushRetryCount int + FlushRetryDelay time.Duration +} + +func SinkerFactory( + baseSink *sink.Sinker, + options SinkerFactoryOptions, +) SinkerFactoryFunc { + return func(ctx context.Context, dsnString string, logger *zap.Logger, tracer logging.Tracer) (*SQLSinker, error) { + dsn, err := db.ParseDSN(dsnString) + if err != nil { + return nil, fmt.Errorf("parsing dsn: %w", err) + } + + dbLoader, err := db.NewLoader( + dsn, + options.CursorTableName, + options.HistoryTableName, + options.ClickhouseCluster, + options.BatchBlockFlushInterval, + options.BatchRowFlushInterval, + options.LiveBlockFlushInterval, + options.OnModuleHashMismatch, + &options.HandleReorgs, + logger, + tracer, + ) + if err != nil { + return nil, fmt.Errorf("creating loader: %w", err) + } + + if err := dbLoader.LoadTables(dsn.Schema(), options.CursorTableName, options.HistoryTableName); err != nil { + var e *db.SystemTableError + if errors.As(err, &e) { + return nil, fmt.Errorf("error validating the system table: %w. Did you run setup?", e) + } + return nil, fmt.Errorf("load tables: %w", err) + } + + sinker, err := New(baseSink, dbLoader, logger, tracer, options.FlushRetryCount, options.FlushRetryDelay) + if err != nil { + return nil, fmt.Errorf("unable to setup SQL sinker: %w", err) + } + + return sinker, nil + } +} diff --git a/sink/sql/db_changes/sinker/metrics.go b/sink/sql/db_changes/sinker/metrics.go new file mode 100644 index 000000000..af115900e --- /dev/null +++ b/sink/sql/db_changes/sinker/metrics.go @@ -0,0 +1,19 @@ +package sinker + +import ( + "github.com/streamingfast/dmetrics" + db "github.com/streamingfast/substreams/sink/sql/db_changes/db" +) + +func RegisterMetrics() { + metrics.Register() + db.RegisterMetrics() +} + +var metrics = dmetrics.NewSet() + +var FlushCount = metrics.NewCounter("substreams_sink_postgres_store_flush_count", "The amount of flush that happened so far") +var FlushedRowsCount = metrics.NewCounter("substreams_sink_postgres_flushed_rows_count", "The number of flushed rows so far") +var FlushDuration = metrics.NewCounter("substreams_sink_postgres_store_flush_duration", "The amount of time spent flushing cache to db (in nanoseconds)") +var FlushedHeadBlockNumber = metrics.NewHeadBlockNumber("substreams_sink_postgres") +var FlushedHeadBlockTimeDrift = metrics.NewHeadTimeDrift("substreams_sink_postgres") diff --git a/sink/sql/db_changes/sinker/setup.go b/sink/sql/db_changes/sinker/setup.go new file mode 100644 index 000000000..a43a3bc8a --- /dev/null +++ b/sink/sql/db_changes/sinker/setup.go @@ -0,0 +1,99 @@ +package sinker + +import ( + "context" + "errors" + "fmt" + + "github.com/lib/pq" + "github.com/streamingfast/logging" + sinksql "github.com/streamingfast/substreams/sink/sql" + db2 "github.com/streamingfast/substreams/sink/sql/db_changes/db" + pbsubstreams "github.com/streamingfast/substreams/pb/sf/substreams/v1" + "go.uber.org/zap" +) + +const ( + deprecated_supportedDeployableService = "type.googleapis.com/sf.substreams.sink.sql.v1.Service" + supportedDeployableService = "type.googleapis.com/sf.substreams.sink.sql.service.v1.Service" +) + +// SinkerSetupOptions contains configuration for the setup operation +type SinkerSetupOptions struct { + CursorTableName string + HistoryTableName string + ClickhouseCluster string + OnModuleHashMismatch string + SystemTablesOnly bool + IgnoreDuplicateTableErrors bool + Postgraphile bool +} + +// SinkerSetup sets up the required infrastructure for a Substreams SQL sink +func SinkerSetup( + ctx context.Context, + dsnString string, + pkg *pbsubstreams.Package, + options SinkerSetupOptions, + logger *zap.Logger, + tracer logging.Tracer, +) error { + sinkConfig, err := sinksql.ExtractSinkService(pkg) + if err != nil { + return fmt.Errorf("extract sink config: %w", err) + } + + dsn, err := db2.ParseDSN(dsnString) + if err != nil { + return fmt.Errorf("parse dsn: %w", err) + } + + handleReorgs := false + dbLoader, err := db2.NewLoader( + dsn, + options.CursorTableName, + options.HistoryTableName, + options.ClickhouseCluster, + 0, 0, 0, + options.OnModuleHashMismatch, + &handleReorgs, + logger, tracer, + ) + if err != nil { + return fmt.Errorf("creating loader: %w", err) + } + defer dbLoader.Close() + + userSQLSchema := sinkConfig.Schema + if options.SystemTablesOnly { + userSQLSchema = "" + } + + err = dbLoader.Setup(ctx, dsn.Schema(), userSQLSchema, options.Postgraphile) + if err != nil { + if isDuplicateTableError(err) && options.IgnoreDuplicateTableErrors { + logger.Info("received duplicate table error, script did not execute successfully") + } else { + return fmt.Errorf("setup: %w", err) + } + } + logger.Info("setup completed successfully") + return nil +} + +// isDuplicateTableError checks if the error is a PostgreSQL duplicate table error +func isDuplicateTableError(err error) bool { + var sqlError *pq.Error + if !errors.As(err, &sqlError) { + return false + } + + // List at https://www.postgresql.org/docs/14/errcodes-appendix.html#ERRCODES-TABLE + switch sqlError.Code { + // Error code named `duplicate_table` + case "42P07": + return true + } + + return false +} diff --git a/sink/sql/db_changes/sinker/sinker.go b/sink/sql/db_changes/sinker/sinker.go new file mode 100644 index 000000000..3615f2bf3 --- /dev/null +++ b/sink/sql/db_changes/sinker/sinker.go @@ -0,0 +1,384 @@ +package sinker + +import ( + "context" + "errors" + "fmt" + "strings" + "time" + + "github.com/streamingfast/logging" + "github.com/streamingfast/logging/zapx" + "github.com/streamingfast/shutter" + sink "github.com/streamingfast/substreams/sink" + pbdatabase "github.com/streamingfast/substreams/pb/sf/substreams/sink/database/v1" + db2 "github.com/streamingfast/substreams/sink/sql/db_changes/db" + pbsubstreamsrpc "github.com/streamingfast/substreams/pb/sf/substreams/rpc/v2" + "go.uber.org/zap" + "google.golang.org/protobuf/proto" +) + +const BLOCK_FLUSH_INTERVAL_DISABLED = 0 + +type SQLSinker struct { + *shutter.Shutter + *sink.Sinker + + loader *db2.Loader + logger *zap.Logger + tracer logging.Tracer + + stats *Stats + lastAppliedBlockNum uint64 + lastAppliedBlockTime time.Time + + flushRetryCount int + flushRetryDelay time.Duration +} + +func New(sink *sink.Sinker, loader *db2.Loader, logger *zap.Logger, tracer logging.Tracer, flushRetryCount int, flushRetryDelay time.Duration) (*SQLSinker, error) { + return &SQLSinker{ + Shutter: shutter.New(), + Sinker: sink, + + loader: loader, + logger: logger, + tracer: tracer, + + stats: NewStats(logger), + lastAppliedBlockNum: 0, + flushRetryCount: flushRetryCount, + flushRetryDelay: flushRetryDelay, + }, nil +} + +func (s *SQLSinker) Close() error { + if s.IsTerminated() { + return nil + } + + s.logger.Info("closing SQL sinker") + if err := s.loader.Close(); err != nil { + return fmt.Errorf("loader close: %w", err) + } + + s.Shutdown(nil) + return nil +} + +func (s *SQLSinker) Run(ctx context.Context) { + cursor, mismatchDetected, err := s.loader.GetCursor(ctx, s.OutputModuleHash()) + if err != nil && !errors.Is(err, db2.ErrCursorNotFound) { + s.Shutdown(fmt.Errorf("unable to retrieve cursor: %w", err)) + return + } + + // We write an empty cursor right away in the database because the flush logic + // only performs an `update` operation so an initial cursor is required in the database + // for the flush to work correctly. + if errors.Is(err, db2.ErrCursorNotFound) { + if err := s.loader.InsertCursor(ctx, s.OutputModuleHash(), sink.NewBlankCursor()); err != nil { + s.Shutdown(fmt.Errorf("unable to write initial empty cursor: %w", err)) + return + } + + } else if mismatchDetected { + if err := s.loader.InsertCursor(ctx, s.OutputModuleHash(), cursor); err != nil { + s.Shutdown(fmt.Errorf("unable to write new cursor after module mismatch: %w", err)) + return + } + } + + // Works in all cases, even if the cursor is blank or nil (gives 0) + s.lastAppliedBlockNum = cursor.Block().Num() + + s.Sinker.OnTerminating(s.Shutdown) + s.OnTerminating(func(err error) { + s.stats.LogNow() + s.logger.Info("sql sinker terminating", zap.Stringer("last_block_written", s.stats.lastBlock)) + s.Sinker.Shutdown(err) + }) + + s.OnTerminating(func(_ error) { s.stats.Close() }) + s.stats.OnTerminated(func(err error) { s.Shutdown(err) }) + + logEach := 15 * time.Second + if s.logger.Core().Enabled(zap.DebugLevel) { + logEach = 5 * time.Second + } + + s.stats.Start(logEach, cursor) + + s.logger.Info("starting sql sink", + zapx.HumanDuration("stats_refresh_each", logEach), + zap.Stringer("restarting_at", cursor.Block()), + zap.String("loader", s.loader.GetIdentifier()), + ) + s.Sinker.Run(ctx, cursor, s) +} + +func (s *SQLSinker) flushWithRetry(ctx context.Context, moduleHash string, cursor *sink.Cursor, finalBlockHeight uint64, retries int) (int, error) { + var lastErr error + for attempt := 0; attempt <= retries; attempt++ { + if attempt > 0 { + // Do not retry if flush delay is 0, useful in tests + if s.flushRetryDelay == 0 { + return 0, lastErr + } + + delay := time.Duration(attempt) * s.flushRetryDelay + s.logger.Warn("retrying flush after error", + zap.Int("attempt", attempt), + zap.Int("max_retries", retries), + zapx.HumanDuration("delay", delay), + zap.Error(lastErr)) + + select { + case <-ctx.Done(): + return 0, ctx.Err() + case <-time.After(delay): + } + } + + rowCount, err := s.loader.Flush(ctx, moduleHash, cursor, finalBlockHeight) + if err == nil { + if attempt > 0 { + s.logger.Info("flush succeeded after retry", zap.Int("attempt", attempt)) + } + return rowCount, nil + } + lastErr = err + } + + return 0, fmt.Errorf("flush failed after %d retries: %w", retries, lastErr) +} + +func (s *SQLSinker) HandleBlockScopedData(ctx context.Context, data *pbsubstreamsrpc.BlockScopedData, isLive *bool, cursor *sink.Cursor) error { + blockReceivedAt := time.Now() + + output := data.Output + + if output.Name == "" { + return nil + } + + if output.Name != s.OutputModuleName() { + return fmt.Errorf("received data from wrong output module, expected to received from %q but got module's output for %q", s.OutputModuleName(), output.Name) + } + + dbChanges := &pbdatabase.DatabaseChanges{} + mapOutput := output.GetMapOutput() + + if mapOutput.String() != "" { + if !mapOutput.MessageIs(dbChanges) && mapOutput.TypeUrl != "type.googleapis.com/sf.substreams.database.v1.DatabaseChanges" { + return fmt.Errorf("mismatched message type: trying to unmarshal unknown type %q", mapOutput.MessageName()) + } + + // We do not use UnmarshalTo here because we need to parse an older proto type and + // UnmarshalTo enforces the type check. So we check manually the `TypeUrl` above and we use + // `Unmarshal` instead which only deals with the bytes value. + if err := proto.Unmarshal(mapOutput.Value, dbChanges); err != nil { + return fmt.Errorf("unmarshal database changes: %w", err) + } + + if err := s.applyDatabaseChanges(dbChanges, data.Clock.Number, data.FinalBlockHeight); err != nil { + return fmt.Errorf("apply database changes: %w", err) + } + } + + batchModulo := s.batchBlockModulo(isLive) + blockFlushNeeded := batchModulo > 0 && data.Clock.Number-s.lastAppliedBlockNum >= batchModulo + + s.logger.Debug("flush condition evaluation", + zap.Uint64("batch_modulo", batchModulo), + zap.Uint64("current_block", data.Clock.Number), + zap.Uint64("last_applied_block", s.lastAppliedBlockNum), + zap.Uint64("block_diff", data.Clock.Number-s.lastAppliedBlockNum), + zap.Bool("block_flush_needed_before_timing_check", blockFlushNeeded)) + + if blockFlushNeeded && isLive != nil && *isLive && s.stats.AverageFlushDuration() > data.Clock.Timestamp.AsTime().Sub(s.lastAppliedBlockTime) { + s.logger.Debug("skipping a flush because we are LIVE and flush average duration is above time between blocks", zapx.HumanDuration("flush_duration_average", s.stats.AverageFlushDuration()), zap.Time("last_block_time", s.lastAppliedBlockTime), zap.Time("block_time", data.Clock.Timestamp.AsTime())) + blockFlushNeeded = false + } + + rowFlushNeeded := s.loader.FlushNeeded() + s.logger.Debug("final flush decision", + zap.Bool("block_flush_needed", blockFlushNeeded), + zap.Bool("row_flush_needed", rowFlushNeeded)) + + if blockFlushNeeded || rowFlushNeeded { + s.logger.Debug("flushing to database", + zap.Stringer("block", cursor.Block()), + zap.Uint64("last_flushed_block", s.lastAppliedBlockNum), + zap.Bool("is_live", *isLive), + zap.Bool("block_flush_interval_reached", blockFlushNeeded), + zap.Bool("row_flush_interval_reached", rowFlushNeeded), + ) + + flushStart := time.Now() + rowFlushedCount, err := s.flushWithRetry(ctx, s.OutputModuleHash(), cursor, data.FinalBlockHeight, s.flushRetryCount) + if err != nil { + return fmt.Errorf("failed to flush at block %s: %w", cursor.Block(), err) + } + + flushDuration := time.Since(flushStart) + handleBlockDuration := time.Since(blockReceivedAt) + + if flushDuration > 5*time.Second { + level := zap.InfoLevel + if flushDuration > 30*time.Second { + level = zap.WarnLevel + } + + s.logger.Check(level, "flush to database took a long time to complete, could cause long sync time along the road").Write(zapx.HumanDuration("took", flushDuration)) + } + + FlushCount.Inc() + FlushedRowsCount.AddInt(rowFlushedCount) + FlushDuration.AddInt64(flushDuration.Nanoseconds()) + FlushedHeadBlockTimeDrift.SetBlockTime(data.Clock.GetTimestamp().AsTime()) + FlushedHeadBlockNumber.SetUint64(data.Clock.GetNumber()) + + s.stats.RecordBlock(cursor.Block()) + s.stats.RecordFlushDuration(flushDuration) + s.stats.RecordHandleBlockDuration(handleBlockDuration) + s.lastAppliedBlockNum = data.Clock.Number + s.lastAppliedBlockTime = data.Clock.Timestamp.AsTime() + } + + return nil +} + +func (s *SQLSinker) applyDatabaseChanges(dbChanges *pbdatabase.DatabaseChanges, blockNum, finalBlockNum uint64) error { + for _, change := range dbChanges.TableChanges { + if !s.loader.HasTable(change.Table) { + return fmt.Errorf( + "your Substreams sent us a change for a table named %s we don't know about on %s (available tables: %s)", + change.Table, + s.loader.GetIdentifier(), + strings.Join(s.loader.GetAvailableTablesInSchema(), ", "), + ) + } + + var primaryKeys map[string]string + switch u := change.PrimaryKey.(type) { + case *pbdatabase.TableChange_Pk: + var err error + primaryKeys, err = s.loader.GetPrimaryKey(change.Table, u.Pk) + if err != nil { + return err + } + case *pbdatabase.TableChange_CompositePk: + primaryKeys = u.CompositePk.Keys + default: + return fmt.Errorf("unknown primary key type: %T", change.PrimaryKey) + } + + changes := map[string]db2.FieldData{} + for _, field := range change.Fields { + changes[field.Name] = db2.FieldData{ + Value: field.Value, + UpdateOp: protoUpdateOpToDbUpdateOp(field.UpdateOp), + } + } + + var reversibleBlockNum *uint64 + if blockNum > finalBlockNum { + reversibleBlockNum = &blockNum + } + + switch change.Operation { + case pbdatabase.TableChange_OPERATION_CREATE: + err := s.loader.Insert(change.Table, primaryKeys, changes, reversibleBlockNum) + if err != nil { + return fmt.Errorf("database insert: %w", err) + } + case pbdatabase.TableChange_OPERATION_UPSERT: + err := s.loader.Upsert(change.Table, primaryKeys, changes, reversibleBlockNum) + if err != nil { + return fmt.Errorf("database upsert: %w", err) + } + case pbdatabase.TableChange_OPERATION_UPDATE: + err := s.loader.Update(change.Table, primaryKeys, changes, reversibleBlockNum) + if err != nil { + return fmt.Errorf("database update: %w", err) + } + case pbdatabase.TableChange_OPERATION_DELETE: + err := s.loader.Delete(change.Table, primaryKeys, reversibleBlockNum) + if err != nil { + return fmt.Errorf("database delete: %w", err) + } + default: + } + } + + return nil +} + +// protoUpdateOpToDbUpdateOp converts proto Field_UpdateOp to db UpdateOp +func protoUpdateOpToDbUpdateOp(op pbdatabase.Field_UpdateOp) db2.UpdateOp { + switch op { + case pbdatabase.Field_UPDATE_OP_ADD: + return db2.UpdateOpAdd + case pbdatabase.Field_UPDATE_OP_MAX: + return db2.UpdateOpMax + case pbdatabase.Field_UPDATE_OP_MIN: + return db2.UpdateOpMin + case pbdatabase.Field_UPDATE_OP_SET_IF_NULL: + return db2.UpdateOpSetIfNull + default: + return db2.UpdateOpSet + } +} + +func (s *SQLSinker) HandleBlockRangeCompletion(ctx context.Context, cursor *sink.Cursor) error { + // To be moved in the base sinker library, happens usually only on integration tests where the connection + // can close with "nil" error but we haven't completed the range for real yet. + stopBlock := s.Sinker.StopBlock() + if stopBlock > 0 && cursor.Block().Num() < stopBlock { + s.logger.Debug("range not completed yet, skipping", zap.Stringer("block", cursor.Block()), zap.Uint64("stop_block", stopBlock)) + return nil + } + + s.logger.Info("stream completed, flushing to database", zap.Stringer("block", cursor.Block())) + _, err := s.flushWithRetry(ctx, s.OutputModuleHash(), cursor, cursor.Block().Num(), s.flushRetryCount) + if err != nil { + return fmt.Errorf("failed to flush %s block on completion: %w", cursor.Block(), err) + } + + return nil +} + +func (s *SQLSinker) HandleBlockUndoSignal(ctx context.Context, data *pbsubstreamsrpc.BlockUndoSignal, cursor *sink.Cursor) error { + handlerStart := time.Now() + + err := s.loader.Revert(ctx, s.OutputModuleHash(), cursor, data.LastValidBlock.Number) + if err != nil { + return err + } + + handleUndoDuration := time.Since(handlerStart) + s.stats.RecordHandleUndoDuration(handleUndoDuration) + + return nil +} + +func (s *SQLSinker) batchBlockModulo(isLive *bool) uint64 { + if isLive == nil { + panic(fmt.Errorf("liveness checker has been disabled on the Sinker instance, this is invalid in the context of 'substreams-sink-sql'")) + } + + if *isLive { + return uint64(s.loader.LiveBlockFlushInterval()) + } + + if s.loader.BatchBlockFlushInterval() > 0 { + return uint64(s.loader.BatchBlockFlushInterval()) + } + + return BLOCK_FLUSH_INTERVAL_DISABLED +} + +func ptr[T any](v T) *T { + return &v +} diff --git a/sink/sql/db_changes/sinker/stats.go b/sink/sql/db_changes/sinker/stats.go new file mode 100644 index 000000000..1f9a3aca2 --- /dev/null +++ b/sink/sql/db_changes/sinker/stats.go @@ -0,0 +1,114 @@ +package sinker + +import ( + "time" + + "github.com/streamingfast/bstream" + "github.com/streamingfast/dmetrics" + "github.com/streamingfast/shutter" + sink "github.com/streamingfast/substreams/sink" + "go.uber.org/zap" +) + +type Stats struct { + *shutter.Shutter + + dbFlushRate *dmetrics.AvgRatePromCounter + dbFlushAvgDuration *dmetrics.AvgDurationCounter + flushedRows *dmetrics.ValueFromMetric + dbFlushedRowsRate *dmetrics.AvgRatePromCounter + handleBlockDuration *dmetrics.AvgDurationCounter + handleUndoDuration *dmetrics.AvgDurationCounter + hasUndoSegments bool + lastBlock bstream.BlockRef + logger *zap.Logger +} + +func NewStats(logger *zap.Logger) *Stats { + return &Stats{ + Shutter: shutter.New(), + + dbFlushRate: dmetrics.MustNewAvgRateFromPromCounter(FlushCount, 1*time.Second, 30*time.Second, "flush"), + dbFlushAvgDuration: dmetrics.NewAvgDurationCounter(30*time.Second, dmetrics.InferUnit, "per flush"), + flushedRows: dmetrics.NewValueFromMetric(FlushedRowsCount, "rows"), + dbFlushedRowsRate: dmetrics.MustNewAvgRateFromPromCounter(FlushedRowsCount, 1*time.Second, 30*time.Second, "flushed rows"), + handleBlockDuration: dmetrics.NewAvgDurationCounter(30*time.Second, dmetrics.InferUnit, "per block"), + handleUndoDuration: dmetrics.NewAvgDurationCounter(30*time.Second, dmetrics.InferUnit, "per undo"), + logger: logger, + + lastBlock: unsetBlockRef{}, + } +} + +func (s *Stats) RecordBlock(block bstream.BlockRef) { + s.lastBlock = block +} + +func (s *Stats) AverageFlushDuration() time.Duration { + return s.dbFlushAvgDuration.Average() +} + +func (s *Stats) RecordFlushDuration(duration time.Duration) { + s.dbFlushAvgDuration.AddDuration(duration) +} + +func (s *Stats) RecordHandleBlockDuration(duration time.Duration) { + s.handleBlockDuration.AddDuration(duration) +} + +func (s *Stats) RecordHandleUndoDuration(duration time.Duration) { + s.handleUndoDuration.AddDuration(duration) + s.hasUndoSegments = true +} + +func (s *Stats) Start(each time.Duration, cursor *sink.Cursor) { + if !cursor.IsBlank() { + s.lastBlock = cursor.Block() + } + + if s.IsTerminating() || s.IsTerminated() { + panic("already shutdown, refusing to start again") + } + + go func() { + ticker := time.NewTicker(each) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + s.LogNow() + case <-s.Terminating(): + return + } + } + }() +} + +func (s *Stats) LogNow() { + fields := []zap.Field{ + zap.Stringer("db_flush_rate", s.dbFlushRate), + zap.Stringer("db_flush_duration_rate", s.dbFlushAvgDuration), + zap.Stringer("db_flushed_rows_rate", s.dbFlushedRowsRate), + zap.Stringer("handle_block_duration", s.handleBlockDuration), + } + + // Only log undo metrics if we've had any undo operations (typically in live mode) + if s.hasUndoSegments { + fields = append(fields, zap.Stringer("handle_undo_duration", s.handleUndoDuration)) + } + + fields = append(fields, zap.Stringer("last_block", s.lastBlock)) + + s.logger.Info("postgres sink stats", fields...) +} + +func (s *Stats) Close() { + s.Shutdown(nil) +} + +type unsetBlockRef struct{} + +func (unsetBlockRef) ID() string { return "" } +func (unsetBlockRef) Num() uint64 { return 0 } +func (unsetBlockRef) String() string { return "" } diff --git a/sink/sql/db_changes/state/file.go b/sink/sql/db_changes/state/file.go new file mode 100644 index 000000000..f2eb33f65 --- /dev/null +++ b/sink/sql/db_changes/state/file.go @@ -0,0 +1,232 @@ +package state + +import ( + "bytes" + "context" + "fmt" + "os" + "sync" + "time" + + "github.com/streamingfast/bstream" + "github.com/streamingfast/dhammer" + "github.com/streamingfast/dstore" + "github.com/streamingfast/shutter" + sink "github.com/streamingfast/substreams/sink" + "github.com/streamingfast/substreams/sink/sql/db_changes/bundler/writer" + "go.uber.org/zap" + "gopkg.in/yaml.v3" +) + +var _ Store = (*FileStateStore)(nil) + +type FileStateStore struct { + *shutter.Shutter + + startOnce sync.Once + + outputPath string + outputStore dstore.Store + uploadQueue *dhammer.Nailer + + logger *zap.Logger + + state *FileState +} + +func NewFileStateStore( + outputPath string, + outputStore dstore.Store, + logger *zap.Logger, +) (*FileStateStore, error) { + s := &FileState{} + + content, err := os.ReadFile(outputPath) + if err != nil && !os.IsNotExist(err) { + return nil, fmt.Errorf("read file: %w", err) + } + if err != nil && os.IsNotExist(err) { + s = newFileState() + } + + if err := yaml.Unmarshal(content, s); err != nil { + return nil, fmt.Errorf("unmarshal state file %q: %w", outputPath, err) + } + outputStore.SetOverwrite(true) + f := &FileStateStore{ + Shutter: shutter.New(), + outputPath: outputPath, + outputStore: outputStore, + state: s, + logger: logger, + } + f.uploadQueue = dhammer.NewNailer(5, f.uploadCursor, dhammer.NailerLogger(logger)) + return f, nil +} + +func (s *FileStateStore) Start(ctx context.Context) { + s.OnTerminating(func(err error) { + s.logger.Info("shutting down file cursor", zap.String("store", s.outputPath), zap.Error(err)) + s.Close() + }) + + s.uploadQueue.Start(ctx) + + go func() { + for v := range s.uploadQueue.Out { + bf := v.(*cursorFile) + s.logger.Debug("uploaded file", zap.String("filename", bf.name)) + } + if s.uploadQueue.Err() != nil { + s.Shutdown(fmt.Errorf("upload queue failed: %w", s.uploadQueue.Err())) + } + }() + + s.uploadQueue.OnTerminating(func(err error) { + s.Shutdown(fmt.Errorf("upload queue failed: %w", s.uploadQueue.Err())) + }) +} + +type cursorFile struct { + name string + file writer.Uploadeable +} + +func (s *FileStateStore) uploadCursor(ctx context.Context, v interface{}) (interface{}, error) { + bf := v.(*cursorFile) + + outputPath, err := bf.file.Upload(ctx, s.outputStore) + if err != nil { + return nil, fmt.Errorf("unable to upload: %w", err) + } + s.logger.Debug("boundary file uploaded", + zap.String("boundary", bf.name), + zap.String("output_path", outputPath), + ) + + return bf, nil +} + +type localFile struct { + localFilePath string + outputFilename string +} + +func (l *localFile) Upload(ctx context.Context, store dstore.Store) (string, error) { + if err := store.PushLocalFile(ctx, l.localFilePath, l.outputFilename); err != nil { + return "", fmt.Errorf("pushing object: %w", err) + } + return store.ObjectPath(l.outputFilename), nil +} + +func (s *FileStateStore) UploadCursor(saveable Saveable) { + s.uploadQueue.In <- &cursorFile{ + name: "cursor.yaml", + file: saveable.GetUploadeable(), + } +} + +func (s *FileStateStore) Close() { + s.uploadQueue.Close() + s.logger.Debug("waiting till queue is drained") + s.uploadQueue.WaitUntilEmpty(context.Background()) +} + +func (s *FileStateStore) ReadCursor(ctx context.Context) (cursor *sink.Cursor, err error) { + fl, err := s.outputStore.OpenObject(ctx, "state.yaml") + if err != nil && err != dstore.ErrNotFound { + return nil, fmt.Errorf("opening csv: %w", err) + } + + if err != nil && err == dstore.ErrNotFound { + s.state = newFileState() + } else { + defer fl.Close() + buf := new(bytes.Buffer) + buf.ReadFrom(fl) + content := buf.Bytes() + + if err := yaml.Unmarshal(content, s.state); err != nil { + return nil, fmt.Errorf("unmarshal state file %q: %w", s.outputPath, err) + } + } + + return sink.NewCursor(s.state.Cursor) +} + +func (s *FileStateStore) NewBoundary(boundary *bstream.Range) { + s.state.ActiveBoundary.StartBlockNumber = boundary.StartBlock() + s.state.ActiveBoundary.EndBlockNumber = *boundary.EndBlock() +} + +func (s *FileStateStore) SetCursor(cursor *sink.Cursor) { + s.startOnce.Do(func() { + restartAt := time.Now() + if s.state.StartedAt.IsZero() { + s.state.StartedAt = restartAt + } + s.state.RestartedAt = restartAt + }) + + s.state.Cursor = cursor.String() + s.state.Block = BlockState{ + ID: cursor.Block().ID(), + Number: cursor.Block().Num(), + } +} + +func (s *FileStateStore) GetState() (Saveable, error) { + cnt, err := yaml.Marshal(s.state) + if err != nil { + return nil, fmt.Errorf("marshall: %w", err) + } + return &stateInstance{ + data: cnt, + path: s.outputPath + "-" + s.state.Block.ID, + }, nil +} + +type FileState struct { + Cursor string `yaml:"cursor" json:"cursor"` + Block BlockState `yaml:"block" json:"block"` + ActiveBoundary ActiveBoundary `yaml:"active_boundary" json:"active_boundary"` + + StartedAt time.Time `yaml:"started_at,omitempty" json:"started_at,omitempty"` + RestartedAt time.Time `yaml:"restarted_at,omitempty" json:"restarted_at,omitempty"` +} + +func newFileState() *FileState { + return &FileState{ + Cursor: "", + Block: BlockState{"", 0}, + } +} + +type BlockState struct { + ID string `yaml:"id" json:"id"` + Number uint64 `yaml:"number" json:"number"` +} + +type ActiveBoundary struct { + StartBlockNumber uint64 `yaml:"start_block_number" json:"start_block_number"` + EndBlockNumber uint64 `yaml:"end_block_number" json:"end_block_number"` +} + +type stateInstance struct { + data []byte + path string +} + +func (s *stateInstance) GetUploadeable() writer.Uploadeable { + return &localFile{ + localFilePath: s.path, + outputFilename: "state.yaml", + } +} + +func (s *stateInstance) Save() error { + if err := os.WriteFile(s.path, s.data, os.ModePerm); err != nil { + return fmt.Errorf("unable to write state file: %w", err) + } + return nil +} diff --git a/sink/sql/db_changes/state/interface.go b/sink/sql/db_changes/state/interface.go new file mode 100644 index 000000000..0e4c06a93 --- /dev/null +++ b/sink/sql/db_changes/state/interface.go @@ -0,0 +1,26 @@ +package state + +import ( + "context" + + "github.com/streamingfast/bstream" + sink "github.com/streamingfast/substreams/sink" + "github.com/streamingfast/substreams/sink/sql/db_changes/bundler/writer" +) + +type Store interface { + Start(context.Context) + Close() + NewBoundary(*bstream.Range) + ReadCursor(context.Context) (*sink.Cursor, error) + SetCursor(*sink.Cursor) + GetState() (Saveable, error) + UploadCursor(state Saveable) + Shutdown(error) + OnTerminating(func(error)) +} + +type Saveable interface { + Save() error + GetUploadeable() writer.Uploadeable +} diff --git a/sink/sql/db_proto/proto/utils.go b/sink/sql/db_proto/proto/utils.go new file mode 100644 index 000000000..151d4b2c0 --- /dev/null +++ b/sink/sql/db_proto/proto/utils.go @@ -0,0 +1,77 @@ +package proto + +import ( + "fmt" + "maps" + "slices" + "strings" + + "github.com/jhump/protoreflect/desc" + v1 "github.com/streamingfast/substreams/pb/sf/substreams/v1" + "google.golang.org/protobuf/types/descriptorpb" +) + +func FileDescriptorForOutputType(spkg *v1.Package, err error, deps map[string]*desc.FileDescriptor, outputType string) (*desc.FileDescriptor, error) { + for _, p := range spkg.ProtoFiles { + fd, err := desc.CreateFileDescriptor(p, slices.Collect(maps.Values(deps))...) + if err != nil { + return nil, fmt.Errorf("creating file descriptor: %w", err) + } + + for _, md := range fd.GetMessageTypes() { + if md.GetFullyQualifiedName() == outputType { + return fd, nil + } + } + } + + return nil, fmt.Errorf("could not find file descriptor") +} + +func ModuleOutputType(spkg *v1.Package, moduleName string) string { + outputType := "" + for _, m := range spkg.Modules.Modules { + if m.Name == moduleName { + outputType = strings.TrimPrefix(m.Output.Type, "proto:") + break + } + } + return outputType +} +func ResolveDependencies(protoFiles map[string]*descriptorpb.FileDescriptorProto) (map[string]*desc.FileDescriptor, error) { + out := map[string]*desc.FileDescriptor{} + for _, protoFile := range protoFiles { + err := resolveDependencies(protoFile, protoFiles, out) + if err != nil { + return nil, fmt.Errorf("error resolving dependencies: %w", err) + } + } + + return out, nil +} + +func resolveDependencies(protoFile *descriptorpb.FileDescriptorProto, protoFiles map[string]*descriptorpb.FileDescriptorProto, deps map[string]*desc.FileDescriptor) error { + if deps[protoFile.GetName()] != nil { + return nil + } + if len(protoFile.Dependency) != 0 { + for _, dep := range protoFile.Dependency { + depProtoFile, found := protoFiles[dep] + if !found { + return fmt.Errorf("could not find proto file for dependency %q", dep) + } + err := resolveDependencies(depProtoFile, protoFiles, deps) + if err != nil { + return fmt.Errorf("error resolving dependencies: %w", err) + } + } + } + + d, err := desc.CreateFileDescriptor(protoFile, slices.Collect(maps.Values(deps))...) + if err != nil { + return fmt.Errorf("creating file descriptor: %w", err) + } + + deps[protoFile.GetName()] = d + return nil +} diff --git a/sink/sql/db_proto/sinker.go b/sink/sql/db_proto/sinker.go new file mode 100644 index 000000000..fb7f3499e --- /dev/null +++ b/sink/sql/db_proto/sinker.go @@ -0,0 +1,271 @@ +package db_proto + +import ( + "context" + "fmt" + "strings" + "sync" + "time" + + "github.com/streamingfast/logging/zapx" + sink "github.com/streamingfast/substreams/sink" + sql "github.com/streamingfast/substreams/sink/sql/db_proto/sql" + "github.com/streamingfast/substreams/sink/sql/db_proto/stats" + pbsubstreamsrpc "github.com/streamingfast/substreams/pb/sf/substreams/rpc/v2" + "go.uber.org/zap" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/dynamicpb" +) + +type multiError []error + +func (m multiError) Error() string { + msgs := make([]string, len(m)) + for i, e := range m { + msgs[i] = e.Error() + } + return strings.Join(msgs, "; ") +} + +type Sinker struct { + *sink.Sinker + db sql.Database + useTransaction bool + parallel bool + blockBatchSize uint64 + stats *stats.Stats + logger *zap.Logger + rootMessageDescriptor protoreflect.MessageDescriptor + useConstraints bool + flushLock sync.Mutex + lastAppliedBlockNum uint64 + lastAppliedBlockTime time.Time +} + +func NewSinker(rootMessageDescriptor protoreflect.MessageDescriptor, sink *sink.Sinker, db sql.Database, useTransaction bool, useConstraints bool, blockBatchSize int, parallel bool, stats *stats.Stats, logger *zap.Logger) *Sinker { + return &Sinker{ + db: db, + rootMessageDescriptor: rootMessageDescriptor, + useTransaction: useTransaction, + parallel: parallel, + blockBatchSize: uint64(blockBatchSize), + stats: stats, + Sinker: sink, + logger: logger, + } +} + +func (s *Sinker) Run(ctx context.Context) error { + // Show stats one last time before exiting run + defer s.LogStats() + + cursor, err := s.db.FetchCursor() + if err != nil { + return fmt.Errorf("fetch cursor: %w", err) + } + + //clean up the mess from running without a transaction + if cursor != nil { + err = s.db.HandleBlocksUndo(cursor.Block().Num()) + if err != nil { + return fmt.Errorf("handle blocks undo from %s: %w", cursor.Block(), err) + } + } + s.logger.Info("fetched cursor", zap.Stringer("block", cursor.Block())) + + s.stats.LastBlockProcessAt = time.Now() + s.Sinker.Run(ctx, cursor, s) + + return s.Sinker.Err() +} + +func (s *Sinker) LogStats() { + s.stats.Log() +} + +type Holder struct { + output *pbsubstreamsrpc.MapModuleOutput + data *pbsubstreamsrpc.BlockScopedData + isLive *bool + cursor *sink.Cursor +} + +var holding []*Holder + +func (s *Sinker) HandleBlockScopedData(ctx context.Context, data *pbsubstreamsrpc.BlockScopedData, isLive *bool, cursor *sink.Cursor) (err error) { + output := data.Output + + if output.Name == "" { + return nil + } + + if output.Name != s.OutputModuleName() { + return fmt.Errorf("received data from wrong output module, expected to received from %q but got module's output for %q", s.OutputModuleName(), output.Name) + } + + if (isLive != nil && *isLive) && s.useConstraints { + return fmt.Errorf("live mode is not supported without constraints") + } + + startAt := time.Now() + defer func() { + s.stats.LastBlockProcessAt = time.Now() + s.stats.BlockProcessingDuration.Add(time.Since(startAt)) + s.stats.TotalProcessingDuration += time.Since(startAt) + }() + + if s.stats.BlockCount > 0 { + s.stats.WaitDurationBetweenBlocks.Add(time.Since(s.stats.LastBlockProcessAt)) + s.stats.TotalDurationBetween += time.Since(s.stats.LastBlockProcessAt) + } + s.stats.BlockCount++ + + holder := &Holder{ + output: output, + data: data, + isLive: isLive, + cursor: cursor, + } + holding = append(holding, holder) + if data.Clock.Number > (s.lastAppliedBlockNum+s.blockBatchSize) || s.blockBatchSize == 1 || (isLive != nil && *isLive) { + if isLive != nil && *isLive && s.stats.FlushDuration.Average() > data.Clock.Timestamp.AsTime().Sub(s.lastAppliedBlockTime) { + s.logger.Debug("skipping a flush because we are LIVE and flush average duration is above time between blocks", zapx.HumanDuration("flush_duration_average", s.stats.FlushDuration.Average()), zap.Time("last_block_time", s.lastAppliedBlockTime), zap.Time("block_time", data.Clock.Timestamp.AsTime())) + return nil + } + + if s.useTransaction && !s.parallel { + if err := s.db.BeginTransaction(); err != nil { + return fmt.Errorf("begin tx: %w", err) + } + } + errs := multiError{} + if s.parallel { + wg := sync.WaitGroup{} + wg.Add(len(holding)) + + for _, h := range holding { + go func() { + db := s.db.Clone() + err := db.BeginTransaction() + if err != nil { + errs = append(errs, err) + } + + err = s.processHolder(h, s.stats) + if err != nil { + db.RollbackTransaction() + errs = append(errs, err) + } + err = db.CommitTransaction() + if err != nil { + errs = append(errs, err) + } + wg.Done() + }() + } + wg.Wait() + if len(errs) > 0 { + return fmt.Errorf("errors: %w", errs) + } + + } else { + for _, h := range holding { + err = s.processHolder(h, s.stats) + if err != nil { + if s.useTransaction { + s.logger.Error("rolling back transaction", zap.Error(err)) + s.db.RollbackTransaction() + } + return fmt.Errorf("process holder: %w", err) + } + } + } + + flushDuration, err := s.db.Flush() + if err != nil { + return fmt.Errorf("flushing: %w", err) + } + + var flushDurationPerBlock time.Duration + if len(holding) > 0 { + flushDurationPerBlock = flushDuration / time.Duration(len(holding)) + } + s.stats.FlushDuration.Add(flushDurationPerBlock) + + s.lastAppliedBlockNum = data.Clock.Number + s.lastAppliedBlockTime = data.Clock.Timestamp.AsTime() + err = s.db.StoreCursor(cursor) + if err != nil { + return fmt.Errorf("inserting cursor: %w", err) + } + + if s.useTransaction && !s.parallel { + if err := s.db.CommitTransaction(); err != nil { + return fmt.Errorf("commit tx: %w", err) + } + } + holding = []*Holder{} + } + + return nil +} + +func (s *Sinker) processHolder(h *Holder, stats *stats.Stats) (err error) { + if len(h.output.GetMapOutput().GetValue()) == 0 { + return nil + } + + unmarshalStartAt := time.Now() + md := s.rootMessageDescriptor + dm := dynamicpb.NewMessage(md) + err = proto.Unmarshal(h.data.Output.GetMapOutput().GetValue(), dm) + if err != nil { + return fmt.Errorf("unmarshaling message: %w", err) + } + + stats.UnmarshallingDuration.Add(time.Since(unmarshalStartAt)) + + err = processMessage(dm, s.db, h.data.Clock.Number, h.data.Clock.Id, h.data.Clock.Timestamp.AsTime(), stats) + if err != nil { + return fmt.Errorf("process entity: %w", err) + } + + return nil +} + +func processMessage(dm *dynamicpb.Message, database sql.Database, blockNum uint64, blockHash string, blockTimestamp time.Time, stats *stats.Stats) error { + startInsertBlock := time.Now() + err := database.InsertBlock(blockNum, blockHash, blockTimestamp) + if err != nil { + return fmt.Errorf("inserting block: %w", err) + } + stats.BlockInsertDuration.Add(time.Since(startInsertBlock)) + + sqlDuration, err := database.WalkMessageDescriptorAndInsert(dm, blockNum, blockTimestamp, nil) + if err != nil { + return fmt.Errorf("processing message %q: %w", string(dm.Descriptor().FullName()), err) + } + + stats.EntitiesInsertDuration.Add(sqlDuration) + + return nil +} + +func (s *Sinker) HandleBlockUndoSignal(ctx context.Context, undoSignal *pbsubstreamsrpc.BlockUndoSignal, cursor *sink.Cursor) (err error) { + lastValidBlockNum := undoSignal.LastValidBlock.Number + + s.logger.Info("Handling undo block signal", zap.Stringer("block", cursor.Block()), zap.Stringer("cursor", cursor)) + + err = s.db.HandleBlocksUndo(lastValidBlockNum) + if err != nil { + return fmt.Errorf("handle blocks undo from %d : %w", lastValidBlockNum, err) + } + + err = s.db.StoreCursor(cursor) + if err != nil { + return fmt.Errorf("inserting cursor: %w", err) + } + + return nil +} diff --git a/sink/sql/db_proto/sinker_factory.go b/sink/sql/db_proto/sinker_factory.go new file mode 100644 index 000000000..0972efeb7 --- /dev/null +++ b/sink/sql/db_proto/sinker_factory.go @@ -0,0 +1,185 @@ +package db_proto + +import ( + "context" + "fmt" + "time" + + "github.com/streamingfast/logging" + sink "github.com/streamingfast/substreams/sink" + "github.com/streamingfast/substreams/sink/sql/bytes" + "github.com/streamingfast/substreams/sink/sql/db_changes/db" + protosql "github.com/streamingfast/substreams/sink/sql/db_proto/sql" + clickhouse "github.com/streamingfast/substreams/sink/sql/db_proto/sql/click_house" + "github.com/streamingfast/substreams/sink/sql/db_proto/sql/postgres" + schema2 "github.com/streamingfast/substreams/sink/sql/db_proto/sql/schema" + stats2 "github.com/streamingfast/substreams/sink/sql/db_proto/stats" + "go.uber.org/zap" + "google.golang.org/protobuf/reflect/protoreflect" +) + +type SinkerFactoryFunc func(ctx context.Context, dsnString, schemaName string, logger *zap.Logger, tracer logging.Tracer) (*Sinker, error) + +type SinkerFactoryOptions struct { + UseProtoOption bool + UseConstraints bool + UseTransactions bool + BlockBatchSize int + Parallel bool + Encoding bytes.Encoding + Clickhouse SinkerFactoryClickhouse +} + +type SinkerFactoryClickhouse struct { + SinkInfoFolder string + CursorFilePath string + QueryRetryCount int + QueryRetrySleep time.Duration +} + +func (o SinkerFactoryOptions) Defaults() SinkerFactoryOptions { + if o.BlockBatchSize <= 0 { + o.BlockBatchSize = 25 + } + o.UseTransactions = true + if o.Encoding == 0 { + o.Encoding = bytes.EncodingRaw + } + return o +} + +func SinkerFactory( + baseSink *sink.Sinker, + outputModuleName string, + rootMessageDescriptor protoreflect.MessageDescriptor, + options SinkerFactoryOptions, +) SinkerFactoryFunc { + return func(ctx context.Context, dsnString string, schemaName string, logger *zap.Logger, tracer logging.Tracer) (*Sinker, error) { + dsn, err := db.ParseDSN(dsnString) + if err != nil { + return nil, fmt.Errorf("parsing dsn: %w", err) + } + + schema, err := schema2.NewSchema(schemaName, rootMessageDescriptor, options.UseProtoOption, logger) + if err != nil { + return nil, fmt.Errorf("creating schema: %w", err) + } + + var database protosql.Database + + switch dsn.Driver() { + case "postgres": + database, err = postgres.NewDatabase(schema, dsn, outputModuleName, rootMessageDescriptor, options.UseProtoOption, options.UseConstraints, options.Encoding, logger) + if err != nil { + return nil, fmt.Errorf("creating postgres database: %w", err) + } + + case "clickhouse": + database, err = clickhouse.NewDatabase( + ctx, + schema, + dsn, + outputModuleName, + rootMessageDescriptor, + options.Clickhouse.SinkInfoFolder, + options.Clickhouse.CursorFilePath, + true, + options.Encoding, + logger, + tracer, + options.Clickhouse.QueryRetryCount, + options.Clickhouse.QueryRetrySleep, + ) + if err != nil { + return nil, fmt.Errorf("creating clickhouse database: %w", err) + } + + default: + panic(fmt.Sprintf("unsupported driver: %s", dsn.Driver())) + + } + + sinkInfo, err := database.FetchSinkInfo(schema.Name) + if err != nil { + return nil, fmt.Errorf("fetching sink info: %w", err) + } + + logger.Info("sink info read", zap.Reflect("sink_info", sinkInfo)) + if sinkInfo == nil { + err := database.BeginTransaction() + if err != nil { + return nil, fmt.Errorf("begin transaction: %w", err) + } + err = database.CreateDatabase(options.UseConstraints) + if err != nil { + database.RollbackTransaction() + return nil, fmt.Errorf("creating database: %w", err) + } + + err = database.StoreSinkInfo(schemaName, database.GetDialect().SchemaHash()) + if err != nil { + database.RollbackTransaction() + return nil, fmt.Errorf("storing sink info: %w", err) + } + + err = database.CommitTransaction() + + } else { + migrationNeeded := sinkInfo.SchemaHash != database.GetDialect().SchemaHash() + if migrationNeeded { + + tempSchemaName := schema.Name + "_" + database.GetDialect().SchemaHash() + tempSinkInfo, err := database.FetchSinkInfo(tempSchemaName) + if err != nil { + return nil, fmt.Errorf("fetching temp schema sink info: %w", err) + } + if tempSinkInfo != nil { + hash, err := database.DatabaseHash(schema.Name) + if err != nil { + return nil, fmt.Errorf("fetching schema %q hash: %w", schema.Name, err) + } + dbTempHash, err := database.DatabaseHash(tempSchemaName) + if err != nil { + return nil, fmt.Errorf("fetching temp schema %q hash: %w", tempSchemaName, err) + } + + if hash != dbTempHash { + return nil, fmt.Errorf("schema %s and temp schema %s have different hash", schema.Name, tempSchemaName) + } + err = database.BeginTransaction() + if err != nil { + return nil, fmt.Errorf("begin transaction: %w", err) + } + err = database.UpdateSinkInfoHash(schemaName, tempSinkInfo.SchemaHash) + if err != nil { + database.RollbackTransaction() + return nil, fmt.Errorf("updating sink info hash: %w", err) + } + + err = database.CommitTransaction() + if err != nil { + return nil, fmt.Errorf("commit transaction: %w", err) + } + + } + } + } + + err = database.Open() + if err != nil { + return nil, fmt.Errorf("opening database: %w", err) + } + + return NewSinker( + rootMessageDescriptor, + baseSink, + database, + options.UseTransactions, + options.UseConstraints, + options.BlockBatchSize, + options.Parallel, + stats2.NewStats(logger), + logger, + ), nil + } +} diff --git a/sink/sql/db_proto/sql/click_house/accumulator_inserter.go b/sink/sql/db_proto/sql/click_house/accumulator_inserter.go new file mode 100644 index 000000000..c32e78b45 --- /dev/null +++ b/sink/sql/db_proto/sql/click_house/accumulator_inserter.go @@ -0,0 +1,627 @@ +package clickhouse + +import ( + "database/sql" + "fmt" + "sort" + "strings" + "time" + + "github.com/ClickHouse/ch-go" + "github.com/ClickHouse/ch-go/proto" + "github.com/streamingfast/logging" + "github.com/streamingfast/logging/zapx" + "github.com/streamingfast/substreams/sink/sql/bytes" + sql2 "github.com/streamingfast/substreams/sink/sql/db_proto/sql" + "github.com/streamingfast/substreams/sink/sql/db_proto/sql/schema" + v1 "github.com/streamingfast/substreams/pb/sf/substreams/sink/sql/schema/v1" + "go.uber.org/zap" + "google.golang.org/protobuf/types/known/timestamppb" +) + +type accumulator struct { + ordinal int + tableName string + columns map[int]*schema.Column + input map[string]proto.ColInput +} + +type AccumulatorInserter struct { + accumulators map[string]*accumulator + cursorStmt *sql.Stmt + logger *zap.Logger + tracer logging.Tracer + bytesEncoding bytes.Encoding +} + +func NewAccumulatorInserter(database *Database, logger *zap.Logger, tracer logging.Tracer) (*AccumulatorInserter, error) { + logger = logger.Named("clickhouse inserter") + + accumulators, err := createAccumulators(database.dialect) + if err != nil { + return nil, fmt.Errorf("creating accumulators: %w", err) + } + return &AccumulatorInserter{ + accumulators: accumulators, + logger: logger, + tracer: tracer, + bytesEncoding: database.bytesEncoding, + }, nil +} + +func createAccumulators(dialect *DialectClickHouse) (map[string]*accumulator, error) { + if dialect == nil { + panic("dialect is nil") + } + + accumulators := map[string]*accumulator{} + + accumulators[sql2.DialectTableBlock] = &accumulator{ + ordinal: -1, + tableName: sql2.DialectTableBlock, + columns: map[int]*schema.Column{ + 0: {Name: "number"}, + 1: {Name: "hash"}, + 2: {Name: "timestamp"}, + 3: {Name: "version"}, + 4: {Name: "deleted"}, + }, + input: map[string]proto.ColInput{ + "number": &proto.ColUInt64{}, + "hash": &proto.ColStr{}, + "timestamp": &proto.ColDateTime{}, + "version": &proto.ColInt64{}, + "deleted": &proto.ColBool{}, + }, + } + + tables := dialect.GetTables() + for _, table := range tables { + input := map[string]proto.ColInput{} + columns := map[int]*schema.Column{} + + input[sql2.DialectFieldBlockNumber] = &proto.ColUInt64{} + columns[0] = &schema.Column{Name: sql2.DialectFieldBlockNumber} + + input[sql2.DialectFieldBlockTimestamp] = &proto.ColDateTime{} + columns[1] = &schema.Column{Name: sql2.DialectFieldBlockTimestamp} + + input[sql2.DialectFieldVersion] = &proto.ColInt64{} + columns[2] = &schema.Column{Name: sql2.DialectFieldVersion} + + input[sql2.DialectFieldDeleted] = &proto.ColBool{} + columns[3] = &schema.Column{Name: sql2.DialectFieldDeleted} + + primaryName := "" + if table.PrimaryKey != nil { + pk := table.PrimaryKey + primaryName = pk.Name + + input[pk.Name] = ColInputForColumn(pk.FieldDescriptor, dialect.bytesEncoding, table.Columns[pk.Index]) + columns[4] = &schema.Column{Name: pk.Name} + } + + offset := len(columns) + if table.ChildOf != nil { + parentTable, parentFound := dialect.TableRegistry[table.ChildOf.ParentTable] + if !parentFound { + return nil, fmt.Errorf("parent table %q not found", table.ChildOf.ParentTable) + } + fieldFound := false + for _, parentField := range parentTable.Columns { + + if parentField.Name == table.ChildOf.ParentTableField { + input[parentField.Name] = ColInputForColumn(parentField.FieldDescriptor, dialect.bytesEncoding, parentField) + columns[offset] = parentField + fieldFound = true + break + } + } + if !fieldFound { + return nil, fmt.Errorf("field %q not found in table %q", table.ChildOf.ParentTableField, table.ChildOf.ParentTable) + } + } + + offset = len(columns) + skipCount := 0 + for i, column := range table.Columns { + if column.Name == primaryName { + skipCount++ + continue + } + if column.Nested != nil { + for _, nestedCol := range column.Nested.Columns { + nestedColName := fmt.Sprintf("%s.%s", column.Name, nestedCol.Name) + nestedInput := ColInputForColumn(nestedCol.FieldDescriptor, dialect.bytesEncoding, nestedCol) + if nestedInput != nil { + switch base := nestedInput.(type) { + case *proto.ColStr: + input[nestedColName] = proto.NewArray(base) + case *proto.ColInt32: + input[nestedColName] = proto.NewArray(base) + case *proto.ColInt64: + input[nestedColName] = proto.NewArray(base) + case *proto.ColUInt32: + input[nestedColName] = proto.NewArray(base) + case *proto.ColUInt64: + input[nestedColName] = proto.NewArray(base) + case *proto.ColFloat32: + input[nestedColName] = proto.NewArray(base) + case *proto.ColFloat64: + input[nestedColName] = proto.NewArray(base) + case *proto.ColBool: + input[nestedColName] = proto.NewArray(base) + case *proto.ColBytes: + input[nestedColName] = proto.NewArray(base) + case *proto.ColDateTime: + input[nestedColName] = proto.NewArray(base) + default: + return nil, fmt.Errorf("unsupported nested column type %T for column %s.%s", base, column.Name, nestedCol.Name) + } + nestedColEntry := &schema.Column{ + Name: nestedColName, + FieldDescriptor: nestedCol.FieldDescriptor, + } + columns[i+offset-skipCount] = nestedColEntry + offset++ + } + } + skipCount++ + continue + } + input[column.Name] = ColInputForColumn(column.FieldDescriptor, dialect.bytesEncoding, column) + columns[i+offset-skipCount] = column + } + + accumulators[table.Name] = &accumulator{ + tableName: table.Name, + ordinal: table.Ordinal, + columns: columns, + input: input, + } + } + + return accumulators, nil +} + +func (i *AccumulatorInserter) insert(table string, values []any) error { + accumulator := i.accumulators[table] + if accumulator == nil { + return fmt.Errorf("accumulator not found for table %q", table) + } + i.logger.Debug("inserting", zap.String("table", table), zap.Int("values", len(values))) + + for idx, value := range values { + column, found := accumulator.columns[idx] + if !found { + return fmt.Errorf("column not found for table %q at idx %d", table, idx) + } + input := accumulator.input[column.Name] + + if i.tracer.Enabled() { + i.logger.Debug("inserting column value", + zap.String("table", table), + zap.String("column", column.Name), + zapx.Type("column_type", input), + zapx.Type("value_type", value), + ) + } + + switch input := input.(type) { + case *proto.ColDateTime: + if t, ok := value.(*timestamppb.Timestamp); ok { + input.Append(t.AsTime()) + } else if t, ok := value.(time.Time); ok { + input.Append(t) + } else { + panic(fmt.Sprintf("unknown time base input type %T for column %s of table %s", input, column.Name, table)) + } + case *proto.ColInt32: + input.Append(value.(int32)) + case *proto.ColInt64: + input.Append(value.(int64)) + case *proto.ColUInt32: + input.Append(value.(uint32)) + case *proto.ColUInt64: + input.Append(value.(uint64)) + case *proto.ColFloat32: + input.Append(value.(float32)) + case *proto.ColFloat64: + input.Append(value.(float64)) + case *ColScaledDecimal128: + stringValue := value.(string) + scale := column.ConvertTo.Convertion.(*v1.StringConvertion_Decimal128).Decimal128.Scale + if column.IsOptional && stringValue == "" { + input.Append(proto.Decimal128{}) + } else { + v, err := StringToDecimal128(stringValue, scale) + if err != nil { + panic(fmt.Sprintf("failed to convert string to decimal128 for column %s of table %s: %v", column.Name, table, err)) + } + input.Append(v) + } + case *ColScaledDecimal256: + stringValue := value.(string) + scale := column.ConvertTo.Convertion.(*v1.StringConvertion_Decimal256).Decimal256.Scale + if column.IsOptional && stringValue == "" { + input.Append(proto.Decimal256{}) + } else { + v, err := StringToDecimal256(stringValue, scale) + if err != nil { + panic(fmt.Sprintf("failed to convert string to decimal256 for column %s of table %s: %v", column.Name, table, err)) + } + input.Append(v) + } + case *proto.ColInt128: + stringValue := value.(string) + if column.IsOptional && stringValue == "" { + input.Append(proto.Int128{}) + } else { + v, err := StringToInt128(stringValue) + if err != nil { + panic(fmt.Sprintf("failed to convert string to int128 for column %s of table %s: %v", column.Name, table, err)) + } + input.Append(v) + } + case *proto.ColUInt128: + stringValue := value.(string) + if column.IsOptional && stringValue == "" { + input.Append(proto.UInt128{}) + } else { + v, err := StringToUInt128(stringValue) + if err != nil { + panic(fmt.Sprintf("failed to convert string to uint128 for column %s of table %s: %v", column.Name, table, err)) + } + input.Append(v) + } + case *proto.ColInt256: + stringValue := value.(string) + if column.IsOptional && stringValue == "" { + input.Append(proto.Int256{}) + } else { + v, err := StringToInt256(stringValue) + if err != nil { + panic(fmt.Sprintf("failed to convert string to int256 for column %s of table %s: %v", column.Name, table, err)) + } + input.Append(v) + } + case *proto.ColUInt256: + stringValue := value.(string) + if column.IsOptional && stringValue == "" { + input.Append(proto.UInt256{}) + } else { + v, err := StringToUInt256(stringValue) + if err != nil { + panic(fmt.Sprintf("failed to convert string to uint256 for column %s of table %s: %v", column.Name, table, err)) + } + input.Append(v) + } + case *proto.ColStr: + if bytesValue, ok := value.([]byte); ok { + encoded, err := i.bytesEncoding.EncodeBytes(bytesValue) + if err != nil { + panic(fmt.Sprintf("failed to encode bytes for column %s of table %s: %v", column.Name, table, err)) + } + input.Append(encoded.(string)) + } else { + input.Append(value.(string)) + } + case *proto.ColBytes: + input.Append(value.([]byte)) + case *proto.ColBool: + input.Append(value.(bool)) + case *proto.ColArr[int32]: + if arr, ok := value.([]interface{}); ok { + int32Arr := make([]int32, len(arr)) + for i, v := range arr { + int32Arr[i] = v.(int32) + } + input.Append(int32Arr) + } else { + panic(fmt.Sprintf("expected []interface{} for array column %s of table %s, got %T", column.Name, table, value)) + } + case *proto.ColArr[int64]: + if arr, ok := value.([]interface{}); ok { + int64Arr := make([]int64, len(arr)) + for i, v := range arr { + int64Arr[i] = v.(int64) + } + input.Append(int64Arr) + } else { + panic(fmt.Sprintf("expected []interface{} for array column %s of table %s, got %T", column.Name, table, value)) + } + case *proto.ColArr[uint32]: + if arr, ok := value.([]interface{}); ok { + uint32Arr := make([]uint32, len(arr)) + for i, v := range arr { + uint32Arr[i] = v.(uint32) + } + input.Append(uint32Arr) + } else { + panic(fmt.Sprintf("expected []interface{} for array column %s of table %s, got %T", column.Name, table, value)) + } + case *proto.ColArr[uint64]: + if arr, ok := value.([]interface{}); ok { + uint64Arr := make([]uint64, len(arr)) + for i, v := range arr { + uint64Arr[i] = v.(uint64) + } + input.Append(uint64Arr) + } else { + panic(fmt.Sprintf("expected []interface{} for array column %s of table %s, got %T", column.Name, table, value)) + } + case *proto.ColArr[*proto.Int128]: + if arr, ok := value.([]interface{}); ok { + int128Arr := make([]*proto.Int128, len(arr)) + for i, v := range arr { + stringValue := v.(string) + if stringValue == "" { + zeroValue := proto.Int128{} + int128Arr[i] = &zeroValue + } else { + v, err := StringToInt128(stringValue) + if err != nil { + panic(fmt.Sprintf("failed to convert array of string to int128 for column %s of table %s: %v", column.Name, table, err)) + } + int128Arr[i] = &v + } + } + input.Append(int128Arr) + } else { + panic(fmt.Sprintf("expected []interface{} for array column %s of table %s, got %T", column.Name, table, value)) + } + case *proto.ColArr[*proto.UInt128]: + if arr, ok := value.([]interface{}); ok { + uint128Arr := make([]*proto.UInt128, len(arr)) + for i, v := range arr { + stringValue := v.(string) + if stringValue == "" { + zeroValue := proto.UInt128{} + uint128Arr[i] = &zeroValue + } else { + v, err := StringToUInt128(stringValue) + if err != nil { + panic(fmt.Sprintf("failed to convert array of string to uint128 for column %s of table %s: %v", column.Name, table, err)) + } + uint128Arr[i] = &v + } + } + input.Append(uint128Arr) + } + case *proto.ColArr[*proto.Int256]: + if arr, ok := value.([]interface{}); ok { + int256Arr := make([]*proto.Int256, len(arr)) + for i, v := range arr { + stringValue := v.(string) + if stringValue == "" { + zeroValue := proto.Int256{} + int256Arr[i] = &zeroValue + } else { + v, err := StringToInt256(stringValue) + if err != nil { + panic(fmt.Sprintf("failed to convert array of string to int256 for column %s of table %s: %v", column.Name, table, err)) + } + int256Arr[i] = &v + } + } + input.Append(int256Arr) + } + case *proto.ColArr[*proto.UInt256]: + if arr, ok := value.([]interface{}); ok { + uint256Arr := make([]*proto.UInt256, len(arr)) + for i, v := range arr { + stringValue := v.(string) + if stringValue == "" { + zeroValue := proto.UInt256{} + uint256Arr[i] = &zeroValue + } else { + v, err := StringToUInt256(stringValue) + if err != nil { + panic(fmt.Sprintf("failed to convert array of string to uint256 for column %s of table %s: %v", column.Name, table, err)) + } + uint256Arr[i] = &v + } + } + input.Append(uint256Arr) + } + case *proto.ColArr[float32]: + if arr, ok := value.([]interface{}); ok { + float32Arr := make([]float32, len(arr)) + for i, v := range arr { + float32Arr[i] = v.(float32) + } + input.Append(float32Arr) + } else { + panic(fmt.Sprintf("expected []interface{} for array column %s of table %s, got %T", column.Name, table, value)) + } + case *proto.ColArr[float64]: + if arr, ok := value.([]interface{}); ok { + float64Arr := make([]float64, len(arr)) + for i, v := range arr { + float64Arr[i] = v.(float64) + } + input.Append(float64Arr) + } else { + panic(fmt.Sprintf("expected []interface{} for array column %s of table %s, got %T", column.Name, table, value)) + } + case *proto.ColArr[*proto.Decimal128]: + scale := column.ConvertTo.Convertion.(*v1.StringConvertion_Decimal128).Decimal128.Scale + if arr, ok := value.([]interface{}); ok { + decimal128Arr := make([]*proto.Decimal128, len(arr)) + for i, v := range arr { + v, err := StringToDecimal128(v.(string), scale) + if err != nil { + panic(fmt.Sprintf("failed to convert array of string to decimal128 for column %s of table %s: %v", column.Name, table, err)) + } + decimal128Arr[i] = &v + } + input.Append(decimal128Arr) + } + case *proto.ColArr[*proto.Decimal256]: + scale := column.ConvertTo.Convertion.(*v1.StringConvertion_Decimal128).Decimal128.Scale + if arr, ok := value.([]interface{}); ok { + decimal256Arr := make([]*proto.Decimal256, len(arr)) + for i, v := range arr { + v, err := StringToDecimal256(v.(string), scale) + if err != nil { + panic(fmt.Sprintf("failed to convert array of string to decimal256 for column %s of table %s: %v", column.Name, table, err)) + } + decimal256Arr[i] = &v + } + input.Append(decimal256Arr) + } + case *proto.ColArr[bool]: + if arr, ok := value.([]interface{}); ok { + boolArr := make([]bool, len(arr)) + for i, v := range arr { + boolArr[i] = v.(bool) + } + input.Append(boolArr) + } else { + panic(fmt.Sprintf("expected []interface{} for array column %s of table %s, got %T", column.Name, table, value)) + } + case *proto.ColArr[string]: + if arr, ok := value.([]interface{}); ok { + stringArr := make([]string, len(arr)) + for i, v := range arr { + stringArr[i] = v.(string) + } + input.Append(stringArr) + } else { + panic(fmt.Sprintf("expected []interface{} for array column %s of table %s, got %T", column.Name, table, value)) + } + case *proto.ColArr[[]byte]: + if arr, ok := value.([]interface{}); ok { + bytesArr := make([][]byte, len(arr)) + for i, v := range arr { + bytesArr[i] = v.([]byte) + } + input.Append(bytesArr) + } else { + panic(fmt.Sprintf("expected []interface{} for array column %s of table %s, got %T", column.Name, table, value)) + } + case *proto.ColArr[time.Time]: + if arr, ok := value.([]interface{}); ok { + timeArr := make([]time.Time, len(arr)) + for i, v := range arr { + if t, ok := v.(*timestamppb.Timestamp); ok { + timeArr[i] = t.AsTime() + } else if t, ok := v.(time.Time); ok { + timeArr[i] = t + } else { + panic(fmt.Sprintf("unknown time type %T in array for column %s of table %s", v, column.Name, table)) + } + } + input.Append(timeArr) + } else { + panic(fmt.Sprintf("expected []interface{} for array column %s of table %s, got %T", column.Name, table, value)) + } + default: + panic(fmt.Sprintf("unknown input type %T for column %s of table %s", input, column.Name, table)) + } + } + + return nil +} + +func (i *AccumulatorInserter) flush(database *Database) error { + i.logger.Debug("flushing started", zap.Int("accumulators", len(i.accumulators))) + var accumulators []accumulator + + start := time.Now() + for _, acc := range i.accumulators { + accumulators = append(accumulators, *acc) + } + + sort.Slice(accumulators, func(i, j int) bool { + return accumulators[i].ordinal < accumulators[j].ordinal + }) + + client, err := database.client() + if err != nil { + return fmt.Errorf("clickhouse accumulator inserter: creating client: %w", err) + } + + queryDuration := time.Duration(0) + rowCount := 0 + for _, acc := range accumulators { + qStart := time.Now() + + inputs := proto.Input{} + for n, inp := range acc.input { + if n == "block_number" { + rowCount += inp.Rows() + } + inputs = append(inputs, proto.InputColumn{ + Name: n, + Data: inp, + }) + } + + retryCount := database.queryRetryCount + retrySleep := database.queryRetrySleep + for attempt := 0; ; attempt++ { + if err := client.Do(database.ctx, ch.Query{ + Body: inputs.Into(acc.tableName), + Input: inputs, + }); err != nil { + if attempt >= retryCount { + return fmt.Errorf("clickhouse accumulator inserter: executing query on %q after %d retries: %w", acc.debugTableAndColumns(), attempt, err) + } + i.logger.Warn("clickhouse insert failed, will retry", zap.Int("attempt", attempt+1), zap.Int("max_attempts", retryCount), zap.String("table", acc.tableName), zap.Error(err)) + time.Sleep(retrySleep) + fresh, cErr := database.freshClient() + if cErr != nil { + return fmt.Errorf("clickhouse accumulator inserter: getting fresh client: %w", cErr) + } + client = fresh + continue + } + break + } + + queryDuration += time.Since(qStart) + } + + accs, err := createAccumulators(database.dialect) + if err != nil { + return fmt.Errorf("clickhouse accumulator inserter: creating accumulators: %w", err) + } + i.accumulators = accs + + i.logger.Debug("flushing done", zapx.HumanDuration("duration", time.Since(start)), zap.Int("rows", rowCount)) + + return nil +} + +func (acc *accumulator) debugTableAndColumns() string { + var b strings.Builder + b.WriteString(acc.tableName) + b.WriteString(" (") + for idx, col := range acc.columns { + if idx > 0 { + b.WriteString(", ") + } + b.WriteString(col.Name) + } + b.WriteString(")") + return b.String() +} + +type ColScaledDecimal256 struct { + *proto.ColDecimal256 + scale uint8 +} + +func (c *ColScaledDecimal256) Type() proto.ColumnType { + return proto.ColumnType(fmt.Sprintf("Decimal256(%d)", c.scale)) +} + +type ColScaledDecimal128 struct { + *proto.ColDecimal128 + scale uint8 +} + +func (c *ColScaledDecimal128) Type() proto.ColumnType { + return proto.ColumnType(fmt.Sprintf("Decimal128(%d)", c.scale)) +} diff --git a/sink/sql/db_proto/sql/click_house/database.go b/sink/sql/db_proto/sql/click_house/database.go new file mode 100644 index 000000000..91cc75523 --- /dev/null +++ b/sink/sql/db_proto/sql/click_house/database.go @@ -0,0 +1,472 @@ +package clickhouse + +import ( + "context" + "crypto/tls" + "fmt" + "io" + "os" + "path" + "sort" + "time" + + "github.com/ClickHouse/ch-go" + "github.com/streamingfast/logging" + "github.com/streamingfast/logging/zapx" + sink "github.com/streamingfast/substreams/sink" + "github.com/streamingfast/substreams/sink/sql/bytes" + "github.com/streamingfast/substreams/sink/sql/db_changes/db" + "github.com/streamingfast/substreams/sink/sql/db_proto/sql" + "github.com/streamingfast/substreams/sink/sql/db_proto/sql/schema" + "go.uber.org/zap" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/dynamicpb" +) + +type Database struct { + *sql.BaseDatabase + schema *schema.Schema + sinkInfoFolder string + cursorFilePath string + logger *zap.Logger + dialect *DialectClickHouse + cachedClient *ch.Client + dsn *db.DSN + ctx context.Context + inserter *AccumulatorInserter + bytesEncoding bytes.Encoding + queryRetryCount int + queryRetrySleep time.Duration +} + +func NewDatabase( + ctx context.Context, + schema *schema.Schema, + dsn *db.DSN, + moduleOutputType string, + rootMessageDescriptor protoreflect.MessageDescriptor, + sinkInfoFolder string, + cursorFilePath string, + useProtoOptions bool, + bytesEncoding bytes.Encoding, + logger *zap.Logger, + tracer logging.Tracer, + queryRetryCount int, + queryRetrySleep time.Duration, +) (*Database, error) { + baseDB, err := sql.NewBaseDatabase(moduleOutputType, rootMessageDescriptor, useProtoOptions, logger) + if err != nil { + return nil, fmt.Errorf("creating base database: %w", err) + } + dialect, err := NewDialectClickHouse(schema, bytesEncoding, logger) + if err != nil { + return nil, fmt.Errorf("creating dialect: %w", err) + } + + database := &Database{ + ctx: ctx, + dsn: dsn, + BaseDatabase: baseDB, + dialect: dialect, + schema: schema, + sinkInfoFolder: sinkInfoFolder, + cursorFilePath: cursorFilePath, + logger: logger, + bytesEncoding: bytesEncoding, + queryRetryCount: queryRetryCount, + queryRetrySleep: queryRetrySleep, + } + if database.queryRetryCount <= 0 { + database.queryRetryCount = 3 + } + if database.queryRetrySleep <= 0 { + database.queryRetrySleep = time.Second + } + inserter, err := NewAccumulatorInserter(database, logger, tracer) + if err != nil { + return nil, fmt.Errorf("creating accumulator inserter: %w", err) + } + database.inserter = inserter + + return database, nil +} + +func (d *Database) Open() error { + return nil +} + +func newClient(dsn *db.DSN, logger *zap.Logger) (*ch.Client, error) { + chOption := ch.Options{ + Address: fmt.Sprintf("%s:%d", dsn.Host, dsn.Port), + Database: dsn.Database, + User: dsn.Username, + Password: dsn.Password, + DialTimeout: 30 * time.Second, + } + + for key, value := range dsn.Options.Iter() { + if key == "secure" && value == "true" { + chOption.TLS = &tls.Config{} + continue + } + if key == "username" { + chOption.User = value + continue + } + if key == "password" { + chOption.Password = value + continue + } + if key == "compress" && value == "true" { + chOption.Compression = ch.CompressionLZ4 + continue + } + } + + for { + client, err := ch.Dial(context.Background(), chOption) + if err != nil { + logger.Warn("dialing clickhouse failed, will retry", zap.Error(err)) + time.Sleep(time.Second) + continue + } + return client, nil + } +} + +func (d *Database) client() (*ch.Client, error) { + if d.cachedClient == nil || d.cachedClient.IsClosed() { + client, err := newClient(d.dsn, d.logger) + if err != nil { + return nil, fmt.Errorf("creating clickhouse client: %w", err) + } + d.cachedClient = client + } + + return d.cachedClient, nil +} + +func (d *Database) freshClient() (*ch.Client, error) { + client, err := newClient(d.dsn, d.logger) + if err != nil { + return nil, fmt.Errorf("creating clickhouse client: %w", err) + } + d.cachedClient = client + return client, nil +} + +func (d *Database) clientNoCache(dsn *db.DSN) (*ch.Client, error) { + client, err := newClient(dsn, d.logger) + if err != nil { + return nil, fmt.Errorf("creating clickhouse client: %w", err) + } + return client, nil +} + +func (d *Database) CreateDatabase(useConstraints bool) error { + dsn := d.dsn.Clone() + dsn.Database = "default" + client, err := d.clientNoCache(dsn) + if err != nil { + return fmt.Errorf("creating clickhouse client: %w", err) + } + + d.logger.Info("creating database", zap.String("schema_name", d.schema.Name)) + + err = client.Ping(d.ctx) + if err != nil { + return fmt.Errorf("pinging clickhouse: %w", err) + } + + if err := client.Do(d.ctx, ch.Query{ + Body: fmt.Sprintf(staticSqlCreatDatabase, d.schema.Name), + }); err != nil { + return fmt.Errorf("executing create database sql: %w", err) + } + + d.logger.Info("database created", zap.String("schema_name", d.schema.Name)) + + client, err = d.client() + if err != nil { + return fmt.Errorf("getting clickhouse client: %w", err) + } + + if err := client.Do(d.ctx, ch.Query{ + Body: fmt.Sprintf(staticSqlCreateBlock, d.schema.Name), + }); err != nil { + return fmt.Errorf("executing create block sql: %w", err) + } + + d.logger.Info("block table created", zap.String("schema_name", d.schema.Name)) + + if err := client.Do(d.ctx, ch.Query{ + Body: "SET flatten_nested = 1;", + }); err != nil { + return fmt.Errorf("executing flatten nested sql: %w", err) + } + + for _, statement := range d.dialect.CreateTableSql { + if err := client.Do(d.ctx, ch.Query{ + Body: statement, + }); err != nil { + return fmt.Errorf("executing create table sql: %w %q", err, statement) + } + d.logger.Info("table created", zap.String("table_name", statement), zap.String("schema_name", d.schema.Name)) + } + + return nil +} + +func (d *Database) Insert(table string, values []any) error { + return d.inserter.insert(table, values) +} + +func (d *Database) WalkMessageDescriptorAndInsert(dm *dynamicpb.Message, blockNum uint64, blockTimestamp time.Time, parent *sql.Parent) (time.Duration, error) { + return d.BaseDatabase.WalkMessageDescriptorAndInsertWithDialect(dm, blockNum, blockTimestamp, parent, d.dialect, d) +} + +func (d *Database) BeginTransaction() error { + return nil +} + +func (d *Database) CommitTransaction() error { + return nil +} + +func (d *Database) RollbackTransaction() { +} + +func (d *Database) Flush() (time.Duration, error) { + d.logger.Debug("flushing") + + startFlush := time.Now() + err := d.inserter.flush(d) + if err != nil { + return 0, fmt.Errorf("flushing: %w", err) + } + return time.Since(startFlush), nil +} + +func (d *Database) GetDialect() sql.Dialect { + return d.dialect +} + +func (d *Database) InsertBlock(blockNum uint64, hash string, timestamp time.Time) error { + d.logger.Debug("inserting _block_", zap.Uint64("block_num", blockNum), zap.String("block_hash", hash)) + err := d.inserter.insert("_blocks_", []any{blockNum, hash, timestamp, time.Now().UnixNano(), false}) + if err != nil { + return fmt.Errorf("inserting block %d: %w", blockNum, err) + } + + return nil +} + +func (d *Database) FetchSinkInfo(schemaName string) (*sql.SinkInfo, error) { + fileName := fmt.Sprintf("%s_schema_hash.txt", schemaName) + schemaFilePath := path.Join(d.sinkInfoFolder, fileName) + file, err := os.Open(schemaFilePath) + if err != nil { + if os.IsNotExist(err) { + d.logger.Warn("schema hash file does not exist", zap.String("file_path", schemaFilePath)) + return nil, nil + } + return nil, fmt.Errorf("opening schema hash file: %w", err) + } + defer file.Close() + + var schemaHash string + _, err = fmt.Fscanf(file, "%s", &schemaHash) + if err != nil { + return nil, fmt.Errorf("reading schema hash from file: %w", err) + } + + return &sql.SinkInfo{SchemaHash: schemaHash}, nil +} + +func (d *Database) StoreSinkInfo(schemaName string, schemaHash string) error { + fileName := fmt.Sprintf("%s_schema_hash.txt", schemaName) + schemaFilePath := path.Join(d.sinkInfoFolder, fileName) + + file, err := os.Create(schemaFilePath) + if err != nil { + return fmt.Errorf("creating schema hash file: %w", err) + } + defer file.Close() + + _, err = file.WriteString(schemaHash) + if err != nil { + return fmt.Errorf("writing schema hash to file: %w", err) + } + + return nil +} + +func (d *Database) UpdateSinkInfoHash(schemaName string, newHash string) error { + panic("implement me") +} + +func (d *Database) FetchCursor() (*sink.Cursor, error) { + if d.cursorFilePath == "" { + return nil, fmt.Errorf("cursor file path is not set") + } + + file, err := os.Open(d.cursorFilePath) + if err != nil { + if os.IsNotExist(err) { + return nil, nil + } + return nil, fmt.Errorf("opening cursor file: %w", err) + } + defer file.Close() + + cursorData, err := io.ReadAll(file) + if err != nil { + return nil, fmt.Errorf("reading cursor file: %w", err) + } + + cursor, err := sink.NewCursor(string(cursorData)) + if err != nil { + return nil, fmt.Errorf("parsing cursor: %w", err) + } + + return cursor, nil +} + +func (d *Database) StoreCursor(cursor *sink.Cursor) error { + if d.cursorFilePath == "" { + return fmt.Errorf("cursor file path is not set") + } + + file, err := os.Create(d.cursorFilePath) + if err != nil { + return fmt.Errorf("creating cursor file: %w", err) + } + defer file.Close() + + _, err = file.WriteString(cursor.String()) + if err != nil { + return fmt.Errorf("writing cursor to file: %w", err) + } + + return nil +} + +func (d *Database) HandleBlocksUndo(lastValidBlockNum uint64) error { + tables := d.dialect.GetTables() + + sort.Slice(tables, func(i, j int) bool { + return tables[i].Ordinal > tables[j].Ordinal + }) + + client, err := d.client() + if err != nil { + return fmt.Errorf("creating clickhouse client: %w", err) + } + + doWithRetry := func(q string) error { + retryCount := d.queryRetryCount + retrySleep := d.queryRetrySleep + for attempt := 0; ; attempt++ { + if err := client.Do(d.ctx, ch.Query{Body: q}); err != nil { + if attempt >= retryCount { + return fmt.Errorf("executing clickhouse query after %d retries: %w", attempt, err) + } + d.logger.Warn("clickhouse query failed, will retry", zap.Int("attempt", attempt+1), zap.Int("max_attempts", retryCount), zap.Error(err)) + time.Sleep(retrySleep) + fresh, cErr := d.freshClient() + if cErr != nil { + return fmt.Errorf("getting fresh client: %w", cErr) + } + client = fresh + continue + } + break + } + return nil + } + + err = d.BeginTransaction() + if err != nil { + return fmt.Errorf("begin transaction: %w", err) + } + + version := time.Now().UnixNano() + + d.logger.Info("undoing blocks", zap.String("table", "_block_"), zap.Uint64("last_valid_block_num", lastValidBlockNum)) + start := time.Now() + insertDeleteBlocks := fmt.Sprintf(` + INSERT INTO %s._blocks_ + SELECT number, hash, timestamp, %d, true + FROM %s._blocks_ WHERE number > %d + `, d.schema.Name, version, d.schema.Name, lastValidBlockNum) + + err = doWithRetry(insertDeleteBlocks) + if err != nil { + return fmt.Errorf("deleting block from %d: %w", lastValidBlockNum, err) + } + + d.logger.Info("undo completed", zap.String("table", "_block_"), zapx.HumanDuration("duration", time.Since(start))) + + for _, table := range tables { + d.logger.Info("undoing blocks", zap.String("table", table.Name), zap.Uint64("last_valid_block_num", lastValidBlockNum)) + start := time.Now() + tableFullName := d.dialect.FullTableName(table) + fields := "" + + if table.ChildOf != nil { + parentTable, parentFound := d.dialect.TableRegistry[table.ChildOf.ParentTable] + if !parentFound { + return fmt.Errorf("parent table %q not found", table.ChildOf.ParentTable) + } + fieldFound := false + for _, parentField := range parentTable.Columns { + if parentField.Name == table.ChildOf.ParentTableField { + fields += fmt.Sprintf(", %s", parentField.Name) + fieldFound = true + break + } + } + if !fieldFound { + return fmt.Errorf("field %q not found in table %q", table.ChildOf.ParentTableField, table.ChildOf.ParentTable) + } + } + + for _, column := range table.Columns { + if column.Nested != nil { + for _, nestedColumn := range column.Nested.Columns { + fields += fmt.Sprintf(", %s.%s", column.Name, nestedColumn.Name) + } + } else { + fields += fmt.Sprintf(", %s", column.Name) + } + } + query := fmt.Sprintf(` + INSERT INTO %s + SELECT %s, %s, %d, true %s + FROM %s WHERE %s > %d AND _deleted_ != 1 + `, tableFullName, sql.DialectFieldBlockNumber, sql.DialectFieldBlockTimestamp, version, fields, tableFullName, sql.DialectFieldBlockNumber, lastValidBlockNum) + + err := doWithRetry(query) + if err != nil { + return fmt.Errorf("deleting block from %d: %w", lastValidBlockNum, err) + } + + d.logger.Info("undo completed", zap.String("table", table.Name), zapx.HumanDuration("duration", time.Since(start))) + } + err = d.CommitTransaction() + if err != nil { + return fmt.Errorf("committing transaction: %w", err) + } + + return nil +} + +func (d *Database) Clone() sql.Database { + base := d.BaseClone() + d.BaseDatabase = base + return d +} + +func (d *Database) DatabaseHash(schemaName string) (uint64, error) { + panic("not implemented") +} diff --git a/sink/sql/db_proto/sql/click_house/decimal.go b/sink/sql/db_proto/sql/click_house/decimal.go new file mode 100644 index 000000000..7a2904437 --- /dev/null +++ b/sink/sql/db_proto/sql/click_house/decimal.go @@ -0,0 +1,233 @@ +package clickhouse + +import ( + "fmt" + "math/big" + "strings" + + "github.com/ClickHouse/ch-go/proto" +) + +// StringToDecimal128 converts a decimal string to proto.Decimal128. +func StringToDecimal128(s string, scale int32) (proto.Decimal128, error) { + s = strings.TrimSpace(s) + if s == "" { + return proto.Decimal128{}, fmt.Errorf("empty string cannot be converted to decimal") + } + + negative := false + if strings.HasPrefix(s, "-") { + negative = true + s = s[1:] + } else if strings.HasPrefix(s, "+") { + s = s[1:] + } + + if scale > 38 { + return proto.Decimal128{}, fmt.Errorf("scale cannot exceed 38, got %d", scale) + } + if scale < 0 { + return proto.Decimal128{}, fmt.Errorf("scale cannot be negative, got %d", scale) + } + + parts := strings.Split(s, ".") + if len(parts) > 2 { + return proto.Decimal128{}, fmt.Errorf("invalid decimal format: %s", s) + } + + integerPart := parts[0] + fractionalPart := "" + if len(parts) == 2 { + fractionalPart = parts[1] + } + + if len(fractionalPart) > int(scale) { + fractionalPart = fractionalPart[:scale] + } else { + for len(fractionalPart) < int(scale) { + fractionalPart += "0" + } + } + + // Validate that all characters are digits + for _, r := range integerPart + fractionalPart { + if r < '0' || r > '9' { + return proto.Decimal128{}, fmt.Errorf("invalid character in decimal: %c", r) + } + } + + combinedStr := integerPart + fractionalPart + if combinedStr == "" { + combinedStr = "0" + } + + bigInt := new(big.Int) + if _, ok := bigInt.SetString(combinedStr, 10); !ok { + return proto.Decimal128{}, fmt.Errorf("failed to parse decimal: %s", combinedStr) + } + + if negative { + bigInt.Neg(bigInt) + } + + maxDecimal128 := new(big.Int) + maxDecimal128.Exp(big.NewInt(10), big.NewInt(38), nil) + minDecimal128 := new(big.Int).Neg(maxDecimal128) + + if bigInt.Cmp(maxDecimal128) >= 0 || bigInt.Cmp(minDecimal128) < 0 { + return proto.Decimal128{}, fmt.Errorf("decimal value out of range for Decimal128: %s", s) + } + + var low, high uint64 + + if bigInt.Sign() >= 0 { + low = bigInt.Uint64() + if bigInt.BitLen() > 64 { + bigInt.Rsh(bigInt, 64) + high = bigInt.Uint64() + } + } else { + absBigInt := new(big.Int).Abs(bigInt) + maxUint128 := new(big.Int) + maxUint128.SetBit(maxUint128, 128, 1) + twosComplement := new(big.Int).Sub(maxUint128, absBigInt) + low = twosComplement.Uint64() + if twosComplement.BitLen() > 64 { + twosComplement.Rsh(twosComplement, 64) + high = twosComplement.Uint64() + } else { + high = ^uint64(0) + } + } + + return proto.Decimal128(proto.Int128{Low: low, High: high}), nil +} + +// StringToDecimal256 converts a decimal string to proto.Decimal256. +func StringToDecimal256(s string, scale int32) (proto.Decimal256, error) { + s = strings.TrimSpace(s) + if s == "" { + return proto.Decimal256{}, fmt.Errorf("empty string cannot be converted to decimal") + } + + negative := false + if strings.HasPrefix(s, "-") { + negative = true + s = s[1:] + } else if strings.HasPrefix(s, "+") { + s = s[1:] + } + + if scale > 76 { + return proto.Decimal256{}, fmt.Errorf("scale cannot exceed 76, got %d", scale) + } + if scale < 0 { + return proto.Decimal256{}, fmt.Errorf("scale cannot be negative, got %d", scale) + } + + parts := strings.Split(s, ".") + if len(parts) > 2 { + return proto.Decimal256{}, fmt.Errorf("invalid decimal format: %s", s) + } + + integerPart := parts[0] + fractionalPart := "" + if len(parts) == 2 { + fractionalPart = parts[1] + } + + if len(fractionalPart) > int(scale) { + fractionalPart = fractionalPart[:scale] + } else { + for len(fractionalPart) < int(scale) { + fractionalPart += "0" + } + } + + // Validate that all characters are digits + for _, r := range integerPart + fractionalPart { + if r < '0' || r > '9' { + return proto.Decimal256{}, fmt.Errorf("invalid character in decimal: %c", r) + } + } + + combinedStr := integerPart + fractionalPart + if combinedStr == "" { + combinedStr = "0" + } + + bigInt := new(big.Int) + if _, ok := bigInt.SetString(combinedStr, 10); !ok { + return proto.Decimal256{}, fmt.Errorf("failed to parse decimal: %s", combinedStr) + } + + if negative { + bigInt.Neg(bigInt) + } + + maxDecimal256 := new(big.Int) + maxDecimal256.Exp(big.NewInt(2), big.NewInt(255), nil) + maxDecimal256.Sub(maxDecimal256, big.NewInt(1)) + minDecimal256 := new(big.Int) + minDecimal256.Exp(big.NewInt(2), big.NewInt(255), nil) + minDecimal256.Neg(minDecimal256) + + if bigInt.Cmp(maxDecimal256) > 0 || bigInt.Cmp(minDecimal256) < 0 { + return proto.Decimal256{}, fmt.Errorf("decimal value out of range for Decimal256: %s", s) + } + + var lowLow, lowHigh, highLow, highHigh uint64 + + if bigInt.Sign() >= 0 { + tempBig := new(big.Int).Set(bigInt) + lowLow = tempBig.Uint64() + tempBig.Rsh(tempBig, 64) + if tempBig.BitLen() > 0 { + lowHigh = tempBig.Uint64() + tempBig.Rsh(tempBig, 64) + } + if tempBig.BitLen() > 0 { + highLow = tempBig.Uint64() + tempBig.Rsh(tempBig, 64) + } + if tempBig.BitLen() > 0 { + highHigh = tempBig.Uint64() + } + } else { + absBigInt := new(big.Int).Abs(bigInt) + maxUint256 := new(big.Int) + maxUint256.SetBit(maxUint256, 256, 1) + twosComplement := new(big.Int).Sub(maxUint256, absBigInt) + tempBig := new(big.Int).Set(twosComplement) + lowLow = tempBig.Uint64() + tempBig.Rsh(tempBig, 64) + if tempBig.BitLen() > 0 { + lowHigh = tempBig.Uint64() + tempBig.Rsh(tempBig, 64) + } else { + lowHigh = ^uint64(0) + } + if tempBig.BitLen() > 0 { + highLow = tempBig.Uint64() + tempBig.Rsh(tempBig, 64) + } else { + highLow = ^uint64(0) + } + if tempBig.BitLen() > 0 { + highHigh = tempBig.Uint64() + } else { + highHigh = ^uint64(0) + } + } + + return proto.Decimal256(proto.Int256{ + Low: proto.UInt128{ + Low: lowLow, + High: lowHigh, + }, + High: proto.UInt128{ + Low: highLow, + High: highHigh, + }, + }), nil +} diff --git a/sink/sql/db_proto/sql/click_house/dialect.go b/sink/sql/db_proto/sql/click_house/dialect.go new file mode 100644 index 000000000..1d09c9b65 --- /dev/null +++ b/sink/sql/db_proto/sql/click_house/dialect.go @@ -0,0 +1,387 @@ +package clickhouse + +import ( + "encoding/hex" + "fmt" + "hash/fnv" + "sort" + "strings" + + "github.com/streamingfast/substreams/sink/sql/bytes" + sql2 "github.com/streamingfast/substreams/sink/sql/db_proto/sql" + "github.com/streamingfast/substreams/sink/sql/db_proto/sql/schema" + pbSchmema "github.com/streamingfast/substreams/pb/sf/substreams/sink/sql/schema/v1" + "go.uber.org/zap" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/dynamicpb" +) + +const staticSqlCreatDatabase = ` + CREATE DATABASE IF NOT EXISTS %s; +` +const staticSqlCreateBlock = ` + CREATE TABLE IF NOT EXISTS %s._blocks_ ( + number UInt64, + hash text, + timestamp timestamp, + version Int64, + deleted bool + + ) + ENGINE = ReplacingMergeTree(version, deleted) + PARTITION BY (toYYYYMM(timestamp)) + PRIMARY KEY (number) + ORDER BY (number) + SETTINGS + allow_experimental_replacing_merge_with_cleanup = 1; +` + +const clickhouseTableOptionsErrorMsg = "schema annotation 'clickhouse_table_options' is required in table annotation 'option (schema.table) = { name: %q, ... }' , see: https://github.com/streamingfast/substreams-sink-sql#clickhouse-table-options for configuration details" + +type DialectClickHouse struct { + *sql2.BaseDialect + schemaName string + bytesEncoding bytes.Encoding +} + +func NewDialectClickHouse(schema *schema.Schema, bytesEncoding bytes.Encoding, logger *zap.Logger) (*DialectClickHouse, error) { + d := &DialectClickHouse{ + BaseDialect: sql2.NewBaseDialect(schema.TableRegistry, logger), + schemaName: schema.Name, + bytesEncoding: bytesEncoding, + } + + err := d.init() + if err != nil { + return nil, fmt.Errorf("initializing dialect: %w", err) + } + + for _, table := range schema.TableRegistry { + err := d.createTable(table) + if err != nil { + return nil, fmt.Errorf("handling table %q: %w", table.Name, err) + } + } + + return d, nil +} + +func (d *DialectClickHouse) UseVersionField() bool { + return true +} + +func (d *DialectClickHouse) UseDeletedField() bool { + return true +} + +func (d *DialectClickHouse) init() error { + return nil +} + +func (d *DialectClickHouse) createTable(table *schema.Table) error { + var sb strings.Builder + + tableName := d.FullTableName(table) + + sb.WriteString(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (", tableName)) + + sb.WriteString(fmt.Sprintf(" %s UInt64 NOT NULL,", sql2.DialectFieldBlockNumber)) + sb.WriteString(fmt.Sprintf(" %s timestamp NOT NULL,", sql2.DialectFieldBlockTimestamp)) + sb.WriteString(fmt.Sprintf(" %s Int64 NOT NULL,", sql2.DialectFieldVersion)) + sb.WriteString(fmt.Sprintf(" %s bool NOT NULL,", sql2.DialectFieldDeleted)) + + var primaryKeyFieldName string + if table.PrimaryKey != nil { + pk := table.PrimaryKey + primaryKeyFieldName = pk.Name + sb.WriteString(fmt.Sprintf("%s %s,", pk.Name, MapFieldType(pk.FieldDescriptor, d.bytesEncoding, table.Columns[pk.Index]))) + } + + if table.ChildOf != nil { + parentTable, parentFound := d.TableRegistry[table.ChildOf.ParentTable] + if !parentFound { + return fmt.Errorf("parent table %q not found", table.ChildOf.ParentTable) + } + fieldFound := false + for _, parentField := range parentTable.Columns { + + if parentField.Name == table.ChildOf.ParentTableField { + sb.WriteString(fmt.Sprintf("%s %s NOT NULL,", parentField.Name, MapFieldType(parentField.FieldDescriptor, d.bytesEncoding, parentField))) + fieldFound = true + break + } + } + if !fieldFound { + return fmt.Errorf("field %q not found in table %q", table.ChildOf.ParentTableField, table.ChildOf.ParentTable) + } + } + + var buildNestedFields func(columns []*schema.Column) string + buildNestedFields = func(columns []*schema.Column) string { + var fields []string + for _, nestedCol := range columns { + if nestedCol.Nested != nil { + innerFields := buildNestedFields(nestedCol.Nested.Columns) + fields = append(fields, fmt.Sprintf("%s Nested(%s)", nestedCol.Name, innerFields)) + } else { + fieldType := MapFieldType(nestedCol.FieldDescriptor, d.bytesEncoding, nestedCol).String() + fields = append(fields, fmt.Sprintf("%s %s", nestedCol.Name, fieldType)) + } + } + return strings.Join(fields, ", ") + } + + var processColumn func(f *schema.Column, sb *strings.Builder) + processColumn = func(f *schema.Column, sb *strings.Builder) { + if f.Nested != nil { + nestedFieldsStr := buildNestedFields(f.Nested.Columns) + sb.WriteString(fmt.Sprintf("%s Nested(%s)", f.Name, nestedFieldsStr)) + sb.WriteString(",") + } else { + fieldType := MapFieldType(f.FieldDescriptor, d.bytesEncoding, f).String() + sb.WriteString(fmt.Sprintf("%s %s", f.Name, fieldType)) + sb.WriteString(",") + } + } + + for _, f := range table.Columns { + if f.Name == primaryKeyFieldName { + continue + } + processColumn(f, &sb) + } + + temp := sb.String() + temp = temp[:len(temp)-1] + sb = strings.Builder{} + sb.WriteString(temp) + + replacingMergeTree := "ReplacingMergeTree(_version_, _deleted_)" + + primaryKey := "" + if primaryKeyFieldName != "" { + primaryKey = fmt.Sprintf("PRIMARY KEY (%s)", primaryKeyFieldName) + } + + orderBy, err := orderByString(table) + if err != nil { + return fmt.Errorf("getting 'order by' string: %w", err) + } + + partitionBy, err := partitionByString(table) + if err != nil { + return fmt.Errorf("getting 'partition by' string: %w", err) + } + + indexes, err := indexString(table) + if err != nil { + return fmt.Errorf("getting 'index' string: %w", err) + } + + sb.WriteString(fmt.Sprintf(" %s) ENGINE = %s %s %s %s", indexes, replacingMergeTree, primaryKey, partitionBy, orderBy)) + sb.WriteString(" SETTINGS\n") + sb.WriteString(" allow_experimental_replacing_merge_with_cleanup = 1") + sb.WriteString(";") + + d.AddCreateTableSql(table.Name, sb.String()) + + return nil + +} + +func (d *DialectClickHouse) FullTableName(table *schema.Table) string { + return tableName(d.schemaName, table.Name) +} + +func (d *DialectClickHouse) AppendInlineFieldValues(fieldValues []any, fd protoreflect.FieldDescriptor, fv protoreflect.Value, dm *dynamicpb.Message) ([]any, error) { + if fd.IsList() { + list := fv.List() + if list.Len() > 0 { + firstMessage := list.Get(0).Message().Interface().(*dynamicpb.Message) + nestedFields := firstMessage.Descriptor().Fields() + + for j := 0; j < nestedFields.Len(); j++ { + nestedFd := nestedFields.Get(j) + var nestedValues []interface{} + + for k := 0; k < list.Len(); k++ { + fm := list.Get(k).Message().Interface().(*dynamicpb.Message) + nestedValue := fm.Get(nestedFd) + nestedValues = append(nestedValues, nestedValue.Interface()) + } + + fieldValues = append(fieldValues, nestedValues) + } + } else { + msgDesc := fd.Message() + nestedFields := msgDesc.Fields() + + for j := 0; j < nestedFields.Len(); j++ { + fieldValues = append(fieldValues, []interface{}{}) + } + } + } else { + fm := fv.Message().Interface().(*dynamicpb.Message) + nestedFields := fm.Descriptor().Fields() + for j := 0; j < nestedFields.Len(); j++ { + nestedFd := nestedFields.Get(j) + nestedValue := fm.Get(nestedFd) + fieldValues = append(fieldValues, []interface{}{nestedValue.Interface()}) + } + } + return fieldValues, nil +} + +func (d *DialectClickHouse) SchemaHash() string { + h := fnv.New64a() + + var buf []byte + + var sqls []string + for _, sql := range d.CreateTableSql { + sqls = append(sqls, sql) + } + + sort.Strings(sqls) + for _, sql := range sqls { + buf = append(buf, []byte(sql)...) + } + + var pk []string + for _, constraint := range d.PrimaryKeySql { + pk = append(pk, constraint.Sql) + } + sort.Strings(pk) + for _, constraint := range pk { + buf = append(buf, []byte(constraint)...) + } + + var fk []string + for _, constraint := range d.ForeignKeySql { + fk = append(fk, constraint.Sql) + } + sort.Strings(fk) + for _, constraint := range fk { + buf = append(buf, []byte(constraint)...) + } + + var uniques []string + for _, constraint := range d.UniqueConstraintSql { + uniques = append(uniques, constraint.Sql) + } + sort.Strings(uniques) + for _, constraint := range uniques { + buf = append(buf, []byte(constraint)...) + } + + _, err := h.Write(buf) + if err != nil { + panic("unable to write to hash") + } + + data := h.Sum(nil) + return hex.EncodeToString(data) +} + +func tableName(schemaName string, tableName string) string { + return fmt.Sprintf("%s.%s", schemaName, tableName) +} + +func orderByString(table *schema.Table) (string, error) { + info := table.PbTableInfo.ClickhouseTableOptions + if info == nil { + return "", fmt.Errorf(clickhouseTableOptionsErrorMsg, table.Name) + } + + if len(info.OrderByFields) == 0 { + return "", fmt.Errorf("clickhouse table options for table %q don't have any 'order_by_fields'. Require at least 1", table.Name) + } + + out := "" + for i, field := range info.OrderByFields { + w := wrapWithClickhouseFunction(field.Name, field.Function) + if field.Descending { + w += " desc" + } + out += w + if i < len(info.OrderByFields)-1 { + out += ", " + } + } + + return fmt.Sprintf("ORDER BY (%s)", out), nil +} + +func partitionByString(table *schema.Table) (string, error) { + info := table.PbTableInfo.ClickhouseTableOptions + if info == nil { + return "", fmt.Errorf(clickhouseTableOptionsErrorMsg, table.Name) + } + + var parts []string + + hasBlockTimestampFunction := false + for _, field := range info.PartitionFields { + if field.Name == sql2.DialectFieldBlockTimestamp { + hasBlockTimestampFunction = true + break + } + } + + if !hasBlockTimestampFunction { + parts = append(parts, wrapWithClickhouseFunction(sql2.DialectFieldBlockTimestamp, pbSchmema.Function_toYYYYMM)) + } + + for _, field := range info.PartitionFields { + w := wrapWithClickhouseFunction(field.Name, field.Function) + parts = append(parts, w) + } + + return fmt.Sprintf("PARTITION BY (%s)", strings.Join(parts, ", ")), nil +} + +func wrapWithClickhouseFunction(fieldName string, function pbSchmema.Function) string { + format := "%s" + switch function { + case pbSchmema.Function_unset: + case pbSchmema.Function_toMonth: + format = "toMonth(%s)" + case pbSchmema.Function_toDate: + format = "toDate(%s)" + case pbSchmema.Function_toStartOfMonth: + case pbSchmema.Function_toYear: + format = "toYear(%s)" + case pbSchmema.Function_toYYYYDD: + format = "toYYYYMMDD(%s)" + case pbSchmema.Function_toYYYYMM: + format = "toYYYYMM(%s)" + } + return fmt.Sprintf(format, fieldName) +} + +func indexString(table *schema.Table) (string, error) { + indexes := "" + if table.PbTableInfo != nil && table.PbTableInfo.ClickhouseTableOptions != nil { + if len(table.PbTableInfo.ClickhouseTableOptions.IndexFields) > 0 { + var indexStrings []string + for _, indexField := range table.PbTableInfo.ClickhouseTableOptions.IndexFields { + fieldName := indexField.FieldName + if indexField.Function != pbSchmema.Function_unset { + fieldName = fmt.Sprintf("%s(%s)", indexField.Function.String(), fieldName) + } + + indexStr := fmt.Sprintf("INDEX %s %s TYPE %s GRANULARITY %d", + indexField.Name, + fieldName, + indexField.Type.String(), + indexField.Granularity) + indexStrings = append(indexStrings, indexStr) + } + + if len(indexStrings) > 0 { + indexes = ", " + strings.Join(indexStrings, ", ") + } + } + } + return indexes, nil +} diff --git a/sink/sql/db_proto/sql/click_house/integer.go b/sink/sql/db_proto/sql/click_house/integer.go new file mode 100644 index 000000000..f73e8de18 --- /dev/null +++ b/sink/sql/db_proto/sql/click_house/integer.go @@ -0,0 +1,233 @@ +package clickhouse + +import ( + "fmt" + "math/big" + "strings" + + "github.com/ClickHouse/ch-go/proto" +) + +// StringToInt128 converts a string to proto.Int128. +func StringToInt128(s string) (proto.Int128, error) { + s = strings.TrimSpace(s) + if s == "" { + return proto.Int128{}, fmt.Errorf("empty string cannot be converted to int128") + } + + bigInt := new(big.Int) + if _, ok := bigInt.SetString(s, 10); !ok { + return proto.Int128{}, fmt.Errorf("invalid integer format: %s", s) + } + + max128 := new(big.Int) + max128.Exp(big.NewInt(2), big.NewInt(127), nil) + max128.Sub(max128, big.NewInt(1)) + + min128 := new(big.Int) + min128.Exp(big.NewInt(2), big.NewInt(127), nil) + min128.Neg(min128) + + if bigInt.Cmp(max128) > 0 || bigInt.Cmp(min128) < 0 { + return proto.Int128{}, fmt.Errorf("integer value out of range for Int128: %s", s) + } + + var low, high uint64 + + if bigInt.Sign() >= 0 { + low = bigInt.Uint64() + if bigInt.BitLen() > 64 { + bigInt.Rsh(bigInt, 64) + high = bigInt.Uint64() + } + } else { + absBigInt := new(big.Int).Abs(bigInt) + maxUint128 := new(big.Int) + maxUint128.SetBit(maxUint128, 128, 1) + twosComplement := new(big.Int).Sub(maxUint128, absBigInt) + low = twosComplement.Uint64() + if twosComplement.BitLen() > 64 { + twosComplement.Rsh(twosComplement, 64) + high = twosComplement.Uint64() + } else { + high = ^uint64(0) + } + } + + return proto.Int128{Low: low, High: high}, nil +} + +// StringToUInt128 converts a string to proto.UInt128. +func StringToUInt128(s string) (proto.UInt128, error) { + s = strings.TrimSpace(s) + if s == "" { + return proto.UInt128{}, fmt.Errorf("empty string cannot be converted to uint128") + } + + if strings.HasPrefix(s, "-") { + return proto.UInt128{}, fmt.Errorf("negative values not allowed for UInt128: %s", s) + } + + if strings.HasPrefix(s, "+") { + s = s[1:] + } + + bigInt := new(big.Int) + if _, ok := bigInt.SetString(s, 10); !ok { + return proto.UInt128{}, fmt.Errorf("invalid integer format: %s", s) + } + + maxUint128 := new(big.Int) + maxUint128.Exp(big.NewInt(2), big.NewInt(128), nil) + maxUint128.Sub(maxUint128, big.NewInt(1)) + + if bigInt.Sign() < 0 || bigInt.Cmp(maxUint128) > 0 { + return proto.UInt128{}, fmt.Errorf("integer value out of range for UInt128: %s", s) + } + + var low, high uint64 + low = bigInt.Uint64() + if bigInt.BitLen() > 64 { + bigInt.Rsh(bigInt, 64) + high = bigInt.Uint64() + } + + return proto.UInt128{Low: low, High: high}, nil +} + +// StringToInt256 converts a string to proto.Int256. +func StringToInt256(s string) (proto.Int256, error) { + s = strings.TrimSpace(s) + if s == "" { + return proto.Int256{}, fmt.Errorf("empty string cannot be converted to int256") + } + + bigInt := new(big.Int) + if _, ok := bigInt.SetString(s, 10); !ok { + return proto.Int256{}, fmt.Errorf("invalid integer format: %s", s) + } + + max256 := new(big.Int) + max256.Exp(big.NewInt(2), big.NewInt(255), nil) + max256.Sub(max256, big.NewInt(1)) + + min256 := new(big.Int) + min256.Exp(big.NewInt(2), big.NewInt(255), nil) + min256.Neg(min256) + + if bigInt.Cmp(max256) > 0 || bigInt.Cmp(min256) < 0 { + return proto.Int256{}, fmt.Errorf("integer value out of range for Int256: %s", s) + } + + var lowLow, lowHigh, highLow, highHigh uint64 + + if bigInt.Sign() >= 0 { + tempBig := new(big.Int).Set(bigInt) + lowLow = tempBig.Uint64() + tempBig.Rsh(tempBig, 64) + if tempBig.BitLen() > 0 { + lowHigh = tempBig.Uint64() + tempBig.Rsh(tempBig, 64) + } + if tempBig.BitLen() > 0 { + highLow = tempBig.Uint64() + tempBig.Rsh(tempBig, 64) + } + if tempBig.BitLen() > 0 { + highHigh = tempBig.Uint64() + } + } else { + absBigInt := new(big.Int).Abs(bigInt) + maxUint256 := new(big.Int) + maxUint256.SetBit(maxUint256, 256, 1) + twosComplement := new(big.Int).Sub(maxUint256, absBigInt) + tempBig := new(big.Int).Set(twosComplement) + lowLow = tempBig.Uint64() + tempBig.Rsh(tempBig, 64) + if tempBig.BitLen() > 0 { + lowHigh = tempBig.Uint64() + tempBig.Rsh(tempBig, 64) + } else { + lowHigh = ^uint64(0) + } + if tempBig.BitLen() > 0 { + highLow = tempBig.Uint64() + tempBig.Rsh(tempBig, 64) + } else { + highLow = ^uint64(0) + } + if tempBig.BitLen() > 0 { + highHigh = tempBig.Uint64() + } else { + highHigh = ^uint64(0) + } + } + + return proto.Int256{ + Low: proto.UInt128{ + Low: lowLow, + High: lowHigh, + }, + High: proto.UInt128{ + Low: highLow, + High: highHigh, + }, + }, nil +} + +// StringToUInt256 converts a string to proto.UInt256. +func StringToUInt256(s string) (proto.UInt256, error) { + s = strings.TrimSpace(s) + if s == "" { + return proto.UInt256{}, fmt.Errorf("empty string cannot be converted to uint256") + } + + if strings.HasPrefix(s, "-") { + return proto.UInt256{}, fmt.Errorf("negative values not allowed for UInt256: %s", s) + } + + if strings.HasPrefix(s, "+") { + s = s[1:] + } + + bigInt := new(big.Int) + if _, ok := bigInt.SetString(s, 10); !ok { + return proto.UInt256{}, fmt.Errorf("invalid integer format: %s", s) + } + + maxUint256 := new(big.Int) + maxUint256.Exp(big.NewInt(2), big.NewInt(256), nil) + maxUint256.Sub(maxUint256, big.NewInt(1)) + + if bigInt.Sign() < 0 || bigInt.Cmp(maxUint256) > 0 { + return proto.UInt256{}, fmt.Errorf("integer value out of range for UInt256: %s", s) + } + + var lowLow, lowHigh, highLow, highHigh uint64 + + tempBig := new(big.Int).Set(bigInt) + lowLow = tempBig.Uint64() + tempBig.Rsh(tempBig, 64) + if tempBig.BitLen() > 0 { + lowHigh = tempBig.Uint64() + tempBig.Rsh(tempBig, 64) + } + if tempBig.BitLen() > 0 { + highLow = tempBig.Uint64() + tempBig.Rsh(tempBig, 64) + } + if tempBig.BitLen() > 0 { + highHigh = tempBig.Uint64() + } + + return proto.UInt256{ + Low: proto.UInt128{ + Low: lowLow, + High: lowHigh, + }, + High: proto.UInt128{ + Low: highLow, + High: highHigh, + }, + }, nil +} diff --git a/sink/sql/db_proto/sql/click_house/types.go b/sink/sql/db_proto/sql/click_house/types.go new file mode 100644 index 000000000..3fb420e5b --- /dev/null +++ b/sink/sql/db_proto/sql/click_house/types.go @@ -0,0 +1,211 @@ +package clickhouse + +import ( + "fmt" + + "github.com/ClickHouse/ch-go/proto" + "github.com/streamingfast/substreams/sink/sql/bytes" + "github.com/streamingfast/substreams/sink/sql/db_proto/sql/schema" + v1 "github.com/streamingfast/substreams/pb/sf/substreams/sink/sql/schema/v1" + "google.golang.org/protobuf/reflect/protoreflect" +) + +type DataType string + +const ( + TypeInteger8 DataType = "Int8" + TypeInteger16 DataType = "Int16" + TypeInteger32 DataType = "Int32" + TypeInteger64 DataType = "Int64" + TypeInteger128 DataType = "Int128" + TypeInteger256 DataType = "Int256" + + TypeUInt8 DataType = "UInt8" + TypeUInt16 DataType = "UInt16" + TypeUInt32 DataType = "UInt32" + TypeUInt64 DataType = "UInt64" + TypeUInt128 DataType = "UInt128" + TypeUInt256 DataType = "UInt256" + + TypeFloat32 DataType = "Float32" + TypeFloat64 DataType = "Float64" + + TypeDecimal128 = "Decimal128" + TypeDecimal256 = "Decimal256" + + TypeBool DataType = "Bool" + TypeVarchar DataType = "VARCHAR" + + TypeDateTime DataType = "DateTime" +) + +func (s DataType) String() string { + return string(s) +} + +func MapFieldType(fd protoreflect.FieldDescriptor, bytesEncoding bytes.Encoding, column *schema.Column) DataType { + kind := fd.Kind() + var baseType DataType + + switch kind { + case protoreflect.MessageKind: + switch { + case fd.Message().FullName() == "google.protobuf.Timestamp": + baseType = TypeDateTime + case column.Nested != nil: + return DataType("") + default: + panic(fmt.Sprintf("Message type not supported: %s", string(fd.Message().FullName()))) + } + case protoreflect.EnumKind: + baseType = TypeInteger32 + case protoreflect.BoolKind: + baseType = TypeBool + case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind: + baseType = TypeInteger32 + case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind: + baseType = TypeInteger64 + case protoreflect.Uint64Kind, protoreflect.Fixed64Kind: + baseType = TypeUInt64 + case protoreflect.Uint32Kind, protoreflect.Fixed32Kind: + baseType = TypeUInt32 + case protoreflect.FloatKind: + baseType = TypeFloat32 + case protoreflect.DoubleKind: + baseType = TypeFloat64 + case protoreflect.StringKind: + if column.ConvertTo != nil && column.ConvertTo.Convertion != nil { + switch column.ConvertTo.Convertion.(type) { + case *v1.StringConvertion_Int128: + baseType = TypeInteger128 + case *v1.StringConvertion_Uint128: + baseType = TypeUInt128 + case *v1.StringConvertion_Int256: + baseType = TypeInteger256 + case *v1.StringConvertion_Uint256: + baseType = TypeUInt256 + case *v1.StringConvertion_Decimal128: + decimal128Conv := column.ConvertTo.Convertion.(*v1.StringConvertion_Decimal128) + baseType = DataType(fmt.Sprintf("Decimal128(%d)", decimal128Conv.Decimal128.Scale)) + case *v1.StringConvertion_Decimal256: + decimal256Conv := column.ConvertTo.Convertion.(*v1.StringConvertion_Decimal256) + baseType = DataType(fmt.Sprintf("Decimal256(%d)", decimal256Conv.Decimal256.Scale)) + default: + panic(fmt.Sprintf("unsupported type: %s", kind)) + } + } else { + baseType = TypeVarchar + } + + case protoreflect.BytesKind: + baseType = TypeVarchar + default: + panic(fmt.Sprintf("unsupported type: %s", kind)) + } + + if fd.IsList() { + return DataType(fmt.Sprintf("Array(%s)", baseType)) + } + + return baseType +} + +func ColInputForColumn(fd protoreflect.FieldDescriptor, bytesEncoding bytes.Encoding, column *schema.Column) proto.ColInput { + var baseInput proto.ColInput + + switch fd.Kind() { + case protoreflect.MessageKind: + switch { + case fd.Message().FullName() == "google.protobuf.Timestamp": + baseInput = &proto.ColDateTime{} + case column.Nested != nil: + return nil + default: + panic(fmt.Sprintf("Message type not supported: %s", string(fd.Message().FullName()))) + } + case protoreflect.EnumKind: + baseInput = &proto.ColInt32{} + case protoreflect.BoolKind: + baseInput = &proto.ColBool{} + case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind: + baseInput = &proto.ColInt32{} + case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind: + baseInput = &proto.ColInt64{} + case protoreflect.Uint64Kind, protoreflect.Fixed64Kind: + baseInput = &proto.ColUInt64{} + case protoreflect.Uint32Kind, protoreflect.Fixed32Kind: + baseInput = &proto.ColUInt32{} + case protoreflect.FloatKind: + baseInput = &proto.ColFloat32{} + case protoreflect.DoubleKind: + baseInput = &proto.ColFloat64{} + case protoreflect.StringKind: + if column.ConvertTo != nil && column.ConvertTo.Convertion != nil { + switch column.ConvertTo.Convertion.(type) { + case *v1.StringConvertion_Int128: + baseInput = &proto.ColInt128{} + case *v1.StringConvertion_Uint128: + baseInput = &proto.ColUInt128{} + case *v1.StringConvertion_Int256: + baseInput = &proto.ColInt256{} + case *v1.StringConvertion_Uint256: + baseInput = &proto.ColUInt256{} + case *v1.StringConvertion_Decimal128: + innerCol := &proto.ColDecimal128{} + scale := (column.ConvertTo.Convertion.(*v1.StringConvertion_Decimal128)).Decimal128.Scale + baseInput = &ColScaledDecimal128{ + ColDecimal128: innerCol, + scale: uint8(scale), + } + case *v1.StringConvertion_Decimal256: + innerCol := &proto.ColDecimal256{} + scale := (column.ConvertTo.Convertion.(*v1.StringConvertion_Decimal256)).Decimal256.Scale + baseInput = &ColScaledDecimal256{ + ColDecimal256: innerCol, + scale: uint8(scale), + } + default: + panic(fmt.Sprintf("unsupported type: %s", fd.Kind())) + } + } else { + baseInput = &proto.ColStr{} + } + case protoreflect.BytesKind: + if bytesEncoding.IsStringType() { + baseInput = &proto.ColStr{} + } else { + baseInput = &proto.ColBytes{} + } + default: + panic(fmt.Sprintf("unsupported type: %s", fd.Kind())) + } + + if fd.IsList() { + switch base := baseInput.(type) { + case *proto.ColInt32: + return proto.NewArray(base) + case *proto.ColInt64: + return proto.NewArray(base) + case *proto.ColUInt32: + return proto.NewArray(base) + case *proto.ColUInt64: + return proto.NewArray(base) + case *proto.ColFloat32: + return proto.NewArray(base) + case *proto.ColFloat64: + return proto.NewArray(base) + case *proto.ColBool: + return proto.NewArray(base) + case *proto.ColStr: + return proto.NewArray(base) + case *proto.ColBytes: + return proto.NewArray(base) + case *proto.ColDateTime: + return proto.NewArray(base) + default: + panic(fmt.Sprintf("unsupported array base type: %T", base)) + } + } + + return baseInput +} diff --git a/sink/sql/db_proto/sql/constraint.go b/sink/sql/db_proto/sql/constraint.go new file mode 100644 index 000000000..f3ebfbf1e --- /dev/null +++ b/sink/sql/db_proto/sql/constraint.go @@ -0,0 +1,20 @@ +package sql + +import "fmt" + +type ForeignKey struct { + Name string + Table string + Field string + ForeignTable string + ForeignField string +} + +type Constraint struct { + Table string + Sql string +} + +func (f *ForeignKey) String() string { + return fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s(%s)", f.Table, f.Name, f.Field, f.ForeignTable, f.ForeignField) +} diff --git a/sink/sql/db_proto/sql/context.go b/sink/sql/db_proto/sql/context.go new file mode 100644 index 000000000..8eb582c14 --- /dev/null +++ b/sink/sql/db_proto/sql/context.go @@ -0,0 +1,17 @@ +package sql + +type Context struct { + blockNumber int +} + +func NewContext() *Context { + return &Context{} +} + +func (c *Context) SetNumber(id int) { + c.blockNumber = id +} + +func (c *Context) BlockNumber() int { + return c.blockNumber +} diff --git a/sink/sql/db_proto/sql/database.go b/sink/sql/db_proto/sql/database.go new file mode 100644 index 000000000..fadd3d81b --- /dev/null +++ b/sink/sql/db_proto/sql/database.go @@ -0,0 +1,235 @@ +package sql + +import ( + "database/sql" + "fmt" + "strings" + "time" + + pbSchema "github.com/streamingfast/substreams/pb/sf/substreams/sink/sql/schema/v1" + "github.com/streamingfast/substreams/sink/sql/proto" + sink "github.com/streamingfast/substreams/sink" + "go.uber.org/zap" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/dynamicpb" + "google.golang.org/protobuf/types/known/timestamppb" +) + +type Database interface { + FetchSinkInfo(schemaName string) (*SinkInfo, error) + UpdateSinkInfoHash(schemaName string, newHash string) error + StoreSinkInfo(schemaName string, schemaHash string) error + + CreateDatabase(useConstraints bool) error + WalkMessageDescriptorAndInsert(dm *dynamicpb.Message, blockNum uint64, blockTimestamp time.Time, parent *Parent) (time.Duration, error) + InsertBlock(blockNum uint64, hash string, timestamp time.Time) error + + HandleBlocksUndo(lastValidBlockNumber uint64) error + + FetchCursor() (*sink.Cursor, error) + StoreCursor(cursor *sink.Cursor) error + + BeginTransaction() error + CommitTransaction() error + RollbackTransaction() + Flush() (time.Duration, error) + + DatabaseHash(schemaName string) (uint64, error) + + GetDialect() Dialect + + Clone() Database + Open() error +} + +type BaseDatabase struct { + logger *zap.Logger + mapOutputType string + insertStatements map[string]*sql.Stmt + RootMessageDescriptor protoreflect.MessageDescriptor + useProtoOptions bool +} + +func NewBaseDatabase(moduleOutputType string, rootMessageDescriptor protoreflect.MessageDescriptor, useProtoOptions bool, logger *zap.Logger) (database *BaseDatabase, err error) { + logger = logger.Named("database") + + return &BaseDatabase{ + logger: logger, + mapOutputType: moduleOutputType, + RootMessageDescriptor: rootMessageDescriptor, + insertStatements: make(map[string]*sql.Stmt), + useProtoOptions: useProtoOptions, + }, nil +} + +func (d *BaseDatabase) BaseClone() *BaseDatabase { + return &BaseDatabase{ + logger: d.logger, + mapOutputType: d.mapOutputType, + RootMessageDescriptor: d.RootMessageDescriptor, + insertStatements: d.insertStatements, + } +} + +type Parent struct { + field string + id interface{} +} + +func (d *BaseDatabase) WalkMessageDescriptorAndInsertWithDialect(dm *dynamicpb.Message, blockNum uint64, blockTimestamp time.Time, parent *Parent, dialect Dialect, inserter Inserter) (time.Duration, error) { + if dm == nil { + return 0, fmt.Errorf("received a nil message") + } + + var fieldValues []any + fieldValues = append(fieldValues, blockNum) + fieldValues = append(fieldValues, blockTimestamp) + + primaryKeyOffset := 2 + if dialect.UseVersionField() { + fieldValues = append(fieldValues, time.Now().UnixNano()) + primaryKeyOffset += 1 + } + + if dialect.UseDeletedField() { + fieldValues = append(fieldValues, false) + primaryKeyOffset += 1 + } + + md := dm.Descriptor() + tableInfo := proto.TableInfo(md) + + if tableInfo == nil && !d.useProtoOptions { + tableInfo = &pbSchema.Table{ + Name: string(md.Name()), + } + } + + d.logger.Debug("Walking message descriptor", zap.String("message_descriptor_name", string(md.Name())), zap.Any("table_info", tableInfo)) + primaryKey := "" + if tableInfo != nil { + if table := dialect.GetTable(tableInfo.Name); table != nil { + if table.PrimaryKey != nil { + primaryKey = table.PrimaryKey.Name + pkField := md.Fields().ByName(protoreflect.Name(primaryKey)) + if pkField == nil { + return 0, fmt.Errorf("missing primary key field %q for table %q", primaryKey, tableInfo.Name) + } + pkValue := dm.Get(pkField) + fieldValues = append(fieldValues, pkValue.Interface()) + } + } + } + + totalSqlDuration := time.Duration(0) + + if parent != nil { + fieldValues = append(fieldValues, parent.id) + } + + var childs []*dynamicpb.Message + + fields := md.Fields() + for i := 0; i < fields.Len(); i++ { + fd := fields.Get(i) + if string(fd.Name()) == primaryKey { + continue + } + fv := dm.Get(fd) + + if fd.IsList() { + list := fv.List() + if fd.Kind() == protoreflect.MessageKind { + fieldInfo := proto.FieldInfo(fd) + if fieldInfo != nil && fieldInfo.Inline { + var err error + fieldValues, err = dialect.AppendInlineFieldValues(fieldValues, fd, fv, dm) + if err != nil { + return 0, fmt.Errorf("appending inline field values for %q: %w", string(fd.Name()), err) + } + } else if list.Len() > 0 { + for j := 0; j < list.Len(); j++ { + fm := list.Get(j).Message().Interface().(*dynamicpb.Message) + childs = append(childs, fm) + } + } + } else if list.Len() > 0 { + var values []interface{} + for j := 0; j < list.Len(); j++ { + values = append(values, list.Get(j).Interface()) + } + fieldValues = append(fieldValues, values) + } else { + fieldValues = append(fieldValues, []interface{}{}) + } + } else if fd.Kind() == protoreflect.MessageKind { + if fv.Message().IsValid() { + fm := fv.Message().Interface().(*dynamicpb.Message) + if fm.Descriptor().FullName() == "google.protobuf.Timestamp" { + timestamp := ×tamppb.Timestamp{} + timestamp.Seconds = fm.Get(fm.Descriptor().Fields().ByName("seconds")).Int() + timestamp.Nanos = int32(fm.Get(fm.Descriptor().Fields().ByName("nanos")).Int()) + fieldValues = append(fieldValues, timestamp) + continue + } + + fieldInfo := proto.FieldInfo(fd) + if fieldInfo != nil && fieldInfo.Inline { + var err error + fieldValues, err = dialect.AppendInlineFieldValues(fieldValues, fd, fv, dm) + if err != nil { + return 0, fmt.Errorf("appending inline field values for %q: %w", string(fd.Name()), err) + } + continue + } + + childs = append(childs, fm) + } + } else { + fieldValues = append(fieldValues, fv.Interface()) + } + } + + var p *Parent + + if tableInfo != nil { + insertStartAt := time.Now() + table := dialect.GetTable(tableInfo.Name) + if table != nil { + err := inserter.Insert(table.Name, fieldValues) + if err != nil { + d.logger.Info("failed to insert into table, printing field values for debugging", zap.String("table_name", table.Name), zap.Any("field_values", fieldValues)) + return 0, fmt.Errorf("inserting into table %q: %w", table.Name, err) + } + if len(childs) > 0 && d.useProtoOptions { + if table.PrimaryKey == nil { + for _, child := range childs { + fmt.Println("child:", child.Descriptor().FullName()) + } + return 0, fmt.Errorf("table %q has no primary key and has %d associated children table", table.Name, len(childs)) + } + idx := table.PrimaryKey.Index + primaryKeyOffset + id := fieldValues[idx] + p = &Parent{ + field: strings.ToLower(string(md.Name())), + id: id, + } + } + totalSqlDuration += time.Since(insertStartAt) + } + } + + for _, fm := range childs { + sqlDuration, err := d.WalkMessageDescriptorAndInsertWithDialect(fm, blockNum, blockTimestamp, p, dialect, inserter) + if err != nil { + return 0, fmt.Errorf("processing child %q: %w", string(fm.Descriptor().FullName()), err) + } + totalSqlDuration += sqlDuration + } + + return totalSqlDuration, nil +} + +type SinkInfo struct { + SchemaHash string `json:"schema_hash"` +} diff --git a/sink/sql/db_proto/sql/dialect.go b/sink/sql/db_proto/sql/dialect.go new file mode 100644 index 000000000..1734ff7ae --- /dev/null +++ b/sink/sql/db_proto/sql/dialect.go @@ -0,0 +1,72 @@ +package sql + +import ( + "github.com/streamingfast/substreams/sink/sql/db_proto/sql/schema" + "go.uber.org/zap" + "golang.org/x/exp/maps" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/dynamicpb" +) + +const DialectTableBlock = "_blocks_" +const DialectTableCursor = "_cursors_" + +const DialectFieldBlockNumber = "_block_number_" +const DialectFieldBlockTimestamp = "_block_timestamp_" +const DialectFieldVersion = "_version_" +const DialectFieldDeleted = "_deleted_" + +type Dialect interface { + SchemaHash() string + FullTableName(table *schema.Table) string + GetTable(table string) *schema.Table + GetTables() []*schema.Table + UseVersionField() bool + UseDeletedField() bool + AppendInlineFieldValues(fieldValues []any, fd protoreflect.FieldDescriptor, fv protoreflect.Value, dm *dynamicpb.Message) ([]any, error) +} + +type BaseDialect struct { + CreateTableSql map[string]string + PrimaryKeySql []*Constraint + ForeignKeySql []*Constraint + UniqueConstraintSql []*Constraint + TableRegistry map[string]*schema.Table + Logger *zap.Logger +} + +func NewBaseDialect(registry map[string]*schema.Table, logger *zap.Logger) *BaseDialect { + return &BaseDialect{ + CreateTableSql: make(map[string]string), + TableRegistry: registry, + Logger: logger, + } +} + +func (d *BaseDialect) AddCreateTableSql(table string, sql string) { + d.CreateTableSql[table] = sql +} + +func (d *BaseDialect) GetCreateTableSql(table string) string { + return d.CreateTableSql[table] +} + +func (d *BaseDialect) AddPrimaryKeySql(table string, sql string) { + d.PrimaryKeySql = append(d.PrimaryKeySql, &Constraint{Table: table, Sql: sql}) +} + +func (d *BaseDialect) AddForeignKeySql(table string, sql string) { + d.ForeignKeySql = append(d.ForeignKeySql, &Constraint{Table: table, Sql: sql}) +} + +func (d *BaseDialect) AddUniqueConstraintSql(table string, sql string) { + d.UniqueConstraintSql = append(d.UniqueConstraintSql, &Constraint{Table: table, Sql: sql}) +} + +func (d *BaseDialect) GetTable(table string) *schema.Table { + return d.TableRegistry[table] +} + +func (d *BaseDialect) GetTables() []*schema.Table { + return maps.Values(d.TableRegistry) +} diff --git a/sink/sql/db_proto/sql/inserter.go b/sink/sql/db_proto/sql/inserter.go new file mode 100644 index 000000000..f517bc546 --- /dev/null +++ b/sink/sql/db_proto/sql/inserter.go @@ -0,0 +1,5 @@ +package sql + +type Inserter interface { + Insert(table string, values []any) error +} diff --git a/sink/sql/db_proto/sql/postgres/accumulator_inserter.go b/sink/sql/db_proto/sql/postgres/accumulator_inserter.go new file mode 100644 index 000000000..102ff1373 --- /dev/null +++ b/sink/sql/db_proto/sql/postgres/accumulator_inserter.go @@ -0,0 +1,149 @@ +package postgres + +import ( + "database/sql" + "fmt" + "strings" + + sql2 "github.com/streamingfast/substreams/sink/sql/db_proto/sql" + "github.com/streamingfast/substreams/sink/sql/db_proto/sql/schema" + "go.uber.org/zap" +) + +type accumulator struct { + query string + rowValues [][]string +} + +type AccumulatorInserter struct { + accumulators map[string]*accumulator + cursorStmt *sql.Stmt + logger *zap.Logger +} + +func NewAccumulatorInserter(logger *zap.Logger) (*AccumulatorInserter, error) { + logger = logger.Named("postgres inserter") + + return &AccumulatorInserter{ + logger: logger, + }, nil +} + +func (i *AccumulatorInserter) init(database *Database) error { + tables := database.dialect.GetTables() + accumulators := map[string]*accumulator{} + + for _, table := range tables { + query, err := createInsertFromDescriptorAcc(table, database.dialect) + if err != nil { + return fmt.Errorf("creating insert from descriptor for table %q: %w", table.Name, err) + } + accumulators[table.Name] = &accumulator{ + query: query, + } + } + accumulators["_blocks_"] = &accumulator{ + query: fmt.Sprintf("INSERT INTO %s (number, hash, timestamp) VALUES ", tableName(database.schema.Name, "_blocks_")), + } + + cursorQuery := fmt.Sprintf("INSERT INTO %s (name, cursor) VALUES ($1, $2) ON CONFLICT (name) DO UPDATE SET cursor = $2", tableName(database.schema.Name, "_cursor_")) + cs, err := database.db.Prepare(cursorQuery) + if err != nil { + return fmt.Errorf("preparing statement %q: %w", cursorQuery, err) + } + + i.accumulators = accumulators + i.cursorStmt = cs + + return nil +} + +func createInsertFromDescriptorAcc(table *schema.Table, dialect sql2.Dialect) (string, error) { + tableName := dialect.FullTableName(table) + fields := table.Columns + + var fieldNames []string + fieldNames = append(fieldNames, sql2.DialectFieldBlockNumber) + fieldNames = append(fieldNames, sql2.DialectFieldBlockTimestamp) + + if pk := table.PrimaryKey; pk != nil { + fieldNames = append(fieldNames, pk.Name) + } + + if table.ChildOf != nil { + fieldNames = append(fieldNames, table.ChildOf.ParentTableField) + } + + for _, field := range fields { + if table.PrimaryKey != nil && field.Name == table.PrimaryKey.Name { + continue + } + + if field.IsExtension { + continue + } + if field.IsRepeated { + if field.IsMessage { + continue + } + } + fieldNames = append(fieldNames, field.QuotedName()) + } + + return fmt.Sprintf("INSERT INTO %s (%s) VALUES ", + tableName, + strings.Join(fieldNames, ", "), + ), nil + +} + +func (i *AccumulatorInserter) insert(table string, values []any, database *Database) error { + var v []string + if table == "_cursor_" { + stmt := database.wrapInsertStatement(i.cursorStmt) + _, err := stmt.Exec(values...) + if err != nil { + return fmt.Errorf("executing insert: %w", err) + } + return nil + } + for _, value := range values { + v = append(v, ValueToString(value, database.dialect.bytesEncoding)) + } + accumulator := i.accumulators[table] + if accumulator == nil { + return fmt.Errorf("accumulator not found for table %q", table) + } + accumulator.rowValues = append(accumulator.rowValues, v) + + return nil +} + +func (i *AccumulatorInserter) flush(database *Database) error { + for _, acc := range i.accumulators { + if len(acc.rowValues) == 0 { + continue + } + var b strings.Builder + b.WriteString(acc.query) + for _, values := range acc.rowValues { + b.WriteString("(") + b.WriteString(strings.Join(values, ",")) + b.WriteString("),") + } + insert := strings.Trim(b.String(), ",") + + _, err := database.tx.Exec(insert) + if err != nil { + shortInsert := insert + if len(insert) > 256 { + shortInsert = insert[:256] + "..." + } + fmt.Println("insert query:", insert) + return fmt.Errorf("executing insert %s: %w", shortInsert, err) + } + acc.rowValues = acc.rowValues[:0] + } + + return nil +} diff --git a/sink/sql/db_proto/sql/postgres/database.go b/sink/sql/db_proto/sql/postgres/database.go new file mode 100644 index 000000000..dc1504bd8 --- /dev/null +++ b/sink/sql/db_proto/sql/postgres/database.go @@ -0,0 +1,422 @@ +package postgres + +import ( + "context" + pgsql "database/sql" + "fmt" + "hash/fnv" + "time" + + "github.com/streamingfast/logging/zapx" + sink "github.com/streamingfast/substreams/sink" + "github.com/streamingfast/substreams/sink/sql/bytes" + "github.com/streamingfast/substreams/sink/sql/db_changes/db" + sql2 "github.com/streamingfast/substreams/sink/sql/db_proto/sql" + "github.com/streamingfast/substreams/sink/sql/db_proto/sql/schema" + "go.uber.org/zap" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/dynamicpb" +) + +type Database struct { + *sql2.BaseDatabase + db *pgsql.DB + tx *pgsql.Tx + schema *schema.Schema + logger *zap.Logger + dialect *DialectPostgres + inserter pgInserter + flusher pgFlusher + useConstraints bool +} + +func NewDatabase(schema *schema.Schema, dsn *db.DSN, moduleOutputType string, rootMessageDescriptor protoreflect.MessageDescriptor, useProtoOptions bool, useConstraints bool, bytesEncoding bytes.Encoding, logger *zap.Logger) (*Database, error) { + logger = logger.Named("postgres") + + logger.Info("connecting to db", zap.String("host", dsn.Host), zap.Int64("port", dsn.Port), zap.String("database", dsn.Database)) + sqlDB, err := pgsql.Open(dsn.Driver(), dsn.ConnString()) + if err != nil { + return nil, fmt.Errorf("open db connection: %w", err) + } + + if reachable, err := isDatabaseReachable(sqlDB); !reachable { + return nil, fmt.Errorf("database not reachable: %w", err) + } + + dialect, err := NewDialectPostgres(schema, bytesEncoding, logger) + if err != nil { + return nil, fmt.Errorf("creating postgres dialect: %w", err) + } + + baseDB, err := sql2.NewBaseDatabase(moduleOutputType, rootMessageDescriptor, useProtoOptions, logger) + if err != nil { + return nil, fmt.Errorf("failed to create base database: %w", err) + } + database := &Database{ + db: sqlDB, + schema: schema, + useConstraints: useConstraints, + BaseDatabase: baseDB, + dialect: dialect, + logger: logger, + } + + return database, nil +} + +func (d *Database) Open() error { + if d.useConstraints { + inserter, err := NewRowInserter(d.logger) + if err != nil { + return fmt.Errorf("creating row inserter: %w", err) + } + if err := inserter.init(d); err != nil { + return fmt.Errorf("initializing row inserter: %w", err) + } + d.inserter = inserter + d.flusher = inserter + } else { + inserter, err := NewAccumulatorInserter(d.logger) + if err != nil { + return fmt.Errorf("creating accumulator inserter: %w", err) + } + if err := inserter.init(d); err != nil { + return fmt.Errorf("initializing row inserter: %w", err) + } + d.inserter = inserter + d.flusher = inserter + } + return nil +} + +func (d *Database) GetDialect() sql2.Dialect { + return d.dialect +} + +func (d *Database) CreateDatabase(useConstraints bool) error { + err := d.createDatabase() + if err != nil { + return fmt.Errorf("creating database: %w", err) + } + + if useConstraints { + err = d.applyConstraints() + if err != nil { + return fmt.Errorf("applying constraints: %w", err) + } + } + + return nil +} + +func (d *Database) createDatabase() error { + staticSql := fmt.Sprintf(postgresStaticSql, d.schema.Name, d.schema.Name, d.schema.Name, d.schema.Name) + _, err := d.tx.Exec(staticSql) + if err != nil { + return fmt.Errorf("executing static staticSql: %w\n%s", err, staticSql) + } + + for _, statement := range d.dialect.CreateTableSql { + d.logger.Info("executing create statement", zap.String("sql", statement)) + _, err := d.tx.Exec(statement) + if err != nil { + return fmt.Errorf("executing create statement: %w %s", err, statement) + } + } + return nil +} + +func (d *Database) applyConstraints() error { + startAt := time.Now() + for _, constraint := range d.dialect.PrimaryKeySql { + d.logger.Info("executing pk statement", zap.String("sql", constraint.Sql)) + _, err := d.tx.Exec(constraint.Sql) + if err != nil { + return fmt.Errorf("executing pk statement: %w %s", err, constraint.Sql) + } + } + for _, constraint := range d.dialect.UniqueConstraintSql { + d.logger.Info("executing unique statement", zap.String("sql", constraint.Sql)) + _, err := d.tx.Exec(constraint.Sql) + if err != nil { + return fmt.Errorf("executing unique statement: %w %s", err, constraint.Sql) + } + } + for _, constraint := range d.dialect.ForeignKeySql { + d.logger.Info("executing fk constraint statement", zap.String("sql", constraint.Sql)) + _, err := d.tx.Exec(constraint.Sql) + if err != nil { + return fmt.Errorf("executing fk constraint statement: %w %s", err, constraint.Sql) + } + } + d.logger.Info("applying constraints", zapx.HumanDuration("duration", time.Since(startAt))) + return nil +} + +func (d *Database) BeginTransaction() (err error) { + d.tx, err = d.db.Begin() + if err != nil { + return fmt.Errorf("beginning transaction: %w", err) + } + return nil +} + +func (d *Database) CommitTransaction() (err error) { + err = d.tx.Commit() + if err != nil { + return fmt.Errorf("committing transaction: %w", err) + } + d.tx = nil + return nil +} + +func (d *Database) RollbackTransaction() { + err := d.tx.Rollback() + if err != nil { + panic("RollbackTransaction failed: " + err.Error()) + } +} + +func (d *Database) wrapInsertStatement(stmt *pgsql.Stmt) *pgsql.Stmt { + if d.tx != nil { + stmt = d.tx.Stmt(stmt) + } + return stmt +} + +func (d *Database) Insert(table string, values []any) error { + return d.inserter.insert(table, values, d) +} + +func (d *Database) WalkMessageDescriptorAndInsert(dm *dynamicpb.Message, blockNum uint64, blockTimestamp time.Time, parent *sql2.Parent) (time.Duration, error) { + return d.WalkMessageDescriptorAndInsertWithDialect(dm, blockNum, blockTimestamp, parent, d.dialect, d) +} + +func (d *Database) InsertBlock(blockNum uint64, hash string, timestamp time.Time) error { + d.logger.Debug("inserting _blocks_", zap.Uint64("block_num", blockNum), zap.String("block_hash", hash)) + err := d.inserter.insert("_blocks_", []any{blockNum, hash, timestamp}, d) + if err != nil { + return fmt.Errorf("inserting block %d: %w", blockNum, err) + } + + return nil +} + +func (d *Database) Flush() (time.Duration, error) { + startFlush := time.Now() + err := d.flusher.flush(d) + if err != nil { + return 0, fmt.Errorf("flushing: %w", err) + } + return time.Since(startFlush), nil +} + +func (d *Database) FetchSinkInfo(schemaName string) (*sql2.SinkInfo, error) { + query := fmt.Sprintf("SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_schema = '%s' AND table_name = '_sink_info_')", schemaName) + + var exist bool + err := d.db.QueryRow(query).Scan(&exist) + if err != nil { + return nil, fmt.Errorf("checking if sync_info table exists: %w", err) + } + if !exist { + return nil, nil + } + + out := &sql2.SinkInfo{} + + err = d.db.QueryRow(fmt.Sprintf("SELECT schema_hash FROM %s._sink_info_", d.schema.Name)).Scan(&out.SchemaHash) + if err != nil { + return nil, fmt.Errorf("fetching sync info: %w", err) + } + return out, nil + +} + +func (d *Database) StoreSinkInfo(schemaName string, schemaHash string) error { + _, err := d.tx.Exec(fmt.Sprintf("INSERT INTO %s._sink_info_ (schema_hash) VALUES ($1)", schemaName), schemaHash) + if err != nil { + return fmt.Errorf("storing schema hash: %w", err) + } + return nil +} + +func (d *Database) UpdateSinkInfoHash(schemaName string, newHash string) error { + _, err := d.tx.Exec(fmt.Sprintf("UPDATE %s._sink_info_ SET schema_hash = $1", schemaName), newHash) + if err != nil { + return fmt.Errorf("updating schema hash: %w", err) + } + return nil +} + +func (d *Database) FetchCursor() (*sink.Cursor, error) { + query := fmt.Sprintf("SELECT cursor FROM %s WHERE name = $1", tableName(d.schema.Name, "_cursor_")) + + rows, err := d.db.Query(query, "cursor") + if err != nil { + return nil, fmt.Errorf("selecting cursor: %w", err) + } + defer rows.Close() + + if rows.Next() { + var cursor string + err = rows.Scan(&cursor) + + return sink.NewCursor(cursor) + } + return nil, nil +} + +func (d *Database) StoreCursor(cursor *sink.Cursor) error { + err := d.inserter.insert("_cursor_", []any{"cursor", cursor.String()}, d) + if err != nil { + return fmt.Errorf("inserting cursor: %w", err) + } + + return err +} + +func (d *Database) HandleBlocksUndo(lastValidBlockNum uint64) (err error) { + tx, err := d.db.Begin() + if err != nil { + return fmt.Errorf("HandleBlocksUndo beginning transaction: %w", err) + } + defer func() { + if err != nil { + e := tx.Rollback() + if e != nil { + err = fmt.Errorf("HandleBlocksUndo rolling back transaction: %w", e) + } + err = fmt.Errorf("HandleBlocksUndo processing entity: %w", err) + + return + } + err = tx.Commit() + }() + + d.logger.Info("undoing blocks", zap.Uint64("last_valid_block_num", lastValidBlockNum)) + query := fmt.Sprintf(`DELETE FROM %s._blocks_ WHERE "number" > $1`, d.schema.Name) + result, err := tx.Exec(query, lastValidBlockNum) + if err != nil { + return fmt.Errorf("deleting block from %d: %w", lastValidBlockNum, err) + } + rowsAffected, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("fetching rows affected: %w", err) + } + d.logger.Info("undo completed", zap.Int64("row_affected", rowsAffected)) + + return nil +} + +func (d *Database) Clone() sql2.Database { + base := d.BaseClone() + d.BaseDatabase = base + return d +} + +func (d *Database) DatabaseHash(schemaName string) (uint64, error) { + query := ` +SELECT + c.table_name, + c.column_name, + c.is_nullable, + c.data_type, + c.character_maximum_length, + c.numeric_precision, + c.numeric_precision_radix, + c.numeric_scale, + c.datetime_precision, + c.interval_precision, + c.is_generated, + c.is_updatable, + tc.constraint_name, + tc.table_name, + tc.constraint_type, + kcu.column_name, + kcu.table_name, + kcu.column_name, + ccu.constraint_name, + ccu.table_name, + ccu.column_name +FROM + information_schema.columns c + LEFT JOIN + information_schema.constraint_column_usage ccu + ON c.table_name = ccu.table_name + AND c.column_name = ccu.column_name + AND c.table_schema = ccu.table_schema + LEFT JOIN + information_schema.key_column_usage kcu + ON ccu.constraint_name = kcu.constraint_name + AND c.table_schema = kcu.table_schema + LEFT JOIN + information_schema.table_constraints tc + ON kcu.constraint_name = tc.constraint_name + AND kcu.table_schema = tc.table_schema +WHERE + c.table_schema = '%s' +ORDER BY + c.table_name, + c.column_name, + tc.table_name, + tc.constraint_name, + kcu.table_name, + kcu.column_name, + kcu.constraint_name; +` + + query = fmt.Sprintf(query, schemaName) + + rows, err := d.db.Query(query) + if err != nil { + return 0, fmt.Errorf("executing query to compute schema hash: %w", err) + } + defer rows.Close() + + h := fnv.New64a() + columns, err := rows.Columns() + if err != nil { + return 0, fmt.Errorf("fetching columns for hashing: %w", err) + } + + values := make([]interface{}, len(columns)) + valuePtrs := make([]interface{}, len(columns)) + for i := range values { + valuePtrs[i] = &values[i] + } + + for rows.Next() { + err = rows.Scan(valuePtrs...) + if err != nil { + return 0, fmt.Errorf("scanning row for hashing: %w", err) + } + + for _, val := range values { + var str string + if val != nil { + str = fmt.Sprintf("%v", val) + } + _, err = h.Write([]byte(str)) + if err != nil { + return 0, fmt.Errorf("hashing value %q: %w", str, err) + } + } + } + + if err = rows.Err(); err != nil { + return 0, fmt.Errorf("iterating rows: %w", err) + } + + return h.Sum64(), nil +} + +func isDatabaseReachable(db *pgsql.DB) (bool, error) { + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + err := db.PingContext(ctx) + if err != nil { + return false, err + } + return true, nil +} diff --git a/sink/sql/db_proto/sql/postgres/dialect.go b/sink/sql/db_proto/sql/postgres/dialect.go new file mode 100644 index 000000000..3afdc4950 --- /dev/null +++ b/sink/sql/db_proto/sql/postgres/dialect.go @@ -0,0 +1,288 @@ +package postgres + +import ( + "encoding/hex" + "fmt" + "hash/fnv" + "sort" + "strings" + + "github.com/lib/pq" + "github.com/streamingfast/substreams/sink/sql/bytes" + sql2 "github.com/streamingfast/substreams/sink/sql/db_proto/sql" + "github.com/streamingfast/substreams/sink/sql/db_proto/sql/schema" + "go.uber.org/zap" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/dynamicpb" +) + +const postgresStaticSql = ` + CREATE SCHEMA IF NOT EXISTS "%s"; + + CREATE TABLE IF NOT EXISTS "%s"._sink_info_ ( + schema_hash TEXT PRIMARY KEY + ); + + CREATE TABLE IF NOT EXISTS "%s"._cursor_ ( + name TEXT PRIMARY KEY, + cursor TEXT NOT NULL + ); + + CREATE TABLE IF NOT EXISTS "%s"._blocks_ ( + number integer, + hash TEXT NOT NULL, + timestamp TIMESTAMP NOT NULL + ); +` + +type DialectPostgres struct { + *sql2.BaseDialect + schemaName string + bytesEncoding bytes.Encoding +} + +func NewDialectPostgres(schema *schema.Schema, bytesEncoding bytes.Encoding, logger *zap.Logger) (*DialectPostgres, error) { + logger = logger.Named("postgres dialect") + + d := &DialectPostgres{ + BaseDialect: sql2.NewBaseDialect(schema.TableRegistry, logger), + schemaName: schema.Name, + bytesEncoding: bytesEncoding, + } + + err := d.init() + if err != nil { + return nil, fmt.Errorf("initializing dialect: %w", err) + } + + for _, table := range schema.TableRegistry { + err := d.createTable(table) + if err != nil { + return nil, fmt.Errorf("handling table %q: %w", table.Name, err) + } + } + + return d, nil +} + +func (d *DialectPostgres) UseVersionField() bool { + return false +} +func (d *DialectPostgres) UseDeletedField() bool { + return false +} + +func (d *DialectPostgres) init() error { + d.AddPrimaryKeySql(sql2.DialectTableBlock, fmt.Sprintf("alter table %s.%s add constraint block_pk primary key (number);", d.schemaName, sql2.DialectTableBlock)) + return nil +} + +func (d *DialectPostgres) createTable(table *schema.Table) error { + var sb strings.Builder + + tableName := d.FullTableName(table) + + sb.WriteString(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (", tableName)) + + sb.WriteString(fmt.Sprintf(" %s INTEGER NOT NULL,", sql2.DialectFieldBlockNumber)) + sb.WriteString(fmt.Sprintf(" %s TIMESTAMP NOT NULL,", sql2.DialectFieldBlockTimestamp)) + + var primaryKeyFieldName string + if table.PrimaryKey != nil { + pk := table.PrimaryKey + primaryKeyFieldName = pk.Name + d.AddPrimaryKeySql(table.Name, fmt.Sprintf("alter table %s add constraint %s_pk primary key (%s);", tableName, table.Name, primaryKeyFieldName)) + sb.WriteString(fmt.Sprintf("%s %s,", pk.Name, MapFieldType(pk.FieldDescriptor, d.bytesEncoding, table.Columns[pk.Index]))) + } + + if table.ChildOf != nil { + parentTable, parentFound := d.TableRegistry[table.ChildOf.ParentTable] + if !parentFound { + return fmt.Errorf("parent table %q not found", table.ChildOf.ParentTable) + } + fieldFound := false + for _, parentField := range parentTable.Columns { + + if parentField.Name == table.ChildOf.ParentTableField { + + sb.WriteString(fmt.Sprintf("%s %s NOT NULL,", parentField.Name, MapFieldType(parentField.FieldDescriptor, d.bytesEncoding, parentField))) + + foreignKey := &sql2.ForeignKey{ + Name: "fk_" + table.ChildOf.ParentTable, + Table: tableName, + Field: table.ChildOf.ParentTableField, + ForeignTable: d.FullTableName(parentTable), + ForeignField: parentField.Name, + } + + d.AddForeignKeySql(table.Name, foreignKey.String()) + + fieldFound = true + break + } + } + if !fieldFound { + return fmt.Errorf("field %q not found in table %q", table.ChildOf.ParentTableField, table.ChildOf.ParentTable) + } + } + + for _, f := range table.Columns { + if f.Name == primaryKeyFieldName { + continue + } + + fieldQuotedName := f.QuotedName() + + switch { + case f.IsRepeated: + case f.Nested != nil: + fmt.Println("found nested type") + case f.IsMessage && !IsWellKnownType(f.FieldDescriptor): + childTable, found := d.TableRegistry[f.Message] + if !found { + continue + } + if childTable.PrimaryKey == nil { + continue + } + foreignKey := &sql2.ForeignKey{ + Name: "fk_" + childTable.Name, + Table: tableName, + Field: fieldQuotedName, + ForeignTable: d.FullTableName(childTable), + ForeignField: childTable.PrimaryKey.Name, + } + d.AddForeignKeySql(table.Name, foreignKey.String()) + + case f.ForeignKey != nil: + foreignTable, found := d.TableRegistry[f.ForeignKey.Table] + if !found { + return fmt.Errorf("foreign table %q not found", f.ForeignKey.Table) + } + + var foreignField *schema.Column + for _, field := range foreignTable.Columns { + if field.Name == f.ForeignKey.TableField { + foreignField = field + break + } + } + if foreignField == nil { + return fmt.Errorf("foreign field %q not found in table %q", f.ForeignKey.TableField, f.ForeignKey.Table) + } + + foreignKey := &sql2.ForeignKey{ + Name: "fk_" + f.Name, + Table: tableName, + Field: f.Name, + ForeignTable: d.FullTableName(foreignTable), + ForeignField: foreignField.Name, + } + d.AddForeignKeySql(table.Name, foreignKey.String()) + } + fieldType := MapFieldType(f.FieldDescriptor, d.bytesEncoding, f) + if f.IsUnique { + d.AddUniqueConstraintSql(table.Name, fmt.Sprintf("alter table %s add constraint %s_%s_unique unique (%s);", tableName, table.Name, f.Name, fieldQuotedName)) + } + + sb.WriteString(fmt.Sprintf("%s %s", fieldQuotedName, fieldType)) + sb.WriteString(",") + } + + temp := sb.String() + temp = temp[:len(temp)-1] + sb = strings.Builder{} + sb.WriteString(temp) + + sb.WriteString(");\n") + + d.AddForeignKeySql(tableName, fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT fk_block FOREIGN KEY (%s) REFERENCES %s.%s(number) ON DELETE CASCADE", tableName, sql2.DialectFieldBlockNumber, d.schemaName, sql2.DialectTableBlock)) + d.AddCreateTableSql(table.Name, sb.String()) + + return nil + +} + +func (d *DialectPostgres) FullTableName(table *schema.Table) string { + return tableName(d.schemaName, table.Name) +} + +func (d *DialectPostgres) AppendInlineFieldValues(fieldValues []any, fd protoreflect.FieldDescriptor, fv protoreflect.Value, dm *dynamicpb.Message) ([]any, error) { + if fd.IsList() { + list := fv.List() + var jsonStrings []string + for j := 0; j < list.Len(); j++ { + fm := list.Get(j).Message().Interface().(*dynamicpb.Message) + jsonBytes, err := protojson.Marshal(fm) + if err != nil { + return nil, fmt.Errorf("failed to marshal protobuf message to JSON: %w", err) + } + jsonStrings = append(jsonStrings, string(jsonBytes)) + } + fieldValues = append(fieldValues, pq.Array(jsonStrings)) + } else { + fm := fv.Message().Interface().(*dynamicpb.Message) + jsonBytes, err := protojson.Marshal(fm) + if err != nil { + return nil, fmt.Errorf("failed to marshal protobuf message to JSON: %w", err) + } + fieldValues = append(fieldValues, string(jsonBytes)) + } + return fieldValues, nil +} + +func (d *DialectPostgres) SchemaHash() string { + h := fnv.New64a() + + var buf []byte + + var sqls []string + for _, sql := range d.CreateTableSql { + sqls = append(sqls, sql) + } + + sort.Strings(sqls) + for _, sql := range sqls { + buf = append(buf, []byte(sql)...) + } + + var pk []string + for _, constraint := range d.PrimaryKeySql { + pk = append(pk, constraint.Sql) + } + sort.Strings(pk) + for _, constraint := range pk { + buf = append(buf, []byte(constraint)...) + } + + var fk []string + for _, constraint := range d.ForeignKeySql { + fk = append(fk, constraint.Sql) + } + sort.Strings(fk) + for _, constraint := range fk { + buf = append(buf, []byte(constraint)...) + } + + var uniques []string + for _, constraint := range d.UniqueConstraintSql { + uniques = append(uniques, constraint.Sql) + } + sort.Strings(uniques) + for _, constraint := range uniques { + buf = append(buf, []byte(constraint)...) + } + + _, err := h.Write(buf) + if err != nil { + panic("unable to write to hash") + } + + data := h.Sum(nil) + return hex.EncodeToString(data) +} + +func tableName(schemaName string, tableName string) string { + return fmt.Sprintf("%s.%s", schemaName, tableName) +} diff --git a/sink/sql/db_proto/sql/postgres/inserter.go b/sink/sql/db_proto/sql/postgres/inserter.go new file mode 100644 index 000000000..d33ad882f --- /dev/null +++ b/sink/sql/db_proto/sql/postgres/inserter.go @@ -0,0 +1,9 @@ +package postgres + +type pgInserter interface { + insert(table string, values []any, database *Database) error +} + +type pgFlusher interface { + flush(database *Database) error +} diff --git a/sink/sql/db_proto/sql/postgres/row_inserter.go b/sink/sql/db_proto/sql/postgres/row_inserter.go new file mode 100644 index 000000000..3a333d315 --- /dev/null +++ b/sink/sql/db_proto/sql/postgres/row_inserter.go @@ -0,0 +1,193 @@ +package postgres + +import ( + "database/sql" + "encoding/base64" + "fmt" + "strconv" + "strings" + "time" + + sql2 "github.com/streamingfast/substreams/sink/sql/db_proto/sql" + "github.com/streamingfast/substreams/sink/sql/db_proto/sql/schema" + "go.uber.org/zap" + "google.golang.org/protobuf/types/known/timestamppb" +) + +type RowInserter struct { + insertQueries map[string]string + insertStatements map[string]*sql.Stmt + logger *zap.Logger + database *Database +} + +func NewRowInserter(logger *zap.Logger) (*RowInserter, error) { + logger = logger.Named("postgres inserter") + + return &RowInserter{ + logger: logger, + }, nil +} + +func (i *RowInserter) init(database *Database) error { + tables := database.dialect.GetTables() + insertStatements := map[string]*sql.Stmt{} + insertQueries := map[string]string{} + + i.database = database + + for _, table := range tables { + query, err := createInsertFromDescriptor(table, database.dialect) + if err != nil { + return fmt.Errorf("creating insert from descriptor for table %q: %w", table.Name, err) + } + insertQueries[table.Name] = query + + stmt, err := database.db.Prepare(query) + if err != nil { + return fmt.Errorf("preparing statement %q: %w", query, err) + } + insertStatements[table.Name] = stmt + } + + insertQueries["_blocks_"] = fmt.Sprintf("INSERT INTO %s (number, hash, timestamp) VALUES ($1, $2, $3) RETURNING number", tableName(database.schema.Name, "_blocks_")) + bs, err := database.db.Prepare(insertQueries["_blocks_"]) + if err != nil { + return fmt.Errorf("preparing statement %q: %w", insertQueries["_blocks_"], err) + } + insertStatements["_blocks_"] = bs + + insertQueries["_cursor_"] = fmt.Sprintf("INSERT INTO %s (name, cursor) VALUES ($1, $2) ON CONFLICT (name) DO UPDATE SET cursor = $2", tableName(database.schema.Name, "_cursor_")) + cs, err := database.db.Prepare(insertQueries["_cursor_"]) + if err != nil { + return fmt.Errorf("preparing statement %q: %w", insertQueries["_cursor_"], err) + } + insertStatements["_cursor_"] = cs + + i.insertQueries = insertQueries + i.insertStatements = insertStatements + + return nil +} + +func createInsertFromDescriptor(table *schema.Table, dialect sql2.Dialect) (string, error) { + tableName := dialect.FullTableName(table) + fields := table.Columns + + var fieldNames []string + var placeholders []string + + fieldCount := 0 + returningField := "" + if table.PrimaryKey != nil { + returningField = table.PrimaryKey.Name + } + + fieldCount++ + fieldNames = append(fieldNames, sql2.DialectFieldBlockNumber) + placeholders = append(placeholders, fmt.Sprintf("$%d", fieldCount)) + fieldCount++ + fieldNames = append(fieldNames, sql2.DialectFieldBlockTimestamp) + placeholders = append(placeholders, fmt.Sprintf("$%d", fieldCount)) + + if pk := table.PrimaryKey; pk != nil { + fieldCount++ + returningField = pk.Name + fieldNames = append(fieldNames, pk.Name) + placeholders = append(placeholders, fmt.Sprintf("$%d", fieldCount)) + } + + if table.ChildOf != nil { + fieldCount++ + fieldNames = append(fieldNames, table.ChildOf.ParentTableField) + placeholders = append(placeholders, fmt.Sprintf("$%d", fieldCount)) + } + + for _, field := range fields { + if field.Name == returningField { + continue + } + if field.IsExtension { + continue + } + if field.IsRepeated && field.Nested == nil { + if field.IsMessage { + continue + } + } + fieldCount++ + fieldNames = append(fieldNames, field.QuotedName()) + placeholders = append(placeholders, fmt.Sprintf("$%d", fieldCount)) + } + + return fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", + tableName, + strings.Join(fieldNames, ", "), + strings.Join(placeholders, ", "), + ), nil + +} + +func (i *RowInserter) insert(table string, values []any, database *Database) error { + i.logger.Debug("inserting row", zap.String("table", table), zap.Any("values", values)) + stmt := i.insertStatements[table] + stmt = database.wrapInsertStatement(stmt) + + t := i.database.dialect.TableRegistry[table] + + fieldIndexOffset := 2 + if t != nil && t.ChildOf != nil { + fieldIndexOffset = 3 + } + + for i, value := range values { + + var column *schema.Column + fieldIndex := i - fieldIndexOffset + + if t != nil && fieldIndex >= 0 { + column = t.Columns[fieldIndex] + } + + switch v := value.(type) { + case string: + if column != nil && column.ConvertTo != nil && column.ConvertTo.Convertion != nil { + if v == "" { + values[i] = 0 + } + } + case uint64: + values[i] = strconv.FormatUint(v, 10) + case []uint8: + if database.dialect.bytesEncoding.IsStringType() { + encoded, err := database.dialect.bytesEncoding.EncodeBytes(v) + if err != nil { + return fmt.Errorf("failed to encode bytes: %v", err) + } + values[i] = encoded.(string) + continue + } + values[i] = "'" + base64.StdEncoding.EncodeToString(v) + "'" + case *timestamppb.Timestamp: + values[i] = "'" + v.AsTime().Format(time.RFC3339) + "'" + case []interface{}: + var elements []string + for _, elem := range v { + elements = append(elements, ValueToString(elem, database.dialect.bytesEncoding)) + } + values[i] = "{" + strings.Join(elements, ",") + "}" + } + } + + _, err := stmt.Exec(values...) + if err != nil { + insert := i.insertQueries[table] + return fmt.Errorf("pg accumalator inserter: querying insert %q: %w", insert, err) + } + + return nil +} + +func (i *RowInserter) flush(database *Database) error { + return nil +} diff --git a/sink/sql/db_proto/sql/postgres/types.go b/sink/sql/db_proto/sql/postgres/types.go new file mode 100644 index 000000000..d2fc488ab --- /dev/null +++ b/sink/sql/db_proto/sql/postgres/types.go @@ -0,0 +1,189 @@ +package postgres + +import ( + "encoding/hex" + "fmt" + "strconv" + "strings" + "time" + + "github.com/streamingfast/substreams/sink/sql/bytes" + "github.com/streamingfast/substreams/sink/sql/db_proto/sql/schema" + v1 "github.com/streamingfast/substreams/pb/sf/substreams/sink/sql/schema/v1" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/known/timestamppb" +) + +type DataType string + +const ( + TypeNumeric DataType = "NUMERIC" + TypeInteger DataType = "INTEGER" + TypeBool DataType = "BOOLEAN" + TypeBigInt DataType = "BIGINT" + TypeDecimal DataType = "DECIMAL" + TypeDouble DataType = "DOUBLE PRECISION" + TypeText DataType = "TEXT" + TypeBlob DataType = "BLOB" + TypeVarchar DataType = "VARCHAR(255)" + TypeBytea DataType = "BYTEA" + TypeTimestamp DataType = "TIMESTAMP" + TypeJSONB DataType = "JSONB" +) + +func (s DataType) String() string { + return string(s) +} + +func IsWellKnownType(fd protoreflect.FieldDescriptor) bool { + if fd.Kind() != protoreflect.MessageKind { + return false + } + switch string(fd.Message().FullName()) { + case "google.protobuf.Timestamp": + return true + default: + return false + } +} + +func MapFieldType(fd protoreflect.FieldDescriptor, bytesEncoding bytes.Encoding, column *schema.Column) DataType { + kind := fd.Kind() + var baseType DataType + + switch kind { + case protoreflect.MessageKind: + if column.Nested != nil { + baseType = TypeJSONB + } else { + switch string(fd.Message().FullName()) { + case "google.protobuf.Timestamp": + baseType = TypeTimestamp + default: + panic(fmt.Sprintf("Message type not supported: %s", string(fd.Message().FullName()))) + } + } + case protoreflect.BoolKind: + baseType = TypeBool + case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind: + baseType = TypeInteger + case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind: + baseType = TypeBigInt + case protoreflect.Uint64Kind, protoreflect.Fixed64Kind: + baseType = TypeNumeric + case protoreflect.Uint32Kind, protoreflect.Fixed32Kind: + baseType = TypeNumeric + case protoreflect.FloatKind: + baseType = TypeDecimal + case protoreflect.DoubleKind: + baseType = TypeDouble + case protoreflect.StringKind: + if column.ConvertTo != nil && column.ConvertTo.Convertion != nil { + switch column.ConvertTo.Convertion.(type) { + case *v1.StringConvertion_Int128: + baseType = TypeNumeric + case *v1.StringConvertion_Uint128: + baseType = TypeNumeric + case *v1.StringConvertion_Int256: + baseType = TypeNumeric + case *v1.StringConvertion_Uint256: + baseType = TypeNumeric + case *v1.StringConvertion_Decimal128: + decimal128Conv := column.ConvertTo.Convertion.(*v1.StringConvertion_Decimal128) + baseType = DataType(fmt.Sprintf("DECIMAL(38,%d)", decimal128Conv.Decimal128.Scale)) + case *v1.StringConvertion_Decimal256: + decimal256Conv := column.ConvertTo.Convertion.(*v1.StringConvertion_Decimal256) + baseType = DataType(fmt.Sprintf("DECIMAL(76,%d)", decimal256Conv.Decimal256.Scale)) + default: + baseType = TypeVarchar + } + } else { + baseType = TypeVarchar + } + case protoreflect.BytesKind: + if bytesEncoding.IsStringType() { + baseType = TypeText + } else { + baseType = TypeBytea + } + case protoreflect.EnumKind: + baseType = TypeText + default: + panic(fmt.Sprintf("unsupported type: %s", kind)) + } + + if fd.IsList() { + return DataType(fmt.Sprintf("%s[]", baseType)) + } + + return baseType +} + +func ValueToString(value any, bytesEncoding bytes.Encoding) (s string) { + switch v := value.(type) { + case string: + s = "'" + strings.ReplaceAll(strings.ReplaceAll(v, "'", "''"), "\\", "\\\\") + "'" + case int64: + s = strconv.FormatInt(v, 10) + case int32: + s = strconv.FormatInt(int64(v), 10) + case int: + s = strconv.FormatInt(int64(v), 10) + case uint64: + s = strconv.FormatUint(v, 10) + case uint32: + s = strconv.FormatUint(uint64(v), 10) + case uint: + s = strconv.FormatUint(uint64(v), 10) + case float64: + s = strconv.FormatFloat(v, 'f', -1, 64) + case float32: + s = strconv.FormatFloat(float64(v), 'f', -1, 32) + case []uint8: + if bytesEncoding == bytes.EncodingRaw { + s = "E'" + hex.EncodeToString(v) + "'::BYTEA" + } else { + encoded, err := bytesEncoding.EncodeBytes(v) + if err != nil { + panic(fmt.Sprintf("failed to encode bytes: %v", err)) + } + s = "'" + encoded.(string) + "'" + } + case bool: + s = strconv.FormatBool(v) + case time.Time: + s = "'" + v.Format(time.RFC3339) + "'" + case *timestamppb.Timestamp: + s = "'" + v.AsTime().Format(time.RFC3339) + "'" + case []interface{}: + if len(v) == 0 { + s = "'{}'" + return + } + + var elements []string + for _, elem := range v { + elements = append(elements, ValueToString(elem, bytesEncoding)) + } + s = "array[" + strings.Join(elements, ",") + "]" + case protoreflect.Message: + jsonBytes, err := protojson.Marshal(v.Interface()) + if err != nil { + panic(fmt.Sprintf("failed to marshal protobuf message to JSON: %v", err)) + } + s = "'" + strings.ReplaceAll(strings.ReplaceAll(string(jsonBytes), "'", "''"), "\\", "\\\\") + "'" + return + default: + if msg, ok := v.(protoreflect.ProtoMessage); ok { + jsonBytes, err := protojson.Marshal(msg) + if err != nil { + panic(fmt.Sprintf("failed to marshal protobuf message to JSON: %v", err)) + } + s = "'" + strings.ReplaceAll(strings.ReplaceAll(string(jsonBytes), "'", "''"), "\\", "\\\\") + "'" + return + } + panic(fmt.Sprintf("unsupported type: %T", v)) + } + return +} diff --git a/sink/sql/db_proto/sql/schema/column.go b/sink/sql/db_proto/sql/schema/column.go new file mode 100644 index 000000000..17f2df846 --- /dev/null +++ b/sink/sql/db_proto/sql/schema/column.go @@ -0,0 +1,90 @@ +package schema + +import ( + "fmt" + "strings" + + v1 "github.com/streamingfast/substreams/pb/sf/substreams/sink/sql/schema/v1" + "google.golang.org/protobuf/reflect/protoreflect" +) + +type Column struct { + Name string + ForeignKey *ForeignKey + FieldDescriptor protoreflect.FieldDescriptor + IsPrimaryKey bool + IsUnique bool + IsRepeated bool + IsExtension bool + IsMessage bool + IsOptional bool + Nested *Table + Message string + ConvertTo *v1.StringConvertion +} + +func NewColumn(d protoreflect.FieldDescriptor, fieldInfo *v1.Column, ordinal int, inlineDepth int) (*Column, error) { + out := &Column{ + Name: string(d.Name()), + FieldDescriptor: d, + IsRepeated: d.IsList(), + IsMessage: d.Kind() == protoreflect.MessageKind, + IsExtension: d.IsExtension(), + IsOptional: d.HasOptionalKeyword(), + } + + if fieldInfo != nil { + if fieldInfo.Inline { + if inlineDepth >= 1 { + return nil, fmt.Errorf("inline nesting level %d is not supported for column %q: only 1 level of inline nesting is allowed", inlineDepth+1, out.Name) + } + ti := &v1.Table{ + Name: out.Name, + } + nested, err := NewTable(d.Message(), ti, ordinal+1, inlineDepth+1) + if err != nil { + return nil, fmt.Errorf("creating nested column %s: %w", out.Name, err) + } + out.Nested = nested + } + + if fieldInfo.Name != nil { + out.Name = *fieldInfo.Name + } + if fieldInfo.ForeignKey != nil { + fk, err := NewForeignKey(*fieldInfo.ForeignKey) + if err != nil { + return nil, fmt.Errorf("error parsing foreign key %s: %w", *fieldInfo.ForeignKey, err) + } + out.ForeignKey = fk + } + out.IsPrimaryKey = fieldInfo.PrimaryKey + out.IsUnique = fieldInfo.Unique + out.ConvertTo = fieldInfo.ConvertTo + } + + if out.IsMessage { + out.Message = string(d.Message().Name()) + } + return out, nil +} + +func (c *Column) QuotedName() string { + return fmt.Sprintf("%q", c.Name) +} + +type ForeignKey struct { + Table string + TableField string +} + +func NewForeignKey(foreignKey string) (*ForeignKey, error) { + parts := strings.Split(foreignKey, " on ") + if len(parts) != 2 { + return nil, fmt.Errorf("invalid foreign key format %q. expecting 'table_name on field_name' format", foreignKey) + } + return &ForeignKey{ + Table: strings.TrimSpace(parts[0]), + TableField: strings.TrimSpace(parts[1]), + }, nil +} diff --git a/sink/sql/db_proto/sql/schema/schema.go b/sink/sql/db_proto/sql/schema/schema.go new file mode 100644 index 000000000..bf5635075 --- /dev/null +++ b/sink/sql/db_proto/sql/schema/schema.go @@ -0,0 +1,107 @@ +package schema + +import ( + "fmt" + + schema "github.com/streamingfast/substreams/pb/sf/substreams/sink/sql/schema/v1" + "github.com/streamingfast/substreams/sink/sql/proto" + "go.uber.org/zap" + "google.golang.org/protobuf/reflect/protoreflect" +) + +type Schema struct { + Name string + TableRegistry map[string]*Table + logger *zap.Logger + rootMessageDescriptor protoreflect.MessageDescriptor + withProtoOption bool +} + +func NewSchema(name string, rootMessageDescriptor protoreflect.MessageDescriptor, withProtoOption bool, logger *zap.Logger) (*Schema, error) { + logger.Info("creating schema", zap.String("name", name), zap.String("root_message_descriptor", string(rootMessageDescriptor.Name())), zap.Bool("with_proto_option", withProtoOption)) + s := &Schema{ + Name: name, + TableRegistry: make(map[string]*Table), + logger: logger, + rootMessageDescriptor: rootMessageDescriptor, + withProtoOption: withProtoOption, + } + + err := s.init(rootMessageDescriptor) + if err != nil { + return nil, fmt.Errorf("initializing schema: %w", err) + } + return s, nil +} + +func (s *Schema) ChangeName(name string) error { + s.Name = name + s.TableRegistry = make(map[string]*Table) + err := s.init(s.rootMessageDescriptor) + if err != nil { + return fmt.Errorf("changing schema name: %w", err) + } + + return nil +} + +func (s *Schema) init(rootMessageDescriptor protoreflect.MessageDescriptor) error { + s.logger.Info("initializing schema", zap.String("name", s.Name), zap.String("root_message_descriptor", string(rootMessageDescriptor.Name()))) + err := s.walkMessageDescriptor(rootMessageDescriptor, 0, func(md protoreflect.MessageDescriptor, ordinal int) error { + s.logger.Debug("creating table message descriptor", zap.String("message_descriptor_name", string(md.Name())), zap.Int("ordinal", ordinal)) + tableInfo := proto.TableInfo(md) + if tableInfo == nil { + if s.withProtoOption { + return nil + } + tableInfo = &schema.Table{ + Name: string(md.Name()), + ChildOf: nil, + } + } + if _, found := s.TableRegistry[tableInfo.Name]; found { + return nil + } + table, err := NewTable(md, tableInfo, ordinal, 0) + if err != nil { + return fmt.Errorf("creating table message descriptor: %w", err) + } + if table != nil { + s.logger.Debug("created table message descriptor", zap.String("message_descriptor_name", string(md.Name())), zap.Int("ordinal", ordinal), zap.String("table_name", table.Name)) + s.TableRegistry[tableInfo.Name] = table + } + return nil + }) + + if err != nil { + return fmt.Errorf("walking and creating table message descriptors registry: %q: %w", string(rootMessageDescriptor.Name()), err) + } + + return nil +} + +func (s *Schema) walkMessageDescriptor(md protoreflect.MessageDescriptor, ordinal int, task func(md protoreflect.MessageDescriptor, ordinal int) error) error { + s.logger.Debug("walking message descriptor", zap.String("message_descriptor_name", string(md.Name())), zap.Int("ordinal", ordinal)) + fields := md.Fields() + for i := 0; i < fields.Len(); i++ { + field := fields.Get(i) + s.logger.Debug("walking field", zap.String("field_name", string(field.Name())), zap.String("field_type", field.Kind().String())) + if field.Kind() == protoreflect.MessageKind { + err := s.walkMessageDescriptor(field.Message(), ordinal+1, task) + if err != nil { + return fmt.Errorf("walking field %q message descriptor: %w", string(field.Name()), err) + } + } + } + + err := task(md, ordinal) + if err != nil { + return fmt.Errorf("running task on message descriptor %q: %w", string(md.Name()), err) + } + + return nil +} + +func (s *Schema) String() string { + return fmt.Sprintf("%s", s.Name) +} diff --git a/sink/sql/db_proto/sql/schema/table.go b/sink/sql/db_proto/sql/schema/table.go new file mode 100644 index 000000000..fdf5c9298 --- /dev/null +++ b/sink/sql/db_proto/sql/schema/table.go @@ -0,0 +1,141 @@ +package schema + +import ( + "fmt" + "strings" + + pbSchmema "github.com/streamingfast/substreams/pb/sf/substreams/sink/sql/schema/v1" + "github.com/streamingfast/substreams/sink/sql/proto" + "google.golang.org/protobuf/reflect/protoreflect" +) + +type PrimaryKey struct { + Name string + FieldDescriptor protoreflect.FieldDescriptor + Index int +} + +type ChildOf struct { + ParentTable string + ParentTableField string +} + +func NewChildOf(childOf string) (*ChildOf, error) { + parts := strings.Split(childOf, " on ") + if len(parts) != 2 { + return nil, fmt.Errorf("invalid child of format %q. expecting 'table_name on field_name' format", childOf) + } + + return &ChildOf{ + ParentTable: strings.TrimSpace(parts[0]), + ParentTableField: strings.TrimSpace(parts[1]), + }, nil +} + +type Table struct { + Name string + PrimaryKey *PrimaryKey + ChildOf *ChildOf + Columns []*Column + Ordinal int + InlineDepth int + PbTableInfo *pbSchmema.Table +} + +func NewTable(descriptor protoreflect.MessageDescriptor, tableInfo *pbSchmema.Table, ordinal int, inlineDepth int) (*Table, error) { + table := &Table{ + Name: string(descriptor.Name()), + Ordinal: ordinal, + InlineDepth: inlineDepth, + PbTableInfo: tableInfo, + } + table.Name = tableInfo.Name + + typeName := string(descriptor.Name()) + isTimestamp := typeName == ".google.protobuf.Timestamp" || typeName == "Timestamp" + if isTimestamp { + return nil, nil + } + + if tableInfo.ChildOf != nil { + co, err := NewChildOf(*tableInfo.ChildOf) + if err != nil { + return nil, fmt.Errorf("error parsing child of: %w", err) + } + table.ChildOf = co + } + + err := table.processColumns(descriptor) + if err != nil { + return nil, fmt.Errorf("error processing fields for table %q: %w", string(descriptor.Name()), err) + } + + if len(table.Columns) == 0 { + return nil, nil + } + + return table, nil +} + +func (t *Table) processColumns(descriptor protoreflect.MessageDescriptor) error { + fields := descriptor.Fields() + for idx := 0; idx < fields.Len(); idx++ { + fieldDescriptor := fields.Get(idx) + fieldInfo := proto.FieldInfo(fieldDescriptor) + + if fieldDescriptor.ContainingOneof() != nil && !fieldDescriptor.HasOptionalKeyword() { + continue + } + + if fieldDescriptor.IsList() { + if fieldDescriptor.Kind() == protoreflect.MessageKind { + if fieldInfo != nil && fieldInfo.Inline { + // Allow inline repeated message fields to be processed as nested columns + } else { + // This will be handled by table relations + continue + } + } + // Allow repeated scalar fields to be processed as array columns + } + + if fieldDescriptor.Kind() == protoreflect.MessageKind { + typeName := string(fieldDescriptor.Message().Name()) + isTimestamp := typeName == ".google.protobuf.Timestamp" || typeName == "Timestamp" + + isInline := fieldInfo != nil && fieldInfo.Inline + if !isTimestamp && !isInline { + continue + } + } + column, err := NewColumn(fieldDescriptor, fieldInfo, t.Ordinal, t.InlineDepth) + if err != nil { + return fmt.Errorf("error processing column %q: %w", string(fieldDescriptor.Name()), err) + } + + if column.IsPrimaryKey { + if t.PrimaryKey != nil { + return fmt.Errorf("multiple field mark has primary keys are not supported") + } + + t.PrimaryKey = &PrimaryKey{ + Name: column.Name, + FieldDescriptor: fieldDescriptor, + Index: idx, + } + } + t.Columns = append(t.Columns, column) + } + + return nil +} + +// ColumnByFieldName returns the column matching the given protobuf field name, or nil if not found. +func (t *Table) ColumnByFieldName(fieldName string) *Column { + for _, col := range t.Columns { + if col.FieldDescriptor != nil && string(col.FieldDescriptor.Name()) == fieldName { + return col + } + } + return nil +} diff --git a/sink/sql/db_proto/sql/utils.go b/sink/sql/db_proto/sql/utils.go new file mode 100644 index 000000000..b509913d2 --- /dev/null +++ b/sink/sql/db_proto/sql/utils.go @@ -0,0 +1,25 @@ +package sql + +import ( + "fmt" + "strings" + + "google.golang.org/protobuf/reflect/protoreflect" +) + +func fieldName(f protoreflect.FieldDescriptor) string { + fieldNameSuffix := "" + if f.Kind() == protoreflect.MessageKind { + fieldNameSuffix = "_id" + } + + return fmt.Sprintf("%s%s", strings.ToLower(string(f.Name())), fieldNameSuffix) +} + +func fieldQuotedName(f protoreflect.FieldDescriptor) string { + return Quoted(fieldName(f)) +} + +func Quoted(value string) string { + return fmt.Sprintf("\"%s\"", value) +} diff --git a/sink/sql/db_proto/stats/stats.go b/sink/sql/db_proto/stats/stats.go new file mode 100644 index 000000000..478df30d4 --- /dev/null +++ b/sink/sql/db_proto/stats/stats.go @@ -0,0 +1,121 @@ +package stats + +import ( + "time" + + "github.com/streamingfast/logging/zapx" + "go.uber.org/zap" +) + +type Average struct { + Duration []time.Duration + windowSize int + title string + lastX int +} + +func NewAverage(title string, windowSize int, lastX int) *Average { + return &Average{ + title: title, + windowSize: windowSize, + lastX: lastX, + } +} +func (a *Average) Add(d time.Duration) { + a.Duration = append(a.Duration, d) + if len(a.Duration) > a.windowSize { + a.Duration = a.Duration[1:] + } +} + +func (a *Average) Average() time.Duration { + if len(a.Duration) == 0 { + return 0 + } + var total time.Duration + for _, d := range a.Duration { + total += d + } + return time.Duration(total / time.Duration(len(a.Duration))) +} + +func (a *Average) LastItemsAverage(count int) time.Duration { + if len(a.Duration) == 0 { + return 0 + } + if count <= 0 || count > len(a.Duration) { + count = len(a.Duration) + } + var total int64 + for _, d := range a.Duration[len(a.Duration)-count:] { + total += d.Nanoseconds() + } + return time.Duration(total / int64(count)) +} + +func (a *Average) Log(logger *zap.Logger) { + logger.Info(a.title, + zapx.HumanDuration("average", a.Average()), + zapx.HumanDuration("last X average", a.LastItemsAverage(a.lastX)), + ) +} + +type Stats struct { + logger *zap.Logger + BlockCount int + WaitDurationBetweenBlocks *Average + BlockProcessingDuration *Average + UnmarshallingDuration *Average + BlockInsertDuration *Average + EntitiesInsertDuration *Average + FlushDuration *Average + LastBlockProcessAt time.Time + TotalProcessingDuration time.Duration + TotalDurationBetween time.Duration +} + +func NewStats(logger *zap.Logger) *Stats { + s := &Stats{ + logger: logger, + WaitDurationBetweenBlocks: NewAverage(" Wait Duration Between Blocks", 250_000, 1000), + BlockProcessingDuration: NewAverage(" Block Processing Duration", 250_000, 1000), + UnmarshallingDuration: NewAverage(" Unmarshalling Duration", 250_000, 1000), + BlockInsertDuration: NewAverage(" Block Insert Duration", 250_000, 1000), + EntitiesInsertDuration: NewAverage(" Entities Insert Duration", 250_000, 1000), + FlushDuration: NewAverage(" Flush duration", 1000, 10), + } + + go func() { + for { + time.Sleep(30 * time.Second) + s.Log() + } + }() + + return s +} + +func (s *Stats) Log() { + s.logger.Info("-----------------------------------") + + if s.BlockCount == 0 { + s.logger.Info("Stats: no blocks processed yet") + } else { + s.logger.Info("Stats", + zap.Int("block_count", s.BlockCount), + zapx.HumanDuration("Processing Time", s.TotalProcessingDuration), + zapx.HumanDuration("Total Wait Duration", s.TotalDurationBetween), + zapx.HumanDuration("Total Duration", s.TotalDurationBetween+s.TotalProcessingDuration), + zap.Time("Last Block Process At", s.LastBlockProcessAt), + ) + + s.WaitDurationBetweenBlocks.Log(s.logger) + s.BlockProcessingDuration.Log(s.logger) + s.UnmarshallingDuration.Log(s.logger) + s.BlockInsertDuration.Log(s.logger) + s.EntitiesInsertDuration.Log(s.logger) + s.FlushDuration.Log(s.logger) + } + + s.logger.Info("-----------------------------------") +} diff --git a/sink/sql/proto/utils.go b/sink/sql/proto/utils.go new file mode 100644 index 000000000..b14cb4f22 --- /dev/null +++ b/sink/sql/proto/utils.go @@ -0,0 +1,38 @@ +package proto + +import ( + "fmt" + + schema "github.com/streamingfast/substreams/pb/sf/substreams/sink/sql/schema/v1" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" +) + +func TableInfo(d protoreflect.MessageDescriptor) *schema.Table { + msgOptions := d.Options() + + if proto.HasExtension(msgOptions, schema.E_Table) { + ext := proto.GetExtension(msgOptions, schema.E_Table) + table, ok := ext.(*schema.Table) + if ok { + if table.Name == "" { + panic(fmt.Sprintf("table name is required for message %q", string(d.Name()))) + } + return table + } + } + return nil +} + +func FieldInfo(d protoreflect.FieldDescriptor) *schema.Column { + options := d.Options() + + if proto.HasExtension(options, schema.E_Field) { + ext := proto.GetExtension(options, schema.E_Field) + f, ok := ext.(*schema.Column) + if ok { + return f + } + } + return nil +} diff --git a/sink/sql/shared.go b/sink/sql/shared.go new file mode 100644 index 000000000..8238105be --- /dev/null +++ b/sink/sql/shared.go @@ -0,0 +1,44 @@ +package sinksql + +import ( + "fmt" + "strings" + + pbsql "github.com/streamingfast/substreams/pb/sf/substreams/sink/sql/services/v1" + pbsubstreams "github.com/streamingfast/substreams/pb/sf/substreams/v1" + "google.golang.org/protobuf/proto" +) + +var ( + supportedDeployableUnits []string + deprecated_supportedDeployableService = "sf.substreams.sink.sql.v1.Service" + supportedDeployableService = "sf.substreams.sink.sql.service.v1.Service" +) + +func init() { + supportedDeployableUnits = []string{ + deprecated_supportedDeployableService, + } +} + +const typeUrlPrefix = "type.googleapis.com/" + +func ExtractSinkService(pkg *pbsubstreams.Package) (*pbsql.Service, error) { + if pkg.SinkConfig == nil { + return nil, fmt.Errorf("no sink config found in spkg") + } + + configPackageID := strings.TrimPrefix(pkg.SinkConfig.TypeUrl, typeUrlPrefix) + + switch configPackageID { + case deprecated_supportedDeployableService, supportedDeployableService: + service := &pbsql.Service{} + + if err := proto.Unmarshal(pkg.SinkConfig.Value, service); err != nil { + return nil, fmt.Errorf("failed to proto unmarshal: %w", err) + } + return service, nil + } + + return nil, fmt.Errorf("invalid config type %q, supported configs are %q", pkg.SinkConfig.TypeUrl, strings.Join(supportedDeployableUnits, ", ")) +}