Skip to content

Commit 2679d81

Browse files
committed
Add a new source concept id warning rule
1 parent 6c81baa commit 2679d81

4 files changed

Lines changed: 317 additions & 3 deletions

File tree

src/fastssv/rules/concept_standardization/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@
99
from .hierarchy_expansion import HierarchyExpansionRule
1010
from .era_table_standard_concepts import EraTableStandardConceptsRule
1111
from .concept_domain_validation import ConceptDomainValidationRule
12+
from .source_concept_id_warning import SourceConceptIdWarningRule
1213

1314
__all__ = [
1415
"StandardConceptEnforcementRule",
1516
"InvalidReasonEnforcementRule",
1617
"HierarchyExpansionRule",
1718
"EraTableStandardConceptsRule",
1819
"ConceptDomainValidationRule",
20+
"SourceConceptIdWarningRule",
1921
]
Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
"""Source Concept ID Usage Warning Rule.
2+
3+
OMOP semantic rule OMOP_022:
4+
The *_source_concept_id columns store the original source vocabulary concept.
5+
For standard analytical queries and cohort identification, use the primary
6+
*_concept_id (standard concept) rather than *_source_concept_id.
7+
8+
Valid uses of source_concept_id:
9+
- Data quality checks
10+
- ETL validation / mapping verification
11+
- Source code exploration
12+
- Provenance tracking
13+
14+
Invalid use (cohort identification):
15+
- SELECT person_id FROM condition_occurrence WHERE condition_source_concept_id = 123
16+
17+
Correct approach:
18+
- SELECT person_id FROM condition_occurrence WHERE condition_concept_id = 456
19+
"""
20+
21+
from typing import Dict, List, Set, Tuple
22+
23+
from sqlglot import exp
24+
25+
from fastssv.core.base import Rule, RuleViolation, Severity
26+
from fastssv.core.helpers import (
27+
extract_aliases,
28+
normalize_name,
29+
parse_sql,
30+
resolve_table_col,
31+
uses_table,
32+
)
33+
from fastssv.core.registry import register
34+
35+
36+
SOURCE_CONCEPT_ID_COLUMNS: Set[str] = {
37+
"condition_source_concept_id",
38+
"drug_source_concept_id",
39+
"procedure_source_concept_id",
40+
"measurement_source_concept_id",
41+
"observation_source_concept_id",
42+
"device_source_concept_id",
43+
"visit_source_concept_id",
44+
"specimen_source_concept_id",
45+
}
46+
47+
SOURCE_TO_STANDARD: Dict[str, str] = {
48+
"condition_source_concept_id": "condition_concept_id",
49+
"drug_source_concept_id": "drug_concept_id",
50+
"procedure_source_concept_id": "procedure_concept_id",
51+
"measurement_source_concept_id": "measurement_concept_id",
52+
"observation_source_concept_id": "observation_concept_id",
53+
"device_source_concept_id": "device_concept_id",
54+
"visit_source_concept_id": "visit_concept_id",
55+
"specimen_source_concept_id": "specimen_concept_id",
56+
}
57+
58+
59+
def _is_in_where_or_having(node: exp.Expression) -> bool:
60+
parent = node.parent
61+
while parent:
62+
if isinstance(parent, (exp.Where, exp.Having)):
63+
return True
64+
if isinstance(parent, exp.Join):
65+
return False
66+
parent = parent.parent
67+
return False
68+
69+
70+
def _find_source_filters(
71+
tree: exp.Expression,
72+
aliases: Dict[str, str],
73+
) -> List[str]:
74+
issues: List[str] = []
75+
seen: Set[Tuple[str, str]] = set()
76+
77+
for node in tree.walk():
78+
if not isinstance(node, (exp.EQ, exp.NEQ, exp.In)):
79+
continue
80+
81+
if not _is_in_where_or_having(node):
82+
continue
83+
84+
left = node.this
85+
right = node.expression
86+
87+
for col_node, _ in [(left, right), (right, left)]:
88+
if not isinstance(col_node, exp.Column):
89+
continue
90+
91+
_, col = resolve_table_col(col_node, aliases)
92+
col_norm = normalize_name(col)
93+
94+
if col_norm not in SOURCE_CONCEPT_ID_COLUMNS:
95+
continue
96+
97+
key = (col_norm, node.sql())
98+
if key in seen:
99+
continue
100+
seen.add(key)
101+
102+
standard_col = SOURCE_TO_STANDARD.get(
103+
col_norm,
104+
col_norm.replace("_source_", "_"),
105+
)
106+
107+
issues.append(
108+
f"Filtering on '{col_norm}' for cohort/analytical logic is discouraged. "
109+
f"Use '{standard_col}' (standard concept) instead. "
110+
f"Source concept IDs are intended for ETL validation, mapping QA, or provenance analysis."
111+
)
112+
113+
return issues
114+
115+
116+
def _is_likely_analytical_query(tree: exp.Expression) -> bool:
117+
# Cohort queries typically involve PERSON or person_id
118+
if uses_table(tree, "person"):
119+
return True
120+
121+
for col in tree.find_all(exp.Column):
122+
if normalize_name(col.name) == "person_id":
123+
return True
124+
125+
return False
126+
127+
128+
def _is_source_exploration_query(tree: exp.Expression) -> bool:
129+
select = tree.find(exp.Select)
130+
if not select:
131+
return False
132+
133+
for expr in select.expressions:
134+
for col in expr.find_all(exp.Column):
135+
name = normalize_name(col.name)
136+
137+
if (
138+
"source_value" in name
139+
or name.endswith("_source_concept_id")
140+
):
141+
return True
142+
143+
return False
144+
145+
146+
@register
147+
class SourceConceptIdWarningRule(Rule):
148+
"""Production-grade validation for source_concept_id misuse."""
149+
150+
rule_id = "semantic.source_concept_id_warning"
151+
name = "Source Concept ID Not For Analytical Filtering"
152+
description = (
153+
"Avoid using *_source_concept_id for cohort definition or analytical filtering. "
154+
"Use standard *_concept_id instead."
155+
)
156+
severity = Severity.WARNING
157+
suggested_fix = (
158+
"Replace *_source_concept_id with corresponding standard *_concept_id column. "
159+
"If this is for ETL validation or source exploration, this warning can be ignored."
160+
)
161+
162+
def validate(self, sql: str, dialect: str = "postgres") -> List[RuleViolation]:
163+
violations: List[RuleViolation] = []
164+
165+
trees, error = parse_sql(sql, dialect)
166+
if error:
167+
return []
168+
169+
for tree in trees:
170+
if not tree:
171+
continue
172+
173+
aliases = extract_aliases(tree)
174+
175+
# --- Context detection ---
176+
is_exploration = _is_source_exploration_query(tree)
177+
178+
if is_exploration:
179+
continue
180+
181+
issues = _find_source_filters(tree, aliases)
182+
183+
for issue in issues:
184+
violations.append(self.create_violation(message=issue))
185+
186+
return violations
187+
188+
189+
__all__ = ["SourceConceptIdWarningRule"]

