Skip to content

Commit abbe104

Browse files
authored
Merge pull request #58 from smukil/create_db_selector
Add the concept of DB selectors
2 parents 8574f4f + a8d66e8 commit abbe104

12 files changed

Lines changed: 366 additions & 234 deletions

frontend/static/dev.html

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -484,11 +484,20 @@ <h2>Configure Visualization</h2>
484484
window.app.tearDown();
485485
}
486486

487+
let selector;
488+
if (mock) {
489+
selector = { env: 'SpannerEnv.MOCK' };
490+
} else {
491+
selector = {
492+
env: 'SpannerEnv.CLOUD',
493+
project: project,
494+
instance: instance,
495+
database: database
496+
};
497+
}
498+
487499
const params = {
488-
'project': project,
489-
'instance': instance,
490-
'database': database,
491-
'mock': mock,
500+
'selector': selector,
492501
'graph': graph
493502
};
494503

@@ -546,4 +555,4 @@ <h2>Configure Visualization</h2>
546555
toggleCommandPalette();
547556
</script>
548557
</body>
549-
</html>
558+
</html>

frontend/static/test.html

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -68,16 +68,16 @@
6868
}
6969

7070
const mount = document.querySelector('.mount-spanner-test');
71-
params = {
72-
'project': 'project-foo',
73-
'instance': 'instance-foo',
74-
'database': 'database-foo',
75-
'mock': true
76-
}
71+
const params = {
72+
selector: {
73+
env: 'SpannerEnv.MOCK'
74+
},
75+
graph: ''
76+
};
7777
window.app = new SpannerApp({
7878
id: 'spanner-test', port:'', params:params, mount:mount, query: ''
7979
});
8080
});
8181
</script>
8282
</body>
83-
</html>
83+
</html>

spanner_graphs/database.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,63 @@
2525
import csv
2626

2727
from dataclasses import dataclass
28+
from enum import Enum, auto
29+
30+
class SpannerEnv(Enum):
31+
"""Defines the types of Spanner environments the application can connect to."""
32+
CLOUD = auto()
33+
INFRA = auto()
34+
MOCK = auto()
35+
36+
@dataclass
37+
class DatabaseSelector:
38+
"""
39+
A factory and configuration holder for Spanner database connection details.
40+
41+
This class provides a clean way to specify which Spanner database to connect to,
42+
whether it's on Google Cloud, an internal infrastructure, or a local mock.
43+
44+
Attributes:
45+
env: The Spanner environment type.
46+
project: The Google Cloud project.
47+
instance: The Spanner instance.
48+
database: The Spanner database.
49+
infra_db_path: The path for an internal infrastructure database.
50+
"""
51+
env: SpannerEnv
52+
project: str | None = None
53+
instance: str | None = None
54+
database: str | None = None
55+
infra_db_path: str | None = None
56+
57+
@classmethod
58+
def cloud(cls, project: str, instance: str, database: str) -> 'DatabaseSelector':
59+
"""Creates a selector for a Google Cloud Spanner database."""
60+
if not project or not instance or not database:
61+
raise ValueError("project, instance, and database are required for Cloud Spanner")
62+
return cls(env=SpannerEnv.CLOUD, project=project, instance=instance, database=database)
63+
64+
@classmethod
65+
def infra(cls, infra_db_path: str) -> 'DatabaseSelector':
66+
"""Creates a selector for an internal infrastructure Spanner database."""
67+
if not infra_db_path:
68+
raise ValueError("infra_db_path is required for Infra Spanner")
69+
return cls(env=SpannerEnv.INFRA, infra_db_path=infra_db_path)
70+
71+
@classmethod
72+
def mock(cls) -> 'DatabaseSelector':
73+
"""Creates a selector for a mock Spanner database."""
74+
return cls(env=SpannerEnv.MOCK)
75+
76+
def get_key(self) -> str:
77+
if self.env == SpannerEnv.CLOUD:
78+
return f"cloud_{self.project}_{self.instance}_{self.database}"
79+
elif self.env == SpannerEnv.INFRA:
80+
return f"infra_{self.infra_db_path}"
81+
elif self.env == SpannerEnv.MOCK:
82+
return "mock"
83+
else:
84+
raise ValueError("Unknown Spanner environment")
2885

2986
class SpannerQueryResult(NamedTuple):
3087
# A dict where each key is a field name returned in the query and the list

spanner_graphs/exec_env.py

Lines changed: 59 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,26 +16,73 @@
1616
"""
1717
This module maintains state for the execution environment of a session
1818
"""
19-
from typing import Dict, Union
2019

