Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion src/cloudai/workloads/nixl_ep/nixl_ep.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,33 @@ def _primary_launch_exit_error_message(content: str) -> str | None:

return f"The primary NIXL EP launch exited before phase {phase} completed."

@staticmethod
def _looks_like_planned_srun_termination(content: str) -> bool:
allowed_patterns = (
re.compile(r"^srun: error: .+: task \d+: Terminated$"),
re.compile(r"^srun: Terminating StepId=\S+$"),
re.compile(r"^srun: Force Terminated StepId=\S+$"),
)
lines = [line.strip() for line in content.splitlines() if line.strip()]
srun_lines = [line for line in lines if line.startswith("srun:")]
return (
bool(srun_lines)
and all(any(pattern.match(line) for pattern in allowed_patterns) for line in srun_lines)
and all(any(pattern.match(line) for line in srun_lines) for pattern in allowed_patterns)
)

def _has_planned_rank_removal(self) -> bool:
plans = self.cmd_args.plan if isinstance(self.cmd_args.plan, list) else [self.cmd_args.plan]
return any(rank < 0 for plan in plans for phase in NixlEPCmdArgs._parse_plan(plan) for rank in phase)

def _scan_log_for_failures(self, path: Path) -> JobStatusResult | None:
if not path.is_file():
return None

content = path.read_text(encoding="utf-8", errors="ignore")
if self._has_planned_rank_removal() and self._looks_like_planned_srun_termination(content):
content = "\n".join(line for line in content.splitlines() if not line.strip().startswith("srun:"))

launcher_failure_patterns = (
("python3: can't open file", "The benchmark entrypoint could not be opened."),
("Traceback (most recent call last):", "The benchmark launcher raised a Python traceback."),
Expand All @@ -164,7 +187,6 @@ def _scan_log_for_failures(self, path: Path) -> JobStatusResult | None:
("srun: error:", "Slurm reported an srun failure."),
("Exited with exit code", "A Slurm step exited with a non-zero status."),
)
content = path.read_text(encoding="utf-8", errors="ignore")
primary_launch_error = self._primary_launch_exit_error_message(content)
if primary_launch_error is not None:
tail = self._tail(path)
Expand Down
33 changes: 28 additions & 5 deletions src/cloudai/workloads/nixl_ep/slurm_command_gen_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ def _new_process_counts_by_phase(self) -> list[int]:
self._validate_requested_processes(counts)
return counts

def _has_planned_rank_removal(self) -> bool:
return any(rank < 0 for phase in self.tdef.cmd_args.parse_plan() for rank in phase)

def _validate_requested_processes(self, new_process_counts: list[int]) -> None:
total_requested_processes = sum(new_process_counts)
num_nodes, _ = self.get_cached_nodes_spec()
Expand Down Expand Up @@ -302,22 +305,42 @@ def _finish_with_rc_lines() -> list[str]:
"exit $rc",
]

