diff --git a/areal/trainer/rl_trainer.py b/areal/trainer/rl_trainer.py index 73b98c21f..9640cc611 100644 --- a/areal/trainer/rl_trainer.py +++ b/areal/trainer/rl_trainer.py @@ -1115,7 +1115,8 @@ def _save_hf(self, epoch: int, epoch_step: int, global_step: int): # Async mode: synchronization handled by AsyncCheckpointManager if not self.saver.is_async: dist.barrier(group=self.actor.cpu_group) - current_platform.synchronize() + if not is_single_controller(): + current_platform.synchronize() def _save_recover_checkpoint(self, epoch: int, epoch_step: int, global_step: int): # Save recoverable checkpoints @@ -1140,7 +1141,8 @@ def _save_recover_checkpoint(self, epoch: int, epoch_step: int, global_step: int ) dist.barrier(group=self.actor.cpu_group) - current_platform.synchronize() + if not is_single_controller(): + current_platform.synchronize() def _evaluate_fn( self, @@ -1162,7 +1164,8 @@ def _evaluate_fn( self.eval_rollout.wait(cnt, timeout=None) dist.barrier(group=self.actor.cpu_group) - current_platform.synchronize() + if not is_single_controller(): + current_platform.synchronize() def _evaluate( self, @@ -1189,7 +1192,8 @@ def _evaluate( global_step, ) dist.barrier(group=self.actor.cpu_group) - current_platform.synchronize() + if not is_single_controller(): + current_platform.synchronize() def _export_and_commit_stats(self, epoch: int, epoch_step: int, global_step: int): # Upload statistics to the logger (e.g., wandb) @@ -1200,7 +1204,8 @@ def _export_and_commit_stats(self, epoch: int, epoch_step: int, global_step: int self.stats_logger.commit(epoch, epoch_step, global_step, stats) dist.barrier(group=self.actor.cpu_group) - current_platform.synchronize() + if not is_single_controller(): + current_platform.synchronize() def _validate_cfg(self): """validate config for incompatible settings before weight initialization, to avoid wasted resources on spawning workers and loading models."""