21-
from spanner_graphs.database import SpannerDatabase, MockSpannerDatabase
22-
from spanner_graphs.cloud_database import CloudSpannerDatabase
20+
import importlib
21+
from typing import Dict, Union
22+
from spanner_graphs.database import (
23+
SpannerDatabase,
24+
MockSpannerDatabase,
25+
DatabaseSelector,
26+
SpannerEnv,
27+
)
2328

2429
# Global dict of database instances created in a single session
2530
database_instances: Dict[str, Union[SpannerDatabase, MockSpannerDatabase]] = {}
2631

27-
def get_database_instance(project: str, instance: str, database: str, mock = False):
28-
if mock:
32+
def get_database_instance(
33+
selector: DatabaseSelector,
34+
) -> Union[SpannerDatabase, MockSpannerDatabase]:
35+
"""Gets a cached or new database instance based on the selector.
36+
37+
Args:
38+
selector: A `DatabaseSelector` object that specifies which database to
39+
connect to.
40+
41+
Returns:
42+
An initialized `SpannerDatabase` or `MockSpannerDatabase` instance.
43+
A CloudSpannerDatabase will only be available in public environments.
44+
An InfraSpannerDatabase will only be available in internal environments.
45+
46+
Raises:
47+
RuntimeError: If the required Spanner client library (for Cloud or Infra)
48+
is not installed in the environment.
49+
ValueError: If the selector specifies an unknown or unsupported
50+
environment.
51+
"""
52+
if selector.env == SpannerEnv.MOCK:
2953
return MockSpannerDatabase()
3054

31-
key = f"{project}_{instance}_{database}"
55+
key = selector.get_key()
3256
db = database_instances.get(key)
57+
if db:
58+
return db
3359

34-
# Currently, we only create and return CloudSpannerDatabase instances. In the future, different
35-
# implementations could be introduced.
36-
if not db:
37-
db = CloudSpannerDatabase(project, instance, database)
38-
database_instances[key] = db
60+
elif selector.env == SpannerEnv.CLOUD:
61+
try:
62+
cloud_db_module = importlib.import_module(
63+
"spanner_graphs.cloud_database"
64+
)
65+
CloudSpannerDatabase = getattr(cloud_db_module, "CloudSpannerDatabase")
66+
db = CloudSpannerDatabase(
67+
selector.project, selector.instance, selector.database
68+
)
69+
except ImportError:
70+
raise RuntimeError(
71+
"Cloud Spanner support is not available in this environment."
72+
)
73+
elif selector.env == SpannerEnv.INFRA:
74+
try:
75+
infra_db_module = importlib.import_module(
76+
"spanner_graphs.infra_database"
77+
)
78+
InfraSpannerDatabase = getattr(infra_db_module, "InfraSpannerDatabase")
79+
db = InfraSpannerDatabase(selector.infra_db_path)
80+
except ImportError:
81+
raise RuntimeError(
82+
"Infra Spanner support is not available in this environment."
83+
)
84+
else:
85+
raise ValueError(f"Unsupported Spanner environment: {selector.env}")
3986

87+
database_instances[key] = db
4088
return db
41-

spanner_graphs/graph_server.py

Lines changed: 38 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
from spanner_graphs.conversion import get_nodes_edges
2727
from spanner_graphs.exec_env import get_database_instance
28-
from spanner_graphs.database import SpannerQueryResult
28+
from spanner_graphs.database import DatabaseSelector, SpannerQueryResult, SpannerEnv
2929