@classmethod
def _wait_for_workers_lines(cls) -> list[str]:
return [
def _wait_for_workers_lines(self) -> list[str]:
allow_planned_removal_143 = "1" if self._has_planned_rank_removal() else "0"
final_phase = len(self.tdef.cmd_args.parse_plan()) - 1
lines = [
"",
f"allow_planned_removal_143={allow_planned_removal_143}",
"ignored_planned_removal_143=0",
"rc=0",
'while [ "$active_srun_count" -gt 0 ]; do',
" wait -n",
" wait_rc=$?",
" active_srun_count=$((active_srun_count - 1))",
' if [ "$wait_rc" -ne 0 ] && [ "$rc" -eq 0 ]; then',
' if [ "$allow_planned_removal_143" -eq 1 ] && [ "$wait_rc" -eq 143 ]; then',
' echo "Ignoring provisional NIXL EP planned-rank-removal exit 143"',
" ignored_planned_removal_143=1",
' elif [ "$wait_rc" -ne 0 ] && [ "$rc" -eq 0 ]; then',
" rc=$wait_rc",
" fi",
"done",
"",
*cls._finish_with_rc_lines(),
]
if self._has_planned_rank_removal():
final_phase_wait = (
f' wait_for_phase_completion "{final_phase}" "{self.node_log_path(0).absolute()}" "$primary_pid" '
"|| rc=143"
)
lines.extend(
[
'if [ "$ignored_planned_removal_143" -eq 1 ] && [ "$rc" -eq 0 ]; then',
final_phase_wait,
"fi",
"",
]
)
lines.extend(self._finish_with_rc_lines())
return lines

@staticmethod
def _has_follower_launches(stages: list[NixlEPStage]) -> bool:
Expand Down
11 changes: 10 additions & 1 deletion tests/ref_data/nixl-ep-launch.sh
Original file line number Diff line number Diff line change
Expand Up @@ -107,16 +107,25 @@ echo "Starting launches for phase 3..."
srun --export=ALL --mpi=pmix --container-image=docker.io/nvidia/nixl-ep:latest --container-mounts=__OUTPUT_DIR__/output:/cloudai_run_results,__INSTALL_DIR__:/cloudai_install,__OUTPUT_DIR__/output --overlap --nodelist="${nodes_array[2]}" --ntasks-per-node=1 --ntasks=1 -N1 --open-mode=append --output=__OUTPUT_DIR__/output/nixl-ep-node-2.log --error=__OUTPUT_DIR__/output/nixl-ep-node-2.log bash -c "source __OUTPUT_DIR__/output/env_vars.sh; python3 /workspace/nixl/examples/device/ep/tests/elastic/elastic.py --plan __OUTPUT_DIR__/output/nixl-ep-plan.json --num-processes 2 --tcp-server $master_ip --disable-ll-nvlink --hidden-dim 8192 --kineto --num-experts-per-rank 4 --num-tokens 256 --num-topk 6" &
active_srun_count=$((active_srun_count + 1))

allow_planned_removal_143=1
ignored_planned_removal_143=0
rc=0
while [ "$active_srun_count" -gt 0 ]; do
wait -n
wait_rc=$?
active_srun_count=$((active_srun_count - 1))
if [ "$wait_rc" -ne 0 ] && [ "$rc" -eq 0 ]; then
if [ "$allow_planned_removal_143" -eq 1 ] && [ "$wait_rc" -eq 143 ]; then
echo "Ignoring provisional NIXL EP planned-rank-removal exit 143"
ignored_planned_removal_143=1
elif [ "$wait_rc" -ne 0 ] && [ "$rc" -eq 0 ]; then
rc=$wait_rc
fi
done

if [ "$ignored_planned_removal_143" -eq 1 ] && [ "$rc" -eq 0 ]; then
wait_for_phase_completion "3" "__OUTPUT_DIR__/output/nixl-ep-node-0.log" "$primary_pid" || rc=143
fi

if [ "$rc" -eq 0 ]; then
echo "All NIXL EP launches completed successfully"
fi
Expand Down
96 changes: 60 additions & 36 deletions tests/workloads/nixl_ep/test_command_gen_strategy_slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@
# limitations under the License.

import json
import re
from importlib.metadata import version
from pathlib import Path

import pytest

Expand Down Expand Up @@ -102,23 +99,6 @@ def nixl_ep_tr(nixl_ep: NixlEPTestDefinition, slurm_system: SlurmSystem) -> Test
)


def normalize_sbatch(content: str, test_run: TestRun, slurm_system: SlurmSystem) -> str:
normalized = content.replace(str(slurm_system.install_path.absolute()), "__INSTALL_DIR__").replace(
str(test_run.output_path.parent.absolute()), "__OUTPUT_DIR__"
)
normalized = re.sub(
r"^#SBATCH --job-name=.*$",
"#SBATCH --job-name=__JOB_NAME__",
normalized,
flags=re.MULTILINE,
)
return normalized.replace(version("cloudai"), "__CLOUDAI_VERSION__")


def significant_sbatch_lines(content: str) -> list[str]:
return [line for line in content.splitlines() if line.strip() and not line.lstrip().startswith("echo ")]


