Skip to content

Commit b577f39

Browse files
guptaakacopybara-github
authored andcommitted
Add CLI mode to Shared Pathways Service
This commit adds a new script `run_workload.py`, which allows users to provide a command to the Shared Pathways Service. The user can simply add `pathwaysutils.initialize()` to their script and run their script with `--command` flag. PiperOrigin-RevId: 893772232
1 parent 9264000 commit b577f39

File tree

3 files changed

+184
-13
lines changed

3 files changed

+184
-13
lines changed

pathwaysutils/experimental/shared_pathways_service/gke_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import logging
44
import socket
55
import subprocess
6+
import time
67
import urllib.parse
78

89
import portpicker
@@ -189,6 +190,7 @@ def wait_for_pod(job_name: str) -> str:
189190
RuntimeError: If the pod is not ready.
190191
"""
191192
_logger.info("Waiting for pod to be created...")
193+
time.sleep(1)
192194
pod_name = get_pod_from_job(job_name)
193195

194196
_logger.info(

pathwaysutils/experimental/shared_pathways_service/isc_pathways.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -192,8 +192,10 @@ def __repr__(self):
192192

193193
def __enter__(self):
194194
"""Enters the context manager, ensuring cluster exists."""
195-
self._old_jax_platforms = os.environ.get(_JAX_PLATFORMS_KEY)
196-
self._old_jax_backend_target = os.environ.get(_JAX_BACKEND_TARGET_KEY)
195+
self._old_jax_platforms = os.environ.get(_JAX_PLATFORMS_KEY.upper())
196+
self._old_jax_backend_target = os.environ.get(
197+
_JAX_BACKEND_TARGET_KEY.upper()
198+
)
197199
self._old_jax_platforms_config = getattr(
198200
jax.config, _JAX_PLATFORMS_KEY, None
199201
)
@@ -224,16 +226,14 @@ def __enter__(self):
224226
)
225227

226228
# Update the JAX backend to use the proxy.
227-
os.environ[_JAX_PLATFORMS_KEY] = _JAX_PLATFORM_PROXY
228-
os.environ[
229-
_JAX_BACKEND_TARGET_KEY
230-
] = f"{_JAX_BACKEND_TARGET_HOSTNAME}:{self._proxy_port}"
231-
229+
jax_backend_target = f"{_JAX_BACKEND_TARGET_HOSTNAME}:{self._proxy_port}"
230+
# Update the JAX config for the inline mode of Shared Pathways Service.
232231
jax.config.update(_JAX_PLATFORMS_KEY, _JAX_PLATFORM_PROXY)
233-
jax.config.update(
234-
_JAX_BACKEND_TARGET_KEY,
235-
f"{_JAX_BACKEND_TARGET_HOSTNAME}:{self._proxy_port}",
236-
)
232+
jax.config.update(_JAX_BACKEND_TARGET_KEY, jax_backend_target)
233+
# Update the environment variables for the CLI mode of Shared Pathways
234+
# Service.
235+
os.environ[_JAX_PLATFORMS_KEY.upper()] = _JAX_PLATFORM_PROXY
236+
os.environ[_JAX_BACKEND_TARGET_KEY.upper()] = jax_backend_target
237237

238238
pathwaysutils.initialize()
239239
_logger.info(
@@ -281,8 +281,10 @@ def _cleanup(self) -> None:
281281

282282
# 4. Restore JAX variables.
283283
_logger.info("Restoring JAX env and config variables...")
284-
_restore_env_var(_JAX_PLATFORMS_KEY, self._old_jax_platforms)
285-
_restore_env_var(_JAX_BACKEND_TARGET_KEY, self._old_jax_backend_target)
284+
_restore_env_var(_JAX_PLATFORMS_KEY.upper(), self._old_jax_platforms)
285+
_restore_env_var(
286+
_JAX_BACKEND_TARGET_KEY.upper(), self._old_jax_backend_target
287+
)
286288
jax.config.update(_JAX_PLATFORMS_KEY, self._old_jax_platforms_config)
287289
jax.config.update(
288290
_JAX_BACKEND_TARGET_KEY, self._old_jax_backend_target_config
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
r"""Run a TPU workload with Shared Pathways Service.
2+
3+
Run your TPU workload locally using Shared Pathways Service, the service will
4+
deploy a Pathways proxy to run the TPU-specific components of your workload on
5+
the requested TPU slices.
6+
7+
Example:
8+
python3 run_workload.py \
9+
--cluster my-cluster \
10+
--project my-project \
11+
--region=us-central1 \
12+
--gcs_bucket=my-gcs-bucket \
13+
--pathways_service=pathways-head:8000 \
14+
--tpu_type=tpuv6e:4x8 \
15+
--tpu_count=1 \
16+
--command "python3 my_workload.py ..."
17+
18+
"""
19+
20+
from collections.abc import Callable, Sequence
21+
import os
22+
import shlex
23+
import subprocess
24+
from typing import Any, ContextManager
25+
26+
from absl import app
27+
from absl import flags
28+
from absl import logging
29+
from pathwaysutils.experimental.shared_pathways_service import isc_pathways
30+
31+
32+
_CLUSTER = flags.DEFINE_string(
33+
"cluster", None, "The name of the GKE cluster.", required=True
34+
)
35+
_PROJECT = flags.DEFINE_string(
36+
"project", None, "The GCP project ID.", required=True
37+
)
38+
_REGION = flags.DEFINE_string(
39+
"region", None, "The GCP region.", required=True
40+
)
41+
_GCS_BUCKET = flags.DEFINE_string(
42+
"gcs_bucket", None, "The Google Cloud Storage bucket.", required=True
43+
)
44+
_PATHWAYS_SERVICE = flags.DEFINE_string(
45+
"pathways_service",
46+
None,
47+
"The address and port of the Pathways Resource Manager. See"
48+
" https://github.com/AI-Hypercomputer/pathways-utils/tree/main/pathwaysutils/experimental/shared_pathways_service#4-find-the-pathways-service-address"
49+
" for instructions on how to get the Pathways service address.",
50+
required=True,
51+
)
52+
_TPU_TYPE = flags.DEFINE_string(
53+
"tpu_type", "tpuv6e:2x2", "The TPU machine type and topology."
54+
)
55+
_TPU_COUNT = flags.DEFINE_integer("tpu_count", 1, "The number of TPU slices.")
56+
_PROXY_SERVER_IMAGE = flags.DEFINE_string(
57+
"proxy_server_image",
58+
"",
59+
"The proxy server image to use. If not provided, a default will be used.",
60+
)
61+
_PROXY_OPTIONS = flags.DEFINE_list(
62+
"proxy_options",
63+
[],
64+
"Configuration options for the Pathways proxy. Specify entries in the form"
65+
' "key:value". For example: --proxy_options=use_insecure_credentials:true',
66+
)
67+
_COMMAND = flags.DEFINE_string(
68+
"command", None, "The command to run on TPUs.", required=True
69+
)
70+
71+
flags.register_validator(
72+
"proxy_options",
73+
lambda value: all(
74+
":" in item
75+
and len(item.split(":")) > 1
76+
and item.split(":", 1)[0]
77+
and item.split(":", 1)[1]
78+
for item in value
79+
),
80+
message='--proxy_options must be in the format "key:value".',
81+
)
82+
83+
84+
def run_command(
85+
*,
86+
cluster: str,
87+
project: str,
88+
region: str,
89+
gcs_bucket: str,
90+
pathways_service: str,
91+
tpu_type: str,
92+
tpu_count: int,
93+
command: str,
94+
proxy_server_image: str | None = None,
95+
proxy_options: Sequence[str] | None = None,
96+
connect_fn: Callable[..., ContextManager[Any]] = isc_pathways.connect,
97+
) -> None:
98+
"""Run the TPU workload within a Shared Pathways connection.
99+
100+
Args:
101+
cluster: The name of the GKE cluster.
102+
project: The GCP project ID.
103+
region: The GCP region.
104+
gcs_bucket: The Google Cloud Storage bucket.
105+
pathways_service: The address and port of the Pathways Resource Manager.
106+
tpu_type: The TPU machine type and topology.
107+
tpu_count: The number of TPU slices.
108+
command: The command to run on TPUs.
109+
proxy_server_image: The proxy server image to use.
110+
proxy_options: Configuration options for the Pathways proxy.
111+
connect_fn: The function to use for establishing the connection context,
112+
expected to be a callable that returns a context manager.
113+
114+
Raises:
115+
subprocess.CalledProcessError: If the workload command fails.
116+
"""
117+
parsed_proxy_options = isc_pathways.ProxyOptions.from_list(proxy_options)
118+
119+
logging.info("Connecting to Shared Pathways Service...")
120+
with connect_fn(
121+
cluster=cluster,
122+
project=project,
123+
region=region,
124+
gcs_bucket=gcs_bucket,
125+
pathways_service=pathways_service,
126+
expected_tpu_instances={tpu_type: tpu_count},
127+
proxy_server_image=(
128+
proxy_server_image
129+
if proxy_server_image
130+
else isc_pathways.DEFAULT_PROXY_IMAGE
131+
),
132+
proxy_options=parsed_proxy_options,
133+
):
134+
logging.info("Connection established. Running command: %r", command)
135+
try:
136+
command_args = shlex.split(command)
137+
subprocess.run(command_args, check=True, env=os.environ.copy())
138+
except subprocess.CalledProcessError:
139+
logging.error(
140+
"Command failed! Find the underlying error in the logs above, where"
141+
" the command is invoked."
142+
)
143+
raise
144+
finally:
145+
logging.info("Command execution finished.")
146+
147+
148+
def main(argv: Sequence[str]) -> None:
149+
if len(argv) > 1:
150+
raise app.UsageError("Too many command-line arguments.")
151+
152+
run_command(
153+
cluster=_CLUSTER.value,
154+
project=_PROJECT.value,
155+
region=_REGION.value,
156+
gcs_bucket=_GCS_BUCKET.value,
157+
pathways_service=_PATHWAYS_SERVICE.value,
158+
tpu_type=_TPU_TYPE.value,
159+
tpu_count=_TPU_COUNT.value,
160+
command=_COMMAND.value,
161+
proxy_server_image=_PROXY_SERVER_IMAGE.value,
162+
proxy_options=_PROXY_OPTIONS.value,
163+
)
164+
165+
166+
if __name__ == "__main__":
167+
app.run(main)

0 commit comments

Comments
 (0)