Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions node/aggregation_pool.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package node

import (
"sync"

"github.com/geanlabs/gean/xmss"
)

// Per-aggregate scratch slices reused across calls via sync.Pool. Cuts the
// young-gen GC pressure from re-allocating these four slices per data root
// (~5-20 per pass × every interval-2 tick). Initial capacities are first-fit
// guesses; pool reuse grows them to the working set and preserves the
// capacity across passes.
//
// Same pattern as xmss/proof_pool.go: pool stores *[]T, get/put wrappers
// reset length on return so callers always see an empty slice.

var (
childProofsPool = sync.Pool{New: func() any { s := make([]xmss.ChildProof, 0, 8); return &s }}
rawPubkeysPool = sync.Pool{New: func() any { s := make([]xmss.CPubKey, 0, 32); return &s }}
rawSigsPool = sync.Pool{New: func() any { s := make([]xmss.CSig, 0, 32); return &s }}
rawIDsPool = sync.Pool{New: func() any { s := make([]uint64, 0, 32); return &s }}
)

func getChildProofsBuf() *[]xmss.ChildProof { return childProofsPool.Get().(*[]xmss.ChildProof) }
func putChildProofsBuf(b *[]xmss.ChildProof) { *b = (*b)[:0]; childProofsPool.Put(b) }

func getRawPubkeysBuf() *[]xmss.CPubKey { return rawPubkeysPool.Get().(*[]xmss.CPubKey) }
func putRawPubkeysBuf(b *[]xmss.CPubKey) { *b = (*b)[:0]; rawPubkeysPool.Put(b) }

func getRawSigsBuf() *[]xmss.CSig { return rawSigsPool.Get().(*[]xmss.CSig) }
func putRawSigsBuf(b *[]xmss.CSig) { *b = (*b)[:0]; rawSigsPool.Put(b) }

func getRawIDsBuf() *[]uint64 { return rawIDsPool.Get().(*[]uint64) }
func putRawIDsBuf(b *[]uint64) { *b = (*b)[:0]; rawIDsPool.Put(b) }
225 changes: 116 additions & 109 deletions node/store_aggregate.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,138 +127,145 @@ func aggregateFromSnapshot(snap *AggregationSnapshot, cache *xmss.PubKeyCache) (
}

