Problem
The TTT position loss weighting logic 0.8**i is hardcoded and duplicated in 3 separate locations in eagle3_trainer.py:
- Line 251 (
_backward): ploss_weight = [0.8**i for i in range(len(plosses))]
- Line 302 (
_aggregate_eval_metrics): [0.8**i for i in range(avg_plosses.shape[0])]
- Line 394 (
_aggregate_metrics): [0.8**i for i in range(avg_plosses.shape[0])]
The weighting strategy, decay factor, and reduction are all implicitly embedded in each call site with no abstraction.
Proposed Solution
Extract a ploss function that takes only the per-position losses and fully encapsulates the weighting strategy internally:
def weighted_ploss(plosses: list[torch.Tensor]) -> torch.Tensor:
"""Compute weighted position loss.
Encapsulates the full weighting strategy — callers don't need to know
whether it's exponential decay, linear decay, or something else.
"""
weights = [0.8 ** i for i in range(len(plosses))]
return sum(w * p for w, p in zip(weights, plosses)) / sum(weights)
The key design choice: the function signature is just plosses in, scalar out. The weighting strategy is an implementation detail hidden inside. This makes it easy to swap to a completely different strategy (e.g., non-uniform, learned weights, position-dependent) without changing any call sites.
This function should be used in:
_backward for the training loss
_aggregate_metrics for the training metric
_aggregate_eval_metrics for the eval metric
Files
torchspec/training/eagle3_trainer.py
Problem
The TTT position loss weighting logic
0.8**iis hardcoded and duplicated in 3 separate locations ineagle3_trainer.py:_backward):ploss_weight = [0.8**i for i in range(len(plosses))]_aggregate_eval_metrics):[0.8**i for i in range(avg_plosses.shape[0])]_aggregate_metrics):[0.8**i for i in range(avg_plosses.shape[0])]The weighting strategy, decay factor, and reduction are all implicitly embedded in each call site with no abstraction.
Proposed Solution
Extract a
plossfunction that takes only the per-position losses and fully encapsulates the weighting strategy internally:The key design choice: the function signature is just
plossesin, scalar out. The weighting strategy is an implementation detail hidden inside. This makes it easy to swap to a completely different strategy (e.g., non-uniform, learned weights, position-dependent) without changing any call sites.This function should be used in:
_backwardfor the training loss_aggregate_metricsfor the training metric_aggregate_eval_metricsfor the eval metricFiles
torchspec/training/eagle3_trainer.py