Conversation
Fix 8 bugs blocking the paged attention rollout path in GRPO training: - Switch model to eval mode during rollout so block_kv_cache_manager path activates inside LanguageModel.forward (guarded by not self.training) - Fix get_layer_kv_gathered returning 0 tokens when token_positions=0 but blocks are already written during multi-layer prefill; fall back to num_valid_blocks * block_size - Fix LanguageModel.forward returning 1 value instead of 2 when block cache is active; use has_cache flag for return condition - Fix dtype mismatch (float32 vs bfloat16) in batched decode scatter - Fix torch.full receiving tensor instead of scalar for pad value - Guard flash_attn_varlen_func against seq_len_kv <= 0 - Fix total_mem -> total_memory for PyTorch 2.6 API - Add per-sequence cache_position from block cache for batched decode
Add automated integration tests for GRPO training (baseline and paged rollout) that run via subprocess/torchrun and validate successful completion. Also add examples/train_grpo.py as a minimal standalone entry point for manual GRPO training.
The path resolver incorrectly removed .yaml extensions, breaking data.config_path which needs the full file path to exist.
There was a problem hiding this comment.
Code Review
This pull request implements a block-based paged KV cache system to optimize Group Relative Policy Optimization (GRPO) rollouts, including a new cache manager and integration with the trainer. The review identifies critical bugs in sequence ID mapping during decoding and cache freeing, as well as redundant block allocations across transformer layers. Additionally, feedback highlights performance bottlenecks caused by serialized prefill steps and frequent cache re-initialization, alongside a logic error in prompt padding for block alignment.
| unwrapped.initialize_cache( | ||
| prompt_ids.size(0) + prompt_ids.size(0) * chunk_group_size, | ||
| prompt_ids.device, | ||
| ) |
There was a problem hiding this comment.
initialize_cache is called inside the training loop for every rollout chunk. In BlockKVCacheManager, this method re-allocates the entire physical KV cache pool (large tensors). This will cause massive memory churn, significant performance degradation, and likely OOM due to memory fragmentation. The KV cache pool should be initialized once outside the training loop (e.g., in _initialize or _post_checkpoint_load).
| pad = torch.full( | ||
| (B, padded_len - prompt_len), | ||
| prompt_ids[0, -1].item(), | ||
| dtype=prompt_ids.dtype, | ||
| device=device, | ||
| ) | ||
| padded_prompts = torch.cat([prompt_ids, pad], dim=1) |
There was a problem hiding this comment.
Padding the prompt with prompt_ids[0, -1] (the last token of the first prompt) to align with block boundaries is incorrect. These padding tokens will be processed by the model during prefill, which will alter the hidden states and KV cache of the actual prompt, leading to incorrect generation. If padding is required for block alignment, it should be handled using a proper pad_token_id and an attention mask, or the paged cache should be updated to handle partial blocks without requiring prompt modification.
| for i in range(B): | ||
| # Allocate blocks for this prompt | ||
| bkv.allocate_blocks(seq_id=i, count=blocks_per_prompt) | ||
|
|
||
| # Forward through all layers (seq_id routes to paged cache path) | ||
| single_prompt = padded_prompts[i : i + 1] # [1, padded_prompt_len] | ||
| logits, _ = unwrapped_model.forward(single_prompt, labels=None, seq_id=i) | ||
|
|
||
| # Advance position after all layers have written | ||
| bkv.advance_position(seq_id=i, tokens=padded_prompt_len) | ||
|
|
||
| prefill_logits_list.append(logits[:, -1, :]) # [1, vocab] | ||
|
|
||
| prefill_logits = torch.cat(prefill_logits_list, dim=0) # [B, vocab] |
There was a problem hiding this comment.
The prefill step is performed one sequence at a time in a Python loop. This is highly inefficient on GPUs as it serializes the computation and fails to leverage batch parallelism. Prefill should be batched across all prompts in padded_prompts by passing a list of seq_ids to a single model.forward() call.
There was a problem hiding this comment.
Pull request overview
Adds a vLLM-style block/paged KV cache and a new GRPO rollout path intended to improve rollout-generation efficiency via on-demand block allocation and prefix sharing.
Changes:
- Introduces
BlockKVCacheManagerand paged KV gather utilities to support block-based caching. - Adds
generate_rollouts_paged()and wires paged-cache plumbing throughLanguageModelandTransformerforward paths. - Adds GRPO smoke/integration coverage plus configs and an example entrypoint script.
Reviewed changes
Copilot reviewed 15 out of 15 changed files in this pull request and generated 14 comments.
Show a summary per file
| File | Description |
|---|---|
ironcore/layers/block_kv_cache.py |
New block-based KV cache manager with allocation, refcounted prefix sharing, and gather helpers. |
ironcore/layers/paged_attention.py |
New utilities to gather non-contiguous KV blocks into contiguous tensors (single + batched). |
ironcore/models/transformer.py |
Adds paged-cache branch in transformer layer forward to write/gather KV via block cache. |
ironcore/language_model.py |
Adds block cache manager, forwards seq_id, and attempts to support paged generation/stateful cache interactions. |
ironcore/alignment/rollout.py |
Adds generate_rollouts_paged() for GRPO rollouts using block cache + prefix sharing. |
ironcore/layers/attention.py |
Guards flash-attn path against zero-length KV. |
ironcore/trainers/grpo_trainer.py |
Adds GRPO-specific initialization and toggles paged rollout path via config. |
ironcore/config/config_model.py |
Extends KV cache config with paged-cache settings. |
ironcore/config/config_alignment.py |
Adds grpo_use_paged_rollout toggle. |
tests/unit/kvcache/test_block_kv_cache.py |
Adds unit tests for block cache allocation, writes, sharing, and gather behavior. |
tests/integration/test_grpo_smoke.py |
Adds subprocess-based E2E GRPO smoke tests (baseline + paged). |
tests/fixtures/configs/grpo_paged_smoke.yaml |
Adds paged rollout smoke config. |
tests/fixtures/configs/grpo_baseline_smoke.yaml |
Adds baseline rollout smoke config. |
examples/train_grpo.py |
Adds standalone GRPO training example runner. |
configs/data/grpo_qwen_keyword_toy.yaml |
Adds toy GRPO dataset config for smoke/example. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # Determine cache position | ||
| # For batched paged decode, use per-sequence positions from block cache | ||
| if cache_position is None: | ||
| cache_position = 0 | ||
| bkv = block_kv_cache_manager | ||
| if bkv is None and self.block_kv_cache_manager is not None and not self.training: | ||
| bkv = self.block_kv_cache_manager | ||
| if bkv is not None and seq_id is not None and isinstance(seq_id, list): | ||
| cache_position = torch.tensor( | ||
| [bkv.token_positions[sid].item() for sid in seq_id], | ||
| dtype=torch.long, | ||
| device=input_ids.device, | ||
| ) | ||
| else: | ||
| cache_position = 0 |
| use_stateful = self.kv_cache_manager is not None | ||
| use_paged = self.block_kv_cache_manager is not None and not use_stateful | ||
|
|
||
| if use_stateful: | ||
| self.initialize_cache(batch_size, input_ids.device) | ||
| elif use_paged: | ||
| self.initialize_cache(batch_size, input_ids.device) | ||
|
|
||
| for step in range(max_new_tokens): | ||
| cur_input = input_ids if step == 0 else next_token | ||
| cur_cache_pos = self.kv_cache_manager.get_cache_position() if use_stateful else None | ||
|
|
||
| out = self.forward( | ||
| cur_input, | ||
| labels=None, | ||
| use_cache=not use_stateful, | ||
| past_key_values=past_key_values, | ||
| cache_position=cur_cache_pos, | ||
| ) | ||
|
|
||
| if use_stateful: | ||
| cur_cache_pos = self.kv_cache_manager.get_cache_position() | ||
| out = self.forward( | ||
| cur_input, | ||
| labels=None, | ||
| use_cache=False, | ||
| cache_position=cur_cache_pos, | ||
| ) | ||
| logits, _ = out | ||
| elif use_paged: | ||
| out = self.forward( | ||
| cur_input, | ||
| labels=None, | ||
| use_cache=False, | ||
| seq_id=0, | ||
| ) | ||
| logits, _ = out | ||
| # Advance position after all layers have written | ||
| tokens_written = cur_input.size(1) | ||
| self.advance_cache_position(0, tokens_written) | ||
| else: |
| def reset_cache(self, batch_indices: list[int] | None = None): | ||
| if self.kv_cache_manager is not None: | ||
| self.kv_cache_manager.reset(batch_indices) | ||
| if self.block_kv_cache_manager is not None: | ||
| self.block_kv_cache_manager.free_sequence(batch_indices[0] if batch_indices else None) |
| elif ( | ||
| block_kv_cache_manager is not None | ||
| and seq_id is not None | ||
| and block_kv_cache_manager.is_initialized | ||
| ): | ||
| # Paged KV cache path (supports single int or batched list) | ||
| is_batched = isinstance(seq_id, list) | ||
| if is_batched: | ||
| if seq_len > 1: | ||
| for i, sid in enumerate(seq_id): | ||
| block_kv_cache_manager.write_prefill( | ||
| self.layer_idx, sid, key[i : i + 1], value[i : i + 1] | ||
| ) | ||
| else: | ||
| block_kv_cache_manager.write_decode_batched(self.layer_idx, seq_id, key, value) | ||
| full_key, full_value = block_kv_cache_manager.get_layer_kv_gathered_batched( | ||
| self.layer_idx, seq_id | ||
| ) | ||
| else: | ||
| if seq_len > 1: | ||
| block_kv_cache_manager.write_prefill(self.layer_idx, seq_id, key, value) | ||
| else: | ||
| block_kv_cache_manager.write_decode(self.layer_idx, seq_id, key, value) | ||
| full_key, full_value = block_kv_cache_manager.get_layer_kv_gathered( | ||
| self.layer_idx, seq_id | ||
| ) | ||
| attn_output = self.self_attention(query, full_key, full_value, attention_mask) | ||
| new_kv = None |
| elif remainder == 0: | ||
| # All blocks are fully filled | ||
| block_indices = self.block_tables[seq_id, :num_valid].long() | ||
| key = self.physical_key_caches[layer_idx][block_indices].reshape( | ||
| 1, total_tokens, -1, self.head_dim | ||
| ) | ||
| value = self.physical_value_caches[layer_idx][block_indices].reshape( | ||
| 1, total_tokens, -1, self.head_dim | ||
| ) |
| for dst_id in dst_seq_ids: | ||
| # Copy block table entries and token position | ||
| self.block_tables[dst_id, :src_num_blocks] = self.block_tables[ | ||
| src_seq_id, :src_num_blocks | ||
| ] | ||
| self.num_valid_blocks[dst_id] = src_num_blocks | ||
| self.token_positions[dst_id] = self.token_positions[src_seq_id] | ||
|
|
| config_dir = config_file.parent | ||
| temp_file = config_dir / f"resolved_{config_file.stem}.yaml" | ||
| with open(temp_file, "w") as f: | ||
| yaml.dump(config, f, default_flow_style=False) | ||
| return str(temp_file) |
…ecks - Fix seq_id offset bug: decode/free loops now use completion_seq_ids [B, B+B*G) instead of [0, B*G), which was wrong for B>1 - Wrap generate_rollouts_paged in try/finally to restore model.training on exceptions - Initialize bkv unconditionally in LanguageModel.forward() to prevent UnboundLocalError when cache_position is passed explicitly - Add safety check in share_prefix() to prevent block leaks when dst already has allocated blocks - Fix scatter indexing: use active_indices not active_seq_ids for tensor indexing - Clean up temp resolved_*.yaml files after subprocess in smoke tests - Use try/finally for monkeypatch restoration in block cache tests
…profiling - Fix train_grpo.py API mismatch (load_trainer_config, forward_step, loss_fn) - Add fallback in dataloader/alignment dataset for path traversal guard rejecting short config paths like data/grpo_gsm8k - Fix _resolve_config_paths to handle ../ relative paths in test fixtures - Move TP=2 KV cache tests from integration/ to multi_gpu/kvcache/ - Add VRAM profiling at 4 GRPO training phase boundaries - Ignore resolved test config files in .gitignore
…dlock When generating with FSDP on multiple GPUs, different prompts on different ranks produce different generation lengths. The rank that finishes early skips FSDP all-gathers, deadlocking the other rank still in the decode loop. Fix: communicate done status via all_reduce on the default process group (separate from FSDP's PG) before each decode step. When all ranks report done, they break together. Dummy forwards only pad the delta between rank finish times, not the full max_new_tokens.
Summary
BlockKVCacheManager) with on-demand block allocation, reference counting, and prefix sharing for efficient GRPO rollout generationgenerate_rollouts_paged()— vLLM-style paged attention with batched decode across all B×G sequences in parallelLanguageModel.forward()for paged cache path: per-sequence cache_position, correct return tuple, eval mode gatingexamples/train_grpo.py)Changes
ironcore/layers/block_kv_cache.pyironcore/layers/paged_attention.pyironcore/alignment/rollout.pygenerate_rollouts_paged()with block cache integrationironcore/language_model.pyironcore/layers/attention.pyironcore/models/transformer.pyironcore/trainers/grpo_trainer.pytests/unit/kvcache/test_block_kv_cache.pytests/integration/test_grpo_smoke.pyexamples/train_grpo.pyTest plan
pytest tests/unit/kvcache/test_block_kv_cache.py— 32 tests passpytest tests/integration/test_grpo_smoke.py -m rlvr— 2/2 E2E tests pass (2 GPUs)pytest tests/ -m "not cuda and not mp"— 521 pass, 1 pre-existing failure, 15 skippedruff checkandruff formatcleanThe 1 failure (
test_reward_manager_config_trains) is pre-existing on main — caused by a stale../data/grpo_gsm8kconfig path ingrpo_gsm8k_smoke_fsdp.yaml, unrelated to this PR.