diff --git a/pkg/fsm/last_sign_state.go b/pkg/fsm/last_sign_state.go index ba1edda..1886659 100644 --- a/pkg/fsm/last_sign_state.go +++ b/pkg/fsm/last_sign_state.go @@ -32,7 +32,13 @@ func (s *LastSignState) CopyToFilePV(pv *privval.FilePVLastSignState) { if pv == nil { return } - + if s.Less(&LastSignState{ + Height: pv.Height, + Round: pv.Round, + Step: pv.Step, + }) { + return + } pv.Height = s.Height pv.Round = s.Round pv.Step = s.Step diff --git a/pkg/keyserver/signing_integration_test.go b/pkg/keyserver/signing_integration_test.go index 8edc2e8..6276c51 100644 --- a/pkg/keyserver/signing_integration_test.go +++ b/pkg/keyserver/signing_integration_test.go @@ -426,6 +426,103 @@ func TestKeyserverMultipleValidatorsRejectConflictingVotes(t *testing.T) { } } +func TestKeyserverRejectsConflictingVotesAtSameHeight(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + peers := makeTestPeers(t, 3) + cluster := startRaftCluster(t, ctx, peers) + t.Cleanup(func() { + for _, node := range cluster { + closeClusterNode(t, node) + } + }) + + baseKeyDir := t.TempDir() + baseKeyPath := filepath.Join(baseKeyDir, "priv_key.json") + baseStatePath := filepath.Join(baseKeyDir, "priv_state.json") + basePV := privval.GenFilePV(baseKeyPath, baseStatePath) + basePV.Save() + + validatorAddrA, endpointA, signerA := startValidator(t, testChainID) + validatorAddrB, endpointB, signerB := startValidator(t, testChainID) + t.Cleanup(func() { + if err := signerA.Close(); err != nil { + t.Logf("signer A close: %v", err) + } + if err := signerB.Close(); err != nil { + t.Logf("signer B close: %v", err) + } + endpointA.Stop() + endpointB.Stop() + }) + + startKeyservers(t, ctx, cluster, baseKeyPath, baseStatePath, []string{validatorAddrA, validatorAddrB}, testChainID) + + _ = waitForLeaderNode(t, cluster, 10*time.Second) + waitForValidatorConnection(t, signerA, 10*time.Second) + waitForValidatorConnection(t, signerB, 10*time.Second) + + validatorAddress := ed25519.GenPrivKey().PubKey().Address().Bytes() + height := int64(40) + + conflictA := makeVote(height, 0x33, validatorAddress) + conflictB := makeVote(height, 0x44, validatorAddress) + + var hookCount atomic.Int32 + syncLastSignStateHook = func(point string, state *fsm.LastSignState) { + if point == "before-write" { + if res := hookCount.Add(1); res == 1 { + time.Sleep(time.Second) + } + } + } + t.Cleanup(func() { + syncLastSignStateHook = nil + }) + + results := make(chan error, 2) + start := make(chan struct{}) + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + <-start + attempt := cloneVote(conflictA) + results <- signerA.SignVote(testChainID, attempt) + }() + + go func() { + defer wg.Done() + <-start + attempt := cloneVote(conflictB) + results <- signerB.SignVote(testChainID, attempt) + }() + + close(start) + + wg.Wait() + close(results) + + errCount := 0 + for res := range results { + if res != nil { + var remoteErr *privval.RemoteSignerError + if !errors.As(res, &remoteErr) { + t.Fatalf("expected remote signer error, got %v", res) + } + if !strings.Contains(remoteErr.Description, "conflicting") { + t.Fatalf("unexpected error: %v", remoteErr) + } + errCount++ + } + } + if errCount == 0 { + t.Fatal("expected at least one conflicting signature to fail") + } +} + func startKeyservers(t *testing.T, ctx context.Context, cluster []*clusterNode, baseKeyPath, baseStatePath string, validatorAddrs []string, chainID string) { t.Helper() diff --git a/pkg/keyserver/validator.go b/pkg/keyserver/validator.go index ac028e3..f4efaaa 100644 --- a/pkg/keyserver/validator.go +++ b/pkg/keyserver/validator.go @@ -20,6 +20,10 @@ type PrivValidator struct { mu sync.Mutex } +// syncLastSignStateHook is for tests to coordinate interleavings. +// It should be nil in production. +var syncLastSignStateHook func(point string, state *fsm.LastSignState) + // NewPrivValidator returns a validator that defers to inner once the // CometKMS lease is available. func NewPrivValidator(inner *privval.FilePV, node *raftnode.Node) *PrivValidator { @@ -80,21 +84,27 @@ func (l *PrivValidator) SignProposal(chainID string, proposal *cmtproto.Proposal // on-disk priv-validator state so leadership changes cannot re-sign old blocks. func (l *PrivValidator) syncLastSignState() error { l.mu.Lock() + defer l.mu.Unlock() + lastSignState := fsm.FromFilePV(&l.inner.LastSignState) if lastSignState.Equal(l.node.GetLastSignState()) { - l.mu.Unlock() return nil } - l.mu.Unlock() + + if syncLastSignStateHook != nil { + syncLastSignStateHook("before-raft", lastSignState) + } state, err := l.node.SyncLastSignState(lastSignState) if err != nil { return err } - l.mu.Lock() + if syncLastSignStateHook != nil { + syncLastSignStateHook("before-write", state) + } + state.CopyToFilePV(&l.inner.LastSignState) - l.mu.Unlock() return nil }