From 58620056824d48cd3cc99c28c2c9ae3af703249d Mon Sep 17 00:00:00 2001 From: acat-rw <892882856@qq.com> Date: Thu, 21 May 2026 03:36:10 +0000 Subject: [PATCH 1/2] Use omegaconf structured configuration system and migrate CLI --- docs/data/data_specification.md | 2 +- docs/developer_guides/extending_guide.md | 11 +- docs/overview/RL_Timeline_quickstart.md | 51 ++-- docs/overview/architecture.md | 3 +- docs/overview/gmm_heatmap_quickstart.md | 68 +++--- examples/gmm_exec.sh | 27 +-- examples/mstx_exec.sh | 13 +- examples/nvtx_exec.sh | 13 +- examples/torch_profiler_exec.sh | 13 +- rl_insight/config/__init__.py | 43 ++++ rl_insight/config/config.py | 103 ++++++++ rl_insight/config/config_loader.py | 225 ++++++++++++++++++ rl_insight/config/gmm.yaml | 22 ++ rl_insight/config/timeline.yaml | 19 ++ rl_insight/config/utils.py | 69 ++++++ rl_insight/main.py | 60 ++--- rl_insight/parser/__init__.py | 19 -- rl_insight/parser/gmm_parser.py | 16 +- rl_insight/parser/parser.py | 18 +- .../pipeline/offline_insight_pipeline.py | 38 +-- rl_insight/visualizer/__init__.py | 26 -- rl_insight/visualizer/gmm_visualizer.py | 13 +- rl_insight/visualizer/timeline_visualizer.py | 15 +- rl_insight/visualizer/visualizer.py | 6 +- tests/parser/test_cluster_analysis.py | 12 +- tests/parser/test_png_visualizer.py | 5 +- tests/special_e2e/test_gmm_e2e.py | 10 +- tests/special_e2e/test_mstx_e2e.py | 7 +- tests/special_e2e/test_nvtx_e2e.py | 9 +- tests/special_e2e/test_torch_e2e.py | 9 +- 30 files changed, 692 insertions(+), 253 deletions(-) create mode 100644 rl_insight/config/__init__.py create mode 100644 rl_insight/config/config.py create mode 100644 rl_insight/config/config_loader.py create mode 100644 rl_insight/config/gmm.yaml create mode 100644 rl_insight/config/timeline.yaml create mode 100644 rl_insight/config/utils.py diff --git a/docs/data/data_specification.md b/docs/data/data_specification.md index e5ecc8f..4606b20 100644 --- a/docs/data/data_specification.md +++ b/docs/data/data_specification.md @@ -236,7 +236,7 @@ python tests/data/check_verl_log.py data/verl_data/good_minimal_verl.log ## 5. GMM 专家负载dump数据 -GMM 热力图输入类型为 `DataEnum.GMM_DATA`(CLI:`--input-type gmm_data`、`--profiler-type gmm`)。**路径约定、参数与示意图**见 [`docs/overview/gmm_heatmap_quickstart.md`](../overview/gmm_heatmap_quickstart.md)。本节补充数据侧目录与文件格式说明。 +GMM 热力图输入类型为 `DataEnum.GMM_DATA`(CLI:`input.input_type=gmm_data`、`input.profiler_type=gmm`)。**路径约定、参数与示意图**见 [`docs/overview/gmm_heatmap_quickstart.md`](../overview/gmm_heatmap_quickstart.md)。本节补充数据侧目录与文件格式说明。 ### 5.1 目录结构 diff --git a/docs/developer_guides/extending_guide.md b/docs/developer_guides/extending_guide.md index 8048ca4..29477e8 100644 --- a/docs/developer_guides/extending_guide.md +++ b/docs/developer_guides/extending_guide.md @@ -18,14 +18,16 @@ 1. 新增模块,例如 `rl_insight/parser/my_parser.py`。 2. 继承 `BaseClusterParser`,实现 `run()` 方法。 3. `@register_cluster_parser("")`,保证 `get_cluster_parser_cls("")` 可用。 -4. 更新 `main.py` 中 `--profiler-type` 的 help 与相关用户文档。 +4. 若有配置参数,在 `rl_insight/config/config.py` 对应场景的 `ParserConfig` 中添加字段。 +5. 更新相关用户文档。 **Visualizer** 1. 新增模块,例如 `rl_insight/visualizer/my_visualizer.py`。 2. 继承 `BaseVisualizer`,实现 `run()` 方法。 3. `@register_cluster_visualizer("")`,保证 `get_cluster_visualizer_cls("")` 可用。 -4. 更新 `main.py` 中 `--vis-type` 的 help 与相关用户文档。 +4. 若有配置参数,在 `rl_insight/config/config.py` 对应场景的 `VisualizerConfig` 中添加字段。 +5. 更新相关用户文档。 若输入或中间数据形态变化,需同步按上一节扩展 **DataRule**。 @@ -34,6 +36,5 @@ 适用于:全新的处理范式(跳过步骤、插入预处理、多产物、在线多进程流程等)。 1. 在 `rl_insight/pipeline/` 新增类,实现 `__init__(self, config)`、`run(self)`,按需组合 `DataChecker`、`get_cluster_parser_cls`、`get_cluster_visualizer_cls` 等。 -2. 在 `main.py` 的 `SUPPORTED_PIPELINE_TYPES` 中注册,例如 `{"MyPipeline": MyPipeline}`。 -3. 更新 `--pipeline-type` 的 help,名称与 dict key 一致,并更新文档。 -4. 若数据解析或数据类型发生变化,同步扩展 **DataRule** / **Parser** / **Visualizer**。 +2. 在 `rl_insight/config/config.py` 的 `PipelineConfig.pipeline_type` 默认值或 preset YAML 中注册新 pipeline 类型。 +3. 若数据解析或数据类型发生变化,同步扩展 **DataRule** / **Parser** / **Visualizer**。 diff --git a/docs/overview/RL_Timeline_quickstart.md b/docs/overview/RL_Timeline_quickstart.md index 3d9d58c..aa9bda9 100644 --- a/docs/overview/RL_Timeline_quickstart.md +++ b/docs/overview/RL_Timeline_quickstart.md @@ -56,9 +56,10 @@ pip install -e . ```bash python -m rl_insight.main \ - --input-path \ - --profiler-type mstx \ - --output-path + input.input_path= \ + input.profiler_type=mstx \ + input.input_type=multi_json_nvtx \ + output.output_path= ``` 或修改并直接使用 `examples/mstx_exec.sh` 脚本: @@ -73,9 +74,10 @@ bash examples/mstx_exec.sh ```bash python -m rl_insight.main \ - --input-path \ - --profiler-type torch \ - --output-path + input.input_path= \ + input.profiler_type=torch \ + input.input_type=multi_json_torch \ + output.output_path= ``` 或修改并直接使用 `examples/torch_profiler_exec.sh` 脚本: @@ -90,9 +92,10 @@ bash examples/torch_profiler_exec.sh ```bash python -m rl_insight.main \ - --input-path \ - --profiler-type nvtx \ - --output-path + input.input_path= \ + input.profiler_type=nvtx \ + input.input_type=multi_json_nvtx \ + output.output_path= ``` 或修改并直接使用 `examples/nvtx_exec.sh` 脚本: @@ -103,17 +106,27 @@ bash examples/nvtx_exec.sh ## 4. 命令行参数 -以下说明与 `python -m rl_insight.main --help` 保持一致;若有出入以命令行帮助为准。 +以下说明与 `python -m rl_insight.main -h` 保持一致;若有出入以命令行帮助为准。 + +### 4.1 公共参数 + +| 参数 | 默认值 | 说明 | +|------|--------|------| +| `input.input_path` | (必填) | Profiling 数据的根目录路径 | +| `input.input_type` | `multi_json_mstx` | 输入数据类型(`multi_json_mstx`、`multi_json_torch`、`multi_json_nvtx`)| +| `input.profiler_type` | `mstx` | 性能数据种类:`mstx`、`torch`、`nvtx` | +| `input.rank_list` | `all` | Rank ID 列表,如 `0,1,2` 或 `all` | +| `output.output_path` | `output` | 输出目录 | +| `preset` | 自动推断 | 预设名称:`timeline`、`gmm`(根据 `profiler_type` 自动推断) | +| `config_path` | 无 | YAML 配置文件路径 | + +### 4.2 Timeline 参数 | 参数 | 默认值 | 说明 | -|------|--------|----| -| `--input-path` | (必填,无默认值) | Profiling 数据的根目录路径 | -| `--input-type` | `multi_json` | 输入数据类型(多目录 JSON 布局等)| -| `--profiler-type` | `mstx` | 性能数据种类:`mstx`、`torch`、`nvtx` | -| `--output-path` | `output` | 输出目录 | -| `--vis-type` | `html` | 可视化类型(当前支持 `html`、`png`) | -| `--rank-list` | `all` | Rank ID 列表(当前仅支持 `all`) | -| `--pipeline-type` | `OfflineInsightPipeline` | 流水线实现类型 | +|------|--------|------| +| `timeline.visualizer.vis_type` | `html` | 可视化类型:`html`、`png` | +| `timeline.visualizer.width` | `2000` | 图片宽度(仅 png) | +| `timeline.visualizer.scale` | `2` | 图片缩放因子(仅 png) | ## 5. 输出说明 @@ -145,7 +158,7 @@ bash examples/nvtx_exec.sh ## 6. 注意事项 -1. RL 分析功能当前仅支持处理所有 Rank(`--rank-list` 参数暂不支持过滤功能) +1. RL 分析功能当前仅支持处理所有 Rank(`input.rank_list` 参数暂不支持过滤功能) 2. 至少采集 level0 及以上数据(不支持 level_none 级数据) 3. 采用离散模式采集 `discrete=True` 4. MSTX 数据满足以下要求: diff --git a/docs/overview/architecture.md b/docs/overview/architecture.md index c3e5c7d..c880546 100644 --- a/docs/overview/architecture.md +++ b/docs/overview/architecture.md @@ -20,7 +20,8 @@ | Concept | Location | Role | |---------|----------|------| -| Entry | `rl_insight/main.py`, `rl_insight/pipeline/` | `main` 对接 CLI;`pipeline` 定义业务流程并选择 **Parser** / **Visualizer**。 | +| Entry | `rl_insight/main.py`, `rl_insight/pipeline/` | `main` 对接 CLI(`key=value` 格式);`pipeline` 定义业务流程并选择 **Parser** / **Visualizer**。 | +| Config | `rl_insight/config/config.py`, `rl_insight/config/config_loader.py` | 基于 OmegaConf 的结构化配置,dataclass 定义 schema 与默认值,支持YAML preset场景化覆盖。 | | DataRule | `rl_insight/data/data_checker.py`, `rl_insight/data/rules.py` | `DataEnum` 区分数据阶段;`DataChecker` 按类型执行对应的 `ValidationRule`。 | | Parser | `rl_insight/parser/parser.py`, `rl_insight/parser/*_parser.py` | 基于约定的 `input_type` 做解析;字段约定见 `rl_insight/utils/schema.py`(`DataMap`、`EventRow`、`Constant` 等)。 | | Visualizer | `rl_insight/visualizer/visualizer.py`, `rl_insight/visualizer/timeline_visualizer.py`, … | 消费 **Parser** 输出,基于约定的 `input_type` 做可视化。 | diff --git a/docs/overview/gmm_heatmap_quickstart.md b/docs/overview/gmm_heatmap_quickstart.md index 4fc628e..25e220b 100644 --- a/docs/overview/gmm_heatmap_quickstart.md +++ b/docs/overview/gmm_heatmap_quickstart.md @@ -59,9 +59,9 @@ gmm_dump/ 路径字段含义: -- `step_`:训练 step(对应 `--step` 过滤) -- ``:角色名(对应 `--role` 过滤) -- `rank`:rank id(对应 `--rank-list` 过滤) +- `step_`:训练 step(对应 `gmm.parser.step` 过滤) +- ``:角色名(对应 `gmm.parser.role` 过滤) +- `rank`:rank id(对应 `input.rank_list` 过滤) - `dump_tensor_data/*.group_list.pt`:MoE grouped_matmul 的专家负载;典型为一维整型张量,第 `i` 个元素表示第 `i` 个 expert 分到的 **token 数** ### 2.2 执行分析脚本 @@ -69,15 +69,15 @@ gmm_dump/ #### GMM 热力图使用示例 ```bash -python -m rl_insight.main - --input-path - --input-type gmm_data - --profiler-type gmm - --vis-type gmm_heatmap - --rank-list all - --step 1 - --role actor_compute_log_prob - --output-path +python -m rl_insight.main \ + input.input_path= \ + input.input_type=gmm_data \ + input.profiler_type=gmm \ + input.rank_list=all \ + gmm.visualizer.vis_type=gmm_heatmap \ + gmm.parser.step=1 \ + gmm.parser.role=actor_compute_log_prob \ + output.output_path= ``` 或修改并直接使用 `examples/gmm_exec.sh` 脚本: @@ -88,22 +88,30 @@ bash examples/gmm_exec.sh ## 三、命令行参数 -以下说明与 `python -m rl_insight.main --help` 保持一致;若有出入以命令行帮助为准。 +以下说明与 `python -m rl_insight.main -h` 保持一致;若有出入以命令行帮助为准。 + +### 3.1 公共参数 + +| 参数 | 默认值 | 说明 | +|------|--------|------| +| `input.input_path` | (必填) | GMM 数据的根目录路径 | +| `input.input_type` | `multi_json_mstx` | 输入数据类型,GMM 功能需设置为 `gmm_data` | +| `input.profiler_type` | `mstx` | 性能数据种类,GMM 功能需设置为 `gmm` | +| `output.output_path` | `output` | 输出路径,若为文件夹则在其中生成 `gmm_heatmap.png` | +| `input.rank_list` | `all` | Rank ID 列表,默认 `all` 表示所有 rank,可指定多个 rank 用逗号分隔 | +| `preset` | 自动推断 | 预设名称:`timeline`、`gmm`(根据 `profiler_type` 自动推断) | +| `config_path` | 无 | YAML 配置文件路径 | + +### 3.2 GMM 参数 | 参数 | 默认值 | 说明 | -|------|--------|----| -| `--input-path` | (必填,无默认值) | GMM 数据的根目录路径 | -| `--input-type` | `multi_json_mstx` | 输入数据类型,GMM 功能需设置为 `gmm_data` | -| `--profiler-type` | `mstx` | 性能数据种类,GMM 功能需设置为 `gmm` | -| `--output-path` | `output` | 输出路径,若为文件夹则在其中生成 `gmm_heatmap.png` | -| `--vis-type` | `html` | 可视化类型,GMM 功能需设置为 `gmm_heatmap` | -| `--rank-list` | `all` | Rank ID 列表,默认 `all` 表示所有 rank,可指定多个 rank 用逗号分隔 | -| `--pipeline-type` | `OfflineInsightPipeline` | 流水线实现类型 | -| `--step` | 无默认值 | 特定的 step 进行可视化(可选,支持 `1` 或 `1,2`) | -| `--role` | 无默认值 | 特定的 role 进行可视化(可选) | -| `--gmm-per-layer` | 3 | 每个 MoE layer 前向阶段预期的 grouped_matmul 次数,用于 actor_update 前向截断判定 | -| `--dpi` | 150 | 热力图输出的 DPI(默认 150) | -| `--cmap` | viridis | 热力图的颜色映射(默认 viridis) | +|------|--------|------| +| `gmm.visualizer.vis_type` | `gmm_heatmap` | 可视化类型 | +| `gmm.parser.step` | 无 | 特定的 step 进行可视化(可选,支持 `1` 或 `1,2`) | +| `gmm.parser.role` | 无 | 特定的 role 进行可视化(可选) | +| `gmm.visualizer.dpi` | `200` | 热力图输出的 DPI | +| `gmm.visualizer.cmap` | `viridis` | 热力图的颜色映射 | +| `gmm.visualizer.gmm_per_layer` | `3` | 每个 MoE layer 前向阶段预期的 grouped_matmul 次数 | ## 四、输出说明 @@ -132,11 +140,11 @@ bash examples/gmm_exec.sh ## 五、注意事项 -1. GMM 热力图功能需要使用 `--input-type gmm_data` 和 `--profiler-type gmm` 参数 -2. 当 `--output-path` 只指定文件夹路径时,会在该文件夹中生成 `gmm_heatmap.png` 文件 -3. 当不指定 `--step`、`--role` 或 `--rank-list` 参数时,默认显示所有数据 +1. GMM 热力图功能需要使用 `input.input_type=gmm_data` 和 `input.profiler_type=gmm` 参数 +2. 当 `output.output_path` 只指定文件夹路径时,会在该文件夹中生成 `gmm_heatmap.png` 文件 +3. 当不指定 `gmm.parser.step`、`gmm.parser.role` 或 `input.rank_list` 参数时,默认显示所有数据 4. 对于大量数据,工具会自动调整图表大小和标签显示密度,确保可读性 5. 数据文件需包含有效的专家负载数据,包括 step、role、rank_id、stage、expert_index 和 load 等字段 -6. 若你的模型实现中每层 grouped_matmul 次数不等于 3,请显式设置 `--gmm-per-layer` 以获得更准确的 actor_update 前向阶段截断结果 +6. 若你的模型实现中每层 grouped_matmul 次数不等于 3,请显式设置 `gmm.visualizer.gmm_per_layer` 以获得更准确的 actor_update 前向阶段截断结果 目录与 JSON 字段的集中说明另见 [数据规格与格式说明](../data/data_specification.md)。运行时校验逻辑以 `rl_insight.data.DataChecker` 及 [`rl_insight/data/rules.py`](../../rl_insight/data/rules.py) 中的规则定义为准。 \ No newline at end of file diff --git a/examples/gmm_exec.sh b/examples/gmm_exec.sh index d22b7a9..16fd9a0 100644 --- a/examples/gmm_exec.sh +++ b/examples/gmm_exec.sh @@ -28,32 +28,29 @@ echo "==========================================" # Build command cmd="python -m rl_insight.main \ - --input-path \"${GMM_DATA_PATH}\" \ - --input-type \"gmm_data\" \ - --profiler-type \"gmm\" \ - --vis-type \"gmm_heatmap\" \ - --output-path \"${OUTPUT_PATH}\" \ - --rank-list \"${RANK_LIST}\" \ - --dpi \"${DPI}\" \ - --cmap \"${CMAP}\" \ - --gmm-per-layer \"${GMM_PER_LAYER}\"" - + input.input_path=\"${GMM_DATA_PATH}\" \ + input.input_type=gmm_data \ + input.profiler_type=gmm \ + output.output_path=\"${OUTPUT_PATH}\" \ + input.rank_list=\"${RANK_LIST}\" \ + gmm.visualizer.vis_type=gmm_heatmap \ + gmm.visualizer.dpi=\"${DPI}\" \ + gmm.visualizer.cmap=\"${CMAP}\" \ + gmm.visualizer.gmm_per_layer=\"${GMM_PER_LAYER}\"" # Add step and role parameters if specified if [ -n "${STEP}" ]; then cmd="${cmd} \ - --step \"${STEP}\"" + gmm.parser.step=\"${STEP}\"" fi if [ -n "${ROLE}" ]; then cmd="${cmd} \ - --role \"${ROLE}\"" + gmm.parser.role=\"${ROLE}\"" fi -# Execute the command echo ">>> Generating GMM expert load heatmap..." eval ${cmd} -# Check if the heatmap was generated successfully if [ -f "${OUTPUT_PATH}" ]; then echo "==========================================" echo ">>> Heatmap generated successfully!" @@ -64,4 +61,4 @@ else echo ">>> Failed to generate heatmap" echo "==========================================" exit 1 -fi \ No newline at end of file +fi diff --git a/examples/mstx_exec.sh b/examples/mstx_exec.sh index f485512..ea55c40 100644 --- a/examples/mstx_exec.sh +++ b/examples/mstx_exec.sh @@ -4,7 +4,6 @@ set -euo pipefail MSTX_PROFILER_DATA_PATH="${MSTX_PROFILER_DATA_PATH:-}" OUTPUT_PATH="${OUTPUT_PATH:-./output}" -PROFILER_TYPE="${PROFILER_TYPE:-mstx}" VIS_TYPE="${VIS_TYPE:-html}" RANK_LIST="${RANK_LIST:-all}" @@ -13,7 +12,6 @@ echo "MSTX Profiler Cluster Analysis" echo "==========================================" echo "Input Path: ${MSTX_PROFILER_DATA_PATH}" echo "Output Path: ${OUTPUT_PATH}" -echo "Profiler Type: ${PROFILER_TYPE}" echo "Vis Type: ${VIS_TYPE}" echo "Rank List: ${RANK_LIST}" echo "==========================================" @@ -27,11 +25,12 @@ python -m rl_insight.utils.mstx_preprocessing "${MSTX_PROFILER_DATA_PATH}" echo ">>> Mstx data preprocessing completed." python -m rl_insight.main \ - --input-path "${MSTX_PROFILER_DATA_PATH}" \ - --profiler-type "${PROFILER_TYPE}" \ - --output-path "${OUTPUT_PATH}" \ - --vis-type "${VIS_TYPE}" \ - --rank-list "${RANK_LIST}" + input.input_path="${MSTX_PROFILER_DATA_PATH}" \ + input.profiler_type=mstx \ + input.input_type=multi_json_mstx \ + input.rank_list="${RANK_LIST}" + output.output_path="${OUTPUT_PATH}" \ + timeline.visualizer.vis_type="${VIS_TYPE}" echo "==========================================" echo ">>> Analysis completed successfully!" diff --git a/examples/nvtx_exec.sh b/examples/nvtx_exec.sh index d164210..f46aa17 100644 --- a/examples/nvtx_exec.sh +++ b/examples/nvtx_exec.sh @@ -4,7 +4,6 @@ set -euo pipefail NVTX_PROFILER_DATA_PATH="${NVTX_PROFILER_DATA_PATH:-}" OUTPUT_PATH="${OUTPUT_PATH:-./output}" -PROFILER_TYPE="${PROFILER_TYPE:-nvtx}" VIS_TYPE="${VIS_TYPE:-html}" RANK_LIST="${RANK_LIST:-all}" @@ -13,17 +12,17 @@ echo "Nvtx Profiler Cluster Analysis" echo "==========================================" echo "Input Path: ${NVTX_PROFILER_DATA_PATH}" echo "Output Path: ${OUTPUT_PATH}" -echo "Profiler Type: ${PROFILER_TYPE}" echo "Vis Type: ${VIS_TYPE}" echo "Rank List: ${RANK_LIST}" echo "==========================================" python -m rl_insight.main \ - --input-path "${NVTX_PROFILER_DATA_PATH}" \ - --profiler-type "${PROFILER_TYPE}" \ - --output-path "${OUTPUT_PATH}" \ - --vis-type "${VIS_TYPE}" \ - --rank-list "${RANK_LIST}" + input.input_path="${NVTX_PROFILER_DATA_PATH}" \ + input.profiler_type=nvtx \ + input.input_type=multi_json_nvtx \ + input.rank_list="${RANK_LIST}" \ + output.output_path="${OUTPUT_PATH}" \ + timeline.visualizer.vis_type="${VIS_TYPE}" echo "==========================================" echo ">>> Analysis completed successfully!" diff --git a/examples/torch_profiler_exec.sh b/examples/torch_profiler_exec.sh index bddb944..ca9612d 100644 --- a/examples/torch_profiler_exec.sh +++ b/examples/torch_profiler_exec.sh @@ -4,7 +4,6 @@ set -euo pipefail TORCH_PROFILER_DATA_PATH="${TORCH_PROFILER_DATA_PATH:-}" OUTPUT_PATH="${OUTPUT_PATH:-./output}" -PROFILER_TYPE="${PROFILER_TYPE:-torch}" VIS_TYPE="${VIS_TYPE:-html}" RANK_LIST="${RANK_LIST:-all}" @@ -13,17 +12,17 @@ echo "Torch Profiler Cluster Analysis" echo "==========================================" echo "Input Path: ${TORCH_PROFILER_DATA_PATH}" echo "Output Path: ${OUTPUT_PATH}" -echo "Profiler Type: ${PROFILER_TYPE}" echo "Vis Type: ${VIS_TYPE}" echo "Rank List: ${RANK_LIST}" echo "==========================================" python -m rl_insight.main \ - --input-path "${TORCH_PROFILER_DATA_PATH}" \ - --profiler-type "${PROFILER_TYPE}" \ - --output-path "${OUTPUT_PATH}" \ - --vis-type "${VIS_TYPE}" \ - --rank-list "${RANK_LIST}" + input.input_path="${TORCH_PROFILER_DATA_PATH}" \ + input.profiler_type=torch \ + input.input_type=multi_json_torch \ + input.rank_list="${RANK_LIST}" \ + output.output_path="${OUTPUT_PATH}" \ + timeline.visualizer.vis_type="${VIS_TYPE}" echo "==========================================" echo ">>> Analysis completed successfully!" diff --git a/rl_insight/config/__init__.py b/rl_insight/config/__init__.py new file mode 100644 index 0000000..3f2c7d3 --- /dev/null +++ b/rl_insight/config/__init__.py @@ -0,0 +1,43 @@ +# Copyright (c) 2025 verl-project authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .config import ( + AppConfig, + GmmConfig, + GmmParserConfig, + GmmVisualizerConfig, + InputConfig, + OutputConfig, + PipelineConfig, + TimelineConfig, + TimelineParserConfig, + TimelineVisualizerConfig, +) +from .config_loader import ConfigLoader +from .utils import get_config_value + +__all__ = [ + "AppConfig", + "GmmConfig", + "GmmParserConfig", + "GmmVisualizerConfig", + "InputConfig", + "OutputConfig", + "PipelineConfig", + "TimelineConfig", + "TimelineParserConfig", + "TimelineVisualizerConfig", + "ConfigLoader", + "get_config_value", +] diff --git a/rl_insight/config/config.py b/rl_insight/config/config.py new file mode 100644 index 0000000..16efb3e --- /dev/null +++ b/rl_insight/config/config.py @@ -0,0 +1,103 @@ +# Copyright (c) 2025 verl-project authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Optional + +from omegaconf import MISSING + + +@dataclass +class InputConfig: + """Input data configuration.""" + + input_path: str = MISSING # Path to profiling data (required) + input_type: str = "multi_json_mstx" # multi_json_mstx | multi_json_torch | multi_json_nvtx | gmm_data + profiler_type: str = "mstx" # mstx | torch | nvtx | gmm + rank_list: str = "all" # Rank id list, e.g. '0,1,2' or 'all' + + +@dataclass +class OutputConfig: + """Output configuration.""" + + output_path: str = "output" # Output directory path + + +@dataclass +class TimelineParserConfig: + """Timeline parser filter configuration.""" + + +@dataclass +class TimelineVisualizerConfig: + """Timeline visualizer configuration.""" + + vis_type: str = "html" # html | png + width: int = 2000 # Image width in pixels (png only) + scale: int = 2 # Image scale factor (png only) + + +@dataclass +class TimelineConfig: + """Timeline configuration.""" + + parser: TimelineParserConfig = field(default_factory=TimelineParserConfig) + visualizer: TimelineVisualizerConfig = field( + default_factory=TimelineVisualizerConfig + ) + + +@dataclass +class GmmParserConfig: + """GMM parser filter configuration.""" + + step: Optional[str] = None # Step filter, e.g. '1' or '1,2' + role: Optional[str] = None # Role filter + + +@dataclass +class GmmVisualizerConfig: + """GMM heatmap visualizer configuration.""" + + vis_type: str = "gmm_heatmap" # gmm_heatmap + dpi: int = 200 # DPI for heatmap PNG output + cmap: str = "viridis" # Matplotlib colormap name + gmm_per_layer: int = 3 # Grouped matmul count per MoE layer + + +@dataclass +class GmmConfig: + """GMM configuration.""" + + parser: GmmParserConfig = field(default_factory=GmmParserConfig) + visualizer: GmmVisualizerConfig = field(default_factory=GmmVisualizerConfig) + + +@dataclass +class PipelineConfig: + """Pipeline configuration.""" + + pipeline_type: str = "OfflineInsightPipeline" # OfflineInsightPipeline + + +@dataclass +class AppConfig: + """RL Insight configuration.""" + + pipeline: PipelineConfig = field(default_factory=PipelineConfig) + input: InputConfig = field(default_factory=InputConfig) + output: OutputConfig = field(default_factory=OutputConfig) + timeline: TimelineConfig = field(default_factory=TimelineConfig) + gmm: GmmConfig = field(default_factory=GmmConfig) diff --git a/rl_insight/config/config_loader.py b/rl_insight/config/config_loader.py new file mode 100644 index 0000000..597a3a7 --- /dev/null +++ b/rl_insight/config/config_loader.py @@ -0,0 +1,225 @@ +# Copyright (c) 2025 verl-project authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dataclasses +import inspect +import sys +from pathlib import Path +from typing import Any, List, Optional, get_type_hints + +from omegaconf import DictConfig, MISSING, OmegaConf + +from .config import AppConfig + + +class _HelpRenderer: + """Generate CLI help text from structured config dataclasses.""" + + @staticmethod + def render(supported_presets: set[str]) -> str: + sections = [ + _HelpRenderer._header(supported_presets), + _HelpRenderer._config_keys(), + _HelpRenderer._examples(), + ] + return "\n".join(sections) + + @staticmethod + def _header(supported_presets: set[str]) -> str: + return ( + "Usage: python -m rl_insight [key=value] ... [config_path=PATH] [preset=NAME]\n" + "\n" + "RL Insight - Cluster scheduling visualization for RL training\n" + "\n" + "Special keys:\n" + f" {'config_path':<30s} Path to YAML config file\n" + f" {'preset':<30s} Preset name: {', '.join(sorted(supported_presets))}\n" + "\n" + "Configuration keys (key=value):\n" + ) + + @staticmethod + def _config_keys() -> str: + lines: list[str] = [] + _HelpRenderer._format_group(lines, AppConfig, prefix="") + return "\n".join(lines) + + @staticmethod + def _examples() -> str: + return ( + "\n" + "Examples:\n" + " python -m rl_insight.main input.input_path=./data/mstx_data/mstx_profile\n" + " python -m rl_insight.main input.input_path=./data/gmm_data input.profiler_type=gmm\n" + " python -m rl_insight.main config_path=my_config.yaml gmm.visualizer.dpi=300\n" + " python -m rl_insight.main preset=timeline timeline.visualizer.vis_type=png\n" + " python -m rl_insight.main preset=gmm input.input_path=./data/gmm_data\n" + ) + + @staticmethod + def _format_group(lines: list[str], cls_type: Any, prefix: str) -> None: + group_name = cls_type.__doc__.strip() if cls_type.__doc__ else cls_type.__name__ + lines.append(f" [{group_name}]") + + hints = get_type_hints(cls_type) + source_lines = inspect.getsource(cls_type).split("\n") + + for f in dataclasses.fields(cls_type): + if dataclasses.is_dataclass(f.type): + continue + + full_key = f"{prefix}{f.name}" if prefix else f.name + type_name = _HelpRenderer._type_name(hints.get(f.name, f.type)) + default_str = _HelpRenderer._default_str(f.default) + comment = _HelpRenderer._field_comment(f.name, source_lines) + + lines.append( + f" {full_key + ' (' + type_name + ')':<38s} " + f"{default_str:<25s}{comment}" + ) + + for f in dataclasses.fields(cls_type): + if dataclasses.is_dataclass(f.type): + sub_prefix = f"{prefix}{f.name}." if prefix else f"{f.name}." + lines.append("") + _HelpRenderer._format_group(lines, f.type, sub_prefix) + + @staticmethod + def _type_name(hint) -> str: + return hint.__name__ if hasattr(hint, "__name__") else str(hint) + + @staticmethod + def _default_str(default) -> str: + if default is dataclasses.MISSING or default is MISSING: + return "REQUIRED" + if default is None: + return "null" + return repr(default) + + @staticmethod + def _field_comment(field_name: str, source_lines: list[str]) -> str: + for line in source_lines: + stripped = line.strip() + if stripped.startswith(f"{field_name}:") and "#" in stripped: + return " " + stripped.split("#", 1)[1].strip() + return "" + + +class ConfigLoader: + PRESETS_DIR = Path(__file__).parent + SUPPORTED_PRESETS = {"timeline", "gmm"} + + @classmethod + def load( + cls, + config_path: Optional[str] = None, + preset: Optional[str] = None, + cli_args: Optional[List[str]] = None, + ) -> DictConfig: + cfg = OmegaConf.structured(AppConfig) + + if preset: + cfg = cls._merge_preset(cfg, preset) + + if config_path: + cfg = cls._merge_yaml(cfg, config_path) + + if cli_args: + cli_cfg = OmegaConf.from_cli(cli_args) + cfg = OmegaConf.merge(cfg, cli_cfg) + + OmegaConf.resolve(cfg) + return cfg + + @classmethod + def load_from_cli(cls, argv: Optional[List[str]] = None) -> DictConfig: + if argv is None: + argv = sys.argv[1:] + + if "--help" in argv or "-h" in argv: + print(_HelpRenderer.render(cls.SUPPORTED_PRESETS)) + sys.exit(0) + + config_path, preset, remaining = cls._parse_special_args(argv) + + if preset is None and config_path is None: + preset = cls._infer_preset_from_args(remaining) or "timeline" + + return cls.load( + config_path=config_path, + preset=preset, + cli_args=remaining or None, + ) + + @classmethod + def load_from_yaml(cls, yaml_path: str) -> DictConfig: + path = Path(yaml_path) + if not path.exists(): + raise FileNotFoundError(f"YAML config file not found: {yaml_path}") + return OmegaConf.load(path) + + @classmethod + def save_to_yaml(cls, cfg: DictConfig, yaml_path: str) -> None: + path = Path(yaml_path) + path.parent.mkdir(parents=True, exist_ok=True) + OmegaConf.save(cfg, path) + + @classmethod + def get_default_config(cls) -> DictConfig: + return OmegaConf.structured(AppConfig) + + @classmethod + def _merge_preset(cls, cfg: DictConfig, preset: str) -> DictConfig: + if preset not in cls.SUPPORTED_PRESETS: + raise ValueError( + f"Unknown preset: {preset}. " + f"Supported presets: {', '.join(sorted(cls.SUPPORTED_PRESETS))}" + ) + preset_path = cls.PRESETS_DIR / f"{preset}.yaml" + if preset_path.exists(): + cfg = OmegaConf.merge(cfg, OmegaConf.load(preset_path)) + return cfg + + @classmethod + def _merge_yaml(cls, cfg: DictConfig, config_path: str) -> DictConfig: + path = Path(config_path) + if not path.exists(): + raise FileNotFoundError(f"Config file not found: {config_path}") + return OmegaConf.merge(cfg, OmegaConf.load(path)) + + @staticmethod + def _parse_special_args( + argv: list[str], + ) -> tuple[Optional[str], Optional[str], list[str]]: + config_path: Optional[str] = None + preset: Optional[str] = None + remaining: list[str] = [] + + for arg in argv: + if arg.startswith("config_path="): + config_path = arg.split("=", 1)[1] + elif arg.startswith("preset="): + preset = arg.split("=", 1)[1] + else: + remaining.append(arg) + + return config_path, preset, remaining + + @staticmethod + def _infer_preset_from_args(args: list[str]) -> Optional[str]: + for arg in args: + if arg.startswith("input.profiler_type="): + profiler_type = arg.split("=", 1)[1] + return "gmm" if profiler_type == "gmm" else "timeline" + return None diff --git a/rl_insight/config/gmm.yaml b/rl_insight/config/gmm.yaml new file mode 100644 index 0000000..a8f3eaa --- /dev/null +++ b/rl_insight/config/gmm.yaml @@ -0,0 +1,22 @@ +# GMM visualization preset + +pipeline: + pipeline_type: OfflineInsightPipeline # OfflineInsightPipeline + +input: + input_type: gmm_data # gmm_data + profiler_type: gmm # gmm + rank_list: all # Rank id list, e.g. '0,1,2' or 'all' + +output: + output_path: output # Output directory path + +gmm: + parser: + step: null # Step filter, e.g. '1' or '1,2' + role: null # Role filter + visualizer: + vis_type: gmm_heatmap # gmm_heatmap + dpi: 200 # DPI for heatmap PNG output + cmap: viridis # Matplotlib colormap name + gmm_per_layer: 3 # Grouped matmul count per MoE layer diff --git a/rl_insight/config/timeline.yaml b/rl_insight/config/timeline.yaml new file mode 100644 index 0000000..20dcc30 --- /dev/null +++ b/rl_insight/config/timeline.yaml @@ -0,0 +1,19 @@ +# Timeline visualization preset + +pipeline: + pipeline_type: OfflineInsightPipeline # OfflineInsightPipeline + +input: + input_type: multi_json_mstx # multi_json_mstx | multi_json_torch | multi_json_nvtx + profiler_type: mstx # mstx | torch | nvtx + rank_list: all # Rank id list, e.g. '0,1,2' or 'all' + +output: + output_path: output # Output directory path + +timeline: + parser: {} # No parser-specific config yet + visualizer: + vis_type: html # html | png + width: 2000 # Image width in pixels (png only) + scale: 2 # Image scale factor (png only) \ No newline at end of file diff --git a/rl_insight/config/utils.py b/rl_insight/config/utils.py new file mode 100644 index 0000000..3a675a1 --- /dev/null +++ b/rl_insight/config/utils.py @@ -0,0 +1,69 @@ +# Copyright (c) 2025 verl-project authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Union + +from omegaconf import DictConfig + + +def get_config_value( + config: Union[DictConfig, dict], key: str, default: Any = None +) -> Any: + """Retrieve a value from a DictConfig or dict using a dot-separated key. + + Supports both nested access (``output.output_path``) and flat key fallback + (``output_output_path``) for backward compatibility. + + Args: + config: Configuration object (DictConfig or plain dict). + key: Dot-separated key path, e.g. ``"output.output_path"``. + default: Value returned when the key is not found. + + Returns: + The resolved value, or *default* if the key does not exist. + """ + if isinstance(config, DictConfig): + flat_key = key.replace(".", "_") + if hasattr(config, flat_key): + return getattr(config, flat_key) + parts = key.split(".") + value = config + for part in parts: + if hasattr(value, part): + value = getattr(value, part) + else: + return default + return value + + flat_key = key.replace(".", "_") + if flat_key in config: + return config.get(flat_key) + if key in config: + return config.get(key) + + last_part = key.split(".")[-1] + if last_part in config: + return config.get(last_part) + + nested_parts = key.split(".") + if len(nested_parts) > 1: + value = config + for part in nested_parts: + if isinstance(value, dict) and part in value: + value = value[part] + else: + return default + return value + + return default diff --git a/rl_insight/main.py b/rl_insight/main.py index 49e574d..566e975 100644 --- a/rl_insight/main.py +++ b/rl_insight/main.py @@ -12,16 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -import argparse +from omegaconf import DictConfig -from .parser import register_parser_specific_args +from .config import ConfigLoader from .pipeline.offline_insight_pipeline import OfflineInsightPipeline -from .visualizer import register_visualizer_specific_args SUPPORTED_PIPELINE_TYPES = {"OfflineInsightPipeline": OfflineInsightPipeline} -def run_pipeline(config, pipeline_class=None): +def run_pipeline(config: DictConfig, pipeline_class=None): if pipeline_class is None: raise ValueError("A pipeline class must be provided.") @@ -29,52 +28,23 @@ def run_pipeline(config, pipeline_class=None): runner.run() -def main(): - arg_parser = argparse.ArgumentParser(description="Cluster scheduling visualization") - arg_parser.add_argument( - "--input-path", required=True, help="Raw path of profiling data" - ) - arg_parser.add_argument( - "--input-type", - default="multi_json_mstx", - help=( - "Input data type. Supported: 'multi_json_mstx', 'multi_json_torch', " - "'multi_json_nvtx', 'gmm_data'" - ), - ) - arg_parser.add_argument( - "--profiler-type", - default="mstx", - help="Profiler type: mstx, torch, nvtx, gmm", - ) - arg_parser.add_argument("--output-path", default="output", help="Output path") - arg_parser.add_argument( - "--vis-type", - default="html", - help="Visualization type, supported: html, gmm_heatmap", - ) - arg_parser.add_argument("--rank-list", type=str, help="Rank id list", default="all") - arg_parser.add_argument( - "--pipeline-type", - type=str, - help="Tool pipeline type", - default="OfflineInsightPipeline", - ) - - register_parser_specific_args(arg_parser) - register_visualizer_specific_args(arg_parser) - config = arg_parser.parse_args() +def validate_config(cfg: DictConfig) -> None: + if cfg.input.input_path is None: + raise ValueError("input.input_path is required") - # Validate pipeline type - if config.pipeline_type not in SUPPORTED_PIPELINE_TYPES: + if cfg.pipeline.pipeline_type not in SUPPORTED_PIPELINE_TYPES: supported_types = ", ".join(SUPPORTED_PIPELINE_TYPES.keys()) raise ValueError( - f"Unsupported pipeline type: {config.pipeline_type}. Supported types are: {supported_types}" + f"Unsupported pipeline type: {cfg.pipeline.pipeline_type}. " + f"Supported types are: {supported_types}" ) - # Run the pipeline - pipeline_class = SUPPORTED_PIPELINE_TYPES[config.pipeline_type] - run_pipeline(config, pipeline_class) + +def main(): + cfg = ConfigLoader.load_from_cli() + validate_config(cfg) + pipeline_class = SUPPORTED_PIPELINE_TYPES[cfg.pipeline.pipeline_type] + run_pipeline(cfg, pipeline_class) if __name__ == "__main__": diff --git a/rl_insight/parser/__init__.py b/rl_insight/parser/__init__.py index 4891689..b5f2b48 100644 --- a/rl_insight/parser/__init__.py +++ b/rl_insight/parser/__init__.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import argparse - from .mstx_parser import MstxClusterParser from .torch_parser import TorchClusterParser from .nvtx_parser import NvtxClusterParser @@ -22,26 +20,10 @@ def get_cluster_parser_cls(name): if name == "gmm": - # Lazy import keeps optional gmm dependency off non-gmm paths. from . import gmm_parser # noqa: F401 return _get_cluster_parser_cls(name) -def register_parser_specific_args(arg_parser: argparse.ArgumentParser) -> None: - """Register optional parser CLI flags (additive). Safe for non-GMM runs; extras are ignored.""" - gmm_group = arg_parser.add_argument_group("GMM parser parameters") - gmm_group.add_argument( - "--step", - type=str, - help="Step filter for GMM parser, e.g. '1' or '1,2'", - ) - gmm_group.add_argument( - "--role", - type=str, - help="Role filter for GMM parser", - ) - - def __getattr__(name): if name == "GmmParser": from .gmm_parser import GmmParser @@ -53,7 +35,6 @@ def __getattr__(name): __all__ = [ "BaseClusterParser", "get_cluster_parser_cls", - "register_parser_specific_args", "MstxClusterParser", "TorchClusterParser", "NvtxClusterParser", diff --git a/rl_insight/parser/gmm_parser.py b/rl_insight/parser/gmm_parser.py index 835982a..01b0c17 100644 --- a/rl_insight/parser/gmm_parser.py +++ b/rl_insight/parser/gmm_parser.py @@ -15,13 +15,15 @@ from loguru import logger import re from pathlib import Path -from typing import Any, List, Optional +from typing import Any, List, Optional, Union import pandas as pd import numpy as np import torch +from omegaconf import DictConfig +from rl_insight.config import get_config_value from rl_insight.parser.parser import BaseClusterParser, register_cluster_parser -from rl_insight.utils.schema import DataMap, Constant +from rl_insight.utils.schema import DataMap from rl_insight.data import DataEnum @@ -29,10 +31,10 @@ class GmmParser(BaseClusterParser): input_type = DataEnum.GMM_DATA - def __init__(self, params) -> None: + def __init__(self, params: Union[DictConfig, dict]) -> None: super().__init__(params) self.events_summary: Optional[pd.DataFrame] = None - rank_list = params.get(Constant.RANK_LIST, "all") + rank_list = get_config_value(params, "input.rank_list", "all") self._rank_list = ( rank_list if rank_list == "all" @@ -42,8 +44,7 @@ def __init__(self, params) -> None: if rank.strip().isdigit() ] ) - # Get step filter(s) if provided. Supports "1" or "1,2". - step = params.get("step", None) + step = get_config_value(params, "gmm.parser.step", None) if step is None: self._step_list: Optional[list[int]] = None elif isinstance(step, int): @@ -59,8 +60,7 @@ def __init__(self, params) -> None: "Will process all steps." ) self._step_list = None - # Get role filter if provided - self._role = params.get("role", None) + self._role = get_config_value(params, "gmm.parser.role", None) @staticmethod def _normalize_path_text(path_value: str | Path) -> str: diff --git a/rl_insight/parser/parser.py b/rl_insight/parser/parser.py index b924d5d..90e8438 100644 --- a/rl_insight/parser/parser.py +++ b/rl_insight/parser/parser.py @@ -12,21 +12,26 @@ # See the License for the specific language governing permissions and # limitations under the License. -from loguru import logger -import multiprocessing from abc import ABC, abstractmethod from concurrent.futures import ProcessPoolExecutor, as_completed -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, Union +import multiprocessing + +from loguru import logger +from omegaconf import DictConfig import pandas as pd from rl_insight.utils.schema import Constant, DataMap class BaseClusterParser(ABC): - def __init__(self, params) -> None: + def __init__(self, params: Union[DictConfig, dict]) -> None: self.events_summary: Optional[pd.DataFrame] = None - rank_list = params.get(Constant.RANK_LIST, "all") + if isinstance(params, DictConfig): + rank_list = params.input.rank_list + else: + rank_list = params.get(Constant.RANK_LIST, "all") self._rank_list = ( rank_list if rank_list == "all" @@ -60,7 +65,6 @@ def mapper_func(self, data_maps: list[DataMap]): failed_ranks = [] with ProcessPoolExecutor(max_workers=max_workers) as executor: - # Submit all tasks future_to_rank = { executor.submit(self._mapper_func, data_map): data_map[Constant.RANK_ID] for data_map in data_maps @@ -104,8 +108,6 @@ def _mapper_func(self, data_map: DataMap) -> list[dict[str, Any]]: return self.parse_analysis_data(profiler_data_path, rank_id, role) def reducer_func(self, mapper_res): - """Process data collected from all ranks""" - # Flatten valid results from all ranks reduce_results: list[dict] = [] for result in mapper_res: if not result: diff --git a/rl_insight/pipeline/offline_insight_pipeline.py b/rl_insight/pipeline/offline_insight_pipeline.py index cbe29b9..4a83c57 100644 --- a/rl_insight/pipeline/offline_insight_pipeline.py +++ b/rl_insight/pipeline/offline_insight_pipeline.py @@ -12,48 +12,48 @@ # See the License for the specific language governing permissions and # limitations under the License. +from omegaconf import DictConfig + from rl_insight.data import DataChecker, DataEnum from rl_insight.parser import get_cluster_parser_cls -from rl_insight.utils.schema import Constant from rl_insight.visualizer import get_cluster_visualizer_cls class OfflineInsightPipeline: - def __init__(self, config): + def __init__(self, config: DictConfig): self.config = config - # init data - self.input_data_type = DataEnum(self.config.input_type) + self.input_data_type = DataEnum(self.config.input.input_type) - # parser related parser_config = self._prepare_parser_config() - parser_cls = get_cluster_parser_cls(self.config.profiler_type) + parser_cls = get_cluster_parser_cls(self.config.input.profiler_type) self.parser = parser_cls(parser_config) - # visualizer related visualizer_config = self._prepare_visualizer_config() - visualizer_cls = get_cluster_visualizer_cls(self.config.vis_type) + visualizer_cls = get_cluster_visualizer_cls( + self.config.timeline.visualizer.vis_type + if self.config.input.profiler_type != "gmm" + else self.config.gmm.visualizer.vis_type + ) self.visualizer = visualizer_cls(visualizer_config) - def _prepare_parser_config(self): - config = vars(self.config).copy() - config[Constant.RANK_LIST] = config.get("rank_list", "all") - return config + def _prepare_parser_config(self) -> DictConfig: + return self.config - def _prepare_visualizer_config(self): - return vars(self.config).copy() + def _prepare_visualizer_config(self) -> DictConfig: + return self.config def run(self): if self.input_data_type != self.parser.input_type: raise ValueError( - f"Input data type {self.input_data_type} does not match parser input type {self.parser.input_type}" + f"Input data type {self.input_data_type} does not match " + f"parser input type {self.parser.input_type}" ) - # validate input data - DataChecker(self.input_data_type, self.config.input_path).run() - output_data = self.parser.run(self.config.input_path) + DataChecker(self.input_data_type, self.config.input.input_path).run() + + output_data = self.parser.run(self.config.input.input_path) - # validate output data DataChecker(self.visualizer.input_type, output_data).run() self.visualizer.run(output_data) diff --git a/rl_insight/visualizer/__init__.py b/rl_insight/visualizer/__init__.py index 78e965b..e62d902 100644 --- a/rl_insight/visualizer/__init__.py +++ b/rl_insight/visualizer/__init__.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import argparse - from .timeline_visualizer import RLTimelineVisualizer from .timeline_visualizer import RLTimelinePNGVisualizer from .visualizer import ( @@ -23,33 +21,9 @@ from .gmm_visualizer import GmmVisualizer -def register_visualizer_specific_args(arg_parser: argparse.ArgumentParser) -> None: - """Register optional visualizer CLI flags (additive). Safe for html timeline; extras are ignored.""" - heatmap_group = arg_parser.add_argument_group("GMM heatmap parameters") - heatmap_group.add_argument( - "--dpi", - type=int, - default=150, - help="DPI for heatmap PNG output", - ) - heatmap_group.add_argument( - "--cmap", - type=str, - default="viridis", - help="Matplotlib colormap name", - ) - heatmap_group.add_argument( - "--gmm-per-layer", - type=int, - default=3, - help="Expected grouped_matmul count per MoE layer in forward pass", - ) - - __all__ = [ "BaseVisualizer", "get_cluster_visualizer_cls", - "register_visualizer_specific_args", "RLTimelineVisualizer", "RLTimelinePNGVisualizer", "GmmVisualizer", diff --git a/rl_insight/visualizer/gmm_visualizer.py b/rl_insight/visualizer/gmm_visualizer.py index 79610b0..a7387d2 100644 --- a/rl_insight/visualizer/gmm_visualizer.py +++ b/rl_insight/visualizer/gmm_visualizer.py @@ -19,6 +19,7 @@ import pandas as pd from loguru import logger +from rl_insight.config import get_config_value from rl_insight.visualizer.visualizer import BaseVisualizer, register_cluster_visualizer from rl_insight.data import DataEnum @@ -49,13 +50,15 @@ def _load_signature(stage_data: pd.DataFrame) -> np.ndarray: def run(self, data): """Run GMM heatmap visualization from parsed data.""" # Extract parameters from config - output_cfg = self.config.get( - "output_path", "./output/gmm_group_list_heatmap.png" + output_cfg = get_config_value( + self.config, "output.output_path", "./output/gmm_group_list_heatmap.png" ) output = self._resolve_output_path(output_cfg) - dpi = self.config.get("dpi", 150) - cmap = self.config.get("cmap", "viridis") - gmm_per_layer = int(self.config.get("gmm_per_layer", 3)) + dpi = get_config_value(self.config, "gmm.visualizer.dpi", 200) + cmap = get_config_value(self.config, "gmm.visualizer.cmap", "viridis") + gmm_per_layer = int( + get_config_value(self.config, "gmm.visualizer.gmm_per_layer", 3) + ) if not isinstance(data, pd.DataFrame): raise ValueError(f"Expected DataFrame, got {type(data).__name__}") diff --git a/rl_insight/visualizer/timeline_visualizer.py b/rl_insight/visualizer/timeline_visualizer.py index 788dcdd..5819d87 100644 --- a/rl_insight/visualizer/timeline_visualizer.py +++ b/rl_insight/visualizer/timeline_visualizer.py @@ -13,12 +13,15 @@ # limitations under the License. import os +from typing import Union import numpy as np import pandas as pd import plotly.graph_objects as go +from omegaconf import DictConfig from plotly.io import to_image +from rl_insight.config import get_config_value from rl_insight.data import DataEnum from rl_insight.utils.schema import FigureConfig @@ -47,9 +50,9 @@ class RLTimelineVisualizer(BaseVisualizer): input_type: DataEnum = DataEnum.SUMMARY_EVENT - def __init__(self, config: dict): + def __init__(self, config: Union[DictConfig, dict]): super().__init__(config) - self.output_path = config.get("output_path", None) + self.output_path = get_config_value(config, "output.output_path", None) def run(self, data): return self.generate_rl_timeline(data) @@ -405,11 +408,11 @@ class RLTimelinePNGVisualizer(BaseVisualizer): input_type: DataEnum = DataEnum.SUMMARY_EVENT - def __init__(self, config: dict): + def __init__(self, config: Union[DictConfig, dict]): super().__init__(config) - self.output_path = config.get("output_path", None) - self.width = config.get("width", 2000) - self.scale = config.get("scale", 2) + self.output_path = get_config_value(config, "output.output_path", None) + self.width = get_config_value(config, "timeline.visualizer.width", 2000) + self.scale = get_config_value(config, "timeline.visualizer.scale", 2) def run(self, data): return self.generate_rl_timeline_png(data) diff --git a/rl_insight/visualizer/visualizer.py b/rl_insight/visualizer/visualizer.py index 14d0fc5..c6d6fa4 100644 --- a/rl_insight/visualizer/visualizer.py +++ b/rl_insight/visualizer/visualizer.py @@ -13,7 +13,9 @@ # limitations under the License. from abc import ABC, abstractmethod -from typing import Callable +from typing import Callable, Union + +from omegaconf import DictConfig from rl_insight.data import DataEnum @@ -21,7 +23,7 @@ class BaseVisualizer(ABC): input_type: DataEnum = DataEnum.SUMMARY_EVENT - def __init__(self, config: dict): + def __init__(self, config: Union[DictConfig, dict]): self.config = config @abstractmethod diff --git a/tests/parser/test_cluster_analysis.py b/tests/parser/test_cluster_analysis.py index d0a247e..57482fa 100644 --- a/tests/parser/test_cluster_analysis.py +++ b/tests/parser/test_cluster_analysis.py @@ -51,7 +51,10 @@ def _timeline_viz(**kwargs): - cfg = {"output_path": "/tmp", "vis_type": "html"} + cfg = { + "output": {"output_path": "/tmp"}, + "timeline": {"visualizer": {"vis_type": "html"}}, + } cfg.update(kwargs) return RLTimelineVisualizer(cfg) @@ -992,7 +995,10 @@ def test_full_pipeline_with_mock_data(self, mock_mstx_profiler_structure, tmp_pa side_effect=lambda frame, threshold_ms=10.0: frame, ): viz = RLTimelineVisualizer( - {"output_path": output_dir, "vis_type": "html"} + { + "output": {"output_path": output_dir}, + "timeline": {"visualizer": {"vis_type": "html"}}, + } ) viz.generate_rl_timeline(df) @@ -1001,7 +1007,7 @@ def test_full_pipeline_with_mock_data(self, mock_mstx_profiler_structure, tmp_pa @patch( "sys.argv", - ["main.py", "--input-path", "/tmp", "--profiler-type", "mstx"], + ["main.py", "input.input_path=/tmp", "input.profiler_type=mstx"], ) @patch("rl_insight.pipeline.offline_insight_pipeline.DataChecker.run") @patch("rl_insight.pipeline.offline_insight_pipeline.get_cluster_parser_cls") diff --git a/tests/parser/test_png_visualizer.py b/tests/parser/test_png_visualizer.py index 74a6bc8..8a81b3b 100644 --- a/tests/parser/test_png_visualizer.py +++ b/tests/parser/test_png_visualizer.py @@ -22,7 +22,10 @@ @pytest.fixture def visualizer(): """Initialize the visualizer instance for testing.""" - config = {"output_path": "test_output", "width": 2000, "scale": 2} + config = { + "output": {"output_path": "test_output"}, + "timeline": {"visualizer": {"width": 2000, "scale": 2}}, + } return RLTimelinePNGVisualizer(config) diff --git a/tests/special_e2e/test_gmm_e2e.py b/tests/special_e2e/test_gmm_e2e.py index 98190cf..397276e 100644 --- a/tests/special_e2e/test_gmm_e2e.py +++ b/tests/special_e2e/test_gmm_e2e.py @@ -34,11 +34,11 @@ def test_gmm_e2e_with_repo_sample_data(monkeypatch, tmp_path): test_args = [ "main.py", - f"--input-path={input_dir}", - f"--output-path={output_dir}", - "--profiler-type=gmm", - "--input-type=gmm_data", - "--vis-type=gmm_heatmap", + f"input.input_path={input_dir}", + f"output.output_path={output_dir}", + "input.profiler_type=gmm", + "input.input_type=gmm_data", + "gmm.visualizer.vis_type=gmm_heatmap", ] monkeypatch.setattr(sys, "argv", test_args) diff --git a/tests/special_e2e/test_mstx_e2e.py b/tests/special_e2e/test_mstx_e2e.py index dcbf73f..393aefa 100644 --- a/tests/special_e2e/test_mstx_e2e.py +++ b/tests/special_e2e/test_mstx_e2e.py @@ -29,12 +29,11 @@ def test_mstx_e2e_with_input_path(monkeypatch, tmp_path): # Ensure the input directory exists assert input_dir.exists(), f"Input directory {input_dir} does not exist" - # Set command line parameters test_args = [ "main.py", - f"--input-path={input_dir}", - f"--output-path={output_dir}", - "--profiler-type=mstx", + f"input.input_path={input_dir}", + f"output.output_path={output_dir}", + "input.profiler_type=mstx", ] monkeypatch.setattr(sys, "argv", test_args) diff --git a/tests/special_e2e/test_nvtx_e2e.py b/tests/special_e2e/test_nvtx_e2e.py index 4a1e476..831af89 100644 --- a/tests/special_e2e/test_nvtx_e2e.py +++ b/tests/special_e2e/test_nvtx_e2e.py @@ -29,13 +29,12 @@ def test_nvtx_e2e_with_input_path(monkeypatch, tmp_path): # Ensure the input directory exists assert input_dir.exists(), f"Input directory {input_dir} does not exist" - # Set command line parameters test_args = [ "main.py", - f"--input-path={input_dir}", - f"--output-path={output_dir}", - "--profiler-type=nvtx", - "--input-type=multi_json_nvtx", + f"input.input_path={input_dir}", + f"output.output_path={output_dir}", + "input.profiler_type=nvtx", + "input.input_type=multi_json_nvtx", ] monkeypatch.setattr(sys, "argv", test_args) diff --git a/tests/special_e2e/test_torch_e2e.py b/tests/special_e2e/test_torch_e2e.py index a72fa8d..b46e21f 100644 --- a/tests/special_e2e/test_torch_e2e.py +++ b/tests/special_e2e/test_torch_e2e.py @@ -29,13 +29,12 @@ def test_torch_e2e_with_input_path(monkeypatch, tmp_path): # Ensure the input directory exists assert input_dir.exists(), f"Input directory {input_dir} does not exist" - # Set command line parameters test_args = [ "main.py", - f"--input-path={input_dir}", - f"--output-path={output_dir}", - "--profiler-type=torch", - "--input-type=multi_json_torch", + f"input.input_path={input_dir}", + f"output.output_path={output_dir}", + "input.profiler_type=torch", + "input.input_type=multi_json_torch", ] monkeypatch.setattr(sys, "argv", test_args) From 4ee0c4311a87cda7ea66901a280f6dca4c6cffc3 Mon Sep 17 00:00:00 2001 From: acat-rw <892882856@qq.com> Date: Thu, 21 May 2026 10:28:35 +0000 Subject: [PATCH 2/2] Modular configuration of parameters and reduction of redundancy. --- docs/data/data_specification.md | 2 +- docs/developer_guides/extending_guide.md | 2 +- docs/overview/RL_Timeline_quickstart.md | 34 ++++++------ docs/overview/gmm_heatmap_quickstart.md | 52 +++++++++---------- examples/gmm_exec.sh | 23 ++++---- examples/mstx_exec.sh | 11 ++-- examples/nvtx_exec.sh | 9 ++-- examples/torch_profiler_exec.sh | 9 ++-- rl_insight/config/__init__.py | 12 ++--- rl_insight/config/config.py | 35 +++++++------ rl_insight/config/config_loader.py | 17 +++--- rl_insight/config/{gmm.yaml => heatmap.yaml} | 13 +++-- rl_insight/config/timeline.yaml | 11 ++-- rl_insight/config/utils.py | 6 +-- rl_insight/main.py | 10 ++-- rl_insight/parser/gmm_parser.py | 4 +- .../pipeline/offline_insight_pipeline.py | 26 ++++------ rl_insight/visualizer/gmm_visualizer.py | 8 +-- rl_insight/visualizer/timeline_visualizer.py | 4 +- tests/parser/test_cluster_analysis.py | 6 +-- tests/parser/test_png_visualizer.py | 2 +- tests/special_e2e/test_gmm_e2e.py | 8 ++- tests/special_e2e/test_mstx_e2e.py | 6 +-- tests/special_e2e/test_nvtx_e2e.py | 7 ++- tests/special_e2e/test_torch_e2e.py | 7 ++- 25 files changed, 152 insertions(+), 172 deletions(-) rename rl_insight/config/{gmm.yaml => heatmap.yaml} (60%) diff --git a/docs/data/data_specification.md b/docs/data/data_specification.md index 4606b20..0329d01 100644 --- a/docs/data/data_specification.md +++ b/docs/data/data_specification.md @@ -236,7 +236,7 @@ python tests/data/check_verl_log.py data/verl_data/good_minimal_verl.log ## 5. GMM 专家负载dump数据 -GMM 热力图输入类型为 `DataEnum.GMM_DATA`(CLI:`input.input_type=gmm_data`、`input.profiler_type=gmm`)。**路径约定、参数与示意图**见 [`docs/overview/gmm_heatmap_quickstart.md`](../overview/gmm_heatmap_quickstart.md)。本节补充数据侧目录与文件格式说明。 +GMM 热力图输入类型为 `DataEnum.GMM_DATA`。**路径约定、参数与示意图**见 [`docs/overview/gmm_heatmap_quickstart.md`](../overview/gmm_heatmap_quickstart.md)。本节补充数据侧目录与文件格式说明。 ### 5.1 目录结构 diff --git a/docs/developer_guides/extending_guide.md b/docs/developer_guides/extending_guide.md index 29477e8..3a20fe5 100644 --- a/docs/developer_guides/extending_guide.md +++ b/docs/developer_guides/extending_guide.md @@ -36,5 +36,5 @@ 适用于:全新的处理范式(跳过步骤、插入预处理、多产物、在线多进程流程等)。 1. 在 `rl_insight/pipeline/` 新增类,实现 `__init__(self, config)`、`run(self)`,按需组合 `DataChecker`、`get_cluster_parser_cls`、`get_cluster_visualizer_cls` 等。 -2. 在 `rl_insight/config/config.py` 的 `PipelineConfig.pipeline_type` 默认值或 preset YAML 中注册新 pipeline 类型。 +2. 在 `rl_insight/config/config.py` 的 `PipelineConfig.type` 默认值或 preset YAML 中注册新 pipeline 类型。 3. 若数据解析或数据类型发生变化,同步扩展 **DataRule** / **Parser** / **Visualizer**。 diff --git a/docs/overview/RL_Timeline_quickstart.md b/docs/overview/RL_Timeline_quickstart.md index aa9bda9..19176f6 100644 --- a/docs/overview/RL_Timeline_quickstart.md +++ b/docs/overview/RL_Timeline_quickstart.md @@ -56,10 +56,9 @@ pip install -e . ```bash python -m rl_insight.main \ - input.input_path= \ - input.profiler_type=mstx \ - input.input_type=multi_json_nvtx \ - output.output_path= + input.path= \ + timeline.parser.type=mstx \ + output.path= ``` 或修改并直接使用 `examples/mstx_exec.sh` 脚本: @@ -74,10 +73,9 @@ bash examples/mstx_exec.sh ```bash python -m rl_insight.main \ - input.input_path= \ - input.profiler_type=torch \ - input.input_type=multi_json_torch \ - output.output_path= + input.path= \ + timeline.parser.type=torch \ + output.path= ``` 或修改并直接使用 `examples/torch_profiler_exec.sh` 脚本: @@ -92,10 +90,9 @@ bash examples/torch_profiler_exec.sh ```bash python -m rl_insight.main \ - input.input_path= \ - input.profiler_type=nvtx \ - input.input_type=multi_json_nvtx \ - output.output_path= + input.path= \ + timeline.parser.type=nvtx \ + output.path= ``` 或修改并直接使用 `examples/nvtx_exec.sh` 脚本: @@ -112,19 +109,18 @@ bash examples/nvtx_exec.sh | 参数 | 默认值 | 说明 | |------|--------|------| -| `input.input_path` | (必填) | Profiling 数据的根目录路径 | -| `input.input_type` | `multi_json_mstx` | 输入数据类型(`multi_json_mstx`、`multi_json_torch`、`multi_json_nvtx`)| -| `input.profiler_type` | `mstx` | 性能数据种类:`mstx`、`torch`、`nvtx` | +| `input.path` | (必填) | Profiling 数据的根目录路径 | | `input.rank_list` | `all` | Rank ID 列表,如 `0,1,2` 或 `all` | -| `output.output_path` | `output` | 输出目录 | -| `preset` | 自动推断 | 预设名称:`timeline`、`gmm`(根据 `profiler_type` 自动推断) | +| `output.path` | `output` | 输出目录 | +| `preset` | 自动推断 | 预设名称:`timeline`、`heatmap`(根据 CLI 参数自动推断) | | `config_path` | 无 | YAML 配置文件路径 | -### 4.2 Timeline 参数 +### 4.2 Timeline 专属参数 | 参数 | 默认值 | 说明 | |------|--------|------| -| `timeline.visualizer.vis_type` | `html` | 可视化类型:`html`、`png` | +| `timeline.parser.type` | `mstx` | 数据源类型:`mstx`、`torch`、`nvtx` | +| `timeline.visualizer.type` | `html` | 可视化类型:`html`、`png` | | `timeline.visualizer.width` | `2000` | 图片宽度(仅 png) | | `timeline.visualizer.scale` | `2` | 图片缩放因子(仅 png) | diff --git a/docs/overview/gmm_heatmap_quickstart.md b/docs/overview/gmm_heatmap_quickstart.md index 25e220b..0037179 100644 --- a/docs/overview/gmm_heatmap_quickstart.md +++ b/docs/overview/gmm_heatmap_quickstart.md @@ -59,25 +59,24 @@ gmm_dump/ 路径字段含义: -- `step_`:训练 step(对应 `gmm.parser.step` 过滤) -- ``:角色名(对应 `gmm.parser.role` 过滤) +- `step_`:训练 step(对应 `heatmap.parser.step` 过滤) +- ``:角色名(对应 `heatmap.parser.role` 过滤) - `rank`:rank id(对应 `input.rank_list` 过滤) - `dump_tensor_data/*.group_list.pt`:MoE grouped_matmul 的专家负载;典型为一维整型张量,第 `i` 个元素表示第 `i` 个 expert 分到的 **token 数** ### 2.2 执行分析脚本 -#### GMM 热力图使用示例 +#### 热力图使用示例 ```bash python -m rl_insight.main \ - input.input_path= \ - input.input_type=gmm_data \ - input.profiler_type=gmm \ + input.path= \ input.rank_list=all \ - gmm.visualizer.vis_type=gmm_heatmap \ - gmm.parser.step=1 \ - gmm.parser.role=actor_compute_log_prob \ - output.output_path= + heatmap.parser.type=gmm \ + heatmap.visualizer.type=gmm_heatmap \ + heatmap.parser.step=1 \ + heatmap.parser.role=actor_compute_log_prob \ + output.path= ``` 或修改并直接使用 `examples/gmm_exec.sh` 脚本: @@ -94,24 +93,23 @@ bash examples/gmm_exec.sh | 参数 | 默认值 | 说明 | |------|--------|------| -| `input.input_path` | (必填) | GMM 数据的根目录路径 | -| `input.input_type` | `multi_json_mstx` | 输入数据类型,GMM 功能需设置为 `gmm_data` | -| `input.profiler_type` | `mstx` | 性能数据种类,GMM 功能需设置为 `gmm` | -| `output.output_path` | `output` | 输出路径,若为文件夹则在其中生成 `gmm_heatmap.png` | +| `input.path` | (必填) | GMM 数据的根目录路径 | +| `output.path` | `output` | 输出路径,若为文件夹则在其中生成 `gmm_heatmap.png` | | `input.rank_list` | `all` | Rank ID 列表,默认 `all` 表示所有 rank,可指定多个 rank 用逗号分隔 | -| `preset` | 自动推断 | 预设名称:`timeline`、`gmm`(根据 `profiler_type` 自动推断) | +| `preset` | 自动推断 | 预设名称:`timeline`、`heatmap`(根据 CLI 参数自动推断) | | `config_path` | 无 | YAML 配置文件路径 | -### 3.2 GMM 参数 +### 3.2 Heatmap 专属参数 | 参数 | 默认值 | 说明 | |------|--------|------| -| `gmm.visualizer.vis_type` | `gmm_heatmap` | 可视化类型 | -| `gmm.parser.step` | 无 | 特定的 step 进行可视化(可选,支持 `1` 或 `1,2`) | -| `gmm.parser.role` | 无 | 特定的 role 进行可视化(可选) | -| `gmm.visualizer.dpi` | `200` | 热力图输出的 DPI | -| `gmm.visualizer.cmap` | `viridis` | 热力图的颜色映射 | -| `gmm.visualizer.gmm_per_layer` | `3` | 每个 MoE layer 前向阶段预期的 grouped_matmul 次数 | +| `heatmap.parser.type` | `gmm` | 解析器类型:`gmm` | +| `heatmap.visualizer.type` | `gmm_heatmap` | 可视化类型 | +| `heatmap.parser.step` | 无 | 特定的 step 进行可视化(可选,支持 `1` 或 `1,2`) | +| `heatmap.parser.role` | 无 | 特定的 role 进行可视化(可选) | +| `heatmap.visualizer.dpi` | `200` | 热力图输出的 DPI | +| `heatmap.visualizer.cmap` | `viridis` | 热力图的颜色映射 | +| `heatmap.visualizer.gmm_per_layer` | `3` | 每个 MoE layer 前向阶段预期的 grouped_matmul 次数 | ## 四、输出说明 @@ -140,11 +138,11 @@ bash examples/gmm_exec.sh ## 五、注意事项 -1. GMM 热力图功能需要使用 `input.input_type=gmm_data` 和 `input.profiler_type=gmm` 参数 -2. 当 `output.output_path` 只指定文件夹路径时,会在该文件夹中生成 `gmm_heatmap.png` 文件 -3. 当不指定 `gmm.parser.step`、`gmm.parser.role` 或 `input.rank_list` 参数时,默认显示所有数据 +1. 热力图功能使用 `preset=heatmap` 或在 CLI 中指定 `heatmap.` 开头的参数即可启用 +2. 当 `output.path` 只指定文件夹路径时,会在该文件夹中生成 `gmm_heatmap.png` 文件 +3. 当不指定 `heatmap.parser.step`、`heatmap.parser.role` 或 `input.rank_list` 参数时,默认显示所有数据 4. 对于大量数据,工具会自动调整图表大小和标签显示密度,确保可读性 5. 数据文件需包含有效的专家负载数据,包括 step、role、rank_id、stage、expert_index 和 load 等字段 -6. 若你的模型实现中每层 grouped_matmul 次数不等于 3,请显式设置 `gmm.visualizer.gmm_per_layer` 以获得更准确的 actor_update 前向阶段截断结果 +6. 若你的模型实现中每层 grouped_matmul 次数不等于 3,请显式设置 `heatmap.visualizer.gmm_per_layer` 以获得更准确的 actor_update 前向阶段截断结果 -目录与 JSON 字段的集中说明另见 [数据规格与格式说明](../data/data_specification.md)。运行时校验逻辑以 `rl_insight.data.DataChecker` 及 [`rl_insight/data/rules.py`](../../rl_insight/data/rules.py) 中的规则定义为准。 \ No newline at end of file +目录与 JSON 字段的集中说明另见 [数据规格与格式说明](../data/data_specification.md)。运行时校验逻辑以 `rl_insight.data.DataChecker` 及 [`rl_insight/data/rules.py`](../../rl_insight/data/rules.py) 中的规则定义为准。 diff --git a/examples/gmm_exec.sh b/examples/gmm_exec.sh index 16fd9a0..64be66d 100644 --- a/examples/gmm_exec.sh +++ b/examples/gmm_exec.sh @@ -28,29 +28,30 @@ echo "==========================================" # Build command cmd="python -m rl_insight.main \ - input.input_path=\"${GMM_DATA_PATH}\" \ - input.input_type=gmm_data \ - input.profiler_type=gmm \ - output.output_path=\"${OUTPUT_PATH}\" \ + input.path=\"${GMM_DATA_PATH}\" \ + output.path=\"${OUTPUT_PATH}\" \ input.rank_list=\"${RANK_LIST}\" \ - gmm.visualizer.vis_type=gmm_heatmap \ - gmm.visualizer.dpi=\"${DPI}\" \ - gmm.visualizer.cmap=\"${CMAP}\" \ - gmm.visualizer.gmm_per_layer=\"${GMM_PER_LAYER}\"" -# Add step and role parameters if specified + heatmap.parser.type=gmm \ + heatmap.visualizer.type=gmm_heatmap \ + heatmap.visualizer.dpi=\"${DPI}\" \ + heatmap.visualizer.cmap=\"${CMAP}\" \ + heatmap.visualizer.gmm_per_layer=\"${GMM_PER_LAYER}\"" + if [ -n "${STEP}" ]; then cmd="${cmd} \ - gmm.parser.step=\"${STEP}\"" + heatmap.parser.step=\"${STEP}\"" fi if [ -n "${ROLE}" ]; then cmd="${cmd} \ - gmm.parser.role=\"${ROLE}\"" + heatmap.parser.role=\"${ROLE}\"" fi +# Execute the command echo ">>> Generating GMM expert load heatmap..." eval ${cmd} +# Check if the heatmap was generated successfully if [ -f "${OUTPUT_PATH}" ]; then echo "==========================================" echo ">>> Heatmap generated successfully!" diff --git a/examples/mstx_exec.sh b/examples/mstx_exec.sh index ea55c40..2d6ee20 100644 --- a/examples/mstx_exec.sh +++ b/examples/mstx_exec.sh @@ -25,12 +25,11 @@ python -m rl_insight.utils.mstx_preprocessing "${MSTX_PROFILER_DATA_PATH}" echo ">>> Mstx data preprocessing completed." python -m rl_insight.main \ - input.input_path="${MSTX_PROFILER_DATA_PATH}" \ - input.profiler_type=mstx \ - input.input_type=multi_json_mstx \ - input.rank_list="${RANK_LIST}" - output.output_path="${OUTPUT_PATH}" \ - timeline.visualizer.vis_type="${VIS_TYPE}" + input.path="${MSTX_PROFILER_DATA_PATH}" \ + timeline.parser.type=mstx \ + input.rank_list="${RANK_LIST}" \ + output.path="${OUTPUT_PATH}" \ + timeline.visualizer.type="${VIS_TYPE}" echo "==========================================" echo ">>> Analysis completed successfully!" diff --git a/examples/nvtx_exec.sh b/examples/nvtx_exec.sh index f46aa17..c085519 100644 --- a/examples/nvtx_exec.sh +++ b/examples/nvtx_exec.sh @@ -17,12 +17,11 @@ echo "Rank List: ${RANK_LIST}" echo "==========================================" python -m rl_insight.main \ - input.input_path="${NVTX_PROFILER_DATA_PATH}" \ - input.profiler_type=nvtx \ - input.input_type=multi_json_nvtx \ + input.path="${NVTX_PROFILER_DATA_PATH}" \ + timeline.parser.type=nvtx \ input.rank_list="${RANK_LIST}" \ - output.output_path="${OUTPUT_PATH}" \ - timeline.visualizer.vis_type="${VIS_TYPE}" + output.path="${OUTPUT_PATH}" \ + timeline.visualizer.type="${VIS_TYPE}" echo "==========================================" echo ">>> Analysis completed successfully!" diff --git a/examples/torch_profiler_exec.sh b/examples/torch_profiler_exec.sh index ca9612d..8a50db0 100644 --- a/examples/torch_profiler_exec.sh +++ b/examples/torch_profiler_exec.sh @@ -17,12 +17,11 @@ echo "Rank List: ${RANK_LIST}" echo "==========================================" python -m rl_insight.main \ - input.input_path="${TORCH_PROFILER_DATA_PATH}" \ - input.profiler_type=torch \ - input.input_type=multi_json_torch \ + input.path="${TORCH_PROFILER_DATA_PATH}" \ + timeline.parser.type=torch \ input.rank_list="${RANK_LIST}" \ - output.output_path="${OUTPUT_PATH}" \ - timeline.visualizer.vis_type="${VIS_TYPE}" + output.path="${OUTPUT_PATH}" \ + timeline.visualizer.type="${VIS_TYPE}" echo "==========================================" echo ">>> Analysis completed successfully!" diff --git a/rl_insight/config/__init__.py b/rl_insight/config/__init__.py index 3f2c7d3..b069cf9 100644 --- a/rl_insight/config/__init__.py +++ b/rl_insight/config/__init__.py @@ -14,9 +14,9 @@ from .config import ( AppConfig, - GmmConfig, - GmmParserConfig, - GmmVisualizerConfig, + HeatmapConfig, + HeatmapParserConfig, + HeatmapVisualizerConfig, InputConfig, OutputConfig, PipelineConfig, @@ -29,9 +29,9 @@ __all__ = [ "AppConfig", - "GmmConfig", - "GmmParserConfig", - "GmmVisualizerConfig", + "HeatmapConfig", + "HeatmapParserConfig", + "HeatmapVisualizerConfig", "InputConfig", "OutputConfig", "PipelineConfig", diff --git a/rl_insight/config/config.py b/rl_insight/config/config.py index 16efb3e..58263d3 100644 --- a/rl_insight/config/config.py +++ b/rl_insight/config/config.py @@ -22,9 +22,7 @@ class InputConfig: """Input data configuration.""" - input_path: str = MISSING # Path to profiling data (required) - input_type: str = "multi_json_mstx" # multi_json_mstx | multi_json_torch | multi_json_nvtx | gmm_data - profiler_type: str = "mstx" # mstx | torch | nvtx | gmm + path: str = MISSING # Path to profiling data (required) rank_list: str = "all" # Rank id list, e.g. '0,1,2' or 'all' @@ -32,19 +30,21 @@ class InputConfig: class OutputConfig: """Output configuration.""" - output_path: str = "output" # Output directory path + path: str = "output" # Output directory path @dataclass class TimelineParserConfig: - """Timeline parser filter configuration.""" + """Timeline parser configuration.""" + + type: Optional[str] = None # mstx | torch | nvtx @dataclass class TimelineVisualizerConfig: """Timeline visualizer configuration.""" - vis_type: str = "html" # html | png + type: str = "html" # html | png width: int = 2000 # Image width in pixels (png only) scale: int = 2 # Image scale factor (png only) @@ -60,36 +60,37 @@ class TimelineConfig: @dataclass -class GmmParserConfig: - """GMM parser filter configuration.""" +class HeatmapParserConfig: + """Heatmap parser configuration.""" + type: Optional[str] = None # gmm step: Optional[str] = None # Step filter, e.g. '1' or '1,2' role: Optional[str] = None # Role filter @dataclass -class GmmVisualizerConfig: - """GMM heatmap visualizer configuration.""" +class HeatmapVisualizerConfig: + """Heatmap visualizer configuration.""" - vis_type: str = "gmm_heatmap" # gmm_heatmap + type: str = "gmm_heatmap" # gmm_heatmap dpi: int = 200 # DPI for heatmap PNG output cmap: str = "viridis" # Matplotlib colormap name gmm_per_layer: int = 3 # Grouped matmul count per MoE layer @dataclass -class GmmConfig: - """GMM configuration.""" +class HeatmapConfig: + """Heatmap configuration.""" - parser: GmmParserConfig = field(default_factory=GmmParserConfig) - visualizer: GmmVisualizerConfig = field(default_factory=GmmVisualizerConfig) + parser: HeatmapParserConfig = field(default_factory=HeatmapParserConfig) + visualizer: HeatmapVisualizerConfig = field(default_factory=HeatmapVisualizerConfig) @dataclass class PipelineConfig: """Pipeline configuration.""" - pipeline_type: str = "OfflineInsightPipeline" # OfflineInsightPipeline + type: str = "OfflineInsightPipeline" # OfflineInsightPipeline @dataclass @@ -100,4 +101,4 @@ class AppConfig: input: InputConfig = field(default_factory=InputConfig) output: OutputConfig = field(default_factory=OutputConfig) timeline: TimelineConfig = field(default_factory=TimelineConfig) - gmm: GmmConfig = field(default_factory=GmmConfig) + heatmap: HeatmapConfig = field(default_factory=HeatmapConfig) diff --git a/rl_insight/config/config_loader.py b/rl_insight/config/config_loader.py index 597a3a7..a24f780 100644 --- a/rl_insight/config/config_loader.py +++ b/rl_insight/config/config_loader.py @@ -60,11 +60,11 @@ def _examples() -> str: return ( "\n" "Examples:\n" - " python -m rl_insight.main input.input_path=./data/mstx_data/mstx_profile\n" - " python -m rl_insight.main input.input_path=./data/gmm_data input.profiler_type=gmm\n" - " python -m rl_insight.main config_path=my_config.yaml gmm.visualizer.dpi=300\n" - " python -m rl_insight.main preset=timeline timeline.visualizer.vis_type=png\n" - " python -m rl_insight.main preset=gmm input.input_path=./data/gmm_data\n" + " python -m rl_insight.main input.path=./data/mstx_data/mstx_profile\n" + " python -m rl_insight.main preset=heatmap input.path=./data/gmm_data\n" + " python -m rl_insight.main config_path=my_config.yaml heatmap.visualizer.dpi=300\n" + " python -m rl_insight.main preset=timeline timeline.visualizer.type=png\n" + " python -m rl_insight.main preset=timeline timeline.parser.type=torch\n" ) @staticmethod @@ -118,7 +118,7 @@ def _field_comment(field_name: str, source_lines: list[str]) -> str: class ConfigLoader: PRESETS_DIR = Path(__file__).parent - SUPPORTED_PRESETS = {"timeline", "gmm"} + SUPPORTED_PRESETS = {"timeline", "heatmap"} @classmethod def load( @@ -219,7 +219,6 @@ def _parse_special_args( @staticmethod def _infer_preset_from_args(args: list[str]) -> Optional[str]: for arg in args: - if arg.startswith("input.profiler_type="): - profiler_type = arg.split("=", 1)[1] - return "gmm" if profiler_type == "gmm" else "timeline" + if arg.startswith("heatmap."): + return "heatmap" return None diff --git a/rl_insight/config/gmm.yaml b/rl_insight/config/heatmap.yaml similarity index 60% rename from rl_insight/config/gmm.yaml rename to rl_insight/config/heatmap.yaml index a8f3eaa..664e455 100644 --- a/rl_insight/config/gmm.yaml +++ b/rl_insight/config/heatmap.yaml @@ -1,22 +1,21 @@ -# GMM visualization preset +# Heatmap visualization preset pipeline: - pipeline_type: OfflineInsightPipeline # OfflineInsightPipeline + type: OfflineInsightPipeline # OfflineInsightPipeline input: - input_type: gmm_data # gmm_data - profiler_type: gmm # gmm rank_list: all # Rank id list, e.g. '0,1,2' or 'all' output: - output_path: output # Output directory path + path: output # Output directory path -gmm: +heatmap: parser: + type: gmm # gmm step: null # Step filter, e.g. '1' or '1,2' role: null # Role filter visualizer: - vis_type: gmm_heatmap # gmm_heatmap + type: gmm_heatmap # gmm_heatmap dpi: 200 # DPI for heatmap PNG output cmap: viridis # Matplotlib colormap name gmm_per_layer: 3 # Grouped matmul count per MoE layer diff --git a/rl_insight/config/timeline.yaml b/rl_insight/config/timeline.yaml index 20dcc30..d92136b 100644 --- a/rl_insight/config/timeline.yaml +++ b/rl_insight/config/timeline.yaml @@ -1,19 +1,18 @@ # Timeline visualization preset pipeline: - pipeline_type: OfflineInsightPipeline # OfflineInsightPipeline + type: OfflineInsightPipeline # OfflineInsightPipeline input: - input_type: multi_json_mstx # multi_json_mstx | multi_json_torch | multi_json_nvtx - profiler_type: mstx # mstx | torch | nvtx rank_list: all # Rank id list, e.g. '0,1,2' or 'all' output: - output_path: output # Output directory path + path: output # Output directory path timeline: - parser: {} # No parser-specific config yet + parser: + type: mstx # mstx | torch | nvtx visualizer: - vis_type: html # html | png + type: html # html | png width: 2000 # Image width in pixels (png only) scale: 2 # Image scale factor (png only) \ No newline at end of file diff --git a/rl_insight/config/utils.py b/rl_insight/config/utils.py index 3a675a1..c9c7290 100644 --- a/rl_insight/config/utils.py +++ b/rl_insight/config/utils.py @@ -22,12 +22,12 @@ def get_config_value( ) -> Any: """Retrieve a value from a DictConfig or dict using a dot-separated key. - Supports both nested access (``output.output_path``) and flat key fallback - (``output_output_path``) for backward compatibility. + Supports both nested access (``output.path``) and flat key fallback + (``output_path``) for backward compatibility. Args: config: Configuration object (DictConfig or plain dict). - key: Dot-separated key path, e.g. ``"output.output_path"``. + key: Dot-separated key path, e.g. ``"output.path"``. default: Value returned when the key is not found. Returns: diff --git a/rl_insight/main.py b/rl_insight/main.py index 566e975..72ce0c1 100644 --- a/rl_insight/main.py +++ b/rl_insight/main.py @@ -29,13 +29,13 @@ def run_pipeline(config: DictConfig, pipeline_class=None): def validate_config(cfg: DictConfig) -> None: - if cfg.input.input_path is None: - raise ValueError("input.input_path is required") + if cfg.input.path is None: + raise ValueError("input.path is required") - if cfg.pipeline.pipeline_type not in SUPPORTED_PIPELINE_TYPES: + if cfg.pipeline.type not in SUPPORTED_PIPELINE_TYPES: supported_types = ", ".join(SUPPORTED_PIPELINE_TYPES.keys()) raise ValueError( - f"Unsupported pipeline type: {cfg.pipeline.pipeline_type}. " + f"Unsupported pipeline type: {cfg.pipeline.type}. " f"Supported types are: {supported_types}" ) @@ -43,7 +43,7 @@ def validate_config(cfg: DictConfig) -> None: def main(): cfg = ConfigLoader.load_from_cli() validate_config(cfg) - pipeline_class = SUPPORTED_PIPELINE_TYPES[cfg.pipeline.pipeline_type] + pipeline_class = SUPPORTED_PIPELINE_TYPES[cfg.pipeline.type] run_pipeline(cfg, pipeline_class) diff --git a/rl_insight/parser/gmm_parser.py b/rl_insight/parser/gmm_parser.py index 01b0c17..9558b7f 100644 --- a/rl_insight/parser/gmm_parser.py +++ b/rl_insight/parser/gmm_parser.py @@ -44,7 +44,7 @@ def __init__(self, params: Union[DictConfig, dict]) -> None: if rank.strip().isdigit() ] ) - step = get_config_value(params, "gmm.parser.step", None) + step = get_config_value(params, "heatmap.parser.step", None) if step is None: self._step_list: Optional[list[int]] = None elif isinstance(step, int): @@ -60,7 +60,7 @@ def __init__(self, params: Union[DictConfig, dict]) -> None: "Will process all steps." ) self._step_list = None - self._role = get_config_value(params, "gmm.parser.role", None) + self._role = get_config_value(params, "heatmap.parser.role", None) @staticmethod def _normalize_path_text(path_value: str | Path) -> str: diff --git a/rl_insight/pipeline/offline_insight_pipeline.py b/rl_insight/pipeline/offline_insight_pipeline.py index 4a83c57..e3744e6 100644 --- a/rl_insight/pipeline/offline_insight_pipeline.py +++ b/rl_insight/pipeline/offline_insight_pipeline.py @@ -14,7 +14,7 @@ from omegaconf import DictConfig -from rl_insight.data import DataChecker, DataEnum +from rl_insight.data import DataChecker from rl_insight.parser import get_cluster_parser_cls from rl_insight.visualizer import get_cluster_visualizer_cls @@ -23,18 +23,18 @@ class OfflineInsightPipeline: def __init__(self, config: DictConfig): self.config = config - self.input_data_type = DataEnum(self.config.input.input_type) + timeline_parser_type = config.timeline.parser.type + if timeline_parser_type is not None: + parser_cls = get_cluster_parser_cls(timeline_parser_type) + visualizer_cls = get_cluster_visualizer_cls(config.timeline.visualizer.type) + else: + parser_cls = get_cluster_parser_cls(config.heatmap.parser.type) + visualizer_cls = get_cluster_visualizer_cls(config.heatmap.visualizer.type) parser_config = self._prepare_parser_config() - parser_cls = get_cluster_parser_cls(self.config.input.profiler_type) self.parser = parser_cls(parser_config) visualizer_config = self._prepare_visualizer_config() - visualizer_cls = get_cluster_visualizer_cls( - self.config.timeline.visualizer.vis_type - if self.config.input.profiler_type != "gmm" - else self.config.gmm.visualizer.vis_type - ) self.visualizer = visualizer_cls(visualizer_config) def _prepare_parser_config(self) -> DictConfig: @@ -44,15 +44,9 @@ def _prepare_visualizer_config(self) -> DictConfig: return self.config def run(self): - if self.input_data_type != self.parser.input_type: - raise ValueError( - f"Input data type {self.input_data_type} does not match " - f"parser input type {self.parser.input_type}" - ) + DataChecker(self.parser.input_type, self.config.input.path).run() - DataChecker(self.input_data_type, self.config.input.input_path).run() - - output_data = self.parser.run(self.config.input.input_path) + output_data = self.parser.run(self.config.input.path) DataChecker(self.visualizer.input_type, output_data).run() diff --git a/rl_insight/visualizer/gmm_visualizer.py b/rl_insight/visualizer/gmm_visualizer.py index a7387d2..28fc901 100644 --- a/rl_insight/visualizer/gmm_visualizer.py +++ b/rl_insight/visualizer/gmm_visualizer.py @@ -51,13 +51,13 @@ def run(self, data): """Run GMM heatmap visualization from parsed data.""" # Extract parameters from config output_cfg = get_config_value( - self.config, "output.output_path", "./output/gmm_group_list_heatmap.png" + self.config, "output.path", "./output/gmm_group_list_heatmap.png" ) output = self._resolve_output_path(output_cfg) - dpi = get_config_value(self.config, "gmm.visualizer.dpi", 200) - cmap = get_config_value(self.config, "gmm.visualizer.cmap", "viridis") + dpi = get_config_value(self.config, "heatmap.visualizer.dpi", 200) + cmap = get_config_value(self.config, "heatmap.visualizer.cmap", "viridis") gmm_per_layer = int( - get_config_value(self.config, "gmm.visualizer.gmm_per_layer", 3) + get_config_value(self.config, "heatmap.visualizer.gmm_per_layer", 3) ) if not isinstance(data, pd.DataFrame): diff --git a/rl_insight/visualizer/timeline_visualizer.py b/rl_insight/visualizer/timeline_visualizer.py index 5819d87..e4f8913 100644 --- a/rl_insight/visualizer/timeline_visualizer.py +++ b/rl_insight/visualizer/timeline_visualizer.py @@ -52,7 +52,7 @@ class RLTimelineVisualizer(BaseVisualizer): def __init__(self, config: Union[DictConfig, dict]): super().__init__(config) - self.output_path = get_config_value(config, "output.output_path", None) + self.output_path = get_config_value(config, "output.path", None) def run(self, data): return self.generate_rl_timeline(data) @@ -410,7 +410,7 @@ class RLTimelinePNGVisualizer(BaseVisualizer): def __init__(self, config: Union[DictConfig, dict]): super().__init__(config) - self.output_path = get_config_value(config, "output.output_path", None) + self.output_path = get_config_value(config, "output.path", None) self.width = get_config_value(config, "timeline.visualizer.width", 2000) self.scale = get_config_value(config, "timeline.visualizer.scale", 2) diff --git a/tests/parser/test_cluster_analysis.py b/tests/parser/test_cluster_analysis.py index 57482fa..309a918 100644 --- a/tests/parser/test_cluster_analysis.py +++ b/tests/parser/test_cluster_analysis.py @@ -52,8 +52,8 @@ def _timeline_viz(**kwargs): cfg = { - "output": {"output_path": "/tmp"}, - "timeline": {"visualizer": {"vis_type": "html"}}, + "output": {"path": "/tmp"}, + "timeline": {"visualizer": {"type": "html"}}, } cfg.update(kwargs) return RLTimelineVisualizer(cfg) @@ -1007,7 +1007,7 @@ def test_full_pipeline_with_mock_data(self, mock_mstx_profiler_structure, tmp_pa @patch( "sys.argv", - ["main.py", "input.input_path=/tmp", "input.profiler_type=mstx"], + ["main.py", "input.path=/tmp", "timeline.parser.type=mstx"], ) @patch("rl_insight.pipeline.offline_insight_pipeline.DataChecker.run") @patch("rl_insight.pipeline.offline_insight_pipeline.get_cluster_parser_cls") diff --git a/tests/parser/test_png_visualizer.py b/tests/parser/test_png_visualizer.py index 8a81b3b..bbaeafe 100644 --- a/tests/parser/test_png_visualizer.py +++ b/tests/parser/test_png_visualizer.py @@ -23,7 +23,7 @@ def visualizer(): """Initialize the visualizer instance for testing.""" config = { - "output": {"output_path": "test_output"}, + "output": {"path": "test_output"}, "timeline": {"visualizer": {"width": 2000, "scale": 2}}, } return RLTimelinePNGVisualizer(config) diff --git a/tests/special_e2e/test_gmm_e2e.py b/tests/special_e2e/test_gmm_e2e.py index 397276e..919c681 100644 --- a/tests/special_e2e/test_gmm_e2e.py +++ b/tests/special_e2e/test_gmm_e2e.py @@ -34,11 +34,9 @@ def test_gmm_e2e_with_repo_sample_data(monkeypatch, tmp_path): test_args = [ "main.py", - f"input.input_path={input_dir}", - f"output.output_path={output_dir}", - "input.profiler_type=gmm", - "input.input_type=gmm_data", - "gmm.visualizer.vis_type=gmm_heatmap", + f"input.path={input_dir}", + f"output.path={output_dir}", + "heatmap.visualizer.type=gmm_heatmap", ] monkeypatch.setattr(sys, "argv", test_args) diff --git a/tests/special_e2e/test_mstx_e2e.py b/tests/special_e2e/test_mstx_e2e.py index 393aefa..e8a38d9 100644 --- a/tests/special_e2e/test_mstx_e2e.py +++ b/tests/special_e2e/test_mstx_e2e.py @@ -31,9 +31,9 @@ def test_mstx_e2e_with_input_path(monkeypatch, tmp_path): test_args = [ "main.py", - f"input.input_path={input_dir}", - f"output.output_path={output_dir}", - "input.profiler_type=mstx", + f"input.path={input_dir}", + f"output.path={output_dir}", + "timeline.parser.type=mstx", ] monkeypatch.setattr(sys, "argv", test_args) diff --git a/tests/special_e2e/test_nvtx_e2e.py b/tests/special_e2e/test_nvtx_e2e.py index 831af89..8dbd852 100644 --- a/tests/special_e2e/test_nvtx_e2e.py +++ b/tests/special_e2e/test_nvtx_e2e.py @@ -31,10 +31,9 @@ def test_nvtx_e2e_with_input_path(monkeypatch, tmp_path): test_args = [ "main.py", - f"input.input_path={input_dir}", - f"output.output_path={output_dir}", - "input.profiler_type=nvtx", - "input.input_type=multi_json_nvtx", + f"input.path={input_dir}", + f"output.path={output_dir}", + "timeline.parser.type=nvtx", ] monkeypatch.setattr(sys, "argv", test_args) diff --git a/tests/special_e2e/test_torch_e2e.py b/tests/special_e2e/test_torch_e2e.py index b46e21f..a019919 100644 --- a/tests/special_e2e/test_torch_e2e.py +++ b/tests/special_e2e/test_torch_e2e.py @@ -31,10 +31,9 @@ def test_torch_e2e_with_input_path(monkeypatch, tmp_path): test_args = [ "main.py", - f"input.input_path={input_dir}", - f"output.output_path={output_dir}", - "input.profiler_type=torch", - "input.input_type=multi_json_torch", + f"input.path={input_dir}", + f"output.path={output_dir}", + "timeline.parser.type=torch", ] monkeypatch.setattr(sys, "argv", test_args)