fix: rollout version dump - filter by loss_mask and add version_rle#1350
fix: rollout version dump - filter by loss_mask and add version_rle#1350pyq623 wants to merge 1 commit into
Conversation
- head_version/tail_version now per-sample, filtered by loss_mask==1 - fixes head_version always being -1 due to input token placeholders - adds version_rle field (run-length encoded per-token version list)
There was a problem hiding this comment.
Code Review
This pull request updates the _dump_trajectory function in areal/infra/workflow_executor.py to provide more granular version tracking, including the addition of a Run-Length Encoding (version_rle) for model versions within each trajectory sample. Feedback from the review suggests filtering the global_tail calculation by the loss mask to prevent placeholder values from being used in directory names. Additionally, the reviewer pointed out a correctness issue where the version_rle should reflect the total generated token count when explicit version data is missing, and recommended using tensor operations for better efficiency.
| global_tail = ( | ||
| all_versions.max().item() | ||
| if all_versions is not None | ||
| else default_version | ||
| ) |
There was a problem hiding this comment.
The global_tail calculation should also be filtered by loss_mask to ensure that the version directory name is not set to -1 (the placeholder for input tokens) if the versions tensor contains such values. This ensures the directory name reflects the actual model version used for generation.
| global_tail = ( | |
| all_versions.max().item() | |
| if all_versions is not None | |
| else default_version | |
| ) | |
| if all_versions is not None: | |
| # Filter by loss_mask to avoid -1 placeholders in the version directory name | |
| valid_versions = all_versions[loss_mask == 1] | |
| global_tail = ( | |
| int(valid_versions.max().item()) | |
| if valid_versions.numel() > 0 | |
| else default_version | |
| ) | |
| else: | |
| global_tail = default_version |
| if all_versions is not None: | ||
| sample_versions = all_versions[i, :seqlen].tolist() | ||
| output_versions = [ | ||
| v for v, m in zip(sample_versions, mask) if m == 1 | ||
| ] | ||
| else: | ||
| output_versions = [default_version] | ||
|
|
||
| head_version = min(output_versions) if output_versions else -1 | ||
| tail_version = max(output_versions) if output_versions else -1 | ||
|
|
||
| # RLE: [[version, count], ...] | ||
| version_rle = [] | ||
| for v in output_versions: | ||
| if version_rle and version_rle[-1][0] == v: | ||
| version_rle[-1][1] += 1 | ||
| else: | ||
| version_rle.append([v, 1]) |
There was a problem hiding this comment.
There are two issues here:
- Correctness: When
all_versionsis missing,version_rleis set to[[default_version, 1]], which incorrectly reports only one token for the entire completion. It should reflect the total number of generated tokens (i.e.,sum(mask)). - Efficiency: The list comprehension for filtering versions can be replaced with more efficient tensor operations.
Additionally, the if output_versions checks are redundant because the preceding if mask[-1] != 1: continue guarantees that at least one token with loss_mask == 1 is present.
if all_versions is not None:
# Filter versions by loss_mask using tensor operations for efficiency
output_versions = all_versions[i, :seqlen][loss_mask[i, :seqlen] == 1].tolist()
head_version = min(output_versions)
tail_version = max(output_versions)
# RLE: [[version, count], ...]
version_rle = []
for v in output_versions:
if version_rle and version_rle[-1][0] == v:
version_rle[-1][1] += 1
else:
version_rle.append([v, 1])
else:
head_version = tail_version = default_version
# If versions are missing, the entire completion is assumed to be default_version
version_rle = [[default_version, int(sum(mask))]]
Summary
head_version/tail_versionnow computed per-sample, filtered byloss_mask==1onlyhead_versionalways being -1 due to input token placeholders (-1) pollutingmin()Context
When dumping rollout trajectories, the
versionstensor includes-1for input tokens (system prompt, tool results). Previouslyhead_version = min(versions)always returned -1, making cross-version detection useless.Now only model-generated tokens (where
loss_mask == 1) are considered:The new version_rle field records version transitions during generation, e.g. [[2, 63632], [3, 13147]] means 63632 tokens generated by model v2, then 13147 tokens by v3.
Test plan