Problem
_aggregate_metrics (line 368-415) and _aggregate_eval_metrics (line 285-321) in eagle3_trainer.py share ~30 lines of nearly identical logic:
torch.stack + mean + all_reduce for plosses and acces
simulated_acc_len cumulative calculation (cumulative *= acces[i])
0.8**i weighted loss computation
- Per-position metric extraction
Bug fixes must be applied in two places, which is error-prone.
Proposed Solution
Extract a shared helper method:
def _compute_weighted_loss_and_acc(self, avg_plosses, avg_acces, prefix="train"):
"""Shared logic: simulated_acc_len, weighted loss, per-position metrics."""
...
Both _aggregate_metrics and _aggregate_eval_metrics would call this helper, adding only their unique fields (e.g., grad_norm, lr for training).
Files
torchspec/training/eagle3_trainer.py
Problem
_aggregate_metrics(line 368-415) and_aggregate_eval_metrics(line 285-321) ineagle3_trainer.pyshare ~30 lines of nearly identical logic:torch.stack + mean + all_reducefor plosses and accessimulated_acc_lencumulative calculation (cumulative *= acces[i])0.8**iweighted loss computationBug fixes must be applied in two places, which is error-prone.
Proposed Solution
Extract a shared helper method:
Both
_aggregate_metricsand_aggregate_eval_metricswould call this helper, adding only their unique fields (e.g.,grad_norm,lrfor training).Files
torchspec/training/eagle3_trainer.py