|
| 1 | +r"""Run a TPU workload with Shared Pathways Service. |
| 2 | +
|
| 3 | +Run your TPU workload locally using Shared Pathways Service, the service will |
| 4 | +deploy a Pathways proxy to run the TPU-specific components of your workload on |
| 5 | +the requested TPU slices. |
| 6 | +
|
| 7 | +Example: |
| 8 | +python3 run_workload.py \ |
| 9 | + --cluster my-cluster \ |
| 10 | + --project my-project \ |
| 11 | + --region=us-central1 \ |
| 12 | + --gcs_bucket=my-gcs-bucket \ |
| 13 | + --pathways_service=pathways-head:8000 \ |
| 14 | + --tpu_type=tpuv6e:4x8 \ |
| 15 | + --tpu_count=1 \ |
| 16 | + --command "python3 my_workload.py ..." |
| 17 | +
|
| 18 | +""" |
| 19 | + |
| 20 | +from collections.abc import Callable, Sequence |
| 21 | +import os |
| 22 | +import shlex |
| 23 | +import subprocess |
| 24 | +from typing import Any, ContextManager |
| 25 | + |
| 26 | +from absl import app |
| 27 | +from absl import flags |
| 28 | +from absl import logging |
| 29 | +from pathwaysutils.experimental.shared_pathways_service import isc_pathways |
| 30 | + |
| 31 | + |
| 32 | +_CLUSTER = flags.DEFINE_string( |
| 33 | + "cluster", None, "The name of the GKE cluster.", required=True |
| 34 | +) |
| 35 | +_PROJECT = flags.DEFINE_string( |
| 36 | + "project", None, "The GCP project ID.", required=True |
| 37 | +) |
| 38 | +_REGION = flags.DEFINE_string( |
| 39 | + "region", None, "The GCP region.", required=True |
| 40 | +) |
| 41 | +_GCS_BUCKET = flags.DEFINE_string( |
| 42 | + "gcs_bucket", None, "The Google Cloud Storage bucket.", required=True |
| 43 | +) |
| 44 | +_PATHWAYS_SERVICE = flags.DEFINE_string( |
| 45 | + "pathways_service", |
| 46 | + None, |
| 47 | + "The address and port of the Pathways Resource Manager. See" |
| 48 | + " https://github.com/AI-Hypercomputer/pathways-utils/tree/main/pathwaysutils/experimental/shared_pathways_service#4-find-the-pathways-service-address" |
| 49 | + " for instructions on how to get the Pathways service address.", |
| 50 | + required=True, |
| 51 | +) |
| 52 | +_TPU_TYPE = flags.DEFINE_string( |
| 53 | + "tpu_type", "tpuv6e:2x2", "The TPU machine type and topology." |
| 54 | +) |
| 55 | +_TPU_COUNT = flags.DEFINE_integer("tpu_count", 1, "The number of TPU slices.") |
| 56 | +_PROXY_SERVER_IMAGE = flags.DEFINE_string( |
| 57 | + "proxy_server_image", |
| 58 | + "", |
| 59 | + "The proxy server image to use. If not provided, a default will be used.", |
| 60 | +) |
| 61 | +_PROXY_OPTIONS = flags.DEFINE_list( |
| 62 | + "proxy_options", |
| 63 | + [], |
| 64 | + "Configuration options for the Pathways proxy. Specify entries in the form" |
| 65 | + ' "key:value". For example: --proxy_options=use_insecure_credentials:true', |
| 66 | +) |
| 67 | +_COMMAND = flags.DEFINE_string( |
| 68 | + "command", None, "The command to run on TPUs.", required=True |
| 69 | +) |
| 70 | + |
| 71 | +flags.register_validator( |
| 72 | + "proxy_options", |
| 73 | + lambda value: all( |
| 74 | + ":" in item |
| 75 | + and len(item.split(":")) > 1 |
| 76 | + and item.split(":", 1)[0] |
| 77 | + and item.split(":", 1)[1] |
| 78 | + for item in value |
| 79 | + ), |
| 80 | + message='--proxy_options must be in the format "key:value".', |
| 81 | +) |
| 82 | + |
| 83 | + |
| 84 | +def run_command( |
| 85 | + *, |
| 86 | + cluster: str, |
| 87 | + project: str, |
| 88 | + region: str, |
| 89 | + gcs_bucket: str, |
| 90 | + pathways_service: str, |
| 91 | + tpu_type: str, |
| 92 | + tpu_count: int, |
| 93 | + command: str, |
| 94 | + proxy_server_image: str | None = None, |
| 95 | + proxy_options: Sequence[str] | None = None, |
| 96 | + connect_fn: Callable[..., ContextManager[Any]] = isc_pathways.connect, |
| 97 | +) -> None: |
| 98 | + """Run the TPU workload within a Shared Pathways connection. |
| 99 | +
|
| 100 | + Args: |
| 101 | + cluster: The name of the GKE cluster. |
| 102 | + project: The GCP project ID. |
| 103 | + region: The GCP region. |
| 104 | + gcs_bucket: The Google Cloud Storage bucket. |
| 105 | + pathways_service: The address and port of the Pathways Resource Manager. |
| 106 | + tpu_type: The TPU machine type and topology. |
| 107 | + tpu_count: The number of TPU slices. |
| 108 | + command: The command to run on TPUs. |
| 109 | + proxy_server_image: The proxy server image to use. |
| 110 | + proxy_options: Configuration options for the Pathways proxy. |
| 111 | + connect_fn: The function to use for establishing the connection context, |
| 112 | + expected to be a callable that returns a context manager. |
| 113 | +
|
| 114 | + Raises: |
| 115 | + subprocess.CalledProcessError: If the workload command fails. |
| 116 | + """ |
| 117 | + parsed_proxy_options = isc_pathways.ProxyOptions.from_list(proxy_options) |
| 118 | + |
| 119 | + logging.info("Connecting to Shared Pathways Service...") |
| 120 | + with connect_fn( |
| 121 | + cluster=cluster, |
| 122 | + project=project, |
| 123 | + region=region, |
| 124 | + gcs_bucket=gcs_bucket, |
| 125 | + pathways_service=pathways_service, |
| 126 | + expected_tpu_instances={tpu_type: tpu_count}, |
| 127 | + proxy_server_image=( |
| 128 | + proxy_server_image |
| 129 | + if proxy_server_image |
| 130 | + else isc_pathways.DEFAULT_PROXY_IMAGE |
| 131 | + ), |
| 132 | + proxy_options=parsed_proxy_options, |
| 133 | + ): |
| 134 | + logging.info("Connection established. Running command: %r", command) |
| 135 | + try: |
| 136 | + command_args = shlex.split(command) |
| 137 | + subprocess.run(command_args, check=True, env=os.environ.copy()) |
| 138 | + except subprocess.CalledProcessError: |
| 139 | + logging.error( |
| 140 | + "Command failed! Find the underlying error in the logs above, where" |
| 141 | + " the command is invoked." |
| 142 | + ) |
| 143 | + raise |
| 144 | + finally: |
| 145 | + logging.info("Command execution finished.") |
| 146 | + |
| 147 | + |
| 148 | +def main(argv: Sequence[str]) -> None: |
| 149 | + if len(argv) > 1: |
| 150 | + raise app.UsageError("Too many command-line arguments.") |
| 151 | + |
| 152 | + run_command( |
| 153 | + cluster=_CLUSTER.value, |
| 154 | + project=_PROJECT.value, |
| 155 | + region=_REGION.value, |
| 156 | + gcs_bucket=_GCS_BUCKET.value, |
| 157 | + pathways_service=_PATHWAYS_SERVICE.value, |
| 158 | + tpu_type=_TPU_TYPE.value, |
| 159 | + tpu_count=_TPU_COUNT.value, |
| 160 | + command=_COMMAND.value, |
| 161 | + proxy_server_image=_PROXY_SERVER_IMAGE.value, |
| 162 | + proxy_options=_PROXY_OPTIONS.value, |
| 163 | + ) |
| 164 | + |
| 165 | + |
| 166 | +if __name__ == "__main__": |
| 167 | + app.run(main) |
0 commit comments