Skip to content

Commit 1988838

Browse files
committed
feat: add support for experimental host Spanner endpoints
1 parent b1c7873 commit 1988838

File tree

5 files changed

+294
-151
lines changed

5 files changed

+294
-151
lines changed

spanner_graphs/cloud_database.py

Lines changed: 60 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -27,39 +27,76 @@
2727
import logging
2828
import pydata_google_auth
2929

30-
from spanner_graphs.database import SpannerDatabase, MockSpannerDatabase, SpannerQueryResult, SpannerFieldInfo
30+
from spanner_graphs.database import (
31+
SpannerDatabase,
32+
MockSpannerDatabase,
33+
SpannerQueryResult,
34+
SpannerFieldInfo,
35+
)
36+
3137

3238
def _get_default_credentials_with_project():
3339
return pydata_google_auth.default(
34-
scopes=["https://www.googleapis.com/auth/cloud-platform"], use_local_webserver=False)
40+
scopes=["https://www.googleapis.com/auth/cloud-platform"],
41+
use_local_webserver=False,
42+
)
43+
3544

3645
def get_as_field_info_list(fields: List[StructType.Field]) -> List[SpannerFieldInfo]:
37-
"""Converts a list of StructType.Field to a list of SpannerFieldInfo."""
38-
return [SpannerFieldInfo(name=field.name, typename=TypeCode(field.type_.code).name) for field in fields]
46+
"""Converts a list of StructType.Field to a list of SpannerFieldInfo."""
47+
return [
48+
SpannerFieldInfo(name=field.name, typename=TypeCode(field.type_.code).name)
49+
for field in fields
50+
]
51+
3952

4053
class CloudSpannerDatabase(SpannerDatabase):
4154
"""Concrete implementation for Spanner database on the cloud."""
42-
def __init__(self, project_id: str, instance_id: str,
43-
database_id: str) -> None:
44-
credentials, _ = _get_default_credentials_with_project()
45-
self.client = spanner.Client(
46-
project=project_id, credentials=credentials, client_options=ClientOptions(quota_project_id=project_id))
55+
56+
def __init__(
57+
self,
58+
project_id: str,
59+
instance_id: str,
60+
database_id: str,
61+
experimental_host: str | None = None,
62+
ca_certificate: str | None = None,
63+
) -> None:
64+
from google.auth.credentials import AnonymousCredentials
65+
66+
if experimental_host:
67+
self.client = spanner.Client(
68+
project=project_id,
69+
credentials=AnonymousCredentials(),
70+
experimental_host=experimental_host,
71+
ca_certificate=ca_certificate,
72+
)
73+
else:
74+
credentials, _ = _get_default_credentials_with_project()
75+
self.client = spanner.Client(
76+
project=project_id,
77+
credentials=credentials,
78+
client_options=ClientOptions(quota_project_id=project_id),
79+
)
4780
self.instance = self.client.instance(instance_id)
4881
logger = logging.getLogger("spanner_graphs")
4982
logger.setLevel(logging.CRITICAL)
5083
self.database = self.instance.database(database_id, logger=logger)
5184
self.schema_json: Any | None = None
5285

5386
def __repr__(self) -> str:
54-
return (f"<CloudSpannerDatabase["
55-
f"project:{self.client.project_name},"
56-
f"instance:{self.instance.name},"
57-
f"db:{self.database.name}]>")
87+
return (
88+
f"<CloudSpannerDatabase["
89+
f"project:{self.client.project_name},"
90+
f"instance:{self.instance.name},"
91+
f"db:{self.database.name}]>"
92+
)
5893

5994
def _extract_graph_name(self, query: str) -> str:
6095
words = query.strip().split()
6196
if len(words) < 3:
62-
raise ValueError("invalid query: must contain at least (GRAPH, graph_name and query)")
97+
raise ValueError(
98+
"invalid query: must contain at least (GRAPH, graph_name and query)"
99+
)
63100

64101
if words[0].upper() != "GRAPH":
65102
raise ValueError("invalid query: GRAPH must be the first word")
@@ -81,7 +118,9 @@ def _get_schema_for_graph(self, graph_query: str) -> Any | None:
81118
params = {"graph_name": graph_name}
82119
param_type = {"graph_name": spanner.param_types.STRING}
83120

84-
result = snapshot.execute_sql(schema_query, params=params, param_types=param_type)
121+
result = snapshot.execute_sql(
122+
schema_query, params=params, param_types=param_type
123+
)
85124
schema_rows = list(result)
86125

