Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 56 additions & 27 deletions src/experimaestro/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,10 @@ def cleanup(self):
if not self.cleaned:
self.cleaned = True
logger.info("Cleaning up")
rmfile(self.pidfile)
env = taskglobals.Env.instance()

if not env.slave:
rmfile(self.pidfile)

# Load MockJob for state tracking
self._mock_job = self._load_mock_job()
Expand Down Expand Up @@ -471,7 +474,8 @@ def cleanup(self):
self._carbon_tracker = None

# Write status.json while still holding locks
self._write_status()
if not env.slave:
self._write_status()

# Release IPC locks
for lock in self.locks:
Expand All @@ -486,7 +490,7 @@ def cleanup(self):
# Note: dynamic dependency locks are released via context manager
# in the run() method, not here

if self.started:
if self.started and not env.slave:
# Report final state: "error" if .failed exists, "done" otherwise
final_state = "error" if self.failedpath.exists() else "done"
report_eoj(final_state)
Expand All @@ -502,10 +506,13 @@ def handle_error(self, code, frame_type, reason: str = "failed", message: str =
message: Optional message with details
"""
logger.info("Error handler: finished with code %d, reason=%s", code, reason)
failure_info = {"code": code, "reason": reason}
if message:
failure_info["message"] = message
self.failedpath.write_text(json.dumps(failure_info))
env = taskglobals.Env.instance()
if not env.slave:
failure_info = {"code": code, "reason": reason}
if message:
failure_info["message"] = message
self.failedpath.write_text(json.dumps(failure_info))

self.cleanup()
logger.info("Exiting")
delayed_shutdown(60, exit_code=code)
Expand All @@ -515,10 +522,12 @@ def _background_cleanup(self, reason: str, message: str = ""):
"""Run framework cleanup in background thread."""
try:
logger.info("Background cleanup: reason=%s", reason)
failure_info = {"code": signal.SIGTERM, "reason": reason}
if message:
failure_info["message"] = message
self.failedpath.write_text(json.dumps(failure_info))
env = taskglobals.Env.instance()
if not env.slave:
failure_info = {"code": signal.SIGTERM, "reason": reason}
if message:
failure_info["message"] = message
self.failedpath.write_text(json.dumps(failure_info))
self.cleanup()
logger.info("Background cleanup finished")
except Exception:
Expand Down Expand Up @@ -579,15 +588,31 @@ def remove_signal_handlers(remove_cleanup=True):
os.getpid()
logger.info("Working in directory %s", workdir)

for lockfile in self.lockfiles:
fullpath = str(Path(lockfile).resolve())
logger.info("Locking %s", fullpath)
lock = create_file_lock(fullpath)
# MAYBE: should have a clever way to lock
# Problem = slurm would have a job doing nothing...
# Fix = maybe with two files
lock.acquire()
self.locks.append(lock)
# Identify non-zero ranks in distributed settings (SLURM, DDP, etc.)
# A process is a slave if any distributed rank environment variable is > 0
def get_rank(name):
val = os.environ.get(name)
logger.debug("Rank detection: %s=%s", name, val)
return int(val) if val is not None else 0

rank = max(get_rank("SLURM_PROCID"), get_rank("RANK"), get_rank("LOCAL_RANK"))
env = taskglobals.Env.instance()
if rank > 0:
logger.info("Non-zero rank (%d): marking as slave process", rank)
env.slave = True
else:
logger.debug("Rank 0 or no distributed environment detected (rank=%d)", rank)

if not env.slave:
for lockfile in self.lockfiles:
fullpath = str(Path(lockfile).resolve())
logger.info("Locking %s", fullpath)
lock = create_file_lock(fullpath)
# MAYBE: should have a clever way to lock
# Problem = slurm would have a job doing nothing...
# Fix = maybe with two files
lock.acquire()
self.locks.append(lock)

# Load and setup dynamic dependency locks from locks.json
locks_path = workdir / "locks.json"
Expand All @@ -610,11 +635,12 @@ def remove_signal_handlers(remove_cleanup=True):
rmfile(self.failedpath)
self.started = True

# Update status.json to "running" before writing events
self._update_status_running()
if not env.slave:
# Update status.json to "running" before writing events
self._update_status_running()

# Notify that the job has started (writes event file)
start_of_job()
# Notify that the job has started (writes event file)
start_of_job()

# Initialize carbon tracking
self._job_start_time = datetime.now()
Expand All @@ -633,7 +659,7 @@ def remove_signal_handlers(remove_cleanup=True):
except Exception as e:
logger.debug("Failed to load params for carbon tracking: %s", e)

if workspace_path:
if workspace_path and not env.slave:
# Load previous carbon metrics from status.json for accumulation
try:
init_mock = self._load_mock_job()
Expand All @@ -658,6 +684,7 @@ def remove_signal_handlers(remove_cleanup=True):
)

# Acquire dynamic dependency locks while running the task
# Non-slave processes acquire actual locks, slaves get dummy locks
with self.dynamic_locks.dependency_locks():
run(workdir / "params.json")

Expand Down Expand Up @@ -690,9 +717,11 @@ def remove_signal_handlers(remove_cleanup=True):
self.handle_error(1, None)

except SystemExit as e:
env = taskglobals.Env.instance()
if e.code == 0:
# Normal exit, just create the ".done" file
self.donepath.touch()
if not env.slave:
# Normal exit, just create the ".done" file
self.donepath.touch()

# ... and finish the exit process
raise
Expand Down
85 changes: 85 additions & 0 deletions src/experimaestro/tests/test_rank_awareness.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import os
import pytest
from unittest.mock import patch
from experimaestro.run import TaskRunner
import experimaestro.taskglobals as taskglobals

@pytest.fixture
def env():
instance = taskglobals.Env.instance()
# Reset env
instance.slave = False
return instance

@patch("experimaestro.run.atexit.register")
@patch("experimaestro.run.signal.signal")
@patch("experimaestro.run.os.chdir")
@patch("experimaestro.run.os.register_at_fork")
@patch("experimaestro.run.create_file_lock")
@patch("experimaestro.run.start_of_job")
@patch("experimaestro.run.run")
@patch("experimaestro.run.TaskRunner._update_status_running")
@patch("experimaestro.run.TaskRunner._load_mock_job")
def test_task_runner_rank_detection(
mock_load_mock_job,
mock_update_status_running,
mock_run_task,
mock_start_of_job,
mock_create_file_lock,
mock_register_at_fork,
mock_os_chdir,
mock_signal,
mock_atexit,
env,
tmp_path
):
script_path = tmp_path / "test.py"
script_path.touch()
lockfiles = [str(tmp_path / "test.lock")]

# Case 1: Main process (rank 0)
with patch.dict(os.environ, {"SLURM_PROCID": "0", "LOCAL_RANK": "0"}):
runner = TaskRunner(str(script_path), lockfiles)
# We need to stop run() from exiting or failing
# We'll mock its internal exit/cleanup calls if necessary,
# or just catch the SystemExit
try:
runner.run()
except SystemExit:
pass

assert env.slave is False
mock_create_file_lock.assert_called()
mock_start_of_job.assert_called()
mock_update_status_running.assert_called()

# Reset mocks for Case 2
mock_create_file_lock.reset_mock()
mock_start_of_job.reset_mock()
mock_update_status_running.reset_mock()
env.slave = False

# Case 2: Slave process (rank > 0 via LOCAL_RANK)
with patch.dict(os.environ, {"SLURM_PROCID": "0", "LOCAL_RANK": "1"}):
runner = TaskRunner(str(script_path), lockfiles)
try:
runner.run()
except SystemExit:
pass

assert env.slave is True
mock_create_file_lock.assert_not_called()
mock_start_of_job.assert_not_called()
mock_update_status_running.assert_not_called()

# Case 3: Slave process (rank > 0 via SLURM_PROCID)
env.slave = False
with patch.dict(os.environ, {"SLURM_PROCID": "2"}):
runner = TaskRunner(str(script_path), lockfiles)
try:
runner.run()
except SystemExit:
pass

assert env.slave is True
mock_create_file_lock.assert_not_called()
Loading