diff --git a/areal/infra/workflow_executor.py b/areal/infra/workflow_executor.py index 949d2af0b..00a3decc3 100644 --- a/areal/infra/workflow_executor.py +++ b/areal/infra/workflow_executor.py @@ -869,14 +869,18 @@ async def _dump_trajectory( self.logger.warning( "Trajectory missing 'versions' field, defaulting to current inference engine version." ) - versions = [self.inference_engine.get_version()] + all_versions = None + default_version = self.inference_engine.get_version() else: - versions = traj["versions"].flatten().tolist() + all_versions = traj["versions"] - tail_version = max(versions) - head_version = min(versions) + global_tail = ( + all_versions.max().item() + if all_versions is not None + else default_version + ) # Create versioned directory - version_dir = os.path.join(dump_dir, str(tail_version)) + version_dir = os.path.join(dump_dir, str(global_tail)) await aiofiles.os.makedirs(version_dir, exist_ok=True) # Handle batched trajectories @@ -894,6 +898,25 @@ async def _dump_trajectory( if mask[-1] != 1: continue + if all_versions is not None: + sample_versions = all_versions[i, :seqlen].tolist() + output_versions = [ + v for v, m in zip(sample_versions, mask) if m == 1 + ] + else: + output_versions = [default_version] + + head_version = min(output_versions) if output_versions else -1 + tail_version = max(output_versions) if output_versions else -1 + + # RLE: [[version, count], ...] + version_rle = [] + for v in output_versions: + if version_rle and version_rle[-1][0] == v: + version_rle[-1][1] += 1 + else: + version_rle.append([v, 1]) + prompt_end = seqlen - sum(mask) prompt_ids = ids[:prompt_end] completion_ids = ids[prompt_end:] @@ -913,6 +936,7 @@ async def _dump_trajectory( "prompt_len": prompt_end, "head_version": head_version, "tail_version": tail_version, + "version_rle": version_rle, "reward": reward, "prompt": prompt_text, "completion": completion_text,