diff --git a/pathwaysutils/experimental/shared_pathways_service/isc_pathways.py b/pathwaysutils/experimental/shared_pathways_service/isc_pathways.py index 3c3b95a..7aabe9e 100644 --- a/pathwaysutils/experimental/shared_pathways_service/isc_pathways.py +++ b/pathwaysutils/experimental/shared_pathways_service/isc_pathways.py @@ -10,12 +10,14 @@ import string import subprocess import threading +import time from typing import Any import jax import jax.extend.backend as jax_backend import pathwaysutils from pathwaysutils.experimental.shared_pathways_service import gke_utils +from pathwaysutils.experimental.shared_pathways_service import metrics_collector from pathwaysutils.experimental.shared_pathways_service import validators @@ -128,6 +130,8 @@ def _wait_for_placement( pod_name: str, num_slices: int, stream_logs_func=gke_utils.stream_pod_logs, + metrics_collector_inst: Any = None, + start_time: float | None = None, ) -> None: """Waits for the placement to be complete by checking proxy logs.""" _logger.info("Streaming proxy logs until the placement is complete...") @@ -165,7 +169,16 @@ def _wait_for_placement( ) else: _logger.info("TPU placement for %d slice(s) complete!", num_slices) + if metrics_collector_inst: + metrics_collector_inst.record_user_waiting(0) + if start_time: + duration = time.time() - start_time + metrics_collector_inst.record_assignment_time(duration) + metrics_collector_inst.record_successful_request() break + else: + if metrics_collector_inst: + metrics_collector_inst.record_user_waiting(1) def _restore_env_var(key: str, original_value: str | None) -> None: @@ -195,11 +208,14 @@ class _ISCPathways: proxy_pod_name: The name of the proxy pod, assigned during deployment. proxy_server_image: The image to use for the proxy server. proxy_options: Configuration options for the Pathways proxy. + metrics_collector: The metrics collector instance if enabled. + start_time: The start time of the TPU assignment. """ def __init__( self, - *, cluster: str, + *, + cluster: str, project: str, region: str, gcs_bucket: str, @@ -208,6 +224,7 @@ def __init__( proxy_job_name: str, proxy_server_image: str, proxy_options: ProxyOptions | None = None, + collect_service_metrics: bool = False, ): """Initializes the TPU manager.""" self.cluster = cluster @@ -223,6 +240,10 @@ def __init__( self.proxy_server_image = proxy_server_image self.proxy_options = proxy_options or ProxyOptions() self._old_jax_platforms = None + self.metrics_collector = None + if collect_service_metrics: + self.metrics_collector = metrics_collector.MetricsCollector(self.project) + self.start_time = None self._old_jax_backend_target = None self._old_jax_platforms_config = None self._old_jax_backend_target_config = None @@ -237,8 +258,25 @@ def __repr__(self): f"proxy_options={self.proxy_options})" ) + def _get_total_chips(self) -> int: + """Calculates total chips from expected_tpu_instances.""" + total_chips = 0 + for tpu_type, count in self.expected_tpu_instances.items(): + parts = tpu_type.split(":") + topology = parts[1] + dimensions = [int(d) for d in topology.split("x")] + chips_per_instance = 1 + for d in dimensions: + chips_per_instance *= d + total_chips += chips_per_instance * count + return total_chips + def __enter__(self): """Enters the context manager, ensuring cluster exists.""" + if self.metrics_collector: + self.metrics_collector.record_active_user(True) + self.metrics_collector.record_capacity_in_use(self._get_total_chips()) + self._old_jax_platforms = os.environ.get(_JAX_PLATFORMS_KEY.upper()) self._old_jax_backend_target = os.environ.get( _JAX_BACKEND_TARGET_KEY.upper() @@ -251,6 +289,7 @@ def __enter__(self): ) try: + self.start_time = time.time() _deploy_pathways_proxy_server( pathways_service=self.pathways_service, proxy_job_name=self._proxy_job_name, @@ -303,14 +342,19 @@ def __exit__(self, exc_type, exc_value, traceback): def _cleanup(self) -> None: """Cleans up resources created by the ISCPathways context.""" - # 1. Clear JAX caches and run garbage collection. + # Reset metrics on exit. + if self.metrics_collector: + self.metrics_collector.record_active_user(False) + self.metrics_collector.record_capacity_in_use(0) + + # Clear JAX caches and run garbage collection. _logger.info("Starting Pathways proxy cleanup.") jax_backend.clear_backends() jax.clear_caches() gc.collect() _logger.info("Cleared JAX caches and ran garbage collection.") - # 2. Terminate the port forwarding process. + # Terminate the port forwarding process. if self._port_forward_process: _logger.info("Terminating port forwarding process...") self._port_forward_process.terminate() @@ -323,12 +367,12 @@ def _cleanup(self) -> None: e, ) - # 3. Delete the proxy GKE job. + # Delete the proxy GKE job. _logger.info("Deleting Pathways proxy...") gke_utils.delete_gke_job(self._proxy_job_name) _logger.info("Pathways proxy GKE job deletion complete.") - # 4. Restore JAX variables. + # Restore JAX variables. _logger.info("Restoring JAX env and config variables...") _restore_env_var(_JAX_PLATFORMS_KEY.upper(), self._old_jax_platforms) _restore_env_var( @@ -353,6 +397,7 @@ def connect( proxy_job_name: str | None = None, proxy_server_image: str = DEFAULT_PROXY_IMAGE, proxy_options: ProxyOptions | None = None, + collect_service_metrics: bool = False, ) -> Iterator["_ISCPathways"]: """Connects to a Pathways server if the cluster exists. If not, creates it. @@ -370,6 +415,8 @@ def connect( default will be used. proxy_options: Configuration options for the Pathways proxy. If not provided, no extra options will be used. + collect_service_metrics: Whether to collect usage metrics for Shared + Pathways Service. Yields: The Pathways manager. @@ -399,6 +446,7 @@ def connect( proxy_job_name=proxy_job_name, proxy_server_image=proxy_server_image, proxy_options=proxy_options, + collect_service_metrics=collect_service_metrics, ) as t: if t.proxy_pod_name: num_slices = sum(t.expected_tpu_instances.values()) @@ -407,6 +455,9 @@ def connect( args=( t.proxy_pod_name, num_slices, + gke_utils.stream_pod_logs, + t.metrics_collector, + t.start_time, ), daemon=True, ) diff --git a/pathwaysutils/experimental/shared_pathways_service/metrics_collector.py b/pathwaysutils/experimental/shared_pathways_service/metrics_collector.py new file mode 100644 index 0000000..24268e2 --- /dev/null +++ b/pathwaysutils/experimental/shared_pathways_service/metrics_collector.py @@ -0,0 +1,171 @@ +"""Metrics collector for Shared Pathways Service.""" + +import logging +import time +from typing import Any, Dict + +try: + # pylint: disable=g-import-not-at-top + from google.api_core import exceptions + from google.cloud import monitoring_v3 +except ImportError: + pass + +_logger = logging.getLogger(__name__) + + +METRIC_PREFIX = "custom.googleapis.com/shared_pathways_service/" + +_METRIC_NUM_ACTIVE_USERS = "num_active_users" +_METRIC_CAPACITY_IN_USE = "capacity_in_use" +_METRIC_ASSIGNMENT_TIME = "assignment_time" +_METRIC_NUM_SUCCESSFUL_REQS = "num_successful_reqs" +_METRIC_NUM_USERS_WAITING = "num_users_waiting" +_METRIC_DESCRIPTORS = [ + { + "name": _METRIC_NUM_ACTIVE_USERS, + "description": "Number of active users at any given time", + "value_type": "INT64", + "unit": "1", + }, + { + "name": _METRIC_CAPACITY_IN_USE, + "description": "Number of chips that are actively running workloads", + "value_type": "INT64", + "unit": "chips", + "display_name": "Capacity (chips) in use", + }, + { + "name": _METRIC_ASSIGNMENT_TIME, + "description": "Time to assign slice(s) to an incoming client", + "value_type": "DOUBLE", + "unit": "s", + "display_name": "Capacity assignment time", + }, + { + "name": _METRIC_NUM_SUCCESSFUL_REQS, + "description": ( + "Number of user requests that got capacity assignment successfully" + ), + "value_type": "INT64", + "unit": "1", + "display_name": "Successful capacity assignment requests", + }, + { + "name": _METRIC_NUM_USERS_WAITING, + "description": "Number of users waiting for capacity", + "value_type": "INT64", + "unit": "1", + "display_name": "Users waiting", + }, +] + + +class MetricsCollector: + """Collects usage metrics for Shared Pathways Service and reports to Cloud Monitoring.""" + + def __init__(self, project_id: str): + self.project_id = project_id + self.client = monitoring_v3.MetricServiceClient() + self.project_name = f"projects/{self.project_id}" + for descriptor in _METRIC_DESCRIPTORS: + self._create_metric_descriptor(**descriptor) + _logger.info("Metrics collection initialized.") + + def _create_time_series_object( + self, + metric_type: str, + value: Any, + value_type: str, + metric_labels: Dict[str, str] | None = None, + resource_type: str = "global", + resource_labels: Dict[str, str] | None = None, + ) -> Any: + """Creates a TimeSeries object for a single metric.""" + # Using Any for return type to avoid failing when monitoring_v3 is not + # available. + series = monitoring_v3.TimeSeries() + series.metric.type = METRIC_PREFIX + metric_type + series.resource.type = resource_type + if resource_labels: + series.resource.labels.update(resource_labels) + if metric_labels: + series.metric.labels.update(metric_labels) + + now = time.time() + seconds = int(now) + nanos = int((now - seconds) * 10**9) + + point = monitoring_v3.Point( + interval=monitoring_v3.TimeInterval( + end_time={"seconds": seconds, "nanos": nanos} + ), + value=monitoring_v3.TypedValue(**{value_type: value}), + ) + series.points.append(point) + return series + + def _send_metric( + self, + metric_type: str, + value: Any, + value_type: str, + metric_labels: Dict[str, str] | None = None, + ): + """Sends a single metric to Cloud Monitoring.""" + series = self._create_time_series_object( + metric_type, value, value_type, metric_labels + ) + self.client.create_time_series(name=self.project_name, time_series=[series]) + _logger.debug("Sent metric %s: %s", metric_type, value) + + def _create_metric_descriptor( + self, + name: str, + description: str, + value_type: str, + unit: str, + metric_kind: str = "GAUGE", + display_name: str | None = None, + ): + """Creates a metric descriptor if not already present.""" + metric_type = METRIC_PREFIX + name + display_name = display_name or name + + try: + self.client.create_metric_descriptor( + name=f"projects/{self.project_id}", + metric_descriptor={ + "type": metric_type, + "metric_kind": metric_kind, + "value_type": value_type, + "description": description, + "display_name": display_name, + "unit": unit, + }, + ) + _logger.info("Created metric descriptor: %s", metric_type) + except exceptions.AlreadyExists: + _logger.debug("Metric descriptor %s already exists.", metric_type) + + def record_active_user(self, is_active: bool): + """Records the number of active users (1 for active, 0 for inactive).""" + self._send_metric( + _METRIC_NUM_ACTIVE_USERS, 1 if is_active else 0, "int64_value" + ) + + def record_capacity_in_use(self, chips: int): + """Records the number of chips in use.""" + self._send_metric(_METRIC_CAPACITY_IN_USE, chips, "int64_value") + + def record_assignment_time(self, duration_seconds: float): + """Records the time taken to assign slices.""" + self._send_metric(_METRIC_ASSIGNMENT_TIME, duration_seconds, "double_value") + + def record_successful_request(self): + """Records a successful request.""" + self._send_metric(_METRIC_NUM_SUCCESSFUL_REQS, 1, "int64_value") + + def record_user_waiting(self): + """Records a user waiting for capacity.""" + self._send_metric(_METRIC_NUM_USERS_WAITING, 1, "int64_value") diff --git a/pathwaysutils/experimental/shared_pathways_service/run_connect_example.py b/pathwaysutils/experimental/shared_pathways_service/run_connect_example.py index db63c28..f07220e 100644 --- a/pathwaysutils/experimental/shared_pathways_service/run_connect_example.py +++ b/pathwaysutils/experimental/shared_pathways_service/run_connect_example.py @@ -41,6 +41,11 @@ "Configuration options for the Pathways proxy. Specify entries in the form" ' "key:value". For example: --proxy_options=use_insecure_credentials:true', ) +flags.DEFINE_bool( + "collect_service_metrics", + False, + "Whether to enable metrics collection for Shared Pathways Service.", +) flags.mark_flags_as_required([ "cluster", @@ -68,6 +73,7 @@ def main(argv: Sequence[str]) -> None: proxy_server_image=FLAGS.proxy_server_image or isc_pathways.DEFAULT_PROXY_IMAGE, proxy_options=proxy_options, + collect_service_metrics=FLAGS.collect_service_metrics, ): orig_matrix = jnp.zeros(5) result_matrix = orig_matrix + 1 diff --git a/pathwaysutils/experimental/shared_pathways_service/run_workload.py b/pathwaysutils/experimental/shared_pathways_service/run_workload.py index f662c35..c666f52 100644 --- a/pathwaysutils/experimental/shared_pathways_service/run_workload.py +++ b/pathwaysutils/experimental/shared_pathways_service/run_workload.py @@ -67,6 +67,14 @@ _COMMAND = flags.DEFINE_string( "command", None, "The command to run on TPUs.", required=True ) +_COLLECT_SERVICE_METRICS = flags.DEFINE_bool( + "collect_service_metrics", + False, + "Whether to enable metrics collection for Shared Pathways Service. If" + " enabled, the service will collect usage metrics such as TPU assignment" + " time, active user count, capacity in use etc. The metrics will be" + " stored in Cloud Monitoring.", +) flags.register_validator( "proxy_options", @@ -93,6 +101,7 @@ def run_command( command: str, proxy_server_image: str | None = None, proxy_options: Sequence[str] | None = None, + collect_service_metrics: bool = False, connect_fn: Callable[..., ContextManager[Any]] = isc_pathways.connect, ) -> None: """Run the TPU workload within a Shared Pathways connection. @@ -108,6 +117,8 @@ def run_command( command: The command to run on TPUs. proxy_server_image: The proxy server image to use. proxy_options: Configuration options for the Pathways proxy. + collect_service_metrics: Whether to collect usage metrics for Shared Pathways + Service. Defaults to False. connect_fn: The function to use for establishing the connection context, expected to be a callable that returns a context manager. @@ -130,6 +141,7 @@ def run_command( else isc_pathways.DEFAULT_PROXY_IMAGE ), proxy_options=parsed_proxy_options, + collect_service_metrics=collect_service_metrics, ): logging.info("Connection established. Running command: %r", command) try: @@ -160,6 +172,7 @@ def main(argv: Sequence[str]) -> None: command=_COMMAND.value, proxy_server_image=_PROXY_SERVER_IMAGE.value, proxy_options=_PROXY_OPTIONS.value, + collect_service_metrics=_COLLECT_SERVICE_METRICS.value, )