Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions areal/api/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -2435,6 +2435,9 @@ class TrackioConfig:
space_id: str | None = None
"""HF Space ID for remote dashboard deployment (e.g. "user/my-space").
When set, metrics are also pushed to the specified Hugging Face Space."""
max_rollout_traces_per_step: int = 32
"""Maximum rollout/eval trajectories to log as Trackio traces per step.
Set to 0 or a negative value to disable trace logging."""

def __post_init__(self):
"""Validate Trackio configuration."""
Expand Down
16 changes: 15 additions & 1 deletion areal/trainer/rl_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,12 @@ def train(
group_size=config.gconfig.n_samples,
dynamic_bs=self.config.dynamic_bs,
)
self.stats_logger.log_rollout_traces(
rollout_batch,
split="rollout",
global_step=global_step,
tokenizer=self.tokenizer,
)
if self._should_offload_rollout:
self._offload_rollout()

Expand Down Expand Up @@ -1135,6 +1141,7 @@ def _evaluate_fn(
self,
eval_workflow: WorkflowLike,
eval_workflow_kwargs,
global_step: int,
):
if self.actor.is_data_parallel_head():
cnt = 0
Expand All @@ -1148,7 +1155,13 @@ def _evaluate_fn(
is_eval=True,
)
cnt += 1
self.eval_rollout.wait(cnt, timeout=None)
eval_batch = self.eval_rollout.wait(cnt, timeout=None)
self.stats_logger.log_rollout_traces(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The wait method maybe return with None standing for rejected trajectories. Inside _trajectory_to_trackio_traces the first line is trajectory.get("input_ids"), which will AttributeError on any None element. Could you either filter Nones out in log_rollout_traces, or guard at the top of
_trajectory_to_trackio_traces?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also agree this should be fixed. It would be best to add test cases for the None trajectory scenario.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching this, I handled rejected None trajectories from rollout/eval results, updated trajectory type hints, and added tests for rejected trajectories

eval_batch,
split="eval-rollout",
global_step=global_step,
tokenizer=self.tokenizer,
)

dist.barrier(group=self.actor.cpu_group)
current_platform.synchronize()
Expand All @@ -1172,6 +1185,7 @@ def _evaluate(
self._evaluate_fn,
eval_workflow=eval_workflow,
eval_workflow_kwargs=eval_workflow_kwargs,
global_step=global_step,
),
epoch,
epoch_step,
Expand Down
176 changes: 176 additions & 0 deletions areal/utils/stats_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import time
from dataclasses import asdict
from typing import Any

