Skip to content

feat(distillation): add on-policy distillation using RolloutEngine#1376

Open
zahrayousefijamarani wants to merge 5 commits into
areal-project:mainfrom
zahrayousefijamarani:on_policy_distillation
Open

feat(distillation): add on-policy distillation using RolloutEngine#1376
zahrayousefijamarani wants to merge 5 commits into
areal-project:mainfrom
zahrayousefijamarani:on_policy_distillation

Conversation

@zahrayousefijamarani
Copy link
Copy Markdown

Description

Summary

This PR enables on-policy distillation with a dedicated teacher rollout/inference engine (vLLM/SGLang), instead of relying on a train-engine teacher path.
The goal is to reduce memory overhead and provide a clean inference-side token log-prob scoring API used by distillation losses.

Motivation

In on-policy distillation, teacher is used for teacher_logp scoring only.
A full train-engine teacher can allocate unnecessary training-state memory (optimizer/grad-related structures), while rollout/inference teacher is lighter and better aligned with the actual use case.

Related Issue

Fixes #1367

Type of Change

  • 🐛 Bug fix
  • ✨ New feature
  • 💥 Breaking change
  • 📝 Documentation update
  • ♻️ Refactoring
  • ⚡ Performance improvement
  • ✅ Test coverage improvement

Checklist

  • I have read the Contributing Guide
  • Pre-commit hooks pass (pre-commit run --all-files)
  • Relevant tests pass; new tests added for new functionality
  • Documentation updated (if applicable; built with ./docs/build_all.sh)
  • Branch is up to date with main
  • Self-reviewed via /review-pr command
  • This PR was created by a coding agent via /create-pr
  • This PR is a breaking change

Breaking Change Details (if applicable):

Additional Context

What changed

1) Teacher config refactor

  • TeacherConfig no longer inherits train-actor config.
  • Added explicit teacher fields:
    • rollout: InferenceEngineConfig
    • path: str
    • offload: bool
  • Retains:
    • rl_loss_weight
    • distill_loss_weight

2) Inference scoring API

  • Added InferenceEngine.compute_logp(...) API.
  • Extended remote backend protocol with:
    • build_score_request(...)
    • parse_score_response(...)

3) Remote scoring implementation

Implemented compute_logp in RemoteInfEngine:

  • Sends backend score requests
  • Parses token log-prob outputs
  • Returns per-trajectory tensors aligned to masked token positions

Added passthrough implementations in:

  • RemotevLLMEngine
  • RemoteSGLangEngine

Backend-specific score request/response logic added for:

  • vLLM
  • SGLang

4) Controller integration

  • Added RolloutController.compute_logp(...) so trainer can call scoring via controller mode.
  • Requests are sharded across workers.
  • Results are merged in input order.

5) Trainer integration

  • RLTrainer now supports dedicated teacher rollout initialization via _init_teacher_rollout(...).
  • Training loop now consumes:
    • teacher.compute_logp(rollout_batch)
  • Attaches:
    • teacher_logp
    • rl_loss_weight
    • distill_loss_weight

Added compatibility guards:

  • Ensure teacher teardown in close()

6) Docs / examples

  • Updated distillation example config to new schema:
    • teacher.path
    • teacher.rollout
  • Added the new plot in the result section.

Need help? Check the Contributing Guide or ask in
GitHub Discussions!

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a rollout-based teacher engine type for inference-only teacher distillation using vLLM or SGLang, deprecating the legacy train-engine teacher path. It implements token log-probability computation across remote engines, controllers, and the trainer. The review feedback identifies a missing pipeline parallel size parameter in the SGLang teacher configuration and recommends adding defensive checks when parsing API responses from vLLM and SGLang servers to prevent potential runtime errors.

Comment thread areal/trainer/rl_trainer.py
Comment thread areal/engine/vllm_remote.py Outdated
Comment thread areal/engine/sglang_remote.py Outdated
@HwVanICI
Copy link
Copy Markdown
Collaborator

Small changes to add:

  1. Can add warning when multiple engine types are given in teacherconfig to avoid confusion.
  2. Previously we saw that prompt_logprobs=1 in vLLM may not always return the student token in the first index, did you verify this? Maybe we can add a verification to ensure the student token matches the prompt logprob id.
    Otherwise looks good to me.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature] Use RolloutEngine as distillation teacher to reduce GPU memory vs TrainEngine teacher

2 participants