From 109352e77bb12c28ba19ef2fd492fd3ef22c8469 Mon Sep 17 00:00:00 2001 From: Minh Vu Date: Fri, 12 Jun 2026 19:23:30 +0200 Subject: [PATCH] Validate scenario dependency scheduling --- src/cloudai/_core/base_runner.py | 6 +++ src/cloudai/models/scenario.py | 67 +++++++++++++++++++++++++++++++- tests/test_base_runner.py | 15 +++++++ tests/test_test_scenario.py | 49 +++++++++++++++++++++++ 4 files changed, 135 insertions(+), 2 deletions(-) diff --git a/src/cloudai/_core/base_runner.py b/src/cloudai/_core/base_runner.py index 10f6bc01c..495e864e2 100644 --- a/src/cloudai/_core/base_runner.py +++ b/src/cloudai/_core/base_runner.py @@ -86,6 +86,12 @@ def run(self): total_tests = len(self.test_scenario.test_runs) dependency_free_trs = self.find_dependency_free_tests() + if total_tests and not dependency_free_trs: + raise ValueError( + f"No runnable tests found in scenario '{self.test_scenario.name}'. At least one test must have no " + "'start_post_init' or 'start_post_comp' dependencies." + ) + for tr in dependency_free_trs: self.submit_test(tr) diff --git a/src/cloudai/models/scenario.py b/src/cloudai/models/scenario.py index 57234df23..d01fa8ba2 100644 --- a/src/cloudai/models/scenario.py +++ b/src/cloudai/models/scenario.py @@ -62,6 +62,42 @@ class TestRunDependencyModel(BaseModel): id: str +_START_BLOCKING_DEPENDENCY_TYPES = {"start_post_comp", "start_post_init"} + + +def _find_dependency_cycle(graph: dict[str, list[str]]) -> list[str]: + visiting: set[str] = set() + visited: set[str] = set() + path: list[str] = [] + + def visit(test_id: str) -> list[str]: + if test_id in visiting: + cycle_start = path.index(test_id) + return [*path[cycle_start:], test_id] + if test_id in visited: + return [] + + visiting.add(test_id) + path.append(test_id) + + for dep_id in graph[test_id]: + cycle = visit(dep_id) + if cycle: + return cycle + + path.pop() + visiting.remove(test_id) + visited.add(test_id) + return [] + + for test_id in graph: + cycle = visit(test_id) + if cycle: + return cycle + + return [] + + class TestRunModel(BaseModel): """Model for test run in test scenario.""" @@ -192,10 +228,10 @@ class TestScenarioModel(BaseModel): @model_validator(mode="after") def check_no_self_dependency(self): - """Check for circular dependencies in the test scenario.""" + """Check for direct non-start-blocking self dependencies in the test scenario.""" for test_run in self.tests: for dep in test_run.dependencies: - if dep.id == test_run.id: + if dep.id == test_run.id and dep.type not in _START_BLOCKING_DEPENDENCY_TYPES: raise ValueError(f"Test '{test_run.id}' must not depend on itself.") return self @@ -222,6 +258,33 @@ def check_all_dependencies_are_known(self): return self + @model_validator(mode="after") + def check_start_blocking_dependencies_are_schedulable(self): + """Check that start-blocking dependencies can be scheduled.""" + graph = { + tr.id: [dep.id for dep in tr.dependencies if dep.type in _START_BLOCKING_DEPENDENCY_TYPES] + for tr in self.tests + } + runnable_roots = [test_id for test_id, dep_ids in graph.items() if not dep_ids] + cycle = _find_dependency_cycle(graph) + + if cycle: + msg = f"Start-blocking dependency cycle detected: {' -> '.join(cycle)}." + if not runnable_roots: + msg += ( + " No runnable root tests found; at least one test must have no " + "'start_post_init' or 'start_post_comp' dependencies." + ) + raise ValueError(msg) + + if not runnable_roots: + raise ValueError( + "No runnable root tests found; at least one test must have no " + "'start_post_init' or 'start_post_comp' dependencies." + ) + + return self + @field_validator("reports", mode="before") @classmethod def parse_reports(cls, value: dict[str, Any] | None) -> dict[str, ReportConfig] | None: diff --git a/tests/test_base_runner.py b/tests/test_base_runner.py index 2a671a851..10a0384d4 100644 --- a/tests/test_base_runner.py +++ b/tests/test_base_runner.py @@ -113,6 +113,21 @@ def test_both_failed_runner_status_reported(self, runner: MyRunner): assert res == runner.runner_job_status_result +def test_run_raises_if_no_tests_are_runnable(runner: MyRunner): + test = runner.test_scenario.test_runs[0].test + tr_a = TestRun("A", test, 1, []) + tr_b = TestRun("B", test, 1, []) + tr_a.dependencies = {"start_post_comp": TestDependency(tr_b)} + tr_b.dependencies = {"start_post_comp": TestDependency(tr_a)} + runner.test_scenario.test_runs = [tr_a, tr_b] + + with pytest.raises(ValueError) as exc_info: + runner.run() + + assert "No runnable tests found in scenario 'Test Scenario'." in str(exc_info.value) + assert runner.submitted_trs == [] + + class TestHandleDependencies: """ Tests for BaseRunner.handle_dependencies method. diff --git a/tests/test_test_scenario.py b/tests/test_test_scenario.py index 9da396b8a..1681405db 100644 --- a/tests/test_test_scenario.py +++ b/tests/test_test_scenario.py @@ -223,6 +223,55 @@ def test_raises_on_unknown_dependency() -> None: assert exc_info.match("Dependency section 'dep-id' not found for test 'test-id'.") +def test_raises_on_start_blocking_dependency_cycle() -> None: + with pytest.raises(ValueError) as exc_info: + TestScenarioModel.model_validate( + { + "name": "test", + "Tests": [ + {"id": "root", "test_name": "nccl"}, + { + "id": "A", + "test_name": "nccl", + "dependencies": [{"type": "start_post_comp", "id": "B"}], + }, + { + "id": "B", + "test_name": "nccl", + "dependencies": [{"type": "start_post_init", "id": "A"}], + }, + ], + } + ) + + assert "Start-blocking dependency cycle detected: A -> B -> A." in str(exc_info.value) + + +def test_raises_if_start_blocking_cycle_has_no_runnable_roots() -> None: + with pytest.raises(ValueError) as exc_info: + TestScenarioModel.model_validate( + { + "name": "test", + "Tests": [ + { + "id": "A", + "test_name": "nccl", + "dependencies": [{"type": "start_post_comp", "id": "B"}], + }, + { + "id": "B", + "test_name": "nccl", + "dependencies": [{"type": "start_post_comp", "id": "A"}], + }, + ], + } + ) + + error = str(exc_info.value) + assert "Start-blocking dependency cycle detected: A -> B -> A." in error + assert "No runnable root tests found" in error + + def test_list_of_tests_must_not_be_empty() -> None: with pytest.raises(ValueError) as exc_info: TestScenarioModel.model_validate({"name": "name"})