Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 56 additions & 5 deletions pathwaysutils/experimental/shared_pathways_service/isc_pathways.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@
import string
import subprocess
import threading
import time
from typing import Any

import jax
import jax.extend.backend as jax_backend
import pathwaysutils
from pathwaysutils.experimental.shared_pathways_service import gke_utils
from pathwaysutils.experimental.shared_pathways_service import metrics_collector
from pathwaysutils.experimental.shared_pathways_service import validators


Expand Down Expand Up @@ -128,6 +130,8 @@ def _wait_for_placement(
pod_name: str,
num_slices: int,
stream_logs_func=gke_utils.stream_pod_logs,
metrics_collector_inst: Any = None,
start_time: float | None = None,
) -> None:
"""Waits for the placement to be complete by checking proxy logs."""
_logger.info("Streaming proxy logs until the placement is complete...")
Expand Down Expand Up @@ -165,7 +169,16 @@ def _wait_for_placement(
)
else:
_logger.info("TPU placement for %d slice(s) complete!", num_slices)
if metrics_collector_inst:
metrics_collector_inst.record_user_waiting(0)
if start_time:
duration = time.time() - start_time
metrics_collector_inst.record_assignment_time(duration)
metrics_collector_inst.record_successful_request()
break
else:
if metrics_collector_inst:
metrics_collector_inst.record_user_waiting(1)


def _restore_env_var(key: str, original_value: str | None) -> None:
Expand Down Expand Up @@ -195,11 +208,14 @@ class _ISCPathways:
proxy_pod_name: The name of the proxy pod, assigned during deployment.
proxy_server_image: The image to use for the proxy server.
proxy_options: Configuration options for the Pathways proxy.
metrics_collector: The metrics collector instance if enabled.
start_time: The start time of the TPU assignment.
"""

def __init__(
self,
*, cluster: str,
*,
cluster: str,
project: str,
region: str,
gcs_bucket: str,
Expand All @@ -208,6 +224,7 @@ def __init__(
proxy_job_name: str,
proxy_server_image: str,
proxy_options: ProxyOptions | None = None,
collect_service_metrics: bool = False,
):
"""Initializes the TPU manager."""
self.cluster = cluster
Expand All @@ -223,6 +240,10 @@ def __init__(
self.proxy_server_image = proxy_server_image
self.proxy_options = proxy_options or ProxyOptions()
self._old_jax_platforms = None
self.metrics_collector = None
if collect_service_metrics:
self.metrics_collector = metrics_collector.MetricsCollector(self.project)
self.start_time = None
self._old_jax_backend_target = None
self._old_jax_platforms_config = None
self._old_jax_backend_target_config = None
Expand All @@ -237,8 +258,25 @@ def __repr__(self):
f"proxy_options={self.proxy_options})"
)

def _get_total_chips(self) -> int:
"""Calculates total chips from expected_tpu_instances."""
total_chips = 0
for tpu_type, count in self.expected_tpu_instances.items():
parts = tpu_type.split(":")
topology = parts[1]
dimensions = [int(d) for d in topology.split("x")]
chips_per_instance = 1
for d in dimensions:
chips_per_instance *= d
total_chips += chips_per_instance * count
return total_chips

def __enter__(self):
"""Enters the context manager, ensuring cluster exists."""
if self.metrics_collector:
self.metrics_collector.record_active_user(True)
self.metrics_collector.record_capacity_in_use(self._get_total_chips())

self._old_jax_platforms = os.environ.get(_JAX_PLATFORMS_KEY.upper())
self._old_jax_backend_target = os.environ.get(
_JAX_BACKEND_TARGET_KEY.upper()
Expand All @@ -251,6 +289,7 @@ def __enter__(self):
)

try:
self.start_time = time.time()
_deploy_pathways_proxy_server(
pathways_service=self.pathways_service,
proxy_job_name=self._proxy_job_name,
Expand Down Expand Up @@ -303,14 +342,19 @@ def __exit__(self, exc_type, exc_value, traceback):

def _cleanup(self) -> None:
"""Cleans up resources created by the ISCPathways context."""
# 1. Clear JAX caches and run garbage collection.
# Reset metrics on exit.
if self.metrics_collector:
self.metrics_collector.record_active_user(False)
self.metrics_collector.record_capacity_in_use(0)

# Clear JAX caches and run garbage collection.
_logger.info("Starting Pathways proxy cleanup.")
jax_backend.clear_backends()
jax.clear_caches()
gc.collect()
_logger.info("Cleared JAX caches and ran garbage collection.")

# 2. Terminate the port forwarding process.
# Terminate the port forwarding process.
if self._port_forward_process:
_logger.info("Terminating port forwarding process...")
self._port_forward_process.terminate()
Expand All @@ -323,12 +367,12 @@ def _cleanup(self) -> None:
e,
)

# 3. Delete the proxy GKE job.
# Delete the proxy GKE job.
_logger.info("Deleting Pathways proxy...")
gke_utils.delete_gke_job(self._proxy_job_name)
_logger.info("Pathways proxy GKE job deletion complete.")

# 4. Restore JAX variables.
# Restore JAX variables.
_logger.info("Restoring JAX env and config variables...")
_restore_env_var(_JAX_PLATFORMS_KEY.upper(), self._old_jax_platforms)
_restore_env_var(
Expand All @@ -353,6 +397,7 @@ def connect(
proxy_job_name: str | None = None,
proxy_server_image: str = DEFAULT_PROXY_IMAGE,
proxy_options: ProxyOptions | None = None,
collect_service_metrics: bool = False,
) -> Iterator["_ISCPathways"]:
"""Connects to a Pathways server if the cluster exists. If not, creates it.