import swanlab
import torch.distributed as dist
Expand Down Expand Up @@ -158,6 +159,181 @@ def commit(self, epoch: int, step: int, global_step: int, data: dict | list[dict
self.summary_writer.add_scalar(f"{key}", val, log_step + i)
self._last_commit_step = log_step + len(data) - 1

def log_rollout_traces(
self,
trajectories: list[dict[str, Any] | None],
*,
split: str,
global_step: int,
tokenizer,
) -> None:
if dist.is_initialized() and dist.get_rank() != 0:
return
if not getattr(self, "_trackio_enabled", False):
return

max_traces = self.config.trackio.max_rollout_traces_per_step
if max_traces <= 0:
return

traces = []
for trajectory_index, trajectory in enumerate(trajectories):
if len(traces) >= max_traces:
break
traces.extend(
self._trajectory_to_trackio_traces(
trajectory,
split=split,
global_step=global_step,
trajectory_index=trajectory_index,
tokenizer=tokenizer,
remaining=max_traces - len(traces),
)
)

if traces:
trackio.log({f"{split}/trajectories": traces}, step=global_step)

def _trajectory_to_trackio_traces(
self,
trajectory: dict[str, Any] | None,
*,
split: str,
global_step: int,
trajectory_index: int,
tokenizer,
remaining: int,
) -> list[Any]:
if trajectory is None:
return []

input_ids = trajectory.get("input_ids")
loss_mask = trajectory.get("loss_mask")
attention_mask = trajectory.get("attention_mask")
rewards = trajectory.get("rewards")
versions = trajectory.get("versions")

if input_ids is None or loss_mask is None or attention_mask is None:
return []

traces = []
batch_size = input_ids.shape[0]
input_ids_cpu = input_ids.detach().cpu()
loss_mask_cpu = loss_mask.detach().cpu()
attention_mask_cpu = attention_mask.detach().cpu()
rewards_cpu = rewards.detach().cpu() if rewards is not None else None
versions_cpu = versions.detach().cpu() if versions is not None else None
for sample_index in range(batch_size):
if len(traces) >= remaining:
break
seqlen = int(attention_mask_cpu[sample_index].sum().item())
if seqlen <= 0:
continue

ids = input_ids_cpu[sample_index, :seqlen].tolist()
mask = loss_mask_cpu[sample_index, :seqlen].tolist()
if not mask or mask[-1] != 1:
continue

messages = self._trajectory_messages(
trajectory,
sample_index=sample_index,
ids=ids,
mask=mask,
tokenizer=tokenizer,
)
if not messages:
continue

metadata = {
"split": split,
"global_step": global_step,
"trajectory_index": trajectory_index,
"sample_index": sample_index,
"seqlen": seqlen,
"prompt_len": mask.index(1),
}
if rewards_cpu is not None:
metadata["reward"] = self._metadata_value(rewards_cpu[sample_index])
if versions_cpu is not None:
sample_versions = versions_cpu[sample_index, :seqlen].tolist()
metadata["head_version"] = min(sample_versions)
metadata["tail_version"] = max(sample_versions)

traces.append(
trackio.Trace(
messages=messages,
metadata=metadata,
)
)
return traces

@staticmethod
def _metadata_value(value):
if hasattr(value, "numel"):
if value.numel() == 1:
return value.item()
return value.tolist()
return value

def _trajectory_messages(
self,
trajectory: dict[str, Any],
*,
sample_index: int,
ids: list[int],
mask: list[int],
tokenizer,
) -> list[dict[str, str]]:
messages = self._structured_messages(trajectory, sample_index)
if messages is not None:
return messages

trace_messages = []
span_start = 0
for span_end in range(1, len(ids) + 1):
if span_end < len(ids) and bool(mask[span_end]) == bool(mask[span_start]):
continue

content = tokenizer.decode(
ids[span_start:span_end], skip_special_tokens=False
)
if content:
if bool(mask[span_start]):
role = "assistant"
else:
role = "user" if not trace_messages else "tool"
trace_messages.append({"role": role, "content": content})
span_start = span_end

return trace_messages

@staticmethod
def _structured_messages(
trajectory: dict[str, Any], sample_index: int
) -> list[dict[str, str]] | None:
for key in ("messages", "conversation", "conversation_text"):
value = trajectory.get(key)
if value is None:
continue
if (
isinstance(value, list)
and len(value) > sample_index
and isinstance(value[sample_index], list)
):
value = value[sample_index]
if isinstance(value, list) and all(
isinstance(message, dict) for message in value
):
return [
{
"role": str(message.get("role", "user")),
"content": str(message.get("content", "")),
}
for message in value
]
return None

def print_stats(self, stats: dict[str, float]):
logger.info("\n" + tabulate_stats(stats))

Expand Down
13 changes: 7 additions & 6 deletions docs/en/cli_reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -838,12 +838,13 @@ Spaces.
See: https://github.com/gradio-app/trackio
```

| Parameter | Type | Default | Description |
| ---------- | -------------- | ------------ | ----------- |
| `mode` | string | `"disabled"` | - |
| `project` | string \| None | `None` | - |
| `name` | string \| None | `None` | - |
| `space_id` | string \| None | `None` | - |
| Parameter | Type | Default | Description |
| ----------------------------- | -------------- | ------------ | ----------- |
| `mode` | string | `"disabled"` | - |
| `project` | string \| None | `None` | - |
| `name` | string \| None | `None` | - |
| `space_id` | string \| None | `None` | - |
| `max_rollout_traces_per_step` | integer | `32` | - |

(section-wand-b)=

Expand Down
13 changes: 13 additions & 0 deletions docs/en/reference/metrics_tracking.md
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ sends aggregated metrics to external logging backends. It is automatically manag
| **Weights & Biases** | `config.stats_logger.wandb` | Cloud-based experiment tracking |
| **SwanLab** | `config.stats_logger.swanlab` | Alternative experiment tracking |
| **TensorBoard** | `config.stats_logger.tensorboard` | Local visualization |
| **Trackio** | `config.stats_logger.trackio` | Local-first metrics and traces |

### Integration with PPOTrainer

Expand Down Expand Up @@ -288,11 +289,17 @@ def commit(self, epoch, step, global_step, data):
# Log to all backends
wandb.log(data, step=global_step)
swanlab.log(data, step=global_step)
trackio.log(data, step=global_step)
if self.summary_writer:
for key, val in data.items():
self.summary_writer.add_scalar(key, val, global_step)
```

When Trackio is enabled, `PPOTrainer` also logs rollout and evaluation trajectories as
`trackio.Trace` records. These traces decode each tensor trajectory into a user prompt
and assistant completion, with metadata such as `global_step`, trajectory index, sample
index, sequence length, prompt length, reward, and head/tail model versions.

### Configuration

Configure logging backends in your experiment config:
Expand All @@ -314,6 +321,12 @@ stats_logger:

tensorboard:
path: "/path/to/tensorboard/logs" # null to disable

trackio:
mode: "online" # "online", "local", or "disabled"
project: "my-project"
name: "run_001"
max_rollout_traces_per_step: 32 # <=0 disables trace logging
```

## Best Practices
Expand Down
13 changes: 7 additions & 6 deletions docs/zh/cli_reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -836,12 +836,13 @@ Spaces.
See: https://github.com/gradio-app/trackio
```

| Parameter | Type | Default | Description |
| ---------- | -------------- | ------------ | ----------- |
| `mode` | string | `"disabled"` | - |
| `project` | string \| None | `None` | - |
| `name` | string \| None | `None` | - |
| `space_id` | string \| None | `None` | - |
| Parameter | Type | Default | Description |
| ----------------------------- | -------------- | ------------ | ----------- |
| `mode` | string | `"disabled"` | - |
| `project` | string \| None | `None` | - |
| `name` | string \| None | `None` | - |
| `space_id` | string \| None | `None` | - |
| `max_rollout_traces_per_step` | integer | `32` | - |

(section-wand-b)=

Expand Down
Loading
Loading