Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 29 additions & 5 deletions areal/infra/workflow_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,14 +869,18 @@ async def _dump_trajectory(
self.logger.warning(
"Trajectory missing 'versions' field, defaulting to current inference engine version."
)
versions = [self.inference_engine.get_version()]
all_versions = None
default_version = self.inference_engine.get_version()
else:
versions = traj["versions"].flatten().tolist()
all_versions = traj["versions"]

tail_version = max(versions)
head_version = min(versions)
global_tail = (
all_versions.max().item()
if all_versions is not None
else default_version
)
Comment on lines +877 to +881
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The global_tail calculation should also be filtered by loss_mask to ensure that the version directory name is not set to -1 (the placeholder for input tokens) if the versions tensor contains such values. This ensures the directory name reflects the actual model version used for generation.

Suggested change
global_tail = (
all_versions.max().item()
if all_versions is not None
else default_version
)
if all_versions is not None:
# Filter by loss_mask to avoid -1 placeholders in the version directory name
valid_versions = all_versions[loss_mask == 1]
global_tail = (
int(valid_versions.max().item())
if valid_versions.numel() > 0
else default_version
)
else:
global_tail = default_version

# Create versioned directory
version_dir = os.path.join(dump_dir, str(tail_version))
version_dir = os.path.join(dump_dir, str(global_tail))
await aiofiles.os.makedirs(version_dir, exist_ok=True)

# Handle batched trajectories
Expand All @@ -894,6 +898,25 @@ async def _dump_trajectory(
if mask[-1] != 1:
continue

if all_versions is not None:
sample_versions = all_versions[i, :seqlen].tolist()
output_versions = [
v for v, m in zip(sample_versions, mask) if m == 1
]
else:
output_versions = [default_version]

head_version = min(output_versions) if output_versions else -1
tail_version = max(output_versions) if output_versions else -1

# RLE: [[version, count], ...]
version_rle = []
for v in output_versions:
if version_rle and version_rle[-1][0] == v:
version_rle[-1][1] += 1
else:
version_rle.append([v, 1])
Comment on lines +901 to +918
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There are two issues here:

  1. Correctness: When all_versions is missing, version_rle is set to [[default_version, 1]], which incorrectly reports only one token for the entire completion. It should reflect the total number of generated tokens (i.e., sum(mask)).
  2. Efficiency: The list comprehension for filtering versions can be replaced with more efficient tensor operations.

Additionally, the if output_versions checks are redundant because the preceding if mask[-1] != 1: continue guarantees that at least one token with loss_mask == 1 is present.

                if all_versions is not None:
                    # Filter versions by loss_mask using tensor operations for efficiency
                    output_versions = all_versions[i, :seqlen][loss_mask[i, :seqlen] == 1].tolist()
                    head_version = min(output_versions)
                    tail_version = max(output_versions)
                    
                    # RLE: [[version, count], ...]
                    version_rle = []
                    for v in output_versions:
                        if version_rle and version_rle[-1][0] == v:
                            version_rle[-1][1] += 1
                        else:
                            version_rle.append([v, 1])
                else:
                    head_version = tail_version = default_version
                    # If versions are missing, the entire completion is assumed to be default_version
                    version_rle = [[default_version, int(sum(mask))]]


prompt_end = seqlen - sum(mask)
prompt_ids = ids[:prompt_end]
completion_ids = ids[prompt_end:]
Expand All @@ -913,6 +936,7 @@ async def _dump_trajectory(
"prompt_len": prompt_end,
"head_version": head_version,
"tail_version": tail_version,
"version_rle": version_rle,
"reward": reward,
"prompt": prompt_text,
"completion": completion_text,
Expand Down
Loading