From 968c27493146c1b3ab38903743d4a0854660044b Mon Sep 17 00:00:00 2001 From: Shihao Shenzhang Date: Thu, 15 May 2025 17:48:39 +0100 Subject: [PATCH 1/5] Fixed the CTE validation issue --- src/omcp/sql_validator.py | 52 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 50 insertions(+), 2 deletions(-) diff --git a/src/omcp/sql_validator.py b/src/omcp/sql_validator.py index 8a4bc24..5afa96b 100644 --- a/src/omcp/sql_validator.py +++ b/src/omcp/sql_validator.py @@ -6,6 +6,7 @@ import sqlglot.expressions as exp import typing as t import omcp.exceptions as ex +from sqlglot.optimizer.scope import build_scope OMOP_TABLES = [ "care_site", @@ -91,12 +92,31 @@ def _check_is_select_query( "Only SELECT statements are allowed for security reasons." ) - def _check_is_omop_table(self, tables: t.List[exp.Table]) -> ex.TableNotFoundError: + def _check_is_omop_table(self, parsed_sql: exp.Expression) -> ex.TableNotFoundError: + """ + Check if all real table references in the query are OMOP CDM tables and + ignores CTEs (defined in WITH clauses). + + Args: + parsed_sql (exp.Expression): The parsed SQL expression. + + Return: + TableNotFoundError: If any non-OMOP tables are found. + """ + root = build_scope(parsed_sql) + tables = [ + source + for scope in root.traverse() + for alias, (node, source) in scope.selected_sources.items() + if isinstance(source, exp.Table) + ] + not_omop_tables = [ table.name.lower() for table in tables if table.name.lower() not in OMOP_TABLES ] + if not_omop_tables: return ex.TableNotFoundError( f"Tables not found in OMOP CDM: {', '.join(not_omop_tables)}" @@ -216,7 +236,7 @@ def validate_sql(self, sql: str): errors.append(ex.ColumnNotFoundError("No columns found in the query.")) # Check is OMOP table - errors.append(self._check_is_omop_table(tables)) + errors.append(self._check_is_omop_table(parsed_sql)) # Check for excluded tables errors.append(self._check_unauthorized_tables(tables)) @@ -234,3 +254,31 @@ def validate_sql(self, sql: str): finally: errors = list(filter(None, errors)) # Remove None values from the list return errors + + +if __name__ == "__main__": + query = """ + WITH lisinopril_patients AS ( + SELECT DISTINCT person_id + FROM base.drug_exposure d + JOIN base.concept c ON d.drug_concept_id = c.concept_id + WHERE c.concept_name LIKE '%lisinopril%' + )SELECT + c.concept_name as condition_name, + COUNT(DISTINCT co.person_id) as patient_count, + COUNT(*) as occurrence_count + FROM + base.condition_occurrence co + JOIN + base.concept c ON co.condition_concept_id = c.concept_id + JOIN + lisinopril_patients lp ON co.person_id = lp.person_id + GROUP BY + c.concept_name + ORDER BY + patient_count DESC, occurrence_count DESC + LIMIT 20 + """ + sql_validator = SQLValidator( allow_source_value_columns=False, exclude_tables=None, exclude_columns=None) + sql_validator.validate_sql(query) + From 678e3c349817d909e36a476d9cff4f1b9149cff9 Mon Sep 17 00:00:00 2001 From: Shihao Shenzhang Date: Thu, 15 May 2025 17:55:16 +0100 Subject: [PATCH 2/5] Removed debug logs --- src/omcp/sql_validator.py | 28 ---------------------------- 1 file changed, 28 deletions(-) diff --git a/src/omcp/sql_validator.py b/src/omcp/sql_validator.py index 5afa96b..a023f8b 100644 --- a/src/omcp/sql_validator.py +++ b/src/omcp/sql_validator.py @@ -254,31 +254,3 @@ def validate_sql(self, sql: str): finally: errors = list(filter(None, errors)) # Remove None values from the list return errors - - -if __name__ == "__main__": - query = """ - WITH lisinopril_patients AS ( - SELECT DISTINCT person_id - FROM base.drug_exposure d - JOIN base.concept c ON d.drug_concept_id = c.concept_id - WHERE c.concept_name LIKE '%lisinopril%' - )SELECT - c.concept_name as condition_name, - COUNT(DISTINCT co.person_id) as patient_count, - COUNT(*) as occurrence_count - FROM - base.condition_occurrence co - JOIN - base.concept c ON co.condition_concept_id = c.concept_id - JOIN - lisinopril_patients lp ON co.person_id = lp.person_id - GROUP BY - c.concept_name - ORDER BY - patient_count DESC, occurrence_count DESC - LIMIT 20 - """ - sql_validator = SQLValidator( allow_source_value_columns=False, exclude_tables=None, exclude_columns=None) - sql_validator.validate_sql(query) - From ece4339b7583b06ca72f958ecea4870ca7994b8c Mon Sep 17 00:00:00 2001 From: Shihao Shenzhang Date: Fri, 16 May 2025 20:23:19 +0100 Subject: [PATCH 3/5] Minor errors amended --- docs/installation.md | 14 +++++++------- src/omcp/sql_validator.py | 8 ++++---- tests/test_sql_validator.py | 37 +++++++++++++++++++++++++++++++++++++ 3 files changed, 48 insertions(+), 11 deletions(-) diff --git a/docs/installation.md b/docs/installation.md index 9cd9b6a..c82f77b 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -179,7 +179,7 @@ Start Claude Desktop and the OMCP server should automatically be available for u ## Integrating with Localhost models 🦙 -### Step 1: Install Ollama +### Step 1: Install Ollama Download and install [Ollama](https://ollama.com/) from the official website. To check if Ollama has been installed properly, open a terminal and type: ```bash @@ -195,7 +195,7 @@ Go to the [Ollama models](https://ollama.com/search) and copy the name of the mo ollama pull cogito:14b ``` -The process will take a while depending on the size of the model, but when it finishes type in the terminal: +The process will take a while depending on the size of the model, but when it finishes type in the terminal: ```bash ollama list @@ -207,7 +207,7 @@ if everything went well, you should see the model you have pulled from Ollama. I We are going to use [Librechat](https://www.librechat.ai/) as the end-user interface. 1. In the OMCP project, navigate to the directory where the `main.py` file is located, go to the function `def main()` and change the `transport` from `stdio` to `sse`. - + ```python def main(): """Main function to run the MCP server.""" @@ -218,11 +218,11 @@ We are going to use [Librechat](https://www.librechat.ai/) as the end-user inter ``` 2. In the same directory where `main.py` is located, run the following command: - + ```python python main.py ``` - + You should see something like this in the terminal: ``` INFO: Started server process [96250] @@ -272,11 +272,11 @@ We are going to use [Librechat](https://www.librechat.ai/) as the end-user inter [+] Running 5/5 ✔ Container vectordb Started ✔ Container chat-meilisearch Started - ✔ Container chat-mongodb Started + ✔ Container chat-mongodb Started ✔ Container rag_api Started ✔ Container LibreChat Started ``` - + 9. Finally, go to the browser and type `localhost:3080`, if it is the first time using Librechat, you need to create an account. Then select the model you pulled, in my case `cogito:14b` and in the chat, just next to the `Code Interpreter` you should see the MCP Tool, click on it and select `omop_mcp`. You should see something like this: diff --git a/src/omcp/sql_validator.py b/src/omcp/sql_validator.py index a023f8b..da4f4c8 100644 --- a/src/omcp/sql_validator.py +++ b/src/omcp/sql_validator.py @@ -94,7 +94,7 @@ def _check_is_select_query( def _check_is_omop_table(self, parsed_sql: exp.Expression) -> ex.TableNotFoundError: """ - Check if all real table references in the query are OMOP CDM tables and + Check if all real table references in the query are OMOP CDM tables and ignores CTEs (defined in WITH clauses). Args: @@ -105,9 +105,9 @@ def _check_is_omop_table(self, parsed_sql: exp.Expression) -> ex.TableNotFoundEr """ root = build_scope(parsed_sql) tables = [ - source - for scope in root.traverse() - for alias, (node, source) in scope.selected_sources.items() + source + for scope in root.traverse() + for alias, (node, source) in scope.selected_sources.items() if isinstance(source, exp.Table) ] diff --git a/tests/test_sql_validator.py b/tests/test_sql_validator.py index a4ab381..230c105 100644 --- a/tests/test_sql_validator.py +++ b/tests/test_sql_validator.py @@ -1,6 +1,7 @@ import pytest from omcp.sql_validator import SQLValidator import omcp.exceptions as ex +import sqlglot @pytest.fixture @@ -125,3 +126,39 @@ def test_source_concept_id_columns(self, validator): assert len(errors) == 1 assert isinstance(errors[0], ex.UnauthorizedColumnError) assert "Source value columns are not allowed" in str(errors[0]) + + def test_check_is_omop_table_ignores_cte(self, validator): + """Test that _check_is_omop_table ignores CTEs""" + sql = """ + WITH patient AS ( + SELECT person_id, gender_concept_id FROM person + ), + visits AS ( + SELECT visit_occurrence_id, person_id FROM visit_occurrence + ) + SELECT p.person_id, v.visit_occurrence_id + FROM patient p + JOIN visits v ON p.person_id = v.person_id + """ + parsed_sql = sqlglot.parse_one(sql) + error = validator._check_is_omop_table(parsed_sql) + assert error is None, f"Expected no error, got: {error}" + + def test_check_is_omop_table_ignores_multiple_ctes(self, validator): + """Test that _check_is_omop_table ignores multiple CTEs with non-OMOP tables""" + sql = """ + WITH temp_users AS ( + SELECT person_id, year_of_birth, gender_concept_id + FROM person + ), + temp_visits AS ( + SELECT visit_occurrence_id, person_id, visit_start_date + FROM visit_occurrence + ) + SELECT p.person_id, p.year_of_birth, v.visit_start_date + FROM temp_users p + JOIN temp_visits v ON p.person_id = v.person_id + """ + parsed_sql = sqlglot.parse_one(sql) + error = validator._check_is_omop_table(parsed_sql) + assert error is None, f"Expected no error, got: {error}" From 689e3989c2206e5fe07df024633c3bedc023f921 Mon Sep 17 00:00:00 2001 From: Shihao Shenzhang Date: Fri, 16 May 2025 20:34:57 +0100 Subject: [PATCH 4/5] Refactored the test code --- tests/test_sql_validator.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/tests/test_sql_validator.py b/tests/test_sql_validator.py index 230c105..c1d3a07 100644 --- a/tests/test_sql_validator.py +++ b/tests/test_sql_validator.py @@ -140,9 +140,8 @@ def test_check_is_omop_table_ignores_cte(self, validator): FROM patient p JOIN visits v ON p.person_id = v.person_id """ - parsed_sql = sqlglot.parse_one(sql) - error = validator._check_is_omop_table(parsed_sql) - assert error is None, f"Expected no error, got: {error}" + errors = validator.validate_sql(sql) + assert len(errors) == 0, f"Expected no errors, got: {errors}" def test_check_is_omop_table_ignores_multiple_ctes(self, validator): """Test that _check_is_omop_table ignores multiple CTEs with non-OMOP tables""" @@ -159,6 +158,6 @@ def test_check_is_omop_table_ignores_multiple_ctes(self, validator): FROM temp_users p JOIN temp_visits v ON p.person_id = v.person_id """ - parsed_sql = sqlglot.parse_one(sql) - error = validator._check_is_omop_table(parsed_sql) - assert error is None, f"Expected no error, got: {error}" + + errors = validator.validate_sql(sql) + assert len(errors) == 0, f"Expected no errors, got: {errors}" From 63cbeb14de9832f60197718f964dd05f1b56100d Mon Sep 17 00:00:00 2001 From: Shihao Shenzhang Date: Fri, 16 May 2025 20:37:18 +0100 Subject: [PATCH 5/5] pre-commit --- tests/test_sql_validator.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_sql_validator.py b/tests/test_sql_validator.py index c1d3a07..2c8ce08 100644 --- a/tests/test_sql_validator.py +++ b/tests/test_sql_validator.py @@ -1,7 +1,6 @@ import pytest from omcp.sql_validator import SQLValidator import omcp.exceptions as ex -import sqlglot @pytest.fixture