Skip to content

feat: paged KV cache for GRPO rollout#42

Open
haanjack wants to merge 7 commits intomainfrom
feat/paged-attention
Open

feat: paged KV cache for GRPO rollout#42
haanjack wants to merge 7 commits intomainfrom
feat/paged-attention

Conversation

@haanjack
Copy link
Copy Markdown
Owner

@haanjack haanjack commented May 5, 2026

Summary

  • Add block-based paged KV cache (BlockKVCacheManager) with on-demand block allocation, reference counting, and prefix sharing for efficient GRPO rollout generation
  • Implement generate_rollouts_paged() — vLLM-style paged attention with batched decode across all B×G sequences in parallel
  • Fix multiple issues in LanguageModel.forward() for paged cache path: per-sequence cache_position, correct return tuple, eval mode gating
  • Add E2E smoke tests (baseline + paged) and standalone example script (examples/train_grpo.py)

Changes

Area Details
ironcore/layers/block_kv_cache.py New: block allocator, paged KV cache manager with prefix sharing
ironcore/layers/paged_attention.py New: flash attention varlen wrapper for paged cache
ironcore/alignment/rollout.py New: generate_rollouts_paged() with block cache integration
ironcore/language_model.py Fixed paged cache path in forward, added helper methods
ironcore/layers/attention.py Guard against zero-length KV in attention
ironcore/models/transformer.py Wire paged cache through transformer layers
ironcore/trainers/grpo_trainer.py GRPO-specific data loading initialization
tests/unit/kvcache/test_block_kv_cache.py 32 unit tests for block cache manager
tests/integration/test_grpo_smoke.py 2 E2E training smoke tests (baseline + paged)
examples/train_grpo.py Standalone GRPO training example script

Test plan

  • pytest tests/unit/kvcache/test_block_kv_cache.py — 32 tests pass
  • pytest tests/integration/test_grpo_smoke.py -m rlvr — 2/2 E2E tests pass (2 GPUs)
  • pytest tests/ -m "not cuda and not mp" — 521 pass, 1 pre-existing failure, 15 skipped
  • ruff check and ruff format clean

The 1 failure (test_reward_manager_config_trains) is pre-existing on main — caused by a stale ../data/grpo_gsm8k config path in grpo_gsm8k_smoke_fsdp.yaml, unrelated to this PR.

haanjack added 4 commits May 5, 2026 06:46
Fix 8 bugs blocking the paged attention rollout path in GRPO training:

- Switch model to eval mode during rollout so block_kv_cache_manager
  path activates inside LanguageModel.forward (guarded by not self.training)
- Fix get_layer_kv_gathered returning 0 tokens when token_positions=0
  but blocks are already written during multi-layer prefill; fall back
  to num_valid_blocks * block_size
- Fix LanguageModel.forward returning 1 value instead of 2 when block
  cache is active; use has_cache flag for return condition
- Fix dtype mismatch (float32 vs bfloat16) in batched decode scatter
- Fix torch.full receiving tensor instead of scalar for pad value
- Guard flash_attn_varlen_func against seq_len_kv <= 0
- Fix total_mem -> total_memory for PyTorch 2.6 API
- Add per-sequence cache_position from block cache for batched decode
Add automated integration tests for GRPO training (baseline and paged
rollout) that run via subprocess/torchrun and validate successful
completion. Also add examples/train_grpo.py as a minimal standalone
entry point for manual GRPO training.
The path resolver incorrectly removed .yaml extensions, breaking
data.config_path which needs the full file path to exist.
Copilot AI review requested due to automatic review settings May 5, 2026 01:38
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements a block-based paged KV cache system to optimize Group Relative Policy Optimization (GRPO) rollouts, including a new cache manager and integration with the trainer. The review identifies critical bugs in sequence ID mapping during decoding and cache freeing, as well as redundant block allocations across transformer layers. Additionally, feedback highlights performance bottlenecks caused by serialized prefill steps and frequent cache re-initialization, alongside a logic error in prompt padding for block alignment.

