Skip to content

Commit b55bf8d

Browse files
Pathways-on-Cloud Teamcopybara-github
authored andcommitted
Add Shared Pathways Service (ISC) integration tests
PiperOrigin-RevId: 859809622
1 parent c0147e6 commit b55bf8d

File tree

4 files changed

+64
-2
lines changed

4 files changed

+64
-2
lines changed

pathwaysutils/experimental/shared_pathways_service/isc_pathways.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from collections.abc import Iterator, Mapping
44
import contextlib
5+
import gc
56
import logging
67
import os
78
import random
@@ -10,6 +11,7 @@
1011
from typing import Any
1112

1213
import jax
14+
import jax.extend.backend as jax_backend
1315
import pathwaysutils
1416
from pathwaysutils.experimental.shared_pathways_service import gke_utils
1517
from pathwaysutils.experimental.shared_pathways_service import validators
@@ -27,6 +29,7 @@
2729
_JAX_PLATFORM_PROXY = "proxy"
2830
_JAX_BACKEND_TARGET_KEY = "jax_backend_target"
2931
_JAX_BACKEND_TARGET_HOSTNAME = "grpc://localhost"
32+
_DEFAULT_PROXY_IMAGE = "us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:jax-0.8.0@sha256:5296fa0819d8cbdfbcf951ffca2072128255411557240624ff4011522a6a2abe"
3033

3134
_logger = logging.getLogger(__name__)
3235

@@ -36,6 +39,7 @@ def _deploy_pathways_proxy_server(
3639
proxy_job_name: str,
3740
expected_instances: Mapping[Any, Any],
3841
gcs_scratch_location: str,
42+
proxy_server_image: str,
3943
) -> None:
4044
"""Deploys the Pathways proxy pods to the GKE cluster.
4145
@@ -45,6 +49,7 @@ def _deploy_pathways_proxy_server(
4549
expected_instances: A dictionary mapping instance types to the number of
4650
instances.
4751
gcs_scratch_location: The Google Cloud Storage location to use.
52+
proxy_server_image: The image to use for the proxy server.
4853
4954
Raises:
5055
subprocess.CalledProcessError: If the kubectl command fails.
@@ -70,6 +75,7 @@ def _deploy_pathways_proxy_server(
7075
PATHWAYS_HEAD_PORT=pathways_head_port,
7176
EXPECTED_INSTANCES=instances_str,
7277
GCS_SCRATCH_LOCATION=gcs_scratch_location,
78+
PROXY_SERVER_IMAGE=proxy_server_image,
7379
)
7480

7581
_logger.info("Deploying Pathways proxy: %s", proxy_job_name)
@@ -89,6 +95,8 @@ class _ISCPathways:
8995
pathways_service: The service name and port of the Pathways head pod.
9096
expected_tpu_instances: A dictionary mapping TPU machine types to the number
9197
of instances.
98+
proxy_job_name: The name to use for the deployed proxy.
99+
proxy_server_image: The image to use for the proxy server.
92100
"""
93101

