diff --git a/spanner_graphs/cloud_database.py b/spanner_graphs/cloud_database.py index 00dd047..fcf6096 100644 --- a/spanner_graphs/cloud_database.py +++ b/spanner_graphs/cloud_database.py @@ -39,11 +39,29 @@ def get_as_field_info_list(fields: List[StructType.Field]) -> List[SpannerFieldI class CloudSpannerDatabase(SpannerDatabase): """Concrete implementation for Spanner database on the cloud.""" - def __init__(self, project_id: str, instance_id: str, - database_id: str) -> None: - credentials, _ = _get_default_credentials_with_project() - self.client = spanner.Client( - project=project_id, credentials=credentials, client_options=ClientOptions(quota_project_id=project_id)) + def __init__( + self, + project_id: str, + instance_id: str, + database_id: str, + experimental_host: str | None = None, + use_plain_text: bool = False, + ca_certificate: str | None = None, + client_certificate: str | None = None, + client_key: str | None = None, + ) -> None: + if experimental_host: + self.client = spanner.Client( + use_plain_text=use_plain_text, + experimental_host=experimental_host, + ca_certificate=ca_certificate, + client_certificate=client_certificate, + client_key=client_key, + ) + else: + credentials, _ = _get_default_credentials_with_project() + self.client = spanner.Client( + project=project_id, credentials=credentials, client_options=ClientOptions(quota_project_id=project_id)) self.instance = self.client.instance(instance_id) logger = logging.getLogger("spanner_graphs") logger.setLevel(logging.CRITICAL) diff --git a/spanner_graphs/database.py b/spanner_graphs/database.py index 63d94d4..176cf20 100644 --- a/spanner_graphs/database.py +++ b/spanner_graphs/database.py @@ -32,6 +32,8 @@ class SpannerEnv(Enum): CLOUD = auto() INFRA = auto() MOCK = auto() + EXPERIMENTAL_HOST = auto() + @dataclass class DatabaseSelector: @@ -47,12 +49,22 @@ class DatabaseSelector: instance: The Spanner instance. database: The Spanner database. infra_db_path: The path for an internal infrastructure database. + experimental_host: The Spanner experimental host endpoint. + use_plain_text: Whether to use plain text for the experimental host endpoint. + ca_certificate: CA certificate path for the experimental host endpoint. + client_certificate: Client certificate path for the experimental host endpoint. + client_key: Client key path for the experimental host endpoint. """ env: SpannerEnv project: str | None = None instance: str | None = None database: str | None = None infra_db_path: str | None = None + experimental_host: str | None = None + use_plain_text: bool = False + ca_certificate: str | None = None + client_certificate: str | None = None + client_key: str | None = None @classmethod def cloud(cls, project: str, instance: str, database: str) -> 'DatabaseSelector': @@ -73,6 +85,27 @@ def mock(cls) -> 'DatabaseSelector': """Creates a selector for a mock Spanner database.""" return cls(env=SpannerEnv.MOCK) + @classmethod + def experimental_host( + cls, experimental_host: str, database: str, use_plain_text: bool = False, ca_certificate: str | None = None, client_certificate: str | None = None, client_key: str | None = None + ) -> "DatabaseSelector": + """Creates a selector for a Google Experimental Host Spanner database.""" + if not database: + raise ValueError( + "database is required for Experimental Host Spanner Endpoint" + ) + return cls( + env=SpannerEnv.EXPERIMENTAL_HOST, + project="default", + instance="default", + database=database, + experimental_host=experimental_host, + use_plain_text=use_plain_text, + ca_certificate=ca_certificate, + client_certificate=client_certificate, + client_key=client_key, + ) + def get_key(self) -> str: if self.env == SpannerEnv.CLOUD: return f"cloud_{self.project}_{self.instance}_{self.database}" @@ -80,6 +113,8 @@ def get_key(self) -> str: return f"infra_{self.infra_db_path}" elif self.env == SpannerEnv.MOCK: return "mock" + elif self.env == SpannerEnv.EXPERIMENTAL_HOST: + return f"experimental_host_{self.database}" else: raise ValueError("Unknown Spanner environment") diff --git a/spanner_graphs/exec_env.py b/spanner_graphs/exec_env.py index 93ed825..ea2a1b8 100644 --- a/spanner_graphs/exec_env.py +++ b/spanner_graphs/exec_env.py @@ -81,6 +81,24 @@ def get_database_instance( raise RuntimeError( "Infra Spanner support is not available in this environment." ) + elif selector.env == SpannerEnv.EXPERIMENTAL_HOST: + try: + cloud_db_module = importlib.import_module("spanner_graphs.cloud_database") + CloudSpannerDatabase = getattr(cloud_db_module, "CloudSpannerDatabase") + db = CloudSpannerDatabase( + selector.project, + selector.instance, + selector.database, + selector.experimental_host, + selector.use_plain_text, + selector.ca_certificate, + selector.client_certificate, + selector.client_key, + ) + except ImportError: + raise RuntimeError( + "Spanner experimental host support is not available in this environment." + ) else: raise ValueError(f"Unsupported Spanner environment: {selector.env}") diff --git a/spanner_graphs/graph_server.py b/spanner_graphs/graph_server.py index 6324207..7d9732d 100644 --- a/spanner_graphs/graph_server.py +++ b/spanner_graphs/graph_server.py @@ -65,6 +65,10 @@ def dict_to_selector(selector_dict: Dict[str, Any]) -> DatabaseSelector: return DatabaseSelector.infra(selector_dict['infra_db_path']) elif env == SpannerEnv.MOCK: return DatabaseSelector.mock() + elif env == SpannerEnv.EXPERIMENTAL_HOST: + return DatabaseSelector.experimental_host( + selector_dict["experimental_host"], selector_dict["database"], selector_dict["use_plain_text"], selector_dict["ca_certificate"], selector_dict["client_certificate"], selector_dict["client_key"] + ) raise ValueError(f"Invalid env in selector dict: {selector_dict}") except Exception as e: print (f"Unexpected error when fetching selector: {e}") diff --git a/spanner_graphs/magics.py b/spanner_graphs/magics.py index 1741e0d..db2155a 100644 --- a/spanner_graphs/magics.py +++ b/spanner_graphs/magics.py @@ -197,14 +197,59 @@ def spanner_graph(self, line: str, cell: str): parser.add_argument("--infra_db_path", action="store_true", help="Connect to internal Infra Spanner") + parser.add_argument( + "--experimental_host", + type=str, + required=False, + help="Spanner experimental host endpoint", + ) + parser.add_argument( + "--use_plain_text", + action="store_true", + help="[Experimental Host Only] Use plain text communication for the experimental host", + ) + parser.add_argument( + "--ca_certificate", + type=str, + required=False, + help="[Experimental Host Only] CA certificate path for the experimental host", + ) + parser.add_argument( + "--client_certificate", + type=str, + required=False, + help="[Experimental Host Only] Client certificate path for the experimental host", + ) + parser.add_argument( + "--client_key", + type=str, + required=False, + help="[Experimental Host Only] Client key path for the experimental host", + ) try: args = parser.parse_args(line.split()) selector = None + if not args.experimental_host: + if args.use_plain_text or args.ca_certificate or args.client_certificate or args.client_key: + raise ValueError("use_plain_text, ca_certificate, client_certificate and client_key are only supported for Experimental Host") if args.mock: selector = DatabaseSelector.mock() elif args.infra_db_path: selector = DatabaseSelector.infra(infra_db_path=args.database) + elif args.experimental_host: + if args.use_plain_text: + if args.ca_certificate or args.client_certificate or args.client_key: + raise ValueError("When use_plain_text is true, no other certificate parameters should be set.") + elif not args.ca_certificate: + raise ValueError("Either use_plain_text must be true or ca_certificate must be set.") + + if bool(args.client_certificate) != bool(args.client_key): + raise ValueError("client_certificate and client_key must both be provided together.") + + selector = DatabaseSelector.experimental_host( + experimental_host=args.experimental_host, database=args.database, use_plain_text=args.use_plain_text, ca_certificate=args.ca_certificate, client_certificate=args.client_certificate, client_key=args.client_key + ) else: if not (args.project and args.instance): raise ValueError( @@ -226,6 +271,7 @@ def spanner_graph(self, line: str, cell: str): print(f"Error: {e}") print(" %%spanner_graph --project --instance --database ") print(" %%spanner_graph --mock") + print(" %%spanner_graph --experimental_host --database [--use_plain_text] [--ca_certificate ] [--client_certificate ] [--client_key ]") print(" Graph query here...") def load_ipython_extension(ipython):