diff --git a/frontend/static/dev.html b/frontend/static/dev.html index 5b1841e..93cc458 100644 --- a/frontend/static/dev.html +++ b/frontend/static/dev.html @@ -484,11 +484,20 @@

Configure Visualization

window.app.tearDown(); } + let selector; + if (mock) { + selector = { env: 'SpannerEnv.MOCK' }; + } else { + selector = { + env: 'SpannerEnv.CLOUD', + project: project, + instance: instance, + database: database + }; + } + const params = { - 'project': project, - 'instance': instance, - 'database': database, - 'mock': mock, + 'selector': selector, 'graph': graph }; @@ -546,4 +555,4 @@

Configure Visualization

toggleCommandPalette(); - \ No newline at end of file + diff --git a/frontend/static/test.html b/frontend/static/test.html index 915e2b2..a85d6db 100644 --- a/frontend/static/test.html +++ b/frontend/static/test.html @@ -68,16 +68,16 @@ } const mount = document.querySelector('.mount-spanner-test'); - params = { - 'project': 'project-foo', - 'instance': 'instance-foo', - 'database': 'database-foo', - 'mock': true - } + const params = { + selector: { + env: 'SpannerEnv.MOCK' + }, + graph: '' + }; window.app = new SpannerApp({ id: 'spanner-test', port:'', params:params, mount:mount, query: '' }); }); - \ No newline at end of file + diff --git a/spanner_graphs/database.py b/spanner_graphs/database.py index 91db0ac..63d94d4 100644 --- a/spanner_graphs/database.py +++ b/spanner_graphs/database.py @@ -25,6 +25,63 @@ import csv from dataclasses import dataclass +from enum import Enum, auto + +class SpannerEnv(Enum): + """Defines the types of Spanner environments the application can connect to.""" + CLOUD = auto() + INFRA = auto() + MOCK = auto() + +@dataclass +class DatabaseSelector: + """ + A factory and configuration holder for Spanner database connection details. + + This class provides a clean way to specify which Spanner database to connect to, + whether it's on Google Cloud, an internal infrastructure, or a local mock. + + Attributes: + env: The Spanner environment type. + project: The Google Cloud project. + instance: The Spanner instance. + database: The Spanner database. + infra_db_path: The path for an internal infrastructure database. + """ + env: SpannerEnv + project: str | None = None + instance: str | None = None + database: str | None = None + infra_db_path: str | None = None + + @classmethod + def cloud(cls, project: str, instance: str, database: str) -> 'DatabaseSelector': + """Creates a selector for a Google Cloud Spanner database.""" + if not project or not instance or not database: + raise ValueError("project, instance, and database are required for Cloud Spanner") + return cls(env=SpannerEnv.CLOUD, project=project, instance=instance, database=database) + + @classmethod + def infra(cls, infra_db_path: str) -> 'DatabaseSelector': + """Creates a selector for an internal infrastructure Spanner database.""" + if not infra_db_path: + raise ValueError("infra_db_path is required for Infra Spanner") + return cls(env=SpannerEnv.INFRA, infra_db_path=infra_db_path) + + @classmethod + def mock(cls) -> 'DatabaseSelector': + """Creates a selector for a mock Spanner database.""" + return cls(env=SpannerEnv.MOCK) + + def get_key(self) -> str: + if self.env == SpannerEnv.CLOUD: + return f"cloud_{self.project}_{self.instance}_{self.database}" + elif self.env == SpannerEnv.INFRA: + return f"infra_{self.infra_db_path}" + elif self.env == SpannerEnv.MOCK: + return "mock" + else: + raise ValueError("Unknown Spanner environment") class SpannerQueryResult(NamedTuple): # A dict where each key is a field name returned in the query and the list diff --git a/spanner_graphs/exec_env.py b/spanner_graphs/exec_env.py index 4a60efe..93ed825 100644 --- a/spanner_graphs/exec_env.py +++ b/spanner_graphs/exec_env.py @@ -16,26 +16,73 @@ """ This module maintains state for the execution environment of a session """ -from typing import Dict, Union -from spanner_graphs.database import SpannerDatabase, MockSpannerDatabase -from spanner_graphs.cloud_database import CloudSpannerDatabase +import importlib +from typing import Dict, Union +from spanner_graphs.database import ( + SpannerDatabase, + MockSpannerDatabase, + DatabaseSelector, + SpannerEnv, +) # Global dict of database instances created in a single session database_instances: Dict[str, Union[SpannerDatabase, MockSpannerDatabase]] = {} -def get_database_instance(project: str, instance: str, database: str, mock = False): - if mock: +def get_database_instance( + selector: DatabaseSelector, +) -> Union[SpannerDatabase, MockSpannerDatabase]: + """Gets a cached or new database instance based on the selector. + + Args: + selector: A `DatabaseSelector` object that specifies which database to + connect to. + + Returns: + An initialized `SpannerDatabase` or `MockSpannerDatabase` instance. + A CloudSpannerDatabase will only be available in public environments. + An InfraSpannerDatabase will only be available in internal environments. + + Raises: + RuntimeError: If the required Spanner client library (for Cloud or Infra) + is not installed in the environment. + ValueError: If the selector specifies an unknown or unsupported + environment. + """ + if selector.env == SpannerEnv.MOCK: return MockSpannerDatabase() - key = f"{project}_{instance}_{database}" + key = selector.get_key() db = database_instances.get(key) + if db: + return db - # Currently, we only create and return CloudSpannerDatabase instances. In the future, different - # implementations could be introduced. - if not db: - db = CloudSpannerDatabase(project, instance, database) - database_instances[key] = db + elif selector.env == SpannerEnv.CLOUD: + try: + cloud_db_module = importlib.import_module( + "spanner_graphs.cloud_database" + ) + CloudSpannerDatabase = getattr(cloud_db_module, "CloudSpannerDatabase") + db = CloudSpannerDatabase( + selector.project, selector.instance, selector.database + ) + except ImportError: + raise RuntimeError( + "Cloud Spanner support is not available in this environment." + ) + elif selector.env == SpannerEnv.INFRA: + try: + infra_db_module = importlib.import_module( + "spanner_graphs.infra_database" + ) + InfraSpannerDatabase = getattr(infra_db_module, "InfraSpannerDatabase") + db = InfraSpannerDatabase(selector.infra_db_path) + except ImportError: + raise RuntimeError( + "Infra Spanner support is not available in this environment." + ) + else: + raise ValueError(f"Unsupported Spanner environment: {selector.env}") + database_instances[key] = db return db - diff --git a/spanner_graphs/graph_server.py b/spanner_graphs/graph_server.py index cf318c3..6324207 100644 --- a/spanner_graphs/graph_server.py +++ b/spanner_graphs/graph_server.py @@ -25,7 +25,7 @@ from spanner_graphs.conversion import get_nodes_edges from spanner_graphs.exec_env import get_database_instance -from spanner_graphs.database import SpannerQueryResult +from spanner_graphs.database import DatabaseSelector, SpannerQueryResult, SpannerEnv # Supported types for a property PROPERTY_TYPE_SET = { @@ -52,6 +52,24 @@ class EdgeDirection(Enum): INCOMING = "INCOMING" OUTGOING = "OUTGOING" + +def dict_to_selector(selector_dict: Dict[str, Any]) -> DatabaseSelector: + """ + Picks the correct DB selector based on the environment the server is running in. + """ + try: + env = SpannerEnv[selector_dict['env'].split('.')[-1]] + if env == SpannerEnv.CLOUD: + return DatabaseSelector.cloud(selector_dict['project'], selector_dict['instance'], selector_dict['database']) + elif env == SpannerEnv.INFRA: + return DatabaseSelector.infra(selector_dict['infra_db_path']) + elif env == SpannerEnv.MOCK: + return DatabaseSelector.mock() + raise ValueError(f"Invalid env in selector dict: {selector_dict}") + except Exception as e: + print (f"Unexpected error when fetching selector: {e}") + + def is_valid_property_type(property_type: str) -> bool: """ Validates a property type. @@ -79,7 +97,7 @@ def is_valid_property_type(property_type: str) -> bool: return True def validate_node_expansion_request(data) -> (list[NodePropertyForDataExploration], EdgeDirection): - required_fields = ["project", "instance", "database", "graph", "uid", "node_labels", "direction"] + required_fields = ["uid", "node_labels", "direction"] missing_fields = [field for field in required_fields if data.get(field) is None] if missing_fields: @@ -146,7 +164,8 @@ def validate_node_expansion_request(data) -> (list[NodePropertyForDataExploratio return validated_properties, direction def execute_node_expansion( - params_str: str, + selector_dict: Dict[str, Any], + graph: str, request: dict) -> dict: """Execute a node expansion query to find connected nodes and edges. @@ -158,13 +177,8 @@ def execute_node_expansion( dict: A dictionary containing the query response with nodes and edges. """ - params = json.loads(params_str) - node_properties, direction = validate_node_expansion_request(params | request) + node_properties, direction = validate_node_expansion_request(request) - project = params.get("project") - instance = params.get("instance") - database = params.get("database") - graph = params.get("graph") uid = request.get("uid") node_labels = request.get("node_labels") edge_label = request.get("edge_label") @@ -204,14 +218,11 @@ def execute_node_expansion( RETURN TO_JSON(e) as e, TO_JSON(d) as d """ - return execute_query(project, instance, database, query, mock=False) + return execute_query(selector_dict, query) def execute_query( - project: str, - instance: str, - database: str, + selector_dict: Dict[str, Any], query: str, - mock: bool = False, ) -> Dict[str, Any]: """Executes a query against a database and formats the result. @@ -220,19 +231,14 @@ def execute_query( If the query fails, it returns a detailed error message, optionally including the database schema to aid in debugging. - Args: - project: The cloud project ID. - instance: The database instance name. - database: The database name. - query: The query string to execute. - mock: If True, use a mock database instance for testing. Defaults to False. - Returns: A dictionary containing either the structured 'response' with nodes, edges, and other data, or an 'error' key with a descriptive message. """ try: - db_instance = get_database_instance(project, instance, database, mock) + selector = dict_to_selector(selector_dict) + db_instance = get_database_instance(selector) + result: SpannerQueryResult = db_instance.execute_query(query) if len(result.rows) == 0 and result.err: @@ -382,32 +388,25 @@ def handle_post_query(self): data = self.parse_post_data() params = json.loads(data["params"]) response = execute_query( - project=params["project"], - instance=params["instance"], - database=params["database"], - query=data["query"], - mock=params["mock"] + selector_dict=params["selector"], + query=data["query"] ) self.do_data_response(response) def handle_post_node_expansion(self): - """Handle POST requests for node expansion. - - Expects a JSON payload with: - - params: A JSON string containing connection parameters (project, instance, database, graph) - - request: A dictionary with node details (uid, node_labels, node_properties, direction, edge_label) - """ try: data = self.parse_post_data() + params = json.loads(data.get("params")) + selector_dict = params["selector"] + graph = params.get("graph") + request_data = data.get("request") - # Execute node expansion with: - # - params_str: JSON string with connection parameters (project, instance, database, graph) - # - request: Dict with node details (uid, node_labels, node_properties, direction, edge_label) self.do_data_response(execute_node_expansion( - params_str=data.get("params"), - request=data.get("request") + selector_dict=selector_dict, + graph=graph, + request=request_data )) - except BaseException as e: + except Exception as e: self.do_error_response(e) return diff --git a/spanner_graphs/graph_visualization.py b/spanner_graphs/graph_visualization.py index 8a9cd77..30ace90 100644 --- a/spanner_graphs/graph_visualization.py +++ b/spanner_graphs/graph_visualization.py @@ -57,7 +57,7 @@ def generate_visualization_html(query: str, port: int, params: str): search_dir = parent template_content = _load_file([search_dir, 'frontend', 'static', 'jupyter.html']) - + # Load the JavaScript bundle directly js_file_path = os.path.join(search_dir, 'third_party', 'index.js') try: diff --git a/spanner_graphs/magics.py b/spanner_graphs/magics.py index b412006..1741e0d 100644 --- a/spanner_graphs/magics.py +++ b/spanner_graphs/magics.py @@ -24,6 +24,7 @@ import sys from threading import Thread import re +from dataclasses import is_dataclass, asdict from IPython.core.display import HTML, JSON from IPython.core.magic import Magics, magics_class, cell_magic @@ -33,6 +34,7 @@ from ipywidgets import interact from jinja2 import Template +from spanner_graphs.database import DatabaseSelector from spanner_graphs.exec_env import get_database_instance from spanner_graphs.graph_server import ( GraphServer, execute_query, execute_node_expansion, @@ -86,11 +88,13 @@ def is_colab() -> bool: def receive_query_request(query: str, params: str): params_dict = json.loads(params) - return JSON(execute_query(project=params_dict["project"], - instance=params_dict["instance"], - database=params_dict["database"], - query=query, - mock=params_dict["mock"])) + selector_dict = params_dict.get("selector") + if not selector_dict: + return JSON({"error": "Missing selector in params"}) + try: + return JSON(execute_query(selector_dict=selector_dict, query=query)) + except Exception as e: + return JSON({"error": str(e)}) def receive_node_expansion_request(request: dict, params_str: str): """Handle node expansion requests in Google Colab environment @@ -103,11 +107,8 @@ def receive_node_expansion_request(request: dict, params_str: str): - direction: str - Direction of expansion ("INCOMING" or "OUTGOING") - edge_label: Optional[str] - Label of edges to filter by params_str: A JSON string containing connection parameters: - - project: str - GCP project ID - - instance: str - Spanner instance ID - - database: str - Spanner database ID + - selector: Dict - The DatabaseSelector object as a dict - graph: str - Graph name - - mock: bool - Whether to use mock data Returns: JSON: A JSON-serialized response containing either: @@ -115,9 +116,23 @@ def receive_node_expansion_request(request: dict, params_str: str): - An error message if the request failed """ try: - return JSON(execute_node_expansion(params_str, request)) + params_dict = json.loads(params_str) + selector_dict = params_dict.get("selector") + graph = params_dict.get("graph") + if not selector_dict: + return JSON({"error": "Missing selector in params"}) + + return JSON(execute_node_expansion(selector_dict=selector_dict, graph=graph, request=request)) except BaseException as e: - return JSON({"error": e}) + return JSON({"error": str(e)}) + +def custom_json_serializer(o): + """A JSON serializer that handles dataclasses and enums.""" + if is_dataclass(o): + return asdict(o) + if isinstance(o, Enum): + return f"{o.__class__.__name__}.{o.name}" + raise TypeError(f"Object of type {o.__class__.__name__} is not JSON serializable") @magics_class class NetworkVisualizationMagics(Magics): @@ -129,6 +144,7 @@ def __init__(self, shell): self.limit = 5 self.args = None self.cell = None + self.selector = None if is_colab(): from google.colab import output @@ -149,17 +165,18 @@ def visualize(self): if match: graph = match.group(1) + # Pack the selector and graph into the params to be sent to the GraphServer + params = { + "selector": self.selector, + "graph": graph + } + # Generate the HTML content html_content = generate_visualization_html( query=self.cell, port=GraphServer.port, - params=json.dumps({ - "project": self.args.project, - "instance": self.args.instance, - "database": self.args.database, - "mock": self.args.mock, - "graph": graph - })) + params=json.dumps(params, default=custom_json_serializer)) + display(HTML(html_content)) @cell_magic @@ -177,35 +194,40 @@ def spanner_graph(self, line: str, cell: str): parser.add_argument("--mock", action="store_true", help="Use mock database") + parser.add_argument("--infra_db_path", + action="store_true", + help="Connect to internal Infra Spanner") try: args = parser.parse_args(line.split()) - if not args.mock: - if not (args.project and args.instance and args.database): + selector = None + if args.mock: + selector = DatabaseSelector.mock() + elif args.infra_db_path: + selector = DatabaseSelector.infra(infra_db_path=args.database) + else: + if not (args.project and args.instance): raise ValueError( - "Please provide `--project`, `--instance`, " - "and `--database` values for your query.") - if not cell or not cell.strip(): - print("Error: Query is required.") - return + "Please provide `--project` and `--instance` for Cloud Spanner." + ) + selector = DatabaseSelector.cloud(args.project, args.instance, args.database) - self.args = parser.parse_args(line.split()) + if not args.mock and (not cell or not cell.strip()): + print("Error: Query is required.") + return + + self.args = args self.cell = cell - self.database = get_database_instance( - self.args.project, - self.args.instance, - self.args.database, - mock=self.args.mock) + self.selector = selector + self.database = get_database_instance(self.selector) clear_output(wait=True) self.visualize() except BaseException as e: print(f"Error: {e}") - print("Usage: %%spanner_graph --project PROJECT_ID " - "--instance INSTANCE_ID --database DATABASE_ID " - "[--mock] ") + print(" %%spanner_graph --project --instance --database ") + print(" %%spanner_graph --mock") print(" Graph query here...") - def load_ipython_extension(ipython): """Registration function""" ipython.register_magics(NetworkVisualizationMagics) diff --git a/tests/graph_server_test.py b/tests/graph_server_test.py index 7b405e2..8a881af 100644 --- a/tests/graph_server_test.py +++ b/tests/graph_server_test.py @@ -6,6 +6,7 @@ is_valid_property_type, execute_node_expansion, ) +from spanner_graphs.database import SpannerEnv class TestPropertyTypeHandling(unittest.TestCase): def test_validate_property_type_valid_types(self): @@ -75,12 +76,14 @@ def test_property_value_formatting(self, mock_execute_query): ("ENUM", "ENUM_VALUE", "'''ENUM_VALUE'''"), ] - params = json.dumps({ + selector_dict = { + "env": str(SpannerEnv.CLOUD), "project": "test-project", "instance": "test-instance", "database": "test-database", - "graph": "test-graph", - }) + "infra_db_path": None + } + graph = "test-graph" for type_str, value, expected_format in test_cases: with self.subTest(type=type_str, value=value): @@ -95,13 +98,14 @@ def test_property_value_formatting(self, mock_execute_query): } execute_node_expansion( - params_str=params, + selector_dict=selector_dict, + graph=graph, request=request ) # Extract the actual formatted value from the query last_call = mock_execute_query.call_args[0] # Get the positional args - query = last_call[3] # The query is the 4th positional arg + query = last_call[1] # The query is the 2nd positional arg # Find the WHERE clause in the query and extract the value where_line = [line for line in query.split('\n') if 'WHERE' in line][0] @@ -117,12 +121,14 @@ def test_property_value_formatting_no_type(self, mock_execute_query): # Create a property dictionary with string type (since null type is not allowed) prop_dict = {"key": "test_property", "value": "test_value", "type": "STRING"} - params = json.dumps({ + selector_dict = { + "env": str(SpannerEnv.CLOUD), "project": "test-project", "instance": "test-instance", "database": "test-database", - "graph": "test-graph", - }) + "infra_db_path": None + } + graph = "test-graph" request = { "uid": "test-uid", @@ -132,13 +138,14 @@ def test_property_value_formatting_no_type(self, mock_execute_query): } execute_node_expansion( - params_str=params, + selector_dict=selector_dict, + graph=graph, request=request ) # Extract the actual formatted value from the query - last_call = mock_execute_query.call_args[0] - query = last_call[3] + last_call = mock_execute_query.call_args[0] # Get the positional args + query = last_call[1] # The query is the 2nd positional arg where_line = [line for line in query.split('\n') if 'WHERE' in line][0] expected_pattern = "n.test_property='''test_value'''" self.assertIn(expected_pattern, where_line, diff --git a/tests/magics_test.py b/tests/magics_test.py index fef2fac..51da9de 100644 --- a/tests/magics_test.py +++ b/tests/magics_test.py @@ -3,6 +3,7 @@ from IPython.core.interactiveshell import InteractiveShell from spanner_graphs.graph_server import GraphServer from spanner_graphs.magics import NetworkVisualizationMagics, load_ipython_extension +from spanner_graphs.database import DatabaseSelector class TestNetworkVisualizationMagics(unittest.TestCase): def setUp(self): @@ -11,6 +12,7 @@ def setUp(self): # Initialize our magic class self.magics = NetworkVisualizationMagics(self.ip) + self.magics.selector = None # Initialize selector @classmethod def tearDownClass(cls): @@ -34,38 +36,55 @@ def test_magic_registration(self): self.ip.register_magics.assert_called_once_with(NetworkVisualizationMagics) @patch('spanner_graphs.magics.get_database_instance') - @patch('spanner_graphs.magics.GraphServer') - @patch('spanner_graphs.magics.display') - def test_spanner_graph_magic_with_valid_args(self, mock_display, mock_server, mock_db): - """Test the %%spanner_graph magic with valid arguments""" + @patch('spanner_graphs.magics.generate_visualization_html') + def test_spanner_graph_magic_with_cloud_args(self, mock_generate_html, mock_db): + """Test the %%spanner_graph magic with valid cloud arguments""" # Setup mock database mock_db.return_value = MagicMock() - - # Setup mock server - mock_server.port = 8080 + mock_generate_html.return_value = "" # Test line with valid arguments line = "--project test_project --instance test_instance --database test_db" cell = "SELECT * FROM test_table" # Execute the magic - result = self.magics.spanner_graph(line, cell) + self.magics.spanner_graph(line, cell) + + # Verify database was initialized with correct parameters + expected_selector = DatabaseSelector.cloud("test_project", "test_instance", "test_db") + mock_db.assert_called_once_with(expected_selector) + self.assertEqual(self.magics.selector, expected_selector) + + # Verify display was called (exact HTML content verification would be complex) + mock_generate_html.assert_called_once() + + @patch('spanner_graphs.magics.get_database_instance') + @patch('spanner_graphs.magics.generate_visualization_html') + def test_spanner_graph_magic_with_mock_args(self, mock_generate_html, mock_db): + """Test the %%spanner_graph magic with mock arguments""" + # Setup mock database + mock_db.return_value = MagicMock() + mock_generate_html.return_value = "" + + # Test line with valid arguments + line = "--mock" + cell = "SELECT * FROM test_table" + + # Execute the magic + self.magics.spanner_graph(line, cell) # Verify database was initialized with correct parameters - mock_db.assert_called_once_with( - "test_project", - "test_instance", - "test_db", - mock=False - ) + expected_selector = DatabaseSelector.mock() + mock_db.assert_called_once_with(expected_selector) + self.assertEqual(self.magics.selector, expected_selector) # Verify display was called (exact HTML content verification would be complex) - mock_display.assert_called_once() + mock_generate_html.assert_called_once() def test_spanner_graph_magic_with_invalid_args(self): """Test the %%spanner_graph magic with invalid arguments""" - # Test with missing required arguments - line = "--project test_project" # Missing instance and database + # Test with missing required arguments for cloud + line = "--project test_project --database test_db" # Missing instance cell = "SELECT * FROM test_table" # Execute the magic and capture output @@ -74,8 +93,7 @@ def test_spanner_graph_magic_with_invalid_args(self): # Verify error message was printed mock_print.assert_any_call( - "Error: Please provide `--project`, `--instance`, " - "and `--database` values for your query." + "Error: Please provide `--project` and `--instance` for Cloud Spanner." ) def test_spanner_graph_magic_with_empty_cell(self): diff --git a/tests/node_expansion_test.py b/tests/node_expansion_test.py index 900caab..172d680 100644 --- a/tests/node_expansion_test.py +++ b/tests/node_expansion_test.py @@ -4,23 +4,29 @@ from spanner_graphs.magics import receive_node_expansion_request from spanner_graphs.graph_server import EdgeDirection +from spanner_graphs.database import DatabaseSelector, SpannerEnv class TestNodeExpansion(unittest.TestCase): def setUp(self): self.sample_request = { "uid": "node-123", - "node_key_property_name": "id", - "node_key_property_value": "123", - "node_key_property_type": "INT64", + "node_labels": ["Person"], + "node_properties": [ + {"key": "id", "value": "123", "type": "INT64"} + ], "direction": "OUTGOING", "edge_label": "CONNECTS_TO" } + # Updated params to use DatabaseSelector structure self.sample_params = json.dumps({ - "project": "test-project", - "instance": "test-instance", - "database": "test-database", + "selector": { + "env": str(SpannerEnv.CLOUD), + "project": "test-project", + "instance": "test-instance", + "database": "test-database", + "infra_db_path": None + }, "graph": "test_graph", - "mock": False }) @patch('spanner_graphs.magics.validate_node_expansion_request') @@ -36,30 +42,16 @@ def test_receive_node_expansion_request(self, mock_execute, mock_validate): } } - # Create request and params objects - request = { - "uid": "node-123", - "node_labels": ["Person"], - "node_properties": [ - {"key": "id", "value": "123", "type": "INT64"} - ], - "direction": "OUTGOING", - "edge_label": "CONNECTS_TO" - } - - params = json.dumps({ - "project": "test-project", - "instance": "test-instance", - "database": "test-database", - "graph": "test_graph", - "mock": False - }) - # Call the function - result = receive_node_expansion_request(request, params) + result = receive_node_expansion_request(self.sample_request, self.sample_params) # Verify execute_node_expansion was called with correct parameters - mock_execute.assert_called_once_with(params, request) + params_dict = json.loads(self.sample_params) + mock_execute.assert_called_once_with( + selector_dict=params_dict["selector"], + graph=params_dict["graph"], + request=self.sample_request + ) # Verify the result is wrapped in JSON self.assertEqual(result.data, mock_execute.return_value) @@ -77,30 +69,20 @@ def test_receive_node_expansion_request_without_edge_label(self, mock_execute, m } } - # Create request without edge_label and params objects - request = { - "uid": "node-123", - "node_labels": ["Person"], - "node_properties": [ - {"key": "id", "value": "123", "type": "INT64"} - ], - "direction": "OUTGOING" - # No edge_label - } - - params = json.dumps({ - "project": "test-project", - "instance": "test-instance", - "database": "test-database", - "graph": "test_graph", - "mock": False - }) + # Create request without edge_label + request = self.sample_request.copy() + del request["edge_label"] # Call the function - result = receive_node_expansion_request(request, params) + result = receive_node_expansion_request(request, self.sample_params) # Verify execute_node_expansion was called with correct parameters - mock_execute.assert_called_once_with(params, request) + params_dict = json.loads(self.sample_params) + mock_execute.assert_called_once_with( + selector_dict=params_dict["selector"], + graph=params_dict["graph"], + request=request + ) # Verify the result is wrapped in JSON self.assertEqual(result.data, mock_execute.return_value) @@ -121,17 +103,10 @@ def test_invalid_property_type(self, mock_validate): "direction": "OUTGOING" } - params = json.dumps({ - "project": "test-project", - "instance": "test-instance", - "database": "test-database", - "graph": "test_graph", - "mock": False - }) - # Call the function and verify it returns an error response - result = receive_node_expansion_request(request, params) + result = receive_node_expansion_request(request, self.sample_params) self.assertIn("error", result.data) + self.assertIn("Invalid property type", result.data["error"]) @patch('spanner_graphs.magics.validate_node_expansion_request') def test_invalid_direction(self, mock_validate): @@ -149,17 +124,10 @@ def test_invalid_direction(self, mock_validate): "direction": "INVALID_DIRECTION" } - params = json.dumps({ - "project": "test-project", - "instance": "test-instance", - "database": "test-database", - "graph": "test_graph", - "mock": False - }) - # Call the function and verify it returns an error response - result = receive_node_expansion_request(request, params) + result = receive_node_expansion_request(request, self.sample_params) self.assertIn("error", result.data) + self.assertIn("Invalid direction", result.data["error"]) if __name__ == '__main__': unittest.main() diff --git a/tests/sample_notebook_test.py b/tests/sample_notebook_test.py index 17400d2..ab12c53 100644 --- a/tests/sample_notebook_test.py +++ b/tests/sample_notebook_test.py @@ -4,6 +4,7 @@ from IPython.core.interactiveshell import InteractiveShell from spanner_graphs.graph_server import GraphServer from spanner_graphs.magics import NetworkVisualizationMagics, load_ipython_extension +from spanner_graphs.database import DatabaseSelector class TestSampleNotebook(unittest.TestCase): def setUp(self): @@ -12,6 +13,7 @@ def setUp(self): # Initialize our magic class self.magics = NetworkVisualizationMagics(self.ip) + self.magics.selector = None # Load the notebook content with open('sample.ipynb', 'r') as f: @@ -59,29 +61,24 @@ def test_notebook_cells(self): # Test the mock visualization with mocked dependencies with patch('spanner_graphs.magics.get_database_instance') as mock_db, \ - patch('spanner_graphs.magics.GraphServer') as mock_server, \ - patch('spanner_graphs.magics.display') as mock_display: + patch('spanner_graphs.magics.generate_visualization_html') as mock_generate_html: mock_db.return_value = MagicMock() - mock_server.port = 8080 + mock_generate_html.return_value = "" # Test with a valid query since empty cell is handled by IPython line = '--mock' cell = 'GRAPH FinGraph\nMATCH p = (a)-[e]->(b)\nRETURN TO_JSON(p) AS path\nLIMIT 100' # Execute the magic with a valid query - result = self.magics.spanner_graph(line, cell) + self.magics.spanner_graph(line, cell) # Verify database was initialized with mock=True - mock_db.assert_called_once_with( - None, # project - None, # instance - None, # database - mock=True - ) + expected_selector = DatabaseSelector.mock() + mock_db.assert_called_once_with(expected_selector) # Verify display was called - mock_display.assert_called_once() + mock_generate_html.assert_called_once() # Fourth cell should be the Spanner Graph query query_cell = self.code_cells[3] @@ -97,29 +94,28 @@ def test_notebook_cells(self): # Test the query with mocked dependencies with patch('spanner_graphs.magics.get_database_instance') as mock_db, \ - patch('spanner_graphs.magics.GraphServer') as mock_server, \ - patch('spanner_graphs.magics.display') as mock_display: + patch('spanner_graphs.magics.generate_visualization_html') as mock_generate_html: mock_db.return_value = MagicMock() - mock_server.port = 8080 + mock_generate_html.return_value = "" # Extract the actual line and cell content from the notebook line = next(line for line in query_cell['source'] if line.startswith('%%spanner_graph')).replace('%%spanner_graph ', '') cell = ''.join(line for line in query_cell['source'] if not line.startswith('%%spanner_graph')) # Execute the magic with the actual notebook content - result = self.magics.spanner_graph(line, cell) + self.magics.spanner_graph(line, cell) # Verify database was initialized with placeholder values - mock_db.assert_called_once_with( + expected_selector = DatabaseSelector.cloud( "{project_id}", "{instance_name}", - "{database_name}", - mock=False + "{database_name}" ) + mock_db.assert_called_once_with(expected_selector) # Verify display was called - mock_display.assert_called_once() + mock_generate_html.assert_called_once() if __name__ == '__main__': unittest.main() diff --git a/tests/server_test.py b/tests/server_test.py index ecdf514..9c1bc28 100644 --- a/tests/server_test.py +++ b/tests/server_test.py @@ -16,6 +16,7 @@ import requests import json from spanner_graphs.graph_server import GraphServer +from spanner_graphs.database import SpannerEnv class TestSpannerServer(unittest.TestCase): def setUp(self): @@ -39,15 +40,15 @@ def test_post_query_with_mock(self): """Test querying with mock database""" # Build the request URL route = GraphServer.build_route(GraphServer.endpoints["post_query"]) - + # Create request data with the new structure params = json.dumps({ - "project": "test-project", - "instance": "test-instance", - "database": "test-database", - "mock": True + "selector": { + "env": str(SpannerEnv.MOCK) + }, + "graph": "TestGraph" }) - + request_data = { "params": params, "query": "GRAPH TestGraph MATCH (n) RETURN n" @@ -55,11 +56,11 @@ def test_post_query_with_mock(self): # Send POST request response = requests.post(route, json=request_data) - + # Verify response self.assertEqual(response.status_code, 200) response_data = response.json() - + # Check response structure self.assertIn("response", response_data) response = response_data["response"] @@ -72,13 +73,13 @@ def test_post_query_with_mock(self): # Verify we got some data self.assertTrue(len(response["nodes"]) > 0, "Should have at least one node") self.assertTrue(len(response["edges"]) > 0, "Should have at least one edge") - + # Verify node structure node = response["nodes"][0] self.assertIn("identifier", node) self.assertIn("labels", node) self.assertIn("properties", node) - + # Verify edge structure edge = response["edges"][0] self.assertIn("identifier", edge) @@ -91,25 +92,33 @@ def test_node_expansion_error_handling(self): """Test that errors in node expansion are properly handled and returned.""" # Build the request URL route = GraphServer.build_route(GraphServer.endpoints["post_node_expansion"]) - + # Create request data with invalid fields to trigger validation error - request_data = { - "project": "test-project", - "instance": "test-instance", - "database": "test-database", + params = { + "selector": { + "env": str(SpannerEnv.CLOUD), + "project_id": "test-project", + "instance_id": "test-instance", + "database_id": "test-database" + }, "graph": "test-graph", - "uid": "test-uid", - # Missing required node_labels field - "direction": "INVALID_DIRECTION" # Invalid direction + } + request_data = { + "params": json.dumps(params), + "request": { + "uid": "test-uid", + # Missing required node_labels field + "direction": "INVALID_DIRECTION" # Invalid direction + } } # Send POST request response = requests.post(route, json=request_data) - + # Verify response self.assertEqual(response.status_code, 200) # Server still returns 200 but with error data response_data = response.json() - + # Check error presence self.assertIn("error", response_data) self.assertIsNotNone(response_data["error"])