Skip to content

Commit 9f74bb3

Browse files
guptaakacopybara-github
authored andcommitted
Add background log streaming to detect TPU placement completion
A background thread watches for specific log messages indicating that the proxy pod is waiting for placement until the TPU placement process has finished. This allows for better tracking of the Pathways service readiness. Continued "waiting" messages from proxy might indicate that the Pathways service doesn't have enough TPU availability to process the request. PiperOrigin-RevId: 888835053
1 parent a228b31 commit 9f74bb3

File tree

2 files changed

+93
-2
lines changed

2 files changed

+93
-2
lines changed

pathwaysutils/experimental/shared_pathways_service/gke_utils.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,32 @@ def enable_port_forwarding(
298298
return (port_available, port_forward_process)
299299

300300

301+
def stream_pod_logs(pod_name: str) -> subprocess.Popen[str]:
302+
"""Streams logs from the given pod.
303+
304+
Args:
305+
pod_name: The name of the pod.
306+
307+
Returns:
308+
The process for streaming the logs.
309+
310+
Raises:
311+
Exception: If the log streaming fails.
312+
"""
313+
command = ["kubectl", "logs", "-f", pod_name]
314+
try:
315+
return subprocess.Popen(
316+
command,
317+
stdout=subprocess.PIPE,
318+
stderr=subprocess.STDOUT,
319+
text=True,
320+
bufsize=1, # Line buffered
321+
)
322+
except Exception as _:
323+
_logger.exception("Error streaming logs for pod %s", pod_name)
324+
raise
325+
326+
301327
def delete_gke_job(job_name: str) -> None:
302328
"""Deletes the given job from the GKE cluster.
303329

pathwaysutils/experimental/shared_pathways_service/isc_pathways.py

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import random
1010
import string
1111
import subprocess
12+
import threading
1213
from typing import Any
1314

1415
import jax
@@ -123,6 +124,50 @@ def _deploy_pathways_proxy_server(
123124
_logger.info("Successfully deployed Pathways proxy.")
124125

125126

127+
def _wait_for_placement(
128+
pod_name: str,
129+
num_slices: int,
130+
stream_logs_func=gke_utils.stream_pod_logs,
131+
) -> None:
132+
"""Waits for the placement to be complete by checking proxy logs."""
133+
_logger.info("Streaming proxy logs until the placement is complete...")
134+
with stream_logs_func(pod_name) as log_process:
135+
keywords = [
136+
"placement",
137+
"Signaling to RM",
138+
"Transition slice",
139+
"FAILED_PRECONDITION",
140+
]
141+
end_phrase = "unplaced -> placed"
142+
placement_count = 0
143+
144+
if not log_process.stdout:
145+
_logger.error("Log streaming process stdout is empty. Terminating.")
146+
log_process.terminate()
147+
_, stderr = log_process.communicate()
148+
raise RuntimeError(
149+
"Failed to stream proxy logs: stdout not available.\n"
150+
f"STDERR: {stderr}"
151+
)
152+
153+
for line in log_process.stdout:
154+
line_lower = line.lower()
155+
if any(keyword.lower() in line_lower for keyword in keywords):
156+
_logger.info("Proxy log: %s", line.strip())
157+
158+
if end_phrase.lower() in line_lower:
159+
placement_count += 1
160+
if placement_count < num_slices:
161+
_logger.info(
162+
"TPU slice %d/%d placed!",
163+
placement_count,
164+
num_slices,
165+
)
166+
else:
167+
_logger.info("TPU placement for %d slice(s) complete!", num_slices)
168+
break
169+
170+
126171
def _restore_env_var(key: str, original_value: str | None) -> None:
127172
"""Restores an environment variable to its original value or unsets it."""
128173
if original_value is None:
@@ -147,6 +192,7 @@ class _ISCPathways:
147192
expected_tpu_instances: A dictionary mapping TPU machine types to the number
148193
of instances.
149194
proxy_job_name: The name to use for the deployed proxy.
195+
proxy_pod_name: The name of the proxy pod, assigned during deployment.
150196
proxy_server_image: The image to use for the proxy server.
151197
proxy_options: Configuration options for the Pathways proxy.
152198
"""
@@ -171,6 +217,7 @@ def __init__(
171217
self.pathways_service = pathways_service
172218
self.expected_tpu_instances = expected_tpu_instances
173219
self._proxy_job_name = proxy_job_name
220+
self.proxy_pod_name: str = ""
174221
self._port_forward_process = None
175222
self._proxy_port = None
176223
self.proxy_server_image = proxy_server_image
@@ -220,9 +267,11 @@ def __enter__(self):
220267
)
221268
_logger.info("View proxy logs in Cloud Logging: %s", cloud_logging_link)
222269

223-
proxy_pod = gke_utils.wait_for_pod(self._proxy_job_name)
270+
self.proxy_pod_name = gke_utils.wait_for_pod(self._proxy_job_name)
224271
self._proxy_port, self._port_forward_process = (
225-
gke_utils.enable_port_forwarding(proxy_pod, PROXY_SERVER_PORT)
272+
gke_utils.enable_port_forwarding(
273+
self.proxy_pod_name, PROXY_SERVER_PORT
274+
)
226275
)
227276

228277
# Update the JAX backend to use the proxy.
@@ -351,4 +400,20 @@ def connect(
351400
proxy_server_image=proxy_server_image,
352401
proxy_options=proxy_options,
353402
) as t:
403+
if t.proxy_pod_name:
404+
num_slices = sum(t.expected_tpu_instances.values())
405+
placement_thread = threading.Thread(
406+
target=_wait_for_placement,
407+
args=(
408+
t.proxy_pod_name,
409+
num_slices,
410+
),
411+
daemon=True,
412+
)
413+
placement_thread.start()
414+
else:
415+
_logger.warning(
416+
"proxy_pod_name not set on _ISCPathways instance, skipping background"
417+
" _wait_for_placement."
418+
)
354419
yield t

0 commit comments

Comments
 (0)