tasks/IMPLEMENTATION_STATUS.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ This checklist tracks which rules from `omop_rules.json` have been implemented i
99

1010
**Statistics:**
1111
- Total rules in JSON: 350+
12-
- Implemented: ~20 core rules
12+
- Implemented: ~21 core rules
1313
- Coverage: ~8-10%
1414

1515
---
@@ -60,8 +60,8 @@ This checklist tracks which rules from `omop_rules.json` have been implemented i
6060
- *Implemented as: `joins/join_path_validation.py`*
6161
- [x] **OMOP_021**: measurement_value_as_number_with_unit
6262
- *Implemented as: `domain_specific/measurement/measurement_unit_validation.py`*
63-
- [ ] **OMOP_022**: source_concept_id_not_for_primary_filtering
64-
- *Suggested group: `concept_standardization/`*
63+
- [x] **OMOP_022**: source_concept_id_not_for_primary_filtering
64+
- *Implemented as: `concept_standardization/source_concept_id_warning.py`*
6565
- [ ] **OMOP_023**: death_table_primary_key_is_person_id
6666
- *Suggested group: `domain_specific/death/`*
6767
- [ ] **OMOP_024**: cohort_subject_id_joins_to_person_id

tests/test_semantic_validation.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1774,5 +1774,128 @@ def test_no_concept_join_not_flagged(self) -> None:
17741774
self.assertEqual(violations, [])
17751775

17761776

