Skip to content
This repository was archived by the owner on Mar 31, 2026. It is now read-only.

Commit 1341f21

Browse files
committed
fix: Metrics thread-safety refactor and Batch.commit idempotency fix
1 parent 67c682e commit 1341f21

File tree

9 files changed

+232
-105
lines changed

9 files changed

+232
-105
lines changed

google/cloud/spanner_v1/batch.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
"""Context manager for Cloud Spanner batched writes."""
16+
1617
import functools
1718
from typing import List, Optional
1819

@@ -242,6 +243,8 @@ def commit(
242243
observability_options=getattr(database, "observability_options", None),
243244
metadata=metadata,
244245
) as span, MetricsCapture():
246+
nth_request = getattr(database, "_next_nth_request", 0)
247+
attempt = AtomicCounter(0)
245248

246249
def wrapped_method():
247250
commit_request = CommitRequest(
@@ -256,8 +259,8 @@ def wrapped_method():
256259
# should be increased. attempt can only be increased if
257260
# we encounter UNAVAILABLE or INTERNAL.
258261
call_metadata, error_augmenter = database.with_error_augmentation(
259-
getattr(database, "_next_nth_request", 0),
260-
1,
262+
nth_request,
263+
attempt.increment(),
261264
metadata,
262265
span,
263266
)

google/cloud/spanner_v1/client.py

Lines changed: 38 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
* a :class:`~google.cloud.spanner_v1.instance.Instance` owns a
2424
:class:`~google.cloud.spanner_v1.database.Database`
2525
"""
26+
2627
import grpc
2728
import os
2829
import logging
@@ -108,6 +109,42 @@ def _get_spanner_enable_builtin_metrics_env():
108109
return os.getenv(SPANNER_DISABLE_BUILTIN_METRICS_ENV_VAR) != "true"
109110

110111

112+
def _initialize_metrics(project, credentials):
113+
"""
114+
Initializes the Spanner built-in metrics.
115+
116+
This function sets up the OpenTelemetry MeterProvider and the SpannerMetricsTracerFactory.
117+
It uses a lock to ensure that initialization happens only once.
118+
"""
119+
global _metrics_monitor_initialized
120+
if not _metrics_monitor_initialized:
121+
with _metrics_monitor_lock:
122+
if not _metrics_monitor_initialized:
123+
meter_provider = metrics.NoOpMeterProvider()
124+
try:
125+
if not _get_spanner_emulator_host():
126+
meter_provider = MeterProvider(
127+
metric_readers=[
128+
PeriodicExportingMetricReader(
129+
CloudMonitoringMetricsExporter(
130+
project_id=project,
131+
credentials=credentials,
132+
),
133+
export_interval_millis=METRIC_EXPORT_INTERVAL_MS,
134+
),
135+
]
136+
)
137+
metrics.set_meter_provider(meter_provider)
138+
SpannerMetricsTracerFactory()
139+
_metrics_monitor_initialized = True
140+
except Exception as e:
141+
# log is already defined at module level
142+
log.warning(
143+
"Failed to initialize Spanner built-in metrics. Error: %s",
144+
e,
145+
)
146+
147+
111148
class Client(ClientWithProject):
112149
"""Client for interacting with Cloud Spanner API.
113150
@@ -262,31 +299,7 @@ def __init__(
262299
and not disable_builtin_metrics
263300
and HAS_GOOGLE_CLOUD_MONITORING_INSTALLED
264301
):
265-
if not _metrics_monitor_initialized:
266-
with _metrics_monitor_lock:
267-
if not _metrics_monitor_initialized:
268-
meter_provider = metrics.NoOpMeterProvider()
269-
try:
270-
if not _get_spanner_emulator_host():
271-
meter_provider = MeterProvider(
272-
metric_readers=[
273-
PeriodicExportingMetricReader(
274-
CloudMonitoringMetricsExporter(
275-
project_id=project,
276-
credentials=credentials,
277-
),
278-
export_interval_millis=METRIC_EXPORT_INTERVAL_MS,
279-
),
280-
]
281-
)
282-
metrics.set_meter_provider(meter_provider)
283-
SpannerMetricsTracerFactory()
284-
_metrics_monitor_initialized = True
285-
except Exception as e:
286-
log.warning(
287-
"Failed to initialize Spanner built-in metrics. Error: %s",
288-
e,
289-
)
302+
_initialize_metrics(project, credentials)
290303
else:
291304
SpannerMetricsTracerFactory(enabled=False)
292305

google/cloud/spanner_v1/metrics/metrics_capture.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,18 @@
2323
from .spanner_metrics_tracer_factory import SpannerMetricsTracerFactory
2424

2525

26+
from contextvars import Token
27+
28+
2629
class MetricsCapture:
2730
"""Context manager for capturing metrics in Cloud Spanner operations.
2831
2932
This class provides a context manager interface to automatically handle
3033
the start and completion of metrics tracing for a given operation.
3134
"""
3235

36+
_token: Token
37+
3338
def __enter__(self):
3439
"""Enter the runtime context related to this object.
3540
@@ -45,11 +50,13 @@ def __enter__(self):
4550
return self
4651

4752
# Define a new metrics tracer for the new operation
48-
SpannerMetricsTracerFactory.current_metrics_tracer = (
49-
factory.create_metrics_tracer()
53+
# Set the context var and keep the token for reset
54+
tracer = factory.create_metrics_tracer()
55+
self._token = SpannerMetricsTracerFactory._current_metrics_tracer_ctx.set(
56+
tracer
5057
)
51-
if SpannerMetricsTracerFactory.current_metrics_tracer:
52-
SpannerMetricsTracerFactory.current_metrics_tracer.record_operation_start()
58+
if tracer:
59+
tracer.record_operation_start()
5360
return self
5461

5562
def __exit__(self, exc_type, exc_value, traceback):
@@ -70,6 +77,11 @@ def __exit__(self, exc_type, exc_value, traceback):
7077
if not SpannerMetricsTracerFactory().enabled:
7178
return False
7279

73-
if SpannerMetricsTracerFactory.current_metrics_tracer:
74-
SpannerMetricsTracerFactory.current_metrics_tracer.record_operation_completion()
80+
tracer = SpannerMetricsTracerFactory._current_metrics_tracer_ctx.get()
81+
if tracer:
82+
tracer.record_operation_completion()
83+
84+
# Reset the context var using the token
85+
if getattr(self, "_token", None):
86+
SpannerMetricsTracerFactory._current_metrics_tracer_ctx.reset(self._token)
7587
return False # Propagate the exception if any

google/cloud/spanner_v1/metrics/metrics_interceptor.py

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -97,22 +97,17 @@ def _set_metrics_tracer_attributes(self, resources: Dict[str, str]) -> None:
9797
Args:
9898
resources (Dict[str, str]): A dictionary containing project, instance, and database information.
9999
"""
100-
if SpannerMetricsTracerFactory.current_metrics_tracer is None:
100+
tracer = SpannerMetricsTracerFactory.get_current_tracer()
101+
if tracer is None:
101102
return
102103

103104
if resources:
104105
if "project" in resources:
105-
SpannerMetricsTracerFactory.current_metrics_tracer.set_project(
106-
resources["project"]
107-
)
106+
tracer.set_project(resources["project"])
108107
if "instance" in resources:
109-
SpannerMetricsTracerFactory.current_metrics_tracer.set_instance(
110-
resources["instance"]
111-
)
108+
tracer.set_instance(resources["instance"])
112109
if "database" in resources:
113-
SpannerMetricsTracerFactory.current_metrics_tracer.set_database(
114-
resources["database"]
115-
)
110+
tracer.set_database(resources["database"])
116111

117112
def intercept(self, invoked_method, request_or_iterator, call_details):
118113
"""Intercept gRPC calls to collect metrics.
@@ -126,10 +121,8 @@ def intercept(self, invoked_method, request_or_iterator, call_details):
126121
The RPC response
127122
"""
128123
factory = SpannerMetricsTracerFactory()
129-
if (
130-
SpannerMetricsTracerFactory.current_metrics_tracer is None
131-
or not factory.enabled
132-
):
124+
tracer = SpannerMetricsTracerFactory.get_current_tracer()
125+
if tracer is None or not factory.enabled:
133126
return invoked_method(request_or_iterator, call_details)
134127

135128
# Setup Metric Tracer attributes from call details
@@ -142,15 +135,13 @@ def intercept(self, invoked_method, request_or_iterator, call_details):
142135
call_details.method, SPANNER_METHOD_PREFIX
143136
).replace("/", ".")
144137

