diff --git a/chasm/lib/activity/activity.go b/chasm/lib/activity/activity.go index df45a9ac490..59f02aff550 100644 --- a/chasm/lib/activity/activity.go +++ b/chasm/lib/activity/activity.go @@ -201,6 +201,29 @@ func (a *Activity) createAddActivityTaskRequest(ctx chasm.Context, namespaceID s }, nil } +// buildCancelCommandTaskToken builds the serialized task token for a cancel command. +// This token identifies the same activity as the poll response token but is not byte-identical — +// matching builds poll tokens with additional fields (Clock, Version, etc.). +func (a *Activity) buildCancelCommandTaskToken(ctx chasm.Context, activityRef chasm.ComponentRef) ([]byte, error) { + componentRefBytes, err := ctx.Ref(a) + if err != nil { + return nil, err + } + + attempt := a.LastAttempt.Get(ctx) + key := ctx.ExecutionKey() + + token := &tokenspb.Task{ + NamespaceId: key.NamespaceID, + ActivityId: key.BusinessID, + ActivityType: a.GetActivityType().GetName(), + Attempt: attempt.GetCount(), + ComponentRef: componentRefBytes, + } + + return token.Marshal() +} + // HandleStarted updates the activity on recording activity task started and populates the response. func (a *Activity) HandleStarted(ctx chasm.MutableContext, request *historyservice.RecordActivityTaskStartedRequest) ( *historyservice.RecordActivityTaskStartedResponse, error, @@ -505,6 +528,13 @@ func (a *Activity) Terminate( return chasm.TerminateComponentResponse{}, nil } + // If the activity is running on a worker, proactively notify the worker via Nexus. + // Must be done before the transition since it checks current status. + if a.GetStatus() == activitypb.ACTIVITY_EXECUTION_STATUS_STARTED || + a.GetStatus() == activitypb.ACTIVITY_EXECUTION_STATUS_CANCEL_REQUESTED { + a.addCancelCommandDispatchTask(ctx) + } + metricsHandler, err := a.enrichMetricsHandler(ctx, metrics.ActivityTerminatedScope) if err != nil { return chasm.TerminateComponentResponse{}, err @@ -527,6 +557,23 @@ func (a *Activity) getOrCreateLastHeartbeat(ctx chasm.MutableContext) *activityp return heartbeat } +// addCancelCommandDispatchTask schedules a side-effect task to dispatch a cancel command to the +// worker via the Nexus worker commands control queue. No-op if the worker doesn't support worker +// commands (i.e., has no control queue). +func (a *Activity) addCancelCommandDispatchTask(ctx chasm.MutableContext) { + controlQueue := a.LastAttempt.Get(ctx).GetWorkerControlTaskQueue() + if controlQueue == "" { + return + } + ctx.AddTask( + a, + chasm.TaskAttributes{ + Destination: controlQueue, + }, + &activitypb.CancelCommandDispatchTask{}, + ) +} + func (a *Activity) handleCancellationRequested(ctx chasm.MutableContext, request *activitypb.RequestCancelActivityExecutionRequest) ( *activitypb.RequestCancelActivityExecutionResponse, error, ) { @@ -551,6 +598,10 @@ func (a *Activity) handleCancellationRequested(ctx chasm.MutableContext, request return nil, err } + if !isCancelImmediately { + a.addCancelCommandDispatchTask(ctx) + } + if isCancelImmediately { details := &commonpb.Payloads{ Payloads: []*commonpb.Payload{ diff --git a/chasm/lib/activity/activity_tasks.go b/chasm/lib/activity/activity_tasks.go index e22b2f586a6..3396f5539dc 100644 --- a/chasm/lib/activity/activity_tasks.go +++ b/chasm/lib/activity/activity_tasks.go @@ -2,14 +2,30 @@ package activity import ( "context" + "errors" + "fmt" + "time" + "github.com/nexus-rpc/sdk-go/nexus" + commonpb "go.temporal.io/api/common/v1" enumspb "go.temporal.io/api/enums/v1" + nexuspb "go.temporal.io/api/nexus/v1" + workerservicepb "go.temporal.io/api/nexusservices/workerservice/v1" + taskqueuepb "go.temporal.io/api/taskqueue/v1" + workerpb "go.temporal.io/api/worker/v1" + "go.temporal.io/server/api/matchingservice/v1" "go.temporal.io/server/chasm" "go.temporal.io/server/chasm/lib/activity/gen/activitypb/v1" + "go.temporal.io/server/common/debug" + "go.temporal.io/server/common/log" + "go.temporal.io/server/common/log/tag" "go.temporal.io/server/common/metrics" + commonnexus "go.temporal.io/server/common/nexus" "go.temporal.io/server/common/resource" "go.temporal.io/server/common/util" + "go.temporal.io/server/service/history/configs" "go.uber.org/fx" + "google.golang.org/protobuf/proto" ) type activityDispatchTaskHandlerOptions struct { @@ -277,3 +293,171 @@ func (h *heartbeatTimeoutTaskHandler) Execute( fromStatus: activity.GetStatus(), }) } + +// cancelCommandDispatchTaskHandler dispatches a cancel command to the worker via the Nexus +// worker commands control queue. This is a best-effort mechanism — the activity will eventually +// time out if the worker doesn't respond. +type cancelCommandDispatchTaskHandler struct { + chasm.SideEffectTaskHandlerBase[*activitypb.CancelCommandDispatchTask] + opts cancelCommandDispatchTaskHandlerOptions +} + +type cancelCommandDispatchTaskHandlerOptions struct { + fx.In + + MatchingClient resource.MatchingClient + Config *configs.Config + MetricsHandler metrics.Handler + Logger log.Logger +} + +func newCancelCommandDispatchTaskHandler(opts cancelCommandDispatchTaskHandlerOptions) *cancelCommandDispatchTaskHandler { + return &cancelCommandDispatchTaskHandler{opts: opts} +} + +func (h *cancelCommandDispatchTaskHandler) Validate( + _ chasm.Context, + activity *Activity, + _ chasm.TaskAttributes, + _ *activitypb.CancelCommandDispatchTask, +) (bool, error) { + // Valid if the activity is in a state where it has been requested to cancel or terminated + // (meaning it was running on a worker when the cancel/terminate was issued). + return activity.Status == activitypb.ACTIVITY_EXECUTION_STATUS_CANCEL_REQUESTED || + activity.Status == activitypb.ACTIVITY_EXECUTION_STATUS_TERMINATED, nil +} + +const ( + cancelCommandDispatchTimeout = time.Second * 10 * debug.TimeoutMultiplier + + workerCommandsServiceName = "temporal.api.nexusservices.workerservice.v1.WorkerService" + workerCommandsOperationName = "ExecuteCommands" +) + +func (h *cancelCommandDispatchTaskHandler) Execute( + ctx context.Context, + activityRef chasm.ComponentRef, + taskAttrs chasm.TaskAttributes, + _ *activitypb.CancelCommandDispatchTask, +) error { + if !h.opts.Config.EnableCancelActivityWorkerCommand() { + return nil + } + + // Read the activity to build the task token for the cancel command. + taskToken, err := chasm.ReadComponent( + ctx, + activityRef, + (*Activity).buildCancelCommandTaskToken, + activityRef, + ) + if err != nil { + return err + } + + command := &workerpb.WorkerCommand{ + Type: &workerpb.WorkerCommand_CancelActivity{ + CancelActivity: &workerpb.CancelActivityCommand{ + TaskToken: taskToken, + }, + }, + } + + return h.dispatchToWorker(ctx, activityRef.NamespaceID, taskAttrs.Destination, []*workerpb.WorkerCommand{command}) +} + +func (h *cancelCommandDispatchTaskHandler) dispatchToWorker( + ctx context.Context, + namespaceID string, + controlQueue string, + commands []*workerpb.WorkerCommand, +) error { + ctx, cancel := context.WithTimeout(ctx, cancelCommandDispatchTimeout) + defer cancel() + + request := &workerservicepb.ExecuteCommandsRequest{ + Commands: commands, + } + requestData, err := proto.Marshal(request) + if err != nil { + return fmt.Errorf("failed to encode worker commands request: %w", err) + } + requestPayload := &commonpb.Payload{ + Metadata: map[string][]byte{ + "encoding": []byte("binary/protobuf"), + }, + Data: requestData, + } + + nexusRequest := &nexuspb.Request{ + Header: map[string]string{}, + Variant: &nexuspb.Request_StartOperation{ + StartOperation: &nexuspb.StartOperationRequest{ + Service: workerCommandsServiceName, + Operation: workerCommandsOperationName, + Payload: requestPayload, + }, + }, + } + + resp, err := h.opts.MatchingClient.DispatchNexusTask(ctx, &matchingservice.DispatchNexusTaskRequest{ + NamespaceId: namespaceID, + TaskQueue: &taskqueuepb.TaskQueue{ + Name: controlQueue, + Kind: enumspb.TASK_QUEUE_KIND_WORKER_COMMANDS, + }, + Request: nexusRequest, + }) + if err != nil { + h.opts.Logger.Warn("Failed to dispatch cancel command", + tag.NewStringTag("control_queue", controlQueue), + tag.Error(err)) + metrics.WorkerCommandsSent.With(h.opts.MetricsHandler).Record(1, metrics.OutcomeTag("rpc_error")) + return err + } + + nexusErr := commonnexus.DispatchResponseToError(resp) + if nexusErr == nil { + metrics.WorkerCommandsSent.With(h.opts.MetricsHandler).Record(1, metrics.OutcomeTag("success")) + return nil + } + + return h.handleDispatchError(nexusErr, controlQueue) +} + +func (h *cancelCommandDispatchTaskHandler) handleDispatchError(nexusErr error, controlQueue string) error { + var handlerErr *nexus.HandlerError + if errors.As(nexusErr, &handlerErr) { + // Handler-level error (transport, timeout, internal). + if handlerErr.Type == nexus.HandlerErrorTypeUpstreamTimeout { + h.opts.Logger.Warn("No worker polling control queue", + tag.NewStringTag("control_queue", controlQueue)) + metrics.WorkerCommandsSent.With(h.opts.MetricsHandler).Record(1, metrics.OutcomeTag("no_poller")) + return nexusErr + } + + if !handlerErr.Retryable() { + h.opts.Logger.Error("Cancel command non-retryable handler error", + tag.NewStringTag("control_queue", controlQueue), + tag.Error(nexusErr)) + metrics.WorkerCommandsSent.With(h.opts.MetricsHandler).Record(1, metrics.OutcomeTag("non_retryable_error")) + return nil + } + + h.opts.Logger.Warn("Cancel command transport failure", + tag.NewStringTag("control_queue", controlQueue), + tag.Error(nexusErr)) + metrics.WorkerCommandsSent.With(h.opts.MetricsHandler).Record(1, metrics.OutcomeTag("transport_error")) + return nexusErr + } + + // Worker-returned failure (ApplicationError, CanceledError, etc.). The worker received + // and processed the request but returned an error. Permanent — the worker contract + // requires success for all defined commands, so this indicates a bug or version + // incompatibility. Retrying won't help. + h.opts.Logger.Error("Worker returned failure for cancel command", + tag.NewStringTag("control_queue", controlQueue), + tag.Error(nexusErr)) + metrics.WorkerCommandsSent.With(h.opts.MetricsHandler).Record(1, metrics.OutcomeTag("worker_error")) + return nil +} diff --git a/chasm/lib/activity/activity_test.go b/chasm/lib/activity/activity_test.go index b614c3eba7a..70a39e69066 100644 --- a/chasm/lib/activity/activity_test.go +++ b/chasm/lib/activity/activity_test.go @@ -8,6 +8,7 @@ import ( "github.com/stretchr/testify/require" commonpb "go.temporal.io/api/common/v1" taskqueuepb "go.temporal.io/api/taskqueue/v1" + "go.temporal.io/api/workflowservice/v1" "go.temporal.io/server/api/historyservice/v1" "go.temporal.io/server/chasm" "go.temporal.io/server/chasm/lib/activity/gen/activitypb/v1" @@ -305,3 +306,214 @@ func TestContextMetadata(t *testing.T) { require.Nil(t, md) }) } + +func TestTransitionStartedStoresWorkerControlTaskQueue(t *testing.T) { + testTime := time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC) + ctx := &chasm.MockMutableContext{ + MockContext: chasm.MockContext{ + HandleNow: func(chasm.Component) time.Time { return testTime }, + HandleExecutionKey: func() chasm.ExecutionKey { + return chasm.ExecutionKey{BusinessID: "test-activity-id", RunID: "test-run-id"} + }, + }, + } + + attemptState := &activitypb.ActivityAttemptState{Count: 1, Stamp: 1} + a := &Activity{ + ActivityState: &activitypb.ActivityState{ + ActivityType: &commonpb.ActivityType{Name: "test-type"}, + Status: activitypb.ACTIVITY_EXECUTION_STATUS_SCHEDULED, + TaskQueue: &taskqueuepb.TaskQueue{Name: "test-queue"}, + StartToCloseTimeout: durationpb.New(3 * time.Minute), + }, + LastAttempt: chasm.NewDataField(ctx, attemptState), + RequestData: chasm.NewDataField(ctx, &activitypb.ActivityRequestData{}), + Outcome: chasm.NewDataField(ctx, &activitypb.ActivityOutcome{}), + } + + request := &historyservice.RecordActivityTaskStartedRequest{ + Stamp: 1, + RequestId: "req-1", + PollRequest: &workflowservice.PollActivityTaskQueueRequest{ + WorkerControlTaskQueue: "test-control-queue", + }, + } + + _, err := a.HandleStarted(ctx, request) + require.NoError(t, err) + require.Equal(t, "test-control-queue", a.LastAttempt.Get(ctx).GetWorkerControlTaskQueue()) +} + +func TestCancelRequestDispatchesCancelCommand(t *testing.T) { + testTime := time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC) + + testCases := []struct { + name string + activityStatus activitypb.ActivityExecutionStatus + controlQueue string + expectDispatchTask bool + }{ + { + name: "started with control queue dispatches cancel task", + activityStatus: activitypb.ACTIVITY_EXECUTION_STATUS_STARTED, + controlQueue: "test-control-queue", + expectDispatchTask: true, + }, + { + name: "started without control queue does not dispatch", + activityStatus: activitypb.ACTIVITY_EXECUTION_STATUS_STARTED, + controlQueue: "", + expectDispatchTask: false, + }, + { + name: "scheduled cancels immediately, no dispatch", + activityStatus: activitypb.ACTIVITY_EXECUTION_STATUS_SCHEDULED, + controlQueue: "", + expectDispatchTask: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + nsRegistry := namespace.NewMockRegistry(ctrl) + nsRegistry.EXPECT().GetNamespaceName(gomock.Any()).Return(namespace.Name("test-ns"), nil).AnyTimes() + + ctx := &chasm.MockMutableContext{ + MockContext: chasm.MockContext{ + HandleNow: func(chasm.Component) time.Time { return testTime }, + GoCtx: context.WithValue(context.Background(), ctxKeyActivityContext, &activityContext{ + config: &Config{ + BreakdownMetricsByTaskQueue: dynamicconfig.GetBoolPropertyFnFilteredByTaskQueue(true), + }, + namespaceRegistry: nsRegistry, + }), + }, + } + + a := &Activity{ + ActivityState: &activitypb.ActivityState{ + ActivityType: &commonpb.ActivityType{Name: "test-type"}, + Status: tc.activityStatus, + TaskQueue: &taskqueuepb.TaskQueue{Name: "test-queue"}, + ScheduleToCloseTimeout: durationpb.New(10 * time.Minute), + StartToCloseTimeout: durationpb.New(3 * time.Minute), + }, + LastAttempt: chasm.NewDataField(ctx, &activitypb.ActivityAttemptState{ + Count: 1, + Stamp: 1, + WorkerControlTaskQueue: tc.controlQueue, + }), + Outcome: chasm.NewDataField(ctx, &activitypb.ActivityOutcome{}), + } + + req := &activitypb.RequestCancelActivityExecutionRequest{ + FrontendRequest: &workflowservice.RequestCancelActivityExecutionRequest{ + RequestId: "cancel-req-1", + Identity: "test-identity", + }, + } + _, err := a.handleCancellationRequested(ctx, req) + require.NoError(t, err) + + hasCancelTask := false + for _, task := range ctx.Tasks { + if _, ok := task.Payload.(*activitypb.CancelCommandDispatchTask); ok { + hasCancelTask = true + require.Equal(t, tc.controlQueue, task.Attributes.Destination) + } + } + require.Equal(t, tc.expectDispatchTask, hasCancelTask, + "expected dispatch task: %v, but found: %v", tc.expectDispatchTask, hasCancelTask) + }) + } +} + +func TestTerminateDispatchesCancelCommand(t *testing.T) { + testTime := time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC) + + testCases := []struct { + name string + activityStatus activitypb.ActivityExecutionStatus + controlQueue string + expectDispatchTask bool + }{ + { + name: "started with control queue dispatches cancel task", + activityStatus: activitypb.ACTIVITY_EXECUTION_STATUS_STARTED, + controlQueue: "test-control-queue", + expectDispatchTask: true, + }, + { + name: "cancel_requested with control queue dispatches cancel task", + activityStatus: activitypb.ACTIVITY_EXECUTION_STATUS_CANCEL_REQUESTED, + controlQueue: "test-control-queue", + expectDispatchTask: true, + }, + { + name: "started without control queue does not dispatch", + activityStatus: activitypb.ACTIVITY_EXECUTION_STATUS_STARTED, + controlQueue: "", + expectDispatchTask: false, + }, + { + name: "scheduled does not dispatch", + activityStatus: activitypb.ACTIVITY_EXECUTION_STATUS_SCHEDULED, + controlQueue: "", + expectDispatchTask: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + nsRegistry := namespace.NewMockRegistry(ctrl) + nsRegistry.EXPECT().GetNamespaceName(gomock.Any()).Return(namespace.Name("test-ns"), nil).AnyTimes() + + ctx := &chasm.MockMutableContext{ + MockContext: chasm.MockContext{ + HandleNow: func(chasm.Component) time.Time { return testTime }, + GoCtx: context.WithValue(context.Background(), ctxKeyActivityContext, &activityContext{ + config: &Config{ + BreakdownMetricsByTaskQueue: dynamicconfig.GetBoolPropertyFnFilteredByTaskQueue(true), + }, + namespaceRegistry: nsRegistry, + }), + }, + } + + a := &Activity{ + ActivityState: &activitypb.ActivityState{ + ActivityType: &commonpb.ActivityType{Name: "test-type"}, + Status: tc.activityStatus, + TaskQueue: &taskqueuepb.TaskQueue{Name: "test-queue"}, + ScheduleToCloseTimeout: durationpb.New(10 * time.Minute), + StartToCloseTimeout: durationpb.New(3 * time.Minute), + }, + LastAttempt: chasm.NewDataField(ctx, &activitypb.ActivityAttemptState{ + Count: 1, + Stamp: 1, + WorkerControlTaskQueue: tc.controlQueue, + }), + Outcome: chasm.NewDataField(ctx, &activitypb.ActivityOutcome{}), + } + + _, err := a.Terminate(ctx, chasm.TerminateComponentRequest{ + Reason: "test terminate", + }) + require.NoError(t, err) + + hasCancelTask := false + for _, task := range ctx.Tasks { + if _, ok := task.Payload.(*activitypb.CancelCommandDispatchTask); ok { + hasCancelTask = true + require.Equal(t, tc.controlQueue, task.Attributes.Destination) + } + } + require.Equal(t, tc.expectDispatchTask, hasCancelTask, + "expected dispatch task: %v, but found: %v", tc.expectDispatchTask, hasCancelTask) + }) + } +} diff --git a/chasm/lib/activity/fx.go b/chasm/lib/activity/fx.go index 905042382c2..0639862b381 100644 --- a/chasm/lib/activity/fx.go +++ b/chasm/lib/activity/fx.go @@ -12,6 +12,7 @@ var HistoryModule = fx.Module( fx.Provide( ConfigProvider, newActivityDispatchTaskHandler, + newCancelCommandDispatchTaskHandler, newScheduleToStartTimeoutTaskHandler, newScheduleToCloseTimeoutTaskHandler, newStartToCloseTimeoutTaskHandler, diff --git a/chasm/lib/activity/gen/activitypb/v1/activity_state.pb.go b/chasm/lib/activity/gen/activitypb/v1/activity_state.pb.go index 3e95ee84f59..90034d167b6 100644 --- a/chasm/lib/activity/gen/activitypb/v1/activity_state.pb.go +++ b/chasm/lib/activity/gen/activitypb/v1/activity_state.pb.go @@ -456,8 +456,11 @@ type ActivityAttemptState struct { // The request ID that came from matching's RecordActivityTaskStarted API call. Used to make this API idempotent in // case of implicit retries. StartRequestId string `protobuf:"bytes,9,opt,name=start_request_id,json=startRequestId,proto3" json:"start_request_id,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + // The worker's control task queue for sending commands (e.g. cancel) via Nexus. + // Set when the worker reports it during poll. Empty if the worker doesn't support worker commands. + WorkerControlTaskQueue string `protobuf:"bytes,10,opt,name=worker_control_task_queue,json=workerControlTaskQueue,proto3" json:"worker_control_task_queue,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *ActivityAttemptState) Reset() { @@ -553,6 +556,13 @@ func (x *ActivityAttemptState) GetStartRequestId() string { return "" } +func (x *ActivityAttemptState) GetWorkerControlTaskQueue() string { + if x != nil { + return x.WorkerControlTaskQueue + } + return "" +} + type ActivityHeartbeatState struct { state protoimpl.MessageState `protogen:"open.v1"` // Details provided in the last recorded activity heartbeat. @@ -934,7 +944,7 @@ const file_temporal_server_chasm_lib_activity_proto_v1_activity_state_proto_rawD "\x06reason\x18\x04 \x01(\tR\x06reason\"7\n" + "\x16ActivityTerminateState\x12\x1d\n" + "\n" + - "request_id\x18\x01 \x01(\tR\trequestId\"\xe8\x05\n" + + "request_id\x18\x01 \x01(\tR\trequestId\"\xa3\x06\n" + "\x14ActivityAttemptState\x12\x14\n" + "\x05count\x18\x01 \x01(\x05R\x05count\x12O\n" + "\x16current_retry_interval\x18\x02 \x01(\v2\x19.google.protobuf.DurationR\x14currentRetryInterval\x12=\n" + @@ -944,7 +954,9 @@ const file_temporal_server_chasm_lib_activity_proto_v1_activity_state_proto_rawD "\x05stamp\x18\x06 \x01(\x05R\x05stamp\x120\n" + "\x14last_worker_identity\x18\a \x01(\tR\x12lastWorkerIdentity\x12k\n" + "\x17last_deployment_version\x18\b \x01(\v23.temporal.api.deployment.v1.WorkerDeploymentVersionR\x15lastDeploymentVersion\x12(\n" + - "\x10start_request_id\x18\t \x01(\tR\x0estartRequestId\x1a\x80\x01\n" + + "\x10start_request_id\x18\t \x01(\tR\x0estartRequestId\x129\n" + + "\x19worker_control_task_queue\x18\n" + + " \x01(\tR\x16workerControlTaskQueue\x1a\x80\x01\n" + "\x12LastFailureDetails\x12.\n" + "\x04time\x18\x01 \x01(\v2\x1a.google.protobuf.TimestampR\x04time\x12:\n" + "\afailure\x18\x02 \x01(\v2 .temporal.api.failure.v1.FailureR\afailure\"\xc9\x01\n" + diff --git a/chasm/lib/activity/gen/activitypb/v1/tasks.go-helpers.pb.go b/chasm/lib/activity/gen/activitypb/v1/tasks.go-helpers.pb.go index d7628a6e9e6..a4173d9659f 100644 --- a/chasm/lib/activity/gen/activitypb/v1/tasks.go-helpers.pb.go +++ b/chasm/lib/activity/gen/activitypb/v1/tasks.go-helpers.pb.go @@ -189,3 +189,40 @@ func (this *HeartbeatTimeoutTask) Equal(that interface{}) bool { return proto.Equal(this, that1) } + +// Marshal an object of type CancelCommandDispatchTask to the protobuf v3 wire format +func (val *CancelCommandDispatchTask) Marshal() ([]byte, error) { + return proto.Marshal(val) +} + +// Unmarshal an object of type CancelCommandDispatchTask from the protobuf v3 wire format +func (val *CancelCommandDispatchTask) Unmarshal(buf []byte) error { + return proto.Unmarshal(buf, val) +} + +// Size returns the size of the object, in bytes, once serialized +func (val *CancelCommandDispatchTask) Size() int { + return proto.Size(val) +} + +// Equal returns whether two CancelCommandDispatchTask values are equivalent by recursively +// comparing the message's fields. +// For more information see the documentation for +// https://pkg.go.dev/google.golang.org/protobuf/proto#Equal +func (this *CancelCommandDispatchTask) Equal(that interface{}) bool { + if that == nil { + return this == nil + } + + var that1 *CancelCommandDispatchTask + switch t := that.(type) { + case *CancelCommandDispatchTask: + that1 = t + case CancelCommandDispatchTask: + that1 = &t + default: + return false + } + + return proto.Equal(this, that1) +} diff --git a/chasm/lib/activity/gen/activitypb/v1/tasks.pb.go b/chasm/lib/activity/gen/activitypb/v1/tasks.pb.go index 796574e7db2..23fc96a8db5 100644 --- a/chasm/lib/activity/gen/activitypb/v1/tasks.pb.go +++ b/chasm/lib/activity/gen/activitypb/v1/tasks.pb.go @@ -239,6 +239,44 @@ func (x *HeartbeatTimeoutTask) GetStamp() int32 { return 0 } +// CancelCommandDispatchTask is a side-effect task that dispatches a cancel command to the worker +// via the Nexus worker commands control queue. +type CancelCommandDispatchTask struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *CancelCommandDispatchTask) Reset() { + *x = CancelCommandDispatchTask{} + mi := &file_temporal_server_chasm_lib_activity_proto_v1_tasks_proto_msgTypes[5] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *CancelCommandDispatchTask) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*CancelCommandDispatchTask) ProtoMessage() {} + +func (x *CancelCommandDispatchTask) ProtoReflect() protoreflect.Message { + mi := &file_temporal_server_chasm_lib_activity_proto_v1_tasks_proto_msgTypes[5] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use CancelCommandDispatchTask.ProtoReflect.Descriptor instead. +func (*CancelCommandDispatchTask) Descriptor() ([]byte, []int) { + return file_temporal_server_chasm_lib_activity_proto_v1_tasks_proto_rawDescGZIP(), []int{5} +} + var File_temporal_server_chasm_lib_activity_proto_v1_tasks_proto protoreflect.FileDescriptor const file_temporal_server_chasm_lib_activity_proto_v1_tasks_proto_rawDesc = "" + @@ -252,7 +290,8 @@ const file_temporal_server_chasm_lib_activity_proto_v1_tasks_proto_rawDesc = "" "\x17StartToCloseTimeoutTask\x12\x14\n" + "\x05stamp\x18\x01 \x01(\x05R\x05stamp\",\n" + "\x14HeartbeatTimeoutTask\x12\x14\n" + - "\x05stamp\x18\x01 \x01(\x05R\x05stampBDZBgo.temporal.io/server/chasm/lib/activity/gen/activitypb;activitypbb\x06proto3" + "\x05stamp\x18\x01 \x01(\x05R\x05stamp\"\x1b\n" + + "\x19CancelCommandDispatchTaskBDZBgo.temporal.io/server/chasm/lib/activity/gen/activitypb;activitypbb\x06proto3" var ( file_temporal_server_chasm_lib_activity_proto_v1_tasks_proto_rawDescOnce sync.Once @@ -266,13 +305,14 @@ func file_temporal_server_chasm_lib_activity_proto_v1_tasks_proto_rawDescGZIP() return file_temporal_server_chasm_lib_activity_proto_v1_tasks_proto_rawDescData } -var file_temporal_server_chasm_lib_activity_proto_v1_tasks_proto_msgTypes = make([]protoimpl.MessageInfo, 5) +var file_temporal_server_chasm_lib_activity_proto_v1_tasks_proto_msgTypes = make([]protoimpl.MessageInfo, 6) var file_temporal_server_chasm_lib_activity_proto_v1_tasks_proto_goTypes = []any{ (*ActivityDispatchTask)(nil), // 0: temporal.server.chasm.lib.activity.proto.v1.ActivityDispatchTask (*ScheduleToStartTimeoutTask)(nil), // 1: temporal.server.chasm.lib.activity.proto.v1.ScheduleToStartTimeoutTask (*ScheduleToCloseTimeoutTask)(nil), // 2: temporal.server.chasm.lib.activity.proto.v1.ScheduleToCloseTimeoutTask (*StartToCloseTimeoutTask)(nil), // 3: temporal.server.chasm.lib.activity.proto.v1.StartToCloseTimeoutTask (*HeartbeatTimeoutTask)(nil), // 4: temporal.server.chasm.lib.activity.proto.v1.HeartbeatTimeoutTask + (*CancelCommandDispatchTask)(nil), // 5: temporal.server.chasm.lib.activity.proto.v1.CancelCommandDispatchTask } var file_temporal_server_chasm_lib_activity_proto_v1_tasks_proto_depIdxs = []int32{ 0, // [0:0] is the sub-list for method output_type @@ -293,7 +333,7 @@ func file_temporal_server_chasm_lib_activity_proto_v1_tasks_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_temporal_server_chasm_lib_activity_proto_v1_tasks_proto_rawDesc), len(file_temporal_server_chasm_lib_activity_proto_v1_tasks_proto_rawDesc)), NumEnums: 0, - NumMessages: 5, + NumMessages: 6, NumExtensions: 0, NumServices: 0, }, diff --git a/chasm/lib/activity/library.go b/chasm/lib/activity/library.go index 83e3d9067af..269648a4a0f 100644 --- a/chasm/lib/activity/library.go +++ b/chasm/lib/activity/library.go @@ -79,6 +79,7 @@ type library struct { handler *handler activityDispatchTaskHandler *activityDispatchTaskHandler + cancelCommandDispatchTaskHandler *cancelCommandDispatchTaskHandler scheduleToStartTimeoutTaskHandler *scheduleToStartTimeoutTaskHandler scheduleToCloseTimeoutTaskHandler *scheduleToCloseTimeoutTaskHandler startToCloseTimeoutTaskHandler *startToCloseTimeoutTaskHandler @@ -88,6 +89,7 @@ type library struct { func newLibrary( handler *handler, activityDispatchTaskHandler *activityDispatchTaskHandler, + cancelCommandDispatchTaskHandler *cancelCommandDispatchTaskHandler, scheduleToStartTimeoutTaskHandler *scheduleToStartTimeoutTaskHandler, scheduleToCloseTimeoutTaskHandler *scheduleToCloseTimeoutTaskHandler, startToCloseTimeoutTaskHandler *startToCloseTimeoutTaskHandler, @@ -99,6 +101,7 @@ func newLibrary( componentOnlyLibrary: *newComponentOnlyLibrary(config, namespaceRegistry), handler: handler, activityDispatchTaskHandler: activityDispatchTaskHandler, + cancelCommandDispatchTaskHandler: cancelCommandDispatchTaskHandler, scheduleToStartTimeoutTaskHandler: scheduleToStartTimeoutTaskHandler, scheduleToCloseTimeoutTaskHandler: scheduleToCloseTimeoutTaskHandler, startToCloseTimeoutTaskHandler: startToCloseTimeoutTaskHandler, @@ -132,5 +135,9 @@ func (l *library) Tasks() []*chasm.RegistrableTask { "heartbeatTimer", l.heartbeatTimeoutTaskHandler, ), + chasm.NewRegistrableSideEffectTask( + "cancelCommandDispatch", + l.cancelCommandDispatchTaskHandler, + ), } } diff --git a/chasm/lib/activity/proto/v1/activity_state.proto b/chasm/lib/activity/proto/v1/activity_state.proto index 931afb0b881..7519ff46be0 100644 --- a/chasm/lib/activity/proto/v1/activity_state.proto +++ b/chasm/lib/activity/proto/v1/activity_state.proto @@ -155,6 +155,10 @@ message ActivityAttemptState { // The request ID that came from matching's RecordActivityTaskStarted API call. Used to make this API idempotent in // case of implicit retries. string start_request_id = 9; + + // The worker's control task queue for sending commands (e.g. cancel) via Nexus. + // Set when the worker reports it during poll. Empty if the worker doesn't support worker commands. + string worker_control_task_queue = 10; } message ActivityHeartbeatState { diff --git a/chasm/lib/activity/proto/v1/tasks.proto b/chasm/lib/activity/proto/v1/tasks.proto index 9a1996e3dd2..70dd3ea992a 100644 --- a/chasm/lib/activity/proto/v1/tasks.proto +++ b/chasm/lib/activity/proto/v1/tasks.proto @@ -26,3 +26,7 @@ message HeartbeatTimeoutTask { // The current stamp for this activity execution. Used for task validation. See also [ActivityAttemptState]. int32 stamp = 1; } + +// CancelCommandDispatchTask is a side-effect task that dispatches a cancel command to the worker +// via the Nexus worker commands control queue. +message CancelCommandDispatchTask {} diff --git a/chasm/lib/activity/statemachine.go b/chasm/lib/activity/statemachine.go index b594e56a6d1..5cd7e9c8a86 100644 --- a/chasm/lib/activity/statemachine.go +++ b/chasm/lib/activity/statemachine.go @@ -146,6 +146,7 @@ var TransitionStarted = chasm.NewTransition( attempt.StartedTime = timestamppb.New(ctx.Now(a)) attempt.StartRequestId = request.GetRequestId() attempt.LastWorkerIdentity = request.GetPollRequest().GetIdentity() + attempt.WorkerControlTaskQueue = request.GetPollRequest().GetWorkerControlTaskQueue() if versionDirective := request.GetVersionDirective().GetDeploymentVersion(); versionDirective != nil { attempt.LastDeploymentVersion = &deploymentpb.WorkerDeploymentVersion{ BuildId: versionDirective.GetBuildId(), diff --git a/tests/standalone_activity_test.go b/tests/standalone_activity_test.go index 49ce3fe4dc4..d88ab8c9fcc 100644 --- a/tests/standalone_activity_test.go +++ b/tests/standalone_activity_test.go @@ -15,6 +15,7 @@ import ( commonpb "go.temporal.io/api/common/v1" enumspb "go.temporal.io/api/enums/v1" failurepb "go.temporal.io/api/failure/v1" + workerservicepb "go.temporal.io/api/nexusservices/workerservice/v1" "go.temporal.io/api/operatorservice/v1" sdkpb "go.temporal.io/api/sdk/v1" "go.temporal.io/api/serviceerror" @@ -6328,3 +6329,146 @@ func (s *standaloneActivityTestSuite) TestCallbacks() { require.Equal(t, enumspb.ACTIVITY_EXECUTION_STATUS_TIMED_OUT, descResp.GetInfo().GetStatus()) }) } + +func (s *standaloneActivityTestSuite) TestDispatchCancelCommandToWorker() { + env := s.newTestEnv() + t := s.T() + ctx, cancel := context.WithTimeout(t.Context(), 30*time.Second) + defer cancel() + + env.OverrideDynamicConfig(dynamicconfig.EnableCancelActivityWorkerCommand, true) + + tv := env.Tv() + controlQueueName := tv.ControlQueueName(env.Namespace().String()) + + tokenSerializer := tasktoken.NewSerializer() + + // assertCancelTokenMatchesPoll verifies the cancel command's task token identifies the same + // activity as the poll response's token. The tokens won't be byte-identical because: + // 1. Matching builds poll tokens with additional fields (Clock, Version, etc.) + // 2. The ComponentRef version advances after state mutations (cancel/terminate) + // We compare the stable identity fields that the SDK uses to find the running activity. + assertCancelTokenMatchesPoll := func(t *testing.T, pollToken, cancelToken []byte) { + t.Helper() + pollTask, err := tokenSerializer.Deserialize(pollToken) + require.NoError(t, err) + cancelTask, err := tokenSerializer.Deserialize(cancelToken) + require.NoError(t, err) + require.Equal(t, pollTask.GetActivityId(), cancelTask.GetActivityId()) + require.Equal(t, pollTask.GetNamespaceId(), cancelTask.GetNamespaceId()) + require.Equal(t, pollTask.GetActivityType(), cancelTask.GetActivityType()) + require.Equal(t, pollTask.GetAttempt(), cancelTask.GetAttempt()) + require.NotEmpty(t, cancelTask.GetComponentRef(), "cancel token must have a ComponentRef") + } + + // pollNexusControlQueue polls the worker commands control queue for a cancel command and + // returns the decoded ExecuteCommandsRequest. Returns nil if no task is received. + pollNexusControlQueue := func() *workerservicepb.ExecuteCommandsRequest { + pollCtx, pollCancel := context.WithTimeout(ctx, 5*time.Second) + defer pollCancel() + resp, err := env.FrontendClient().PollNexusTaskQueue(pollCtx, &workflowservice.PollNexusTaskQueueRequest{ + Namespace: env.Namespace().String(), + TaskQueue: &taskqueuepb.TaskQueue{Name: controlQueueName, Kind: enumspb.TASK_QUEUE_KIND_WORKER_COMMANDS}, + Identity: tv.WorkerIdentity(), + }) + if err != nil || resp == nil || resp.Request == nil { + return nil + } + startOp := resp.Request.GetStartOperation() + if startOp == nil { + return nil + } + var executeReq workerservicepb.ExecuteCommandsRequest + if err := payload.Decode(startOp.Payload, &executeReq); err != nil { + return nil + } + return &executeReq + } + + t.Run("CancelRequest", func(t *testing.T) { + activityID := testcore.RandomizeStr(t.Name()) + taskQueue := testcore.RandomizeStr(t.Name()) + + startResp := env.startAndValidateActivity(ctx, t, activityID, taskQueue) + runID := startResp.RunId + + // Poll with a worker control task queue so the activity stores it. + pollTaskResp, err := env.FrontendClient().PollActivityTaskQueue(ctx, &workflowservice.PollActivityTaskQueueRequest{ + Namespace: env.Namespace().String(), + TaskQueue: &taskqueuepb.TaskQueue{ + Name: taskQueue, + Kind: enumspb.TASK_QUEUE_KIND_NORMAL, + }, + Identity: tv.WorkerIdentity(), + WorkerInstanceKey: tv.WorkerInstanceKey(), + WorkerControlTaskQueue: controlQueueName, + }) + require.NoError(t, err) + require.NotEmpty(t, pollTaskResp.TaskToken) + + // Request cancellation — should dispatch cancel command to the control queue. + _, err = env.FrontendClient().RequestCancelActivityExecution(ctx, &workflowservice.RequestCancelActivityExecutionRequest{ + Namespace: env.Namespace().String(), + ActivityId: activityID, + RunId: runID, + Identity: "canceller", + RequestId: tv.RequestID(), + Reason: "test cancel", + }) + require.NoError(t, err) + + var executeReq *workerservicepb.ExecuteCommandsRequest + require.Eventually(t, func() bool { + executeReq = pollNexusControlQueue() + return executeReq != nil + }, 15*time.Second, 100*time.Millisecond, "cancel command not received on control queue") + + require.Len(t, executeReq.Commands, 1) + cancelCmd := executeReq.Commands[0].GetCancelActivity() + require.NotNil(t, cancelCmd, "expected CancelActivity command") + assertCancelTokenMatchesPoll(t, pollTaskResp.TaskToken, cancelCmd.TaskToken) + }) + + t.Run("Terminate", func(t *testing.T) { + activityID := testcore.RandomizeStr(t.Name()) + taskQueue := testcore.RandomizeStr(t.Name()) + + startResp := env.startAndValidateActivity(ctx, t, activityID, taskQueue) + runID := startResp.RunId + + // Poll with a worker control task queue. + pollTaskResp, err := env.FrontendClient().PollActivityTaskQueue(ctx, &workflowservice.PollActivityTaskQueueRequest{ + Namespace: env.Namespace().String(), + TaskQueue: &taskqueuepb.TaskQueue{ + Name: taskQueue, + Kind: enumspb.TASK_QUEUE_KIND_NORMAL, + }, + Identity: tv.WorkerIdentity(), + WorkerInstanceKey: tv.WorkerInstanceKey(), + WorkerControlTaskQueue: controlQueueName, + }) + require.NoError(t, err) + require.NotEmpty(t, pollTaskResp.TaskToken) + + // Terminate — should dispatch cancel command to the control queue. + _, err = env.FrontendClient().TerminateActivityExecution(ctx, &workflowservice.TerminateActivityExecutionRequest{ + Namespace: env.Namespace().String(), + ActivityId: activityID, + RunId: runID, + Reason: "test terminate", + Identity: "terminator", + }) + require.NoError(t, err) + + var executeReq *workerservicepb.ExecuteCommandsRequest + require.Eventually(t, func() bool { + executeReq = pollNexusControlQueue() + return executeReq != nil + }, 15*time.Second, 100*time.Millisecond, "cancel command not received on control queue after terminate") + + require.Len(t, executeReq.Commands, 1) + cancelCmd := executeReq.Commands[0].GetCancelActivity() + require.NotNil(t, cancelCmd, "expected CancelActivity command") + assertCancelTokenMatchesPoll(t, pollTaskResp.TaskToken, cancelCmd.TaskToken) + }) +}