Skip to content

Commit aa51df3

Browse files
guptaakacopybara-github
authored andcommitted
Integrate metrics collection into ISC Pathways
This change introduces a new `metrics_collector` module to track key metrics within the ISC Pathways client. The `_ISCPathways` context manager initializes and uses the `MetricsCollector` to record the below metrics if `collect_isc_metrics` flag is enabled: - Active user count on context entry and exit - Capacity in use based on requested TPU instances - Assignment time upon successful TPU placement - Successful request count - Low capacity failures when "FAILED_PRECONDITION" is detected in proxy logs PiperOrigin-RevId: 899626835
1 parent c2a9fe0 commit aa51df3

File tree

4 files changed

+246
-5
lines changed

4 files changed

+246
-5
lines changed

pathwaysutils/experimental/shared_pathways_service/isc_pathways.py

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,14 @@
1010
import string
1111
import subprocess
1212
import threading
13+
import time
1314
from typing import Any
1415

1516
import jax
1617
import jax.extend.backend as jax_backend
1718
import pathwaysutils
1819
from pathwaysutils.experimental.shared_pathways_service import gke_utils
20+
from pathwaysutils.experimental.shared_pathways_service import metrics_collector
1921
from pathwaysutils.experimental.shared_pathways_service import validators
2022

2123

@@ -128,6 +130,8 @@ def _wait_for_placement(
128130
pod_name: str,
129131
num_slices: int,
130132
stream_logs_func=gke_utils.stream_pod_logs,
133+
metrics_collector_inst: Any = None,
134+
start_time: float | None = None,
131135
) -> None:
132136
"""Waits for the placement to be complete by checking proxy logs."""
133137
_logger.info("Streaming proxy logs until the placement is complete...")
@@ -165,7 +169,16 @@ def _wait_for_placement(
165169
)
166170
else:
167171
_logger.info("TPU placement for %d slice(s) complete!", num_slices)
172+
if metrics_collector_inst:
173+
metrics_collector_inst.record_user_waiting(0)
174+
if start_time:
175+
duration = time.time() - start_time
176+
metrics_collector_inst.record_assignment_time(duration)
177+
metrics_collector_inst.record_successful_request()
168178
break
179+
else:
180+
if metrics_collector_inst:
181+
metrics_collector_inst.record_user_waiting(1)
169182

170183

