Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions frontend/static/dev.html
Original file line number Diff line number Diff line change
Expand Up @@ -484,11 +484,20 @@ <h2>Configure Visualization</h2>
window.app.tearDown();
}

let selector;
if (mock) {
selector = { env: 'SpannerEnv.MOCK' };
} else {
selector = {
env: 'SpannerEnv.CLOUD',
project: project,
instance: instance,
database: database
};
}

const params = {
'project': project,
'instance': instance,
'database': database,
'mock': mock,
'selector': selector,
'graph': graph
};

Expand Down Expand Up @@ -546,4 +555,4 @@ <h2>Configure Visualization</h2>
toggleCommandPalette();
</script>
</body>
</html>
</html>
14 changes: 7 additions & 7 deletions frontend/static/test.html
Original file line number Diff line number Diff line change
Expand Up @@ -68,16 +68,16 @@
}

const mount = document.querySelector('.mount-spanner-test');
params = {
'project': 'project-foo',
'instance': 'instance-foo',
'database': 'database-foo',
'mock': true
}
const params = {
selector: {
env: 'SpannerEnv.MOCK'
},
graph: ''
};
window.app = new SpannerApp({
id: 'spanner-test', port:'', params:params, mount:mount, query: ''
});
});
</script>
</body>
</html>
</html>
57 changes: 57 additions & 0 deletions spanner_graphs/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,63 @@
import csv

from dataclasses import dataclass
from enum import Enum, auto

class SpannerEnv(Enum):
"""Defines the types of Spanner environments the application can connect to."""
CLOUD = auto()
INFRA = auto()
MOCK = auto()

@dataclass
class DatabaseSelector:
"""
A factory and configuration holder for Spanner database connection details.

This class provides a clean way to specify which Spanner database to connect to,
whether it's on Google Cloud, an internal infrastructure, or a local mock.

Attributes:
env: The Spanner environment type.
project: The Google Cloud project.
instance: The Spanner instance.
database: The Spanner database.
infra_db_path: The path for an internal infrastructure database.
"""
env: SpannerEnv
project: str | None = None
instance: str | None = None
database: str | None = None
infra_db_path: str | None = None

@classmethod
def cloud(cls, project: str, instance: str, database: str) -> 'DatabaseSelector':
"""Creates a selector for a Google Cloud Spanner database."""
if not project or not instance or not database:
raise ValueError("project, instance, and database are required for Cloud Spanner")
return cls(env=SpannerEnv.CLOUD, project=project, instance=instance, database=database)

@classmethod
def infra(cls, infra_db_path: str) -> 'DatabaseSelector':
"""Creates a selector for an internal infrastructure Spanner database."""
if not infra_db_path:
raise ValueError("infra_db_path is required for Infra Spanner")
return cls(env=SpannerEnv.INFRA, infra_db_path=infra_db_path)

@classmethod
def mock(cls) -> 'DatabaseSelector':
"""Creates a selector for a mock Spanner database."""
return cls(env=SpannerEnv.MOCK)

def get_key(self) -> str:
if self.env == SpannerEnv.CLOUD:
return f"cloud_{self.project}_{self.instance}_{self.database}"
elif self.env == SpannerEnv.INFRA:
return f"infra_{self.infra_db_path}"
elif self.env == SpannerEnv.MOCK:
return "mock"
else:
raise ValueError("Unknown Spanner environment")

class SpannerQueryResult(NamedTuple):
# A dict where each key is a field name returned in the query and the list
Expand Down
71 changes: 59 additions & 12 deletions spanner_graphs/exec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,73 @@
"""
This module maintains state for the execution environment of a session
"""
from typing import Dict, Union

from spanner_graphs.database import SpannerDatabase, MockSpannerDatabase
from spanner_graphs.cloud_database import CloudSpannerDatabase
import importlib
from typing import Dict, Union
from spanner_graphs.database import (
SpannerDatabase,
MockSpannerDatabase,
DatabaseSelector,
SpannerEnv,
)

