Skip to content

Reduce peak GPU memory in Eagle3 online target generation by avoiding an extra logits copy#528

Open
zijiexia wants to merge 2 commits intosgl-project:mainfrom
zijiexia:fix_target_out_oom
Open

Reduce peak GPU memory in Eagle3 online target generation by avoiding an extra logits copy#528
zijiexia wants to merge 2 commits intosgl-project:mainfrom
zijiexia:fix_target_out_oom

Conversation

@zijiexia
Copy link
Copy Markdown

@zijiexia zijiexia commented Apr 9, 2026

Motivation

This PR fixes an out-of-memory issue in Eagle3 online training caused by an unnecessary full-tensor copy when shifting target logits.

Previously, generate_eagle3_data() accumulated per-sample logits into a list, concatenated them into a [B, T, V] tensor, and then called padding(target_out, left=False) to shift the logits left and append a zero row at the end. For large vocab models, that final padding step materialized another full [B, T, V] allocation and could trigger multi-GB peak memory spikes.

This change pre-allocates the final target_out tensor once and writes the shifted logits directly into it:

  • target_out[idx, :-1] = logits[..., 1:, :]
  • target_out[idx, -1] = 0

That preserves the original semantics while removing the extra full-size allocation.

Root Cause
The old implementation created peak memory pressure in two stages:

  1. Concatenate per-sample logits into a full target_out tensor.
  2. Call padding(target_out, left=False), which internally builds a zero padding tensor and concatenates again, creating another full-sized [B, T, V] tensor.

For Eagle3 online training, V is the target model vocabulary size, so this copy is extremely expensive. In practice this showed up as OOM during generate_eagle3_data() even though steady-state memory usage was otherwise close to fitting.

Modifications

  • Stop collecting target_out as a Python list of per-sample logits tensors.
  • Detect whether logits are present with has_logits = logits_list[0] is not None.
  • Pre-allocate target_out with shape [B, T, V] using the first logits tensor's device and dtype.
  • Write the shifted logits directly into the pre-allocated output tensor during the main loop.
  • Remove the padding(target_out, left=False) call entirely.

Related Issues

Accuracy Test

Benchmark & Profiling

Checklist

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant