Skip to content

Commit a53cf44

Browse files
authored
Merge pull request #218 from VictorMorand/feat-slave-processes
Feat: detect and slave processes
2 parents 8157974 + 064ae8c commit a53cf44

2 files changed

Lines changed: 141 additions & 27 deletions

File tree

src/experimaestro/run.py

Lines changed: 56 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,10 @@ def cleanup(self):
431431
if not self.cleaned:
432432
self.cleaned = True
433433
logger.info("Cleaning up")
434-
rmfile(self.pidfile)
434+
env = taskglobals.Env.instance()
435+
436+
if not env.slave:
437+
rmfile(self.pidfile)
435438

436439
# Load MockJob for state tracking
437440
self._mock_job = self._load_mock_job()
@@ -471,7 +474,8 @@ def cleanup(self):
471474
self._carbon_tracker = None
472475

473476
# Write status.json while still holding locks
474-
self._write_status()
477+
if not env.slave:
478+
self._write_status()
475479

476480
# Release IPC locks
477481
for lock in self.locks:
@@ -486,7 +490,7 @@ def cleanup(self):
486490
# Note: dynamic dependency locks are released via context manager
487491
# in the run() method, not here
488492

489-
if self.started:
493+
if self.started and not env.slave:
490494
# Report final state: "error" if .failed exists, "done" otherwise
491495
final_state = "error" if self.failedpath.exists() else "done"
492496
report_eoj(final_state)
@@ -502,10 +506,13 @@ def handle_error(self, code, frame_type, reason: str = "failed", message: str =
502506
message: Optional message with details
503507
"""
504508
logger.info("Error handler: finished with code %d, reason=%s", code, reason)
505-
failure_info = {"code": code, "reason": reason}
506-
if message:
507-
failure_info["message"] = message
508-
self.failedpath.write_text(json.dumps(failure_info))
509+
env = taskglobals.Env.instance()
510+
if not env.slave:
511+
failure_info = {"code": code, "reason": reason}
512+
if message:
513+
failure_info["message"] = message
514+
self.failedpath.write_text(json.dumps(failure_info))
515+
509516
self.cleanup()
510517
logger.info("Exiting")
511518
delayed_shutdown(60, exit_code=code)
@@ -515,10 +522,12 @@ def _background_cleanup(self, reason: str, message: str = ""):
515522
"""Run framework cleanup in background thread."""
516523
try:
517524
logger.info("Background cleanup: reason=%s", reason)
518-
failure_info = {"code": signal.SIGTERM, "reason": reason}
519-
if message:
520-
failure_info["message"] = message
521-
self.failedpath.write_text(json.dumps(failure_info))
525+
env = taskglobals.Env.instance()
526+
if not env.slave:
527+
failure_info = {"code": signal.SIGTERM, "reason": reason}
528+
if message:
529+
failure_info["message"] = message
530+
self.failedpath.write_text(json.dumps(failure_info))
522531
self.cleanup()
523532
logger.info("Background cleanup finished")
524533
except Exception:
@@ -579,15 +588,31 @@ def remove_signal_handlers(remove_cleanup=True):
579588
os.getpid()
580589
logger.info("Working in directory %s", workdir)
581590

582-
for lockfile in self.lockfiles:
583-
fullpath = str(Path(lockfile).resolve())
584-
logger.info("Locking %s", fullpath)
585-
lock = create_file_lock(fullpath)
586-
# MAYBE: should have a clever way to lock
587-
# Problem = slurm would have a job doing nothing...
588-
# Fix = maybe with two files
589-
lock.acquire()
590-
self.locks.append(lock)
591+
# Identify non-zero ranks in distributed settings (SLURM, DDP, etc.)
592+
# A process is a slave if any distributed rank environment variable is > 0
593+
def get_rank(name):
594+
val = os.environ.get(name)
595+
logger.debug("Rank detection: %s=%s", name, val)
596+
return int(val) if val is not None else 0
597+
598+
rank = max(get_rank("SLURM_PROCID"), get_rank("RANK"), get_rank("LOCAL_RANK"))
599+
env = taskglobals.Env.instance()
600+
if rank > 0:
601+
logger.info("Non-zero rank (%d): marking as slave process", rank)
602+
env.slave = True
603+
else:
604+
logger.debug("Rank 0 or no distributed environment detected (rank=%d)", rank)
605+
606+
if not env.slave:
607+
for lockfile in self.lockfiles:
608+
fullpath = str(Path(lockfile).resolve())
609+
logger.info("Locking %s", fullpath)
610+
lock = create_file_lock(fullpath)
611+
# MAYBE: should have a clever way to lock
612+
# Problem = slurm would have a job doing nothing...
613+
# Fix = maybe with two files
614+
lock.acquire()
615+
self.locks.append(lock)
591616

592617
# Load and setup dynamic dependency locks from locks.json
593618
locks_path = workdir / "locks.json"
@@ -610,11 +635,12 @@ def remove_signal_handlers(remove_cleanup=True):
610635
rmfile(self.failedpath)
611636
self.started = True
612637

613-
# Update status.json to "running" before writing events
614-
self._update_status_running()
638+
if not env.slave:
639+
# Update status.json to "running" before writing events
640+
self._update_status_running()
615641

616-
# Notify that the job has started (writes event file)
617-
start_of_job()
642+
# Notify that the job has started (writes event file)
643+
start_of_job()
618644

619645
# Initialize carbon tracking
620646
self._job_start_time = datetime.now()
@@ -633,7 +659,7 @@ def remove_signal_handlers(remove_cleanup=True):
633659
except Exception as e:
634660
logger.debug("Failed to load params for carbon tracking: %s", e)
635661

636-
if workspace_path:
662+
if workspace_path and not env.slave:
637663
# Load previous carbon metrics from status.json for accumulation
638664
try:
639665
init_mock = self._load_mock_job()
@@ -658,6 +684,7 @@ def remove_signal_handlers(remove_cleanup=True):
658684
)
659685

660686
# Acquire dynamic dependency locks while running the task
687+
# Non-slave processes acquire actual locks, slaves get dummy locks
661688
with self.dynamic_locks.dependency_locks():
662689
run(workdir / "params.json")
663690

@@ -690,9 +717,11 @@ def remove_signal_handlers(remove_cleanup=True):
690717
self.handle_error(1, None)
691718

692719
except SystemExit as e:
720+
env = taskglobals.Env.instance()
693721
if e.code == 0:
694-
# Normal exit, just create the ".done" file
695-
self.donepath.touch()
722+
if not env.slave:
723+
# Normal exit, just create the ".done" file
724+
self.donepath.touch()
696725

697726
# ... and finish the exit process
698727
raise
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import os
2+
import pytest
3+
from unittest.mock import patch
4+
from experimaestro.run import TaskRunner
5+
import experimaestro.taskglobals as taskglobals
6+
7+
@pytest.fixture
8+
def env():
9+
instance = taskglobals.Env.instance()
10+
# Reset env
11+
instance.slave = False
12+
return instance
13+
14+
@patch("experimaestro.run.atexit.register")
15+
@patch("experimaestro.run.signal.signal")
16+
@patch("experimaestro.run.os.chdir")
17+
@patch("experimaestro.run.os.register_at_fork")
18+
@patch("experimaestro.run.create_file_lock")
19+
@patch("experimaestro.run.start_of_job")
20+
@patch("experimaestro.run.run")
21+
@patch("experimaestro.run.TaskRunner._update_status_running")
22+
@patch("experimaestro.run.TaskRunner._load_mock_job")
23+
def test_task_runner_rank_detection(
24+
mock_load_mock_job,
25+
mock_update_status_running,
26+
mock_run_task,
27+
mock_start_of_job,
28+
mock_create_file_lock,
29+
mock_register_at_fork,
30+
mock_os_chdir,
31+
mock_signal,
32+
mock_atexit,
33+
env,
34+
tmp_path
35+
):
36+
script_path = tmp_path / "test.py"
37+
script_path.touch()
38+
lockfiles = [str(tmp_path / "test.lock")]
39+
40+
# Case 1: Main process (rank 0)
41+
with patch.dict(os.environ, {"SLURM_PROCID": "0", "LOCAL_RANK": "0"}):
42+
runner = TaskRunner(str(script_path), lockfiles)
43+
# We need to stop run() from exiting or failing
44+
# We'll mock its internal exit/cleanup calls if necessary,
45+
# or just catch the SystemExit
46+
try:
47+
runner.run()
48+
except SystemExit:
49+
pass
50+
51+
assert env.slave is False
52+
mock_create_file_lock.assert_called()
53+
mock_start_of_job.assert_called()
54+
mock_update_status_running.assert_called()
55+
56+
# Reset mocks for Case 2
57+
mock_create_file_lock.reset_mock()
58+
mock_start_of_job.reset_mock()
59+
mock_update_status_running.reset_mock()
60+
env.slave = False
61+
62+
# Case 2: Slave process (rank > 0 via LOCAL_RANK)
63+
with patch.dict(os.environ, {"SLURM_PROCID": "0", "LOCAL_RANK": "1"}):
64+
runner = TaskRunner(str(script_path), lockfiles)
65+
try:
66+
runner.run()
67+
except SystemExit:
68+
pass
69+
70+
assert env.slave is True
71+
mock_create_file_lock.assert_not_called()
72+
mock_start_of_job.assert_not_called()
73+
mock_update_status_running.assert_not_called()
74+
75+
# Case 3: Slave process (rank > 0 via SLURM_PROCID)
76+
env.slave = False
77+
with patch.dict(os.environ, {"SLURM_PROCID": "2"}):
78+
runner = TaskRunner(str(script_path), lockfiles)
79+
try:
80+
runner.run()
81+
except SystemExit:
82+
pass
83+
84+
assert env.slave is True
85+
mock_create_file_lock.assert_not_called()

0 commit comments

Comments
 (0)