# Global dict of database instances created in a single session
database_instances: Dict[str, Union[SpannerDatabase, MockSpannerDatabase]] = {}

def get_database_instance(project: str, instance: str, database: str, mock = False):
if mock:
def get_database_instance(
selector: DatabaseSelector,
) -> Union[SpannerDatabase, MockSpannerDatabase]:
"""Gets a cached or new database instance based on the selector.

Args:
selector: A `DatabaseSelector` object that specifies which database to
connect to.

Returns:
An initialized `SpannerDatabase` or `MockSpannerDatabase` instance.
A CloudSpannerDatabase will only be available in public environments.
An InfraSpannerDatabase will only be available in internal environments.

Raises:
RuntimeError: If the required Spanner client library (for Cloud or Infra)
is not installed in the environment.
ValueError: If the selector specifies an unknown or unsupported
environment.
"""
if selector.env == SpannerEnv.MOCK:
return MockSpannerDatabase()

key = f"{project}_{instance}_{database}"
key = selector.get_key()
db = database_instances.get(key)
if db:
return db

# Currently, we only create and return CloudSpannerDatabase instances. In the future, different
# implementations could be introduced.
if not db:
db = CloudSpannerDatabase(project, instance, database)
database_instances[key] = db
elif selector.env == SpannerEnv.CLOUD:
try:
cloud_db_module = importlib.import_module(
"spanner_graphs.cloud_database"
)
CloudSpannerDatabase = getattr(cloud_db_module, "CloudSpannerDatabase")
db = CloudSpannerDatabase(
selector.project, selector.instance, selector.database
)
except ImportError:
raise RuntimeError(
"Cloud Spanner support is not available in this environment."
)
elif selector.env == SpannerEnv.INFRA:
try:
infra_db_module = importlib.import_module(
"spanner_graphs.infra_database"
)
InfraSpannerDatabase = getattr(infra_db_module, "InfraSpannerDatabase")
db = InfraSpannerDatabase(selector.infra_db_path)
except ImportError:
raise RuntimeError(
"Infra Spanner support is not available in this environment."
)
else:
raise ValueError(f"Unsupported Spanner environment: {selector.env}")

database_instances[key] = db
return db

77 changes: 38 additions & 39 deletions spanner_graphs/graph_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from spanner_graphs.conversion import get_nodes_edges
from spanner_graphs.exec_env import get_database_instance
from spanner_graphs.database import SpannerQueryResult
from spanner_graphs.database import DatabaseSelector, SpannerQueryResult, SpannerEnv

