2525
2626from spanner_graphs .conversion import get_nodes_edges
2727from spanner_graphs .exec_env import get_database_instance
28- from spanner_graphs .database import SpannerQueryResult
28+ from spanner_graphs .database import DatabaseSelector , SpannerQueryResult , SpannerEnv
2929
3030# Supported types for a property
3131PROPERTY_TYPE_SET = {
@@ -52,6 +52,24 @@ class EdgeDirection(Enum):
5252 INCOMING = "INCOMING"
5353 OUTGOING = "OUTGOING"
5454
55+
56+ def dict_to_selector (selector_dict : Dict [str , Any ]) -> DatabaseSelector :
57+ """
58+ Picks the correct DB selector based on the environment the server is running in.
59+ """
60+ try :
61+ env = SpannerEnv [selector_dict ['env' ].split ('.' )[- 1 ]]
62+ if env == SpannerEnv .CLOUD :
63+ return DatabaseSelector .cloud (selector_dict ['project' ], selector_dict ['instance' ], selector_dict ['database' ])
64+ elif env == SpannerEnv .INFRA :
65+ return DatabaseSelector .infra (selector_dict ['infra_db_path' ])
66+ elif env == SpannerEnv .MOCK :
67+ return DatabaseSelector .mock ()
68+ raise ValueError (f"Invalid env in selector dict: { selector_dict } " )
69+ except Exception as e :
70+ print (f"Unexpected error when fetching selector: { e } " )
71+
72+
5573def is_valid_property_type (property_type : str ) -> bool :
5674 """
5775 Validates a property type.
@@ -79,7 +97,7 @@ def is_valid_property_type(property_type: str) -> bool:
7997 return True
8098
8199def validate_node_expansion_request (data ) -> (list [NodePropertyForDataExploration ], EdgeDirection ):
82- required_fields = ["project" , "instance" , "database" , "graph" , " uid" , "node_labels" , "direction" ]
100+ required_fields = ["uid" , "node_labels" , "direction" ]
83101 missing_fields = [field for field in required_fields if data .get (field ) is None ]
84102
85103 if missing_fields :
@@ -146,7 +164,8 @@ def validate_node_expansion_request(data) -> (list[NodePropertyForDataExploratio
146164 return validated_properties , direction
147165
148166def execute_node_expansion (
149- params_str : str ,
167+ selector_dict : Dict [str , Any ],
168+ graph : str ,
150169 request : dict ) -> dict :
151170 """Execute a node expansion query to find connected nodes and edges.
152171
@@ -158,13 +177,8 @@ def execute_node_expansion(
158177 dict: A dictionary containing the query response with nodes and edges.
159178 """
160179
161- params = json .loads (params_str )
162- node_properties , direction = validate_node_expansion_request (params | request )
180+ node_properties , direction = validate_node_expansion_request (request )
163181
164- project = params .get ("project" )
165- instance = params .get ("instance" )
166- database = params .get ("database" )
167- graph = params .get ("graph" )
168182 uid = request .get ("uid" )
169183 node_labels = request .get ("node_labels" )
170184 edge_label = request .get ("edge_label" )
@@ -204,14 +218,11 @@ def execute_node_expansion(
204218 RETURN TO_JSON(e) as e, TO_JSON(d) as d
205219 """
206220
207- return execute_query (project , instance , database , query , mock = False )
221+ return execute_query (selector_dict , query )
208222
209223def execute_query (
210- project : str ,
211- instance : str ,
212- database : str ,
224+ selector_dict : Dict [str , Any ],
213225 query : str ,
214- mock : bool = False ,
215226) -> Dict [str , Any ]:
216227 """Executes a query against a database and formats the result.
217228
@@ -220,19 +231,14 @@ def execute_query(
220231 If the query fails, it returns a detailed error message, optionally
221232 including the database schema to aid in debugging.
222233
223- Args:
224- project: The cloud project ID.
225- instance: The database instance name.
226- database: The database name.
227- query: The query string to execute.
228- mock: If True, use a mock database instance for testing. Defaults to False.
229-
230234 Returns:
231235 A dictionary containing either the structured 'response' with nodes,
232236 edges, and other data, or an 'error' key with a descriptive message.
233237 """
234238 try :
235- db_instance = get_database_instance (project , instance , database , mock )
239+ selector = dict_to_selector (selector_dict )
240+ db_instance = get_database_instance (selector )
241+
236242 result : SpannerQueryResult = db_instance .execute_query (query )
237243
238244 if len (result .rows ) == 0 and result .err :
@@ -382,32 +388,25 @@ def handle_post_query(self):
382388 data = self .parse_post_data ()
383389 params = json .loads (data ["params" ])
384390 response = execute_query (
385- project = params ["project" ],
386- instance = params ["instance" ],
387- database = params ["database" ],
388- query = data ["query" ],
389- mock = params ["mock" ]
391+ selector_dict = params ["selector" ],
392+ query = data ["query" ]
390393 )
391394 self .do_data_response (response )
392395
393396 def handle_post_node_expansion (self ):
394- """Handle POST requests for node expansion.
395-
396- Expects a JSON payload with:
397- - params: A JSON string containing connection parameters (project, instance, database, graph)
398- - request: A dictionary with node details (uid, node_labels, node_properties, direction, edge_label)
399- """
400397 try :
401398 data = self .parse_post_data ()
399+ params = json .loads (data .get ("params" ))
400+ selector_dict = params ["selector" ]
401+ graph = params .get ("graph" )
402+ request_data = data .get ("request" )
402403
403- # Execute node expansion with:
404- # - params_str: JSON string with connection parameters (project, instance, database, graph)
405- # - request: Dict with node details (uid, node_labels, node_properties, direction, edge_label)
406404 self .do_data_response (execute_node_expansion (
407- params_str = data .get ("params" ),
408- request = data .get ("request" )
405+ selector_dict = selector_dict ,
406+ graph = graph ,
407+ request = request_data
409408 ))
410- except BaseException as e :
409+ except Exception as e :
411410 self .do_error_response (e )
412411 return
413412
0 commit comments