diff --git a/common/namespace/nsreplication/replication_admitter_test.go b/common/namespace/nsreplication/replication_admitter_test.go index f7b3119518b..14764357c56 100644 --- a/common/namespace/nsreplication/replication_admitter_test.go +++ b/common/namespace/nsreplication/replication_admitter_test.go @@ -12,6 +12,7 @@ import ( replicationspb "go.temporal.io/server/api/replication/v1" "go.temporal.io/server/common/log" "go.temporal.io/server/common/persistence" + "go.temporal.io/server/common/testing/testhooks" "go.uber.org/mock/gomock" ) @@ -77,6 +78,7 @@ func newExecutorForAdmitterTest(t *testing.T, admitter NamespaceReplicationAdmit NewNoopDataMerger(), admitter, log.NewTestLogger(), + testhooks.NewTestHooks(), ).(*taskExecutorImpl) return exec, mockMgr } diff --git a/common/namespace/nsreplication/replication_task_executor.go b/common/namespace/nsreplication/replication_task_executor.go index e64b080be66..704bd0b18b6 100644 --- a/common/namespace/nsreplication/replication_task_executor.go +++ b/common/namespace/nsreplication/replication_task_executor.go @@ -13,7 +13,9 @@ import ( replicationspb "go.temporal.io/server/api/replication/v1" "go.temporal.io/server/common/log" "go.temporal.io/server/common/log/tag" + "go.temporal.io/server/common/namespace" "go.temporal.io/server/common/persistence" + "go.temporal.io/server/common/testing/testhooks" ) var ( @@ -53,6 +55,7 @@ type ( dataMerger NamespaceDataMerger admitter NamespaceReplicationAdmitter logger log.Logger + testHooks testhooks.TestHooks } ) @@ -63,6 +66,7 @@ func NewTaskExecutor( dataMerger NamespaceDataMerger, admitter NamespaceReplicationAdmitter, logger log.Logger, + testHooks testhooks.TestHooks, ) TaskExecutor { return &taskExecutorImpl{ @@ -71,6 +75,7 @@ func NewTaskExecutor( dataMerger: dataMerger, admitter: admitter, logger: logger, + testHooks: testHooks, } } @@ -82,6 +87,20 @@ func (h *taskExecutorImpl) Execute( if err := h.validateNamespaceReplicationTask(task); err != nil { return err } + if hook, ok := testhooks.Get( + h.testHooks, + testhooks.NamespaceReplicationTaskInterceptor, + namespace.Name(task.GetInfo().GetName()), + ); ok { + return hook(ctx, task, h.executeValidatedTask) + } + return h.executeValidatedTask(ctx, task) +} + +func (h *taskExecutorImpl) executeValidatedTask( + ctx context.Context, + task *replicationspb.NamespaceTaskAttributes, +) error { if shouldProcess, err := h.shouldProcessTask(ctx, task); !shouldProcess || err != nil { return err } diff --git a/common/namespace/nsreplication/replication_task_executor_test.go b/common/namespace/nsreplication/replication_task_executor_test.go index 9c3a0d57231..cb22d9398cc 100644 --- a/common/namespace/nsreplication/replication_task_executor_test.go +++ b/common/namespace/nsreplication/replication_task_executor_test.go @@ -17,6 +17,7 @@ import ( replicationspb "go.temporal.io/server/api/replication/v1" "go.temporal.io/server/common/log" "go.temporal.io/server/common/persistence" + "go.temporal.io/server/common/testing/testhooks" "go.uber.org/mock/gomock" "google.golang.org/protobuf/types/known/durationpb" "google.golang.org/protobuf/types/known/timestamppb" @@ -54,6 +55,7 @@ func (s *namespaceReplicationTaskExecutorSuite) SetupTest() { NewNoopDataMerger(), NewDefaultAdmitter(), logger, + testhooks.NewTestHooks(), ).(*taskExecutorImpl) } diff --git a/common/testing/testhooks/hooks.go b/common/testing/testhooks/hooks.go index 8cbb2545e44..ff0f9d266b6 100644 --- a/common/testing/testhooks/hooks.go +++ b/common/testing/testhooks/hooks.go @@ -1,8 +1,10 @@ package testhooks import ( + "context" "time" + replicationspb "go.temporal.io/server/api/replication/v1" "go.temporal.io/server/common/namespace" ) @@ -18,6 +20,10 @@ var ( MatchingIgnoreRoutingConfigRevisionCheck = newKey[bool, namespace.ID]() MatchingDeploymentRegisterErrorBackoff = newKey[time.Duration, namespace.ID]() MatchingForwardTaskDelay = newKey[time.Duration, namespace.ID]() + ReplicationDLQWrite = newKey[func(any), namespace.ID]() + HistoryReplicationTaskInterceptor = newKey[func(any, func() error) error, namespace.ID]() + HistoryReplicationTaskAfterConvert = newKey[func(string, any, any) any, global]() + NamespaceReplicationTaskInterceptor = newKey[func(context.Context, *replicationspb.NamespaceTaskAttributes, func(context.Context, *replicationspb.NamespaceTaskAttributes) error) error, namespace.Name]() ) // keyID is a unique identifier for a key, used as a map key. diff --git a/common/testing/testhooks/test_impl.go b/common/testing/testhooks/test_impl.go index b7a5c28d8f2..b42354a6ba2 100644 --- a/common/testing/testhooks/test_impl.go +++ b/common/testing/testhooks/test_impl.go @@ -92,7 +92,7 @@ func newKey[T any, S any]() Key[T, S] { var zero S var s ScopeType switch any(zero).(type) { - case namespace.ID: + case namespace.ID, namespace.Name: s = ScopeNamespace case global: s = ScopeGlobal diff --git a/service/frontend/admin_handler_test.go b/service/frontend/admin_handler_test.go index 138cfd738ee..548851dee4d 100644 --- a/service/frontend/admin_handler_test.go +++ b/service/frontend/admin_handler_test.go @@ -56,6 +56,7 @@ import ( "go.temporal.io/server/common/testing/historyrequire" "go.temporal.io/server/common/testing/mocksdk" "go.temporal.io/server/common/testing/protorequire" + "go.temporal.io/server/common/testing/testhooks" "go.temporal.io/server/common/testing/testvars" "go.temporal.io/server/service/history/tasks" "go.temporal.io/server/service/worker/dlq" @@ -198,6 +199,7 @@ func (s *adminHandlerSuite) SetupTest() { nsreplication.NewDefaultAdmitter(), s.mockResource.GetNamespaceReplicationQueue(), s.mockResource.GetLogger(), + testhooks.TestHooks{}, ) s.handler = NewAdminHandler(args, namespaceDLQHandler) s.handler.Start() diff --git a/service/frontend/fx.go b/service/frontend/fx.go index 53320dbaddd..ab3d44f2527 100644 --- a/service/frontend/fx.go +++ b/service/frontend/fx.go @@ -43,6 +43,7 @@ import ( "go.temporal.io/server/common/sdk" "go.temporal.io/server/common/searchattribute" "go.temporal.io/server/common/telemetry" + "go.temporal.io/server/common/testing/testhooks" "go.temporal.io/server/service" "go.temporal.io/server/service/frontend/configs" "go.temporal.io/server/service/history/tasks" @@ -793,6 +794,7 @@ func NamespaceDLQHandlerProvider( namespaceAdmitter nsreplication.NamespaceReplicationAdmitter, namespaceReplicationQueue persistence.NamespaceReplicationQueue, logger log.SnTaggedLogger, + testHooks testhooks.TestHooks, ) nsreplication.DLQMessageHandler { taskExecutor := nsreplication.NewTaskExecutor( clusterMetadata.GetCurrentClusterName(), @@ -800,6 +802,7 @@ func NamespaceDLQHandlerProvider( namespaceDataMerger, namespaceAdmitter, logger, + testHooks, ) return nsreplication.NewDLQMessageHandler( taskExecutor, diff --git a/service/history/history_engine.go b/service/history/history_engine.go index 2ee68ae48f7..24a94714d28 100644 --- a/service/history/history_engine.go +++ b/service/history/history_engine.go @@ -329,6 +329,7 @@ func NewEngineWithShardContext( serializer, replicationTaskFetcherFactory, replicationTaskExecutorProvider, + testHooks, dlqWriter, ) diff --git a/service/history/replication/dlq_writer.go b/service/history/replication/dlq_writer.go index cd3fff193da..fa1fb810ed1 100644 --- a/service/history/replication/dlq_writer.go +++ b/service/history/replication/dlq_writer.go @@ -4,7 +4,9 @@ import ( "context" persistencespb "go.temporal.io/server/api/persistence/v1" + "go.temporal.io/server/common/namespace" "go.temporal.io/server/common/persistence" + "go.temporal.io/server/common/testing/testhooks" "go.temporal.io/server/service/history/configs" "go.temporal.io/server/service/history/queues" "go.uber.org/fx" @@ -49,6 +51,7 @@ type ( Config *configs.Config ExecutionManagerDLQWriter *executionManagerDLQWriter DLQWriterAdapter *DLQWriterAdapter + TestHooks testhooks.TestHooks } dlqWriterToggle struct { *dlqWriterToggleParams @@ -88,10 +91,16 @@ func newDLQWriterToggle( // - QueueV1: [ExecutionManagerDLQWriter.WriteTaskToDLQ] // - QueueV2: [DLQWriterAdapter.WriteTaskToDLQ] func (d *dlqWriterToggle) WriteTaskToDLQ(ctx context.Context, request DLQWriteRequest) error { + var err error if d.Config.HistoryReplicationDLQV2() { - return d.DLQWriterAdapter.WriteTaskToDLQ(ctx, request) + err = d.DLQWriterAdapter.WriteTaskToDLQ(ctx, request) + } else { + err = d.ExecutionManagerDLQWriter.WriteTaskToDLQ(ctx, request) } - return d.ExecutionManagerDLQWriter.WriteTaskToDLQ(ctx, request) + if hook, ok := testhooks.Get(d.TestHooks, testhooks.ReplicationDLQWrite, namespace.ID(request.ReplicationTaskInfo.GetNamespaceId())); ok { + hook(request) + } + return err } // WriteTaskToDLQ implements [DLQWriter.WriteTaskToDLQ] by calling [persistence.ExecutionManager.PutReplicationTaskToDLQ]. diff --git a/service/history/replication/executable_task.go b/service/history/replication/executable_task.go index 726ac615daa..63a5cef3098 100644 --- a/service/history/replication/executable_task.go +++ b/service/history/replication/executable_task.go @@ -27,6 +27,7 @@ import ( serviceerrors "go.temporal.io/server/common/serviceerror" "go.temporal.io/server/common/softassert" ctasks "go.temporal.io/server/common/tasks" + "go.temporal.io/server/common/testing/testhooks" "go.temporal.io/server/service/history/consts" historyi "go.temporal.io/server/service/history/interfaces" "go.temporal.io/server/service/history/tasks" @@ -168,6 +169,41 @@ func (e *ExecutableTaskImpl) ReplicationTask() *replicationspb.ReplicationTask { return e.replicationTask } +func executeReplicationTask( + testHooks testhooks.TestHooks, + replicationTask *replicationspb.ReplicationTask, + execute func() error, +) error { + if hook, ok := testhooks.Get(testHooks, testhooks.HistoryReplicationTaskInterceptor, replicationTaskNamespaceID(replicationTask)); ok { + return hook(replicationTask, execute) + } + return execute() +} + +func replicationTaskNamespaceID(replicationTask *replicationspb.ReplicationTask) namespace.ID { + if rawTaskInfo := replicationTask.GetRawTaskInfo(); rawTaskInfo != nil { + return namespace.ID(rawTaskInfo.GetNamespaceId()) + } + switch attr := replicationTask.GetAttributes().(type) { + case *replicationspb.ReplicationTask_SyncWorkflowStateTaskAttributes: + return namespace.ID(attr.SyncWorkflowStateTaskAttributes.GetWorkflowState().GetExecutionInfo().GetNamespaceId()) + case *replicationspb.ReplicationTask_SyncActivityTaskAttributes: + return namespace.ID(attr.SyncActivityTaskAttributes.GetNamespaceId()) + case *replicationspb.ReplicationTask_HistoryTaskAttributes: + return namespace.ID(attr.HistoryTaskAttributes.GetNamespaceId()) + case *replicationspb.ReplicationTask_SyncHsmAttributes: + return namespace.ID(attr.SyncHsmAttributes.GetNamespaceId()) + case *replicationspb.ReplicationTask_BackfillHistoryTaskAttributes: + return namespace.ID(attr.BackfillHistoryTaskAttributes.GetNamespaceId()) + case *replicationspb.ReplicationTask_VerifyVersionedTransitionTaskAttributes: + return namespace.ID(attr.VerifyVersionedTransitionTaskAttributes.GetNamespaceId()) + case *replicationspb.ReplicationTask_SyncVersionedTransitionTaskAttributes: + return namespace.ID(attr.SyncVersionedTransitionTaskAttributes.GetNamespaceId()) + default: + return "" + } +} + func (e *ExecutableTaskImpl) Ack() { if atomic.LoadInt32(&e.taskState) != taskStatePending { return diff --git a/service/history/replication/executable_task_converter.go b/service/history/replication/executable_task_converter.go index 3a2de1c6916..426e6875204 100644 --- a/service/history/replication/executable_task_converter.go +++ b/service/history/replication/executable_task_converter.go @@ -7,6 +7,7 @@ import ( enumsspb "go.temporal.io/server/api/enums/v1" replicationspb "go.temporal.io/server/api/replication/v1" "go.temporal.io/server/common/metrics" + "go.temporal.io/server/common/testing/testhooks" ) type ( @@ -47,6 +48,15 @@ func (e *executableTaskConverterImpl) Convert( metrics.OperationTag(TaskOperationTag(replicationTask)), ) tasks[index] = e.convertOne(sourceClusterName, serverShardKey, replicationTask) + if hook, ok := testhooks.Get( + e.processToolBox.TestHooks, + testhooks.HistoryReplicationTaskAfterConvert, + testhooks.GlobalScope, + ); ok { + if task, ok := hook(sourceClusterName, replicationTask, tasks[index]).(TrackableExecutableTask); ok { + tasks[index] = task + } + } } return tasks } diff --git a/service/history/replication/executable_task_tool_box.go b/service/history/replication/executable_task_tool_box.go index ef39b14db62..86094703769 100644 --- a/service/history/replication/executable_task_tool_box.go +++ b/service/history/replication/executable_task_tool_box.go @@ -9,6 +9,7 @@ import ( "go.temporal.io/server/common/namespace" "go.temporal.io/server/common/persistence/serialization" ctasks "go.temporal.io/server/common/tasks" + "go.temporal.io/server/common/testing/testhooks" "go.temporal.io/server/service/history/configs" "go.temporal.io/server/service/history/replication/eventhandler" "go.temporal.io/server/service/history/shard" @@ -36,6 +37,7 @@ type ( Logger log.Logger ThrottledLogger log.ThrottledLogger Serializer serialization.Serializer + TestHooks testhooks.TestHooks DLQWriter DLQWriter HistoryEventsHandler eventhandler.HistoryEventsHandler WorkflowCache wcache.Cache diff --git a/service/history/replication/fx.go b/service/history/replication/fx.go index e5cf5c310f1..b3fd2382c45 100644 --- a/service/history/replication/fx.go +++ b/service/history/replication/fx.go @@ -19,6 +19,7 @@ import ( "go.temporal.io/server/common/persistence/serialization" "go.temporal.io/server/common/quotas" ctasks "go.temporal.io/server/common/tasks" + "go.temporal.io/server/common/testing/testhooks" "go.temporal.io/server/service/history/configs" historyi "go.temporal.io/server/service/history/interfaces" "go.temporal.io/server/service/history/queues" @@ -83,6 +84,7 @@ func eagerNamespaceRefresherProvider( dataMerger nsreplication.NamespaceDataMerger, admitter nsreplication.NamespaceReplicationAdmitter, metricsHandler metrics.Handler, + testHooks testhooks.TestHooks, ) EagerNamespaceRefresher { return NewEagerNamespaceRefresher( metadataManager, @@ -95,6 +97,7 @@ func eagerNamespaceRefresherProvider( dataMerger, admitter, logger, + testHooks, ), clusterMetadata.GetCurrentClusterName(), metricsHandler, diff --git a/service/history/replication/task_processor.go b/service/history/replication/task_processor.go index 94015b05650..b58d36c88c3 100644 --- a/service/history/replication/task_processor.go +++ b/service/history/replication/task_processor.go @@ -28,6 +28,7 @@ import ( "go.temporal.io/server/common/primitives/timestamp" "go.temporal.io/server/common/quotas" serviceerrors "go.temporal.io/server/common/serviceerror" + "go.temporal.io/server/common/testing/testhooks" "go.temporal.io/server/service/history/configs" historyi "go.temporal.io/server/service/history/interfaces" "go.temporal.io/server/service/history/shard" @@ -64,6 +65,7 @@ type ( config *configs.Config metricsHandler metrics.Handler logger log.Logger + testHooks testhooks.TestHooks replicationTaskExecutor TaskExecutor dlqWriter DLQWriter @@ -99,6 +101,7 @@ func NewTaskProcessor( replicationTaskFetcher taskFetcher, replicationTaskExecutor TaskExecutor, eventSerializer serialization.Serializer, + testHooks testhooks.TestHooks, dlqWriter DLQWriter, ) TaskProcessor { shardID := shardContext.GetShardID() @@ -125,6 +128,7 @@ func NewTaskProcessor( config: config, metricsHandler: metricsHandler, logger: shardContext.GetLogger(), + testHooks: testHooks, replicationTaskExecutor: replicationTaskExecutor, dlqWriter: dlqWriter, rateLimiter: quotas.NewMultiRateLimiter([]quotas.RateLimiter{ @@ -313,7 +317,9 @@ func (p *taskProcessorImpl) handleReplicationTask( operationTagValue := p.getOperationTagValue(replicationTask) operation := func() error { - err := p.replicationTaskExecutor.Execute(ctx, replicationTask, false) + err := executeReplicationTask(p.testHooks, replicationTask, func() error { + return p.replicationTaskExecutor.Execute(ctx, replicationTask, false) + }) p.emitTaskMetrics(operationTagValue, err) return err } diff --git a/service/history/replication/task_processor_manager.go b/service/history/replication/task_processor_manager.go index c6b3911ae22..386b91b8af6 100644 --- a/service/history/replication/task_processor_manager.go +++ b/service/history/replication/task_processor_manager.go @@ -17,6 +17,7 @@ import ( "go.temporal.io/server/common/metrics" "go.temporal.io/server/common/persistence" "go.temporal.io/server/common/persistence/serialization" + "go.temporal.io/server/common/testing/testhooks" "go.temporal.io/server/service/history/configs" "go.temporal.io/server/service/history/deletemanager" historyi "go.temporal.io/server/service/history/interfaces" @@ -46,6 +47,7 @@ type ( taskPollerManager pollerManager metricsHandler metrics.Handler logger log.Logger + testHooks testhooks.TestHooks dlqWriter DLQWriter enableFetcher bool @@ -66,6 +68,7 @@ func NewTaskProcessorManager( eventSerializer serialization.Serializer, replicationTaskFetcherFactory TaskFetcherFactory, taskExecutorProvider TaskExecutorProvider, + testHooks testhooks.TestHooks, dlqWriter DLQWriter, ) *taskProcessorManagerImpl { historyFetcher := eventhandler.NewHistoryPaginatedFetcher(shardContext.GetNamespaceRegistry(), clientBean, eventSerializer, shardContext.GetLogger()) @@ -81,6 +84,7 @@ func NewTaskProcessorManager( removeHistoryFetcher: historyFetcher, logger: shardContext.GetLogger(), metricsHandler: shardContext.GetMetricsHandler(), + testHooks: testHooks, dlqWriter: dlqWriter, enableFetcher: !config.EnableReplicationStream(), @@ -193,6 +197,7 @@ func (r *taskProcessorManagerImpl) handleClusterMetadataUpdate( WorkflowCache: r.workflowCache, }), r.eventSerializer, + r.testHooks, r.dlqWriter, ) replicationTaskProcessor.Start() diff --git a/service/history/replication/task_processor_manager_test.go b/service/history/replication/task_processor_manager_test.go index fca3224e560..b0ba14ddbca 100644 --- a/service/history/replication/task_processor_manager_test.go +++ b/service/history/replication/task_processor_manager_test.go @@ -18,6 +18,7 @@ import ( "go.temporal.io/server/common/namespace" "go.temporal.io/server/common/persistence" "go.temporal.io/server/common/persistence/serialization" + "go.temporal.io/server/common/testing/testhooks" "go.temporal.io/server/service/history/configs" historyi "go.temporal.io/server/service/history/interfaces" "go.temporal.io/server/service/history/shard" @@ -106,6 +107,7 @@ func (s *taskProcessorManagerSuite) SetupTest() { func(params TaskExecutorParams) TaskExecutor { return s.mockReplicationTaskExecutor }, + testhooks.NewTestHooks(), NewExecutionManagerDLQWriter(s.mockExecutionManager), ) } diff --git a/service/history/replication/task_processor_test.go b/service/history/replication/task_processor_test.go index 22d8dfcc864..0e41c064319 100644 --- a/service/history/replication/task_processor_test.go +++ b/service/history/replication/task_processor_test.go @@ -30,6 +30,7 @@ import ( "go.temporal.io/server/common/quotas" "go.temporal.io/server/common/resourcetest" "go.temporal.io/server/common/testing/protorequire" + "go.temporal.io/server/common/testing/testhooks" "go.temporal.io/server/service/history/configs" historyi "go.temporal.io/server/service/history/interfaces" "go.temporal.io/server/service/history/shard" @@ -124,6 +125,7 @@ func (s *taskProcessorSuite) SetupTest() { s.mockReplicationTaskFetcher, s.mockReplicationTaskExecutor, serialization.NewSerializer(), + testhooks.NewTestHooks(), NewExecutionManagerDLQWriter(s.mockExecutionManager), ).(*taskProcessorImpl) } diff --git a/service/worker/fx.go b/service/worker/fx.go index d1e735e4b9c..366c85316d0 100644 --- a/service/worker/fx.go +++ b/service/worker/fx.go @@ -28,6 +28,7 @@ import ( "go.temporal.io/server/common/resource" "go.temporal.io/server/common/sdk" "go.temporal.io/server/common/searchattribute" + "go.temporal.io/server/common/testing/testhooks" "go.temporal.io/server/service" "go.temporal.io/server/service/worker/batcher" workercommon "go.temporal.io/server/service/worker/common" @@ -87,6 +88,7 @@ var Module = fx.Options( dataMerger nsreplication.NamespaceDataMerger, admitter nsreplication.NamespaceReplicationAdmitter, logger log.Logger, + testHooks testhooks.TestHooks, ) nsreplication.TaskExecutor { return nsreplication.NewTaskExecutor( clusterMetadata.GetCurrentClusterName(), @@ -94,6 +96,7 @@ var Module = fx.Options( dataMerger, admitter, logger, + testHooks, ) }), fx.Provide(nsreplication.NewNoopDataMerger), diff --git a/tests/testcore/functional_test_base.go b/tests/testcore/functional_test_base.go index 7d6832dc500..42925866eb4 100644 --- a/tests/testcore/functional_test_base.go +++ b/tests/testcore/functional_test_base.go @@ -125,6 +125,8 @@ func init() { // This is similar to the pattern of plumbing dependencies through the TestClusterConfig, but it's much more convenient, // scalable and flexible. The reason we need to do this on a per-service basis is that there are separate fx apps for // each one. +// +// Deprecated: prefer dedicated TestClusterOption helpers or testhooks over injecting arbitrary Fx options. func WithFxOptionsForService(serviceName primitives.ServiceName, options ...fx.Option) TestClusterOption { return func(params *TestClusterParams) { params.ServiceOptions[serviceName] = append(params.ServiceOptions[serviceName], options...) diff --git a/tests/testcore/test_cluster.go b/tests/testcore/test_cluster.go index 27850223e43..b2e8d53d40a 100644 --- a/tests/testcore/test_cluster.go +++ b/tests/testcore/test_cluster.go @@ -45,6 +45,7 @@ import ( "go.temporal.io/server/common/rpc/encryption" "go.temporal.io/server/common/telemetry" "go.temporal.io/server/common/testing/freeport" + "go.temporal.io/server/common/testing/testhooks" "go.temporal.io/server/temporal" "go.temporal.io/server/temporal/environment" "go.temporal.io/server/tests/testutils" @@ -339,7 +340,7 @@ func newClusterWithPersistenceTestBaseFactory( MatchingConfig: clusterConfig.MatchingConfig, WorkerConfig: clusterConfig.WorkerConfig, MockAdminClient: clusterConfig.MockAdminClient, - NamespaceReplicationTaskExecutor: nsreplication.NewTaskExecutor(clusterConfig.ClusterMetadata.CurrentClusterName, testBase.MetadataManager, nsreplication.NewNoopDataMerger(), nsreplication.NewDefaultAdmitter(), logger), + NamespaceReplicationTaskExecutor: nsreplication.NewTaskExecutor(clusterConfig.ClusterMetadata.CurrentClusterName, testBase.MetadataManager, nsreplication.NewNoopDataMerger(), nsreplication.NewDefaultAdmitter(), logger, testhooks.TestHooks{}), DynamicConfigOverrides: clusterConfig.DynamicConfigOverrides, TLSConfigProvider: tlsConfigProvider, ServiceFxOptions: clusterConfig.ServiceFxOptions, @@ -600,6 +601,10 @@ func (tc *TestCluster) Host() *TemporalImpl { return tc.host } +func (tc *TestCluster) InjectHook(t *testing.T, hook testhooks.Hook, scope any) func() { + return tc.host.injectHook(t, hook, scope) +} + func (tc *TestCluster) WorkerGRPCAddress() string { return tc.host.WorkerGRPCAddress() } diff --git a/tests/xdc/history_replication_dlq_test.go b/tests/xdc/history_replication_dlq_test.go index 074619aaa97..45d21dc6584 100644 --- a/tests/xdc/history_replication_dlq_test.go +++ b/tests/xdc/history_replication_dlq_test.go @@ -27,15 +27,13 @@ import ( replicationspb "go.temporal.io/server/api/replication/v1" "go.temporal.io/server/common/dynamicconfig" "go.temporal.io/server/common/log" - "go.temporal.io/server/common/namespace/nsreplication" + "go.temporal.io/server/common/namespace" "go.temporal.io/server/common/persistence/serialization" - "go.temporal.io/server/common/primitives" + "go.temporal.io/server/common/testing/testhooks" "go.temporal.io/server/service/history/replication" "go.temporal.io/server/service/history/tasks" - "go.temporal.io/server/tests/testcore" "go.temporal.io/server/tools/tdbg" "go.temporal.io/server/tools/tdbg/tdbgtest" - "go.uber.org/fx" "google.golang.org/protobuf/types/known/durationpb" ) @@ -73,35 +71,14 @@ type ( workflowIDToFail atomic.Pointer[string] workflowIDToObserve atomic.Pointer[string] } - testReplicationTaskExecutor struct { - *replicationTaskExecutorParams - taskExecutor replication.TaskExecutor - } dlqWriterParams struct { // This channel is sent to once we're done processing a request to add a message to the DLQ. processedDLQRequests chan replication.DLQWriteRequest } - testDLQWriter struct { - *dlqWriterParams - replication.DLQWriter - } namespaceReplicationTaskExecutorParams struct { // This channel is sent to once we're done processing a namespace replication task. tasks chan *replicationspb.NamespaceTaskAttributes } - testNamespaceReplicationTaskExecutor struct { - *namespaceReplicationTaskExecutorParams - replicationTaskExecutor nsreplication.TaskExecutor - } - testExecutableTaskConverter struct { - *replicationTaskExecutorParams - converter replication.ExecutableTaskConverter - } - testExecutableTask struct { - *replicationTaskExecutorParams - replication.TrackableExecutableTask - replicationTask *replicationspb.ReplicationTask - } ) func TestHistoryReplicationDLQSuite(t *testing.T) { @@ -165,35 +142,8 @@ func (s *historyReplicationDLQSuite) SetupSuite() { s.replicationTaskExecutors.workflowIDToFail.Store(&workflowIDToFail) s.replicationTaskExecutors.workflowIDToObserve.Store(&workflowIDToFail) - // This can't be very long, so we just use a UUID instead of a more descriptive name. - // We also don't escape this string in many places, so it can't contain any dashes. - taskExecutorDecorator := s.getTaskExecutorDecorator() s.logger = log.NewTestLogger() - s.setupSuite( - testcore.WithFxOptionsForService(primitives.HistoryService, - fx.Decorate( - taskExecutorDecorator, - func(dlqWriter replication.DLQWriter) replication.DLQWriter { - // Replace the dlq writer with one that records DLQ requests so that we can wait until a task is - // added to the DLQ before querying tdbg. - return &testDLQWriter{ - dlqWriterParams: &s.dlqWriters, - DLQWriter: dlqWriter, - } - }, - ), - ), - testcore.WithFxOptionsForService(primitives.WorkerService, - fx.Decorate( - func(executor nsreplication.TaskExecutor) nsreplication.TaskExecutor { - return &testNamespaceReplicationTaskExecutor{ - replicationTaskExecutor: executor, - namespaceReplicationTaskExecutorParams: &s.namespaceReplicationTaskExecutors, - } - }, - ), - ), - ) + s.setupSuite() } func (s *historyReplicationDLQSuite) TearDownSuite() { @@ -217,6 +167,21 @@ func (s *historyReplicationDLQSuite) TestWorkflowReplicationTaskFailure() { // Register a namespace. ns := "history-replication-dlq-test-namespace" + s.clusters[1].InjectHook( + s.T(), + testhooks.NewHook(testhooks.NamespaceReplicationTaskInterceptor, func( + ctx context.Context, + task *replicationspb.NamespaceTaskAttributes, + execute func(context.Context, *replicationspb.NamespaceTaskAttributes) error, + ) error { + err := execute(ctx, task) + if err == nil { + s.namespaceReplicationTaskExecutors.tasks <- task + } + return err + }), + namespace.Name(ns), + ) _, err := s.clusters[0].FrontendClient().RegisterNamespace(ctx, &workflowservice.RegisterNamespaceRequest{ Namespace: ns, Clusters: s.clusterReplicationConfig(), @@ -228,6 +193,30 @@ func (s *historyReplicationDLQSuite) TestWorkflowReplicationTaskFailure() { WorkflowExecutionRetentionPeriod: durationpb.New(time.Hour * 24), }) s.NoError(err) + describeResp, err := s.clusters[0].FrontendClient().DescribeNamespace(ctx, &workflowservice.DescribeNamespaceRequest{ + Namespace: ns, + }) + s.NoError(err) + namespaceID := namespace.ID(describeResp.NamespaceInfo.Id) + + replicationDLQWriteHook := testhooks.NewHook(testhooks.ReplicationDLQWrite, func(request any) { + s.dlqWriters.processedDLQRequests <- request.(replication.DLQWriteRequest) + }) + historyReplicationTaskHook := testhooks.NewHook( + testhooks.HistoryReplicationTaskInterceptor, + func(task any, execute func() error) error { + replicationTask := task.(*replicationspb.ReplicationTask) + defer s.afterReplicationTaskExecute(replicationTask) + if replicationTaskWorkflowID(replicationTask) == *s.replicationTaskExecutors.workflowIDToFail.Load() { + return serviceerror.NewInvalidArgument("failed to apply replication task") + } + return execute() + }, + ) + for _, cluster := range s.clusters { + cluster.InjectHook(s.T(), replicationDLQWriteHook, namespaceID) + cluster.InjectHook(s.T(), historyReplicationTaskHook, namespaceID) + } // Create a worker and register a workflow on the active cluster. activeClient, err := sdkclient.Dial(sdkclient.Options{ @@ -546,124 +535,21 @@ func (s *historyReplicationDLQSuite) runTDBGCommand( s.T().Log("========================================") } -func (s *historyReplicationDLQSuite) getTaskExecutorDecorator() any { - if s.enableReplicationStream { - // The replication stream uses a different code path which converts tasks into executables using this interface, - // so that's a good injection point for us. - return func(converter replication.ExecutableTaskConverter) replication.ExecutableTaskConverter { - return &testExecutableTaskConverter{ - replicationTaskExecutorParams: &s.replicationTaskExecutors, - converter: converter, - } - } - } - // Without the replication stream, we use polling that relies on a task executor, so we can inject our own - // faulty version here. - return func(provider replication.TaskExecutorProvider) replication.TaskExecutorProvider { - return func(params replication.TaskExecutorParams) replication.TaskExecutor { - taskExecutor := provider(params) - return &testReplicationTaskExecutor{ - replicationTaskExecutorParams: &s.replicationTaskExecutors, - taskExecutor: taskExecutor, - } - } - } -} - -// Execute the replication task as-normal, but also send it to the channel so that the test can wait for it to -// know that the namespace data has been replicated. -func (t *testNamespaceReplicationTaskExecutor) Execute( - ctx context.Context, - task *replicationspb.NamespaceTaskAttributes, -) error { - err := t.replicationTaskExecutor.Execute(ctx, task) - if err != nil { - return err - } - t.tasks <- task - return nil -} - -// WriteTaskToDLQ is the same as the normal dlq writer, but also sends the request to the channel so that the test can -// wait for it to know that the replication task has been added to the DLQ. -func (t *testDLQWriter) WriteTaskToDLQ( - ctx context.Context, - request replication.DLQWriteRequest, -) error { - err := t.DLQWriter.WriteTaskToDLQ(ctx, request) - t.processedDLQRequests <- request - return err -} - -// Execute the replication task as-normal or return an error if the workflow ID matches the one that we want to fail. -// This is run only when streaming is disabled for replication. -func (f testReplicationTaskExecutor) Execute( - ctx context.Context, - replicationTask *replicationspb.ReplicationTask, - forceApply bool, -) error { - err := f.execute(ctx, replicationTask, forceApply) - if attr := replicationTask.GetHistoryTaskAttributes(); attr != nil && attr.WorkflowId == *f.workflowIDToObserve.Load() { - f.executedTasks <- replicationTask - } - if attr := replicationTask.GetSyncVersionedTransitionTaskAttributes(); attr != nil && attr.WorkflowId == *f.workflowIDToObserve.Load() { - f.executedTasks <- replicationTask - } - return err -} - -func (f testReplicationTaskExecutor) execute( - ctx context.Context, - replicationTask *replicationspb.ReplicationTask, - forceApply bool, -) error { - if attr := replicationTask.GetHistoryTaskAttributes(); attr != nil && attr.WorkflowId == *f.workflowIDToFail.Load() { - return serviceerror.NewInvalidArgument("failed to apply replication task") - } - if attr := replicationTask.GetSyncVersionedTransitionTaskAttributes(); attr != nil && attr.WorkflowId == *f.workflowIDToFail.Load() { - return serviceerror.NewInvalidArgument("failed to apply replication task") +func (s *historyReplicationDLQSuite) afterReplicationTaskExecute(replicationTask *replicationspb.ReplicationTask) { + if replicationTaskWorkflowID(replicationTask) == *s.replicationTaskExecutors.workflowIDToObserve.Load() { + s.replicationTaskExecutors.executedTasks <- replicationTask } - err := f.taskExecutor.Execute(ctx, replicationTask, forceApply) - return err } -// Convert the replication tasks using the testcore converter, but then wrap them in our own faulty executable tasks. -func (t *testExecutableTaskConverter) Convert( - taskClusterName string, - clientShardKey replication.ClusterShardKey, - serverShardKey replication.ClusterShardKey, - replicationTasks ...*replicationspb.ReplicationTask, -) []replication.TrackableExecutableTask { - convertedTasks := t.converter.Convert(taskClusterName, clientShardKey, serverShardKey, replicationTasks...) - testExecutableTasks := make([]replication.TrackableExecutableTask, len(convertedTasks)) - for i, task := range convertedTasks { - testExecutableTasks[i] = &testExecutableTask{ - replicationTaskExecutorParams: t.replicationTaskExecutorParams, - TrackableExecutableTask: task, - replicationTask: replicationTasks[i], - } +func replicationTaskWorkflowID(replicationTask *replicationspb.ReplicationTask) string { + if attr := replicationTask.GetHistoryTaskAttributes(); attr != nil { + return attr.WorkflowId } - return testExecutableTasks -} - -// Execute the replication task as-normal or return an error if the workflow ID matches the one that we want to fail. -// This is run only when streaming is enabled for replication. -func (t *testExecutableTask) Execute() error { - err := t.execute() - t.replicationTaskExecutorParams.executedTasks <- t.replicationTask - return err -} - -func (t *testExecutableTask) execute() error { - if et, ok := t.TrackableExecutableTask.(*replication.ExecutableHistoryTask); ok { - if et.WorkflowID == *t.workflowIDToFail.Load() { - return serviceerror.NewInvalidArgument("failed to apply replication task") - } + if attr := replicationTask.GetSyncVersionedTransitionTaskAttributes(); attr != nil { + return attr.WorkflowId } - if et, ok := t.TrackableExecutableTask.(*replication.ExecutableSyncVersionedTransitionTask); ok { - if et.WorkflowID == *t.workflowIDToFail.Load() { - return serviceerror.NewInvalidArgument("failed to apply replication task") - } + if attr := replicationTask.GetVerifyVersionedTransitionTaskAttributes(); attr != nil { + return attr.WorkflowId } - return t.TrackableExecutableTask.Execute() + return "" } diff --git a/tests/xdc/history_replication_signals_and_updates_test.go b/tests/xdc/history_replication_signals_and_updates_test.go index 3a7cde3006c..1cea7376bec 100644 --- a/tests/xdc/history_replication_signals_and_updates_test.go +++ b/tests/xdc/history_replication_signals_and_updates_test.go @@ -25,15 +25,14 @@ import ( replicationspb "go.temporal.io/server/api/replication/v1" "go.temporal.io/server/common/dynamicconfig" "go.temporal.io/server/common/log" - "go.temporal.io/server/common/namespace/nsreplication" + "go.temporal.io/server/common/namespace" "go.temporal.io/server/common/payloads" "go.temporal.io/server/common/persistence/serialization" - "go.temporal.io/server/common/primitives" "go.temporal.io/server/common/testing/protoutils" + "go.temporal.io/server/common/testing/testhooks" "go.temporal.io/server/common/testing/testvars" "go.temporal.io/server/service/history/replication" "go.temporal.io/server/tests/testcore" - "go.uber.org/fx" "google.golang.org/protobuf/types/known/durationpb" ) @@ -47,14 +46,11 @@ type ( // that push their tasks into test-specific (i.e. workflow-specific) buffers. hrsuTestSuite struct { xdcBaseSuite - namespaceTaskExecutor nsreplication.TaskExecutor // The injection is performed once, at the level of the test suite, but we need the modified executors to be - // able to route tasks to test-specific (i.e. workflow-specific) buffers. The following two maps serve that - // purpose (each test registers itself in these maps as it starts). Workflow ID and namespace name are both - // unique per test (due to the use of TestVars). - testMapLock sync.Mutex - testsByWorkflowId map[string]*hrsuTest - testsByNamespaceName map[string]*hrsuTest + // able to route history replication tasks to test-specific (i.e. workflow-specific) buffers. Each test + // registers itself as it starts. Workflow ID is unique per test due to the use of TestVars. + testMapLock sync.Mutex + testsByWorkflowID map[string]*hrsuTest } // Each test starts its own workflow, in its own namespace. hrsuTest struct { @@ -62,7 +58,7 @@ type ( // Per-test buffer of namespace replication tasks. // TODO (dan): buffer namespace replication tasks from each cluster separately, as we do for history replication // tasks. - namespaceReplicationTasks chan *replicationspb.NamespaceTaskAttributes + namespaceReplicationTasks chan *hrsuNamespaceReplicationTask cluster1 hrsuTestCluster cluster2 hrsuTestCluster s *hrsuTestSuite @@ -74,15 +70,9 @@ type ( inboundHistoryReplicationTasks chan *hrsuTestExecutableTask t *hrsuTest } - // Used to inject a modified namespace replication task executor. - hrsuTestNamespaceReplicationTaskExecutor struct { - replicationTaskExecutor nsreplication.TaskExecutor - s *hrsuTestSuite - } - // Used to inject a modified history event replication task executor. - hrsuTestExecutableTaskConverter struct { - converter replication.ExecutableTaskConverter - s *hrsuTestSuite + hrsuNamespaceReplicationTask struct { + task *replicationspb.NamespaceTaskAttributes + execute func(context.Context, *replicationspb.NamespaceTaskAttributes) error } // Used to inject a modified history event replication task executor. hrsuTestExecutableTask struct { @@ -110,31 +100,22 @@ func (s *hrsuTestSuite) SetupSuite() { dynamicconfig.HistoryLongPollExpirationInterval.Key(): 100 * time.Millisecond, } s.logger = log.NewTestLogger() - s.setupSuite( - testcore.WithFxOptionsForService(primitives.WorkerService, - fx.Decorate( - func(executor nsreplication.TaskExecutor) nsreplication.TaskExecutor { - s.namespaceTaskExecutor = executor - return &hrsuTestNamespaceReplicationTaskExecutor{ - replicationTaskExecutor: executor, - s: s, - } - }, - ), - ), - testcore.WithFxOptionsForService(primitives.HistoryService, - fx.Decorate( - func(converter replication.ExecutableTaskConverter) replication.ExecutableTaskConverter { - return &hrsuTestExecutableTaskConverter{ - converter: converter, - s: s, - } - }, - ), - ), - ) - s.testsByWorkflowId = make(map[string]*hrsuTest) - s.testsByNamespaceName = make(map[string]*hrsuTest) + s.setupSuite() + for _, cluster := range s.clusters { + cluster.InjectHook(s.T(), testhooks.NewHook( + testhooks.HistoryReplicationTaskAfterConvert, + func(sourceCluster string, replicationTask any, executable any) any { + return &hrsuTestExecutableTask{ + sourceCluster: sourceCluster, + s: s, + TrackableExecutableTask: executable.(replication.TrackableExecutableTask), + replicationTask: replicationTask.(*replicationspb.ReplicationTask), + result: make(chan error), + } + }, + ), testhooks.GlobalScope) + } + s.testsByWorkflowID = make(map[string]*hrsuTest) } func (s *hrsuTestSuite) SetupTest() { @@ -151,15 +132,13 @@ func (s *hrsuTestSuite) startHrsuTest() (*hrsuTest, context.Context, context.Can ns := tv.NamespaceName().String() t := hrsuTest{ tv: tv, - namespaceReplicationTasks: make(chan *replicationspb.NamespaceTaskAttributes, taskBufferCapacity), + namespaceReplicationTasks: make(chan *hrsuNamespaceReplicationTask, taskBufferCapacity), s: s, } // Register test with the suite, so that globally modified task executors can push tasks to test-specific buffers. s.testMapLock.Lock() - s.testsByWorkflowId[tv.WorkflowID()] = &t - s.testsByNamespaceName[ns] = &t + s.testsByWorkflowID[tv.WorkflowID()] = &t s.testMapLock.Unlock() - t.cluster1 = t.newHrsuTestCluster(ns, s.clusters[0]) t.cluster2 = t.newHrsuTestCluster(ns, s.clusters[1]) t.registerMultiRegionNamespace(ctx) @@ -173,6 +152,26 @@ func (t *hrsuTest) newHrsuTestCluster(ns string, cluster *testcore.TestCluster) Logger: log.NewSdkLogger(t.s.logger), }) t.s.NoError(err) + + interceptNamespaceReplicationTask := func( + _ context.Context, + task *replicationspb.NamespaceTaskAttributes, + execute func(context.Context, *replicationspb.NamespaceTaskAttributes) error, + ) error { + t.namespaceReplicationTasks <- &hrsuNamespaceReplicationTask{ + task: task, + execute: execute, + } + return nil + } + + // Inject a hook to intercept namespace replication tasks and buffer them for inspection. + cluster.InjectHook( + t.s.T(), + testhooks.NewHook(testhooks.NamespaceReplicationTaskInterceptor, interceptNamespaceReplicationTask), + namespace.Name(ns), + ) + return hrsuTestCluster{ testCluster: cluster, client: sdkClient, @@ -657,10 +656,10 @@ func (t *hrsuTest) enterSplitBrainState(ctx context.Context) { // type is encountered with the specified failover version. func (t *hrsuTest) executeNamespaceReplicationTasksUntil(ctx context.Context, operation enumsspb.NamespaceOperation) { for { - task := <-t.namespaceReplicationTasks - err := t.s.namespaceTaskExecutor.Execute(ctx, task) + bufferedTask := <-t.namespaceReplicationTasks + err := bufferedTask.execute(ctx, bufferedTask.task) t.s.NoError(err) - if task.NamespaceOperation == operation { + if bufferedTask.task.NamespaceOperation == operation { return } } @@ -694,47 +693,10 @@ func (s *hrsuTestSuite) executeHistoryReplicationTask(task *hrsuTestExecutableTa return events } -func (e *hrsuTestNamespaceReplicationTaskExecutor) Execute(_ context.Context, task *replicationspb.NamespaceTaskAttributes) error { - // TODO (dan) Use one channel per cluster, as we do for history replication tasks in this test suite. This is - // currently blocked by the fact that namespace tasks don't expose the current cluster name. - ns := task.Info.Name - e.s.testMapLock.Lock() - test := e.s.testsByNamespaceName[ns] - e.s.testMapLock.Unlock() - if test == nil { - // This can happen after a test has completed - return fmt.Errorf("failed to retrieve test for namespace %s", ns) - } - test.namespaceReplicationTasks <- task - // Report success, although we have merely buffered the task and will execute it later. - return nil -} - -// Convert the replication tasks using the testcore converter, and wrap them in our own executable tasks. -func (t *hrsuTestExecutableTaskConverter) Convert( - taskClusterName string, - clientShardKey replication.ClusterShardKey, - serverShardKey replication.ClusterShardKey, - replicationTasks ...*replicationspb.ReplicationTask, -) []replication.TrackableExecutableTask { - convertedTasks := t.converter.Convert(taskClusterName, clientShardKey, serverShardKey, replicationTasks...) - testExecutableTasks := make([]replication.TrackableExecutableTask, len(convertedTasks)) - for i, task := range convertedTasks { - testExecutableTasks[i] = &hrsuTestExecutableTask{ - sourceCluster: taskClusterName, - s: t.s, - TrackableExecutableTask: task, - replicationTask: replicationTasks[i], - result: make(chan error), - } - } - return testExecutableTasks -} - // Execute pushes the task to a buffer and waits for it to be executed. func (task *hrsuTestExecutableTask) Execute() error { task.s.testMapLock.Lock() - test := task.s.testsByWorkflowId[task.workflowId()] + test := task.s.testsByWorkflowID[task.workflowId()] task.s.testMapLock.Unlock() if test == nil { return fmt.Errorf("failed to retrieve test for workflow %s", task.workflowId())