Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 150 additions & 23 deletions src/nsls2ptycho/core/ptycho_recon.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,99 @@
import traceback
import time


import requests
import json

from .databroker_api import load_metadata, save_data
from .utils import use_mpi_machinefile, set_flush_early
from .ptycho.utils import save_config


class RemoteJobHandler:
def __init__(self):
self.url = "https://orion-api-staging.nsls2.bnl.gov/api/v1/compute/orion/jobs"
self.api_key = os.getenv("APIKEY")
self.headers = {
"Content-Type": "application/json",
"x-api-key": f"{self.api_key}"
}
self.params = {"expand_info": False}

def submit_job(self, remote_path, parent_module, param):

fname_full = os.path.join(remote_path,'ptycho_'+str(param.scan_num)+'_'+param.sign)
srun_command = "python " + "-W " + "ignore " + "-m " + parent_module + ".ptycho.recon_ptycho_gui " + fname_full
#srun_command = "python -W ignore -m nsls2ptycho.core.ptycho.recon_ptycho_gui /nsls2/users/skarakuzu1/ptycho_test/remote_orion/ptycho_320045_t1"

HOME = os.environ["HOME"]

overrides = {
"name": "trial",
"partition": "normal",
"tasks": f"{len(param.gpus)}",
"time_limit": 720,
"tres_per_task": "cpu=1,gres/gpu=1",
"standard_output": "trial.out",
"standard_error": "trial.err",
}

payload = {
"script": (
"#!/bin/bash -l\n"
"module load orion/gpu\n"
"module unload openmpi\n"
"conda activate /nsls2/conda/envs/2025-2.0-py311-tiled/\n"
"nvidia-smi\n"
"echo $HOME\n"
"echo $(pwd)\n"
"echo $(which mpicc)\n"
#f"mpirun -n 2 {srun_command}\n"
f"srun --mpi=pmix {srun_command}\n"
),
"working_dir_path": f"{remote_path}",
"environment": [
"PATH=/usr/bin:/bin:/usr/sbin:/sbin",
f"HOME={HOME}",
"SLURM_EXPORT_ENV=ALL",
],
"overrides": overrides,
}


sys.stdout.flush()

response = requests.post(self.url, headers=self.headers, json=payload)
resp_json = response.json()
#print("response is ", response.status_code, resp_json)

if response.status_code == 200:
self.remote_job_id = resp_json['job_id']
print("submitted job with id ", self.remote_job_id)


return response.status_code


def cancel_job(self):
response = requests.delete(f"{self.url}/{self.remote_job_id}", headers=self.headers, params=self.params)

resp_json = response.json()
#print("response is ", response.status_code, resp_json)

return response.status_code

def get_job_status(self):
response = requests.get(f"{self.url}/{self.remote_job_id}", headers=self.headers, params=self.params)

resp_json = response.json()
#print("response is ", response.status_code, resp_json)

if response.status_code == 200:
state = resp_json["jobs"][0]["state"][0]
return state


class PtychoReconRemote(QtCore.QThread):
update_signal = QtCore.pyqtSignal(int, object) # (interation number, chi arrays)

Expand All @@ -28,12 +117,14 @@ def __init__(self, param:Param=None, parent=None):
if not os.path.isdir(self.remote_path):
os.mkdir(self.remote_path)

self.msg_file = os.path.join(os.path.join(self.remote_path,'msg'))
if not os.path.isfile(self.msg_file):
with open(self.msg_file,'w') as f:
pass
self.msg = open(self.msg_file,'r')
self.msg.readlines()
#self.msg_file = os.path.join(os.path.join(self.remote_path,'msg'))
#if not os.path.isfile(self.msg_file):
# with open(self.msg_file,'w') as f:
# pass
#self.msg = open(self.msg_file,'r')
#self.msg.readlines()

self.remote_job_handler = RemoteJobHandler()

def _parse_message(self, tokens):
def _parser(current, upper_limit, target_list):
Expand Down Expand Up @@ -98,11 +189,21 @@ def clear_slurm_header(self):
os.remove(slurm_header)
except:
pass


