Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/cloudai/_core/base_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,12 @@ def run(self):

total_tests = len(self.test_scenario.test_runs)
dependency_free_trs = self.find_dependency_free_tests()
if total_tests and not dependency_free_trs:
raise ValueError(
f"No runnable tests found in scenario '{self.test_scenario.name}'. At least one test must have no "
"'start_post_init' or 'start_post_comp' dependencies."
)

for tr in dependency_free_trs:
self.submit_test(tr)

Expand Down
67 changes: 65 additions & 2 deletions src/cloudai/models/scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,42 @@ class TestRunDependencyModel(BaseModel):
id: str


_START_BLOCKING_DEPENDENCY_TYPES = {"start_post_comp", "start_post_init"}


def _find_dependency_cycle(graph: dict[str, list[str]]) -> list[str]:
visiting: set[str] = set()
visited: set[str] = set()
path: list[str] = []

def visit(test_id: str) -> list[str]:
if test_id in visiting:
cycle_start = path.index(test_id)
return [*path[cycle_start:], test_id]
if test_id in visited:
return []

visiting.add(test_id)
path.append(test_id)

for dep_id in graph[test_id]:
cycle = visit(dep_id)
if cycle:
return cycle

path.pop()
visiting.remove(test_id)
visited.add(test_id)
return []

for test_id in graph:
cycle = visit(test_id)
if cycle:
return cycle

return []


class TestRunModel(BaseModel):
"""Model for test run in test scenario."""

Expand Down Expand Up @@ -192,10 +228,10 @@ class TestScenarioModel(BaseModel):

@model_validator(mode="after")
def check_no_self_dependency(self):
"""Check for circular dependencies in the test scenario."""
"""Check for direct non-start-blocking self dependencies in the test scenario."""
for test_run in self.tests:
for dep in test_run.dependencies:
if dep.id == test_run.id:
if dep.id == test_run.id and dep.type not in _START_BLOCKING_DEPENDENCY_TYPES:
raise ValueError(f"Test '{test_run.id}' must not depend on itself.")

return self
Expand All @@ -222,6 +258,33 @@ def check_all_dependencies_are_known(self):

return self

@model_validator(mode="after")
def check_start_blocking_dependencies_are_schedulable(self):
"""Check that start-blocking dependencies can be scheduled."""
graph = {
tr.id: [dep.id for dep in tr.dependencies if dep.type in _START_BLOCKING_DEPENDENCY_TYPES]
for tr in self.tests
}
runnable_roots = [test_id for test_id, dep_ids in graph.items() if not dep_ids]
cycle = _find_dependency_cycle(graph)

if cycle:
msg = f"Start-blocking dependency cycle detected: {' -> '.join(cycle)}."
if not runnable_roots:
msg += (
" No runnable root tests found; at least one test must have no "
"'start_post_init' or 'start_post_comp' dependencies."
)
raise ValueError(msg)

if not runnable_roots:
raise ValueError(
"No runnable root tests found; at least one test must have no "
"'start_post_init' or 'start_post_comp' dependencies."
)

return self

@field_validator("reports", mode="before")
@classmethod
def parse_reports(cls, value: dict[str, Any] | None) -> dict[str, ReportConfig] | None:
Expand Down
15 changes: 15 additions & 0 deletions tests/test_base_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,21 @@ def test_both_failed_runner_status_reported(self, runner: MyRunner):
assert res == runner.runner_job_status_result


def test_run_raises_if_no_tests_are_runnable(runner: MyRunner):
test = runner.test_scenario.test_runs[0].test
tr_a = TestRun("A", test, 1, [])
tr_b = TestRun("B", test, 1, [])
tr_a.dependencies = {"start_post_comp": TestDependency(tr_b)}
tr_b.dependencies = {"start_post_comp": TestDependency(tr_a)}
runner.test_scenario.test_runs = [tr_a, tr_b]

with pytest.raises(ValueError) as exc_info:
runner.run()

assert "No runnable tests found in scenario 'Test Scenario'." in str(exc_info.value)
assert runner.submitted_trs == []


class TestHandleDependencies:
"""
Tests for BaseRunner.handle_dependencies method.
Expand Down
49 changes: 49 additions & 0 deletions tests/test_test_scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,55 @@ def test_raises_on_unknown_dependency() -> None:
assert exc_info.match("Dependency section 'dep-id' not found for test 'test-id'.")


def test_raises_on_start_blocking_dependency_cycle() -> None:
with pytest.raises(ValueError) as exc_info:
TestScenarioModel.model_validate(
{
"name": "test",
"Tests": [
{"id": "root", "test_name": "nccl"},
{
"id": "A",
"test_name": "nccl",
"dependencies": [{"type": "start_post_comp", "id": "B"}],
},
{
"id": "B",
"test_name": "nccl",
"dependencies": [{"type": "start_post_init", "id": "A"}],
},
],
}
)

assert "Start-blocking dependency cycle detected: A -> B -> A." in str(exc_info.value)


def test_raises_if_start_blocking_cycle_has_no_runnable_roots() -> None:
with pytest.raises(ValueError) as exc_info:
TestScenarioModel.model_validate(
{
"name": "test",
"Tests": [
{
"id": "A",
"test_name": "nccl",
"dependencies": [{"type": "start_post_comp", "id": "B"}],
},
{
"id": "B",
"test_name": "nccl",
"dependencies": [{"type": "start_post_comp", "id": "A"}],
},
],
}
)

error = str(exc_info.value)
assert "Start-blocking dependency cycle detected: A -> B -> A." in error
assert "No runnable root tests found" in error


def test_list_of_tests_must_not_be_empty() -> None:
with pytest.raises(ValueError) as exc_info:
TestScenarioModel.model_validate({"name": "name"})
Expand Down