diff --git a/src/experimaestro/run.py b/src/experimaestro/run.py index 409049b6..e38af302 100644 --- a/src/experimaestro/run.py +++ b/src/experimaestro/run.py @@ -431,7 +431,10 @@ def cleanup(self): if not self.cleaned: self.cleaned = True logger.info("Cleaning up") - rmfile(self.pidfile) + env = taskglobals.Env.instance() + + if not env.slave: + rmfile(self.pidfile) # Load MockJob for state tracking self._mock_job = self._load_mock_job() @@ -471,7 +474,8 @@ def cleanup(self): self._carbon_tracker = None # Write status.json while still holding locks - self._write_status() + if not env.slave: + self._write_status() # Release IPC locks for lock in self.locks: @@ -486,7 +490,7 @@ def cleanup(self): # Note: dynamic dependency locks are released via context manager # in the run() method, not here - if self.started: + if self.started and not env.slave: # Report final state: "error" if .failed exists, "done" otherwise final_state = "error" if self.failedpath.exists() else "done" report_eoj(final_state) @@ -502,10 +506,13 @@ def handle_error(self, code, frame_type, reason: str = "failed", message: str = message: Optional message with details """ logger.info("Error handler: finished with code %d, reason=%s", code, reason) - failure_info = {"code": code, "reason": reason} - if message: - failure_info["message"] = message - self.failedpath.write_text(json.dumps(failure_info)) + env = taskglobals.Env.instance() + if not env.slave: + failure_info = {"code": code, "reason": reason} + if message: + failure_info["message"] = message + self.failedpath.write_text(json.dumps(failure_info)) + self.cleanup() logger.info("Exiting") delayed_shutdown(60, exit_code=code) @@ -515,10 +522,12 @@ def _background_cleanup(self, reason: str, message: str = ""): """Run framework cleanup in background thread.""" try: logger.info("Background cleanup: reason=%s", reason) - failure_info = {"code": signal.SIGTERM, "reason": reason} - if message: - failure_info["message"] = message - self.failedpath.write_text(json.dumps(failure_info)) + env = taskglobals.Env.instance() + if not env.slave: + failure_info = {"code": signal.SIGTERM, "reason": reason} + if message: + failure_info["message"] = message + self.failedpath.write_text(json.dumps(failure_info)) self.cleanup() logger.info("Background cleanup finished") except Exception: @@ -579,15 +588,31 @@ def remove_signal_handlers(remove_cleanup=True): os.getpid() logger.info("Working in directory %s", workdir) - for lockfile in self.lockfiles: - fullpath = str(Path(lockfile).resolve()) - logger.info("Locking %s", fullpath) - lock = create_file_lock(fullpath) - # MAYBE: should have a clever way to lock - # Problem = slurm would have a job doing nothing... - # Fix = maybe with two files - lock.acquire() - self.locks.append(lock) + # Identify non-zero ranks in distributed settings (SLURM, DDP, etc.) + # A process is a slave if any distributed rank environment variable is > 0 + def get_rank(name): + val = os.environ.get(name) + logger.debug("Rank detection: %s=%s", name, val) + return int(val) if val is not None else 0 + + rank = max(get_rank("SLURM_PROCID"), get_rank("RANK"), get_rank("LOCAL_RANK")) + env = taskglobals.Env.instance() + if rank > 0: + logger.info("Non-zero rank (%d): marking as slave process", rank) + env.slave = True + else: + logger.debug("Rank 0 or no distributed environment detected (rank=%d)", rank) + + if not env.slave: + for lockfile in self.lockfiles: + fullpath = str(Path(lockfile).resolve()) + logger.info("Locking %s", fullpath) + lock = create_file_lock(fullpath) + # MAYBE: should have a clever way to lock + # Problem = slurm would have a job doing nothing... + # Fix = maybe with two files + lock.acquire() + self.locks.append(lock) # Load and setup dynamic dependency locks from locks.json locks_path = workdir / "locks.json" @@ -610,11 +635,12 @@ def remove_signal_handlers(remove_cleanup=True): rmfile(self.failedpath) self.started = True - # Update status.json to "running" before writing events - self._update_status_running() + if not env.slave: + # Update status.json to "running" before writing events + self._update_status_running() - # Notify that the job has started (writes event file) - start_of_job() + # Notify that the job has started (writes event file) + start_of_job() # Initialize carbon tracking self._job_start_time = datetime.now() @@ -633,7 +659,7 @@ def remove_signal_handlers(remove_cleanup=True): except Exception as e: logger.debug("Failed to load params for carbon tracking: %s", e) - if workspace_path: + if workspace_path and not env.slave: # Load previous carbon metrics from status.json for accumulation try: init_mock = self._load_mock_job() @@ -658,6 +684,7 @@ def remove_signal_handlers(remove_cleanup=True): ) # Acquire dynamic dependency locks while running the task + # Non-slave processes acquire actual locks, slaves get dummy locks with self.dynamic_locks.dependency_locks(): run(workdir / "params.json") @@ -690,9 +717,11 @@ def remove_signal_handlers(remove_cleanup=True): self.handle_error(1, None) except SystemExit as e: + env = taskglobals.Env.instance() if e.code == 0: - # Normal exit, just create the ".done" file - self.donepath.touch() + if not env.slave: + # Normal exit, just create the ".done" file + self.donepath.touch() # ... and finish the exit process raise diff --git a/src/experimaestro/tests/test_rank_awareness.py b/src/experimaestro/tests/test_rank_awareness.py new file mode 100644 index 00000000..d99e0440 --- /dev/null +++ b/src/experimaestro/tests/test_rank_awareness.py @@ -0,0 +1,85 @@ +import os +import pytest +from unittest.mock import patch +from experimaestro.run import TaskRunner +import experimaestro.taskglobals as taskglobals + +@pytest.fixture +def env(): + instance = taskglobals.Env.instance() + # Reset env + instance.slave = False + return instance + +@patch("experimaestro.run.atexit.register") +@patch("experimaestro.run.signal.signal") +@patch("experimaestro.run.os.chdir") +@patch("experimaestro.run.os.register_at_fork") +@patch("experimaestro.run.create_file_lock") +@patch("experimaestro.run.start_of_job") +@patch("experimaestro.run.run") +@patch("experimaestro.run.TaskRunner._update_status_running") +@patch("experimaestro.run.TaskRunner._load_mock_job") +def test_task_runner_rank_detection( + mock_load_mock_job, + mock_update_status_running, + mock_run_task, + mock_start_of_job, + mock_create_file_lock, + mock_register_at_fork, + mock_os_chdir, + mock_signal, + mock_atexit, + env, + tmp_path +): + script_path = tmp_path / "test.py" + script_path.touch() + lockfiles = [str(tmp_path / "test.lock")] + + # Case 1: Main process (rank 0) + with patch.dict(os.environ, {"SLURM_PROCID": "0", "LOCAL_RANK": "0"}): + runner = TaskRunner(str(script_path), lockfiles) + # We need to stop run() from exiting or failing + # We'll mock its internal exit/cleanup calls if necessary, + # or just catch the SystemExit + try: + runner.run() + except SystemExit: + pass + + assert env.slave is False + mock_create_file_lock.assert_called() + mock_start_of_job.assert_called() + mock_update_status_running.assert_called() + + # Reset mocks for Case 2 + mock_create_file_lock.reset_mock() + mock_start_of_job.reset_mock() + mock_update_status_running.reset_mock() + env.slave = False + + # Case 2: Slave process (rank > 0 via LOCAL_RANK) + with patch.dict(os.environ, {"SLURM_PROCID": "0", "LOCAL_RANK": "1"}): + runner = TaskRunner(str(script_path), lockfiles) + try: + runner.run() + except SystemExit: + pass + + assert env.slave is True + mock_create_file_lock.assert_not_called() + mock_start_of_job.assert_not_called() + mock_update_status_running.assert_not_called() + + # Case 3: Slave process (rank > 0 via SLURM_PROCID) + env.slave = False + with patch.dict(os.environ, {"SLURM_PROCID": "2"}): + runner = TaskRunner(str(script_path), lockfiles) + try: + runner.run() + except SystemExit: + pass + + assert env.slave is True + mock_create_file_lock.assert_not_called()