feat:enable v2 training pipeline with controller parity#1363
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces 'v2' controller support, integrating RolloutControllerV2 and implementing RL parity methods like connect_engine and update_weights in the GatewayTrainController. It also refactors weight update logic, removes legacy reward utilities, and updates example configurations and tests. Feedback identifies a critical AttributeError in connect_engine due to an undefined attribute and notes that update_weights incorrectly invokes asynchronous methods synchronously while containing redundant generation control logic that may conflict with the trainer's state.
| ) | ||
| ctrl.initialize() | ||
|
|
||
| inference_urls: list[str] = rollout.inference_worker_urls |
There was a problem hiding this comment.
The attribute inference_worker_urls is not defined in RolloutControllerV2, which will cause an AttributeError at runtime. You should use the internal _inf_addrs attribute or add a public property to RolloutControllerV2 to expose these URLs.
| inference_urls: list[str] = rollout.inference_worker_urls | |
| inference_urls: list[str] = rollout._inf_addrs |
| self.rollout.pause_generation() | ||
| assert meta.version is not None and meta.version > 0, ( | ||
| f"meta.version must be a positive integer, got {meta.version}" | ||
| ) | ||
| result = self._weight_update_ctrl.update_weights(version=meta.version) | ||
| self.rollout.continue_generation() |
There was a problem hiding this comment.
This block has multiple issues:
- The calls to
pause_generation()andcontinue_generation()are asynchronous but called synchronously, returning coroutines without executing them. - These calls are redundant and potentially harmful here.
PPOTrainer.trainalready manages the rollout pause/resume state. InRolloutControllerV2,pause()correctly stops generation. Resuming it here viacontinue_generation()would break the trainer's expectation that inference remains paused during the subsequent save and evaluation steps. - The
assertshould be replaced with a proper runtime check as assertions can be disabled in production.
if meta.version is None or meta.version <= 0:
raise ValueError(f"meta.version must be a positive integer, got {meta.version}")
result = self._weight_update_ctrl.update_weights(version=meta.version)ce21c49 to
2e4dffd
Compare
Bring GatewayTrainController and RolloutControllerV2 to full parity with v1 controllers for RL training paths. Key changes: - Route to RolloutControllerV2 when config._version=="v2" - Add version management, connect_engine, clear_batches to GatewayTrainController - Unify HTTP client session in GatewayTrainController (follows PR #1354) - Switch default workflow to MathAgent in example configs - Add agent config section to all example YAML files - Remove obsolete get_custom_reward_fn from reward module
2e4dffd to
2541bbd
Compare
Changes
GatewayTrainController: add version management (set_version/get_version),connect_engine,clear_batches; persist guard addresses for later port allocation; unify HTTP client session (follow-up to feat: controller v2 refactor #1354).rl_trainer,sglang_remote,vllm_remote): route toRolloutControllerV2whenconfig._version == "v2".connectmethod for the v2 path.agentconfig section to allexamples/math/*.yaml; switch the default workflow ingsm8k_rl.pytoMathAgent.get_custom_reward_fnfromareal/reward/__init__.py.