Skip to content

Commit 72edc5d

Browse files
authored
Merge pull request #20 from fastomop/feat/ss-databricks
Feat/ss databricks
2 parents 421e81d + 70c8529 commit 72edc5d

3 files changed

Lines changed: 817 additions & 15 deletions

File tree

src/omcp/db.py

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from ibis.backends.databricks import Backend as DatabricksBackend
1414

1515
from omcp.sql_validator import SQLValidator
16+
from omcp.transpiler import transpile_query
1617

1718
logger = logging.getLogger(__name__)
1819

@@ -113,6 +114,11 @@ def __init__(
113114
self.cdm_schema = cdm_schema
114115
self.vocab_schema = vocab_schema
115116

117+
# Determine database dialect from connection string
118+
self.target_dialect = self._get_dialect_from_connection_string(
119+
connection_string
120+
)
121+
116122
# Try initial connection
117123
logger.info(f"Initializing connection to: {connection_string}")
118124
try:
@@ -127,6 +133,29 @@ def __init__(
127133
except Exception as e:
128134
raise ConnectionError(f"Failed to connect to database: {str(e)}")
129135

136+
def _get_dialect_from_connection_string(self, connection_string: str) -> str:
137+
"""
138+
Determine the SQL dialect from the connection string.
139+
140+
Args:
141+
connection_string: Database connection string
142+
143+
Returns:
144+
SQL dialect name (e.g., 'databricks', 'postgres', 'duckdb')
145+
"""
146+
if connection_string.startswith("databricks://"):
147+
return "databricks"
148+
elif connection_string.startswith("postgres"):
149+
return "postgres"
150+
elif connection_string.startswith("duckdb://"):
151+
return "duckdb"
152+
else:
153+
# Default to postgres for unknown dialects
154+
logger.warning(
155+
f"Unknown dialect for connection string: {connection_string}, defaulting to postgres"
156+
)
157+
return "postgres"
158+
130159
def _ensure_connected(self):
131160
"""Ensure we have a valid database connection."""
132161
with self._conn_lock:
@@ -304,12 +333,34 @@ def read_query(self, query: str) -> str:
304333
errors,
305334
)
306335

336+
# Transpile query if needed (postgres -> databricks, etc.)
337+
# We assume Claude generates queries in postgres dialect by default
338+
source_dialect = "postgres"
339+
transpiled_query = query
340+
341+
if self.target_dialect != source_dialect:
342+
logger.info(
343+
f"Transpiling query from {source_dialect} to {self.target_dialect}"
344+
)
345+
try:
346+
transpiled_query = transpile_query(
347+
query, source_dialect, self.target_dialect
348+
)
349+
logger.debug(f"Original query: {query}")
350+
logger.debug(f"Transpiled query: {transpiled_query}")
351+
except Exception as transpile_error:
352+
logger.warning(
353+
f"Transpilation failed: {transpile_error}, using original query"
354+
)
355+
# If transpilation fails, fall back to original query
356+
transpiled_query = query
357+
307358
# Ensure connected
308359
self._ensure_connected()
309360

310-
# Execute the validated query
361+
# Execute the validated and transpiled query
311362
with self._conn_lock:
312-
result = self._conn.sql(query).limit(self.row_limit)
363+
result = self._conn.sql(transpiled_query).limit(self.row_limit)
313364
df = result.execute()
314365
# Convert dataframe to csv
315366
return df.to_csv(index=False)
@@ -325,7 +376,7 @@ def read_query(self, query: str) -> str:
325376
try:
326377
self._ensure_connected()
327378
with self._conn_lock:
328-
result = self._conn.sql(query).limit(self.row_limit)
379+
result = self._conn.sql(transpiled_query).limit(self.row_limit)
329380
df = result.execute()
330381
return df.to_csv(index=False)
331382
except Exception as retry_error:

src/omcp/main.py

Lines changed: 118 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -345,37 +345,45 @@ def signal_handler(signum, frame):
345345

346346
@mcp_app.tool(
347347
name="Get_Information_Schema",
348-
description="Get the information schema of the OMOP database.",
348+
description="Get the database schema name and type. Returns the schema prefix to use for table references (e.g., 'gold') and the database type (e.g., 'databricks').",
349349
)
350350
@capture_context(tool_name="Get_Information_Schema")
351351
def get_information_schema() -> mcp.types.CallToolResult:
352-
"""Get the information schema of the OMOP database.
352+
"""Get the database schema name and type.
353+
354+
This function returns only the essential schema information needed for SQL generation:
355+
- schema_name: The schema/database prefix to use (e.g., 'gold', 'omop', 'public')
356+
- database_type: The SQL dialect to use (e.g., 'databricks', 'postgres', 'duckdb')
353357
354-
This function retrieves information from the information schema of the OMOP database.
355-
Information is restricted to only tables and columns allowed by the users configuration.
356358
Args:
357359
None
358360
Returns:
359-
List of schemas, tables, columns and data types formatted as a CSV string.
361+
Simple text with schema name and database type
360362
"""
361363
try:
362-
logger.debug("Getting information schema...")
363-
# Note: @capture_context decorator already handles Langfuse tracing
364-
result = db.get_information_schema()
365-
logger.debug("Information schema retrieved successfully")
364+
logger.debug("Getting schema information...")
365+
366+
# Return only the essential information
367+
schema_name = db.cdm_schema
368+
database_type = db.target_dialect
369+
370+
result = f"Schema: {schema_name}\nDatabase Type: {database_type}"
371+
372+
logger.debug(f"Schema info: {result}")
366373
return mcp.types.CallToolResult(
367374
content=[
368375
mcp.types.TextContent(type="text", text=result),
369-
]
376+
],
377+
_meta={"database_type": database_type, "schema_name": schema_name},
370378
)
371379
except Exception as e:
372-
logger.error(f"Failed to retrieve information schema: {e}")
380+
logger.error(f"Failed to retrieve schema information: {e}")
373381
return mcp.types.CallToolResult(
374382
isError=True,
375383
content=[
376384
mcp.types.TextContent(
377385
type="text",
378-
text=f"Failed to retrieve information schema: {str(e)}",
386+
text=f"Failed to retrieve schema information: {str(e)}",
379387
)
380388
],
381389
)
@@ -431,6 +439,104 @@ def read_query(query: str) -> mcp.types.CallToolResult:
431439
)
432440

433441

442+
@mcp_app.tool(
443+
name="Lookup_Drug",
444+
description="Look up drug concepts by name in the OMOP concept table. Returns standardized drug concepts with concept_id, concept_name, concept_code, vocabulary_id, and domain_id.",
445+
)
446+
@capture_context(tool_name="Lookup_Drug")
447+
def lookup_drug(term: str, limit: int = 10) -> mcp.types.CallToolResult:
448+
"""Look up drug concepts by name.
449+
450+
This function searches for drug concepts in the OMOP concept table by partial name match.
451+
Only returns standard, valid drug concepts ordered by name length (shortest first).
452+
453+
Args:
454+
term: Drug name to search for (case-insensitive partial match)
455+
limit: Maximum number of results to return (default: 10)
456+
457+
Returns:
458+
CSV formatted results with: concept_id, concept_name, concept_code, vocabulary_id, domain_id
459+
"""
460+
try:
461+
schema = db.cdm_schema
462+
# Use parameterized query pattern with LIKE
463+
query = f"""
464+
SELECT concept_id, concept_name, concept_code, vocabulary_id, domain_id
465+
FROM {schema}.concept
466+
WHERE LOWER(concept_name) LIKE LOWER('%{term}%')
467+
AND domain_id = 'Drug'
468+
AND standard_concept = 'S'
469+
AND invalid_reason IS NULL
470+
ORDER BY LENGTH(concept_name), concept_name
471+
LIMIT {limit}
472+
"""
473+
logger.info(f"Looking up drug: {term}")
474+
result = db.read_query(query)
475+
logger.info(f"Drug lookup completed for: {term}")
476+
return mcp.types.CallToolResult(
477+
content=[mcp.types.TextContent(type="text", text=result)]
478+
)
479+
except Exception as e:
480+
logger.error(f"Failed to lookup drug '{term}': {e}")
481+
return mcp.types.CallToolResult(
482+
isError=True,
483+
content=[
484+
mcp.types.TextContent(
485+
type="text", text=f"Failed to lookup drug: {str(e)}"
486+
)
487+
],
488+
)
489+
490+
491+
@mcp_app.tool(
492+
name="Lookup_Condition",
493+
description="Look up condition concepts by name in the OMOP concept table. Returns standardized condition concepts with concept_id, concept_name, concept_code, vocabulary_id, and domain_id.",
494+
)
495+
@capture_context(tool_name="Lookup_Condition")
496+
def lookup_condition(term: str, limit: int = 10) -> mcp.types.CallToolResult:
497+
"""Look up condition concepts by name.
498+
499+
This function searches for condition concepts in the OMOP concept table by partial name match.
500+
Only returns standard, valid condition concepts ordered by name length (shortest first).
501+
502+
Args:
503+
term: Condition name to search for (case-insensitive partial match)
504+
limit: Maximum number of results to return (default: 10)
505+
506+
Returns:
507+
CSV formatted results with: concept_id, concept_name, concept_code, vocabulary_id, domain_id
508+
"""
509+
try:
510+
schema = db.cdm_schema
511+
# Use parameterized query pattern with LIKE
512+
query = f"""
513+
SELECT concept_id, concept_name, concept_code, vocabulary_id, domain_id
514+
FROM {schema}.concept
515+
WHERE LOWER(concept_name) LIKE LOWER('%{term}%')
516+
AND domain_id = 'Condition'
517+
AND standard_concept = 'S'
518+
AND invalid_reason IS NULL
519+
ORDER BY LENGTH(concept_name), concept_name
520+
LIMIT {limit}
521+
"""
522+
logger.info(f"Looking up condition: {term}")
523+
result = db.read_query(query)
524+
logger.info(f"Condition lookup completed for: {term}")
525+
return mcp.types.CallToolResult(
526+
content=[mcp.types.TextContent(type="text", text=result)]
527+
)
528+
except Exception as e:
529+
logger.error(f"Failed to lookup condition '{term}': {e}")
530+
return mcp.types.CallToolResult(
531+
isError=True,
532+
content=[
533+
mcp.types.TextContent(
534+
type="text", text=f"Failed to lookup condition: {str(e)}"
535+
)
536+
],
537+
)
538+
539+
434540
def main():
435541
"""Main function to run the MCP server."""
436542
logger.info(f"Starting OMOP MCP Server with {transport_type.upper()} transport...")

0 commit comments

Comments
 (0)