# Supported types for a property
PROPERTY_TYPE_SET = {
Expand All @@ -52,6 +52,24 @@ class EdgeDirection(Enum):
INCOMING = "INCOMING"
OUTGOING = "OUTGOING"


def dict_to_selector(selector_dict: Dict[str, Any]) -> DatabaseSelector:
"""
Picks the correct DB selector based on the environment the server is running in.
"""
try:
env = SpannerEnv[selector_dict['env'].split('.')[-1]]
if env == SpannerEnv.CLOUD:
return DatabaseSelector.cloud(selector_dict['project'], selector_dict['instance'], selector_dict['database'])
elif env == SpannerEnv.INFRA:
return DatabaseSelector.infra(selector_dict['infra_db_path'])
elif env == SpannerEnv.MOCK:
return DatabaseSelector.mock()
raise ValueError(f"Invalid env in selector dict: {selector_dict}")
except Exception as e:
print (f"Unexpected error when fetching selector: {e}")


def is_valid_property_type(property_type: str) -> bool:
"""
Validates a property type.
Expand Down Expand Up @@ -79,7 +97,7 @@ def is_valid_property_type(property_type: str) -> bool:
return True

def validate_node_expansion_request(data) -> (list[NodePropertyForDataExploration], EdgeDirection):
required_fields = ["project", "instance", "database", "graph", "uid", "node_labels", "direction"]
required_fields = ["uid", "node_labels", "direction"]
missing_fields = [field for field in required_fields if data.get(field) is None]

if missing_fields:
Expand Down Expand Up @@ -146,7 +164,8 @@ def validate_node_expansion_request(data) -> (list[NodePropertyForDataExploratio
return validated_properties, direction

def execute_node_expansion(
params_str: str,
selector_dict: Dict[str, Any],
graph: str,
request: dict) -> dict:
"""Execute a node expansion query to find connected nodes and edges.

Expand All @@ -158,13 +177,8 @@ def execute_node_expansion(
dict: A dictionary containing the query response with nodes and edges.
"""

params = json.loads(params_str)
node_properties, direction = validate_node_expansion_request(params | request)
node_properties, direction = validate_node_expansion_request(request)

project = params.get("project")
instance = params.get("instance")
database = params.get("database")
graph = params.get("graph")
uid = request.get("uid")
node_labels = request.get("node_labels")
edge_label = request.get("edge_label")
Expand Down Expand Up @@ -204,14 +218,11 @@ def execute_node_expansion(
RETURN TO_JSON(e) as e, TO_JSON(d) as d
"""

return execute_query(project, instance, database, query, mock=False)
return execute_query(selector_dict, query)

def execute_query(
project: str,
instance: str,
database: str,
selector_dict: Dict[str, Any],
query: str,
mock: bool = False,
) -> Dict[str, Any]:
"""Executes a query against a database and formats the result.

Expand All @@ -220,19 +231,14 @@ def execute_query(
If the query fails, it returns a detailed error message, optionally
including the database schema to aid in debugging.

Args:
project: The cloud project ID.
instance: The database instance name.
database: The database name.
query: The query string to execute.
mock: If True, use a mock database instance for testing. Defaults to False.

Returns:
A dictionary containing either the structured 'response' with nodes,
edges, and other data, or an 'error' key with a descriptive message.
"""
try:
db_instance = get_database_instance(project, instance, database, mock)
selector = dict_to_selector(selector_dict)
db_instance = get_database_instance(selector)

result: SpannerQueryResult = db_instance.execute_query(query)

if len(result.rows) == 0 and result.err:
Expand Down Expand Up @@ -382,32 +388,25 @@ def handle_post_query(self):
data = self.parse_post_data()
params = json.loads(data["params"])
response = execute_query(
project=params["project"],
instance=params["instance"],
database=params["database"],
query=data["query"],
mock=params["mock"]
selector_dict=params["selector"],
query=data["query"]
)
self.do_data_response(response)

def handle_post_node_expansion(self):
"""Handle POST requests for node expansion.

Expects a JSON payload with:
- params: A JSON string containing connection parameters (project, instance, database, graph)
- request: A dictionary with node details (uid, node_labels, node_properties, direction, edge_label)
"""
try:
data = self.parse_post_data()
params = json.loads(data.get("params"))
selector_dict = params["selector"]
graph = params.get("graph")
request_data = data.get("request")

# Execute node expansion with:
# - params_str: JSON string with connection parameters (project, instance, database, graph)
# - request: Dict with node details (uid, node_labels, node_properties, direction, edge_label)
self.do_data_response(execute_node_expansion(
params_str=data.get("params"),
request=data.get("request")
selector_dict=selector_dict,
graph=graph,
request=request_data
))
except BaseException as e:
except Exception as e:
self.do_error_response(e)
return

Expand Down
2 changes: 1 addition & 1 deletion spanner_graphs/graph_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def generate_visualization_html(query: str, port: int, params: str):
search_dir = parent

template_content = _load_file([search_dir, 'frontend', 'static', 'jupyter.html'])

# Load the JavaScript bundle directly
js_file_path = os.path.join(search_dir, 'third_party', 'index.js')
try:
Expand Down
Loading