From d0f76061b816a9bd99a03c3e138e20f0609f5083 Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Thu, 21 May 2026 15:25:45 -0700 Subject: [PATCH 1/3] Add Trackio rollout trace logging --- areal/api/cli_args.py | 3 + areal/trainer/rl_trainer.py | 16 +++- areal/utils/stats_logger.py | 100 ++++++++++++++++++++++ docs/en/cli_reference.md | 13 +-- docs/en/reference/metrics_tracking.md | 13 +++ tests/test_trackio_backend.py | 114 ++++++++++++++++++++++++++ 6 files changed, 252 insertions(+), 7 deletions(-) diff --git a/areal/api/cli_args.py b/areal/api/cli_args.py index ea55f557a8..a3c2f19bfd 100644 --- a/areal/api/cli_args.py +++ b/areal/api/cli_args.py @@ -2382,6 +2382,9 @@ class TrackioConfig: space_id: str | None = None """HF Space ID for remote dashboard deployment (e.g. "user/my-space"). When set, metrics are also pushed to the specified Hugging Face Space.""" + max_rollout_traces_per_step: int = 32 + """Maximum rollout/eval trajectories to log as Trackio traces per step. + Set to 0 or a negative value to disable trace logging.""" def __post_init__(self): """Validate Trackio configuration.""" diff --git a/areal/trainer/rl_trainer.py b/areal/trainer/rl_trainer.py index c7249eb281..d45cd7bd57 100644 --- a/areal/trainer/rl_trainer.py +++ b/areal/trainer/rl_trainer.py @@ -578,6 +578,12 @@ def train( group_size=config.gconfig.n_samples, dynamic_bs=self.config.dynamic_bs, ) + self.stats_logger.log_rollout_traces( + rollout_batch, + split="rollout", + global_step=global_step, + tokenizer=self.tokenizer, + ) if self._should_offload_rollout: self._offload_rollout() @@ -1135,6 +1141,7 @@ def _evaluate_fn( self, eval_workflow: WorkflowLike, eval_workflow_kwargs, + global_step: int, ): if self.actor.is_data_parallel_head(): cnt = 0 @@ -1148,7 +1155,13 @@ def _evaluate_fn( is_eval=True, ) cnt += 1 - self.eval_rollout.wait(cnt, timeout=None) + eval_batch = self.eval_rollout.wait(cnt, timeout=None) + self.stats_logger.log_rollout_traces( + eval_batch, + split="eval-rollout", + global_step=global_step, + tokenizer=self.tokenizer, + ) dist.barrier(group=self.actor.cpu_group) current_platform.synchronize() @@ -1172,6 +1185,7 @@ def _evaluate( self._evaluate_fn, eval_workflow=eval_workflow, eval_workflow_kwargs=eval_workflow_kwargs, + global_step=global_step, ), epoch, epoch_step, diff --git a/areal/utils/stats_logger.py b/areal/utils/stats_logger.py index 220e6d3920..ea92a2c331 100644 --- a/areal/utils/stats_logger.py +++ b/areal/utils/stats_logger.py @@ -4,6 +4,7 @@ import os import time from dataclasses import asdict +from typing import Any import swanlab import torch.distributed as dist @@ -158,6 +159,105 @@ def commit(self, epoch: int, step: int, global_step: int, data: dict | list[dict self.summary_writer.add_scalar(f"{key}", val, log_step + i) self._last_commit_step = log_step + len(data) - 1 + def log_rollout_traces( + self, + trajectories: list[dict[str, Any]], + *, + split: str, + global_step: int, + tokenizer, + ) -> None: + if dist.is_initialized() and dist.get_rank() != 0: + return + if not getattr(self, "_trackio_enabled", False): + return + + max_traces = self.config.trackio.max_rollout_traces_per_step + if max_traces <= 0: + return + + traces = [] + for trajectory_index, trajectory in enumerate(trajectories): + if len(traces) >= max_traces: + break + traces.extend( + self._trajectory_to_trackio_traces( + trajectory, + split=split, + global_step=global_step, + trajectory_index=trajectory_index, + tokenizer=tokenizer, + remaining=max_traces - len(traces), + ) + ) + + if traces: + trackio.log({f"{split}/trajectories": traces}, step=global_step) + + def _trajectory_to_trackio_traces( + self, + trajectory: dict[str, Any], + *, + split: str, + global_step: int, + trajectory_index: int, + tokenizer, + remaining: int, + ) -> list[Any]: + input_ids = trajectory.get("input_ids") + loss_mask = trajectory.get("loss_mask") + attention_mask = trajectory.get("attention_mask") + rewards = trajectory.get("rewards") + versions = trajectory.get("versions") + + if input_ids is None or loss_mask is None or attention_mask is None: + return [] + + traces = [] + batch_size = input_ids.shape[0] + for sample_index in range(batch_size): + if len(traces) >= remaining: + break + seqlen = int(attention_mask[sample_index].sum().item()) + if seqlen <= 0: + continue + + ids = input_ids[sample_index, :seqlen].detach().cpu().tolist() + mask = loss_mask[sample_index, :seqlen].detach().cpu().tolist() + if not mask or mask[-1] != 1: + continue + + prompt_len = seqlen - sum(mask) + prompt = tokenizer.decode(ids[:prompt_len], skip_special_tokens=False) + completion = tokenizer.decode(ids[prompt_len:], skip_special_tokens=False) + metadata = { + "split": split, + "global_step": global_step, + "trajectory_index": trajectory_index, + "sample_index": sample_index, + "seqlen": seqlen, + "prompt_len": prompt_len, + } + if rewards is not None: + metadata["reward"] = float(rewards[sample_index].item()) + if versions is not None: + sample_versions = ( + versions[sample_index, :seqlen].detach().cpu().tolist() + ) + metadata["head_version"] = min(sample_versions) + metadata["tail_version"] = max(sample_versions) + + traces.append( + trackio.Trace( + messages=[ + {"role": "user", "content": prompt}, + {"role": "assistant", "content": completion}, + ], + metadata=metadata, + ) + ) + return traces + def print_stats(self, stats: dict[str, float]): logger.info("\n" + tabulate_stats(stats)) diff --git a/docs/en/cli_reference.md b/docs/en/cli_reference.md index 0b217a2673..85a28a4d3f 100644 --- a/docs/en/cli_reference.md +++ b/docs/en/cli_reference.md @@ -835,12 +835,13 @@ Spaces. See: https://github.com/gradio-app/trackio ``` -| Parameter | Type | Default | Description | -| ---------- | -------------- | ------------ | ----------- | -| `mode` | string | `"disabled"` | - | -| `project` | string \| None | `None` | - | -| `name` | string \| None | `None` | - | -| `space_id` | string \| None | `None` | - | +| Parameter | Type | Default | Description | +| ----------------------------- | -------------- | ------------ | ---------------------------------------------------------- | +| `mode` | string | `"disabled"` | - | +| `project` | string \| None | `None` | - | +| `name` | string \| None | `None` | - | +| `space_id` | string \| None | `None` | - | +| `max_rollout_traces_per_step` | integer | `32` | Maximum rollout/eval trajectories to log as Trackio traces | (section-wand-b)= diff --git a/docs/en/reference/metrics_tracking.md b/docs/en/reference/metrics_tracking.md index a83fe253aa..22da6174da 100644 --- a/docs/en/reference/metrics_tracking.md +++ b/docs/en/reference/metrics_tracking.md @@ -254,6 +254,7 @@ sends aggregated metrics to external logging backends. It is automatically manag | **Weights & Biases** | `config.stats_logger.wandb` | Cloud-based experiment tracking | | **SwanLab** | `config.stats_logger.swanlab` | Alternative experiment tracking | | **TensorBoard** | `config.stats_logger.tensorboard` | Local visualization | +| **Trackio** | `config.stats_logger.trackio` | Local-first metrics and traces | ### Integration with PPOTrainer @@ -288,11 +289,17 @@ def commit(self, epoch, step, global_step, data): # Log to all backends wandb.log(data, step=global_step) swanlab.log(data, step=global_step) + trackio.log(data, step=global_step) if self.summary_writer: for key, val in data.items(): self.summary_writer.add_scalar(key, val, global_step) ``` +When Trackio is enabled, `PPOTrainer` also logs rollout and evaluation trajectories as +`trackio.Trace` records. These traces decode each tensor trajectory into a user prompt +and assistant completion, with metadata such as `global_step`, trajectory index, +sample index, sequence length, prompt length, reward, and head/tail model versions. + ### Configuration Configure logging backends in your experiment config: @@ -314,6 +321,12 @@ stats_logger: tensorboard: path: "/path/to/tensorboard/logs" # null to disable + + trackio: + mode: "online" # "online", "local", or "disabled" + project: "my-project" + name: "run_001" + max_rollout_traces_per_step: 32 # <=0 disables trace logging ``` ## Best Practices diff --git a/tests/test_trackio_backend.py b/tests/test_trackio_backend.py index 1460ec199e..a6f96b2d7f 100644 --- a/tests/test_trackio_backend.py +++ b/tests/test_trackio_backend.py @@ -3,6 +3,8 @@ from dataclasses import fields from unittest.mock import MagicMock, patch +import torch + from areal.api.cli_args import ( StatsLoggerConfig, TrackioConfig, @@ -23,6 +25,7 @@ def test_default_optional_fields_are_none(self): assert config.project is None assert config.name is None assert config.space_id is None + assert config.max_rollout_traces_per_step == 32 def test_custom_values(self): """TrackioConfig should accept custom values.""" @@ -31,11 +34,13 @@ def test_custom_values(self): project="my-project", name="my-run", space_id="user/my-space", + max_rollout_traces_per_step=8, ) assert config.mode == "online" assert config.project == "my-project" assert config.name == "my-run" assert config.space_id == "user/my-space" + assert config.max_rollout_traces_per_step == 8 def test_invalid_mode_raises_error(self): """TrackioConfig should reject invalid mode values.""" @@ -193,3 +198,112 @@ def test_trackio_not_logged_when_disabled( data = {"loss/avg": 0.5} logger.commit(epoch=0, step=0, global_step=0, data=data) mock_trackio.log.assert_not_called() + + @patch("areal.utils.stats_logger.trackio") + @patch("areal.utils.stats_logger.wandb") + @patch("areal.utils.stats_logger.swanlab") + @patch("areal.utils.stats_logger.dist") + def test_trackio_trace_logging_from_rollout_tensors( + self, mock_dist, mock_swanlab, mock_wandb, mock_trackio + ): + """Rollout tensors should be decoded and logged as trackio.Trace records.""" + mock_dist.is_initialized.return_value = False + mock_trackio.Trace.side_effect = lambda messages, metadata: { + "_type": "trackio.trace", + "messages": messages, + "metadata": metadata, + } + + from areal.utils.stats_logger import StatsLogger + + config = _make_test_config(TrackioConfig(mode="online")) + logger = StatsLogger(config, _make_ft_spec()) + mock_trackio.log.reset_mock() + + tokenizer = MagicMock() + tokenizer.decode.side_effect = lambda ids, skip_special_tokens=False: " ".join( + str(i) for i in ids + ) + trajectory = { + "input_ids": torch.tensor([[1, 2, 3, 4]], dtype=torch.int64), + "attention_mask": torch.tensor([[1, 1, 1, 1]], dtype=torch.bool), + "loss_mask": torch.tensor([[0, 0, 1, 1]], dtype=torch.int64), + "rewards": torch.tensor([0.75], dtype=torch.float32), + "versions": torch.tensor([[0, 0, 2, 2]], dtype=torch.int64), + } + + logger.log_rollout_traces( + [trajectory], + split="rollout", + global_step=3, + tokenizer=tokenizer, + ) + + mock_trackio.Trace.assert_called_once() + trace_kwargs = mock_trackio.Trace.call_args.kwargs + assert trace_kwargs["messages"] == [ + {"role": "user", "content": "1 2"}, + {"role": "assistant", "content": "3 4"}, + ] + assert trace_kwargs["metadata"] == { + "split": "rollout", + "global_step": 3, + "trajectory_index": 0, + "sample_index": 0, + "seqlen": 4, + "prompt_len": 2, + "reward": 0.75, + "head_version": 0, + "tail_version": 2, + } + mock_trackio.log.assert_called_once() + assert mock_trackio.log.call_args.args[0]["rollout/trajectories"] == [ + { + "_type": "trackio.trace", + "messages": trace_kwargs["messages"], + "metadata": trace_kwargs["metadata"], + } + ] + assert mock_trackio.log.call_args.kwargs == {"step": 3} + + @patch("areal.utils.stats_logger.trackio") + @patch("areal.utils.stats_logger.wandb") + @patch("areal.utils.stats_logger.swanlab") + @patch("areal.utils.stats_logger.dist") + def test_trackio_trace_logging_respects_cap( + self, mock_dist, mock_swanlab, mock_wandb, mock_trackio + ): + """Trace logging should respect max_rollout_traces_per_step.""" + mock_dist.is_initialized.return_value = False + mock_trackio.Trace.side_effect = lambda messages, metadata: { + "messages": messages, + "metadata": metadata, + } + + from areal.utils.stats_logger import StatsLogger + + config = _make_test_config( + TrackioConfig(mode="online", max_rollout_traces_per_step=1) + ) + logger = StatsLogger(config, _make_ft_spec()) + mock_trackio.log.reset_mock() + + tokenizer = MagicMock() + tokenizer.decode.side_effect = lambda ids, skip_special_tokens=False: " ".join( + str(i) for i in ids + ) + trajectory = { + "input_ids": torch.tensor([[1, 2], [3, 4]], dtype=torch.int64), + "attention_mask": torch.tensor([[1, 1], [1, 1]], dtype=torch.bool), + "loss_mask": torch.tensor([[0, 1], [0, 1]], dtype=torch.int64), + "rewards": torch.tensor([1.0, 0.0], dtype=torch.float32), + } + + logger.log_rollout_traces( + [trajectory], + split="rollout", + global_step=0, + tokenizer=tokenizer, + ) + + assert mock_trackio.Trace.call_count == 1 From 44a964492eb1752368f50556e1c677c26cc6e849 Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Mon, 25 May 2026 10:21:16 -0700 Subject: [PATCH 2/3] docs: update generated cli reference --- docs/en/cli_reference.md | 14 +++++++------- docs/en/reference/metrics_tracking.md | 4 ++-- docs/zh/cli_reference.md | 13 +++++++------ 3 files changed, 16 insertions(+), 15 deletions(-) diff --git a/docs/en/cli_reference.md b/docs/en/cli_reference.md index 85a28a4d3f..ef09907aaf 100644 --- a/docs/en/cli_reference.md +++ b/docs/en/cli_reference.md @@ -835,13 +835,13 @@ Spaces. See: https://github.com/gradio-app/trackio ``` -| Parameter | Type | Default | Description | -| ----------------------------- | -------------- | ------------ | ---------------------------------------------------------- | -| `mode` | string | `"disabled"` | - | -| `project` | string \| None | `None` | - | -| `name` | string \| None | `None` | - | -| `space_id` | string \| None | `None` | - | -| `max_rollout_traces_per_step` | integer | `32` | Maximum rollout/eval trajectories to log as Trackio traces | +| Parameter | Type | Default | Description | +| ----------------------------- | -------------- | ------------ | ----------- | +| `mode` | string | `"disabled"` | - | +| `project` | string \| None | `None` | - | +| `name` | string \| None | `None` | - | +| `space_id` | string \| None | `None` | - | +| `max_rollout_traces_per_step` | integer | `32` | - | (section-wand-b)= diff --git a/docs/en/reference/metrics_tracking.md b/docs/en/reference/metrics_tracking.md index 22da6174da..f01d794c34 100644 --- a/docs/en/reference/metrics_tracking.md +++ b/docs/en/reference/metrics_tracking.md @@ -297,8 +297,8 @@ def commit(self, epoch, step, global_step, data): When Trackio is enabled, `PPOTrainer` also logs rollout and evaluation trajectories as `trackio.Trace` records. These traces decode each tensor trajectory into a user prompt -and assistant completion, with metadata such as `global_step`, trajectory index, -sample index, sequence length, prompt length, reward, and head/tail model versions. +and assistant completion, with metadata such as `global_step`, trajectory index, sample +index, sequence length, prompt length, reward, and head/tail model versions. ### Configuration diff --git a/docs/zh/cli_reference.md b/docs/zh/cli_reference.md index e9e6f11180..a9ee72211c 100644 --- a/docs/zh/cli_reference.md +++ b/docs/zh/cli_reference.md @@ -833,12 +833,13 @@ Spaces. See: https://github.com/gradio-app/trackio ``` -| Parameter | Type | Default | Description | -| ---------- | -------------- | ------------ | ----------- | -| `mode` | string | `"disabled"` | - | -| `project` | string \| None | `None` | - | -| `name` | string \| None | `None` | - | -| `space_id` | string \| None | `None` | - | +| Parameter | Type | Default | Description | +| ----------------------------- | -------------- | ------------ | ----------- | +| `mode` | string | `"disabled"` | - | +| `project` | string \| None | `None` | - | +| `name` | string \| None | `None` | - | +| `space_id` | string \| None | `None` | - | +| `max_rollout_traces_per_step` | integer | `32` | - | (section-wand-b)= From c65866e8a7bb421ef83ac37d470b26804b70deb2 Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Thu, 28 May 2026 12:26:19 -0700 Subject: [PATCH 3/3] Address Trackio trace review comments --- areal/utils/stats_logger.py | 114 +++++++++++++++++++++++++------ tests/test_trackio_backend.py | 125 ++++++++++++++++++++++++++++++++++ 2 files changed, 220 insertions(+), 19 deletions(-) diff --git a/areal/utils/stats_logger.py b/areal/utils/stats_logger.py index ea92a2c331..b1410b3ac0 100644 --- a/areal/utils/stats_logger.py +++ b/areal/utils/stats_logger.py @@ -161,7 +161,7 @@ def commit(self, epoch: int, step: int, global_step: int, data: dict | list[dict def log_rollout_traces( self, - trajectories: list[dict[str, Any]], + trajectories: list[dict[str, Any] | None], *, split: str, global_step: int, @@ -196,7 +196,7 @@ def log_rollout_traces( def _trajectory_to_trackio_traces( self, - trajectory: dict[str, Any], + trajectory: dict[str, Any] | None, *, split: str, global_step: int, @@ -204,6 +204,9 @@ def _trajectory_to_trackio_traces( tokenizer, remaining: int, ) -> list[Any]: + if trajectory is None: + return [] + input_ids = trajectory.get("input_ids") loss_mask = trajectory.get("loss_mask") attention_mask = trajectory.get("attention_mask") @@ -215,49 +218,122 @@ def _trajectory_to_trackio_traces( traces = [] batch_size = input_ids.shape[0] + input_ids_cpu = input_ids.detach().cpu() + loss_mask_cpu = loss_mask.detach().cpu() + attention_mask_cpu = attention_mask.detach().cpu() + rewards_cpu = rewards.detach().cpu() if rewards is not None else None + versions_cpu = versions.detach().cpu() if versions is not None else None for sample_index in range(batch_size): if len(traces) >= remaining: break - seqlen = int(attention_mask[sample_index].sum().item()) + seqlen = int(attention_mask_cpu[sample_index].sum().item()) if seqlen <= 0: continue - ids = input_ids[sample_index, :seqlen].detach().cpu().tolist() - mask = loss_mask[sample_index, :seqlen].detach().cpu().tolist() + ids = input_ids_cpu[sample_index, :seqlen].tolist() + mask = loss_mask_cpu[sample_index, :seqlen].tolist() if not mask or mask[-1] != 1: continue - prompt_len = seqlen - sum(mask) - prompt = tokenizer.decode(ids[:prompt_len], skip_special_tokens=False) - completion = tokenizer.decode(ids[prompt_len:], skip_special_tokens=False) + messages = self._trajectory_messages( + trajectory, + sample_index=sample_index, + ids=ids, + mask=mask, + tokenizer=tokenizer, + ) + if not messages: + continue + metadata = { "split": split, "global_step": global_step, "trajectory_index": trajectory_index, "sample_index": sample_index, "seqlen": seqlen, - "prompt_len": prompt_len, + "prompt_len": mask.index(1), } - if rewards is not None: - metadata["reward"] = float(rewards[sample_index].item()) - if versions is not None: - sample_versions = ( - versions[sample_index, :seqlen].detach().cpu().tolist() - ) + if rewards_cpu is not None: + metadata["reward"] = self._metadata_value(rewards_cpu[sample_index]) + if versions_cpu is not None: + sample_versions = versions_cpu[sample_index, :seqlen].tolist() metadata["head_version"] = min(sample_versions) metadata["tail_version"] = max(sample_versions) traces.append( trackio.Trace( - messages=[ - {"role": "user", "content": prompt}, - {"role": "assistant", "content": completion}, - ], + messages=messages, metadata=metadata, ) ) return traces + @staticmethod + def _metadata_value(value): + if hasattr(value, "numel"): + if value.numel() == 1: + return value.item() + return value.tolist() + return value + + def _trajectory_messages( + self, + trajectory: dict[str, Any], + *, + sample_index: int, + ids: list[int], + mask: list[int], + tokenizer, + ) -> list[dict[str, str]]: + messages = self._structured_messages(trajectory, sample_index) + if messages is not None: + return messages + + trace_messages = [] + span_start = 0 + for span_end in range(1, len(ids) + 1): + if span_end < len(ids) and bool(mask[span_end]) == bool(mask[span_start]): + continue + + content = tokenizer.decode( + ids[span_start:span_end], skip_special_tokens=False + ) + if content: + if bool(mask[span_start]): + role = "assistant" + else: + role = "user" if not trace_messages else "tool" + trace_messages.append({"role": role, "content": content}) + span_start = span_end + + return trace_messages + + @staticmethod + def _structured_messages( + trajectory: dict[str, Any], sample_index: int + ) -> list[dict[str, str]] | None: + for key in ("messages", "conversation", "conversation_text"): + value = trajectory.get(key) + if value is None: + continue + if ( + isinstance(value, list) + and len(value) > sample_index + and isinstance(value[sample_index], list) + ): + value = value[sample_index] + if isinstance(value, list) and all( + isinstance(message, dict) for message in value + ): + return [ + { + "role": str(message.get("role", "user")), + "content": str(message.get("content", "")), + } + for message in value + ] + return None + def print_stats(self, stats: dict[str, float]): logger.info("\n" + tabulate_stats(stats)) diff --git a/tests/test_trackio_backend.py b/tests/test_trackio_backend.py index a6f96b2d7f..761dbf1447 100644 --- a/tests/test_trackio_backend.py +++ b/tests/test_trackio_backend.py @@ -307,3 +307,128 @@ def test_trackio_trace_logging_respects_cap( ) assert mock_trackio.Trace.call_count == 1 + + @patch("areal.utils.stats_logger.trackio") + @patch("areal.utils.stats_logger.wandb") + @patch("areal.utils.stats_logger.swanlab") + @patch("areal.utils.stats_logger.dist") + def test_trackio_trace_logging_skips_none_trajectories( + self, mock_dist, mock_swanlab, mock_wandb, mock_trackio + ): + """Rejected rollout trajectories should be skipped.""" + mock_dist.is_initialized.return_value = False + mock_trackio.Trace.side_effect = lambda messages, metadata: { + "messages": messages, + "metadata": metadata, + } + + from areal.utils.stats_logger import StatsLogger + + config = _make_test_config(TrackioConfig(mode="online")) + logger = StatsLogger(config, _make_ft_spec()) + mock_trackio.log.reset_mock() + + tokenizer = MagicMock() + tokenizer.decode.side_effect = lambda ids, skip_special_tokens=False: " ".join( + str(i) for i in ids + ) + trajectory = { + "input_ids": torch.tensor([[1, 2]], dtype=torch.int64), + "attention_mask": torch.tensor([[1, 1]], dtype=torch.bool), + "loss_mask": torch.tensor([[0, 1]], dtype=torch.int64), + } + + logger.log_rollout_traces( + [None, trajectory], + split="eval-rollout", + global_step=4, + tokenizer=tokenizer, + ) + + mock_trackio.Trace.assert_called_once() + assert mock_trackio.Trace.call_args.kwargs["metadata"]["trajectory_index"] == 1 + mock_trackio.log.assert_called_once() + + @patch("areal.utils.stats_logger.trackio") + @patch("areal.utils.stats_logger.wandb") + @patch("areal.utils.stats_logger.swanlab") + @patch("areal.utils.stats_logger.dist") + def test_trackio_trace_logging_reconstructs_multiturn_tool_messages( + self, mock_dist, mock_swanlab, mock_wandb, mock_trackio + ): + """Interleaved loss masks should become multi-message traces.""" + mock_dist.is_initialized.return_value = False + + from areal.utils.stats_logger import StatsLogger + + config = _make_test_config(TrackioConfig(mode="online")) + logger = StatsLogger(config, _make_ft_spec()) + mock_trackio.log.reset_mock() + + tokenizer = MagicMock() + tokenizer.decode.side_effect = lambda ids, skip_special_tokens=False: "|".join( + str(i) for i in ids + ) + trajectory = { + "input_ids": torch.tensor([[1, 2, 3, 4, 5, 6, 7]], dtype=torch.int64), + "attention_mask": torch.tensor([[1, 1, 1, 1, 1, 1, 1]], dtype=torch.bool), + "loss_mask": torch.tensor([[0, 0, 1, 1, 0, 0, 1]], dtype=torch.int64), + } + + logger.log_rollout_traces( + [trajectory], + split="rollout", + global_step=5, + tokenizer=tokenizer, + ) + + assert mock_trackio.Trace.call_args.kwargs["messages"] == [ + {"role": "user", "content": "1|2"}, + {"role": "assistant", "content": "3|4"}, + {"role": "tool", "content": "5|6"}, + {"role": "assistant", "content": "7"}, + ] + + @patch("areal.utils.stats_logger.trackio") + @patch("areal.utils.stats_logger.wandb") + @patch("areal.utils.stats_logger.swanlab") + @patch("areal.utils.stats_logger.dist") + def test_trackio_trace_logging_prefers_structured_messages( + self, mock_dist, mock_swanlab, mock_wandb, mock_trackio + ): + """Structured conversation data should be logged directly when present.""" + mock_dist.is_initialized.return_value = False + + from areal.utils.stats_logger import StatsLogger + + config = _make_test_config(TrackioConfig(mode="online")) + logger = StatsLogger(config, _make_ft_spec()) + mock_trackio.log.reset_mock() + + tokenizer = MagicMock() + trajectory = { + "input_ids": torch.tensor([[1, 2, 3]], dtype=torch.int64), + "attention_mask": torch.tensor([[1, 1, 1]], dtype=torch.bool), + "loss_mask": torch.tensor([[0, 1, 1]], dtype=torch.int64), + "messages": [ + [ + {"role": "system", "content": "Use tools when needed."}, + {"role": "user", "content": "What is 2 + 2?"}, + {"role": "assistant", "content": "I will calculate it."}, + {"role": "tool", "content": "calculator: 4"}, + {"role": "assistant", "content": "4"}, + ] + ], + } + + logger.log_rollout_traces( + [trajectory], + split="rollout", + global_step=6, + tokenizer=tokenizer, + ) + + assert ( + mock_trackio.Trace.call_args.kwargs["messages"] == trajectory["messages"][0] + ) + tokenizer.decode.assert_not_called()