for dataRoot := range dataRoots {
prepStart := time.Now()
gossipEntry := snap.attSigs[dataRoot]
newEntry := snap.newEntries[dataRoot]
knownEntry := snap.knownEntries[dataRoot]

// Need attestation data from any available source.
var attData *types.AttestationData
if gossipEntry != nil {
attData = gossipEntry.Data
} else if newEntry != nil {
attData = newEntry.Data
} else if knownEntry != nil {
attData = knownEntry.Data
}
if attData == nil {
continue
}

targetState := snap.targetStates[attData.Target.Root]
if targetState == nil {
continue
}

// Phase 1: Select — greedy pick existing child proofs.
var childProofs []xmss.ChildProof
covered := make(map[uint64]bool)
// Anonymous func per iteration so pooled scratch slices (and the
// defer xmss.FreeSignature inside the sig loop) release per data
// root rather than accumulating until aggregateFromSnapshot returns.
func() {
childProofsBuf := getChildProofsBuf()
defer putChildProofsBuf(childProofsBuf)
rawPubkeysBuf := getRawPubkeysBuf()
defer putRawPubkeysBuf(rawPubkeysBuf)
rawSigsBuf := getRawSigsBuf()
defer putRawSigsBuf(rawSigsBuf)
rawIDsBuf := getRawIDsBuf()
defer putRawIDsBuf(rawIDsBuf)

prepStart := time.Now()
gossipEntry := snap.attSigs[dataRoot]
newEntry := snap.newEntries[dataRoot]
knownEntry := snap.knownEntries[dataRoot]

// Need attestation data from any available source.
var attData *types.AttestationData
if gossipEntry != nil {
attData = gossipEntry.Data
} else if newEntry != nil {
attData = newEntry.Data
} else if knownEntry != nil {
attData = knownEntry.Data
}
if attData == nil {
return
}

selectChildProofs(newEntry, targetState, &childProofs, covered, cache)
selectChildProofs(knownEntry, targetState, &childProofs, covered, cache)
targetState := snap.targetStates[attData.Target.Root]
if targetState == nil {
return
}

// Phase 2: Fill — collect raw gossip signatures for uncovered validators.
var rawPubkeys []xmss.CPubKey
var rawSigs []xmss.CSig
var rawIDs []uint64
// Phase 1: Select — greedy pick existing child proofs.
covered := make(map[uint64]bool)
selectChildProofs(newEntry, targetState, childProofsBuf, covered, cache)
selectChildProofs(knownEntry, targetState, childProofsBuf, covered, cache)

// Phase 2: Fill — collect raw gossip signatures for uncovered validators.
if gossipEntry != nil && len(gossipEntry.Signatures) > 0 {
sortedSigs := make([]AttestationSignatureEntry, len(gossipEntry.Signatures))
copy(sortedSigs, gossipEntry.Signatures)
sort.Slice(sortedSigs, func(i, j int) bool {
return sortedSigs[i].ValidatorID < sortedSigs[j].ValidatorID
})

if gossipEntry != nil && len(gossipEntry.Signatures) > 0 {
sortedSigs := make([]AttestationSignatureEntry, len(gossipEntry.Signatures))
copy(sortedSigs, gossipEntry.Signatures)
sort.Slice(sortedSigs, func(i, j int) bool {
return sortedSigs[i].ValidatorID < sortedSigs[j].ValidatorID
})
for _, sigEntry := range sortedSigs {
if covered[sigEntry.ValidatorID] {
continue
}
if sigEntry.ValidatorID >= uint64(len(targetState.Validators)) {
continue
}

for _, sigEntry := range sortedSigs {
if covered[sigEntry.ValidatorID] {
continue
}
if sigEntry.ValidatorID >= uint64(len(targetState.Validators)) {
continue
}
sigHandle := sigEntry.SigHandle
if sigHandle == nil {
parsed, err := xmss.ParseSignature(sigEntry.Signature[:])
if err != nil {
continue
}
defer xmss.FreeSignature(parsed)
sigHandle = parsed
}

sigHandle := sigEntry.SigHandle
if sigHandle == nil {
parsed, err := xmss.ParseSignature(sigEntry.Signature[:])
pk, err := cache.Get(targetState.Validators[sigEntry.ValidatorID].AttestationPubkey)
if err != nil {
continue
}
defer xmss.FreeSignature(parsed)
sigHandle = parsed
}

pk, err := cache.Get(targetState.Validators[sigEntry.ValidatorID].AttestationPubkey)
if err != nil {
continue
*rawPubkeysBuf = append(*rawPubkeysBuf, pk)
*rawSigsBuf = append(*rawSigsBuf, sigHandle)
*rawIDsBuf = append(*rawIDsBuf, sigEntry.ValidatorID)
}

rawPubkeys = append(rawPubkeys, pk)
rawSigs = append(rawSigs, sigHandle)
rawIDs = append(rawIDs, sigEntry.ValidatorID)
}
}

// Prover requires at least 2 total inputs.
totalInputs := len(rawIDs) + len(childProofs)
if totalInputs < 2 {
continue
}
// Prover requires at least 2 total inputs.
if len(*rawIDsBuf)+len(*childProofsBuf) < 2 {
return
}

// Phase 3: Aggregate — produce recursive proof.
dataRootHash, _ := attData.HashTreeRoot()
slot := uint32(attData.Slot)
// Phase 3: Aggregate — produce recursive proof.
dataRootHash, _ := attData.HashTreeRoot()
slot := uint32(attData.Slot)

ObserveAggregationPrepTime(time.Since(prepStart).Seconds())
ObserveAggregationPrepTime(time.Since(prepStart).Seconds())

aggStart := time.Now()
proofBytes, err := xmss.AggregateWithChildren(rawPubkeys, rawSigs, childProofs, dataRootHash, slot)
aggDuration := time.Since(aggStart)
if err != nil {
logger.Error(logger.Signature, "aggregate: failed slot=%d raw=%d children=%d duration=%v: %v",
slot, len(rawIDs), len(childProofs), aggDuration, err)
continue
}
aggStart := time.Now()
proofBytes, err := xmss.AggregateWithChildren(*rawPubkeysBuf, *rawSigsBuf, *childProofsBuf, dataRootHash, slot)
aggDuration := time.Since(aggStart)
if err != nil {
logger.Error(logger.Signature, "aggregate: failed slot=%d raw=%d children=%d duration=%v: %v",
slot, len(*rawIDsBuf), len(*childProofsBuf), aggDuration, err)
return
}

allIDs := make([]uint64, 0, len(rawIDs))
allIDs = append(allIDs, rawIDs...)
for vid := range covered {
allIDs = append(allIDs, vid)
}
allIDs := make([]uint64, 0, len(*rawIDsBuf)+len(covered))
allIDs = append(allIDs, (*rawIDsBuf)...)
for vid := range covered {
allIDs = append(allIDs, vid)
}

participants := AggregationBitsFromIndices(allIDs)
proof := &types.AggregatedSignatureProof{
Participants: participants,
ProofData: proofBytes,
}
participants := AggregationBitsFromIndices(allIDs)
proof := &types.AggregatedSignatureProof{
Participants: participants,
ProofData: proofBytes,
}

logger.Info(logger.Signature, "aggregate: slot=%d raw=%d children=%d total=%d proof=%d bytes duration=%v",
slot, len(rawIDs), len(childProofs), len(allIDs), len(proofBytes), aggDuration)
logger.Info(logger.Signature, "aggregate: slot=%d raw=%d children=%d total=%d proof=%d bytes duration=%v",
slot, len(*rawIDsBuf), len(*childProofsBuf), len(allIDs), len(proofBytes), aggDuration)

if AggregateMetricsFunc != nil {
AggregateMetricsFunc(aggDuration.Seconds(), len(allIDs))
}
if AggregateMetricsFunc != nil {
AggregateMetricsFunc(aggDuration.Seconds(), len(allIDs))
}

postStart := time.Now()
newAggregates = append(newAggregates, &types.SignedAggregatedAttestation{
Data: attData,
Proof: proof,
})
postStart := time.Now()
newAggregates = append(newAggregates, &types.SignedAggregatedAttestation{
Data: attData,
Proof: proof,
})

mut.PayloadEntries = append(mut.PayloadEntries, PayloadKV{
DataRoot: dataRoot,
Data: attData,
Proof: proof,
})
mut.PayloadEntries = append(mut.PayloadEntries, PayloadKV{
DataRoot: dataRoot,
Data: attData,
Proof: proof,
})

if gossipEntry != nil {
for _, sig := range gossipEntry.Signatures {
mut.KeysToDelete = append(mut.KeysToDelete, AttestationDeleteKey{
ValidatorID: sig.ValidatorID,
DataRoot: dataRoot,
})
if gossipEntry != nil {
for _, sig := range gossipEntry.Signatures {
mut.KeysToDelete = append(mut.KeysToDelete, AttestationDeleteKey{
ValidatorID: sig.ValidatorID,
DataRoot: dataRoot,
})
}
}
}
ObserveAggregationPostTime(time.Since(postStart).Seconds())
ObserveAggregationPostTime(time.Since(postStart).Seconds())
}()
}

return newAggregates, mut
Expand Down