87126
if schema_rows:
@@ -117,15 +156,13 @@ def execute_query(
117156
params = dict(limit=limit)
118157

119158
try:
120-
results = snapshot.execute_sql(query, params=params, param_types=param_types)
159+
results = snapshot.execute_sql(
160+
query, params=params, param_types=param_types
161+
)
121162
rows = list(results)
122163
except Exception as e:
123164
return SpannerQueryResult(
124-
data={},
125-
fields=[],
126-
rows=[],
127-
schema_json=self.schema_json,
128-
err=e
165+
data={}, fields=[], rows=[], schema_json=self.schema_json, err=e
129166
)
130167

131168
fields: List[SpannerFieldInfo] = get_as_field_info_list(results.fields)
@@ -137,7 +174,7 @@ def execute_query(
137174
fields=fields,
138175
rows=rows,
139176
schema_json=self.schema_json,
140-
err=None
177+
err=None,
141178
)
142179

143180
for row_data in rows:
@@ -152,5 +189,5 @@ def execute_query(
152189
fields=fields,
153190
rows=rows,
154191
schema_json=self.schema_json,
155-
err=None
192+
err=None,
156193
)

spanner_graphs/database.py

Lines changed: 57 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,15 @@
2727
from dataclasses import dataclass
2828
from enum import Enum, auto
2929

30+
3031
class SpannerEnv(Enum):
3132
"""Defines the types of Spanner environments the application can connect to."""
33+
3234
CLOUD = auto()
3335
INFRA = auto()
3436
MOCK = auto()
37+
EXPERIMENTAL_HOST = auto()
38+
3539

3640
@dataclass
3741
class DatabaseSelector:
@@ -47,42 +51,74 @@ class DatabaseSelector:
4751
instance: The Spanner instance.
4852
database: The Spanner database.
4953
infra_db_path: The path for an internal infrastructure database.
54+
experimental_host: The Spanner experimental host endpoint.
55+
ca_certificate: CA certificate path for the experimental host endpoint.
56+
5057
"""
58+
5159
env: SpannerEnv
5260
project: str | None = None
5361
instance: str | None = None
5462
database: str | None = None
5563
infra_db_path: str | None = None
64+
experimental_host: str | None = None
65+
ca_certificate: str | None = None
66+
5667

5768
@classmethod
58-
def cloud(cls, project: str, instance: str, database: str) -> 'DatabaseSelector':
69+
def cloud(cls, project: str, instance: str, database: str) -> "DatabaseSelector":
5970
"""Creates a selector for a Google Cloud Spanner database."""
6071
if not project or not instance or not database:
61-
raise ValueError("project, instance, and database are required for Cloud Spanner")
62-
return cls(env=SpannerEnv.CLOUD, project=project, instance=instance, database=database)
72+
raise ValueError(
73+
"project, instance, and database are required for Cloud Spanner"
74+
)
75+
return cls(
76+
env=SpannerEnv.CLOUD, project=project, instance=instance, database=database
77+
)
6378

6479
@classmethod
65-
def infra(cls, infra_db_path: str) -> 'DatabaseSelector':
80+
def infra(cls, infra_db_path: str) -> "DatabaseSelector":
6681
"""Creates a selector for an internal infrastructure Spanner database."""
6782
if not infra_db_path:
6883
raise ValueError("infra_db_path is required for Infra Spanner")
6984
return cls(env=SpannerEnv.INFRA, infra_db_path=infra_db_path)
7085

7186
@classmethod
72-
def mock(cls) -> 'DatabaseSelector':
87+
def mock(cls) -> "DatabaseSelector":
7388
"""Creates a selector for a mock Spanner database."""
7489
return cls(env=SpannerEnv.MOCK)
7590

91+
@classmethod
92+
def experimental_host(
93+
cls, experimental_host: str, database: str, ca_certificate: str | None = None,
94+
) -> "DatabaseSelector":
95+
"""Creates a selector for a Google Experimental Host Spanner database."""
96+
if not database:
97+
raise ValueError(
98+
"database is required for Experimental Host Spanner Endpoint"
99+
)
100+
return cls(
101+
env=SpannerEnv.EXPERIMENTAL_HOST,
102+
project="default",
103+
instance="default",
104+
database=database,
105+
experimental_host=experimental_host,
106+
ca_certificate=ca_certificate,
107+
)
108+
76109
def get_key(self) -> str:
77110
if self.env == SpannerEnv.CLOUD:
78111
return f"cloud_{self.project}_{self.instance}_{self.database}"
79112
elif self.env == SpannerEnv.INFRA:
80113
return f"infra_{self.infra_db_path}"
81114
elif self.env == SpannerEnv.MOCK:
82115
return "mock"
116+
elif self.env == SpannerEnv.EXPERIMENTAL_HOST:
117+
return f"experimental_host_{self.database}"
83118
else:
84119
raise ValueError("Unknown Spanner environment")
85120

121+
86122
class SpannerQueryResult(NamedTuple):
87123
# A dict where each key is a field name returned in the query and the list
88124
# contains all items of the same type found for the given field
@@ -96,6 +132,7 @@ class SpannerQueryResult(NamedTuple):
96132
# The error message if any
97133
err: Exception | None
98134

135+
99136
class SpannerDatabase(ABC):
100137
"""The spanner class holding the database connection"""
101138

@@ -116,6 +153,7 @@ def execute_query(
116153
) -> SpannerQueryResult:
117154
pass
118155

156+
119157
# Represents the name and type of a field in a Spanner query result. (Implementation-agnostic)
120158
@dataclass
121159
class SpannerFieldInfo:
@@ -136,8 +174,7 @@ def _load_data(self):
136174
csv_reader = csv.reader(csvfile)
137175
headers = next(csv_reader)
138176
self.fields = [
139-
SpannerFieldInfo(name=header, typename="JSON")
140-
for header in headers
177+
SpannerFieldInfo(name=header, typename="JSON") for header in headers
141178
]
142179

143180
for row in csv_reader:
@@ -153,22 +190,17 @@ def _load_data(self):
153190
def __iter__(self):
154191
return iter(self._rows)
155192

156-
class MockSpannerDatabase():
193+
194+
class MockSpannerDatabase:
157195
"""Mock database class"""
158196

159197
def __init__(self):
160198
dirname = os.path.dirname(__file__)
161-
self.graph_csv_path = os.path.join(
162-
dirname, "graph_mock_data.csv")
163-
self.schema_json_path = os.path.join(
164-
dirname, "graph_mock_schema.json")
199+
self.graph_csv_path = os.path.join(dirname, "graph_mock_data.csv")
200+
self.schema_json_path = os.path.join(dirname, "graph_mock_schema.json")
165201
self.schema_json: dict = {}
166202

167-
def execute_query(
168-
self,
169-
_: str,
170-
limit: int = 5
171-
) -> SpannerQueryResult:
203+
def execute_query(self, _: str, limit: int = 5) -> SpannerQueryResult:
172204
"""Mock execution of query"""
173205

174206
# Fetch the schema
@@ -182,12 +214,12 @@ def execute_query(
182214

183215
if len(fields) == 0:
184216
return SpannerQueryResult(
185-
data=data,
186-
fields=fields,
187-
rows=rows,
188-
schema_json=self.schema_json,
189-
err=None
190-
)
217+
data=data,
218+
fields=fields,
219+
rows=rows,
220+
schema_json=self.schema_json,
221+
err=None,
222+
)
191223

192224
for i, row in enumerate(results):
193225
if limit is not None and i >= limit:
@@ -196,9 +228,5 @@ def execute_query(
196228
data[field.name].append(value)
197229

198230
return SpannerQueryResult(
199-
data=data,
200-
fields=fields,
201-
rows=rows,
202-
schema_json=self.schema_json,
203-
err=None
204-
)
231+
data=data, fields=fields, rows=rows, schema_json=self.schema_json, err=None
232+
)

spanner_graphs/exec_env.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
# Copyright 2024 Google LLC
32

43
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -29,6 +28,7 @@
2928
# Global dict of database instances created in a single session
3029
database_instances: Dict[str, Union[SpannerDatabase, MockSpannerDatabase]] = {}
3130

31+
3232
def get_database_instance(
3333
selector: DatabaseSelector,
3434
) -> Union[SpannerDatabase, MockSpannerDatabase]:
@@ -59,9 +59,7 @@ def get_database_instance(
5959

6060
elif selector.env == SpannerEnv.CLOUD:
6161
try:
62-
cloud_db_module = importlib.import_module(
63-
"spanner_graphs.cloud_database"
64-
)
62+
cloud_db_module = importlib.import_module("spanner_graphs.cloud_database")
6563
CloudSpannerDatabase = getattr(cloud_db_module, "CloudSpannerDatabase")
6664
db = CloudSpannerDatabase(
6765
selector.project, selector.instance, selector.database
@@ -72,15 +70,28 @@ def get_database_instance(
7270
)
7371
elif selector.env == SpannerEnv.INFRA:
7472
try:
75-
infra_db_module = importlib.import_module(
76-
"spanner_graphs.infra_database"
77-
)
73+
infra_db_module = importlib.import_module("spanner_graphs.infra_database")
7874
InfraSpannerDatabase = getattr(infra_db_module, "InfraSpannerDatabase")
7975
db = InfraSpannerDatabase(selector.infra_db_path)
8076
except ImportError:
8177
raise RuntimeError(
8278
"Infra Spanner support is not available in this environment."
8379
)
80+
elif selector.env == SpannerEnv.EXPERIMENTAL_HOST:
81+
try:
82+
cloud_db_module = importlib.import_module("spanner_graphs.cloud_database")
83+
CloudSpannerDatabase = getattr(cloud_db_module, "CloudSpannerDatabase")
84+
db = CloudSpannerDatabase(
85+
selector.project,
86+
selector.instance,
87+
selector.database,
88+
selector.experimental_host,
89+
selector.ca_certificate,
90+
)
91+
except ImportError:
92+
raise RuntimeError(
93+
"Spanner experimental host support is not available in this environment."
94+
)
8495
else:
8596
raise ValueError(f"Unsupported Spanner environment: {selector.env}")
8697

0 commit comments

Comments
 (0)