145-
SpannerMetricsTracerFactory.current_metrics_tracer.set_method(method_name)
146-
SpannerMetricsTracerFactory.current_metrics_tracer.record_attempt_start()
138+
tracer.set_method(method_name)
139+
tracer.record_attempt_start()
147140
response = invoked_method(request_or_iterator, call_details)
148-
SpannerMetricsTracerFactory.current_metrics_tracer.record_attempt_completion()
141+
tracer.record_attempt_completion()
149142

150143
# Process and send GFE metrics if enabled
151-
if SpannerMetricsTracerFactory.current_metrics_tracer.gfe_enabled:
144+
if tracer.gfe_enabled:
152145
metadata = response.initial_metadata()
153-
SpannerMetricsTracerFactory.current_metrics_trace.record_gfe_metrics(
154-
metadata
155-
)
146+
tracer.record_gfe_metrics(metadata)
156147
return response

google/cloud/spanner_v1/metrics/spanner_metrics_tracer_factory.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import os
2020
import logging
2121
from .constants import SPANNER_SERVICE_NAME
22+
import contextvars
2223

2324
try:
2425
import mmh3
@@ -43,7 +44,9 @@ class SpannerMetricsTracerFactory(MetricsTracerFactory):
4344
"""A factory for creating SpannerMetricsTracer instances."""
4445

