Skip to content
This repository was archived by the owner on Mar 31, 2026. It is now read-only.

Commit 997bab0

Browse files
committed
fix: add retries for failed first statements
Add retries if the first statement in a read/write transaction fails, as the statement then does not return a transaction ID. In order to ensure that we get a transaction ID, we first execute an explicit BeginTransaction RPC and then retry the original statement. We return the response of the retry to the application, regardless whether the retry fails or succeeds. The reason that we do a retry with a BeginTransaction AND include the first statement, is to guarantee transaction consistency. If we were to leave the first statement out of the transaction, then it will not be guaranteed that the error condition that cause the failure in the first place is actually still true when the transaction commits. This would break the transaction guarantees. Example (pseudo-code): ```sql -- The following statement fails with ALREADY_EXISTS insert into some_table (id, value) values (1, 'One'); -- Execute an explicit BeginTransaction RPC. begin; -- Retry the initial statement. This ensures that -- whatever the response is, this response will be -- valid for the entire transaction. insert into some_table (id, value) values (1, 'One'); -- This is guaranteed to return a row. select * from some_table where id=1; -- ... execute the rest of the transaction ... commit; ``` If we had not included the initial insert statement in the retried transaction, then there is no guarantee that the select statement would actually return any rows, as other transactions could in theory have deleted it in the meantime.
1 parent ac59847 commit 997bab0

File tree

7 files changed

+775
-90
lines changed

7 files changed

+775
-90
lines changed

google/cloud/spanner_dbapi/batch_dml_executor.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,11 @@ def run_batch_dml(cursor: "Cursor", statements: List[Statement]):
104104
connection._transaction = None
105105
raise Aborted(status.message)
106106
elif status.code != OK:
107+
if not transaction._transaction_id:
108+
# This should normally not happen,
109+
# but we safeguard against it just to be sure.
110+
transaction._reset_and_begin()
111+
continue
107112
raise OperationalError(status.message)
108113

109114
cursor._batch_dml_rows_count = res
@@ -116,6 +121,11 @@ def run_batch_dml(cursor: "Cursor", statements: List[Statement]):
116121
raise
117122
else:
118123
connection._transaction_helper.retry_transaction()
124+
except Exception as ex:
125+
if not transaction._transaction_id:
126+
transaction._reset_and_begin()
127+
continue
128+
raise ex
119129

120130

121131
def _do_batch_update_autocommit(transaction, statements):

