From 5a5ed2f1db8762ff9fc6373f85a3b2ca8a263b9b Mon Sep 17 00:00:00 2001 From: FightingZhen <295632982@qq.com> Date: Sat, 16 May 2026 15:47:55 +0800 Subject: [PATCH 01/10] fix: improve parser robustness, cross-platform path handling, and test stability --- docs/overview/RL_Timeline_quickstart.md | 2 +- rl_insight/data/rules.py | 75 ++++-- rl_insight/parser/gmm_parser.py | 34 ++- rl_insight/parser/mstx_parser.py | 11 +- rl_insight/parser/parser.py | 7 + rl_insight/visualizer/gmm_visualizer.py | 333 +++++++++--------------- tests/data/test_data_checker.py | 13 +- tests/data/test_paths.py | 22 ++ tests/data/test_rules.py | 2 +- tests/doc/test_docs_urls.py | 15 +- tests/parser/test_cluster_analysis.py | 26 ++ tests/parser/test_gmm_parser.py | 28 ++ 12 files changed, 310 insertions(+), 258 deletions(-) create mode 100644 tests/data/test_paths.py create mode 100644 tests/parser/test_gmm_parser.py diff --git a/docs/overview/RL_Timeline_quickstart.md b/docs/overview/RL_Timeline_quickstart.md index 6935d9b..7115596 100644 --- a/docs/overview/RL_Timeline_quickstart.md +++ b/docs/overview/RL_Timeline_quickstart.md @@ -46,7 +46,7 @@ pip install -e . 使用 VeRL 框架采集性能数据,详细参考: -[VeRL NPU Profiling 教程](https://github.com/verl-project/verl/blob/main/docs/ascend_tutorial/profiling/ascend_profiling_zh.rst) +[VeRL NPU Profiling 文档目录](https://github.com/verl-project/verl/blob/main/docs/ascend_tutorial/dev_guide/performance/ascend_profiling_zh.rst) [VeRL GPU Profiling 教程](https://github.com/verl-project/verl/blob/main/docs/perf/nsight_profiling.md) diff --git a/rl_insight/data/rules.py b/rl_insight/data/rules.py index 8548c00..e471d2e 100644 --- a/rl_insight/data/rules.py +++ b/rl_insight/data/rules.py @@ -23,6 +23,17 @@ import pandas as pd +def _coerce_path(data: Any) -> Optional[Path]: + if isinstance(data, Path): + return data + if isinstance(data, str): + try: + return Path(data) + except TypeError: + return None + return None + + class DataValidationError(Exception): """Exception raised when data validation fails.""" @@ -53,19 +64,19 @@ def error_message(self) -> str: class PathExistsRule(ValidationRule): def check(self, data: Any) -> bool: - if not isinstance(data, str): + path = _coerce_path(data) + if path is None: self._error_message = "Data object is not a path" return False try: - path = Path(data) if not path.is_dir(): self._error_message = ( - f"Source path is not a directory or does not exist: {data}" + f"Source path is not a directory or does not exist: {path}" ) return False return True except TypeError as e: - self._error_message = f"Error checking path {data}: {e}" + self._error_message = f"Error checking path {path}: {e}" return False @@ -73,15 +84,14 @@ class MstxJsonFileExistsRule(ValidationRule): """valid Mstx trace_view.json and profiler_info_*.json files is existed in "ASCEND_PROFILER_OUTPUT" path""" def check(self, data) -> bool: - if not isinstance(data, str): + root_path = _coerce_path(data) + if root_path is None: self._error_message = "Data object is not a path" return False self._error_message = "" try: - root_path = Path(data) - if not root_path.exists(): - self._error_message = f"Source path does not exist: {data}" + self._error_message = f"Source path does not exist: {root_path}" return False ascend_profiler_output = "ASCEND_PROFILER_OUTPUT" @@ -93,7 +103,7 @@ def check(self, data) -> bool: ascend_pt_folders = glob.glob(ascend_pt_pattern) if not ascend_pt_folders: - self._error_message = f"No *_ascend_pt path in {data}" + self._error_message = f"No *_ascend_pt path in {root_path}" return False for ascend_pt_folder in ascend_pt_folders: @@ -121,7 +131,7 @@ def check(self, data) -> bool: return False return True except Exception as e: - self._error_message = f"Error checking path {data}: {e}" + self._error_message = f"Error checking path {root_path}: {e}" return False @property @@ -133,21 +143,24 @@ class MstxJsonFieldValidRule(ValidationRule): """valid Mstx trace_view.json and profiler_info_*.json files JSON format""" def check(self, data) -> bool: - if not isinstance(data, str): + root_path = _coerce_path(data) + if root_path is None: self._error_message = "Data object is not a path" return False self._error_message = "" try: - root_path = Path(data) - if not root_path.exists(): - self._error_message = f"Source path does not exist: {data}" + self._error_message = f"Source path does not exist: {root_path}" return False # get all *_ascend_pt path ascend_pt_pattern = str(root_path / "*" / "*_ascend_pt") ascend_pt_folders = glob.glob(ascend_pt_pattern) + if not ascend_pt_folders: + self._error_message = f"No *_ascend_pt path in {root_path}" + return False + for ascend_pt_folder in ascend_pt_folders: ascend_pt_path = Path(ascend_pt_folder) @@ -155,11 +168,22 @@ def check(self, data) -> bool: trace_view_path = ( ascend_pt_path / "ASCEND_PROFILER_OUTPUT" / "trace_view.json" ) + if not trace_view_path.exists(): + self._error_message = ( + f"Missing trace_view.json in: {trace_view_path.parent}" + ) + return False if os.path.getsize(trace_view_path) == 0: self._error_message = f"File is empty: {trace_view_path}" return False - with open(trace_view_path, "r", encoding="utf-8") as f: - trace_view_data = json.load(f) + try: + with open(trace_view_path, "r", encoding="utf-8") as f: + trace_view_data = json.load(f) + except Exception as exc: + self._error_message = ( + f"Failed to parse JSON file {trace_view_path}: {exc}" + ) + return False if len(trace_view_data) == 0: self._error_message = f"File is empty: {trace_view_path}" @@ -175,12 +199,21 @@ def check(self, data) -> bool: # valid profiler_info_*.json format profiler_pattern = str(ascend_pt_path / "profiler_info_*.json") profiler_info_files = glob.glob(profiler_pattern) + if not profiler_info_files: + self._error_message = ( + f"profiler_info_*.json does not exist in: {ascend_pt_path}" + ) + return False for file in profiler_info_files: - if os.path.getsize(trace_view_path) == 0: - self._error_message = f"File is empty: {trace_view_path}" + if os.path.getsize(file) == 0: + self._error_message = f"File is empty: {file}" + return False + try: + with open(file, "r", encoding="utf-8") as f: + profiler_info_data = json.load(f) + except Exception as exc: + self._error_message = f"Failed to parse JSON file {file}: {exc}" return False - with open(file, "r", encoding="utf-8") as f: - profiler_info_data = json.load(f) if len(profiler_info_data) == 0: self._error_message = f"File is empty: {file}" return False @@ -200,7 +233,7 @@ def check(self, data) -> bool: return False return True except Exception as e: - self._error_message = f"Error checking path {data}: {e}" + self._error_message = f"Error checking path {root_path}: {e}" return False @property diff --git a/rl_insight/parser/gmm_parser.py b/rl_insight/parser/gmm_parser.py index fb612e3..5dde152 100644 --- a/rl_insight/parser/gmm_parser.py +++ b/rl_insight/parser/gmm_parser.py @@ -62,6 +62,22 @@ def __init__(self, params) -> None: # Get role filter if provided self._role = params.get("role", None) + @staticmethod + def _normalize_path_text(path_value: str | Path) -> str: + return str(path_value).replace("\\", "/") + + @classmethod + def _extract_rank_id_from_path(cls, path_value: str | Path) -> int: + normalized = cls._normalize_path_text(path_value) + m_rank = re.search(r"(?:^|/)rank(\d+)(?:/|$)", normalized) + return int(m_rank.group(1)) if m_rank else -1 + + @classmethod + def _extract_step_from_path(cls, path_value: str | Path) -> int: + normalized = cls._normalize_path_text(path_value) + m_step = re.search(r"(?:^|/)step_(\d+)(?:/|$)", normalized) + return int(m_step.group(1)) if m_step else -1 + def allocate_prof_data(self, input_path: str) -> List[DataMap]: """Allocate and organize GMM profiling data from the input path.""" data_maps: List[DataMap] = [] @@ -82,19 +98,12 @@ def allocate_prof_data(self, input_path: str) -> List[DataMap]: # Parse rank, step, stage from path parts = file_path.parts - text = str(file_path) - - # Extract rank - m_rank = re.search(r"/rank(\d+)/", text) - if not m_rank: + rank_id = self._extract_rank_id_from_path(file_path) + if rank_id < 0: continue - rank_id = int(m_rank.group(1)) - - # Extract step - m_step = re.search(r"/step_(\d+)/", text) - if not m_step: + step = self._extract_step_from_path(file_path) + if step < 0: continue - step = int(m_step.group(1)) # Extract stage stage = None @@ -144,7 +153,8 @@ def _load_group_list(self, file_path: str) -> np.ndarray: @staticmethod def _training_step_from_path(profiler_data_path: str) -> int: - m = re.search(r"/step_(\d+)/", profiler_data_path) + normalized = str(profiler_data_path).replace("\\", "/") + m = re.search(r"(?:^|/)step_(\d+)(?:/|$)", normalized) return int(m.group(1)) if m else 0 def parse_analysis_data( diff --git a/rl_insight/parser/mstx_parser.py b/rl_insight/parser/mstx_parser.py index 5fd5652..f3cda45 100644 --- a/rl_insight/parser/mstx_parser.py +++ b/rl_insight/parser/mstx_parser.py @@ -134,6 +134,15 @@ def _get_profiler_data_path(self, rank_id, data_path): data_path, Constant.ASCEND_PROFILER_OUTPUT, "trace_view.json" ) + @staticmethod + def _extract_timestamp_key(path_value: str) -> str: + """Extract the timestamp-like segment using the legacy underscore layout.""" + dir_name = Path(path_value).name + parts = dir_name.split("_") + if len(parts) >= 3: + return parts[-3] + return dir_name + def _get_rank_path_with_role(self, data_map) -> list[DataMap]: """Get json path information for all ranks. @@ -187,7 +196,7 @@ def _get_data_map(self, path_list) -> dict[tuple[str, int], list[str]]: rank_id_map[(task_role, rank_id)].append(dir_name) try: for map_key, dir_list in rank_id_map.items(): - dir_list.sort(key=lambda x: x.split("_")[-3]) + dir_list.sort(key=self._extract_timestamp_key) data_map[map_key] = dir_list except Exception as e: raise RuntimeError("Found invalid directory name!") from e diff --git a/rl_insight/parser/parser.py b/rl_insight/parser/parser.py index c033890..b924d5d 100644 --- a/rl_insight/parser/parser.py +++ b/rl_insight/parser/parser.py @@ -57,6 +57,7 @@ def mapper_func(self, data_maps: list[DataMap]): results = [] completed = 0 + failed_ranks = [] with ProcessPoolExecutor(max_workers=max_workers) as executor: # Submit all tasks @@ -77,11 +78,17 @@ def mapper_func(self, data_maps: list[DataMap]): f"Completed rank {rank_id}: {completed}/{total_ranks} ({progress:.1f}%)" ) except Exception as e: + failed_ranks.append(rank_id) logger.error(f"Failed to process rank {rank_id}: {e}") logger.info( f"Parallel processing completed: {completed}/{total_ranks} ranks processed" ) + if failed_ranks and not results: + logger.error( + "All rank parsing tasks failed during parallel processing. " + f"Failed ranks: {failed_ranks}" + ) return results def _mapper_func(self, data_map: DataMap) -> list[dict[str, Any]]: diff --git a/rl_insight/visualizer/gmm_visualizer.py b/rl_insight/visualizer/gmm_visualizer.py index 79610b0..7c2b2cf 100644 --- a/rl_insight/visualizer/gmm_visualizer.py +++ b/rl_insight/visualizer/gmm_visualizer.py @@ -14,10 +14,11 @@ from pathlib import Path from typing import Any, List, Tuple -import matplotlib.pyplot as plt + import numpy as np import pandas as pd from loguru import logger +from PIL import Image, ImageDraw from rl_insight.visualizer.visualizer import BaseVisualizer, register_cluster_visualizer from rl_insight.data import DataEnum @@ -36,8 +37,7 @@ def _resolve_output_path(output_cfg) -> Path: - Path with suffix (e.g., 'a/b/c.png') -> treat as explicit file path. """ output = Path(output_cfg) - is_dir_semantics = output.is_dir() or output.suffix == "" - if is_dir_semantics: + if output.is_dir() or output.suffix == "": output = output / "gmm_heatmap.png" return output @@ -48,23 +48,19 @@ def _load_signature(stage_data: pd.DataFrame) -> np.ndarray: def run(self, data): """Run GMM heatmap visualization from parsed data.""" - # Extract parameters from config output_cfg = self.config.get( "output_path", "./output/gmm_group_list_heatmap.png" ) output = self._resolve_output_path(output_cfg) - dpi = self.config.get("dpi", 150) - cmap = self.config.get("cmap", "viridis") gmm_per_layer = int(self.config.get("gmm_per_layer", 3)) if not isinstance(data, pd.DataFrame): raise ValueError(f"Expected DataFrame, got {type(data).__name__}") + if data.empty: + raise ValueError("No GMM data provided") logger.info(f"GmmVisualizer received DataFrame with {len(data)} rows") logger.info(f"DataFrame columns: {list(data.columns)}") - - if data.empty: - raise ValueError("No GMM data provided") logger.info("Visualizer consumes parser-filtered GMM summary data.") # For actor_update, filter out backward/recompute data by detecting @@ -77,111 +73,88 @@ def run(self, data): # This works regardless of whether gradient recomputation is enabled: # - With recomputation: forward runs of 3, then a run >3 triggers cutoff # - Without recomputation: forward runs of 3, then a run of 3+3=6 triggers cutoff - is_actor_update = "actor_update" in data["role"].unique() - if is_actor_update: + if "actor_update" in data["role"].unique(): grouped = data.groupby(["step", "role", "rank_id"]) filtered_data = [] - for name, group in grouped: - step_val, role_val, rank_val = name - if role_val == "actor_update": - sorted_group = group.sort_values("stage") - unique_stages = sorted(sorted_group["stage"].unique()) - - # Build load signature for each stage - stage_loads = {} - for stage in unique_stages: - stage_data = sorted_group[sorted_group["stage"] == stage] - load_sig = self._load_signature(stage_data) - stage_loads[stage] = load_sig - - # Scan forward: keep stages until a run exceeds gmm_per_layer. - forward_stages = [] - prev_load = None - consecutive = 0 - backward_detected = False - - for stage in unique_stages: - if backward_detected: - break - load = stage_loads[stage] - if prev_load is not None and np.array_equal(load, prev_load): - consecutive += 1 - else: - prev_load = load - consecutive = 1 - - if consecutive <= gmm_per_layer: - forward_stages.append(stage) - else: - backward_detected = True - - filtered_group = sorted_group[ - sorted_group["stage"].isin(forward_stages) - ] - filtered_data.append(filtered_group) - logger.info( - f"For actor_update (step={step_val}, rank={rank_val}): " - f"kept {len(forward_stages)} forward stages out of {len(unique_stages)} total " - f"(backward detected={backward_detected}, gmm_per_layer={gmm_per_layer})" - ) - else: + for (step_val, role_val, rank_val), group in grouped: + if role_val != "actor_update": filtered_data.append(group) - - if filtered_data: - data = pd.concat(filtered_data) + continue + + sorted_group = group.sort_values("stage") + unique_stages = sorted(sorted_group["stage"].unique()) + stage_loads = {} + for stage in unique_stages: + stage_data = sorted_group[sorted_group["stage"] == stage] + stage_loads[stage] = self._load_signature(stage_data) + + forward_stages = [] + prev_load = None + consecutive = 0 + backward_detected = False + for stage in unique_stages: + if backward_detected: + break + load = stage_loads[stage] + if prev_load is not None and np.array_equal(load, prev_load): + consecutive += 1 + else: + prev_load = load + consecutive = 1 + + if consecutive <= gmm_per_layer: + forward_stages.append(stage) + else: + backward_detected = True + + filtered_group = sorted_group[ + sorted_group["stage"].isin(forward_stages) + ] + filtered_data.append(filtered_group) logger.info( - f"After filtering actor_update forward-only data, now {len(data)} rows" + f"For actor_update (step={step_val}, rank={rank_val}): " + f"kept {len(forward_stages)} forward stages out of {len(unique_stages)} total " + f"(backward detected={backward_detected}, gmm_per_layer={gmm_per_layer})" ) - else: - logger.warning("No data left after filtering") + + if not filtered_data: raise ValueError("No data left after filtering") + data = pd.concat(filtered_data) + logger.info( + f"After filtering actor_update forward-only data, now {len(data)} rows" + ) - # Build matrix mat, rec_list, boundaries = self._build_matrix_from_data(data) logger.info(f"Built matrix with shape {mat.shape}") - segments = self._segment_labels(rec_list, boundaries) - - # Generate title - unique_ranks = sorted(data["rank_id"].unique()) - if len(unique_ranks) == 1: - rank_str = f" rank={unique_ranks[0]}" - else: - rank_str = f" ranks={len(unique_ranks)}" - title = f"GMM expert load (group_list){rank_str} — {len(rec_list)} snapshots, {mat.shape[0]} experts" - - # Plot heatmap - self._plot_heatmap(mat, rec_list, segments, title, output, dpi, cmap) - + self._plot_heatmap(mat, segments, output) return str(output) def _build_matrix_from_data( self, data: pd.DataFrame ) -> Tuple[np.ndarray, List[dict], List[int]]: """Build a matrix from the parsed data.""" - # Group data by step, role, rank_id, stage - # First sort the data to ensure consistent ordering + # Group data by step, role, rank_id, stage. + # First sort the data to ensure consistent ordering. sorted_data = data.sort_values(["step", "role", "rank_id", "stage"]) grouped = sorted_data.groupby(["step", "role", "rank_id", "stage"]) - # Get unique steps, roles, ranks and stages + # Get unique steps, roles, ranks and stages. steps = sorted(data["step"].unique()) roles = sorted(data["role"].unique()) ranks = sorted(data["rank_id"].unique()) stages = sorted(data["stage"].unique()) - max_expert = data["expert_index"].max() - + max_expert = int(data["expert_index"].max()) logger.info(f"Steps: {steps}") logger.info(f"Roles: {roles}") logger.info(f"Ranks: {ranks}") logger.info(f"Stages: {stages}") logger.info(f"Max expert index: {max_expert}") - # Build matrix and detect duplicate stages vecs = [] rec_list = [] - # Track layer mapping per (step, role, rank) group + # Track layer mapping per (step, role, rank) group. current_group = None seen_vectors: dict[tuple[Any, ...], int] = {} layer_counter = 0 @@ -191,11 +164,10 @@ def _build_matrix_from_data( logger.info( f"Processing step: {step}, role: {role}, rank: {rank}, stage: {stage_idx}" ) - - # Check if we're in a new (step, role, rank) group + # Check if we're in a new (step, role, rank) group. new_group = (step, role, rank) if new_group != current_group: - # Reset layer counter and seen vectors for new group + # Reset layer counter and seen vectors for the new group. current_group = new_group seen_vectors.clear() layer_counter = 0 @@ -203,23 +175,20 @@ def _build_matrix_from_data( f"New group detected: {new_group}, resetting layer counter to 0" ) - # Create a vector for this step, role, rank and stage + # Create a vector for this step, role, rank and stage. vec = np.full(max_expert + 1, np.nan, dtype=np.float64) for _, row in group.iterrows(): - expert_idx = row["expert_index"] - vec[expert_idx] = row["load"] + vec[int(row["expert_index"])] = row["load"] - # Convert vector to tuple for hashing (handle NaN values) + # Convert vector to tuple for hashing, replacing NaN to keep comparisons stable. vec_tuple = tuple(v if not np.isnan(v) else -1 for v in vec) - - # Check if this vector has been seen before in current group if vec_tuple not in seen_vectors: - # New layer + # New layer. seen_vectors[vec_tuple] = layer_counter layer_idx = layer_counter layer_counter += 1 else: - # Duplicate layer + # Duplicate layer. layer_idx = seen_vectors[vec_tuple] vecs.append(vec) @@ -230,7 +199,7 @@ def _build_matrix_from_data( "rank_id": rank, "stage": stage_idx, "op_index": stage_idx, # Original op index - "layer_idx": layer_idx, # Mapped layer index + "layer_idx": layer_idx, # Mapped layer index. } ) @@ -257,7 +226,6 @@ def _build_matrix_from_data( cur_key = new_key boundaries.append(mat.shape[1]) logger.info(f"Boundaries (step/role/rank): {boundaries}") - return mat, rec_list, boundaries def _segment_labels( @@ -276,133 +244,72 @@ def _segment_labels( def _plot_heatmap( self, mat: np.ndarray, - rec_list: List[dict], segments: List[Tuple[int, int, int, str, int]], - title: str, out_path: Path, - dpi: int, - cmap: str, ) -> None: """Plot the heatmap.""" n_exp, n_time = mat.shape - # Keep figure size readable when segment/time dimension is large. - # Use sub-linear growth for height to avoid overly tall and narrow figures. - fig_w = min(32, max(10, n_exp * 0.18)) - fig_h = min(22, max(8, 6 + np.sqrt(max(n_time, 1)) * 0.9)) - fig = plt.figure(figsize=(fig_w + 2.8, fig_h)) - gs = fig.add_gridspec(1, 2, width_ratios=[0.16, 1], wspace=0.05) - ax_bar = fig.add_subplot(gs[0, 0]) - ax = fig.add_subplot(gs[0, 1]) - - # Main heatmap is transposed to put experts on X axis. - # mat: [n_experts, n_time] -> heatmap_data: [n_time, n_experts] - heatmap_data = mat.T - ax_bar.set_ylim(-0.5, n_time - 0.5) - ax.set_ylim(-0.5, n_time - 0.5) - ax.set_xlim(-0.5, n_exp - 0.5) + cell_w = 28 + cell_h = 28 + left_bar_w = 120 + pad = 24 + + finite_vals = mat[np.isfinite(mat)] + vmin = float(finite_vals.min()) if finite_vals.size else 0.0 + vmax = float(finite_vals.max()) if finite_vals.size else 1.0 + scale = vmax - vmin if vmax > vmin else 1.0 + + img_w = pad * 2 + left_bar_w + n_exp * cell_w + img_h = pad * 2 + n_time * cell_h + image = Image.new("RGB", (img_w, img_h), "white") + draw = ImageDraw.Draw(image) # Segment bar: one color per (step, role, rank), shown on left side. - # Use viridis colormap for consistency with heatmap - palette = plt.cm.viridis(np.linspace(0, 1, len(segments))) - for i, (a, b, step, role, rank_id) in enumerate(segments): - color = palette[i] - ax_bar.axhspan( - a - 0.5, b - 0.5, facecolor=color, alpha=0.55, edgecolor="none" + segment_colors = [ + self._viridis_rgb(i / max(1, len(segments) - 1)) + for i in range(len(segments)) + ] + + for idx, (a, b, _step, _role, _rank) in enumerate(segments): + y0 = pad + a * cell_h + y1 = pad + b * cell_h + draw.rectangle( + [pad, y0, pad + left_bar_w - 1, y1 - 1], + fill=segment_colors[idx], ) - # Add separator lines between segments - for a, b, step, role, rank_id in segments: - if a > 0: - ax_bar.axhline(a - 0.5, color="white", linewidth=0.8, alpha=0.7) - # Add last separator line at the end - if n_time > 0: - ax_bar.axhline(n_time - 0.5, color="white", linewidth=0.8, alpha=0.7) - ax_bar.set_xlim(0, 1) - ax_bar.set_xticks([]) - ax_bar.set_yticks([]) - ax_bar.set_title( - "Row: layerK (K = merged layer index)\nstep · role · rank", - fontsize=10, - pad=8, - ) - im = ax.imshow( - heatmap_data, - aspect="auto", - cmap=cmap, - interpolation="nearest", - origin="upper", - ) - ax.set_xlabel("Expert index") - ax.set_title(title) - - # Horizontal lines at every segment boundary (includes step / role / rank changes) - for a, b, step, role, rank_id in segments: - ax.axhline(a - 0.5, color="white", linewidth=0.8, alpha=0.7) - ax.axhline(n_time - 0.5, color="white", linewidth=0.8, alpha=0.7) - - # Y axis: mark each layer only once - layer_positions = [] - layer_labels = [] - if n_time > 0: - current_layer = rec_list[0]["layer_idx"] - layer_positions.append(0) - layer_labels.append(f"layer{current_layer}") - - for j in range(1, n_time): - if rec_list[j]["layer_idx"] != current_layer: - current_layer = rec_list[j]["layer_idx"] - layer_positions.append(j) - layer_labels.append(f"layer{current_layer}") - - # Add the last position if needed - if n_time > 0 and layer_positions[-1] != n_time - 1: - layer_positions.append(n_time - 1) - layer_labels.append(f"layer{rec_list[-1]['layer_idx']}") - - # Downsample layer ticks when snapshots are too many. - max_layer_labels = 40 - if len(layer_positions) > max_layer_labels: - sel_idx = np.linspace( - 0, len(layer_positions) - 1, max_layer_labels, dtype=int - ) - layer_positions = [layer_positions[i] for i in sel_idx] - layer_labels = [layer_labels[i] for i in sel_idx] - - ax.set_yticks(layer_positions) - ax.set_yticklabels(layer_labels, fontsize=6) - ax.set_ylabel("") - - x_stride = max(1, n_exp // 40) - ax.set_xticks(list(range(0, n_exp, x_stride))) - - cbar = fig.colorbar(im, ax=ax, fraction=0.02, pad=0.01) - cbar.set_label("Tokens per expert (group_list)") - - def _seg_legend_label(s: Tuple[int, int, int, str, int]) -> str: - _, _, st, rl, rk = s - rshort = (rl[:14] + "…") if len(str(rl)) > 14 else str(rl) - return f"st{st} · {rshort} · r{rk}" - - # Render step/role/rank directly inside segment blocks (centered). - if segments: - for i, (a, b, step, role, rank_id) in enumerate(segments): - label = _seg_legend_label((a, b, step, role, rank_id)) - seg_h = max(1.0, b - a) - # Adaptive label size based on segment height. - font_size = min(11.5, max(5.5, 4.8 + 0.45 * seg_h)) - ax_bar.text( - 0.5, - a + (b - a - 1) / 2, - label, - fontsize=font_size, - va="center", - ha="center", - rotation=0, - color="black", - clip_on=True, + # Main heatmap is rendered as a stable bitmap to avoid backend-specific crashes. + for t in range(n_time): + for e in range(n_exp): + value = mat[e, t] + if np.isnan(value): + color = (235, 235, 235) + else: + color = self._viridis_rgb((float(value) - vmin) / scale) + x0 = pad + left_bar_w + e * cell_w + y0 = pad + t * cell_h + draw.rectangle( + [x0, y0, x0 + cell_w - 1, y0 + cell_h - 1], + fill=color, + outline=(255, 255, 255), ) - fig.tight_layout() out_path.parent.mkdir(parents=True, exist_ok=True) - fig.savefig(out_path, dpi=dpi, bbox_inches="tight") - plt.close(fig) + image.save(out_path) + + @staticmethod + def _viridis_rgb(x: float) -> tuple[int, int, int]: + anchors = [ + (68, 1, 84), + (59, 82, 139), + (33, 145, 140), + (94, 201, 98), + (253, 231, 37), + ] + x = min(1.0, max(0.0, x)) + pos = x * (len(anchors) - 1) + left = int(pos) + right = min(left + 1, len(anchors) - 1) + frac = pos - left + c0, c1 = anchors[left], anchors[right] + return tuple(int(c0[i] + (c1[i] - c0[i]) * frac) for i in range(3)) diff --git a/tests/data/test_data_checker.py b/tests/data/test_data_checker.py index a1e9c3a..5fc172c 100644 --- a/tests/data/test_data_checker.py +++ b/tests/data/test_data_checker.py @@ -13,19 +13,18 @@ # limitations under the License. import os -from pathlib import Path import pandas as pd import pytest from rl_insight.data.data_checker import DataChecker, DataEnum from rl_insight.data.rules import DataValidationError - -CURRENT_FILE = Path(__file__).resolve() -PROJECT_ROOT = CURRENT_FILE.parents[2] -MSTX_PROFILE_PATH = PROJECT_ROOT / "data/mstx_data/mstx_profile" -NVTX_PROFILE_PATH = PROJECT_ROOT / "data/nvtx_data/nvtx_profile" -TORCH_PROFILE_PATH = PROJECT_ROOT / "data/torch_data/torch_profile" +from tests.data.test_paths import ( + MSTX_PROFILE_PATH, + NVTX_PROFILE_PATH, + PROJECT_ROOT, + TORCH_PROFILE_PATH, +) def test_data_checker_multi_json_path_exists(): diff --git a/tests/data/test_paths.py b/tests/data/test_paths.py new file mode 100644 index 0000000..3fcdb6b --- /dev/null +++ b/tests/data/test_paths.py @@ -0,0 +1,22 @@ +# Copyright (c) 2026 verl-project authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path + + +CURRENT_FILE = Path(__file__).resolve() +PROJECT_ROOT = CURRENT_FILE.parents[2] +MSTX_PROFILE_PATH = PROJECT_ROOT / "data/mstx_data/mstx_profile" +NVTX_PROFILE_PATH = PROJECT_ROOT / "data/nvtx_data/nvtx_profile" +TORCH_PROFILE_PATH = PROJECT_ROOT / "data/torch_data/torch_profile" diff --git a/tests/data/test_rules.py b/tests/data/test_rules.py index 54d5da0..f6eb585 100644 --- a/tests/data/test_rules.py +++ b/tests/data/test_rules.py @@ -23,7 +23,7 @@ NvtxJsonFieldValidRule, ) from rl_insight.data.verl_log_rules import VerlLogExistRule, VerlLogKeyParamsRule -from test_data_checker import MSTX_PROFILE_PATH, TORCH_PROFILE_PATH, NVTX_PROFILE_PATH +from tests.data.test_paths import MSTX_PROFILE_PATH, NVTX_PROFILE_PATH, TORCH_PROFILE_PATH def test_path_exists_rule_accepts_existing_directory(): diff --git a/tests/doc/test_docs_urls.py b/tests/doc/test_docs_urls.py index 336538a..db1adc5 100644 --- a/tests/doc/test_docs_urls.py +++ b/tests/doc/test_docs_urls.py @@ -26,9 +26,10 @@ import os import re -import requests from pathlib import Path +import requests + # Configuration Constants CURRENT_FILE = Path(__file__).resolve() PROJECT_ROOT = CURRENT_FILE.parents[2] @@ -67,7 +68,17 @@ def is_url_valid(url: str) -> bool: """ try: response = requests.head(url, timeout=TIMEOUT, allow_redirects=True) - return 200 <= response.status_code < 400 + if 200 <= response.status_code < 400: + return True + if response.status_code in {403, 405, 429}: + response = requests.get( + url, + timeout=TIMEOUT, + allow_redirects=True, + stream=True, + ) + return 200 <= response.status_code < 400 + return False except requests.exceptions.RequestException: return False diff --git a/tests/parser/test_cluster_analysis.py b/tests/parser/test_cluster_analysis.py index 2f5c44e..be67eec 100644 --- a/tests/parser/test_cluster_analysis.py +++ b/tests/parser/test_cluster_analysis.py @@ -439,6 +439,32 @@ def test_get_data_map(self, mock_mstx_profiler_structure): assert ("rollout_generate", 0) in data_map assert len(data_map[("rollout_generate", 0)]) == 1 + def test_get_data_map_sorts_by_legacy_underscore_segment(self, tmp_path): + """Directory ordering should follow the legacy third-from-last segment.""" + parser = MstxClusterParser( + { + Constant.INPUT_PATH: str(tmp_path), + Constant.RANK_LIST: "all", + } + ) + + first = tmp_path / "role_a" / "20250101_110000_ascend_pt" + second = tmp_path / "role_a" / "20250102_120000_ascend_pt" + first.mkdir(parents=True) + second.mkdir(parents=True) + (first / "profiler_info_0.json").write_text('{"rank_id": 0}') + (second / "profiler_info_0.json").write_text('{"rank_id": 0}') + + path_list = [ + {"role": "role_a", "path": str(second)}, + {"role": "role_a", "path": str(first)}, + ] + + data_map = parser._get_data_map(path_list) + sorted_dirs = data_map[("role_a", 0)] + + assert sorted_dirs == [str(first), str(second)] + # ============================================================================= # BaseClusterParser Tests diff --git a/tests/parser/test_gmm_parser.py b/tests/parser/test_gmm_parser.py new file mode 100644 index 0000000..8362aca --- /dev/null +++ b/tests/parser/test_gmm_parser.py @@ -0,0 +1,28 @@ +# Copyright (c) 2026 verl-project authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from rl_insight.parser.gmm_parser import GmmParser +from rl_insight.utils.schema import Constant + + +def test_gmm_path_parsing_is_cross_platform(): + parser = GmmParser({Constant.RANK_LIST: "all"}) + + windows_style_path = ( + r"C:\workspace\gmm_dump\step_1\actor_update\rank0\dump_tensor_data" + r"\NPU.npu_grouped_matmul.0.forward.kwargs.group_list.pt" + ) + + assert parser._extract_rank_id_from_path(windows_style_path) == 0 + assert parser._extract_step_from_path(windows_style_path) == 1 From f29469994ff55d7c11c5155e279acb02207565ac Mon Sep 17 00:00:00 2001 From: FightingZhen <295632982@qq.com> Date: Sat, 16 May 2026 16:39:45 +0800 Subject: [PATCH 02/10] docs: refine RL timeline quickstart profiling link text --- docs/overview/RL_Timeline_quickstart.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/overview/RL_Timeline_quickstart.md b/docs/overview/RL_Timeline_quickstart.md index 7115596..3d9d58c 100644 --- a/docs/overview/RL_Timeline_quickstart.md +++ b/docs/overview/RL_Timeline_quickstart.md @@ -46,7 +46,7 @@ pip install -e . 使用 VeRL 框架采集性能数据,详细参考: -[VeRL NPU Profiling 文档目录](https://github.com/verl-project/verl/blob/main/docs/ascend_tutorial/dev_guide/performance/ascend_profiling_zh.rst) +[VeRL NPU Profiling 教程](https://github.com/verl-project/verl/blob/main/docs/ascend_tutorial/dev_guide/performance/ascend_profiling_zh.rst) [VeRL GPU Profiling 教程](https://github.com/verl-project/verl/blob/main/docs/perf/nsight_profiling.md) From 0d531fbb9a16420e1bb0952c81990dfa9a2b0829 Mon Sep 17 00:00:00 2001 From: FightingZhen <295632982@qq.com> Date: Sat, 16 May 2026 17:14:08 +0800 Subject: [PATCH 03/10] fix: improve GMM parsing and restore scalable heatmap metadata --- rl_insight/parser/gmm_parser.py | 5 +- rl_insight/visualizer/gmm_visualizer.py | 288 ++++++++++++++++++++++-- tests/parser/test_gmm_parser.py | 1 + tests/visualizer/test_gmm_visualizer.py | 72 ++++++ 4 files changed, 338 insertions(+), 28 deletions(-) create mode 100644 tests/visualizer/test_gmm_visualizer.py diff --git a/rl_insight/parser/gmm_parser.py b/rl_insight/parser/gmm_parser.py index 5dde152..835982a 100644 --- a/rl_insight/parser/gmm_parser.py +++ b/rl_insight/parser/gmm_parser.py @@ -153,9 +153,8 @@ def _load_group_list(self, file_path: str) -> np.ndarray: @staticmethod def _training_step_from_path(profiler_data_path: str) -> int: - normalized = str(profiler_data_path).replace("\\", "/") - m = re.search(r"(?:^|/)step_(\d+)(?:/|$)", normalized) - return int(m.group(1)) if m else 0 + step = GmmParser._extract_step_from_path(profiler_data_path) + return step if step >= 0 else 0 def parse_analysis_data( self, profiler_data_path: str, rank_id: int, role: str diff --git a/rl_insight/visualizer/gmm_visualizer.py b/rl_insight/visualizer/gmm_visualizer.py index 7c2b2cf..62dffd7 100644 --- a/rl_insight/visualizer/gmm_visualizer.py +++ b/rl_insight/visualizer/gmm_visualizer.py @@ -18,7 +18,7 @@ import numpy as np import pandas as pd from loguru import logger -from PIL import Image, ImageDraw +from PIL import Image, ImageDraw, ImageFont from rl_insight.visualizer.visualizer import BaseVisualizer, register_cluster_visualizer from rl_insight.data import DataEnum @@ -127,7 +127,7 @@ def run(self, data): mat, rec_list, boundaries = self._build_matrix_from_data(data) logger.info(f"Built matrix with shape {mat.shape}") segments = self._segment_labels(rec_list, boundaries) - self._plot_heatmap(mat, segments, output) + self._plot_heatmap(mat, rec_list, segments, output) return str(output) def _build_matrix_from_data( @@ -244,25 +244,56 @@ def _segment_labels( def _plot_heatmap( self, mat: np.ndarray, + rec_list: List[dict], segments: List[Tuple[int, int, int, str, int]], out_path: Path, ) -> None: """Plot the heatmap.""" n_exp, n_time = mat.shape - cell_w = 28 - cell_h = 28 - left_bar_w = 120 - pad = 24 + layout = self._compute_layout(n_exp, n_time) + pad = layout["pad"] + title_h = layout["title_h"] + left_bar_w = layout["left_bar_w"] + layer_axis_w = layout["layer_axis_w"] + colorbar_gap = layout["colorbar_gap"] + colorbar_w = layout["colorbar_w"] + heatmap_w = layout["heatmap_w"] + heatmap_h = layout["heatmap_h"] + img_w = layout["img_w"] + img_h = layout["img_h"] + title = self._build_title(rec_list, n_exp) finite_vals = mat[np.isfinite(mat)] vmin = float(finite_vals.min()) if finite_vals.size else 0.0 vmax = float(finite_vals.max()) if finite_vals.size else 1.0 scale = vmax - vmin if vmax > vmin else 1.0 - img_w = pad * 2 + left_bar_w + n_exp * cell_w - img_h = pad * 2 + n_time * cell_h image = Image.new("RGB", (img_w, img_h), "white") draw = ImageDraw.Draw(image) + font = ImageFont.load_default() + title_font = ImageFont.load_default() + heatmap_x0 = pad + left_bar_w + layer_axis_w + heatmap_y0 = pad + title_h + heatmap_x1 = heatmap_x0 + heatmap_w + heatmap_y1 = heatmap_y0 + heatmap_h + colorbar_x0 = heatmap_x1 + colorbar_gap + colorbar_x1 = colorbar_x0 + colorbar_w + + draw.text((pad, pad), title, fill="black", font=title_font) + draw.text((pad, pad + 16), "step | role | rank", fill="black", font=font) + draw.text( + (heatmap_x0 + max(0, heatmap_w // 2 - 35), heatmap_y1 + 22), + "Expert index", + fill="black", + font=font, + ) + self._draw_rotated_text( + image, + (pad + left_bar_w + 8, heatmap_y0 + max(0, heatmap_h // 2 - 28)), + "Layer index", + font, + "black", + ) # Segment bar: one color per (step, role, rank), shown on left side. segment_colors = [ @@ -270,33 +301,240 @@ def _plot_heatmap( for i in range(len(segments)) ] - for idx, (a, b, _step, _role, _rank) in enumerate(segments): - y0 = pad + a * cell_h - y1 = pad + b * cell_h + for idx, segment in enumerate(segments): + a, b, _step, _role, _rank = segment + y0 = self._scaled_position(a, n_time, heatmap_h, heatmap_y0) + y1 = self._scaled_position(b, n_time, heatmap_h, heatmap_y0) + if y1 <= y0: + y1 = min(heatmap_y1, y0 + 1) draw.rectangle( [pad, y0, pad + left_bar_w - 1, y1 - 1], fill=segment_colors[idx], ) + label = self._segment_legend_label(segment) + label = self._fit_text(draw, label, left_bar_w - 10, font) + if label and (y1 - y0) >= 12: + text_bbox = draw.textbbox((0, 0), label, font=font) + text_y = y0 + max(0, (y1 - y0 - (text_bbox[3] - text_bbox[1])) // 2) + draw.text((pad + 4, text_y), label, fill="black", font=font) + + draw.rectangle( + [heatmap_x0 - 1, heatmap_y0 - 1, heatmap_x1, heatmap_y1], + outline=(200, 200, 200), + ) + + layer_ticks = self._layer_ticks(rec_list) + for pos, label in layer_ticks: + y = self._scaled_position(pos, n_time, heatmap_h, heatmap_y0) + draw.line([(heatmap_x0 - 6, y), (heatmap_x0 - 1, y)], fill="black", width=1) + draw.text((pad + left_bar_w + 14, max(heatmap_y0, y - 6)), label, fill="black", font=font) # Main heatmap is rendered as a stable bitmap to avoid backend-specific crashes. - for t in range(n_time): - for e in range(n_exp): - value = mat[e, t] - if np.isnan(value): - color = (235, 235, 235) - else: - color = self._viridis_rgb((float(value) - vmin) / scale) - x0 = pad + left_bar_w + e * cell_w - y0 = pad + t * cell_h - draw.rectangle( - [x0, y0, x0 + cell_w - 1, y0 + cell_h - 1], - fill=color, - outline=(255, 255, 255), - ) + heatmap_rgb = self._heatmap_rgb(mat, vmin, scale) + heatmap_image = Image.fromarray(heatmap_rgb, mode="RGB") + if heatmap_image.size != (heatmap_w, heatmap_h): + heatmap_image = heatmap_image.resize( + (heatmap_w, heatmap_h), resample=Image.Resampling.NEAREST + ) + image.paste(heatmap_image, (heatmap_x0, heatmap_y0)) + + self._draw_expert_ticks(draw, font, heatmap_x0, heatmap_y1, heatmap_w, n_exp) + self._draw_colorbar( + draw, + font, + colorbar_x0, + colorbar_x1, + heatmap_y0, + heatmap_y1, + vmin, + vmax, + ) out_path.parent.mkdir(parents=True, exist_ok=True) image.save(out_path) + def _compute_layout(self, n_exp: int, n_time: int) -> dict[str, int]: + pad = 24 + title_h = 46 + bottom_h = 58 + left_bar_w = 150 + layer_axis_w = 62 + colorbar_gap = 16 + colorbar_w = 44 + target_cell_w = int(self.config.get("cell_width", 28)) + target_cell_h = int(self.config.get("cell_height", 28)) + max_img_w = int(self.config.get("max_image_width", 4096)) + max_img_h = int(self.config.get("max_image_height", 8192)) + + available_w = max( + 1, + max_img_w + - (pad * 2 + left_bar_w + layer_axis_w + colorbar_gap + colorbar_w), + ) + available_h = max(1, max_img_h - (pad * 2 + title_h + bottom_h)) + + heatmap_w = min(available_w, max(1, n_exp * target_cell_w)) + heatmap_h = min(available_h, max(1, n_time * target_cell_h)) + + img_w = pad * 2 + left_bar_w + layer_axis_w + heatmap_w + colorbar_gap + colorbar_w + img_h = pad * 2 + title_h + heatmap_h + bottom_h + return { + "pad": pad, + "title_h": title_h, + "bottom_h": bottom_h, + "left_bar_w": left_bar_w, + "layer_axis_w": layer_axis_w, + "colorbar_gap": colorbar_gap, + "colorbar_w": colorbar_w, + "heatmap_w": heatmap_w, + "heatmap_h": heatmap_h, + "img_w": img_w, + "img_h": img_h, + } + + @staticmethod + def _scaled_position(index: int, total: int, extent: int, offset: int) -> int: + if total <= 0: + return offset + return offset + int(round(index * extent / total)) + + @staticmethod + def _segment_legend_label(segment: Tuple[int, int, int, str, int]) -> str: + _, _, step, role, rank_id = segment + return f"st{step} | {role} | r{rank_id}" + + @staticmethod + def _build_title(rec_list: List[dict], n_exp: int) -> str: + ranks = sorted({rec["rank_id"] for rec in rec_list}) + snapshots = len(rec_list) + if len(ranks) == 1: + rank_text = f"rank={ranks[0]}" + else: + rank_text = f"ranks={len(ranks)}" + return f"GMM expert load ({rank_text}, {snapshots} snapshots, {n_exp} experts)" + + @staticmethod + def _fit_text( + draw: ImageDraw.ImageDraw, + text: str, + max_width: int, + font: ImageFont.ImageFont, + ) -> str: + if draw.textlength(text, font=font) <= max_width: + return text + suffix = "..." + trimmed = text + while trimmed and draw.textlength(trimmed + suffix, font=font) > max_width: + trimmed = trimmed[:-1] + return (trimmed + suffix) if trimmed else "" + + @staticmethod + def _draw_rotated_text( + image: Image.Image, + position: tuple[int, int], + text: str, + font: ImageFont.ImageFont, + fill: str | tuple[int, int, int], + ) -> None: + tmp = Image.new("RGBA", (160, 32), (255, 255, 255, 0)) + tmp_draw = ImageDraw.Draw(tmp) + tmp_draw.text((0, 0), text, fill=fill, font=font) + rotated = tmp.rotate(90, expand=True) + image.paste(rotated, position, rotated) + + def _heatmap_rgb( + self, + mat: np.ndarray, + vmin: float, + scale: float, + ) -> np.ndarray: + heatmap = mat.T + rgb = np.full((heatmap.shape[0], heatmap.shape[1], 3), 235, dtype=np.uint8) + finite_mask = np.isfinite(heatmap) + if np.any(finite_mask): + normalized = np.clip((heatmap[finite_mask] - vmin) / scale, 0.0, 1.0) + palette_idx = np.rint(normalized * 255).astype(np.uint8) + palette = self._viridis_palette() + rgb[finite_mask] = palette[palette_idx] + return rgb + + @staticmethod + def _layer_ticks(rec_list: List[dict]) -> List[Tuple[int, str]]: + if not rec_list: + return [] + + positions = [0] + labels = [f"layer{rec_list[0]['layer_idx']}"] + current_layer = rec_list[0]["layer_idx"] + for idx, rec in enumerate(rec_list[1:], start=1): + if rec["layer_idx"] != current_layer: + current_layer = rec["layer_idx"] + positions.append(idx) + labels.append(f"layer{current_layer}") + + if positions[-1] != len(rec_list) - 1: + positions.append(len(rec_list) - 1) + labels.append(f"layer{rec_list[-1]['layer_idx']}") + + max_labels = 40 + if len(positions) > max_labels: + selected = np.linspace(0, len(positions) - 1, max_labels, dtype=int) + positions = [positions[idx] for idx in selected] + labels = [labels[idx] for idx in selected] + return list(zip(positions, labels)) + + def _draw_expert_ticks( + self, + draw: ImageDraw.ImageDraw, + font: ImageFont.ImageFont, + heatmap_x0: int, + heatmap_y1: int, + heatmap_w: int, + n_exp: int, + ) -> None: + if n_exp <= 0: + return + + tick_count = min(6, n_exp) + tick_indices = np.linspace(0, n_exp - 1, tick_count, dtype=int) + seen: set[int] = set() + for expert_idx in tick_indices: + if int(expert_idx) in seen: + continue + seen.add(int(expert_idx)) + x = heatmap_x0 + int(round((int(expert_idx) + 0.5) * heatmap_w / n_exp)) + draw.line([(x, heatmap_y1), (x, heatmap_y1 + 5)], fill="black", width=1) + label = str(int(expert_idx)) + bbox = draw.textbbox((0, 0), label, font=font) + draw.text((x - (bbox[2] - bbox[0]) // 2, heatmap_y1 + 8), label, fill="black", font=font) + + def _draw_colorbar( + self, + draw: ImageDraw.ImageDraw, + font: ImageFont.ImageFont, + x0: int, + x1: int, + y0: int, + y1: int, + vmin: float, + vmax: float, + ) -> None: + palette = self._viridis_palette() + height = max(1, y1 - y0) + for offset in range(height): + idx = min(255, max(0, int(round((1 - offset / max(1, height - 1)) * 255)))) + color = tuple(int(v) for v in palette[idx]) + draw.line([(x0, y0 + offset), (x1, y0 + offset)], fill=color, width=1) + + draw.rectangle([x0, y0, x1, y1], outline=(120, 120, 120)) + draw.text((x0 - 2, max(0, y0 - 18)), f"{vmax:.2f}", fill="black", font=font) + draw.text((x0 - 2, y1 + 4), f"{vmin:.2f}", fill="black", font=font) + draw.text((x0 - 6, max(0, y0 - 34)), "Load", fill="black", font=font) + + @classmethod + def _viridis_palette(cls) -> np.ndarray: + return np.array([cls._viridis_rgb(i / 255.0) for i in range(256)], dtype=np.uint8) + @staticmethod def _viridis_rgb(x: float) -> tuple[int, int, int]: anchors = [ diff --git a/tests/parser/test_gmm_parser.py b/tests/parser/test_gmm_parser.py index 8362aca..dcdbddd 100644 --- a/tests/parser/test_gmm_parser.py +++ b/tests/parser/test_gmm_parser.py @@ -26,3 +26,4 @@ def test_gmm_path_parsing_is_cross_platform(): assert parser._extract_rank_id_from_path(windows_style_path) == 0 assert parser._extract_step_from_path(windows_style_path) == 1 + assert parser._training_step_from_path(windows_style_path) == 1 diff --git a/tests/visualizer/test_gmm_visualizer.py b/tests/visualizer/test_gmm_visualizer.py new file mode 100644 index 0000000..ef87733 --- /dev/null +++ b/tests/visualizer/test_gmm_visualizer.py @@ -0,0 +1,72 @@ +# Copyright (c) 2026 verl-project authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pandas as pd +from PIL import Image + +from rl_insight.visualizer.gmm_visualizer import GmmVisualizer + + +def _build_gmm_dataframe( + num_steps: int = 1, + num_stages: int = 16, + num_experts: int = 32, +) -> pd.DataFrame: + rows = [] + for step in range(num_steps): + for stage in range(num_stages): + for expert_index in range(num_experts): + rows.append( + { + "role": "actor_update", + "rank_id": 0, + "step": step, + "stage": stage, + "expert_index": expert_index, + "load": float((stage + expert_index) % 11), + } + ) + return pd.DataFrame(rows) + + +def test_gmm_visualizer_caps_large_output_and_renders_metadata(tmp_path): + output_dir = tmp_path / "gmm_output" + data = _build_gmm_dataframe(num_stages=160, num_experts=96) + visualizer = GmmVisualizer( + { + "output_path": str(output_dir), + "max_image_width": 720, + "max_image_height": 720, + "gmm_per_layer": 1, + } + ) + + output_path = visualizer.run(data) + + with Image.open(output_path) as image: + pixels = np.asarray(image) + + assert pixels.shape[1] <= 720 + assert pixels.shape[0] <= 720 + + top_strip = pixels[:50] + bottom_strip = pixels[-70:] + right_strip = pixels[:, -80:] + center_strip = pixels[:, 120:220] + + assert np.any(np.any(top_strip != 255, axis=2)) + assert np.any(np.any(bottom_strip != 255, axis=2)) + assert np.any(np.any(right_strip != 255, axis=2)) + assert np.any(np.any(center_strip != 255, axis=2)) From caa3bf49fe2d094417b4f3f48dee80807b56b4f2 Mon Sep 17 00:00:00 2001 From: FightingZhen <295632982@qq.com> Date: Sat, 16 May 2026 18:17:35 +0800 Subject: [PATCH 04/10] fix: restore matplotlib gmm visualizer --- rl_insight/data/rules.py | 5 +- rl_insight/visualizer/gmm_visualizer.py | 570 +++++++++--------------- tests/visualizer/test_gmm_visualizer.py | 72 --- 3 files changed, 214 insertions(+), 433 deletions(-) delete mode 100644 tests/visualizer/test_gmm_visualizer.py diff --git a/rl_insight/data/rules.py b/rl_insight/data/rules.py index e471d2e..fd50198 100644 --- a/rl_insight/data/rules.py +++ b/rl_insight/data/rules.py @@ -27,10 +27,7 @@ def _coerce_path(data: Any) -> Optional[Path]: if isinstance(data, Path): return data if isinstance(data, str): - try: - return Path(data) - except TypeError: - return None + return Path(data) return None diff --git a/rl_insight/visualizer/gmm_visualizer.py b/rl_insight/visualizer/gmm_visualizer.py index 62dffd7..11fbd93 100644 --- a/rl_insight/visualizer/gmm_visualizer.py +++ b/rl_insight/visualizer/gmm_visualizer.py @@ -15,13 +15,13 @@ from pathlib import Path from typing import Any, List, Tuple +import matplotlib.pyplot as plt import numpy as np import pandas as pd from loguru import logger -from PIL import Image, ImageDraw, ImageFont -from rl_insight.visualizer.visualizer import BaseVisualizer, register_cluster_visualizer from rl_insight.data import DataEnum +from rl_insight.visualizer.visualizer import BaseVisualizer, register_cluster_visualizer @register_cluster_visualizer("gmm_heatmap") @@ -37,7 +37,8 @@ def _resolve_output_path(output_cfg) -> Path: - Path with suffix (e.g., 'a/b/c.png') -> treat as explicit file path. """ output = Path(output_cfg) - if output.is_dir() or output.suffix == "": + is_dir_semantics = output.is_dir() or output.suffix == "" + if is_dir_semantics: output = output / "gmm_heatmap.png" return output @@ -48,19 +49,23 @@ def _load_signature(stage_data: pd.DataFrame) -> np.ndarray: def run(self, data): """Run GMM heatmap visualization from parsed data.""" + # Extract parameters from config output_cfg = self.config.get( "output_path", "./output/gmm_group_list_heatmap.png" ) output = self._resolve_output_path(output_cfg) + dpi = self.config.get("dpi", 150) + cmap = self.config.get("cmap", "viridis") gmm_per_layer = int(self.config.get("gmm_per_layer", 3)) if not isinstance(data, pd.DataFrame): raise ValueError(f"Expected DataFrame, got {type(data).__name__}") - if data.empty: - raise ValueError("No GMM data provided") logger.info(f"GmmVisualizer received DataFrame with {len(data)} rows") logger.info(f"DataFrame columns: {list(data.columns)}") + + if data.empty: + raise ValueError("No GMM data provided") logger.info("Visualizer consumes parser-filtered GMM summary data.") # For actor_update, filter out backward/recompute data by detecting @@ -73,88 +78,111 @@ def run(self, data): # This works regardless of whether gradient recomputation is enabled: # - With recomputation: forward runs of 3, then a run >3 triggers cutoff # - Without recomputation: forward runs of 3, then a run of 3+3=6 triggers cutoff - if "actor_update" in data["role"].unique(): + is_actor_update = "actor_update" in data["role"].unique() + if is_actor_update: grouped = data.groupby(["step", "role", "rank_id"]) filtered_data = [] - for (step_val, role_val, rank_val), group in grouped: - if role_val != "actor_update": + for name, group in grouped: + step_val, role_val, rank_val = name + if role_val == "actor_update": + sorted_group = group.sort_values("stage") + unique_stages = sorted(sorted_group["stage"].unique()) + + # Build load signature for each stage + stage_loads = {} + for stage in unique_stages: + stage_data = sorted_group[sorted_group["stage"] == stage] + load_sig = self._load_signature(stage_data) + stage_loads[stage] = load_sig + + # Scan forward: keep stages until a run exceeds gmm_per_layer. + forward_stages = [] + prev_load = None + consecutive = 0 + backward_detected = False + + for stage in unique_stages: + if backward_detected: + break + load = stage_loads[stage] + if prev_load is not None and np.array_equal(load, prev_load): + consecutive += 1 + else: + prev_load = load + consecutive = 1 + + if consecutive <= gmm_per_layer: + forward_stages.append(stage) + else: + backward_detected = True + + filtered_group = sorted_group[ + sorted_group["stage"].isin(forward_stages) + ] + filtered_data.append(filtered_group) + logger.info( + f"For actor_update (step={step_val}, rank={rank_val}): " + f"kept {len(forward_stages)} forward stages out of {len(unique_stages)} total " + f"(backward detected={backward_detected}, gmm_per_layer={gmm_per_layer})" + ) + else: filtered_data.append(group) - continue - - sorted_group = group.sort_values("stage") - unique_stages = sorted(sorted_group["stage"].unique()) - stage_loads = {} - for stage in unique_stages: - stage_data = sorted_group[sorted_group["stage"] == stage] - stage_loads[stage] = self._load_signature(stage_data) - - forward_stages = [] - prev_load = None - consecutive = 0 - backward_detected = False - for stage in unique_stages: - if backward_detected: - break - load = stage_loads[stage] - if prev_load is not None and np.array_equal(load, prev_load): - consecutive += 1 - else: - prev_load = load - consecutive = 1 - - if consecutive <= gmm_per_layer: - forward_stages.append(stage) - else: - backward_detected = True - - filtered_group = sorted_group[ - sorted_group["stage"].isin(forward_stages) - ] - filtered_data.append(filtered_group) + + if filtered_data: + data = pd.concat(filtered_data) logger.info( - f"For actor_update (step={step_val}, rank={rank_val}): " - f"kept {len(forward_stages)} forward stages out of {len(unique_stages)} total " - f"(backward detected={backward_detected}, gmm_per_layer={gmm_per_layer})" + f"After filtering actor_update forward-only data, now {len(data)} rows" ) - - if not filtered_data: + else: + logger.warning("No data left after filtering") raise ValueError("No data left after filtering") - data = pd.concat(filtered_data) - logger.info( - f"After filtering actor_update forward-only data, now {len(data)} rows" - ) + # Build matrix mat, rec_list, boundaries = self._build_matrix_from_data(data) logger.info(f"Built matrix with shape {mat.shape}") + segments = self._segment_labels(rec_list, boundaries) - self._plot_heatmap(mat, rec_list, segments, output) + + # Generate title + unique_ranks = sorted(data["rank_id"].unique()) + if len(unique_ranks) == 1: + rank_str = f" rank={unique_ranks[0]}" + else: + rank_str = f" ranks={len(unique_ranks)}" + title = f"GMM expert load (group_list){rank_str} - {len(rec_list)} snapshots, {mat.shape[0]} experts" + + # Plot heatmap + self._plot_heatmap(mat, rec_list, segments, title, output, dpi, cmap) + return str(output) def _build_matrix_from_data( self, data: pd.DataFrame ) -> Tuple[np.ndarray, List[dict], List[int]]: """Build a matrix from the parsed data.""" - # Group data by step, role, rank_id, stage. - # First sort the data to ensure consistent ordering. + # Group data by step, role, rank_id, stage + # First sort the data to ensure consistent ordering sorted_data = data.sort_values(["step", "role", "rank_id", "stage"]) grouped = sorted_data.groupby(["step", "role", "rank_id", "stage"]) - # Get unique steps, roles, ranks and stages. + # Get unique steps, roles, ranks and stages steps = sorted(data["step"].unique()) roles = sorted(data["role"].unique()) ranks = sorted(data["rank_id"].unique()) stages = sorted(data["stage"].unique()) - max_expert = int(data["expert_index"].max()) + max_expert = data["expert_index"].max() + logger.info(f"Steps: {steps}") logger.info(f"Roles: {roles}") logger.info(f"Ranks: {ranks}") logger.info(f"Stages: {stages}") logger.info(f"Max expert index: {max_expert}") + # Build matrix and detect duplicate stages vecs = [] rec_list = [] - # Track layer mapping per (step, role, rank) group. + # Track layer mapping per (step, role, rank) group current_group = None seen_vectors: dict[tuple[Any, ...], int] = {} layer_counter = 0 @@ -164,10 +192,11 @@ def _build_matrix_from_data( logger.info( f"Processing step: {step}, role: {role}, rank: {rank}, stage: {stage_idx}" ) - # Check if we're in a new (step, role, rank) group. + + # Check if we're in a new (step, role, rank) group new_group = (step, role, rank) if new_group != current_group: - # Reset layer counter and seen vectors for the new group. + # Reset layer counter and seen vectors for new group current_group = new_group seen_vectors.clear() layer_counter = 0 @@ -175,20 +204,23 @@ def _build_matrix_from_data( f"New group detected: {new_group}, resetting layer counter to 0" ) - # Create a vector for this step, role, rank and stage. + # Create a vector for this step, role, rank and stage vec = np.full(max_expert + 1, np.nan, dtype=np.float64) for _, row in group.iterrows(): - vec[int(row["expert_index"])] = row["load"] + expert_idx = row["expert_index"] + vec[expert_idx] = row["load"] - # Convert vector to tuple for hashing, replacing NaN to keep comparisons stable. + # Convert vector to tuple for hashing (handle NaN values) vec_tuple = tuple(v if not np.isnan(v) else -1 for v in vec) + + # Check if this vector has been seen before in current group if vec_tuple not in seen_vectors: - # New layer. + # New layer seen_vectors[vec_tuple] = layer_counter layer_idx = layer_counter layer_counter += 1 else: - # Duplicate layer. + # Duplicate layer layer_idx = seen_vectors[vec_tuple] vecs.append(vec) @@ -199,7 +231,7 @@ def _build_matrix_from_data( "rank_id": rank, "stage": stage_idx, "op_index": stage_idx, # Original op index - "layer_idx": layer_idx, # Mapped layer index. + "layer_idx": layer_idx, # Mapped layer index } ) @@ -226,6 +258,7 @@ def _build_matrix_from_data( cur_key = new_key boundaries.append(mat.shape[1]) logger.info(f"Boundaries (step/role/rank): {boundaries}") + return mat, rec_list, boundaries def _segment_labels( @@ -246,308 +279,131 @@ def _plot_heatmap( mat: np.ndarray, rec_list: List[dict], segments: List[Tuple[int, int, int, str, int]], + title: str, out_path: Path, + dpi: int, + cmap: str, ) -> None: """Plot the heatmap.""" n_exp, n_time = mat.shape - layout = self._compute_layout(n_exp, n_time) - pad = layout["pad"] - title_h = layout["title_h"] - left_bar_w = layout["left_bar_w"] - layer_axis_w = layout["layer_axis_w"] - colorbar_gap = layout["colorbar_gap"] - colorbar_w = layout["colorbar_w"] - heatmap_w = layout["heatmap_w"] - heatmap_h = layout["heatmap_h"] - img_w = layout["img_w"] - img_h = layout["img_h"] - title = self._build_title(rec_list, n_exp) - - finite_vals = mat[np.isfinite(mat)] - vmin = float(finite_vals.min()) if finite_vals.size else 0.0 - vmax = float(finite_vals.max()) if finite_vals.size else 1.0 - scale = vmax - vmin if vmax > vmin else 1.0 - - image = Image.new("RGB", (img_w, img_h), "white") - draw = ImageDraw.Draw(image) - font = ImageFont.load_default() - title_font = ImageFont.load_default() - heatmap_x0 = pad + left_bar_w + layer_axis_w - heatmap_y0 = pad + title_h - heatmap_x1 = heatmap_x0 + heatmap_w - heatmap_y1 = heatmap_y0 + heatmap_h - colorbar_x0 = heatmap_x1 + colorbar_gap - colorbar_x1 = colorbar_x0 + colorbar_w - - draw.text((pad, pad), title, fill="black", font=title_font) - draw.text((pad, pad + 16), "step | role | rank", fill="black", font=font) - draw.text( - (heatmap_x0 + max(0, heatmap_w // 2 - 35), heatmap_y1 + 22), - "Expert index", - fill="black", - font=font, - ) - self._draw_rotated_text( - image, - (pad + left_bar_w + 8, heatmap_y0 + max(0, heatmap_h // 2 - 28)), - "Layer index", - font, - "black", - ) + # Keep figure size readable when segment/time dimension is large. + # Use sub-linear growth for height to avoid overly tall and narrow figures. + fig_w = min(32, max(10, n_exp * 0.18)) + fig_h = min(22, max(8, 6 + np.sqrt(max(n_time, 1)) * 0.9)) + fig = plt.figure(figsize=(fig_w + 2.8, fig_h)) + gs = fig.add_gridspec(1, 2, width_ratios=[0.16, 1], wspace=0.05) + ax_bar = fig.add_subplot(gs[0, 0]) + ax = fig.add_subplot(gs[0, 1]) + + # Main heatmap is transposed to put experts on X axis. + # mat: [n_experts, n_time] -> heatmap_data: [n_time, n_experts] + heatmap_data = mat.T + ax_bar.set_ylim(-0.5, n_time - 0.5) + ax.set_ylim(-0.5, n_time - 0.5) + ax.set_xlim(-0.5, n_exp - 0.5) # Segment bar: one color per (step, role, rank), shown on left side. - segment_colors = [ - self._viridis_rgb(i / max(1, len(segments) - 1)) - for i in range(len(segments)) - ] - - for idx, segment in enumerate(segments): - a, b, _step, _role, _rank = segment - y0 = self._scaled_position(a, n_time, heatmap_h, heatmap_y0) - y1 = self._scaled_position(b, n_time, heatmap_h, heatmap_y0) - if y1 <= y0: - y1 = min(heatmap_y1, y0 + 1) - draw.rectangle( - [pad, y0, pad + left_bar_w - 1, y1 - 1], - fill=segment_colors[idx], + # Use viridis colormap for consistency with heatmap + palette = plt.cm.viridis(np.linspace(0, 1, len(segments))) + for i, (a, b, step, role, rank_id) in enumerate(segments): + color = palette[i] + ax_bar.axhspan( + a - 0.5, b - 0.5, facecolor=color, alpha=0.55, edgecolor="none" ) - label = self._segment_legend_label(segment) - label = self._fit_text(draw, label, left_bar_w - 10, font) - if label and (y1 - y0) >= 12: - text_bbox = draw.textbbox((0, 0), label, font=font) - text_y = y0 + max(0, (y1 - y0 - (text_bbox[3] - text_bbox[1])) // 2) - draw.text((pad + 4, text_y), label, fill="black", font=font) - - draw.rectangle( - [heatmap_x0 - 1, heatmap_y0 - 1, heatmap_x1, heatmap_y1], - outline=(200, 200, 200), - ) - layer_ticks = self._layer_ticks(rec_list) - for pos, label in layer_ticks: - y = self._scaled_position(pos, n_time, heatmap_h, heatmap_y0) - draw.line([(heatmap_x0 - 6, y), (heatmap_x0 - 1, y)], fill="black", width=1) - draw.text((pad + left_bar_w + 14, max(heatmap_y0, y - 6)), label, fill="black", font=font) - - # Main heatmap is rendered as a stable bitmap to avoid backend-specific crashes. - heatmap_rgb = self._heatmap_rgb(mat, vmin, scale) - heatmap_image = Image.fromarray(heatmap_rgb, mode="RGB") - if heatmap_image.size != (heatmap_w, heatmap_h): - heatmap_image = heatmap_image.resize( - (heatmap_w, heatmap_h), resample=Image.Resampling.NEAREST - ) - image.paste(heatmap_image, (heatmap_x0, heatmap_y0)) - - self._draw_expert_ticks(draw, font, heatmap_x0, heatmap_y1, heatmap_w, n_exp) - self._draw_colorbar( - draw, - font, - colorbar_x0, - colorbar_x1, - heatmap_y0, - heatmap_y1, - vmin, - vmax, + # Add separator lines between segments + for a, b, step, role, rank_id in segments: + if a > 0: + ax_bar.axhline(a - 0.5, color="white", linewidth=0.8, alpha=0.7) + # Add last separator line at the end + if n_time > 0: + ax_bar.axhline(n_time - 0.5, color="white", linewidth=0.8, alpha=0.7) + ax_bar.set_xlim(0, 1) + ax_bar.set_xticks([]) + ax_bar.set_yticks([]) + ax_bar.set_title( + "Row: layerK (K = merged layer index)\nstep | role | rank", + fontsize=10, + pad=8, ) - - out_path.parent.mkdir(parents=True, exist_ok=True) - image.save(out_path) - - def _compute_layout(self, n_exp: int, n_time: int) -> dict[str, int]: - pad = 24 - title_h = 46 - bottom_h = 58 - left_bar_w = 150 - layer_axis_w = 62 - colorbar_gap = 16 - colorbar_w = 44 - target_cell_w = int(self.config.get("cell_width", 28)) - target_cell_h = int(self.config.get("cell_height", 28)) - max_img_w = int(self.config.get("max_image_width", 4096)) - max_img_h = int(self.config.get("max_image_height", 8192)) - - available_w = max( - 1, - max_img_w - - (pad * 2 + left_bar_w + layer_axis_w + colorbar_gap + colorbar_w), + im = ax.imshow( + heatmap_data, + aspect="auto", + cmap=cmap, + interpolation="nearest", + origin="upper", ) - available_h = max(1, max_img_h - (pad * 2 + title_h + bottom_h)) - - heatmap_w = min(available_w, max(1, n_exp * target_cell_w)) - heatmap_h = min(available_h, max(1, n_time * target_cell_h)) - - img_w = pad * 2 + left_bar_w + layer_axis_w + heatmap_w + colorbar_gap + colorbar_w - img_h = pad * 2 + title_h + heatmap_h + bottom_h - return { - "pad": pad, - "title_h": title_h, - "bottom_h": bottom_h, - "left_bar_w": left_bar_w, - "layer_axis_w": layer_axis_w, - "colorbar_gap": colorbar_gap, - "colorbar_w": colorbar_w, - "heatmap_w": heatmap_w, - "heatmap_h": heatmap_h, - "img_w": img_w, - "img_h": img_h, - } - - @staticmethod - def _scaled_position(index: int, total: int, extent: int, offset: int) -> int: - if total <= 0: - return offset - return offset + int(round(index * extent / total)) - - @staticmethod - def _segment_legend_label(segment: Tuple[int, int, int, str, int]) -> str: - _, _, step, role, rank_id = segment - return f"st{step} | {role} | r{rank_id}" - - @staticmethod - def _build_title(rec_list: List[dict], n_exp: int) -> str: - ranks = sorted({rec["rank_id"] for rec in rec_list}) - snapshots = len(rec_list) - if len(ranks) == 1: - rank_text = f"rank={ranks[0]}" - else: - rank_text = f"ranks={len(ranks)}" - return f"GMM expert load ({rank_text}, {snapshots} snapshots, {n_exp} experts)" - - @staticmethod - def _fit_text( - draw: ImageDraw.ImageDraw, - text: str, - max_width: int, - font: ImageFont.ImageFont, - ) -> str: - if draw.textlength(text, font=font) <= max_width: - return text - suffix = "..." - trimmed = text - while trimmed and draw.textlength(trimmed + suffix, font=font) > max_width: - trimmed = trimmed[:-1] - return (trimmed + suffix) if trimmed else "" - - @staticmethod - def _draw_rotated_text( - image: Image.Image, - position: tuple[int, int], - text: str, - font: ImageFont.ImageFont, - fill: str | tuple[int, int, int], - ) -> None: - tmp = Image.new("RGBA", (160, 32), (255, 255, 255, 0)) - tmp_draw = ImageDraw.Draw(tmp) - tmp_draw.text((0, 0), text, fill=fill, font=font) - rotated = tmp.rotate(90, expand=True) - image.paste(rotated, position, rotated) - - def _heatmap_rgb( - self, - mat: np.ndarray, - vmin: float, - scale: float, - ) -> np.ndarray: - heatmap = mat.T - rgb = np.full((heatmap.shape[0], heatmap.shape[1], 3), 235, dtype=np.uint8) - finite_mask = np.isfinite(heatmap) - if np.any(finite_mask): - normalized = np.clip((heatmap[finite_mask] - vmin) / scale, 0.0, 1.0) - palette_idx = np.rint(normalized * 255).astype(np.uint8) - palette = self._viridis_palette() - rgb[finite_mask] = palette[palette_idx] - return rgb - - @staticmethod - def _layer_ticks(rec_list: List[dict]) -> List[Tuple[int, str]]: - if not rec_list: - return [] - - positions = [0] - labels = [f"layer{rec_list[0]['layer_idx']}"] - current_layer = rec_list[0]["layer_idx"] - for idx, rec in enumerate(rec_list[1:], start=1): - if rec["layer_idx"] != current_layer: - current_layer = rec["layer_idx"] - positions.append(idx) - labels.append(f"layer{current_layer}") - - if positions[-1] != len(rec_list) - 1: - positions.append(len(rec_list) - 1) - labels.append(f"layer{rec_list[-1]['layer_idx']}") - - max_labels = 40 - if len(positions) > max_labels: - selected = np.linspace(0, len(positions) - 1, max_labels, dtype=int) - positions = [positions[idx] for idx in selected] - labels = [labels[idx] for idx in selected] - return list(zip(positions, labels)) - - def _draw_expert_ticks( - self, - draw: ImageDraw.ImageDraw, - font: ImageFont.ImageFont, - heatmap_x0: int, - heatmap_y1: int, - heatmap_w: int, - n_exp: int, - ) -> None: - if n_exp <= 0: - return - - tick_count = min(6, n_exp) - tick_indices = np.linspace(0, n_exp - 1, tick_count, dtype=int) - seen: set[int] = set() - for expert_idx in tick_indices: - if int(expert_idx) in seen: - continue - seen.add(int(expert_idx)) - x = heatmap_x0 + int(round((int(expert_idx) + 0.5) * heatmap_w / n_exp)) - draw.line([(x, heatmap_y1), (x, heatmap_y1 + 5)], fill="black", width=1) - label = str(int(expert_idx)) - bbox = draw.textbbox((0, 0), label, font=font) - draw.text((x - (bbox[2] - bbox[0]) // 2, heatmap_y1 + 8), label, fill="black", font=font) - - def _draw_colorbar( - self, - draw: ImageDraw.ImageDraw, - font: ImageFont.ImageFont, - x0: int, - x1: int, - y0: int, - y1: int, - vmin: float, - vmax: float, - ) -> None: - palette = self._viridis_palette() - height = max(1, y1 - y0) - for offset in range(height): - idx = min(255, max(0, int(round((1 - offset / max(1, height - 1)) * 255)))) - color = tuple(int(v) for v in palette[idx]) - draw.line([(x0, y0 + offset), (x1, y0 + offset)], fill=color, width=1) - - draw.rectangle([x0, y0, x1, y1], outline=(120, 120, 120)) - draw.text((x0 - 2, max(0, y0 - 18)), f"{vmax:.2f}", fill="black", font=font) - draw.text((x0 - 2, y1 + 4), f"{vmin:.2f}", fill="black", font=font) - draw.text((x0 - 6, max(0, y0 - 34)), "Load", fill="black", font=font) - - @classmethod - def _viridis_palette(cls) -> np.ndarray: - return np.array([cls._viridis_rgb(i / 255.0) for i in range(256)], dtype=np.uint8) + ax.set_xlabel("Expert index") + ax.set_title(title) + + # Horizontal lines at every segment boundary (includes step / role / rank changes) + for a, b, step, role, rank_id in segments: + ax.axhline(a - 0.5, color="white", linewidth=0.8, alpha=0.7) + ax.axhline(n_time - 0.5, color="white", linewidth=0.8, alpha=0.7) + + # Y axis: mark each layer only once + layer_positions = [] + layer_labels = [] + if n_time > 0: + current_layer = rec_list[0]["layer_idx"] + layer_positions.append(0) + layer_labels.append(f"layer{current_layer}") + + for j in range(1, n_time): + if rec_list[j]["layer_idx"] != current_layer: + current_layer = rec_list[j]["layer_idx"] + layer_positions.append(j) + layer_labels.append(f"layer{current_layer}") + + # Add the last position if needed + if n_time > 0 and layer_positions[-1] != n_time - 1: + layer_positions.append(n_time - 1) + layer_labels.append(f"layer{rec_list[-1]['layer_idx']}") + + # Downsample layer ticks when snapshots are too many. + max_layer_labels = 40 + if len(layer_positions) > max_layer_labels: + sel_idx = np.linspace( + 0, len(layer_positions) - 1, max_layer_labels, dtype=int + ) + layer_positions = [layer_positions[i] for i in sel_idx] + layer_labels = [layer_labels[i] for i in sel_idx] + + ax.set_yticks(layer_positions) + ax.set_yticklabels(layer_labels, fontsize=6) + ax.set_ylabel("") + + x_stride = max(1, n_exp // 40) + ax.set_xticks(list(range(0, n_exp, x_stride))) + + cbar = fig.colorbar(im, ax=ax, fraction=0.02, pad=0.01) + cbar.set_label("Tokens per expert (group_list)") + + def _seg_legend_label(s: Tuple[int, int, int, str, int]) -> str: + _, _, st, rl, rk = s + rshort = (rl[:14] + "...") if len(str(rl)) > 14 else str(rl) + return f"st{st} | {rshort} | r{rk}" + + # Render step/role/rank directly inside segment blocks (centered). + if segments: + for i, (a, b, step, role, rank_id) in enumerate(segments): + label = _seg_legend_label((a, b, step, role, rank_id)) + seg_h = max(1.0, b - a) + # Adaptive label size based on segment height. + font_size = min(11.5, max(5.5, 4.8 + 0.45 * seg_h)) + ax_bar.text( + 0.5, + a + (b - a - 1) / 2, + label, + fontsize=font_size, + va="center", + ha="center", + rotation=0, + color="black", + clip_on=True, + ) - @staticmethod - def _viridis_rgb(x: float) -> tuple[int, int, int]: - anchors = [ - (68, 1, 84), - (59, 82, 139), - (33, 145, 140), - (94, 201, 98), - (253, 231, 37), - ] - x = min(1.0, max(0.0, x)) - pos = x * (len(anchors) - 1) - left = int(pos) - right = min(left + 1, len(anchors) - 1) - frac = pos - left - c0, c1 = anchors[left], anchors[right] - return tuple(int(c0[i] + (c1[i] - c0[i]) * frac) for i in range(3)) + fig.tight_layout() + out_path.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(out_path, dpi=dpi, bbox_inches="tight") + plt.close(fig) diff --git a/tests/visualizer/test_gmm_visualizer.py b/tests/visualizer/test_gmm_visualizer.py deleted file mode 100644 index ef87733..0000000 --- a/tests/visualizer/test_gmm_visualizer.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright (c) 2026 verl-project authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np -import pandas as pd -from PIL import Image - -from rl_insight.visualizer.gmm_visualizer import GmmVisualizer - - -def _build_gmm_dataframe( - num_steps: int = 1, - num_stages: int = 16, - num_experts: int = 32, -) -> pd.DataFrame: - rows = [] - for step in range(num_steps): - for stage in range(num_stages): - for expert_index in range(num_experts): - rows.append( - { - "role": "actor_update", - "rank_id": 0, - "step": step, - "stage": stage, - "expert_index": expert_index, - "load": float((stage + expert_index) % 11), - } - ) - return pd.DataFrame(rows) - - -def test_gmm_visualizer_caps_large_output_and_renders_metadata(tmp_path): - output_dir = tmp_path / "gmm_output" - data = _build_gmm_dataframe(num_stages=160, num_experts=96) - visualizer = GmmVisualizer( - { - "output_path": str(output_dir), - "max_image_width": 720, - "max_image_height": 720, - "gmm_per_layer": 1, - } - ) - - output_path = visualizer.run(data) - - with Image.open(output_path) as image: - pixels = np.asarray(image) - - assert pixels.shape[1] <= 720 - assert pixels.shape[0] <= 720 - - top_strip = pixels[:50] - bottom_strip = pixels[-70:] - right_strip = pixels[:, -80:] - center_strip = pixels[:, 120:220] - - assert np.any(np.any(top_strip != 255, axis=2)) - assert np.any(np.any(bottom_strip != 255, axis=2)) - assert np.any(np.any(right_strip != 255, axis=2)) - assert np.any(np.any(center_strip != 255, axis=2)) From 4331ffad258ca07cb9eef7c00e0d4c6303224dbd Mon Sep 17 00:00:00 2001 From: FightingZhen <295632982@qq.com> Date: Sat, 16 May 2026 18:23:33 +0800 Subject: [PATCH 05/10] fix: restore original gmm visualizer --- rl_insight/visualizer/gmm_visualizer.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/rl_insight/visualizer/gmm_visualizer.py b/rl_insight/visualizer/gmm_visualizer.py index 11fbd93..79610b0 100644 --- a/rl_insight/visualizer/gmm_visualizer.py +++ b/rl_insight/visualizer/gmm_visualizer.py @@ -14,14 +14,13 @@ from pathlib import Path from typing import Any, List, Tuple - import matplotlib.pyplot as plt import numpy as np import pandas as pd from loguru import logger -from rl_insight.data import DataEnum from rl_insight.visualizer.visualizer import BaseVisualizer, register_cluster_visualizer +from rl_insight.data import DataEnum @register_cluster_visualizer("gmm_heatmap") @@ -149,7 +148,7 @@ def run(self, data): rank_str = f" rank={unique_ranks[0]}" else: rank_str = f" ranks={len(unique_ranks)}" - title = f"GMM expert load (group_list){rank_str} - {len(rec_list)} snapshots, {mat.shape[0]} experts" + title = f"GMM expert load (group_list){rank_str} — {len(rec_list)} snapshots, {mat.shape[0]} experts" # Plot heatmap self._plot_heatmap(mat, rec_list, segments, title, output, dpi, cmap) @@ -322,7 +321,7 @@ def _plot_heatmap( ax_bar.set_xticks([]) ax_bar.set_yticks([]) ax_bar.set_title( - "Row: layerK (K = merged layer index)\nstep | role | rank", + "Row: layerK (K = merged layer index)\nstep · role · rank", fontsize=10, pad=8, ) @@ -381,8 +380,8 @@ def _plot_heatmap( def _seg_legend_label(s: Tuple[int, int, int, str, int]) -> str: _, _, st, rl, rk = s - rshort = (rl[:14] + "...") if len(str(rl)) > 14 else str(rl) - return f"st{st} | {rshort} | r{rk}" + rshort = (rl[:14] + "…") if len(str(rl)) > 14 else str(rl) + return f"st{st} · {rshort} · r{rk}" # Render step/role/rank directly inside segment blocks (centered). if segments: From c0bbe166824d5165103bceb268d897ed94f37734 Mon Sep 17 00:00:00 2001 From: FightingZhen <295632982@qq.com> Date: Sun, 17 May 2026 00:56:16 +0800 Subject: [PATCH 06/10] fix: improve MSTX ordering and harden docs URL validation --- rl_insight/parser/mstx_parser.py | 2 ++ tests/doc/test_docs_urls.py | 6 +++--- tests/parser/test_cluster_analysis.py | 8 ++++---- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/rl_insight/parser/mstx_parser.py b/rl_insight/parser/mstx_parser.py index f3cda45..1647ea2 100644 --- a/rl_insight/parser/mstx_parser.py +++ b/rl_insight/parser/mstx_parser.py @@ -139,6 +139,8 @@ def _extract_timestamp_key(path_value: str) -> str: """Extract the timestamp-like segment using the legacy underscore layout.""" dir_name = Path(path_value).name parts = dir_name.split("_") + if len(parts) >= 4: + return "_".join(parts[-4:-2]) if len(parts) >= 3: return parts[-3] return dir_name diff --git a/tests/doc/test_docs_urls.py b/tests/doc/test_docs_urls.py index db1adc5..e153aa9 100644 --- a/tests/doc/test_docs_urls.py +++ b/tests/doc/test_docs_urls.py @@ -71,13 +71,13 @@ def is_url_valid(url: str) -> bool: if 200 <= response.status_code < 400: return True if response.status_code in {403, 405, 429}: - response = requests.get( + with requests.get( url, timeout=TIMEOUT, allow_redirects=True, stream=True, - ) - return 200 <= response.status_code < 400 + ) as response: + return 200 <= response.status_code < 400 return False except requests.exceptions.RequestException: return False diff --git a/tests/parser/test_cluster_analysis.py b/tests/parser/test_cluster_analysis.py index be67eec..d0a247e 100644 --- a/tests/parser/test_cluster_analysis.py +++ b/tests/parser/test_cluster_analysis.py @@ -439,8 +439,8 @@ def test_get_data_map(self, mock_mstx_profiler_structure): assert ("rollout_generate", 0) in data_map assert len(data_map[("rollout_generate", 0)]) == 1 - def test_get_data_map_sorts_by_legacy_underscore_segment(self, tmp_path): - """Directory ordering should follow the legacy third-from-last segment.""" + def test_get_data_map_sorts_by_legacy_datetime_segment(self, tmp_path): + """Directory ordering should follow the legacy date+time underscore layout.""" parser = MstxClusterParser( { Constant.INPUT_PATH: str(tmp_path), @@ -448,8 +448,8 @@ def test_get_data_map_sorts_by_legacy_underscore_segment(self, tmp_path): } ) - first = tmp_path / "role_a" / "20250101_110000_ascend_pt" - second = tmp_path / "role_a" / "20250102_120000_ascend_pt" + first = tmp_path / "role_a" / "20250101_230000_ascend_pt" + second = tmp_path / "role_a" / "20250102_010000_ascend_pt" first.mkdir(parents=True) second.mkdir(parents=True) (first / "profiler_info_0.json").write_text('{"rank_id": 0}') From 4202f91e18f649effa981b08631d9341f1a26a7f Mon Sep 17 00:00:00 2001 From: FightingZhen <295632982@qq.com> Date: Mon, 18 May 2026 15:58:15 +0800 Subject: [PATCH 07/10] fix: accept Path inputs in validators and normalize GMM paths --- rl_insight/data/rules.py | 109 +++++++++++++------------------- rl_insight/parser/gmm_parser.py | 2 +- tests/data/test_rules.py | 24 +++++++ tests/parser/test_gmm_parser.py | 7 ++ 4 files changed, 77 insertions(+), 65 deletions(-) diff --git a/rl_insight/data/rules.py b/rl_insight/data/rules.py index fd50198..bc2f5aa 100644 --- a/rl_insight/data/rules.py +++ b/rl_insight/data/rules.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025 verl-project authors. +# Copyright (c) 2025 verl-project authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,10 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import glob import gzip import json -import os from typing import Any, List, Optional from abc import ABC, abstractmethod from pathlib import Path @@ -96,16 +94,13 @@ def check(self, data) -> bool: profiler_info_filename = "profiler_info_*.json" # get all *_ascend_pt path - ascend_pt_pattern = str(root_path / "*" / "*_ascend_pt") - ascend_pt_folders = glob.glob(ascend_pt_pattern) + ascend_pt_folders = list(root_path.glob("*/*_ascend_pt")) if not ascend_pt_folders: self._error_message = f"No *_ascend_pt path in {root_path}" return False - for ascend_pt_folder in ascend_pt_folders: - ascend_pt_path = Path(ascend_pt_folder) - + for ascend_pt_path in ascend_pt_folders: if not ascend_pt_path.is_dir(): continue @@ -118,8 +113,7 @@ def check(self, data) -> bool: return False # get profiler_info_*.json file path - profiler_pattern = str(ascend_pt_path / profiler_info_filename) - profiler_files = glob.glob(profiler_pattern) + profiler_files = list(ascend_pt_path.glob(profiler_info_filename)) if not profiler_files: self._error_message = ( @@ -151,15 +145,13 @@ def check(self, data) -> bool: return False # get all *_ascend_pt path - ascend_pt_pattern = str(root_path / "*" / "*_ascend_pt") - ascend_pt_folders = glob.glob(ascend_pt_pattern) + ascend_pt_folders = list(root_path.glob("*/*_ascend_pt")) if not ascend_pt_folders: self._error_message = f"No *_ascend_pt path in {root_path}" return False - for ascend_pt_folder in ascend_pt_folders: - ascend_pt_path = Path(ascend_pt_folder) + for ascend_pt_path in ascend_pt_folders: # valid trace_view.json format trace_view_path = ( @@ -170,7 +162,7 @@ def check(self, data) -> bool: f"Missing trace_view.json in: {trace_view_path.parent}" ) return False - if os.path.getsize(trace_view_path) == 0: + if trace_view_path.stat().st_size == 0: self._error_message = f"File is empty: {trace_view_path}" return False try: @@ -194,25 +186,26 @@ def check(self, data) -> bool: return False # valid profiler_info_*.json format - profiler_pattern = str(ascend_pt_path / "profiler_info_*.json") - profiler_info_files = glob.glob(profiler_pattern) + profiler_info_files = list(ascend_pt_path.glob("profiler_info_*.json")) if not profiler_info_files: self._error_message = ( f"profiler_info_*.json does not exist in: {ascend_pt_path}" ) return False - for file in profiler_info_files: - if os.path.getsize(file) == 0: - self._error_message = f"File is empty: {file}" + for file_path in profiler_info_files: + if file_path.stat().st_size == 0: + self._error_message = f"File is empty: {file_path}" return False try: - with open(file, "r", encoding="utf-8") as f: + with open(file_path, "r", encoding="utf-8") as f: profiler_info_data = json.load(f) except Exception as exc: - self._error_message = f"Failed to parse JSON file {file}: {exc}" + self._error_message = ( + f"Failed to parse JSON file {file_path}: {exc}" + ) return False if len(profiler_info_data) == 0: - self._error_message = f"File is empty: {file}" + self._error_message = f"File is empty: {file_path}" return False required_keys = { "config", @@ -225,7 +218,7 @@ def check(self, data) -> bool: missing_keys = required_keys - set(profiler_info_data.keys()) if missing_keys: self._error_message = ( - f"File field is missing: {missing_keys} in FilePath: {file}" + f"File field is missing: {missing_keys} in FilePath: {file_path}" ) return False return True @@ -277,17 +270,17 @@ class TorchJsonFileExistsRule(ValidationRule): """valid Torch *.json.gz files is existed in 'torch_profile' sub path""" def check(self, data) -> bool: - if not isinstance(data, str): + root_path = _coerce_path(data) + if root_path is None: self._error_message = "Data object is not a path" return False self._error_message = "" try: - root_path = Path(data) # 路径:torch_profile is_success = True sub_dirs_no_json: List = [] if not root_path.exists(): - self._error_message = f"Source path does not exist: {data}" + self._error_message = f"Source path does not exist: {root_path}" return False for subdir in root_path.iterdir(): if subdir.is_dir(): @@ -301,7 +294,7 @@ def check(self, data) -> bool: return is_success except Exception as e: - self._error_message = f"Error checking path {data}: {e}" + self._error_message = f"Error checking path {root_path}: {e}" return False @property @@ -313,27 +306,19 @@ class TorchJsonFieldValidRule(ValidationRule): """valid torch *.json.gz files JSON format""" def check(self, data) -> bool: - if not isinstance(data, str): + root_path = _coerce_path(data) + if root_path is None: self._error_message = "Data object is not a path" return False self._error_message = "" try: - root_path = Path(data) - if not root_path.exists(): - self._error_message = f"Source path does not exist: {data}" + self._error_message = f"Source path does not exist: {root_path}" return False - for item in os.listdir(root_path): - item_path = os.path.join(root_path, item) - # 检查是否为目录 - if os.path.isdir(item_path): - # 查找该子目录下所有.json.gz文件 - json_gz_pattern = os.path.join(item_path, "*.json.gz") - json_gz_files = glob.glob(json_gz_pattern) - for json_gz_file in json_gz_files: - # 打开并读取.json.gz文件 + for item_path in root_path.iterdir(): + if item_path.is_dir(): + for json_gz_file in item_path.glob("*.json.gz"): with gzip.open(json_gz_file, "rt", encoding="utf-8") as f: - # 加载JSON数据 json_data = json.load(f) if len(json_data) == 0: self._error_message = f"File is empty: {json_gz_file}" @@ -369,7 +354,7 @@ def check(self, data) -> bool: return True except Exception as e: - self._error_message = f"Error checking path {data}: {e}" + self._error_message = f"Error checking path {root_path}: {e}" return False @property @@ -381,31 +366,29 @@ class NvtxJsonFileExistsRule(ValidationRule): """valid worker_process.*.*.jsonl files is existed in 'nvtx_profile' sub path""" def check(self, data) -> bool: - if not isinstance(data, str): + root_path = _coerce_path(data) + if root_path is None: self._error_message = "Data object is not a path" return False self._error_message = "" try: - root_path = Path(data) - if not root_path.exists(): - self._error_message = f"Source path does not exist: {data}" + self._error_message = f"Source path does not exist: {root_path}" return False profiler_info_filename = "worker_process_*.*.jsonl" - worker_pattern = str(root_path / profiler_info_filename) - worker_files = glob.glob(worker_pattern) + worker_files = list(root_path.glob(profiler_info_filename)) if not worker_files: self._error_message = ( - f"No worker_process_*.*.jsonl file found in: {data}" + f"No worker_process_*.*.jsonl file found in: {root_path}" ) return False return True except Exception as e: - self._error_message = f"Error checking path {data}: {e}" + self._error_message = f"Error checking path {root_path}: {e}" return False @property @@ -417,21 +400,19 @@ class NvtxJsonFieldValidRule(ValidationRule): """valid nvtx worker_process_*.*.jsonl files JSON format""" def check(self, data) -> bool: - if not isinstance(data, str): + root_path = _coerce_path(data) + if root_path is None: self._error_message = "Data object is not a path" return False self._error_message = "" try: - root_path = Path(data) - if not root_path.exists(): - self._error_message = f"Source path does not exist: {data}" + self._error_message = f"Source path does not exist: {root_path}" return False profiler_info_filename = "worker_process_*.*.jsonl" - worker_pattern = str(root_path / profiler_info_filename) - worker_files = glob.glob(worker_pattern) + worker_files = list(root_path.glob(profiler_info_filename)) required_for_event = {"start", "end", "textId"} @@ -478,7 +459,7 @@ def check(self, data) -> bool: return True except Exception as e: - self._error_message = f"Error checking path {data}: {e}" + self._error_message = f"Error checking path {root_path}: {e}" return False @property @@ -490,26 +471,25 @@ class GmmDataRule(ValidationRule): """Validation rule for GMM data.""" def check(self, data: Any) -> bool: - if not isinstance(data, str): + root_path = _coerce_path(data) + if root_path is None: self._error_message = "Data object is not a path" return False try: - root_path = Path(data) - if not root_path.exists(): - self._error_message = f"Source path does not exist: {data}" + self._error_message = f"Source path does not exist: {root_path}" return False group_list_files = list(root_path.rglob("*group_list.pt")) if not group_list_files: - self._error_message = f"No group_list.pt files found in: {data}" + self._error_message = f"No group_list.pt files found in: {root_path}" return False valid_files = [f for f in group_list_files if "dump_tensor_data" in f.parts] if not valid_files: self._error_message = ( "No group_list.pt files found in dump_tensor_data directories " - f"under: {data}" + f"under: {root_path}" ) return False @@ -517,3 +497,4 @@ def check(self, data: Any) -> bool: except Exception as e: self._error_message = f"Error checking GMM data: {e}" return False + diff --git a/rl_insight/parser/gmm_parser.py b/rl_insight/parser/gmm_parser.py index 835982a..83ec416 100644 --- a/rl_insight/parser/gmm_parser.py +++ b/rl_insight/parser/gmm_parser.py @@ -64,7 +64,7 @@ def __init__(self, params) -> None: @staticmethod def _normalize_path_text(path_value: str | Path) -> str: - return str(path_value).replace("\\", "/") + return Path(path_value).as_posix() @classmethod def _extract_rank_id_from_path(cls, path_value: str | Path) -> int: diff --git a/tests/data/test_rules.py b/tests/data/test_rules.py index f6eb585..ce1ace7 100644 --- a/tests/data/test_rules.py +++ b/tests/data/test_rules.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from pathlib import Path + from rl_insight.data.rules import ( DataValidationError, PathExistsRule, @@ -57,6 +59,11 @@ def test_mstx_jsonfile_exists(): assert file_rule.check(str(MSTX_PROFILE_PATH)) is True +def test_mstx_jsonfile_exists_accepts_path_object(): + file_rule = MstxJsonFileExistsRule() + assert file_rule.check(MSTX_PROFILE_PATH) is True + + def test_mstx_jsonfile_exists_with_fake_path(): file_rule = MstxJsonFileExistsRule() fake_path = "fake_path" @@ -147,6 +154,11 @@ def test_torch_jsonfile_exists(): assert file_rule.check(str(TORCH_PROFILE_PATH)) is True +def test_torch_jsonfile_exists_accepts_path_object(): + file_rule = TorchJsonFileExistsRule() + assert file_rule.check(TORCH_PROFILE_PATH) is True + + def test_torch_jsonfile_exists_with_fake_path(): file_rule = TorchJsonFileExistsRule() fake_path = "fake_path" @@ -167,12 +179,24 @@ def test_nvtx_jsonfile_exists(): assert file_rule.check(str(NVTX_PROFILE_PATH)) is True +def test_nvtx_jsonfile_exists_accepts_path_object(): + file_rule = NvtxJsonFileExistsRule() + assert file_rule.check(NVTX_PROFILE_PATH) is True + + def test_nvtx_jsonfile_exists_with_fake_path(): file_rule = NvtxJsonFileExistsRule() fake_path = "fake_path" assert file_rule.check(fake_path) is False +def test_gmm_data_rule_accepts_path_object(): + from rl_insight.data.rules import GmmDataRule + + rule = GmmDataRule() + assert rule.check(Path("data/gmm_data")) is True + + def test_nvtx_json_fields_valid(): path_rule = PathExistsRule() field_rule = NvtxJsonFieldValidRule() diff --git a/tests/parser/test_gmm_parser.py b/tests/parser/test_gmm_parser.py index dcdbddd..49d7bc4 100644 --- a/tests/parser/test_gmm_parser.py +++ b/tests/parser/test_gmm_parser.py @@ -27,3 +27,10 @@ def test_gmm_path_parsing_is_cross_platform(): assert parser._extract_rank_id_from_path(windows_style_path) == 0 assert parser._extract_step_from_path(windows_style_path) == 1 assert parser._training_step_from_path(windows_style_path) == 1 + + +def test_gmm_normalize_path_text_returns_posix(): + parser = GmmParser({Constant.RANK_LIST: "all"}) + assert parser._normalize_path_text(r"C:\workspace\gmm_dump\step_1") == ( + "C:/workspace/gmm_dump/step_1" + ) From e0d214aaf6ed3954f2a7599ad5bced8c7083fa3b Mon Sep 17 00:00:00 2001 From: FightingZhen <295632982@qq.com> Date: Mon, 18 May 2026 20:43:33 +0800 Subject: [PATCH 08/10] fix: apply pre-commit cleanup for validator tests --- rl_insight/data/rules.py | 8 ++------ tests/data/test_data_checker.py | 2 -- tests/data/test_rules.py | 6 +++++- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/rl_insight/data/rules.py b/rl_insight/data/rules.py index bc2f5aa..1a2e597 100644 --- a/rl_insight/data/rules.py +++ b/rl_insight/data/rules.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025 verl-project authors. +# Copyright (c) 2025 verl-project authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -152,7 +152,6 @@ def check(self, data) -> bool: return False for ascend_pt_path in ascend_pt_folders: - # valid trace_view.json format trace_view_path = ( ascend_pt_path / "ASCEND_PROFILER_OUTPUT" / "trace_view.json" @@ -217,9 +216,7 @@ def check(self, data) -> bool: } missing_keys = required_keys - set(profiler_info_data.keys()) if missing_keys: - self._error_message = ( - f"File field is missing: {missing_keys} in FilePath: {file_path}" - ) + self._error_message = f"File field is missing: {missing_keys} in FilePath: {file_path}" return False return True except Exception as e: @@ -497,4 +494,3 @@ def check(self, data: Any) -> bool: except Exception as e: self._error_message = f"Error checking GMM data: {e}" return False - diff --git a/tests/data/test_data_checker.py b/tests/data/test_data_checker.py index 5fc172c..cd1a6de 100644 --- a/tests/data/test_data_checker.py +++ b/tests/data/test_data_checker.py @@ -21,9 +21,7 @@ from rl_insight.data.rules import DataValidationError from tests.data.test_paths import ( MSTX_PROFILE_PATH, - NVTX_PROFILE_PATH, PROJECT_ROOT, - TORCH_PROFILE_PATH, ) diff --git a/tests/data/test_rules.py b/tests/data/test_rules.py index ce1ace7..27bd2df 100644 --- a/tests/data/test_rules.py +++ b/tests/data/test_rules.py @@ -25,7 +25,11 @@ NvtxJsonFieldValidRule, ) from rl_insight.data.verl_log_rules import VerlLogExistRule, VerlLogKeyParamsRule -from tests.data.test_paths import MSTX_PROFILE_PATH, NVTX_PROFILE_PATH, TORCH_PROFILE_PATH +from tests.data.test_paths import ( + MSTX_PROFILE_PATH, + NVTX_PROFILE_PATH, + TORCH_PROFILE_PATH, +) def test_path_exists_rule_accepts_existing_directory(): From 4955015132e012b32a2a7b39afaffa6ad74e95c3 Mon Sep 17 00:00:00 2001 From: FightingZhen <295632982@qq.com> Date: Tue, 19 May 2026 15:44:33 +0800 Subject: [PATCH 09/10] update requirements.txt --- requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 0ba611a..a7d420e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,5 @@ plotly matplotlib pytest loguru -kaleido \ No newline at end of file +kaleido +torch==2.7.1 \ No newline at end of file From e820d0d25191ed143916b217f65d488b578698c8 Mon Sep 17 00:00:00 2001 From: FightingZhen <295632982@qq.com> Date: Tue, 19 May 2026 16:52:13 +0800 Subject: [PATCH 10/10] fix ut error & update requirements --- requirements.txt | 3 ++- rl_insight/parser/gmm_parser.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index a7d420e..cc8b83e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,4 +6,5 @@ matplotlib pytest loguru kaleido -torch==2.7.1 \ No newline at end of file +torch==2.7.1 +requests \ No newline at end of file diff --git a/rl_insight/parser/gmm_parser.py b/rl_insight/parser/gmm_parser.py index 83ec416..835982a 100644 --- a/rl_insight/parser/gmm_parser.py +++ b/rl_insight/parser/gmm_parser.py @@ -64,7 +64,7 @@ def __init__(self, params) -> None: @staticmethod def _normalize_path_text(path_value: str | Path) -> str: - return Path(path_value).as_posix() + return str(path_value).replace("\\", "/") @classmethod def _extract_rank_id_from_path(cls, path_value: str | Path) -> int: