Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions common/namespace/nsreplication/replication_admitter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
replicationspb "go.temporal.io/server/api/replication/v1"
"go.temporal.io/server/common/log"
"go.temporal.io/server/common/persistence"
"go.temporal.io/server/common/testing/testhooks"
"go.uber.org/mock/gomock"
)

Expand Down Expand Up @@ -77,6 +78,7 @@ func newExecutorForAdmitterTest(t *testing.T, admitter NamespaceReplicationAdmit
NewNoopDataMerger(),
admitter,
log.NewTestLogger(),
testhooks.NewTestHooks(),
).(*taskExecutorImpl)
return exec, mockMgr
}
Expand Down
19 changes: 19 additions & 0 deletions common/namespace/nsreplication/replication_task_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ import (
replicationspb "go.temporal.io/server/api/replication/v1"
"go.temporal.io/server/common/log"
"go.temporal.io/server/common/log/tag"
"go.temporal.io/server/common/namespace"
"go.temporal.io/server/common/persistence"
"go.temporal.io/server/common/testing/testhooks"
)

var (
Expand Down Expand Up @@ -53,6 +55,7 @@ type (
dataMerger NamespaceDataMerger
admitter NamespaceReplicationAdmitter
logger log.Logger
testHooks testhooks.TestHooks
}
)

