diff --git a/pyproject.toml b/pyproject.toml index 56d66ecff5..c9d5fd4447 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ dependencies = [ "requests", "rich[jupyter]", "ruamel.yaml", - "sqlglot~=30.0.1", + "sqlglot~=30.2.1", "tenacity", "time-machine", "json-stream" diff --git a/sqlmesh/core/config/connection.py b/sqlmesh/core/config/connection.py index 7a002faebb..07e8be2908 100644 --- a/sqlmesh/core/config/connection.py +++ b/sqlmesh/core/config/connection.py @@ -2341,23 +2341,20 @@ def init(cursor: t.Any) -> None: return init +_CONNECTION_CONFIG_EXCLUDE: t.Set[t.Type[ConnectionConfig]] = { + ConnectionConfig, # type: ignore[type-abstract] + BaseDuckDBConnectionConfig, # type: ignore[type-abstract] +} + CONNECTION_CONFIG_TO_TYPE = { # Map all subclasses of ConnectionConfig to the value of their `type_` field. tpe.all_field_infos()["type_"].default: tpe - for tpe in subclasses( - __name__, - ConnectionConfig, - exclude={ConnectionConfig, BaseDuckDBConnectionConfig}, - ) + for tpe in subclasses(__name__, ConnectionConfig, exclude=_CONNECTION_CONFIG_EXCLUDE) } DIALECT_TO_TYPE = { tpe.all_field_infos()["type_"].default: tpe.DIALECT - for tpe in subclasses( - __name__, - ConnectionConfig, - exclude={ConnectionConfig, BaseDuckDBConnectionConfig}, - ) + for tpe in subclasses(__name__, ConnectionConfig, exclude=_CONNECTION_CONFIG_EXCLUDE) } INIT_DISPLAY_INFO_TO_TYPE = { @@ -2365,11 +2362,7 @@ def init(cursor: t.Any) -> None: tpe.DISPLAY_ORDER, tpe.DISPLAY_NAME, ) - for tpe in subclasses( - __name__, - ConnectionConfig, - exclude={ConnectionConfig, BaseDuckDBConnectionConfig}, - ) + for tpe in subclasses(__name__, ConnectionConfig, exclude=_CONNECTION_CONFIG_EXCLUDE) } diff --git a/sqlmesh/core/config/scheduler.py b/sqlmesh/core/config/scheduler.py index 970defee62..9d9d1d3c79 100644 --- a/sqlmesh/core/config/scheduler.py +++ b/sqlmesh/core/config/scheduler.py @@ -144,9 +144,10 @@ def get_default_catalog_per_gateway(self, context: GenericContext) -> t.Dict[str return default_catalogs_per_gateway -SCHEDULER_CONFIG_TO_TYPE = { +SCHEDULER_CONFIG_TO_TYPE: t.Dict[str, t.Type[SchedulerConfig]] = { tpe.all_field_infos()["type_"].default: tpe for tpe in subclasses(__name__, BaseConfig, exclude={BaseConfig}) + if issubclass(tpe, SchedulerConfig) } diff --git a/sqlmesh/core/dialect.py b/sqlmesh/core/dialect.py index 3e8f4fe9a7..565c629789 100644 --- a/sqlmesh/core/dialect.py +++ b/sqlmesh/core/dialect.py @@ -14,6 +14,7 @@ from sqlglot.dialects.dialect import DialectType from sqlglot.dialects import DuckDB, Snowflake, TSQL import sqlglot.dialects.athena as athena +import sqlglot.generators.athena as athena_generators from sqlglot.parsers.athena import AthenaTrinoParser from sqlglot.helper import seq_get from sqlglot.optimizer.normalize_identifiers import normalize_identifiers @@ -1048,8 +1049,8 @@ def extend_sqlglot() -> None: if dialect == athena.Athena: tokenizers.add(athena._TrinoTokenizer) parsers.add(AthenaTrinoParser) - generators.add(athena._TrinoGenerator) - generators.add(athena._HiveGenerator) + generators.add(athena_generators.AthenaTrinoGenerator) + generators.add(athena_generators._HiveGenerator) if hasattr(dialect, "Tokenizer"): tokenizers.add(dialect.Tokenizer) diff --git a/sqlmesh/core/linter/rules/builtin.py b/sqlmesh/core/linter/rules/builtin.py index 4547ac0528..8dc4172f9f 100644 --- a/sqlmesh/core/linter/rules/builtin.py +++ b/sqlmesh/core/linter/rules/builtin.py @@ -318,4 +318,5 @@ def check_model(self, model: Model) -> t.Optional[RuleViolation]: return None -BUILTIN_RULES = RuleSet(subclasses(__name__, Rule, exclude={Rule})) +_RULE_EXCLUDE: t.Set[t.Type[Rule]] = {Rule} # type: ignore[type-abstract] +BUILTIN_RULES = RuleSet(subclasses(__name__, Rule, exclude=_RULE_EXCLUDE)) diff --git a/sqlmesh/core/loader.py b/sqlmesh/core/loader.py index 4b7b1bac02..cb951b4f9e 100644 --- a/sqlmesh/core/loader.py +++ b/sqlmesh/core/loader.py @@ -840,7 +840,8 @@ def _load_linting_rules(self) -> RuleSet: if os.path.getsize(path): self._track_file(path) module = import_python_file(path, self.config_path) - module_rules = subclasses(module.__name__, Rule, exclude={Rule}) + _rule_exclude: t.Set[t.Type[Rule]] = {Rule} # type: ignore[type-abstract] + module_rules = subclasses(module.__name__, Rule, exclude=_rule_exclude) for user_rule in module_rules: user_rules[user_rule.name] = user_rule diff --git a/sqlmesh/core/macros.py b/sqlmesh/core/macros.py index 888acbb8eb..9370bffdeb 100644 --- a/sqlmesh/core/macros.py +++ b/sqlmesh/core/macros.py @@ -294,7 +294,9 @@ def evaluate_macros( return node transformed = exp.replace_tree( - expression.copy(), evaluate_macros, prune=lambda n: isinstance(n, exp.Lambda) + expression.copy(), + evaluate_macros, # type: ignore[arg-type] + prune=lambda n: isinstance(n, exp.Lambda), ) if changed: diff --git a/sqlmesh/core/renderer.py b/sqlmesh/core/renderer.py index 7683956064..9f403cbcb4 100644 --- a/sqlmesh/core/renderer.py +++ b/sqlmesh/core/renderer.py @@ -331,7 +331,7 @@ def _resolve_table( deployability_index: t.Optional[DeployabilityIndex] = None, ) -> exp.Table: table = exp.replace_tables( - exp.maybe_parse(table_name, into=exp.Table, dialect=self._dialect), + t.cast(exp.Table, exp.maybe_parse(table_name, into=exp.Table, dialect=self._dialect)), { **self._to_table_mapping((snapshots or {}).values(), deployability_index), **(table_mapping or {}), diff --git a/sqlmesh/core/table_diff.py b/sqlmesh/core/table_diff.py index df99227f89..97cb0c19ba 100644 --- a/sqlmesh/core/table_diff.py +++ b/sqlmesh/core/table_diff.py @@ -255,12 +255,13 @@ def __init__( self.source_alias = source_alias self.target_alias = target_alias + cols: t.List[str] = ensure_list(skip_columns) self.skip_columns = { normalize_identifiers( - exp.parse_identifier(t.cast(str, col)), + exp.parse_identifier(col), dialect=self.model_dialect or self.dialect, ).name - for col in ensure_list(skip_columns) + for col in cols } self._on = on diff --git a/sqlmesh/core/test/definition.py b/sqlmesh/core/test/definition.py index 629e8f8d5b..e6fd130f1b 100644 --- a/sqlmesh/core/test/definition.py +++ b/sqlmesh/core/test/definition.py @@ -612,20 +612,27 @@ def _concurrent_render_context(self) -> t.Iterator[None]: - Globally patch the SQLGlot dialect so that any date/time nodes are evaluated at the `execution_time` during generation """ import time_machine + from sqlglot.generator import _DISPATCH_CACHE lock_ctx: AbstractContextManager = ( self.CONCURRENT_RENDER_LOCK if self.concurrency else nullcontext() ) time_ctx: AbstractContextManager = nullcontext() dialect_patch_ctx: AbstractContextManager = nullcontext() + dispatch_patch_ctx: AbstractContextManager = nullcontext() if self._execution_time: + generator_class = self._test_adapter_dialect.generator_class time_ctx = time_machine.travel(self._execution_time, tick=False) - dialect_patch_ctx = patch.dict( - self._test_adapter_dialect.generator_class.TRANSFORMS, self._transforms - ) + dialect_patch_ctx = patch.dict(generator_class.TRANSFORMS, self._transforms) + + # sqlglot caches a dispatch table per generator class, so we need to patch + # it as well to ensure the overridden transforms are actually used + dispatch = _DISPATCH_CACHE.get(generator_class) + if dispatch is not None: + dispatch_patch_ctx = patch.dict(dispatch, self._transforms) - with lock_ctx, time_ctx, dialect_patch_ctx: + with lock_ctx, time_ctx, dialect_patch_ctx, dispatch_patch_ctx: yield def _execute(self, query: exp.Query | str) -> pd.DataFrame: diff --git a/sqlmesh/dbt/manifest.py b/sqlmesh/dbt/manifest.py index fce561a24d..aae4de871c 100644 --- a/sqlmesh/dbt/manifest.py +++ b/sqlmesh/dbt/manifest.py @@ -846,7 +846,8 @@ def _build_test_name(node: ManifestNode, dependencies: Dependencies) -> str: continue if isinstance(val, dict): val = list(val.values()) - val = [re.sub("[^0-9a-zA-Z_]+", "_", str(v)) for v in ensure_list(val)] + items: t.List[t.Any] = ensure_list(val) + val = [re.sub("[^0-9a-zA-Z_]+", "_", str(v)) for v in items] arg_val_parts.extend(val) unique_args = "__".join(arg_val_parts) if arg_val_parts else "" unique_args = f"_{unique_args}" if unique_args else "" diff --git a/sqlmesh/utils/pydantic.py b/sqlmesh/utils/pydantic.py index 8bc81e2774..5e3e5f979b 100644 --- a/sqlmesh/utils/pydantic.py +++ b/sqlmesh/utils/pydantic.py @@ -265,9 +265,10 @@ def _get_fields( elif isinstance(v, exp.Expr): expressions = [v] else: + items: t.List[t.Any] = ensure_list(v) expressions = [ parse_one(entry, dialect=dialect) if isinstance(entry, str) else entry # type: ignore[misc] - for entry in ensure_list(v) + for entry in items ] results = [] diff --git a/tests/conftest.py b/tests/conftest.py index 46086444bd..4d1bb23577 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -571,7 +571,7 @@ def ignore(src, names): return [name for name in names if name == ".cache"] def _make_function( - paths: t.Union[t.Union[str, Path], t.Collection[t.Union[str, Path]]], + paths: t.Union[str, Path, t.List[t.Union[str, Path]], t.Tuple[t.Union[str, Path], ...]], ) -> t.List[Path]: paths = ensure_list(paths) all_paths = [Path(p) for p in paths] diff --git a/tests/dbt/test_config.py b/tests/dbt/test_config.py index 5dccd90ed2..9fb2a442a1 100644 --- a/tests/dbt/test_config.py +++ b/tests/dbt/test_config.py @@ -151,7 +151,7 @@ def test_model_to_sqlmesh_fields(dbt_dummy_postgres_config: PostgresConfig): assert kind.on_additive_change == OnAdditiveChange.ALLOW assert ( kind.merge_filter.sql(dialect=model.dialect) # type: ignore - == """55 > "__MERGE_SOURCE__"."b" AND "__MERGE_TARGET__"."session_start" > CURRENT_DATE + INTERVAL '7'""" + == """55 > "__MERGE_SOURCE__"."b" AND "__MERGE_TARGET__"."session_start" > CURRENT_DATE + INTERVAL '7 DAY'""" ) model = model_config.update_with({"dialect": "snowflake"}).to_sqlmesh(context) diff --git a/web/client/src/workers/sqlglot/sqlglot.py b/web/client/src/workers/sqlglot/sqlglot.py index 998d467923..435a79823f 100644 --- a/web/client/src/workers/sqlglot/sqlglot.py +++ b/web/client/src/workers/sqlglot/sqlglot.py @@ -17,7 +17,10 @@ def parse_to_json(sql: str, read: DialectType = None) -> str: return json.dumps( - [exp.dump() if exp else {} for exp in sqlglot.parse(sql, read=read, error_level="ignore")] + [ + exp.dump() if exp else {} + for exp in sqlglot.parse(sql, read=read, error_level=sqlglot.ErrorLevel.IGNORE) + ] ) diff --git a/web/server/watcher.py b/web/server/watcher.py index 588f6c5e22..8bc87c8719 100644 --- a/web/server/watcher.py +++ b/web/server/watcher.py @@ -16,7 +16,6 @@ invalidate_context_cache, ) from web.server.utils import is_relative_to -from sqlglot.helper import ensure_list async def watch_project() -> None: @@ -40,10 +39,10 @@ async def watch_project() -> None: async for entries in awatch( settings.project_path, watch_filter=DefaultFilter( - ignore_paths=ensure_list(DefaultFilter.ignore_paths) + ignore_paths, - ignore_entity_patterns=ensure_list(DefaultFilter.ignore_entity_patterns) + ignore_paths=list(DefaultFilter.ignore_paths or []) + ignore_paths, + ignore_entity_patterns=list(DefaultFilter.ignore_entity_patterns or []) + ignore_entity_patterns, - ignore_dirs=ensure_list(DefaultFilter.ignore_dirs) + ignore_dirs, + ignore_dirs=list(DefaultFilter.ignore_dirs or []) + ignore_dirs, ), ): changes: t.List[models.ArtifactChange] = []