3030
# Supported types for a property
3131
PROPERTY_TYPE_SET = {
@@ -52,6 +52,24 @@ class EdgeDirection(Enum):
5252
INCOMING = "INCOMING"
5353
OUTGOING = "OUTGOING"
5454

55+
56+
def dict_to_selector(selector_dict: Dict[str, Any]) -> DatabaseSelector:
57+
"""
58+
Picks the correct DB selector based on the environment the server is running in.
59+
"""
60+
try:
61+
env = SpannerEnv[selector_dict['env'].split('.')[-1]]
62+
if env == SpannerEnv.CLOUD:
63+
return DatabaseSelector.cloud(selector_dict['project'], selector_dict['instance'], selector_dict['database'])
64+
elif env == SpannerEnv.INFRA:
65+
return DatabaseSelector.infra(selector_dict['infra_db_path'])
66+
elif env == SpannerEnv.MOCK:
67+
return DatabaseSelector.mock()
68+
raise ValueError(f"Invalid env in selector dict: {selector_dict}")
69+
except Exception as e:
70+
print (f"Unexpected error when fetching selector: {e}")
71+
72+
5573
def is_valid_property_type(property_type: str) -> bool:
5674
"""
5775
Validates a property type.
@@ -79,7 +97,7 @@ def is_valid_property_type(property_type: str) -> bool:
7997
return True
8098

8199
def validate_node_expansion_request(data) -> (list[NodePropertyForDataExploration], EdgeDirection):
82-
required_fields = ["project", "instance", "database", "graph", "uid", "node_labels", "direction"]
100+
required_fields = ["uid", "node_labels", "direction"]
83101
missing_fields = [field for field in required_fields if data.get(field) is None]
84102

85103
if missing_fields:
@@ -146,7 +164,8 @@ def validate_node_expansion_request(data) -> (list[NodePropertyForDataExploratio
146164
return validated_properties, direction
147165

148166
def execute_node_expansion(
149-
params_str: str,
167+
selector_dict: Dict[str, Any],
168+
graph: str,
150169
request: dict) -> dict:
151170
"""Execute a node expansion query to find connected nodes and edges.
152171
@@ -158,13 +177,8 @@ def execute_node_expansion(
158177
dict: A dictionary containing the query response with nodes and edges.
159178
"""
160179

161-
params = json.loads(params_str)
162-
node_properties, direction = validate_node_expansion_request(params | request)
180+
node_properties, direction = validate_node_expansion_request(request)
163181

164-
project = params.get("project")
165-
instance = params.get("instance")
166-
database = params.get("database")
167-
graph = params.get("graph")
168182
uid = request.get("uid")
169183
node_labels = request.get("node_labels")
170184
edge_label = request.get("edge_label")
@@ -204,14 +218,11 @@ def execute_node_expansion(
204218
RETURN TO_JSON(e) as e, TO_JSON(d) as d
205219
"""
206220

207-
return execute_query(project, instance, database, query, mock=False)
221+
return execute_query(selector_dict, query)
208222

209223
def execute_query(
210-
project: str,
211-
instance: str,
212-
database: str,
224+
selector_dict: Dict[str, Any],
213225
query: str,
214-
mock: bool = False,
215226
) -> Dict[str, Any]:
216227
"""Executes a query against a database and formats the result.
217228
@@ -220,19 +231,14 @@ def execute_query(
220231
If the query fails, it returns a detailed error message, optionally
221232
including the database schema to aid in debugging.
222233
223-
Args:
224-
project: The cloud project ID.
225-
instance: The database instance name.
226-
database: The database name.
227-
query: The query string to execute.
228-
mock: If True, use a mock database instance for testing. Defaults to False.
229-
230234
Returns:
231235
A dictionary containing either the structured 'response' with nodes,
232236
edges, and other data, or an 'error' key with a descriptive message.
233237
"""
234238
try:
235-
db_instance = get_database_instance(project, instance, database, mock)
239+
selector = dict_to_selector(selector_dict)
240+
db_instance = get_database_instance(selector)
241+
236242
result: SpannerQueryResult = db_instance.execute_query(query)
237243

238244
if len(result.rows) == 0 and result.err:
@@ -382,32 +388,25 @@ def handle_post_query(self):
382388
data = self.parse_post_data()
383389
params = json.loads(data["params"])
384390
response = execute_query(
385-
project=params["project"],
386-
instance=params["instance"],
387-
database=params["database"],
388-
query=data["query"],
389-
mock=params["mock"]
391+
selector_dict=params["selector"],
392+
query=data["query"]
390393
)
391394
self.do_data_response(response)
392395

393396
def handle_post_node_expansion(self):
394-
"""Handle POST requests for node expansion.
395-
396-
Expects a JSON payload with:
397-
- params: A JSON string containing connection parameters (project, instance, database, graph)
398-
- request: A dictionary with node details (uid, node_labels, node_properties, direction, edge_label)
399-
"""
400397
try:
401398
data = self.parse_post_data()
399+
params = json.loads(data.get("params"))
400+
selector_dict = params["selector"]
401+
graph = params.get("graph")
402+
request_data = data.get("request")
402403

403-
# Execute node expansion with:
404-
# - params_str: JSON string with connection parameters (project, instance, database, graph)
405-
# - request: Dict with node details (uid, node_labels, node_properties, direction, edge_label)
406404
self.do_data_response(execute_node_expansion(
407-
params_str=data.get("params"),
408-
request=data.get("request")
405+
selector_dict=selector_dict,
406+
graph=graph,
407+
request=request_data
409408
))
410-
except BaseException as e:
409+
except Exception as e:
411410
self.do_error_response(e)
412411
return
413412

spanner_graphs/graph_visualization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def generate_visualization_html(query: str, port: int, params: str):
5757
search_dir = parent
5858

5959
template_content = _load_file([search_dir, 'frontend', 'static', 'jupyter.html'])
60-
60+
6161
# Load the JavaScript bundle directly
6262
js_file_path = os.path.join(search_dir, 'third_party', 'index.js')
6363
try:

0 commit comments

Comments
 (0)