def cleanup(self):
if self.fname_full and os.path.exists(self.fname_full):
os.remove(self.fname_full)
self.fname_full = None
if os.path.exists(os.path.join(self.remote_path,'prb_live.npy')):
os.remove(os.path.join(self.remote_path,'prb_live.npy'))
if os.path.exists(os.path.join(self.remote_path,'obj_live.npy')):
os.remove(os.path.join(self.remote_path,'obj_live.npy'))

def recon_remote(self, param:Param, update_fcn=None):

self.fname_full = os.path.join(self.remote_path,'ptycho_'+str(param.scan_num)+'_'+param.sign)

self.parent_module = '.'.join(self.__module__.rsplit('.', 2)[:-1]) # get parent module name to run the correct recon worker

if param.working_directory:
param.working_directory = os.path.realpath(param.working_directory)+'/'
if param.prb_dir:
Expand All @@ -115,26 +216,39 @@ def recon_remote(self, param:Param, update_fcn=None):
param.obj_path = os.path.realpath(param.obj_path)

save_config(self.fname_full,param)
self.export_slurm_header()
#self.export_slurm_header()

self.return_value = 0 # Assume the recon will succeed unless later detects failure and modify it.

# try:
status = self.remote_job_handler.submit_job(self.remote_path, self.parent_module, param)
print(f"Submitted job from the gui with status code {status} and reserved job id {self.remote_job_handler.remote_job_id}")

time.sleep(1)
out = self.msg.readlines()
while not out:
while self.remote_job_handler.get_job_status() != "RUNNING":
print('Waiting for remote worker on %s to take the recon task...'%param.remote_srv)
time.sleep(1)


file_name = f"slurm-{self.remote_job_handler.remote_job_id}.out"
self.msg_file = os.path.join(self.remote_path, file_name)

self.msg = open(self.msg_file, "r")
out = self.msg.readlines()
pos = 0

time.sleep(1)
while not out:
print('Waiting for remote worker on %s to start writing...'%param.remote_srv)
out = self.msg.readlines()
if os.path.isfile(os.path.join(self.remote_path,'abort')):
os.remove(os.path.join(self.remote_path,'abort'))
if os.path.isfile(os.path.join(self.remote_path,'msg')):
os.remove(os.path.join(self.remote_path,'msg'))
if os.path.isfile(self.fname_full):
os.remove(self.fname_full)
raise Exception('Remote recon aborted...')
time.sleep(1)


time.sleep(1)
while True:
self.msg.seek(pos)
out = self.msg.readlines()
pos = self.msg.tell() # remember where we stopped

for line in out:
print(line, end='') # because the line already ends with '\n'
tokens = line.split()
Expand All @@ -145,19 +259,27 @@ def recon_remote(self, param:Param, update_fcn=None):
#print(result['probe_chi'])
if 'aborted' in line:
self.return_value = 1 # Aborted

if not os.path.isfile(self.fname_full):
break

# ask Slurm about job status
status = self.remote_job_handler.get_job_status()

# stop when job is no longer running AND there was no new data
if status != "RUNNING":
size = os.path.getsize(self.msg_file)
if size == pos:
break

time.sleep(0.1)
out = self.msg.readlines()

# except:
# pass
# finally:
# pass

def run(self):
print('Ptycho thread started')
print('Ptycho thread started helloooo***')
try:
self.recon_remote(self.param, self.update_signal.emit)
except IndexError:
Expand All @@ -172,10 +294,15 @@ def run(self):
self.update_signal.emit(self.param.n_iterations+1,None)

finally:
self.clear_slurm_header()
print('finally?')
#self.clear_slurm_header()
status = self.remote_job_handler.cancel_job()
print(f"Cancelled job with id {self.remote_job_handler.remote_job_id} from the gui with status code {status}")
self.cleanup()

def kill(self):
self.remote_job_handler.cancel_job()
print(f"Cancelled job with id {self.remote_job_handler.remote_job_id} from the gui with status code {status}")
if os.path.isdir(self.remote_path):
with open(os.path.join(self.remote_path,'abort'),'w') as f:
pass
Expand Down