Comment thread ironcore/alignment/rollout.py Outdated
Comment thread ironcore/alignment/rollout.py Outdated
Comment thread ironcore/layers/block_kv_cache.py
Comment on lines +485 to +488
unwrapped.initialize_cache(
prompt_ids.size(0) + prompt_ids.size(0) * chunk_group_size,
prompt_ids.device,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

initialize_cache is called inside the training loop for every rollout chunk. In BlockKVCacheManager, this method re-allocates the entire physical KV cache pool (large tensors). This will cause massive memory churn, significant performance degradation, and likely OOM due to memory fragmentation. The KV cache pool should be initialized once outside the training loop (e.g., in _initialize or _post_checkpoint_load).

Comment thread ironcore/alignment/rollout.py Outdated
Comment on lines +402 to +408
pad = torch.full(
(B, padded_len - prompt_len),
prompt_ids[0, -1].item(),
dtype=prompt_ids.dtype,
device=device,
)
padded_prompts = torch.cat([prompt_ids, pad], dim=1)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Padding the prompt with prompt_ids[0, -1] (the last token of the first prompt) to align with block boundaries is incorrect. These padding tokens will be processed by the model during prefill, which will alter the hidden states and KV cache of the actual prompt, leading to incorrect generation. If padding is required for block alignment, it should be handled using a proper pad_token_id and an attention mask, or the paged cache should be updated to handle partial blocks without requiring prompt modification.

Comment thread ironcore/alignment/rollout.py Outdated
Comment on lines +418 to +431
for i in range(B):
# Allocate blocks for this prompt
bkv.allocate_blocks(seq_id=i, count=blocks_per_prompt)

# Forward through all layers (seq_id routes to paged cache path)
single_prompt = padded_prompts[i : i + 1] # [1, padded_prompt_len]
logits, _ = unwrapped_model.forward(single_prompt, labels=None, seq_id=i)

# Advance position after all layers have written
bkv.advance_position(seq_id=i, tokens=padded_prompt_len)

prefill_logits_list.append(logits[:, -1, :]) # [1, vocab]

prefill_logits = torch.cat(prefill_logits_list, dim=0) # [B, vocab]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The prefill step is performed one sequence at a time in a Python loop. This is highly inefficient on GPUs as it serializes the computation and fails to leverage batch parallelism. Prefill should be batched across all prompts in padded_prompts by passing a list of seq_ids to a single model.forward() call.

Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a vLLM-style block/paged KV cache and a new GRPO rollout path intended to improve rollout-generation efficiency via on-demand block allocation and prefix sharing.

Changes:

  • Introduces BlockKVCacheManager and paged KV gather utilities to support block-based caching.
  • Adds generate_rollouts_paged() and wires paged-cache plumbing through LanguageModel and Transformer forward paths.
  • Adds GRPO smoke/integration coverage plus configs and an example entrypoint script.

Reviewed changes

Copilot reviewed 15 out of 15 changed files in this pull request and generated 14 comments.

Show a summary per file
File Description
ironcore/layers/block_kv_cache.py New block-based KV cache manager with allocation, refcounted prefix sharing, and gather helpers.
ironcore/layers/paged_attention.py New utilities to gather non-contiguous KV blocks into contiguous tensors (single + batched).
ironcore/models/transformer.py Adds paged-cache branch in transformer layer forward to write/gather KV via block cache.
ironcore/language_model.py Adds block cache manager, forwards seq_id, and attempts to support paged generation/stateful cache interactions.
ironcore/alignment/rollout.py Adds generate_rollouts_paged() for GRPO rollouts using block cache + prefix sharing.
ironcore/layers/attention.py Guards flash-attn path against zero-length KV.
ironcore/trainers/grpo_trainer.py Adds GRPO-specific initialization and toggles paged rollout path via config.
ironcore/config/config_model.py Extends KV cache config with paged-cache settings.
ironcore/config/config_alignment.py Adds grpo_use_paged_rollout toggle.
tests/unit/kvcache/test_block_kv_cache.py Adds unit tests for block cache allocation, writes, sharing, and gather behavior.
tests/integration/test_grpo_smoke.py Adds subprocess-based E2E GRPO smoke tests (baseline + paged).
tests/fixtures/configs/grpo_paged_smoke.yaml Adds paged rollout smoke config.
tests/fixtures/configs/grpo_baseline_smoke.yaml Adds baseline rollout smoke config.
examples/train_grpo.py Adds standalone GRPO training example runner.
configs/data/grpo_qwen_keyword_toy.yaml Adds toy GRPO dataset config for smoke/example.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread ironcore/language_model.py Outdated
Comment on lines 101 to +114
# Determine cache position
# For batched paged decode, use per-sequence positions from block cache
if cache_position is None:
cache_position = 0
bkv = block_kv_cache_manager
if bkv is None and self.block_kv_cache_manager is not None and not self.training:
bkv = self.block_kv_cache_manager
if bkv is not None and seq_id is not None and isinstance(seq_id, list):
cache_position = torch.tensor(
[bkv.token_positions[sid].item() for sid in seq_id],
dtype=torch.long,
device=input_ids.device,
)
else:
cache_position = 0
Comment on lines 237 to 268
use_stateful = self.kv_cache_manager is not None
use_paged = self.block_kv_cache_manager is not None and not use_stateful

if use_stateful:
self.initialize_cache(batch_size, input_ids.device)
elif use_paged:
self.initialize_cache(batch_size, input_ids.device)

for step in range(max_new_tokens):
cur_input = input_ids if step == 0 else next_token
cur_cache_pos = self.kv_cache_manager.get_cache_position() if use_stateful else None

out = self.forward(
cur_input,
labels=None,
use_cache=not use_stateful,
past_key_values=past_key_values,
cache_position=cur_cache_pos,
)

if use_stateful:
cur_cache_pos = self.kv_cache_manager.get_cache_position()
out = self.forward(
cur_input,
labels=None,
use_cache=False,
cache_position=cur_cache_pos,
)
logits, _ = out
elif use_paged:
out = self.forward(
cur_input,
labels=None,
use_cache=False,
seq_id=0,
)
logits, _ = out
# Advance position after all layers have written
tokens_written = cur_input.size(1)
self.advance_cache_position(0, tokens_written)
else:
Comment on lines 380 to +384
def reset_cache(self, batch_indices: list[int] | None = None):
if self.kv_cache_manager is not None:
self.kv_cache_manager.reset(batch_indices)
if self.block_kv_cache_manager is not None:
self.block_kv_cache_manager.free_sequence(batch_indices[0] if batch_indices else None)
Comment on lines +155 to +182
elif (
block_kv_cache_manager is not None
and seq_id is not None
and block_kv_cache_manager.is_initialized
):
# Paged KV cache path (supports single int or batched list)
is_batched = isinstance(seq_id, list)
if is_batched:
if seq_len > 1:
for i, sid in enumerate(seq_id):
block_kv_cache_manager.write_prefill(
self.layer_idx, sid, key[i : i + 1], value[i : i + 1]
)
else:
block_kv_cache_manager.write_decode_batched(self.layer_idx, seq_id, key, value)
full_key, full_value = block_kv_cache_manager.get_layer_kv_gathered_batched(
self.layer_idx, seq_id
)
else:
if seq_len > 1:
block_kv_cache_manager.write_prefill(self.layer_idx, seq_id, key, value)
else:
block_kv_cache_manager.write_decode(self.layer_idx, seq_id, key, value)
full_key, full_value = block_kv_cache_manager.get_layer_kv_gathered(
self.layer_idx, seq_id
)
attn_output = self.self_attention(query, full_key, full_value, attention_mask)
new_kv = None
Comment on lines +556 to +564
elif remainder == 0:
# All blocks are fully filled
block_indices = self.block_tables[seq_id, :num_valid].long()
key = self.physical_key_caches[layer_idx][block_indices].reshape(
1, total_tokens, -1, self.head_dim
)
value = self.physical_value_caches[layer_idx][block_indices].reshape(
1, total_tokens, -1, self.head_dim
)
Comment on lines +468 to +475
for dst_id in dst_seq_ids:
# Copy block table entries and token position
self.block_tables[dst_id, :src_num_blocks] = self.block_tables[
src_seq_id, :src_num_blocks
]
self.num_valid_blocks[dst_id] = src_num_blocks
self.token_positions[dst_id] = self.token_positions[src_seq_id]

Comment on lines +67 to +71
config_dir = config_file.parent
temp_file = config_dir / f"resolved_{config_file.stem}.yaml"
with open(temp_file, "w") as f:
yaml.dump(config, f, default_flow_style=False)
return str(temp_file)
Comment thread tests/unit/kvcache/test_block_kv_cache.py Outdated
Comment thread ironcore/alignment/rollout.py Outdated
haanjack added 3 commits May 5, 2026 19:07
…ecks

- Fix seq_id offset bug: decode/free loops now use completion_seq_ids
  [B, B+B*G) instead of [0, B*G), which was wrong for B>1
- Wrap generate_rollouts_paged in try/finally to restore model.training
  on exceptions
- Initialize bkv unconditionally in LanguageModel.forward() to prevent
  UnboundLocalError when cache_position is passed explicitly
- Add safety check in share_prefix() to prevent block leaks when dst
  already has allocated blocks
- Fix scatter indexing: use active_indices not active_seq_ids for
  tensor indexing
- Clean up temp resolved_*.yaml files after subprocess in smoke tests
- Use try/finally for monkeypatch restoration in block cache tests
…profiling

- Fix train_grpo.py API mismatch (load_trainer_config, forward_step, loss_fn)
- Add fallback in dataloader/alignment dataset for path traversal guard
  rejecting short config paths like data/grpo_gsm8k
- Fix _resolve_config_paths to handle ../ relative paths in test fixtures
- Move TP=2 KV cache tests from integration/ to multi_gpu/kvcache/
- Add VRAM profiling at 4 GRPO training phase boundaries
- Ignore resolved test config files in .gitignore
…dlock

When generating with FSDP on multiple GPUs, different prompts on
different ranks produce different generation lengths.  The rank that
finishes early skips FSDP all-gathers, deadlocking the other rank
still in the decode loop.

Fix: communicate done status via all_reduce on the default process
group (separate from FSDP's PG) before each decode step.  When all
ranks report done, they break together.  Dummy forwards only pad
the delta between rank finish times, not the full max_new_tokens.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants