diff --git a/cloud/google/common.py b/cloud/google/common.py index ff107c3a..ffb1bac6 100644 --- a/cloud/google/common.py +++ b/cloud/google/common.py @@ -14,6 +14,7 @@ mkdir -p /etc/docker systemctl restart docker gcloud auth --quiet configure-docker +gcloud auth --quiet configure-docker us-docker.pkg.dev,us-east1-docker.pkg.dev,us-central1-docker.pkg.dev,us-west1-docker.pkg.dev ''' INSTALL_NVIDIA_DOCKER_CMD = ''' diff --git a/dags/google_api_helper.py b/dags/google_api_helper.py index 554a183c..cf39fab5 100644 --- a/dags/google_api_helper.py +++ b/dags/google_api_helper.py @@ -223,21 +223,27 @@ def ramp_up_cluster(key, initial_size, total_size): run_metadata = Variable.get("run_metadata", deserialize_json=True, default_var={}) if not run_metadata.get("manage_clusters", True): return + already_at_size = False try: target_sizes = Variable.get("cluster_target_size", deserialize_json=True) + already_at_size = target_sizes.get(key, 0) >= total_size target_sizes[key] = total_size Variable.set("cluster_target_size", target_sizes, serialize_json=True) slack_message(":information_source: ramping up cluster {} to {} instances, starting from {} instances".format(key, total_size, min(initial_size, total_size))) increase_instance_group_size(key, min(initial_size, total_size)) except: increase_instance_group_size(key, total_size) - sleep(60) + if not already_at_size: + sleep(60) Variable.set("cluster_target_size", target_sizes, serialize_json=True) def ramp_down_cluster(key, total_size): run_metadata = Variable.get("run_metadata", deserialize_json=True, default_var={}) if not run_metadata.get("manage_clusters", True): return + if Variable.get("batch_keep_cluster", default_var="false") == "true": + slack_message(f":recycle: Batch mode: keeping {key} cluster alive for next job") + return try: target_sizes = Variable.get("cluster_target_size", deserialize_json=True) target_sizes[key] = total_size diff --git a/slackbot/pipeline_commands.py b/slackbot/pipeline_commands.py index 6e1475c0..fb4b3a55 100644 --- a/slackbot/pipeline_commands.py +++ b/slackbot/pipeline_commands.py @@ -207,53 +207,62 @@ def handle_batch(task, msg): replyto(msg, "Batch jobs will reuse on the parameters from the first job unless new parameters are specified, *including those with default values*") default_param = json_obj[0] - for i, p in enumerate(json_obj): - if visible_messages(broker_url, "seuronbot_cmd") != 0: - cmd = get_message(broker_url, "seuronbot_cmd") - if cmd == "cancel": - replyto(msg, "Cancel batch process") - break - - if p.get("INHERIT_PARAMETERS", True): - param = deepcopy(default_param) - else: - param = {} - - if i > 0: - if 'NAME' in param: - del param['NAME'] - for k in p: - param[k] = p[k] - supply_default_param(param) - replyto(msg, "*Sanity check: batch job {} out of {}*".format(i+1, len(json_obj))) + is_batch = len(json_obj) > 1 + try: + for i, p in enumerate(json_obj): + if visible_messages(broker_url, "seuronbot_cmd") != 0: + cmd = get_message(broker_url, "seuronbot_cmd") + if cmd == "cancel": + replyto(msg, "Cancel batch process") + break + + if p.get("INHERIT_PARAMETERS", True): + param = deepcopy(default_param) + else: + param = {} + + if i > 0: + if 'NAME' in param: + del param['NAME'] + for k in p: + param[k] = p[k] + supply_default_param(param) + replyto(msg, "*Sanity check: batch job {} out of {}*".format(i+1, len(json_obj))) + state = "unknown" + current_task = guess_run_type(param) + if current_task == "seg_run": + set_variable('param', param, serialize_json=True) + state = run_dag("sanity_check", wait_for_completion=True).state + elif current_task == "inf_run": + set_variable('inference_param', param, serialize_json=True) + state = run_dag("chunkflow_generator", wait_for_completion=True).state + elif current_task == "syn_run": + set_variable("synaptor_param.json", param, serialize_json=True) + state = run_dag("synaptor_sanity_check", wait_for_completion=True).state + + if state != "success": + replyto(msg, "*Sanity check failed, abort!*") + break + + is_last_job = (i == len(json_obj) - 1) state = "unknown" - current_task = guess_run_type(param) + replyto(msg, "*Starting batch job {} out of {}*".format(i+1, len(json_obj)), broadcast=True) + if current_task == "seg_run": - set_variable('param', param, serialize_json=True) - state = run_dag("sanity_check", wait_for_completion=True).state + state = run_dag('segmentation', wait_for_completion=True).state elif current_task == "inf_run": - set_variable('inference_param', param, serialize_json=True) - state = run_dag("chunkflow_generator", wait_for_completion=True).state + if is_batch and not is_last_job: + set_variable("batch_keep_cluster", "true") + else: + set_variable("batch_keep_cluster", "false") + state = run_dag("chunkflow_worker", wait_for_completion=True).state elif current_task == "syn_run": - set_variable("synaptor_param.json", param, serialize_json=True) - state = run_dag("synaptor_sanity_check", wait_for_completion=True).state + state = run_dag("synaptor", wait_for_completion=True).state if state != "success": - replyto(msg, "*Sanity check failed, abort!*") + replyto(msg, f"*Bach job failed, abort!* ({state})") break - - state = "unknown" - replyto(msg, "*Starting batch job {} out of {}*".format(i+1, len(json_obj)), broadcast=True) - - if current_task == "seg_run": - state = run_dag('segmentation', wait_for_completion=True).state - elif current_task == "inf_run": - state = run_dag("chunkflow_worker", wait_for_completion=True).state - elif current_task == "syn_run": - state = run_dag("synaptor", wait_for_completion=True).state - - if state != "success": - replyto(msg, f"*Bach job failed, abort!* ({state})") - break + finally: + set_variable("batch_keep_cluster", "false") replyto(msg, "*Batch process finished*")