171184
def _restore_env_var(key: str, original_value: str | None) -> None:
@@ -195,11 +208,14 @@ class _ISCPathways:
195208
proxy_pod_name: The name of the proxy pod, assigned during deployment.
196209
proxy_server_image: The image to use for the proxy server.
197210
proxy_options: Configuration options for the Pathways proxy.
211+
metrics_collector: The metrics collector instance if enabled.
212+
start_time: The start time of the TPU assignment.
198213
"""
199214

200215
def __init__(
201216
self,
202-
*, cluster: str,
217+
*,
218+
cluster: str,
203219
project: str,
204220
region: str,
205221
gcs_bucket: str,
@@ -208,6 +224,7 @@ def __init__(
208224
proxy_job_name: str,
209225
proxy_server_image: str,
210226
proxy_options: ProxyOptions | None = None,
227+
collect_service_metrics: bool = False,
211228
):
212229
"""Initializes the TPU manager."""
213230
self.cluster = cluster
@@ -223,6 +240,10 @@ def __init__(
223240
self.proxy_server_image = proxy_server_image
224241
self.proxy_options = proxy_options or ProxyOptions()
225242
self._old_jax_platforms = None
243+
self.metrics_collector = None
244+
if collect_service_metrics:
245+
self.metrics_collector = metrics_collector.MetricsCollector(self.project)
246+
self.start_time = None
226247
self._old_jax_backend_target = None
227248
self._old_jax_platforms_config = None
228249
self._old_jax_backend_target_config = None
@@ -237,8 +258,25 @@ def __repr__(self):
237258
f"proxy_options={self.proxy_options})"
238259
)
239260

261+
def _get_total_chips(self) -> int:
262+
"""Calculates total chips from expected_tpu_instances."""
263+
total_chips = 0
264+
for tpu_type, count in self.expected_tpu_instances.items():
265+
parts = tpu_type.split(":")
266+
topology = parts[1]
267+
dimensions = [int(d) for d in topology.split("x")]
268+
chips_per_instance = 1
269+
for d in dimensions:
270+
chips_per_instance *= d
271+
total_chips += chips_per_instance * count
272+
return total_chips
273+
240274
def __enter__(self):
241275
"""Enters the context manager, ensuring cluster exists."""
276+
if self.metrics_collector:
277+
self.metrics_collector.record_active_user(True)
278+
self.metrics_collector.record_capacity_in_use(self._get_total_chips())
279+
242280
self._old_jax_platforms = os.environ.get(_JAX_PLATFORMS_KEY.upper())
243281
self._old_jax_backend_target = os.environ.get(
244282
_JAX_BACKEND_TARGET_KEY.upper()
@@ -251,6 +289,7 @@ def __enter__(self):
251289
)
252290

253291
try:
292+
self.start_time = time.time()
254293
_deploy_pathways_proxy_server(
255294
pathways_service=self.pathways_service,
256295
proxy_job_name=self._proxy_job_name,
@@ -303,14 +342,19 @@ def __exit__(self, exc_type, exc_value, traceback):
303342

304343
def _cleanup(self) -> None:
305344
"""Cleans up resources created by the ISCPathways context."""
306-
# 1. Clear JAX caches and run garbage collection.
345+
# Reset metrics on exit.
346+
if self.metrics_collector:
347+
self.metrics_collector.record_active_user(False)
348+
self.metrics_collector.record_capacity_in_use(0)
349+
350+
# Clear JAX caches and run garbage collection.
307351
_logger.info("Starting Pathways proxy cleanup.")
308352
jax_backend.clear_backends()
309353
jax.clear_caches()
310354
gc.collect()
311355
_logger.info("Cleared JAX caches and ran garbage collection.")
312356

313-
# 2. Terminate the port forwarding process.
357+
# Terminate the port forwarding process.
314358
if self._port_forward_process:
315359
_logger.info("Terminating port forwarding process...")
316360
self._port_forward_process.terminate()
@@ -323,12 +367,12 @@ def _cleanup(self) -> None:
323367
e,
324368
)
325369

326-
# 3. Delete the proxy GKE job.
370+
# Delete the proxy GKE job.
327371
_logger.info("Deleting Pathways proxy...")
328372
gke_utils.delete_gke_job(self._proxy_job_name)
329373
_logger.info("Pathways proxy GKE job deletion complete.")
330374

331-
# 4. Restore JAX variables.
375+
# Restore JAX variables.
332376
_logger.info("Restoring JAX env and config variables...")
333377
_restore_env_var(_JAX_PLATFORMS_KEY.upper(), self._old_jax_platforms)
334378
_restore_env_var(
@@ -353,6 +397,7 @@ def connect(
353397
proxy_job_name: str | None = None,
354398
proxy_server_image: str = DEFAULT_PROXY_IMAGE,
355399
proxy_options: ProxyOptions | None = None,
400+
collect_service_metrics: bool = False,
356401
) -> Iterator["_ISCPathways"]:
357402
"""Connects to a Pathways server if the cluster exists. If not, creates it.
358403
@@ -370,6 +415,8 @@ def connect(
370415
default will be used.
371416
proxy_options: Configuration options for the Pathways proxy. If not
372417
provided, no extra options will be used.
418+
collect_service_metrics: Whether to collect usage metrics for Shared
419+
Pathways Service.
373420
374421
Yields:
375422
The Pathways manager.
@@ -399,6 +446,7 @@ def connect(
399446
proxy_job_name=proxy_job_name,
400447
proxy_server_image=proxy_server_image,
401448
proxy_options=proxy_options,
449+
collect_service_metrics=collect_service_metrics,
402450
) as t:
403451
if t.proxy_pod_name:
404452
num_slices = sum(t.expected_tpu_instances.values())
@@ -407,6 +455,9 @@ def connect(
407455
args=(
408456
t.proxy_pod_name,
409457
num_slices,
458+
gke_utils.stream_pod_logs,
459+
t.metrics_collector,
460+
t.start_time,
410461
),
411462
daemon=True,
412463
)
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
"""Metrics collector for Shared Pathways Service."""
2+
3+
import logging
4+
import time
5+
from typing import Any, Dict
6+
7+
try:
8+
# pylint: disable=g-import-not-at-top
9+
from google.api_core import exceptions
10+
from google.cloud import monitoring_v3
11+
except ImportError:
12+
pass
13+
14+
_logger = logging.getLogger(__name__)
15+
16+
17+
METRIC_PREFIX = "custom.googleapis.com/shared_pathways_service/"
18+
19+
_METRIC_NUM_ACTIVE_USERS = "num_active_users"
20+
_METRIC_CAPACITY_IN_USE = "capacity_in_use"
21+
_METRIC_ASSIGNMENT_TIME = "assignment_time"
22+
_METRIC_NUM_SUCCESSFUL_REQS = "num_successful_reqs"
23+
_METRIC_NUM_USERS_WAITING = "num_users_waiting"
24+
_METRIC_DESCRIPTORS = [
25+
{
26+
"name": _METRIC_NUM_ACTIVE_USERS,
27+
"description": "Number of active users at any given time",
28+
"value_type": "INT64",
29+
"unit": "1",
30+
},
31+
{
32+
"name": _METRIC_CAPACITY_IN_USE,
33+
"description": "Number of chips that are actively running workloads",
34+
"value_type": "INT64",
35+
"unit": "chips",
36+
"display_name": "Capacity (chips) in use",
37+
},
38+
{
39+
"name": _METRIC_ASSIGNMENT_TIME,
40+
"description": "Time to assign slice(s) to an incoming client",
41+
"value_type": "DOUBLE",
42+
"unit": "s",
43+
"display_name": "Capacity assignment time",
44+
},
45+
{
46+
"name": _METRIC_NUM_SUCCESSFUL_REQS,
47+
"description": (
48+
"Number of user requests that got capacity assignment successfully"
49+
),
50+
"value_type": "INT64",
51+
"unit": "1",
52+
"display_name": "Successful capacity assignment requests",
53+
},
54+
{
55+
"name": _METRIC_NUM_USERS_WAITING,
56+
"description": "Number of users waiting for capacity",
57+
"value_type": "INT64",
58+
"unit": "1",
59+
"display_name": "Users waiting",
60+
},
61+
]
62+
63+
64+
class MetricsCollector:
65+
"""Collects usage metrics for Shared Pathways Service and reports to Cloud Monitoring."""
66+
67+
def __init__(self, project_id: str):
68+
self.project_id = project_id
69+
self.client = monitoring_v3.MetricServiceClient()
70+
self.project_name = f"projects/{self.project_id}"
71+
for descriptor in _METRIC_DESCRIPTORS:
72+
self._create_metric_descriptor(**descriptor)
73+
_logger.info("Metrics collection initialized.")
74+
75+
def _create_time_series_object(
76+
self,
77+
metric_type: str,
78+
value: Any,
79+
value_type: str,
80+
metric_labels: Dict[str, str] | None = None,
81+
resource_type: str = "global",
82+
resource_labels: Dict[str, str] | None = None,
83+
) -> Any:
84+
"""Creates a TimeSeries object for a single metric."""
85+
# Using Any for return type to avoid failing when monitoring_v3 is not
86+
# available.
87+
series = monitoring_v3.TimeSeries()
88+
series.metric.type = METRIC_PREFIX + metric_type
89+
series.resource.type = resource_type
90+
if resource_labels:
91+
series.resource.labels.update(resource_labels)
92+
if metric_labels:
93+
series.metric.labels.update(metric_labels)
94+
95+
now = time.time()
96+
seconds = int(now)
97+
nanos = int((now - seconds) * 10**9)
98+
99+
point = monitoring_v3.Point(
100+
interval=monitoring_v3.TimeInterval(
101+
end_time={"seconds": seconds, "nanos": nanos}
102+
),
103+
value=monitoring_v3.TypedValue(**{value_type: value}),
104+
)
105+
series.points.append(point)
106+
return series
107+
108+
def _send_metric(
109+
self,
110+
metric_type: str,
111+
value: Any,
112+
value_type: str,
113+
metric_labels: Dict[str, str] | None = None,
114+
):
115+
"""Sends a single metric to Cloud Monitoring."""
116+
series = self._create_time_series_object(
117+
metric_type, value, value_type, metric_labels
118+
)
119+
self.client.create_time_series(name=self.project_name, time_series=[series])
120+
_logger.debug("Sent metric %s: %s", metric_type, value)
121+
122+
def _create_metric_descriptor(
123+
self,
124+
name: str,
125+
description: str,
126+
value_type: str,
127+
unit: str,
128+
metric_kind: str = "GAUGE",
129+
display_name: str | None = None,
130+
):
131+
"""Creates a metric descriptor if not already present."""
132+
metric_type = METRIC_PREFIX + name
133+
display_name = display_name or name
134+
135+
try:
136+
self.client.create_metric_descriptor(
137+
name=f"projects/{self.project_id}",
138+
metric_descriptor={
139+
"type": metric_type,
140+
"metric_kind": metric_kind,
141+
"value_type": value_type,
142+
"description": description,
143+
"display_name": display_name,
144+
"unit": unit,
145+
},
146+
)
147+
_logger.info("Created metric descriptor: %s", metric_type)
148+
except exceptions.AlreadyExists:
149+
_logger.debug("Metric descriptor %s already exists.", metric_type)
150+
151+
def record_active_user(self, is_active: bool):
152+
"""Records the number of active users (1 for active, 0 for inactive)."""
153+
self._send_metric(
154+
_METRIC_NUM_ACTIVE_USERS, 1 if is_active else 0, "int64_value"
155+
)
156+
157+
def record_capacity_in_use(self, chips: int):
158+
"""Records the number of chips in use."""
159+
self._send_metric(_METRIC_CAPACITY_IN_USE, chips, "int64_value")
160+
161+
def record_assignment_time(self, duration_seconds: float):
162+
"""Records the time taken to assign slices."""
163+
self._send_metric(_METRIC_ASSIGNMENT_TIME, duration_seconds, "double_value")
164+
165+
def record_successful_request(self):
166+
"""Records a successful request."""
167+
self._send_metric(_METRIC_NUM_SUCCESSFUL_REQS, 1, "int64_value")
168+
169+
def record_user_waiting(self):
170+
"""Records a user waiting for capacity."""
171+
self._send_metric(_METRIC_NUM_USERS_WAITING, 1, "int64_value")

pathwaysutils/experimental/shared_pathways_service/run_connect_example.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,11 @@
4141
"Configuration options for the Pathways proxy. Specify entries in the form"
4242
' "key:value". For example: --proxy_options=use_insecure_credentials:true',
4343
)
44+
flags.DEFINE_bool(
45+
"collect_service_metrics",
46+
False,
47+
"Whether to enable metrics collection for Shared Pathways Service.",
48+
)
4449

4550
flags.mark_flags_as_required([
4651
"cluster",
@@ -68,6 +73,7 @@ def main(argv: Sequence[str]) -> None:
6873
proxy_server_image=FLAGS.proxy_server_image
6974
or isc_pathways.DEFAULT_PROXY_IMAGE,
7075
proxy_options=proxy_options,
76+
collect_service_metrics=FLAGS.collect_service_metrics,
7177
):
7278
orig_matrix = jnp.zeros(5)
7379
result_matrix = orig_matrix + 1

0 commit comments

Comments
 (0)