94102
def __init__(
@@ -100,6 +108,7 @@ def __init__(
100108
pathways_service: str,
101109
expected_tpu_instances: Mapping[Any, Any],
102110
proxy_job_name: str | None,
111+
proxy_server_image: str,
103112
):
104113
"""Initializes the TPU manager."""
105114
self.cluster = cluster
@@ -115,6 +124,7 @@ def __init__(
115124
self._proxy_job_name = proxy_job_name or f"isc-proxy-{user}-{suffix}"
116125
self._port_forward_process = None
117126
self._proxy_port = None
127+
self.proxy_server_image = proxy_server_image
118128

119129
def __repr__(self):
120130
return (
@@ -133,6 +143,7 @@ def __enter__(self):
133143
proxy_job_name=self._proxy_job_name,
134144
expected_instances=self.expected_tpu_instances,
135145
gcs_scratch_location=self.bucket,
146+
proxy_server_image=self.proxy_server_image,
136147
)
137148
# Print a link to Cloud Logging
138149
cloud_logging_link = gke_utils.get_log_link(
@@ -172,7 +183,16 @@ def __exit__(self, exc_type, exc_value, traceback):
172183

173184
def _cleanup(self):
174185
"""Cleans up resources created by the ISCPathways context."""
186+
# 1. Clear JAX caches and run garbage collection.
187+
_logger.info("Starting Pathways proxy cleanup.")
188+
jax_backend.clear_backends()
189+
jax.clear_caches()
190+
gc.collect()
191+
_logger.info("Cleared JAX caches and ran garbage collection.")
192+
193+
# 2. Terminate the port forwarding process.
175194
if self._port_forward_process:
195+
_logger.info("Terminating port forwarding process...")
176196
self._port_forward_process.terminate()
177197
try:
178198
self._port_forward_process.wait(timeout=10)
@@ -183,8 +203,10 @@ def _cleanup(self):
183203
e,
184204
)
185205

186-
_logger.info("Deleting Pathways proxy")
206+
# 3. Delete the proxy GKE job.
207+
_logger.info("Deleting Pathways proxy...")
187208
gke_utils.delete_gke_job(self._proxy_job_name)
209+
_logger.info("Pathways proxy GKE job deletion complete.")
188210

189211

190212
@contextlib.contextmanager
@@ -196,6 +218,7 @@ def connect(
196218
pathways_service: str,
197219
expected_tpu_instances: Mapping[str, int],
198220
proxy_job_name: str | None = None,
221+
proxy_server_image: str | None = _DEFAULT_PROXY_IMAGE,
199222
) -> Iterator["_ISCPathways"]:
200223
"""Connects to a Pathways server if the cluster exists. If not, creates it.
201224
@@ -209,13 +232,16 @@ def connect(
209232
of instances. For example: {"tpuv6e:2x2": 2}
210233
proxy_job_name: The name to use for the deployed proxy. If not provided, a
211234
random name will be generated.
235+
proxy_server_image: The proxy server image to use. If not provided, a
236+
default will be used.
212237
213238
Yields:
214239
The Pathways manager.
215240
"""
216241
_logger.info("Validating Pathways service and TPU instances...")
217242
validators.validate_pathways_service(pathways_service)
218243
validators.validate_tpu_instances(expected_tpu_instances)
244+
validators.validate_proxy_server_image(proxy_server_image)
219245
_logger.info("Validation complete.")
220246
gke_utils.fetch_cluster_credentials(
221247
cluster_name=cluster, project_id=project, location=region
@@ -229,5 +255,6 @@ def connect(
229255
pathways_service=pathways_service,
230256
expected_tpu_instances=expected_tpu_instances,
231257
proxy_job_name=proxy_job_name,
258+
proxy_server_image=proxy_server_image,
232259
) as t:
233260
yield t

pathwaysutils/experimental/shared_pathways_service/run_connect_example.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,17 @@
2424
"tpu_type", "tpuv6e:2x2", "The TPU machine type and topology."
2525
)
2626
flags.DEFINE_integer("tpu_count", 1, "The number of TPU slices.")
27+
flags.DEFINE_string(
28+
"proxy_job_name",
29+
None,
30+
"The name to use for the deployed proxy. If not provided, a random name"
31+
" will be generated.",
32+
)
33+
flags.DEFINE_string(
34+
"proxy_server_image",
35+
None,
36+
"The proxy server image to use. If not provided, a default will be used.",
37+
)
2738

2839
flags.mark_flags_as_required([
2940
"cluster",
@@ -37,13 +48,21 @@
3748
def main(argv: Sequence[str]) -> None:
3849
if len(argv) > 1:
3950
raise app.UsageError("Too many command-line arguments.")
51+
52+
kwargs = {}
53+
if FLAGS.proxy_job_name:
54+
kwargs["proxy_job_name"] = FLAGS.proxy_job_name
55+
if FLAGS.proxy_server_image:
56+
kwargs["proxy_server_image"] = FLAGS.proxy_server_image
57+
4058
with isc_pathways.connect(
4159
cluster=FLAGS.cluster,
4260
project=FLAGS.project,
4361
region=FLAGS.region,
4462
gcs_bucket=FLAGS.gcs_bucket,
4563
pathways_service=FLAGS.pathways_service,
4664
expected_tpu_instances={FLAGS.tpu_type: FLAGS.tpu_count},
65+
**kwargs,
4766
):
4867
orig_matrix = jnp.zeros(5)
4968
result_matrix = orig_matrix + 1

pathwaysutils/experimental/shared_pathways_service/validators.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,3 +89,19 @@ def validate_tpu_instances(expected_tpu_instances: Mapping[Any, Any]) -> None:
8989

9090
inst = next(iter(expected_tpu_instances.keys()))
9191
_validate_tpu_supported(inst)
92+
93+
94+
def validate_proxy_server_image(proxy_server_image: str) -> None:
95+
"""Validates the proxy server image format."""
96+
if not proxy_server_image or not proxy_server_image.strip():
97+
raise ValueError("Proxy server image cannot be empty.")
98+
if "/" not in proxy_server_image:
99+
raise ValueError(
100+
f"Proxy server image '{proxy_server_image}' must contain '/', "
101+
"separating the registry or namespace from the final image name."
102+
)
103+
if ":" not in proxy_server_image and "@" not in proxy_server_image:
104+
raise ValueError(
105+
f"Proxy server image '{proxy_server_image}' must contain a tag with ':'"
106+
" or a digest with '@'."
107+
)

pathwaysutils/experimental/shared_pathways_service/yamls/pw-proxy.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ spec:
1414
automountServiceAccountToken: false
1515
containers:
1616
- name: pathways-proxy
17-
image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:jax-0.8.0@sha256:5296fa0819d8cbdfbcf951ffca2072128255411557240624ff4011522a6a2abe
17+
image: ${PROXY_SERVER_IMAGE}
1818
imagePullPolicy: Always
1919
args:
2020
- --server_port=${PROXY_SERVER_PORT}

0 commit comments

Comments
 (0)