|
| 1 | +"""Measurement Unit Validation Rule. |
| 2 | +
|
| 3 | +OMOP semantic rule: |
| 4 | +When a query filters measurement.value_as_number against a numeric threshold, |
| 5 | +it must also constrain unit_concept_id. |
| 6 | +
|
| 7 | +The measurement table stores numeric results alongside their units. The same |
| 8 | +clinical concept (e.g. blood glucose, HbA1c) can be recorded in different |
| 9 | +units across source systems and ETL pipelines: |
| 10 | +
|
| 11 | + Blood glucose: 5.5 (mmol/L) vs 100 (mg/dL) |
| 12 | + HbA1c: 7.0 (%) vs 53 (mmol/mol) |
| 13 | +
|
| 14 | +Filtering on a numeric threshold without specifying the unit means both |
| 15 | +representations are tested against the same cutoff, silently including or |
| 16 | +excluding patients based on which unit convention was used at their site. |
| 17 | +
|
| 18 | +Correct pattern: |
| 19 | + WHERE m.value_as_number > 7.0 |
| 20 | + AND m.unit_concept_id = 8554 -- % (UCUM) |
| 21 | +""" |
| 22 | + |
| 23 | +from typing import Dict, List |
| 24 | + |
| 25 | +from sqlglot import exp |
| 26 | + |
| 27 | +from fastssv.core.base import Rule, RuleViolation, Severity |
| 28 | +from fastssv.core.helpers import ( |
| 29 | + extract_aliases, |
| 30 | + is_in_where_or_join_clause, |
| 31 | + normalize_name, |
| 32 | + parse_sql, |
| 33 | + resolve_table_col, |
| 34 | + uses_table, |
| 35 | +) |
| 36 | +from fastssv.core.registry import register |
| 37 | + |
| 38 | +# Comparison operators that indicate a numeric threshold filter |
| 39 | +_NUMERIC_COMPARISON_TYPES = (exp.GT, exp.GTE, exp.LT, exp.LTE, exp.EQ) |
| 40 | + |
| 41 | + |
| 42 | +def _find_value_as_number_threshold( |
| 43 | + tree: exp.Expression, |
| 44 | + aliases: Dict[str, str], |
| 45 | +) -> bool: |
| 46 | + """Return True if query compares value_as_number against a numeric literal. |
| 47 | +
|
| 48 | + Handles both aliased (m.value_as_number > 7) and unqualified |
| 49 | + (value_as_number > 7) column references, and both comparison directions. |
| 50 | + """ |
| 51 | + for node in tree.find_all(_NUMERIC_COMPARISON_TYPES): |
| 52 | + if not is_in_where_or_join_clause(node): |
| 53 | + continue |
| 54 | + |
| 55 | + left, right = node.left, node.right |
| 56 | + |
| 57 | + # Check both orientations: col OP literal and literal OP col |
| 58 | + for col_side, val_side in ((left, right), (right, left)): |
| 59 | + if not isinstance(col_side, exp.Column): |
| 60 | + continue |
| 61 | + |
| 62 | + # Accept numeric literals and negated literals (e.g. -1.5) |
| 63 | + if not isinstance(val_side, (exp.Literal, exp.Neg)): |
| 64 | + continue |
| 65 | + |
| 66 | + # Confirm the literal side is numeric (not a string) |
| 67 | + literal = val_side if isinstance(val_side, exp.Literal) else val_side.this |
| 68 | + if not isinstance(literal, exp.Literal) or not literal.is_number: |
| 69 | + continue |
| 70 | + |
| 71 | + table, col = resolve_table_col(col_side, aliases) |
| 72 | + if normalize_name(col) != "value_as_number": |
| 73 | + continue |
| 74 | + |
| 75 | + # Accept if column is unqualified (only measurement in scope) |
| 76 | + # or explicitly from measurement |
| 77 | + if table is None or normalize_name(table) == "measurement": |
| 78 | + return True |
| 79 | + |
| 80 | + # Also handle: value_as_number BETWEEN low AND high |
| 81 | + for node in tree.find_all(exp.Between): |
| 82 | + if not is_in_where_or_join_clause(node): |
| 83 | + continue |
| 84 | + col_node = node.this |
| 85 | + if not isinstance(col_node, exp.Column): |
| 86 | + continue |
| 87 | + low = node.args.get("low") |
| 88 | + high = node.args.get("high") |
| 89 | + if not (isinstance(low, (exp.Literal, exp.Neg)) and isinstance(high, (exp.Literal, exp.Neg))): |
| 90 | + continue |
| 91 | + for side in (low, high): |
| 92 | + literal = side if isinstance(side, exp.Literal) else side.this |
| 93 | + if not (isinstance(literal, exp.Literal) and literal.is_number): |
| 94 | + break |
| 95 | + else: |
| 96 | + table, col = resolve_table_col(col_node, aliases) |
| 97 | + if normalize_name(col) != "value_as_number": |
| 98 | + continue |
| 99 | + if table is None or normalize_name(table) == "measurement": |
| 100 | + return True |
| 101 | + |
| 102 | + return False |
| 103 | + |
| 104 | + |
| 105 | +def _has_unit_concept_constraint(tree: exp.Expression) -> bool: |
| 106 | + """Return True if unit_concept_id participates in a WHERE or JOIN condition. |
| 107 | +
|
| 108 | + We intentionally require unit_concept_id to appear in a filtering/join context, |
| 109 | + not just in the SELECT list. This aligns with the rule's goal of ensuring the |
| 110 | + numeric threshold is evaluated under an explicit unit constraint, rather than |
| 111 | + merely being aware of the column's existence. |
| 112 | + """ |
| 113 | + for col in tree.find_all(exp.Column): |
| 114 | + if normalize_name(col.name) == "unit_concept_id" and is_in_where_or_join_clause(col): |
| 115 | + return True |
| 116 | + return False |
| 117 | + |
| 118 | + |
| 119 | +@register |
| 120 | +class MeasurementUnitValidationRule(Rule): |
| 121 | + """Detects numeric measurement threshold filters missing a unit_concept_id constraint.""" |
| 122 | + |
| 123 | + rule_id = "semantic.measurement_unit_validation" |
| 124 | + name = "Measurement Unit Validation" |
| 125 | + description = ( |
| 126 | + "Detects queries that filter measurement.value_as_number against a numeric " |
| 127 | + "threshold without also constraining unit_concept_id. The same measurement " |
| 128 | + "concept can be stored in different units across sites (e.g. glucose in " |
| 129 | + "mmol/L vs mg/dL). A numeric threshold applied without a unit filter silently " |
| 130 | + "mixes patients measured in different unit conventions." |
| 131 | + ) |
| 132 | + severity = Severity.WARNING |
| 133 | + suggested_fix = ( |
| 134 | + "Add a unit_concept_id constraint alongside the numeric threshold: " |
| 135 | + "AND m.unit_concept_id = <unit_concept_id>. " |
| 136 | + "Look up the correct UCUM unit concept ID in the OMOP vocabulary " |
| 137 | + "(e.g. SELECT concept_id FROM concept WHERE concept_code = '%' " |
| 138 | + "AND vocabulary_id = 'UCUM')." |
| 139 | + ) |
| 140 | + |
| 141 | + def validate(self, sql: str, dialect: str = "postgres") -> List[RuleViolation]: |
| 142 | + """Validate SQL and return list of violations.""" |
| 143 | + violations = [] |
| 144 | + |
| 145 | + trees, parse_error = parse_sql(sql, dialect) |
| 146 | + if parse_error: |
| 147 | + return [] |
| 148 | + |
| 149 | + for tree in trees: |
| 150 | + if tree is None: |
| 151 | + continue |
| 152 | + |
| 153 | + # Only examine queries that reference the measurement table |
| 154 | + if not uses_table(tree, "measurement"): |
| 155 | + continue |
| 156 | + |
| 157 | + aliases = extract_aliases(tree) |
| 158 | + |
| 159 | + if not _find_value_as_number_threshold(tree, aliases): |
| 160 | + continue |
| 161 | + |
| 162 | + if _has_unit_concept_constraint(tree): |
| 163 | + continue |
| 164 | + |
| 165 | + violations.append(self.create_violation( |
| 166 | + message=( |
| 167 | + "Query filters measurement.value_as_number against a numeric " |
| 168 | + "threshold without constraining unit_concept_id. The same " |
| 169 | + "measurement concept can be stored in different units across " |
| 170 | + "sites (e.g. glucose: 5.5 mmol/L vs 100 mg/dL), making the " |
| 171 | + "numeric threshold unreliable without a unit filter." |
| 172 | + ), |
| 173 | + details={ |
| 174 | + "column": "measurement.value_as_number", |
| 175 | + "missing": "unit_concept_id constraint", |
| 176 | + }, |
| 177 | + )) |
| 178 | + |
| 179 | + return violations |
| 180 | + |
| 181 | + |
| 182 | +__all__ = ["MeasurementUnitValidationRule"] |
0 commit comments