1777+
class SourceConceptIdWarningTests(unittest.TestCase):
1778+
"""Tests for the source_concept_id warning rule (OMOP_022)."""
1779+
1780+
def _run_rule(self, sql: str, dialect: str = "postgres") -> list:
1781+
from fastssv.core.registry import get_rule
1782+
rule = get_rule("semantic.source_concept_id_warning")()
1783+
return rule.validate(sql, dialect)
1784+
1785+
def test_condition_source_concept_id_filter_warns(self) -> None:
1786+
"""Filtering on condition_source_concept_id should warn."""
1787+
sql = """
1788+
SELECT DISTINCT person_id
1789+
FROM condition_occurrence
1790+
WHERE condition_source_concept_id = 44836914
1791+
"""
1792+
violations = self._run_rule(sql)
1793+
self.assertTrue(len(violations) > 0)
1794+
self.assertTrue("condition_source_concept_id" in violations[0].message)
1795+
self.assertTrue("condition_concept_id" in violations[0].message)
1796+
1797+
def test_drug_source_concept_id_filter_warns(self) -> None:
1798+
"""Filtering on drug_source_concept_id should warn."""
1799+
sql = """
1800+
SELECT person_id
1801+
FROM drug_exposure
1802+
WHERE drug_source_concept_id = 123456
1803+
"""
1804+
violations = self._run_rule(sql)
1805+
self.assertTrue(len(violations) > 0)
1806+
self.assertTrue("drug_source_concept_id" in violations[0].message)
1807+
1808+
def test_procedure_source_concept_id_in_clause_warns(self) -> None:
1809+
"""Using IN clause with procedure_source_concept_id should warn."""
1810+
sql = """
1811+
SELECT *
1812+
FROM procedure_occurrence
1813+
WHERE procedure_source_concept_id IN (111, 222, 333)
1814+
"""
1815+
violations = self._run_rule(sql)
1816+
self.assertTrue(len(violations) > 0)
1817+
self.assertTrue("procedure_source_concept_id" in violations[0].message)
1818+
1819+
def test_standard_concept_id_does_not_warn(self) -> None:
1820+
"""Using standard *_concept_id should not warn."""
1821+
sql = """
1822+
SELECT person_id
1823+
FROM condition_occurrence
1824+
WHERE condition_concept_id = 201826
1825+
"""
1826+
violations = self._run_rule(sql)
1827+
self.assertEqual(violations, [])
1828+
1829+
def test_source_concept_id_in_select_not_flagged(self) -> None:
1830+
"""Selecting source_concept_id (not filtering) should not warn."""
1831+
sql = """
1832+
SELECT person_id, condition_source_concept_id
1833+
FROM condition_occurrence
1834+
WHERE condition_concept_id = 12345
1835+
"""
1836+
violations = self._run_rule(sql)
1837+
self.assertEqual(violations, [])
1838+
1839+
def test_source_concept_id_in_group_by_not_flagged(self) -> None:
1840+
"""GROUP BY source_concept_id should not warn."""
1841+
sql = """
1842+
SELECT condition_source_concept_id, COUNT(*)
1843+
FROM condition_occurrence
1844+
GROUP BY condition_source_concept_id
1845+
"""
1846+
violations = self._run_rule(sql)
1847+
self.assertEqual(violations, [])
1848+
1849+
def test_measurement_source_concept_id_warns(self) -> None:
1850+
"""Filtering on measurement_source_concept_id should warn."""
1851+
sql = """
1852+
SELECT *
1853+
FROM measurement
1854+
WHERE measurement_source_concept_id = 999
1855+
"""
1856+
violations = self._run_rule(sql)
1857+
self.assertTrue(len(violations) > 0)
1858+
1859+
def test_observation_source_concept_id_warns(self) -> None:
1860+
"""Filtering on observation_source_concept_id should warn."""
1861+
sql = """
1862+
SELECT *
1863+
FROM observation
1864+
WHERE observation_source_concept_id = 888
1865+
"""
1866+
violations = self._run_rule(sql)
1867+
self.assertTrue(len(violations) > 0)
1868+
1869+
def test_multiple_source_concept_id_filters_warns_multiple(self) -> None:
1870+
"""Multiple source_concept_id filters should generate multiple warnings."""
1871+
sql = """
1872+
SELECT co.person_id
1873+
FROM condition_occurrence co
1874+
JOIN drug_exposure de ON co.person_id = de.person_id
1875+
WHERE co.condition_source_concept_id = 111
1876+
AND de.drug_source_concept_id = 222
1877+
"""
1878+
violations = self._run_rule(sql)
1879+
self.assertTrue(len(violations) >= 2)
1880+
1881+
def test_source_concept_id_comparison_operators_warn(self) -> None:
1882+
"""Various comparison operators on source_concept_id should warn."""
1883+
sql = """
1884+
SELECT DISTINCT person_id
1885+
FROM condition_occurrence
1886+
WHERE condition_source_concept_id != 0
1887+
"""
1888+
violations = self._run_rule(sql)
1889+
self.assertTrue(len(violations) > 0)
1890+
1891+
def test_no_clinical_tables_not_flagged(self) -> None:
1892+
"""Query without clinical tables should not trigger."""
1893+
sql = """
1894+
SELECT * FROM concept WHERE concept_id = 12345
1895+
"""
1896+
violations = self._run_rule(sql)
1897+
self.assertEqual(violations, [])
1898+
1899+
17771900
if __name__ == "__main__":
17781901
unittest.main()

0 commit comments

Comments
 (0)