diff --git a/.github/workflows/check-pr-title.yml b/.github/workflows/check_pr_title.yml similarity index 98% rename from .github/workflows/check-pr-title.yml rename to .github/workflows/check_pr_title.yml index f9edb3f..1c78783 100644 --- a/.github/workflows/check-pr-title.yml +++ b/.github/workflows/check_pr_title.yml @@ -27,13 +27,14 @@ # - new workflow yaml is added to `.github/workflows` # - new tests are added to workflow mentioned in 2. +name: check_pr_title on: pull_request: types: [opened, edited, synchronize] jobs: - check-title: + check_title: runs-on: ubuntu-latest steps: - name: Checkout code diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre_commit.yml similarity index 97% rename from .github/workflows/pre-commit.yml rename to .github/workflows/pre_commit.yml index 4f6aa4b..45c2bbb 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre_commit.yml @@ -1,5 +1,5 @@ # c.f. https://github.com/pre-commit/action?tab=readme-ov-file#using-this-action -name: pre-commit +name: pre_commit # No need to avoid / cancel lightweight pre-commit jobs on: @@ -18,7 +18,7 @@ permissions: contents: read jobs: - pre-commit: + pre_commit: runs-on: ubuntu-latest strategy: matrix: diff --git a/.github/workflows/special_e2e.yml b/.github/workflows/special_e2e.yml index 28cb80e..d63066a 100644 --- a/.github/workflows/special_e2e.yml +++ b/.github/workflows/special_e2e.yml @@ -1,4 +1,4 @@ -name: profiling_data_analysis_st +name: special_e2e on: push: @@ -22,7 +22,7 @@ permissions: contents: read jobs: - profiling_data_analysis_st: + special_e2e: runs-on: ubuntu-latest timeout-minutes: 5 strategy: @@ -42,6 +42,6 @@ jobs: pip install -r requirements.txt pip install -e . - - name: Run profiling_data_analysis_st tests + - name: Run rl-insight e2e tests run: | pytest -s -x tests/special_e2e diff --git a/.github/workflows/cluster_analysis.yml b/.github/workflows/unit_test.yml similarity index 78% rename from .github/workflows/cluster_analysis.yml rename to .github/workflows/unit_test.yml index 6845171..a5a11d4 100644 --- a/.github/workflows/cluster_analysis.yml +++ b/.github/workflows/unit_test.yml @@ -1,4 +1,4 @@ -name: cluster_analyse +name: unit_test on: push: @@ -11,8 +11,8 @@ on: - v0.* paths: - "**/*.py" - - .github/workflows/cluster_analysis.yml - - "tests/cluster_analysis/**" + - .github/workflows/unit_test.yml + - "tests/**" concurrency: group: ${{ github.workflow }}-${{ github.ref }} @@ -22,7 +22,7 @@ permissions: contents: read jobs: - cluster_analyse: + unit_test: runs-on: ubuntu-latest timeout-minutes: 5 strategy: @@ -42,6 +42,10 @@ jobs: pip install -r requirements.txt pip install -e . - - name: Run cluster_analyse tests + - name: Run parser tests run: | - pytest -s -x tests/cluster_analysis + pytest -s -x tests/parser + + - name: Run data_checker tests + run: | + pytest -s -x tests/data diff --git a/data/base.py b/data/base.py deleted file mode 100644 index 64019d6..0000000 --- a/data/base.py +++ /dev/null @@ -1,20 +0,0 @@ -# Copyright (c) 2025 verl-project authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -class BaseData: - def __init__(self, params) -> None: - pass - - @classmethod - def type_check(cls, params): - return True diff --git a/data/multi_json.py b/data/multi_json.py deleted file mode 100644 index 8be1d6c..0000000 --- a/data/multi_json.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright (c) 2025 verl-project authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/data/summary_event.py b/data/summary_event.py deleted file mode 100644 index 8be1d6c..0000000 --- a/data/summary_event.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright (c) 2025 verl-project authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/data/verl_log.py b/data/verl_log.py deleted file mode 100644 index 8be1d6c..0000000 --- a/data/verl_log.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright (c) 2025 verl-project authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/docs/cluster_analysis.md b/docs/cluster_analysis.md index ae6ff03..514a232 100644 --- a/docs/cluster_analysis.md +++ b/docs/cluster_analysis.md @@ -17,6 +17,7 @@ RL-Insight 是一个强化学习性能数据快速分析的可视化工具,基 - Pandas - Plotly - NumPy +- Loguru ## 二、快速使用 diff --git a/pyproject.toml b/pyproject.toml index fe1c72f..9c21ff4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,7 @@ dependencies = [ "pandas", "plotly", "pytest", + "loguru" ] [project.urls] diff --git a/requirements.txt b/requirements.txt index e343bf1..bdee79b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,4 +2,5 @@ numpy<2.0.0 pandas plotly -pytest \ No newline at end of file +pytest +loguru \ No newline at end of file diff --git a/tests/cluster_analysis/__init__.py b/rl_insight/data/__init__.py similarity index 81% rename from tests/cluster_analysis/__init__.py rename to rl_insight/data/__init__.py index 8be1d6c..509cfcf 100644 --- a/tests/cluster_analysis/__init__.py +++ b/rl_insight/data/__init__.py @@ -11,3 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +"""Data module for RL-Insight.""" + +from .data_checker import DataChecker, DataEnum + +__all__ = [ + "DataChecker", + "DataEnum", +] diff --git a/rl_insight/data/data_checker.py b/rl_insight/data/data_checker.py new file mode 100644 index 0000000..37fc55a --- /dev/null +++ b/rl_insight/data/data_checker.py @@ -0,0 +1,60 @@ +# Copyright (c) 2025 verl-project authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Base data definitions for RL-Insight.""" + +from typing import Any, List +from .rules import ValidationRule, PathExistsRule, DataValidationError +from enum import Enum +from loguru import logger + + +class DataEnum(Enum): + """Enum for data types in RL-Insight.""" + + # input data type of parser + MULTI_JSON = "multi_json" + VERL_LOG = "verl_log" + # output data type of parser, input data type of visualizer + SUMMARY_EVENT = "summary_event" + # other data type + UNKNOWN = "unknown" + + +class DataChecker: + """Base data class for RL-Insight.""" + + rules: dict[DataEnum, List[ValidationRule]] = { + DataEnum.MULTI_JSON: [PathExistsRule()], + DataEnum.VERL_LOG: [], + DataEnum.SUMMARY_EVENT: [], + DataEnum.UNKNOWN: [], + } + + def __init__(self, data_type: DataEnum, data: Any): + self.data_type = data_type + self.data = data + + def run(self): + """Validate the data""" + errors = [] + if self.data_type not in self.rules: + raise ValueError(f"Invalid data type: {self.data_type}") + rules = self.rules[self.data_type] + for rule in rules: + if not rule.check(self.data): + errors.append(rule.error_message) + if errors: + raise DataValidationError("Data validation failed", errors) + logger.info(f"Data validation passed for {self.data_type}") diff --git a/rl_insight/data/rules.py b/rl_insight/data/rules.py new file mode 100644 index 0000000..d6e069f --- /dev/null +++ b/rl_insight/data/rules.py @@ -0,0 +1,64 @@ +# Copyright (c) 2025 verl-project authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Any +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Optional + + +class DataValidationError(Exception): + """Exception raised when data validation fails.""" + + def __init__(self, message: str, errors: Optional[List[str]] = None): + super().__init__(message) + self.errors = errors or [] + + def __str__(self) -> str: + if self.errors: + return f"{super().__str__()}\n - " + "\n - ".join(self.errors) + return super().__str__() + + +class ValidationRule(ABC): + """Validation rule base class""" + + def __init__(self): + self._error_message: str = "" + + @abstractmethod + def check(self, data) -> bool: + pass + + @property + def error_message(self) -> str: + return self._error_message + + +class PathExistsRule(ValidationRule): + def check(self, data: Any) -> bool: + if not isinstance(data, str): + self._error_message = "Data object is not a path" + return False + try: + path = Path(data) + if not path.is_dir(): + self._error_message = ( + f"Source path is not a directory or does not exist: {data}" + ) + return False + return True + except TypeError as e: + self._error_message = f"Error checking path {data}: {e}" + return False diff --git a/rl_insight/main.py b/rl_insight/main.py index 078e562..83069f6 100644 --- a/rl_insight/main.py +++ b/rl_insight/main.py @@ -29,12 +29,17 @@ def run_pipeline(config, pipeline_class=None): def main(): arg_parser = argparse.ArgumentParser(description="Cluster scheduling visualization") arg_parser.add_argument( - "--input-path", default="test", help="Raw path of profiling data" + "--input-path", required=True, help="Raw path of profiling data" + ) + arg_parser.add_argument( + "--input-type", + default="multi_json", + help="Input data type. Supported: 'multi_json' (for nvtx/mstx/torch_profile from different directories).", ) arg_parser.add_argument( "--profiler-type", default="mstx", help="Profiler type, supported mstx/nvtx" ) - arg_parser.add_argument("--output-path", default="test", help="Output path") + arg_parser.add_argument("--output-path", default="output", help="Output path") arg_parser.add_argument( "--vis-type", default="html", help="Visualization type, supported html" ) diff --git a/rl_insight/parser/mstx_parser.py b/rl_insight/parser/mstx_parser.py index c8b7cc6..8cbcb4e 100644 --- a/rl_insight/parser/mstx_parser.py +++ b/rl_insight/parser/mstx_parser.py @@ -13,7 +13,7 @@ # limitations under the License. import json -import logging +from loguru import logger import os from collections import defaultdict from pathlib import Path @@ -21,13 +21,6 @@ from .parser import BaseClusterParser, register_cluster_parser from rl_insight.utils.schema import Constant, DataMap, EventRow -logging.basicConfig( - level=logging.INFO, - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", - handlers=[logging.StreamHandler()], -) -logger = logging.getLogger(__name__) - @register_cluster_parser("mstx") class MstxClusterParser(BaseClusterParser): diff --git a/rl_insight/parser/parser.py b/rl_insight/parser/parser.py index 622d98d..1c13f63 100644 --- a/rl_insight/parser/parser.py +++ b/rl_insight/parser/parser.py @@ -12,28 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging +from loguru import logger import multiprocessing from abc import ABC, abstractmethod from concurrent.futures import ProcessPoolExecutor, as_completed from typing import Callable, Optional import pandas as pd -from rl_insight.utils.schema import Constant, DataMap, EventRow - -logging.basicConfig( - level=logging.INFO, - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", - handlers=[logging.StreamHandler()], -) -logger = logging.getLogger(__name__) +from rl_insight.data import DataEnum +from rl_insight.utils.schema import Constant, DataMap, EventRow class BaseClusterParser(ABC): + input_type: DataEnum = DataEnum.MULTI_JSON + def __init__(self, params) -> None: self.events_summary: Optional[pd.DataFrame] = None - self.input_path = params.get(Constant.INPUT_PATH, "") rank_list = params.get(Constant.RANK_LIST, "all") self._rank_list = ( rank_list @@ -41,9 +36,9 @@ def __init__(self, params) -> None: else [int(rank) for rank in rank_list.split(",") if rank.isdigit()] ) - def run(self) -> pd.DataFrame: + def run(self, input_data: str) -> pd.DataFrame: """Run parsing and return the parsed DataFrame.""" - _data_maps = self.allocate_prof_data(self.input_path) + _data_maps = self.allocate_prof_data(input_data) mapper_res = self.mapper_func(_data_maps) self.reducer_func(mapper_res) return self.get_data() @@ -131,12 +126,6 @@ def clean_data(self) -> None: def get_data(self) -> pd.DataFrame: return self.events_summary - def get_input_type(self): - pass - - def get_output_type(self): - return pd.DataFrame - @abstractmethod def allocate_prof_data(self, input_path: str) -> list[DataMap]: """ diff --git a/rl_insight/parser/torch_parser.py b/rl_insight/parser/torch_parser.py index 0fd5acb..d17fd09 100644 --- a/rl_insight/parser/torch_parser.py +++ b/rl_insight/parser/torch_parser.py @@ -14,7 +14,7 @@ import gzip import json -import logging +from loguru import logger import os from collections import defaultdict from pathlib import Path @@ -22,13 +22,6 @@ from .parser import BaseClusterParser, register_cluster_parser from rl_insight.utils.schema import Constant, DataMap, EventRow -logging.basicConfig( - level=logging.INFO, - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", - handlers=[logging.StreamHandler()], -) -logger = logging.getLogger(__name__) - @register_cluster_parser("torch") class TorchClusterParser(BaseClusterParser): diff --git a/rl_insight/pipeline/offline_insight_pipeline.py b/rl_insight/pipeline/offline_insight_pipeline.py index dbc364f..1816f44 100644 --- a/rl_insight/pipeline/offline_insight_pipeline.py +++ b/rl_insight/pipeline/offline_insight_pipeline.py @@ -12,58 +12,50 @@ # See the License for the specific language governing permissions and # limitations under the License. +from rl_insight.data import DataChecker, DataEnum from rl_insight.parser import get_cluster_parser_cls from rl_insight.utils.schema import Constant from rl_insight.visualizer.visualizer import RLTimelineVisualizer -from data.base import BaseData class OfflineInsightPipeline: def __init__(self, config): - self.input_path = config.input_path - self.profiler_type = config.profiler_type - self.output_path = config.output_path - self.vis_type = config.vis_type - self.rank_list = config.rank_list + self.config = config + + # init data + self.input_data_type = DataEnum(self.config.input_type) # parser related - self.parser_config = self._prepare_parser_config() - self.parser_cls = get_cluster_parser_cls(self.profiler_type) - self.parser = self.parser_cls(self.parser_config) - self.parser_input_type = self.parser.get_input_type() - self.parser_output_type = self.parser.get_output_type() + parser_config = self._prepare_parser_config() + parser_cls = get_cluster_parser_cls(self.config.profiler_type) + self.parser = parser_cls(parser_config) # visualizer related - self.visualizer_config = self._prepare_visualizer_config() - self.visualizer = RLTimelineVisualizer(self.visualizer_config) - self.visualizer_input_type = self.visualizer.get_input_type() + visualizer_config = self._prepare_visualizer_config() + self.visualizer = RLTimelineVisualizer(visualizer_config) def _prepare_parser_config(self): return { - Constant.INPUT_PATH: self.input_path, - Constant.RANK_LIST: self.rank_list, + Constant.RANK_LIST: self.config.rank_list, } def _prepare_visualizer_config(self): return { - "output_path": self.output_path, - "vis_type": self.vis_type, + "output_path": self.config.output_path, + "vis_type": self.config.vis_type, } - def _input_data_check(self): - if not BaseData.type_check(self.parser_input_type): + def run(self): + if self.input_data_type != self.parser.input_type: raise ValueError( - f"Parser input type {self.parser_input_type} is not a valid BaseData type" + f"Input data type {self.input_data_type} does not match parser input type {self.parser.input_type}" ) + # validate input data + DataChecker(self.input_data_type, self.config.input_path).run() - def _inter_res_check(self): - if not isinstance(self.parser_output_type, type(self.visualizer_input_type)): - raise ValueError( - f"Parser output type {self.parser_output_type} does not match visualizer input type {self.visualizer_input_type}" - ) + output_data = self.parser.run(self.config.input_path) - def run(self): - self._input_data_check() - self._inter_res_check() - data = self.parser.run() - self.visualizer.run(data) + # validate output data + DataChecker(self.visualizer.input_type, output_data).run() + + self.visualizer.run(output_data) diff --git a/rl_insight/utils/mstx_preprocessing.py b/rl_insight/utils/mstx_preprocessing.py index 91ef722..1455a98 100644 --- a/rl_insight/utils/mstx_preprocessing.py +++ b/rl_insight/utils/mstx_preprocessing.py @@ -15,17 +15,10 @@ import os import sys import argparse -import logging +from loguru import logger import torch_npu from torch_npu.profiler.profiler import analyse -logging.basicConfig( - level=logging.INFO, - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", - handlers=[logging.StreamHandler()], -) -logger = logging.getLogger(__name__) - def main(): arg_parser = argparse.ArgumentParser(description="Run mstx offline analysis") diff --git a/rl_insight/visualizer/visualizer.py b/rl_insight/visualizer/visualizer.py index bf20e87..73d80d0 100644 --- a/rl_insight/visualizer/visualizer.py +++ b/rl_insight/visualizer/visualizer.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging +from loguru import logger import os from typing import Callable from abc import ABC, abstractmethod @@ -20,15 +20,9 @@ import numpy as np import pandas as pd import plotly.graph_objects as go +from rl_insight.data import DataEnum from rl_insight.utils.schema import FigureConfig -logging.basicConfig( - level=logging.INFO, - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", - handlers=[logging.StreamHandler()], -) -logger = logging.getLogger(__name__) - ClusterVisualizerFn = Callable[ [pd.DataFrame, str, dict], None, @@ -61,12 +55,11 @@ def decorator(func: ClusterVisualizerFn) -> ClusterVisualizerFn: class BaseVisualizer(ABC): + input_type: DataEnum = DataEnum.SUMMARY_EVENT + def __init__(self, config: dict): self.config = config - def get_input_type(self): - pass - @abstractmethod def run(self): raise NotImplementedError @@ -79,9 +72,6 @@ def __init__(self, config: dict): self.vis_type = config.get("vis_type", None) self.visualizer_fn = None - def get_input_type(self): - return pd.DataFrame - def run(self, data): self.visualizer_fn = get_cluster_visualizer_fn(self.vis_type) self.visualizer_fn(data, self.output_path, self.config) diff --git a/tests/data/test_data_checker.py b/tests/data/test_data_checker.py new file mode 100644 index 0000000..c10cc33 --- /dev/null +++ b/tests/data/test_data_checker.py @@ -0,0 +1,37 @@ +# Copyright (c) 2026 verl-project authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from rl_insight.data.data_checker import DataChecker, DataEnum +from rl_insight.data.rules import DataValidationError + + +def test_data_checker_multi_json_path_exists(tmp_path): + checker = DataChecker(data_type=DataEnum.MULTI_JSON, data=str(tmp_path)) + checker.run() + + +def test_data_checker_multi_json_path_missing(): + checker = DataChecker( + data_type=DataEnum.MULTI_JSON, data="C:/definitely/not/exist/path" + ) + with pytest.raises(DataValidationError) as exc_info: + checker.run() + assert "Data validation failed" in str(exc_info.value) + + +def test_data_checker_summary_event_has_no_rule_with_dict_data(): + checker = DataChecker(data_type=DataEnum.SUMMARY_EVENT, data={"k": "v"}) + checker.run() diff --git a/tests/data/test_rules.py b/tests/data/test_rules.py new file mode 100644 index 0000000..d4e597b --- /dev/null +++ b/tests/data/test_rules.py @@ -0,0 +1,39 @@ +# Copyright (c) 2026 verl-project authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from rl_insight.data.rules import DataValidationError, PathExistsRule + + +def test_path_exists_rule_accepts_existing_directory(tmp_path): + rule = PathExistsRule() + assert rule.check(str(tmp_path)) is True + + +def test_path_exists_rule_rejects_non_string_input(): + rule = PathExistsRule() + assert rule.check({"path": "x"}) is False + assert "not a path" in rule.error_message + + +def test_path_exists_rule_rejects_missing_directory(): + rule = PathExistsRule() + assert rule.check("C:/definitely/not/exist/path") is False + + +def test_data_validation_error_string_includes_error_details(): + err = DataValidationError("Data validation failed", ["line1", "line2"]) + text = str(err) + assert "Data validation failed" in text + assert "line1" in text + assert "line2" in text diff --git a/data/__init__.py b/tests/parser/__init__.py similarity index 100% rename from data/__init__.py rename to tests/parser/__init__.py diff --git a/tests/cluster_analysis/test_cluster_analysis.py b/tests/parser/test_cluster_analysis.py similarity index 98% rename from tests/cluster_analysis/test_cluster_analysis.py rename to tests/parser/test_cluster_analysis.py index 88eaad5..3a8c9c4 100644 --- a/tests/cluster_analysis/test_cluster_analysis.py +++ b/tests/parser/test_cluster_analysis.py @@ -29,6 +29,7 @@ import pandas as pd import pytest +from rl_insight.data import DataEnum from rl_insight.main import main from rl_insight.parser import MstxClusterParser from rl_insight.parser import ( @@ -594,7 +595,7 @@ def test_parse_full_pipeline(self, mock_mstx_profiler_structure): ) with patch("concurrent.futures.ProcessPoolExecutor"): - df = parser.run() + df = parser.run(mock_mstx_profiler_structure) assert df is not None assert len(df) >= 1 @@ -920,7 +921,7 @@ def test_full_pipeline_with_mock_data(self, mock_mstx_profiler_structure, tmp_pa ) with patch("concurrent.futures.ProcessPoolExecutor"): - df = parser.run() + df = parser.run(mock_mstx_profiler_structure) assert df is not None assert len(df) >= 1 @@ -946,10 +947,15 @@ def test_full_pipeline_with_mock_data(self, mock_mstx_profiler_structure, tmp_pa "sys.argv", ["main.py", "--input-path", "/tmp", "--profiler-type", "mstx"], ) + @patch("rl_insight.pipeline.offline_insight_pipeline.DataChecker.run") @patch("rl_insight.pipeline.offline_insight_pipeline.get_cluster_parser_cls") @patch("rl_insight.pipeline.offline_insight_pipeline.RLTimelineVisualizer") def test_main_function( - self, mock_visualizer_cls, mock_get_parser, mock_mstx_profiler_structure + self, + mock_visualizer_cls, + mock_get_parser, + mock_data_checker_run, + mock_mstx_profiler_structure, ): """Test main CLI entry point.""" # Mock parser @@ -968,14 +974,13 @@ def test_main_function( } ] ) - mock_parser_instance.get_output_type.return_value = pd.DataFrame - mock_parser_instance.get_input_type.return_value = None + mock_parser_instance.input_type = DataEnum.MULTI_JSON mock_parser.return_value = mock_parser_instance mock_get_parser.return_value = mock_parser # Mock visualizer mock_visualizer_instance = MagicMock() - mock_visualizer_instance.get_input_type.return_value = pd.DataFrame + mock_visualizer_instance.input_type = DataEnum.SUMMARY_EVENT mock_visualizer_cls.return_value = mock_visualizer_instance # Run main diff --git a/tests/cluster_analysis/test_torch_parser.py b/tests/parser/test_torch_parser.py similarity index 100% rename from tests/cluster_analysis/test_torch_parser.py rename to tests/parser/test_torch_parser.py diff --git a/tests/special_sanity/check_license.py b/tests/special_sanity/check_license.py index 8babb0a..07f5ae2 100644 --- a/tests/special_sanity/check_license.py +++ b/tests/special_sanity/check_license.py @@ -30,7 +30,8 @@ license_head_huawei = ( "Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved." ) -license_head_verl = "Copyright (c) 2025 verl-project authors." +license_head_verl_25 = "Copyright (c) 2025 verl-project authors." +license_head_verl_26 = "Copyright (c) 2026 verl-project authors." license_headers = [ license_head_bytedance, license_head_bytedance_25, @@ -44,7 +45,8 @@ license_head_facebook, license_head_meituan, license_head_huawei, - license_head_verl, + license_head_verl_25, + license_head_verl_26, ]