Expand All @@ -370,6 +415,8 @@ def connect(
default will be used.
proxy_options: Configuration options for the Pathways proxy. If not
provided, no extra options will be used.
collect_service_metrics: Whether to collect usage metrics for Shared
Pathways Service.

Yields:
The Pathways manager.
Expand Down Expand Up @@ -399,6 +446,7 @@ def connect(
proxy_job_name=proxy_job_name,
proxy_server_image=proxy_server_image,
proxy_options=proxy_options,
collect_service_metrics=collect_service_metrics,
) as t:
if t.proxy_pod_name:
num_slices = sum(t.expected_tpu_instances.values())
Expand All @@ -407,6 +455,9 @@ def connect(
args=(
t.proxy_pod_name,
num_slices,
gke_utils.stream_pod_logs,
t.metrics_collector,
t.start_time,
),
daemon=True,
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
"""Metrics collector for Shared Pathways Service."""

import logging
import time
from typing import Any, Dict

try:
# pylint: disable=g-import-not-at-top
from google.api_core import exceptions
from google.cloud import monitoring_v3
except ImportError:
pass

_logger = logging.getLogger(__name__)


METRIC_PREFIX = "custom.googleapis.com/shared_pathways_service/"

_METRIC_NUM_ACTIVE_USERS = "num_active_users"
_METRIC_CAPACITY_IN_USE = "capacity_in_use"
_METRIC_ASSIGNMENT_TIME = "assignment_time"
_METRIC_NUM_SUCCESSFUL_REQS = "num_successful_reqs"
_METRIC_NUM_USERS_WAITING = "num_users_waiting"
_METRIC_DESCRIPTORS = [
{
"name": _METRIC_NUM_ACTIVE_USERS,
"description": "Number of active users at any given time",
"value_type": "INT64",
"unit": "1",
},
{
"name": _METRIC_CAPACITY_IN_USE,
"description": "Number of chips that are actively running workloads",
"value_type": "INT64",
"unit": "chips",
"display_name": "Capacity (chips) in use",
},
{
"name": _METRIC_ASSIGNMENT_TIME,
"description": "Time to assign slice(s) to an incoming client",
"value_type": "DOUBLE",
"unit": "s",
"display_name": "Capacity assignment time",
},
{
"name": _METRIC_NUM_SUCCESSFUL_REQS,
"description": (
"Number of user requests that got capacity assignment successfully"
),
"value_type": "INT64",
"unit": "1",
"display_name": "Successful capacity assignment requests",
},
{
"name": _METRIC_NUM_USERS_WAITING,
"description": "Number of users waiting for capacity",
"value_type": "INT64",
"unit": "1",
"display_name": "Users waiting",
},
]


class MetricsCollector:
"""Collects usage metrics for Shared Pathways Service and reports to Cloud Monitoring."""

def __init__(self, project_id: str):
self.project_id = project_id
self.client = monitoring_v3.MetricServiceClient()
self.project_name = f"projects/{self.project_id}"
for descriptor in _METRIC_DESCRIPTORS:
self._create_metric_descriptor(**descriptor)
_logger.info("Metrics collection initialized.")

def _create_time_series_object(
self,
metric_type: str,
value: Any,
value_type: str,
metric_labels: Dict[str, str] | None = None,
resource_type: str = "global",
resource_labels: Dict[str, str] | None = None,
) -> Any:
"""Creates a TimeSeries object for a single metric."""
# Using Any for return type to avoid failing when monitoring_v3 is not
# available.
series = monitoring_v3.TimeSeries()
series.metric.type = METRIC_PREFIX + metric_type
series.resource.type = resource_type
if resource_labels:
series.resource.labels.update(resource_labels)
if metric_labels:
series.metric.labels.update(metric_labels)

now = time.time()
seconds = int(now)
nanos = int((now - seconds) * 10**9)

point = monitoring_v3.Point(
interval=monitoring_v3.TimeInterval(
end_time={"seconds": seconds, "nanos": nanos}
),
value=monitoring_v3.TypedValue(**{value_type: value}),
)
series.points.append(point)
return series

def _send_metric(
self,
metric_type: str,
value: Any,
value_type: str,
metric_labels: Dict[str, str] | None = None,
):
"""Sends a single metric to Cloud Monitoring."""
series = self._create_time_series_object(
metric_type, value, value_type, metric_labels
)
self.client.create_time_series(name=self.project_name, time_series=[series])
_logger.debug("Sent metric %s: %s", metric_type, value)

def _create_metric_descriptor(
self,
name: str,
description: str,
value_type: str,
unit: str,
metric_kind: str = "GAUGE",
display_name: str | None = None,
):
"""Creates a metric descriptor if not already present."""
metric_type = METRIC_PREFIX + name
display_name = display_name or name

try:
self.client.create_metric_descriptor(
name=f"projects/{self.project_id}",
metric_descriptor={
"type": metric_type,
"metric_kind": metric_kind,
"value_type": value_type,
"description": description,
"display_name": display_name,
"unit": unit,
},
)
_logger.info("Created metric descriptor: %s", metric_type)
except exceptions.AlreadyExists:
_logger.debug("Metric descriptor %s already exists.", metric_type)

def record_active_user(self, is_active: bool):
"""Records the number of active users (1 for active, 0 for inactive)."""
self._send_metric(
_METRIC_NUM_ACTIVE_USERS, 1 if is_active else 0, "int64_value"
)

def record_capacity_in_use(self, chips: int):
"""Records the number of chips in use."""
self._send_metric(_METRIC_CAPACITY_IN_USE, chips, "int64_value")

def record_assignment_time(self, duration_seconds: float):
"""Records the time taken to assign slices."""
self._send_metric(_METRIC_ASSIGNMENT_TIME, duration_seconds, "double_value")

def record_successful_request(self):
"""Records a successful request."""
self._send_metric(_METRIC_NUM_SUCCESSFUL_REQS, 1, "int64_value")

def record_user_waiting(self):
"""Records a user waiting for capacity."""
self._send_metric(_METRIC_NUM_USERS_WAITING, 1, "int64_value")
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@
"Configuration options for the Pathways proxy. Specify entries in the form"
' "key:value". For example: --proxy_options=use_insecure_credentials:true',
)
flags.DEFINE_bool(
"collect_service_metrics",
False,
"Whether to enable metrics collection for Shared Pathways Service.",
)

flags.mark_flags_as_required([
"cluster",
Expand Down Expand Up @@ -68,6 +73,7 @@ def main(argv: Sequence[str]) -> None:
proxy_server_image=FLAGS.proxy_server_image
or isc_pathways.DEFAULT_PROXY_IMAGE,
proxy_options=proxy_options,
collect_service_metrics=FLAGS.collect_service_metrics,
):
orig_matrix = jnp.zeros(5)
result_matrix = orig_matrix + 1
Expand Down
Loading
Loading