feat: reduce Eagle3 training memory spike via all-to-all sharding#524
feat: reduce Eagle3 training memory spike via all-to-all sharding#524laoconeth wants to merge 5 commits intosgl-project:mainfrom
Conversation
Co-authored-by: Inhyuk Cho <cih9088@gmail.com> Co-authored-by: Kthyeon <potter32@kaist.ac.kr>
ModelRunner (sglang==0.5.9) always returns ModelRunnerOutput with logits_output field, so the hasattr guard is unnecessary. Co-authored-by: Inhyuk Cho <cih9088@gmail.com>
There was a problem hiding this comment.
Code Review
This pull request introduces the --shard-target-output feature for EAGLE3 training, enabling model outputs to be sharded across the batch dimension when using the SGLang backend to optimize memory usage. The changes include updates to the training script, the EAGLE3 target model, and the SGLang-specific logits processor, which now utilizes all_to_all communication. Review feedback points out a potential RuntimeError during tensor concatenation if shards are empty and suggests explicitly deleting intermediate tensors in the SGLang backend to ensure effective memory reclamation.
| aux_hidden_states_out = torch.cat(aux_hidden_states_out, dim=0) | ||
|
|
||
| loss_mask_out = torch.cat(loss_mask_out, dim=0) | ||
| attention_mask_out = torch.cat(attention_mask_out, dim=0) | ||
| input_ids_out = torch.cat(input_ids_out, dim=0) |
There was a problem hiding this comment.
These torch.cat operations will raise a RuntimeError if the input lists are empty. This scenario can occur when shard_returns is enabled and a rank is not assigned any batch items (e.g., if the global batch size is smaller than the world size). Although current sanity checks in the training script might prevent this, the model implementation should be robust to empty shards.
| sample_indices, | ||
| logits_metadata, | ||
| ) | ||
| del hidden_states |
There was a problem hiding this comment.
To ensure the large un-sharded tensors are eligible for garbage collection and truly eliminate the memory spike, it is recommended to also delete pruned_states, aux_pruned_states, and pruned_states_before_norm. These variables hold views into the original full-sized tensors, which prevents the underlying memory from being reclaimed even after hidden_states is deleted.
| del hidden_states | |
| del hidden_states, pruned_states, aux_pruned_states, pruned_states_before_norm |
There was a problem hiding this comment.
The reference counter will be decremented at the end of the function anyway, just a few lines below. del hidden_states is kept for consistency with upstream, sglang==0.5.9
|
Please help fix the unit-test. |
|
@jiapingW |
|
@jiapingW |
Motivation
Currently, Eagle3 training suffers from memory spikes that make it difficult to increase batch size (as reported in #466).
During Eagle3 training with the SGLang backend under tensor parallelism, the
lm_headis vocab parallel - each TP rank computes a shard of the logits with shape(global_batch_size * seq_len, vocab_size // tp_size). The current code usesall_gatherto reconstruct the full logits, resulting in every rank holding(global_batch_size * seq_len, vocab_size). This is already atp_size-fold memory increase, which is very big. The problem compounds ingenerate_eagle3_data(): the gathered logit tensors are concatenated viatorch.cat(eagle3_target_model.py:793), and thenpadding()(utils.py:41) performs anothertorch.catto shift the sequence. Due tocatallocating a new full-size copy while the original is still alive, peak memory reaches much higher. With large vocab sizes (e.g. 150K for Qwen3), this makes it difficult to increase batch size.This PR adds a
--shard-target-outputflag that replacesall_gatherwithall_to_all. Instead of gathering the full vocab on every rank, each rank exchanges its vocab shard for a batch shard, ending up with((global_batch_size // dp_size) * seq_len, vocab_size)- each rank holds the full vocab but only for its own DP shard of the batch. The same redistribution is applied toaux_hidden_statesandlast_hidden_states. Since the base tensor is much smaller (1/tp_sizeof the tensor without sharding), the subsequentpadding/catoperations also operate on smaller tensors, eliminating the memory spike.There is no alteration of tensor content, just distributed differently - the mathematical result is identical (training run shows identical loss values for 50 steps with sharding on/off).
Note:
--shard-target-outputis currently implemented for SGLang backend only, and is not tested for VLM models.Modifications
specforge/modeling/target/sglang_backend/utils.pytensor_all_to_allhelper that redistributes a tensor from vocab-sharded to batch-sharded layoutLogitsProcessor.forwardto useall_to_allinstead ofall_gatherwhenchunk_sizesare providedLogitsProcessorForEAGLE3wrapper withshard_returnsproperty to control the sharding behaviorreplaced_logits_processor_forward_for_eagle3to return sharded logits, aux hidden states, and last hidden statesspecforge/modeling/target/eagle3_target_model.pyshard_returnsparameter toSGLangEagle3TargetModel._extend()andgenerate_eagle3_data()_get_sharded_return()helper to split and select per-rank batch shards from concatenated outputsscripts/train_eagle3.py--shard-target-outputCLI flagsanity_checkrestrictions: sglang backend only, non-VLM onlyget_dp_data_shard_from_tp()helper that bypasses the standard DP sharding when outputs are already shardedRelated Issues
Closes #466
Accuracy Test
Verified on Qwen3-8B with tp_size=8 over 30 training steps. Loss values are the same between baseline (
all_gather) and--shard-target-output(all_to_all):Loss comparison (steps 1-30)
All 30 steps produce identical loss: the
all_to_allpath is mathematically equivalent toall_gather.Benchmark & Profiling
Setup: Qwen3-8B, tp_size=8, batch_size=1 (per DP rank), max_length=4096, 8x H200 GPUs
Memory snapshot visualization (at rank 0):
Baseline (

all_gather) - memory spike from full-vocab materialization:With

--shard-target-output(all_to_all) - spike gone:Peak memory during training step (observed at rank 0):
The baseline peak (60.7 GiB) reflects the cascading allocations in
generate_eagle3_data():all_gathermaterializes the full-vocab logits, thentorch.catandpadding()makes it bigger. Withall_to_all, the base tensor is1/tp_sizethe size, so the same downstream operations peak at only 16.1 GiB.cc @cih9088 @Kthyeon
Checklist