Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 23 additions & 5 deletions spanner_graphs/cloud_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,29 @@ def get_as_field_info_list(fields: List[StructType.Field]) -> List[SpannerFieldI

class CloudSpannerDatabase(SpannerDatabase):
"""Concrete implementation for Spanner database on the cloud."""
def __init__(self, project_id: str, instance_id: str,
database_id: str) -> None:
credentials, _ = _get_default_credentials_with_project()
self.client = spanner.Client(
project=project_id, credentials=credentials, client_options=ClientOptions(quota_project_id=project_id))
def __init__(
self,
project_id: str,
instance_id: str,
database_id: str,
experimental_host: str | None = None,
use_plain_text: bool = False,
ca_certificate: str | None = None,
client_certificate: str | None = None,
client_key: str | None = None,
) -> None:
if experimental_host:
self.client = spanner.Client(
use_plain_text=use_plain_text,
experimental_host=experimental_host,
ca_certificate=ca_certificate,
client_certificate=client_certificate,
client_key=client_key,
)
else:
credentials, _ = _get_default_credentials_with_project()
self.client = spanner.Client(
project=project_id, credentials=credentials, client_options=ClientOptions(quota_project_id=project_id))
self.instance = self.client.instance(instance_id)
logger = logging.getLogger("spanner_graphs")
logger.setLevel(logging.CRITICAL)
Expand Down
35 changes: 35 additions & 0 deletions spanner_graphs/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ class SpannerEnv(Enum):
CLOUD = auto()
INFRA = auto()
MOCK = auto()
EXPERIMENTAL_HOST = auto()


@dataclass
class DatabaseSelector:
Expand All @@ -47,12 +49,22 @@ class DatabaseSelector:
instance: The Spanner instance.
database: The Spanner database.
infra_db_path: The path for an internal infrastructure database.
experimental_host: The Spanner experimental host endpoint.
use_plain_text: Whether to use plain text for the experimental host endpoint.
ca_certificate: CA certificate path for the experimental host endpoint.
client_certificate: Client certificate path for the experimental host endpoint.
client_key: Client key path for the experimental host endpoint.
"""
env: SpannerEnv
project: str | None = None
instance: str | None = None
database: str | None = None
infra_db_path: str | None = None
experimental_host: str | None = None
use_plain_text: bool = False
ca_certificate: str | None = None
client_certificate: str | None = None
client_key: str | None = None

@classmethod
def cloud(cls, project: str, instance: str, database: str) -> 'DatabaseSelector':
Expand All @@ -73,13 +85,36 @@ def mock(cls) -> 'DatabaseSelector':
"""Creates a selector for a mock Spanner database."""
return cls(env=SpannerEnv.MOCK)

@classmethod
def experimental_host(
cls, experimental_host: str, database: str, use_plain_text: bool = False, ca_certificate: str | None = None, client_certificate: str | None = None, client_key: str | None = None
) -> "DatabaseSelector":
"""Creates a selector for a Google Experimental Host Spanner database."""
if not database:
raise ValueError(
"database is required for Experimental Host Spanner Endpoint"
)
return cls(
env=SpannerEnv.EXPERIMENTAL_HOST,
project="default",
instance="default",
database=database,
experimental_host=experimental_host,
use_plain_text=use_plain_text,
ca_certificate=ca_certificate,
client_certificate=client_certificate,
client_key=client_key,
)

def get_key(self) -> str:
if self.env == SpannerEnv.CLOUD:
return f"cloud_{self.project}_{self.instance}_{self.database}"
elif self.env == SpannerEnv.INFRA:
return f"infra_{self.infra_db_path}"
elif self.env == SpannerEnv.MOCK:
return "mock"
elif self.env == SpannerEnv.EXPERIMENTAL_HOST:
return f"experimental_host_{self.database}"
else:
raise ValueError("Unknown Spanner environment")

Expand Down
18 changes: 18 additions & 0 deletions spanner_graphs/exec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,24 @@ def get_database_instance(
raise RuntimeError(
"Infra Spanner support is not available in this environment."
)
elif selector.env == SpannerEnv.EXPERIMENTAL_HOST:
try:
cloud_db_module = importlib.import_module("spanner_graphs.cloud_database")
CloudSpannerDatabase = getattr(cloud_db_module, "CloudSpannerDatabase")
db = CloudSpannerDatabase(
selector.project,
selector.instance,
selector.database,
selector.experimental_host,
selector.use_plain_text,
selector.ca_certificate,
selector.client_certificate,
selector.client_key,
)
except ImportError:
raise RuntimeError(
"Spanner experimental host support is not available in this environment."
)
else:
raise ValueError(f"Unsupported Spanner environment: {selector.env}")

Expand Down
4 changes: 4 additions & 0 deletions spanner_graphs/graph_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ def dict_to_selector(selector_dict: Dict[str, Any]) -> DatabaseSelector:
return DatabaseSelector.infra(selector_dict['infra_db_path'])
elif env == SpannerEnv.MOCK:
return DatabaseSelector.mock()
elif env == SpannerEnv.EXPERIMENTAL_HOST:
return DatabaseSelector.experimental_host(
selector_dict["experimental_host"], selector_dict["database"], selector_dict["use_plain_text"], selector_dict["ca_certificate"], selector_dict["client_certificate"], selector_dict["client_key"]
)
raise ValueError(f"Invalid env in selector dict: {selector_dict}")
except Exception as e:
print (f"Unexpected error when fetching selector: {e}")
Expand Down
46 changes: 46 additions & 0 deletions spanner_graphs/magics.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,14 +197,59 @@ def spanner_graph(self, line: str, cell: str):
parser.add_argument("--infra_db_path",
action="store_true",
help="Connect to internal Infra Spanner")
parser.add_argument(
"--experimental_host",
type=str,
required=False,
help="Spanner experimental host endpoint",
Comment thread
sagnghos marked this conversation as resolved.
)
parser.add_argument(
"--use_plain_text",
action="store_true",
help="[Experimental Host Only] Use plain text communication for the experimental host",
)
parser.add_argument(
"--ca_certificate",
type=str,
required=False,
help="[Experimental Host Only] CA certificate path for the experimental host",
)
parser.add_argument(
"--client_certificate",
type=str,
required=False,
help="[Experimental Host Only] Client certificate path for the experimental host",
)
parser.add_argument(
"--client_key",
type=str,
required=False,
help="[Experimental Host Only] Client key path for the experimental host",
)

try:
args = parser.parse_args(line.split())
selector = None
if not args.experimental_host:
if args.use_plain_text or args.ca_certificate or args.client_certificate or args.client_key:
raise ValueError("use_plain_text, ca_certificate, client_certificate and client_key are only supported for Experimental Host")
if args.mock:
selector = DatabaseSelector.mock()
elif args.infra_db_path:
selector = DatabaseSelector.infra(infra_db_path=args.database)
elif args.experimental_host:
Comment thread
sagnghos marked this conversation as resolved.
if args.use_plain_text:
if args.ca_certificate or args.client_certificate or args.client_key:
raise ValueError("When use_plain_text is true, no other certificate parameters should be set.")
elif not args.ca_certificate:
raise ValueError("Either use_plain_text must be true or ca_certificate must be set.")

if bool(args.client_certificate) != bool(args.client_key):
raise ValueError("client_certificate and client_key must both be provided together.")

selector = DatabaseSelector.experimental_host(
experimental_host=args.experimental_host, database=args.database, use_plain_text=args.use_plain_text, ca_certificate=args.ca_certificate, client_certificate=args.client_certificate, client_key=args.client_key
)
else:
if not (args.project and args.instance):
raise ValueError(
Expand All @@ -226,6 +271,7 @@ def spanner_graph(self, line: str, cell: str):
print(f"Error: {e}")
print(" %%spanner_graph --project <proj> --instance <inst> --database <db>")
print(" %%spanner_graph --mock")
print(" %%spanner_graph --experimental_host <host> --database <db> [--use_plain_text] [--ca_certificate <path>] [--client_certificate <path>] [--client_key <path>]")
print(" Graph query here...")

def load_ipython_extension(ipython):
Expand Down