22
33from collections .abc import Iterator , Mapping
44import contextlib
5+ import gc
56import logging
67import os
78import random
1011from typing import Any
1112
1213import jax
14+ import jax .extend .backend as jax_backend
1315import pathwaysutils
1416from pathwaysutils .experimental .shared_pathways_service import gke_utils
1517from pathwaysutils .experimental .shared_pathways_service import validators
2729_JAX_PLATFORM_PROXY = "proxy"
2830_JAX_BACKEND_TARGET_KEY = "jax_backend_target"
2931_JAX_BACKEND_TARGET_HOSTNAME = "grpc://localhost"
32+ _DEFAULT_PROXY_IMAGE = "us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:jax-0.8.0@sha256:5296fa0819d8cbdfbcf951ffca2072128255411557240624ff4011522a6a2abe"
3033
3134_logger = logging .getLogger (__name__ )
3235
@@ -36,6 +39,7 @@ def _deploy_pathways_proxy_server(
3639 proxy_job_name : str ,
3740 expected_instances : Mapping [Any , Any ],
3841 gcs_scratch_location : str ,
42+ proxy_server_image : str ,
3943) -> None :
4044 """Deploys the Pathways proxy pods to the GKE cluster.
4145
@@ -45,6 +49,7 @@ def _deploy_pathways_proxy_server(
4549 expected_instances: A dictionary mapping instance types to the number of
4650 instances.
4751 gcs_scratch_location: The Google Cloud Storage location to use.
52+ proxy_server_image: The image to use for the proxy server.
4853
4954 Raises:
5055 subprocess.CalledProcessError: If the kubectl command fails.
@@ -70,6 +75,7 @@ def _deploy_pathways_proxy_server(
7075 PATHWAYS_HEAD_PORT = pathways_head_port ,
7176 EXPECTED_INSTANCES = instances_str ,
7277 GCS_SCRATCH_LOCATION = gcs_scratch_location ,
78+ PROXY_SERVER_IMAGE = proxy_server_image ,
7379 )
7480
7581 _logger .info ("Deploying Pathways proxy: %s" , proxy_job_name )
@@ -89,6 +95,8 @@ class _ISCPathways:
8995 pathways_service: The service name and port of the Pathways head pod.
9096 expected_tpu_instances: A dictionary mapping TPU machine types to the number
9197 of instances.
98+ proxy_job_name: The name to use for the deployed proxy.
99+ proxy_server_image: The image to use for the proxy server.
92100 """
93101
94102 def __init__ (
@@ -100,6 +108,7 @@ def __init__(
100108 pathways_service : str ,
101109 expected_tpu_instances : Mapping [Any , Any ],
102110 proxy_job_name : str | None ,
111+ proxy_server_image : str ,
103112 ):
104113 """Initializes the TPU manager."""
105114 self .cluster = cluster
@@ -115,6 +124,7 @@ def __init__(
115124 self ._proxy_job_name = proxy_job_name or f"isc-proxy-{ user } -{ suffix } "
116125 self ._port_forward_process = None
117126 self ._proxy_port = None
127+ self .proxy_server_image = proxy_server_image
118128
119129 def __repr__ (self ):
120130 return (
@@ -133,6 +143,7 @@ def __enter__(self):
133143 proxy_job_name = self ._proxy_job_name ,
134144 expected_instances = self .expected_tpu_instances ,
135145 gcs_scratch_location = self .bucket ,
146+ proxy_server_image = self .proxy_server_image ,
136147 )
137148 # Print a link to Cloud Logging
138149 cloud_logging_link = gke_utils .get_log_link (
@@ -172,7 +183,16 @@ def __exit__(self, exc_type, exc_value, traceback):
172183
173184 def _cleanup (self ):
174185 """Cleans up resources created by the ISCPathways context."""
186+ # 1. Clear JAX caches and run garbage collection.
187+ _logger .info ("Starting Pathways proxy cleanup." )
188+ jax_backend .clear_backends ()
189+ jax .clear_caches ()
190+ gc .collect ()
191+ _logger .info ("Cleared JAX caches and ran garbage collection." )
192+
193+ # 2. Terminate the port forwarding process.
175194 if self ._port_forward_process :
195+ _logger .info ("Terminating port forwarding process..." )
176196 self ._port_forward_process .terminate ()
177197 try :
178198 self ._port_forward_process .wait (timeout = 10 )
@@ -183,8 +203,10 @@ def _cleanup(self):
183203 e ,
184204 )
185205
186- _logger .info ("Deleting Pathways proxy" )
206+ # 3. Delete the proxy GKE job.
207+ _logger .info ("Deleting Pathways proxy..." )
187208 gke_utils .delete_gke_job (self ._proxy_job_name )
209+ _logger .info ("Pathways proxy GKE job deletion complete." )
188210
189211
190212@contextlib .contextmanager
@@ -196,6 +218,7 @@ def connect(
196218 pathways_service : str ,
197219 expected_tpu_instances : Mapping [str , int ],
198220 proxy_job_name : str | None = None ,
221+ proxy_server_image : str | None = _DEFAULT_PROXY_IMAGE ,
199222) -> Iterator ["_ISCPathways" ]:
200223 """Connects to a Pathways server if the cluster exists. If not, creates it.
201224
@@ -209,13 +232,16 @@ def connect(
209232 of instances. For example: {"tpuv6e:2x2": 2}
210233 proxy_job_name: The name to use for the deployed proxy. If not provided, a
211234 random name will be generated.
235+ proxy_server_image: The proxy server image to use. If not provided, a
236+ default will be used.
212237
213238 Yields:
214239 The Pathways manager.
215240 """
216241 _logger .info ("Validating Pathways service and TPU instances..." )
217242 validators .validate_pathways_service (pathways_service )
218243 validators .validate_tpu_instances (expected_tpu_instances )
244+ validators .validate_proxy_server_image (proxy_server_image )
219245 _logger .info ("Validation complete." )
220246 gke_utils .fetch_cluster_credentials (
221247 cluster_name = cluster , project_id = project , location = region
@@ -229,5 +255,6 @@ def connect(
229255 pathways_service = pathways_service ,
230256 expected_tpu_instances = expected_tpu_instances ,
231257 proxy_job_name = proxy_job_name ,
258+ proxy_server_image = proxy_server_image ,
232259 ) as t :
233260 yield t
0 commit comments