def normalize_stages(strategy: NixlEPSlurmCommandGenStrategy) -> list[tuple[int, tuple[int, ...]]]:
num_nodes, _ = strategy.get_cached_nodes_spec()
normalized_stages: list[tuple[int, tuple[int, ...]]] = []
Expand Down Expand Up @@ -747,6 +727,66 @@ def test_gen_srun_command_multi_node_public_single_expansion_waits_for_phase_bef
assert launcher_script.count("--open-mode=append") == 1


def test_gen_srun_command_planned_rank_removal_tolerates_143_after_final_phase(
slurm_system: SlurmSystem,
) -> None:
tdef = NixlEPTestDefinition(
name="nixl_ep",
description="NIXL Elastic EP benchmark",
test_template_name="NixlEP",
cmd_args=NixlEPCmdArgs(
docker_image_url="docker.io/nvidia/nixl-ep:latest",
plan=json.dumps([[0, 1], [0, 1, 2, 3], [0, -2, 3], [0, 1, 2, 3]]),
num_processes_per_node=3,
),
)
test_run = TestRun(
name="nixl-ep",
num_nodes=2,
nodes=[],
test=tdef,
output_path=slurm_system.output_path,
)
strategy = NixlEPSlurmCommandGenStrategy(slurm_system, test_run)

launcher_script = read_launcher_script(strategy)

assert "allow_planned_removal_143=1" in launcher_script
assert 'if [ "$allow_planned_removal_143" -eq 1 ] && [ "$wait_rc" -eq 143 ]; then' in launcher_script
assert "Ignoring provisional NIXL EP planned-rank-removal exit 143" in launcher_script
assert 'wait_for_phase_completion "3"' in launcher_script
assert "|| rc=143" in launcher_script


def test_gen_srun_command_without_planned_rank_removal_keeps_143_fatal(
slurm_system: SlurmSystem,
) -> None:
tdef = NixlEPTestDefinition(
name="nixl_ep",
description="NIXL Elastic EP benchmark",
test_template_name="NixlEP",
cmd_args=NixlEPCmdArgs(
docker_image_url="docker.io/nvidia/nixl-ep:latest",
plan=SINGLE_EXPANSION_PLAN_STR,
num_processes_per_node=4,
),
)
test_run = TestRun(
name="nixl-ep",
num_nodes=2,
nodes=[],
test=tdef,
output_path=slurm_system.output_path,
)
strategy = NixlEPSlurmCommandGenStrategy(slurm_system, test_run)

launcher_script = read_launcher_script(strategy)

assert "allow_planned_removal_143=0" in launcher_script
assert "Ignoring provisional NIXL EP planned-rank-removal exit 143" in launcher_script
assert 'wait_for_phase_completion "1"' not in launcher_script


def test_gen_srun_command_multi_node_single_stage_starts_followers(
slurm_system: SlurmSystem,
) -> None:
Expand Down Expand Up @@ -826,19 +866,3 @@ def test_gen_srun_command_single_launch_reports_success(
assert 'echo "All NIXL EP launches completed successfully"' in launcher_script
assert 'if [ "$rc" -eq 0 ]; then' in launcher_script
assert "exit $rc" in launcher_script


def test_gen_exec_command_matches_reference(nixl_ep_tr: TestRun, slurm_system: SlurmSystem) -> None:
slurm_system.container_mount_home = True
strategy = NixlEPSlurmCommandGenStrategy(slurm_system, nixl_ep_tr)

sbatch_cmd = strategy.gen_exec_command()

assert sbatch_cmd == f"sbatch {nixl_ep_tr.output_path / 'cloudai_sbatch_script.sh'}"

content = (nixl_ep_tr.output_path / "cloudai_sbatch_script.sh").read_text().strip()
content = normalize_sbatch(content, nixl_ep_tr, slurm_system)

ref = (Path(__file__).parents[2] / "ref_data" / "nixl-ep.sbatch").read_text().strip()
ref = normalize_sbatch(ref, nixl_ep_tr, slurm_system)
assert significant_sbatch_lines(content) == significant_sbatch_lines(ref)
Loading
Loading