google/cloud/spanner_dbapi/cursor.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,16 @@ def _execute_in_rw_transaction(self):
366366
raise
367367
else:
368368
self.transaction_helper.retry_transaction()
369+
except Exception as ex:
370+
# In case of inline-begin failure, the transaction isn't started.
371+
# We immediately retry with an explicit BeginTransaction.
372+
transaction = getattr(self.connection, "_transaction", None)
373+
if transaction and not transaction._transaction_id:
374+
transaction._reset_and_begin()
375+
376+
# Let the existing retry loop handle the retry of the statement
377+
continue
378+
raise ex
369379
else:
370380
self.connection.database.run_in_transaction(
371381
self._do_execute_update_in_autocommit,

google/cloud/spanner_v1/testing/mock_spanner.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515
import inspect
1616
import grpc
1717
from concurrent import futures
18+
from dataclasses import dataclass
1819

19-
from google.protobuf import empty_pb2
2020
from grpc_status.rpc_status import _Status
21+
from google.rpc.code_pb2 import OK
22+
from google.protobuf import empty_pb2
2123

2224
from google.cloud.spanner_v1 import (
2325
TransactionOptions,
@@ -53,10 +55,23 @@ def get_result(self, sql: str) -> result_set.ResultSet:
5355
return result
5456

5557
def add_error(self, method: str, error: _Status):
58+
if not hasattr(self, "_errors_list"):
59+
self._errors_list = {}
60+
if method not in self._errors_list:
61+
self._errors_list[method] = []
62+
self._errors_list[method].append(error)
5663
self.errors[method] = error
5764

5865
def pop_error(self, context):
5966
name = inspect.currentframe().f_back.f_code.co_name
67+
if hasattr(self, "_errors_list") and name in self._errors_list:
68+
if self._errors_list[name]:
69+
error = self._errors_list[name].pop(0)
70+
context.abort_with_status(error)
71+
return
72+
return # Queue is empty, return normally (no error)
73+
74+
# Fallback to single error
6075
error: _Status | None = self.errors.pop(name, None)
6176
if error:
6277
context.abort_with_status(error)
@@ -94,6 +109,12 @@ def get_result_as_partial_result_sets(
94109
return partials
95110

96111

112+
@dataclass
113+
class BatchDmlResponseConfig:
114+
status: _Status
115+
include_transaction_id: bool = True
116+
117+
97118
# An in-memory mock Spanner server that can be used for testing.
98119
class SpannerServicer(spanner_grpc.SpannerServicer):
99120
def __init__(self):
@@ -103,6 +124,7 @@ def __init__(self):
103124
self.transaction_counter = 0
104125
self.transactions = {}
105126
self._mock_spanner = MockSpanner()
127+
self._batch_dml_response_configs = []
106128

107129
@property
108130
def mock_spanner(self):
@@ -115,6 +137,15 @@ def requests(self):
115137
def clear_requests(self):
116138
self._requests = []
117139

140+
def add_batch_dml_response_status(self, status, include_transaction_id=True):
141+
if not hasattr(self, "_batch_dml_response_configs"):
142+
self._batch_dml_response_configs = []
143+
self._batch_dml_response_configs.append(
144+
BatchDmlResponseConfig(
145+
status=status, include_transaction_id=include_transaction_id
146+
)
147+
)
148+
118149
def CreateSession(self, request, context):
119150
self._requests.append(request)
120151
return self.__create_session(request.database, request.session)
@@ -176,6 +207,14 @@ def ExecuteBatchDml(self, request, context):
176207
self.mock_spanner.pop_error(context)
177208
response = spanner.ExecuteBatchDmlResponse()
178209
started_transaction = self.__maybe_create_transaction(request)
210+
211+
config = None
212+
if (
213+
hasattr(self, "_batch_dml_response_configs")
214+
and self._batch_dml_response_configs
215+
):
216+
config = self._batch_dml_response_configs.pop(0)
217+
179218
first = True
180219
for statement in request.statements:
181220
result = self.mock_spanner.get_result(statement.sql)
@@ -184,8 +223,16 @@ def ExecuteBatchDml(self, request, context):
184223
self.mock_spanner.get_result(statement.sql)
185224
)
186225
result.metadata = result_set.ResultSetMetadata(result.metadata)
187-
result.metadata.transaction = started_transaction
226+
if config is None or config.include_transaction_id:
227+
result.metadata.transaction = started_transaction
228+
first = False
188229
response.result_sets.append(result)
230+
231+
if config is not None:
232+
response.status.CopyFrom(config.status)
233+
else:
234+
response.status.code = OK
235+
189236
return response
190237

191238
def Read(self, request, context):

google/cloud/spanner_v1/transaction.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,12 @@ def wrapped_method(*args, **kwargs):
214214

215215
self.rolled_back = True
216216

217+
def _reset_and_begin(self):
218+
"""This function can be used to reset the transaction and execute an explicit BeginTransaction RPC if the first statement in the transaction failed, and that statement included an inlined BeginTransaction option."""
219+
self._read_request_count = 0
220+
self._execute_sql_request_count = 0
221+
self.begin()
222+
217223
def commit(
218224
self, return_commit_stats=False, request_options=None, max_commit_delay=None
219225
):

tests/mockserver_tests/mock_server_test_base.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import logging
15+
import os
1516
import unittest
1617

1718
import grpc
@@ -65,6 +66,19 @@ def aborted_status() -> _Status:
6566
return status
6667

6768

69+
def invalid_argument_status() -> _Status:
70+
error = status_pb2.Status(
71+
code=code_pb2.INVALID_ARGUMENT,
72+
message="Invalid argument.",
73+
)
74+
status = _Status(
75+
code=code_to_grpc_status_code(error.code),
76+
details=error.message,
77+
trailing_metadata=(("grpc-status-details-bin", error.SerializeToString()),),
78+
)
79+
return status
80+
81+
6882
def _make_partial_result_sets(
6983
fields: list[tuple[str, TypeCode]], results: list[dict]
7084
) -> list[result_set.PartialResultSet]:
@@ -174,6 +188,9 @@ class MockServerTestBase(unittest.TestCase):
174188

175189
def __init__(self, *args, **kwargs):
176190
super(MockServerTestBase, self).__init__(*args, **kwargs)
191+
# Disable built-in metrics for tests to avoid Unauthenticated errors
192+
os.environ["SPANNER_DISABLE_BUILTIN_METRICS"] = "true"
193+
177194
self._client = None
178195
self._instance = None
179196
self._database = None

0 commit comments

Comments
 (0)