4546
_metrics_tracer_factory: "SpannerMetricsTracerFactory" = None
46-
current_metrics_tracer: MetricsTracer = None
47+
_current_metrics_tracer_ctx = contextvars.ContextVar(
48+
"current_metrics_tracer", default=None
49+
)
4750

4851
def __new__(
4952
cls, enabled: bool = True, gfe_enabled: bool = False
@@ -80,10 +83,18 @@ def __new__(
8083
cls._metrics_tracer_factory.gfe_enabled = gfe_enabled
8184

8285
if cls._metrics_tracer_factory.enabled != enabled:
83-
cls._metrics_tracer_factory.enabeld = enabled
86+
cls._metrics_tracer_factory.enabled = enabled
8487

8588
return cls._metrics_tracer_factory
8689

90+
@staticmethod
91+
def get_current_tracer() -> MetricsTracer:
92+
return SpannerMetricsTracerFactory._current_metrics_tracer_ctx.get()
93+
94+
@property
95+
def current_metrics_tracer(self) -> MetricsTracer:
96+
return SpannerMetricsTracerFactory._current_metrics_tracer_ctx.get()
97+
8798
@staticmethod
8899
def _generate_client_uid() -> str:
89100
"""Generate a client UID in the form of uuidv4@pid@hostname.

tests/unit/conftest.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import pytest
2+
from unittest.mock import patch
3+
4+
5+
@pytest.fixture(autouse=True)
6+
def mock_periodic_exporting_metric_reader():
7+
"""Globally mock PeriodicExportingMetricReader to prevent real network calls."""
8+
with patch(
9+
"google.cloud.spanner_v1.client.PeriodicExportingMetricReader"
10+
) as mock_client_reader, patch(
11+
"opentelemetry.sdk.metrics.export.PeriodicExportingMetricReader"
12+
):
13+
yield mock_client_reader

tests/unit/test_metrics.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,19 +65,25 @@ def patched_client(monkeypatch):
6565

6666
client_module._metrics_monitor_initialized = False
6767

68-
with patch("google.cloud.spanner_v1.client.CloudMonitoringMetricsExporter"):
68+
with patch(
69+
"google.cloud.spanner_v1.metrics.metrics_exporter.MetricServiceClient"
70+
), patch(
71+
"google.cloud.spanner_v1.metrics.metrics_exporter.CloudMonitoringMetricsExporter"
72+
), patch(
73+
"opentelemetry.sdk.metrics.export.PeriodicExportingMetricReader"
74+
):
6975
client = Client(
7076
project="test",
7177
credentials=TestCredentials(),
72-
# client_options={"api_endpoint": "none"}
7378
)
7479
yield client
7580

7681
# Resetting
7782
metrics.set_meter_provider(metrics.NoOpMeterProvider())
7883
SpannerMetricsTracerFactory._metrics_tracer_factory = None
79-
SpannerMetricsTracerFactory.current_metrics_tracer = None
80-
client_module._metrics_monitor_initialized = False
84+
# Reset context var
85+
ctx = SpannerMetricsTracerFactory._current_metrics_tracer_ctx
86+
ctx.set(None)
8187

8288

8389
def test_metrics_emission_with_failure_attempt(patched_client):
@@ -92,10 +98,14 @@ def test_metrics_emission_with_failure_attempt(patched_client):
9298
original_intercept = metrics_interceptor.intercept
9399
first_attempt = True
94100

101+
captured_tracer_list = []
102+
95103
def mocked_raise(*args, **kwargs):
96104
raise ServiceUnavailable("Service Unavailable")
97105

98106
def mocked_call(*args, **kwargs):
107+
# Capture the tracer while it is active
108+
captured_tracer_list.append(SpannerMetricsTracerFactory.get_current_tracer())
99109
return _UnaryOutcome(MagicMock(), MagicMock())
100110

101111
def intercept_wrapper(invoked_method, request_or_iterator, call_details):
@@ -113,11 +123,14 @@ def intercept_wrapper(invoked_method, request_or_iterator, call_details):
113123

114124
metrics_interceptor.intercept = intercept_wrapper
115125
patch_path = "google.cloud.spanner_v1.metrics.metrics_exporter.CloudMonitoringMetricsExporter.export"
126+
116127
with patch(patch_path):
117128
with database.snapshot():
118129
pass
119130

120131
# Verify that the attempt count increased from the failed initial attempt
121-
assert (
122-
SpannerMetricsTracerFactory.current_metrics_tracer.current_op.attempt_count
123-
) == 2
132+
# We use the captured tracer from the SUCCESSFUL attempt (the second one)
133+
assert len(captured_tracer_list) > 0
134+
tracer = captured_tracer_list[0]
135+
assert tracer is not None
136+
assert tracer.current_op.attempt_count == 2

0 commit comments

Comments
 (0)