Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 21 additions & 12 deletions core/capabilities/remote/trigger_publisher.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,10 @@ type ackKey struct {
}

type pubRegState struct {
callback <-chan commoncap.TriggerResponse
request commoncap.TriggerRegistrationRequest
cancel context.CancelFunc
callback <-chan commoncap.TriggerResponse
request commoncap.TriggerRegistrationRequest
cancel context.CancelFunc
registrationFailed bool // true if RegisterTrigger returned an error; used to suppress retries
}

type batchedResponse struct {
Expand Down Expand Up @@ -216,9 +217,12 @@ func (p *triggerPublisher) Receive(_ context.Context, msg *types.MessageBody) {
p.mu.Lock()
defer p.mu.Unlock()
p.messageCache.Insert(key, sender, nowMs, msg.Payload)
_, exists := p.registrations[key]
if exists {
p.lggr.Debugw("trigger registration already exists", "workflowId", req.Metadata.WorkflowID, "triggerID", req.TriggerID)
if existing, exists := p.registrations[key]; exists {
if existing.registrationFailed {
p.lggr.Debugw("skipping trigger registration; previous attempt failed", "workflowId", req.Metadata.WorkflowID, "triggerID", req.TriggerID)
} else {
p.lggr.Debugw("trigger registration already exists", "workflowId", req.Metadata.WorkflowID, "triggerID", req.TriggerID)
}
return
}
// NOTE: require 2F+1 by default, introduce different strategies later (KS-76)
Expand Down Expand Up @@ -251,6 +255,7 @@ func (p *triggerPublisher) Receive(_ context.Context, msg *types.MessageBody) {
p.lggr.Debugw("updated trigger registration", "workflowId", req.Metadata.WorkflowID, "triggerID", req.TriggerID)
} else {
cancel()
p.registrations[key] = &pubRegState{registrationFailed: true}
p.lggr.Errorw("failed to register trigger", "workflowId", req.Metadata.WorkflowID, "triggerID", req.TriggerID, "err", err)
}
case types.MethodTriggerEvent:
Expand Down Expand Up @@ -333,16 +338,20 @@ func (p *triggerPublisher) cacheCleanupLoop() {
now := time.Now().UnixMilli()

p.mu.Lock()
for key, req := range p.registrations {
for key, reg := range p.registrations {
callerDon := cfg.workflowDONs[key.callerDonID]
ready, _ := p.messageCache.Ready(key, uint32(2*callerDon.F+1), now-cfg.remoteConfig.RegistrationExpiry.Milliseconds(), false)
if !ready {
p.lggr.Infow("trigger registration expired", "callerDonID", key.callerDonID, "workflowId", key.workflowID, "triggerID", key.triggerID)
ctx, cancel := p.stopCh.NewCtx()
err := cfg.underlying.UnregisterTrigger(ctx, req.request)
cancel()
p.registrations[key].cancel() // Cancel context on register trigger
p.lggr.Infow("unregistered trigger", "callerDonID", key.callerDonID, "workflowId", key.workflowID, "triggerID", key.triggerID, "err", err)
if !reg.registrationFailed {
ctx, cancel := p.stopCh.NewCtx()
err := cfg.underlying.UnregisterTrigger(ctx, reg.request)
cancel()
reg.cancel()
p.lggr.Infow("unregistered trigger", "callerDonID", key.callerDonID, "workflowId", key.workflowID, "triggerID", key.triggerID, "err", err)
} else {
p.lggr.Debugw("removing failed registration attempt from local state", "callerDonID", key.callerDonID, "workflowId", key.workflowID, "triggerID", key.triggerID)
}
// after calling UnregisterTrigger, the underlying trigger will not send any more events to the channel
delete(p.registrations, key)
p.messageCache.Delete(key)
Expand Down
80 changes: 80 additions & 0 deletions core/capabilities/remote/trigger_publisher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package remote_test

import (
"context"
"errors"
"sync"
"testing"
"time"
Expand Down Expand Up @@ -576,6 +577,85 @@ func TestTriggerPublisher_ResendBehavior_MultiTriggerBatch(t *testing.T) {
})
}

func TestTriggerPublisher_RegisterTrigger_FailureShortCircuit(t *testing.T) {
ctx := testutils.Context(t)
lggr := logger.Test(t)
capabilityDONID, workflowDONID := uint32(1), uint32(2)

peers := make([]p2ptypes.PeerID, 2)
require.NoError(t, peers[0].UnmarshalText([]byte(peerID1)))
require.NoError(t, peers[1].UnmarshalText([]byte(peerID2)))

capDonInfo := commoncap.DON{
ID: capabilityDONID,
Members: []p2ptypes.PeerID{peers[0]},
F: 0,
}
workflowDonInfo := commoncap.DON{
ID: workflowDONID,
Members: []p2ptypes.PeerID{peers[1]},
F: 0,
}
workflowDONs := map[uint32]commoncap.DON{
workflowDonInfo.ID: workflowDonInfo,
}

capInfo := commoncap.CapabilityInfo{
ID: capID,
CapabilityType: commoncap.CapabilityTypeTrigger,
}
underlying := &errTrigger{info: capInfo, err: errors.New("registration error")}

dispatcher := mocks.NewDispatcher(t)
config := &commoncap.RemoteTriggerConfig{
RegistrationRefresh: 100 * time.Millisecond,
RegistrationExpiry: 100 * time.Second,
MinResponsesToAggregate: 1,
MessageExpiry: 100 * time.Second,
MaxBatchSize: 1,
BatchCollectionPeriod: time.Second,
}
publisher := remote.NewTriggerPublisher(capInfo.ID, "", dispatcher, lggr)
require.NoError(t, publisher.SetConfig(config, underlying, capDonInfo, workflowDONs))
require.NoError(t, publisher.Start(ctx))

// First message reaches quorum and triggers a RegisterTrigger call that fails.
regMsg := newRegisterTriggerMessage(t, workflowDONID, peers[1])
publisher.Receive(ctx, regMsg)
require.Equal(t, 1, underlying.callCount, "RegisterTrigger should be called once on first quorum")

// Subsequent messages for the same key must be short-circuited without retrying.
publisher.Receive(ctx, regMsg)
publisher.Receive(ctx, regMsg)
require.Equal(t, 1, underlying.callCount, "RegisterTrigger must not be retried after a failure")

require.NoError(t, publisher.Close())
}

// errTrigger is a TriggerCapability that always returns an error from RegisterTrigger.
type errTrigger struct {
info commoncap.CapabilityInfo
err error
callCount int
}

func (tr *errTrigger) Info(_ context.Context) (commoncap.CapabilityInfo, error) {
return tr.info, nil
}

func (tr *errTrigger) RegisterTrigger(_ context.Context, _ commoncap.TriggerRegistrationRequest) (<-chan commoncap.TriggerResponse, error) {
tr.callCount++
return nil, tr.err
}

func (tr *errTrigger) UnregisterTrigger(_ context.Context, _ commoncap.TriggerRegistrationRequest) error {
return nil
}

func (tr *errTrigger) AckEvent(_ context.Context, _ string, _ string, _ string) error {
return nil
}

func newRegisterTriggerMessageWithTriggerID(t *testing.T, callerDonID uint32, sender p2ptypes.PeerID, triggerID string) *remotetypes.MessageBody {
triggerRequest := commoncap.TriggerRegistrationRequest{
TriggerID: triggerID,
Expand Down
Loading