Skip to content

feat: reduce Eagle3 training memory spike via all-to-all sharding#524

Open
laoconeth wants to merge 5 commits intosgl-project:mainfrom
laoconeth:feat/shard-target-output
Open

feat: reduce Eagle3 training memory spike via all-to-all sharding#524
laoconeth wants to merge 5 commits intosgl-project:mainfrom
laoconeth:feat/shard-target-output

Conversation

@laoconeth
Copy link
Copy Markdown

Motivation

Currently, Eagle3 training suffers from memory spikes that make it difficult to increase batch size (as reported in #466).

During Eagle3 training with the SGLang backend under tensor parallelism, the lm_head is vocab parallel - each TP rank computes a shard of the logits with shape (global_batch_size * seq_len, vocab_size // tp_size). The current code uses all_gather to reconstruct the full logits, resulting in every rank holding (global_batch_size * seq_len, vocab_size). This is already a tp_size-fold memory increase, which is very big. The problem compounds in generate_eagle3_data(): the gathered logit tensors are concatenated via torch.cat (eagle3_target_model.py:793), and then padding() (utils.py:41) performs another torch.cat to shift the sequence. Due to cat allocating a new full-size copy while the original is still alive, peak memory reaches much higher. With large vocab sizes (e.g. 150K for Qwen3), this makes it difficult to increase batch size.

This PR adds a --shard-target-output flag that replaces all_gather with all_to_all. Instead of gathering the full vocab on every rank, each rank exchanges its vocab shard for a batch shard, ending up with ((global_batch_size // dp_size) * seq_len, vocab_size) - each rank holds the full vocab but only for its own DP shard of the batch. The same redistribution is applied to aux_hidden_states and last_hidden_states. Since the base tensor is much smaller (1/tp_size of the tensor without sharding), the subsequent padding/cat operations also operate on smaller tensors, eliminating the memory spike.

There is no alteration of tensor content, just distributed differently - the mathematical result is identical (training run shows identical loss values for 50 steps with sharding on/off).

Note: --shard-target-output is currently implemented for SGLang backend only, and is not tested for VLM models.

Modifications

specforge/modeling/target/sglang_backend/utils.py

  • Added tensor_all_to_all helper that redistributes a tensor from vocab-sharded to batch-sharded layout
  • Modified LogitsProcessor.forward to use all_to_all instead of all_gather when chunk_sizes are provided
  • Added LogitsProcessorForEAGLE3 wrapper with shard_returns property to control the sharding behavior
  • Modified replaced_logits_processor_forward_for_eagle3 to return sharded logits, aux hidden states, and last hidden states

specforge/modeling/target/eagle3_target_model.py

  • Added shard_returns parameter to SGLangEagle3TargetModel._extend() and generate_eagle3_data()
  • Added _get_sharded_return() helper to split and select per-rank batch shards from concatenated outputs

scripts/train_eagle3.py

  • Added --shard-target-output CLI flag
  • Added sanity_check restrictions: sglang backend only, non-VLM only
  • Added get_dp_data_shard_from_tp() helper that bypasses the standard DP sharding when outputs are already sharded

Related Issues

Closes #466

Accuracy Test

Verified on Qwen3-8B with tp_size=8 over 30 training steps. Loss values are the same between baseline (all_gather) and --shard-target-output (all_to_all):

Loss comparison (steps 1-30)
step baseline shard match
1 10.39 10.39 yes
2 10.23 10.23 yes
3 9.36 9.36 yes
4 9.42 9.42 yes
5 9.37 9.37 yes
6 7.57 7.57 yes
7 2.68 2.68 yes
8 2.55 2.55 yes
9 2.48 2.48 yes
10 8.24 8.24 yes
11 2.55 2.55 yes
12 2.38 2.38 yes
13 2.42 2.42 yes
14 4.94 4.94 yes
15 2.42 2.42 yes
16 2.40 2.40 yes
17 6.54 6.54 yes
18 4.93 4.93 yes
19 4.03 4.03 yes
20 2.03 2.03 yes
21 2.37 2.37 yes
22 2.37 2.37 yes
23 5.82 5.82 yes
24 5.01 5.01 yes
25 2.33 2.33 yes
26 2.35 2.35 yes
27 2.32 2.32 yes
28 2.36 2.36 yes
29 3.87 3.87 yes
30 5.88 5.88 yes

All 30 steps produce identical loss: the all_to_all path is mathematically equivalent to all_gather.

Benchmark & Profiling

Setup: Qwen3-8B, tp_size=8, batch_size=1 (per DP rank), max_length=4096, 8x H200 GPUs

Memory snapshot visualization (at rank 0):

Baseline (all_gather) - memory spike from full-vocab materialization:
snapshot_baseline

With --shard-target-output (all_to_all) - spike gone:
snapshot_shard

Peak memory during training step (observed at rank 0):

baseline (all_gather) --shard-target-output (all_to_all) reduction
peak allocated 60.7 GiB 16.1 GiB -73%

The baseline peak (60.7 GiB) reflects the cascading allocations in generate_eagle3_data(): all_gather materializes the full-vocab logits, then torch.cat and padding() makes it bigger. With all_to_all, the base tensor is 1/tp_size the size, so the same downstream operations peak at only 16.1 GiB.

cc @cih9088 @Kthyeon

Checklist

laoconeth and others added 2 commits April 6, 2026 00:52
Co-authored-by: Inhyuk Cho <cih9088@gmail.com>
Co-authored-by: Kthyeon <potter32@kaist.ac.kr>
ModelRunner (sglang==0.5.9) always returns ModelRunnerOutput
with logits_output field, so the hasattr guard is unnecessary.

Co-authored-by: Inhyuk Cho <cih9088@gmail.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the --shard-target-output feature for EAGLE3 training, enabling model outputs to be sharded across the batch dimension when using the SGLang backend to optimize memory usage. The changes include updates to the training script, the EAGLE3 target model, and the SGLang-specific logits processor, which now utilizes all_to_all communication. Review feedback points out a potential RuntimeError during tensor concatenation if shards are empty and suggests explicitly deleting intermediate tensors in the SGLang backend to ensure effective memory reclamation.

Comment on lines 782 to 786
aux_hidden_states_out = torch.cat(aux_hidden_states_out, dim=0)

loss_mask_out = torch.cat(loss_mask_out, dim=0)
attention_mask_out = torch.cat(attention_mask_out, dim=0)
input_ids_out = torch.cat(input_ids_out, dim=0)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

These torch.cat operations will raise a RuntimeError if the input lists are empty. This scenario can occur when shard_returns is enabled and a rank is not assigned any batch items (e.g., if the global batch size is smaller than the world size). Although current sanity checks in the training script might prevent this, the model implementation should be robust to empty shards.

Copy link
Copy Markdown

@cih9088 cih9088 Apr 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This cannot happen.

sample_indices,
logits_metadata,
)
del hidden_states
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To ensure the large un-sharded tensors are eligible for garbage collection and truly eliminate the memory spike, it is recommended to also delete pruned_states, aux_pruned_states, and pruned_states_before_norm. These variables hold views into the original full-sized tensors, which prevents the underlying memory from being reclaimed even after hidden_states is deleted.

Suggested change
del hidden_states
del hidden_states, pruned_states, aux_pruned_states, pruned_states_before_norm

Copy link
Copy Markdown

@cih9088 cih9088 Apr 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reference counter will be decremented at the end of the function anyway, just a few lines below. del hidden_states is kept for consistency with upstream, sglang==0.5.9

@jiapingW
Copy link
Copy Markdown
Collaborator

jiapingW commented Apr 6, 2026

Please help fix the unit-test.

@cih9088
Copy link
Copy Markdown

cih9088 commented Apr 8, 2026

@jiapingW
Could you approve and re-run the unittests?

@cih9088
Copy link
Copy Markdown

cih9088 commented Apr 9, 2026

@jiapingW
My bad. Fixed typo.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] Abnormal memory usage and Out-of-Memory in eagle3 training.

3 participants