Expand All @@ -63,6 +66,7 @@ func NewTaskExecutor(
dataMerger NamespaceDataMerger,
admitter NamespaceReplicationAdmitter,
logger log.Logger,
testHooks testhooks.TestHooks,
) TaskExecutor {

return &taskExecutorImpl{
Expand All @@ -71,6 +75,7 @@ func NewTaskExecutor(
dataMerger: dataMerger,
admitter: admitter,
logger: logger,
testHooks: testHooks,
}
}

Expand All @@ -82,6 +87,20 @@ func (h *taskExecutorImpl) Execute(
if err := h.validateNamespaceReplicationTask(task); err != nil {
return err
}
if hook, ok := testhooks.Get(
h.testHooks,
testhooks.NamespaceReplicationTaskInterceptor,
namespace.Name(task.GetInfo().GetName()),
); ok {
return hook(ctx, task, h.executeValidatedTask)
}
return h.executeValidatedTask(ctx, task)
}

func (h *taskExecutorImpl) executeValidatedTask(
ctx context.Context,
task *replicationspb.NamespaceTaskAttributes,
) error {
if shouldProcess, err := h.shouldProcessTask(ctx, task); !shouldProcess || err != nil {
return err
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
replicationspb "go.temporal.io/server/api/replication/v1"
"go.temporal.io/server/common/log"
"go.temporal.io/server/common/persistence"
"go.temporal.io/server/common/testing/testhooks"
"go.uber.org/mock/gomock"
"google.golang.org/protobuf/types/known/durationpb"
"google.golang.org/protobuf/types/known/timestamppb"
Expand Down Expand Up @@ -54,6 +55,7 @@ func (s *namespaceReplicationTaskExecutorSuite) SetupTest() {
NewNoopDataMerger(),
NewDefaultAdmitter(),
logger,
testhooks.NewTestHooks(),
).(*taskExecutorImpl)
}

Expand Down
6 changes: 6 additions & 0 deletions common/testing/testhooks/hooks.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package testhooks

import (
"context"
"time"

replicationspb "go.temporal.io/server/api/replication/v1"
"go.temporal.io/server/common/namespace"
)

Expand All @@ -18,6 +20,10 @@ var (
MatchingIgnoreRoutingConfigRevisionCheck = newKey[bool, namespace.ID]()
MatchingDeploymentRegisterErrorBackoff = newKey[time.Duration, namespace.ID]()
MatchingForwardTaskDelay = newKey[time.Duration, namespace.ID]()
ReplicationDLQWrite = newKey[func(any), namespace.ID]()
HistoryReplicationTaskInterceptor = newKey[func(any, func() error) error, namespace.ID]()
HistoryReplicationTaskAfterConvert = newKey[func(string, any, any) any, global]()
NamespaceReplicationTaskInterceptor = newKey[func(context.Context, *replicationspb.NamespaceTaskAttributes, func(context.Context, *replicationspb.NamespaceTaskAttributes) error) error, namespace.Name]()
)

// keyID is a unique identifier for a key, used as a map key.
Expand Down
2 changes: 1 addition & 1 deletion common/testing/testhooks/test_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ func newKey[T any, S any]() Key[T, S] {
var zero S
var s ScopeType
switch any(zero).(type) {
case namespace.ID:
case namespace.ID, namespace.Name:
s = ScopeNamespace
case global:
s = ScopeGlobal
Expand Down
2 changes: 2 additions & 0 deletions service/frontend/admin_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ import (
"go.temporal.io/server/common/testing/historyrequire"
"go.temporal.io/server/common/testing/mocksdk"
"go.temporal.io/server/common/testing/protorequire"
"go.temporal.io/server/common/testing/testhooks"
"go.temporal.io/server/common/testing/testvars"
"go.temporal.io/server/service/history/tasks"
"go.temporal.io/server/service/worker/dlq"
Expand Down Expand Up @@ -198,6 +199,7 @@ func (s *adminHandlerSuite) SetupTest() {
nsreplication.NewDefaultAdmitter(),
s.mockResource.GetNamespaceReplicationQueue(),
s.mockResource.GetLogger(),
testhooks.TestHooks{},
)
s.handler = NewAdminHandler(args, namespaceDLQHandler)
s.handler.Start()
Expand Down
3 changes: 3 additions & 0 deletions service/frontend/fx.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ import (
"go.temporal.io/server/common/sdk"
"go.temporal.io/server/common/searchattribute"
"go.temporal.io/server/common/telemetry"
"go.temporal.io/server/common/testing/testhooks"
"go.temporal.io/server/service"
"go.temporal.io/server/service/frontend/configs"
"go.temporal.io/server/service/history/tasks"
Expand Down Expand Up @@ -793,13 +794,15 @@ func NamespaceDLQHandlerProvider(
namespaceAdmitter nsreplication.NamespaceReplicationAdmitter,
namespaceReplicationQueue persistence.NamespaceReplicationQueue,
logger log.SnTaggedLogger,
testHooks testhooks.TestHooks,
) nsreplication.DLQMessageHandler {
taskExecutor := nsreplication.NewTaskExecutor(
clusterMetadata.GetCurrentClusterName(),
persistenceMetadataManager,
namespaceDataMerger,
namespaceAdmitter,
logger,
testHooks,
)
return nsreplication.NewDLQMessageHandler(
taskExecutor,
Expand Down
1 change: 1 addition & 0 deletions service/history/history_engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ func NewEngineWithShardContext(
serializer,
replicationTaskFetcherFactory,
replicationTaskExecutorProvider,
testHooks,
dlqWriter,
)

Expand Down
13 changes: 11 additions & 2 deletions service/history/replication/dlq_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ import (
"context"

persistencespb "go.temporal.io/server/api/persistence/v1"
"go.temporal.io/server/common/namespace"
"go.temporal.io/server/common/persistence"
"go.temporal.io/server/common/testing/testhooks"
"go.temporal.io/server/service/history/configs"
"go.temporal.io/server/service/history/queues"
"go.uber.org/fx"
Expand Down Expand Up @@ -49,6 +51,7 @@ type (
Config *configs.Config
ExecutionManagerDLQWriter *executionManagerDLQWriter
DLQWriterAdapter *DLQWriterAdapter
TestHooks testhooks.TestHooks
}
dlqWriterToggle struct {
*dlqWriterToggleParams
Expand Down Expand Up @@ -88,10 +91,16 @@ func newDLQWriterToggle(
// - QueueV1: [ExecutionManagerDLQWriter.WriteTaskToDLQ]
// - QueueV2: [DLQWriterAdapter.WriteTaskToDLQ]
func (d *dlqWriterToggle) WriteTaskToDLQ(ctx context.Context, request DLQWriteRequest) error {
var err error
if d.Config.HistoryReplicationDLQV2() {
return d.DLQWriterAdapter.WriteTaskToDLQ(ctx, request)
err = d.DLQWriterAdapter.WriteTaskToDLQ(ctx, request)
} else {
err = d.ExecutionManagerDLQWriter.WriteTaskToDLQ(ctx, request)
}
return d.ExecutionManagerDLQWriter.WriteTaskToDLQ(ctx, request)
if hook, ok := testhooks.Get(d.TestHooks, testhooks.ReplicationDLQWrite, namespace.ID(request.ReplicationTaskInfo.GetNamespaceId())); ok {
hook(request)
}
return err
}

// WriteTaskToDLQ implements [DLQWriter.WriteTaskToDLQ] by calling [persistence.ExecutionManager.PutReplicationTaskToDLQ].
Expand Down
36 changes: 36 additions & 0 deletions service/history/replication/executable_task.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
serviceerrors "go.temporal.io/server/common/serviceerror"
"go.temporal.io/server/common/softassert"
ctasks "go.temporal.io/server/common/tasks"
"go.temporal.io/server/common/testing/testhooks"
"go.temporal.io/server/service/history/consts"
historyi "go.temporal.io/server/service/history/interfaces"
"go.temporal.io/server/service/history/tasks"
Expand Down Expand Up @@ -168,6 +169,41 @@ func (e *ExecutableTaskImpl) ReplicationTask() *replicationspb.ReplicationTask {
return e.replicationTask
}

func executeReplicationTask(
testHooks testhooks.TestHooks,
replicationTask *replicationspb.ReplicationTask,
execute func() error,
) error {
if hook, ok := testhooks.Get(testHooks, testhooks.HistoryReplicationTaskInterceptor, replicationTaskNamespaceID(replicationTask)); ok {
return hook(replicationTask, execute)
}
return execute()
}

func replicationTaskNamespaceID(replicationTask *replicationspb.ReplicationTask) namespace.ID {
if rawTaskInfo := replicationTask.GetRawTaskInfo(); rawTaskInfo != nil {
return namespace.ID(rawTaskInfo.GetNamespaceId())
}
switch attr := replicationTask.GetAttributes().(type) {
case *replicationspb.ReplicationTask_SyncWorkflowStateTaskAttributes:
return namespace.ID(attr.SyncWorkflowStateTaskAttributes.GetWorkflowState().GetExecutionInfo().GetNamespaceId())
case *replicationspb.ReplicationTask_SyncActivityTaskAttributes:
return namespace.ID(attr.SyncActivityTaskAttributes.GetNamespaceId())
case *replicationspb.ReplicationTask_HistoryTaskAttributes:
return namespace.ID(attr.HistoryTaskAttributes.GetNamespaceId())
case *replicationspb.ReplicationTask_SyncHsmAttributes:
return namespace.ID(attr.SyncHsmAttributes.GetNamespaceId())
case *replicationspb.ReplicationTask_BackfillHistoryTaskAttributes:
return namespace.ID(attr.BackfillHistoryTaskAttributes.GetNamespaceId())
case *replicationspb.ReplicationTask_VerifyVersionedTransitionTaskAttributes:
return namespace.ID(attr.VerifyVersionedTransitionTaskAttributes.GetNamespaceId())
case *replicationspb.ReplicationTask_SyncVersionedTransitionTaskAttributes:
return namespace.ID(attr.SyncVersionedTransitionTaskAttributes.GetNamespaceId())
default:
return ""
}
}

func (e *ExecutableTaskImpl) Ack() {
if atomic.LoadInt32(&e.taskState) != taskStatePending {
return
Expand Down
10 changes: 10 additions & 0 deletions service/history/replication/executable_task_converter.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
enumsspb "go.temporal.io/server/api/enums/v1"
replicationspb "go.temporal.io/server/api/replication/v1"
"go.temporal.io/server/common/metrics"
"go.temporal.io/server/common/testing/testhooks"
)

type (
Expand Down Expand Up @@ -47,6 +48,15 @@ func (e *executableTaskConverterImpl) Convert(
metrics.OperationTag(TaskOperationTag(replicationTask)),
)
tasks[index] = e.convertOne(sourceClusterName, serverShardKey, replicationTask)
if hook, ok := testhooks.Get(
e.processToolBox.TestHooks,
testhooks.HistoryReplicationTaskAfterConvert,
testhooks.GlobalScope,
); ok {
if task, ok := hook(sourceClusterName, replicationTask, tasks[index]).(TrackableExecutableTask); ok {
tasks[index] = task
}
}
}
return tasks
}
Expand Down
2 changes: 2 additions & 0 deletions service/history/replication/executable_task_tool_box.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"go.temporal.io/server/common/namespace"
"go.temporal.io/server/common/persistence/serialization"
ctasks "go.temporal.io/server/common/tasks"
"go.temporal.io/server/common/testing/testhooks"
"go.temporal.io/server/service/history/configs"
"go.temporal.io/server/service/history/replication/eventhandler"
"go.temporal.io/server/service/history/shard"
Expand Down Expand Up @@ -36,6 +37,7 @@ type (
Logger log.Logger
ThrottledLogger log.ThrottledLogger
Serializer serialization.Serializer
TestHooks testhooks.TestHooks
DLQWriter DLQWriter
HistoryEventsHandler eventhandler.HistoryEventsHandler
WorkflowCache wcache.Cache
Expand Down
3 changes: 3 additions & 0 deletions service/history/replication/fx.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"go.temporal.io/server/common/persistence/serialization"
"go.temporal.io/server/common/quotas"
ctasks "go.temporal.io/server/common/tasks"
"go.temporal.io/server/common/testing/testhooks"
"go.temporal.io/server/service/history/configs"
historyi "go.temporal.io/server/service/history/interfaces"
"go.temporal.io/server/service/history/queues"
Expand Down Expand Up @@ -83,6 +84,7 @@ func eagerNamespaceRefresherProvider(
dataMerger nsreplication.NamespaceDataMerger,
admitter nsreplication.NamespaceReplicationAdmitter,
metricsHandler metrics.Handler,
testHooks testhooks.TestHooks,
) EagerNamespaceRefresher {
return NewEagerNamespaceRefresher(
metadataManager,
Expand All @@ -95,6 +97,7 @@ func eagerNamespaceRefresherProvider(
dataMerger,
admitter,
logger,
testHooks,
),
clusterMetadata.GetCurrentClusterName(),
metricsHandler,
Expand Down
8 changes: 7 additions & 1 deletion service/history/replication/task_processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"go.temporal.io/server/common/primitives/timestamp"
"go.temporal.io/server/common/quotas"
serviceerrors "go.temporal.io/server/common/serviceerror"
"go.temporal.io/server/common/testing/testhooks"
"go.temporal.io/server/service/history/configs"
historyi "go.temporal.io/server/service/history/interfaces"
"go.temporal.io/server/service/history/shard"
Expand Down Expand Up @@ -64,6 +65,7 @@ type (
config *configs.Config
metricsHandler metrics.Handler
logger log.Logger
testHooks testhooks.TestHooks
replicationTaskExecutor TaskExecutor
dlqWriter DLQWriter

Expand Down Expand Up @@ -99,6 +101,7 @@ func NewTaskProcessor(
replicationTaskFetcher taskFetcher,
replicationTaskExecutor TaskExecutor,
eventSerializer serialization.Serializer,
testHooks testhooks.TestHooks,
dlqWriter DLQWriter,
) TaskProcessor {
shardID := shardContext.GetShardID()
Expand All @@ -125,6 +128,7 @@ func NewTaskProcessor(
config: config,
metricsHandler: metricsHandler,
logger: shardContext.GetLogger(),
testHooks: testHooks,
replicationTaskExecutor: replicationTaskExecutor,
dlqWriter: dlqWriter,
rateLimiter: quotas.NewMultiRateLimiter([]quotas.RateLimiter{
Expand Down Expand Up @@ -313,7 +317,9 @@ func (p *taskProcessorImpl) handleReplicationTask(
operationTagValue := p.getOperationTagValue(replicationTask)

operation := func() error {
err := p.replicationTaskExecutor.Execute(ctx, replicationTask, false)
err := executeReplicationTask(p.testHooks, replicationTask, func() error {
return p.replicationTaskExecutor.Execute(ctx, replicationTask, false)
})
p.emitTaskMetrics(operationTagValue, err)
return err
}
Expand Down
5 changes: 5 additions & 0 deletions service/history/replication/task_processor_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"go.temporal.io/server/common/metrics"
"go.temporal.io/server/common/persistence"
"go.temporal.io/server/common/persistence/serialization"
"go.temporal.io/server/common/testing/testhooks"
"go.temporal.io/server/service/history/configs"
"go.temporal.io/server/service/history/deletemanager"
historyi "go.temporal.io/server/service/history/interfaces"
Expand Down Expand Up @@ -46,6 +47,7 @@ type (
taskPollerManager pollerManager
metricsHandler metrics.Handler
logger log.Logger
testHooks testhooks.TestHooks
dlqWriter DLQWriter

enableFetcher bool
Expand All @@ -66,6 +68,7 @@ func NewTaskProcessorManager(
eventSerializer serialization.Serializer,
replicationTaskFetcherFactory TaskFetcherFactory,
taskExecutorProvider TaskExecutorProvider,
testHooks testhooks.TestHooks,
dlqWriter DLQWriter,
) *taskProcessorManagerImpl {
historyFetcher := eventhandler.NewHistoryPaginatedFetcher(shardContext.GetNamespaceRegistry(), clientBean, eventSerializer, shardContext.GetLogger())
Expand All @@ -81,6 +84,7 @@ func NewTaskProcessorManager(
removeHistoryFetcher: historyFetcher,
logger: shardContext.GetLogger(),
metricsHandler: shardContext.GetMetricsHandler(),
testHooks: testHooks,
dlqWriter: dlqWriter,

enableFetcher: !config.EnableReplicationStream(),
Expand Down Expand Up @@ -193,6 +197,7 @@ func (r *taskProcessorManagerImpl) handleClusterMetadataUpdate(
WorkflowCache: r.workflowCache,
}),
r.eventSerializer,
r.testHooks,
r.dlqWriter,
)
replicationTaskProcessor.Start()
Expand Down
Loading
Loading