2727from dataclasses import dataclass
2828from enum import Enum , auto
2929
30+
3031class SpannerEnv (Enum ):
3132 """Defines the types of Spanner environments the application can connect to."""
33+
3234 CLOUD = auto ()
3335 INFRA = auto ()
3436 MOCK = auto ()
37+ EXPERIMENTAL_HOST = auto ()
38+
3539
3640@dataclass
3741class DatabaseSelector :
@@ -47,42 +51,74 @@ class DatabaseSelector:
4751 instance: The Spanner instance.
4852 database: The Spanner database.
4953 infra_db_path: The path for an internal infrastructure database.
54+ experimental_host: The Spanner experimental host endpoint.
55+ ca_certificate: CA certificate path for the experimental host endpoint.
56+
5057 """
58+
5159 env : SpannerEnv
5260 project : str | None = None
5361 instance : str | None = None
5462 database : str | None = None
5563 infra_db_path : str | None = None
64+ experimental_host : str | None = None
65+ ca_certificate : str | None = None
66+
5667
5768 @classmethod
58- def cloud (cls , project : str , instance : str , database : str ) -> ' DatabaseSelector' :
69+ def cloud (cls , project : str , instance : str , database : str ) -> " DatabaseSelector" :
5970 """Creates a selector for a Google Cloud Spanner database."""
6071 if not project or not instance or not database :
61- raise ValueError ("project, instance, and database are required for Cloud Spanner" )
62- return cls (env = SpannerEnv .CLOUD , project = project , instance = instance , database = database )
72+ raise ValueError (
73+ "project, instance, and database are required for Cloud Spanner"
74+ )
75+ return cls (
76+ env = SpannerEnv .CLOUD , project = project , instance = instance , database = database
77+ )
6378
6479 @classmethod
65- def infra (cls , infra_db_path : str ) -> ' DatabaseSelector' :
80+ def infra (cls , infra_db_path : str ) -> " DatabaseSelector" :
6681 """Creates a selector for an internal infrastructure Spanner database."""
6782 if not infra_db_path :
6883 raise ValueError ("infra_db_path is required for Infra Spanner" )
6984 return cls (env = SpannerEnv .INFRA , infra_db_path = infra_db_path )
7085
7186 @classmethod
72- def mock (cls ) -> ' DatabaseSelector' :
87+ def mock (cls ) -> " DatabaseSelector" :
7388 """Creates a selector for a mock Spanner database."""
7489 return cls (env = SpannerEnv .MOCK )
7590
91+ @classmethod
92+ def experimental_host (
93+ cls , experimental_host : str , database : str , ca_certificate : str | None = None ,
94+ ) -> "DatabaseSelector" :
95+ """Creates a selector for a Google Experimental Host Spanner database."""
96+ if not database :
97+ raise ValueError (
98+ "database is required for Experimental Host Spanner Endpoint"
99+ )
100+ return cls (
101+ env = SpannerEnv .EXPERIMENTAL_HOST ,
102+ project = "default" ,
103+ instance = "default" ,
104+ database = database ,
105+ experimental_host = experimental_host ,
106+ ca_certificate = ca_certificate ,
107+ )
108+
76109 def get_key (self ) -> str :
77110 if self .env == SpannerEnv .CLOUD :
78111 return f"cloud_{ self .project } _{ self .instance } _{ self .database } "
79112 elif self .env == SpannerEnv .INFRA :
80113 return f"infra_{ self .infra_db_path } "
81114 elif self .env == SpannerEnv .MOCK :
82115 return "mock"
116+ elif self .env == SpannerEnv .EXPERIMENTAL_HOST :
117+ return f"experimental_host_{ self .database } "
83118 else :
84119 raise ValueError ("Unknown Spanner environment" )
85120
121+
86122class SpannerQueryResult (NamedTuple ):
87123 # A dict where each key is a field name returned in the query and the list
88124 # contains all items of the same type found for the given field
@@ -96,6 +132,7 @@ class SpannerQueryResult(NamedTuple):
96132 # The error message if any
97133 err : Exception | None
98134
135+
99136class SpannerDatabase (ABC ):
100137 """The spanner class holding the database connection"""
101138
@@ -116,6 +153,7 @@ def execute_query(
116153 ) -> SpannerQueryResult :
117154 pass
118155
156+
119157# Represents the name and type of a field in a Spanner query result. (Implementation-agnostic)
120158@dataclass
121159class SpannerFieldInfo :
@@ -136,8 +174,7 @@ def _load_data(self):
136174 csv_reader = csv .reader (csvfile )
137175 headers = next (csv_reader )
138176 self .fields = [
139- SpannerFieldInfo (name = header , typename = "JSON" )
140- for header in headers
177+ SpannerFieldInfo (name = header , typename = "JSON" ) for header in headers
141178 ]
142179
143180 for row in csv_reader :
@@ -153,22 +190,17 @@ def _load_data(self):
153190 def __iter__ (self ):
154191 return iter (self ._rows )
155192
156- class MockSpannerDatabase ():
193+
194+ class MockSpannerDatabase :
157195 """Mock database class"""
158196
159197 def __init__ (self ):
160198 dirname = os .path .dirname (__file__ )
161- self .graph_csv_path = os .path .join (
162- dirname , "graph_mock_data.csv" )
163- self .schema_json_path = os .path .join (
164- dirname , "graph_mock_schema.json" )
199+ self .graph_csv_path = os .path .join (dirname , "graph_mock_data.csv" )
200+ self .schema_json_path = os .path .join (dirname , "graph_mock_schema.json" )
165201 self .schema_json : dict = {}
166202
167- def execute_query (
168- self ,
169- _ : str ,
170- limit : int = 5
171- ) -> SpannerQueryResult :
203+ def execute_query (self , _ : str , limit : int = 5 ) -> SpannerQueryResult :
172204 """Mock execution of query"""
173205
174206 # Fetch the schema
@@ -182,12 +214,12 @@ def execute_query(
182214
183215 if len (fields ) == 0 :
184216 return SpannerQueryResult (
185- data = data ,
186- fields = fields ,
187- rows = rows ,
188- schema_json = self .schema_json ,
189- err = None
190- )
217+ data = data ,
218+ fields = fields ,
219+ rows = rows ,
220+ schema_json = self .schema_json ,
221+ err = None ,
222+ )
191223
192224 for i , row in enumerate (results ):
193225 if limit is not None and i >= limit :
@@ -196,9 +228,5 @@ def execute_query(
196228 data [field .name ].append (value )
197229
198230 return SpannerQueryResult (
199- data = data ,
200- fields = fields ,
201- rows = rows ,
202- schema_json = self .schema_json ,
203- err = None
204- )
231+ data = data , fields = fields , rows = rows , schema_json = self .schema_json , err = None
232+ )
0 commit comments