99import random
1010import string
1111import subprocess
12+ import threading
1213from typing import Any
1314
1415import jax
@@ -123,6 +124,50 @@ def _deploy_pathways_proxy_server(
123124 _logger .info ("Successfully deployed Pathways proxy." )
124125
125126
127+ def _wait_for_placement (
128+ pod_name : str ,
129+ num_slices : int ,
130+ stream_logs_func = gke_utils .stream_pod_logs ,
131+ ) -> None :
132+ """Waits for the placement to be complete by checking proxy logs."""
133+ _logger .info ("Streaming proxy logs until the placement is complete..." )
134+ with stream_logs_func (pod_name ) as log_process :
135+ keywords = [
136+ "placement" ,
137+ "Signaling to RM" ,
138+ "Transition slice" ,
139+ "FAILED_PRECONDITION" ,
140+ ]
141+ end_phrase = "unplaced -> placed"
142+ placement_count = 0
143+
144+ if not log_process .stdout :
145+ _logger .error ("Log streaming process stdout is empty. Terminating." )
146+ log_process .terminate ()
147+ _ , stderr = log_process .communicate ()
148+ raise RuntimeError (
149+ "Failed to stream proxy logs: stdout not available.\n "
150+ f"STDERR: { stderr } "
151+ )
152+
153+ for line in log_process .stdout :
154+ line_lower = line .lower ()
155+ if any (keyword .lower () in line_lower for keyword in keywords ):
156+ _logger .info ("Proxy log: %s" , line .strip ())
157+
158+ if end_phrase .lower () in line_lower :
159+ placement_count += 1
160+ if placement_count < num_slices :
161+ _logger .info (
162+ "TPU slice %d/%d placed!" ,
163+ placement_count ,
164+ num_slices ,
165+ )
166+ else :
167+ _logger .info ("TPU placement for %d slice(s) complete!" , num_slices )
168+ break
169+
170+
126171def _restore_env_var (key : str , original_value : str | None ) -> None :
127172 """Restores an environment variable to its original value or unsets it."""
128173 if original_value is None :
@@ -147,6 +192,7 @@ class _ISCPathways:
147192 expected_tpu_instances: A dictionary mapping TPU machine types to the number
148193 of instances.
149194 proxy_job_name: The name to use for the deployed proxy.
195+ proxy_pod_name: The name of the proxy pod, assigned during deployment.
150196 proxy_server_image: The image to use for the proxy server.
151197 proxy_options: Configuration options for the Pathways proxy.
152198 """
@@ -171,6 +217,7 @@ def __init__(
171217 self .pathways_service = pathways_service
172218 self .expected_tpu_instances = expected_tpu_instances
173219 self ._proxy_job_name = proxy_job_name
220+ self .proxy_pod_name : str = ""
174221 self ._port_forward_process = None
175222 self ._proxy_port = None
176223 self .proxy_server_image = proxy_server_image
@@ -220,9 +267,11 @@ def __enter__(self):
220267 )
221268 _logger .info ("View proxy logs in Cloud Logging: %s" , cloud_logging_link )
222269
223- proxy_pod = gke_utils .wait_for_pod (self ._proxy_job_name )
270+ self . proxy_pod_name = gke_utils .wait_for_pod (self ._proxy_job_name )
224271 self ._proxy_port , self ._port_forward_process = (
225- gke_utils .enable_port_forwarding (proxy_pod , PROXY_SERVER_PORT )
272+ gke_utils .enable_port_forwarding (
273+ self .proxy_pod_name , PROXY_SERVER_PORT
274+ )
226275 )
227276
228277 # Update the JAX backend to use the proxy.
@@ -351,4 +400,20 @@ def connect(
351400 proxy_server_image = proxy_server_image ,
352401 proxy_options = proxy_options ,
353402 ) as t :
403+ if t .proxy_pod_name :
404+ num_slices = sum (t .expected_tpu_instances .values ())
405+ placement_thread = threading .Thread (
406+ target = _wait_for_placement ,
407+ args = (
408+ t .proxy_pod_name ,
409+ num_slices ,
410+ ),
411+ daemon = True ,
412+ )
413+ placement_thread .start ()
414+ else :
415+ _logger .warning (
416+ "proxy_pod_name not set on _ISCPathways instance, skipping background"
417+ " _wait_for_placement."
418+ )
354419 yield t
0 commit comments