Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ dependencies = [
"requests",
"rich[jupyter]",
"ruamel.yaml",
"sqlglot~=30.0.1",
"sqlglot~=30.2.1",
"tenacity",
"time-machine",
"json-stream"
Expand Down
23 changes: 8 additions & 15 deletions sqlmesh/core/config/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2341,35 +2341,28 @@ def init(cursor: t.Any) -> None:
return init


_CONNECTION_CONFIG_EXCLUDE: t.Set[t.Type[ConnectionConfig]] = {
ConnectionConfig, # type: ignore[type-abstract]
BaseDuckDBConnectionConfig, # type: ignore[type-abstract]
}

CONNECTION_CONFIG_TO_TYPE = {
# Map all subclasses of ConnectionConfig to the value of their `type_` field.
tpe.all_field_infos()["type_"].default: tpe
for tpe in subclasses(
__name__,
ConnectionConfig,
exclude={ConnectionConfig, BaseDuckDBConnectionConfig},
)
for tpe in subclasses(__name__, ConnectionConfig, exclude=_CONNECTION_CONFIG_EXCLUDE)
}

DIALECT_TO_TYPE = {
tpe.all_field_infos()["type_"].default: tpe.DIALECT
for tpe in subclasses(
__name__,
ConnectionConfig,
exclude={ConnectionConfig, BaseDuckDBConnectionConfig},
)
for tpe in subclasses(__name__, ConnectionConfig, exclude=_CONNECTION_CONFIG_EXCLUDE)
}

INIT_DISPLAY_INFO_TO_TYPE = {
tpe.all_field_infos()["type_"].default: (
tpe.DISPLAY_ORDER,
tpe.DISPLAY_NAME,
)
for tpe in subclasses(
__name__,
ConnectionConfig,
exclude={ConnectionConfig, BaseDuckDBConnectionConfig},
)
for tpe in subclasses(__name__, ConnectionConfig, exclude=_CONNECTION_CONFIG_EXCLUDE)
}


Expand Down
3 changes: 2 additions & 1 deletion sqlmesh/core/config/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,10 @@ def get_default_catalog_per_gateway(self, context: GenericContext) -> t.Dict[str
return default_catalogs_per_gateway


SCHEDULER_CONFIG_TO_TYPE = {
SCHEDULER_CONFIG_TO_TYPE: t.Dict[str, t.Type[SchedulerConfig]] = {
tpe.all_field_infos()["type_"].default: tpe
for tpe in subclasses(__name__, BaseConfig, exclude={BaseConfig})
if issubclass(tpe, SchedulerConfig)
}


Expand Down
5 changes: 3 additions & 2 deletions sqlmesh/core/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from sqlglot.dialects.dialect import DialectType
from sqlglot.dialects import DuckDB, Snowflake, TSQL
import sqlglot.dialects.athena as athena
import sqlglot.generators.athena as athena_generators
from sqlglot.parsers.athena import AthenaTrinoParser
from sqlglot.helper import seq_get
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
Expand Down Expand Up @@ -1048,8 +1049,8 @@ def extend_sqlglot() -> None:
if dialect == athena.Athena:
tokenizers.add(athena._TrinoTokenizer)
parsers.add(AthenaTrinoParser)
generators.add(athena._TrinoGenerator)
generators.add(athena._HiveGenerator)
generators.add(athena_generators.AthenaTrinoGenerator)
generators.add(athena_generators._HiveGenerator)

if hasattr(dialect, "Tokenizer"):
tokenizers.add(dialect.Tokenizer)
Expand Down
3 changes: 2 additions & 1 deletion sqlmesh/core/linter/rules/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,4 +318,5 @@ def check_model(self, model: Model) -> t.Optional[RuleViolation]:
return None


BUILTIN_RULES = RuleSet(subclasses(__name__, Rule, exclude={Rule}))
_RULE_EXCLUDE: t.Set[t.Type[Rule]] = {Rule} # type: ignore[type-abstract]
BUILTIN_RULES = RuleSet(subclasses(__name__, Rule, exclude=_RULE_EXCLUDE))
3 changes: 2 additions & 1 deletion sqlmesh/core/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,7 +840,8 @@ def _load_linting_rules(self) -> RuleSet:
if os.path.getsize(path):
self._track_file(path)
module = import_python_file(path, self.config_path)
module_rules = subclasses(module.__name__, Rule, exclude={Rule})
_rule_exclude: t.Set[t.Type[Rule]] = {Rule} # type: ignore[type-abstract]
module_rules = subclasses(module.__name__, Rule, exclude=_rule_exclude)
for user_rule in module_rules:
user_rules[user_rule.name] = user_rule

Expand Down
4 changes: 3 additions & 1 deletion sqlmesh/core/macros.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,9 @@ def evaluate_macros(
return node

transformed = exp.replace_tree(
expression.copy(), evaluate_macros, prune=lambda n: isinstance(n, exp.Lambda)
expression.copy(),
evaluate_macros, # type: ignore[arg-type]
prune=lambda n: isinstance(n, exp.Lambda),
)

if changed:
Expand Down
2 changes: 1 addition & 1 deletion sqlmesh/core/renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def _resolve_table(
deployability_index: t.Optional[DeployabilityIndex] = None,
) -> exp.Table:
table = exp.replace_tables(
exp.maybe_parse(table_name, into=exp.Table, dialect=self._dialect),
t.cast(exp.Table, exp.maybe_parse(table_name, into=exp.Table, dialect=self._dialect)),
{
**self._to_table_mapping((snapshots or {}).values(), deployability_index),
**(table_mapping or {}),
Expand Down
5 changes: 3 additions & 2 deletions sqlmesh/core/table_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,12 +255,13 @@ def __init__(
self.source_alias = source_alias
self.target_alias = target_alias

cols: t.List[str] = ensure_list(skip_columns)
self.skip_columns = {
normalize_identifiers(
exp.parse_identifier(t.cast(str, col)),
exp.parse_identifier(col),
dialect=self.model_dialect or self.dialect,
).name
for col in ensure_list(skip_columns)
for col in cols
}

self._on = on
Expand Down
15 changes: 11 additions & 4 deletions sqlmesh/core/test/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,20 +612,27 @@ def _concurrent_render_context(self) -> t.Iterator[None]:
- Globally patch the SQLGlot dialect so that any date/time nodes are evaluated at the `execution_time` during generation
"""
import time_machine
from sqlglot.generator import _DISPATCH_CACHE

lock_ctx: AbstractContextManager = (
self.CONCURRENT_RENDER_LOCK if self.concurrency else nullcontext()
)
time_ctx: AbstractContextManager = nullcontext()
dialect_patch_ctx: AbstractContextManager = nullcontext()
dispatch_patch_ctx: AbstractContextManager = nullcontext()

if self._execution_time:
generator_class = self._test_adapter_dialect.generator_class
time_ctx = time_machine.travel(self._execution_time, tick=False)
dialect_patch_ctx = patch.dict(
self._test_adapter_dialect.generator_class.TRANSFORMS, self._transforms
)
dialect_patch_ctx = patch.dict(generator_class.TRANSFORMS, self._transforms)

# sqlglot caches a dispatch table per generator class, so we need to patch
# it as well to ensure the overridden transforms are actually used
dispatch = _DISPATCH_CACHE.get(generator_class)
if dispatch is not None:
dispatch_patch_ctx = patch.dict(dispatch, self._transforms)

with lock_ctx, time_ctx, dialect_patch_ctx:
with lock_ctx, time_ctx, dialect_patch_ctx, dispatch_patch_ctx:
yield

def _execute(self, query: exp.Query | str) -> pd.DataFrame:
Expand Down
3 changes: 2 additions & 1 deletion sqlmesh/dbt/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,7 +846,8 @@ def _build_test_name(node: ManifestNode, dependencies: Dependencies) -> str:
continue
if isinstance(val, dict):
val = list(val.values())
val = [re.sub("[^0-9a-zA-Z_]+", "_", str(v)) for v in ensure_list(val)]
items: t.List[t.Any] = ensure_list(val)
val = [re.sub("[^0-9a-zA-Z_]+", "_", str(v)) for v in items]
arg_val_parts.extend(val)
unique_args = "__".join(arg_val_parts) if arg_val_parts else ""
unique_args = f"_{unique_args}" if unique_args else ""
Expand Down
3 changes: 2 additions & 1 deletion sqlmesh/utils/pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,9 +265,10 @@ def _get_fields(
elif isinstance(v, exp.Expr):
expressions = [v]
else:
items: t.List[t.Any] = ensure_list(v)
expressions = [
parse_one(entry, dialect=dialect) if isinstance(entry, str) else entry # type: ignore[misc]
for entry in ensure_list(v)
for entry in items
]

results = []
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,7 @@ def ignore(src, names):
return [name for name in names if name == ".cache"]

def _make_function(
paths: t.Union[t.Union[str, Path], t.Collection[t.Union[str, Path]]],
paths: t.Union[str, Path, t.List[t.Union[str, Path]], t.Tuple[t.Union[str, Path], ...]],
) -> t.List[Path]:
paths = ensure_list(paths)
all_paths = [Path(p) for p in paths]
Expand Down
2 changes: 1 addition & 1 deletion tests/dbt/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def test_model_to_sqlmesh_fields(dbt_dummy_postgres_config: PostgresConfig):
assert kind.on_additive_change == OnAdditiveChange.ALLOW
assert (
kind.merge_filter.sql(dialect=model.dialect) # type: ignore
== """55 > "__MERGE_SOURCE__"."b" AND "__MERGE_TARGET__"."session_start" > CURRENT_DATE + INTERVAL '7'"""
== """55 > "__MERGE_SOURCE__"."b" AND "__MERGE_TARGET__"."session_start" > CURRENT_DATE + INTERVAL '7 DAY'"""
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous SQL was incorrect - there was a bug in the postgres dialect for a long time where we would drop the unit. CURRENT_DATE + INTERVAL '7' means add 7 seconds to the current date, whereas the actual filter had "add 7 days to the current date".

I fixed this upstream, hence the bump to v30.2.1.

)

model = model_config.update_with({"dialect": "snowflake"}).to_sqlmesh(context)
Expand Down
5 changes: 4 additions & 1 deletion web/client/src/workers/sqlglot/sqlglot.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@

def parse_to_json(sql: str, read: DialectType = None) -> str:
return json.dumps(
[exp.dump() if exp else {} for exp in sqlglot.parse(sql, read=read, error_level="ignore")]
[
exp.dump() if exp else {}
for exp in sqlglot.parse(sql, read=read, error_level=sqlglot.ErrorLevel.IGNORE)
]
)


Expand Down
7 changes: 3 additions & 4 deletions web/server/watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
invalidate_context_cache,
)
from web.server.utils import is_relative_to
from sqlglot.helper import ensure_list


async def watch_project() -> None:
Expand All @@ -40,10 +39,10 @@ async def watch_project() -> None:
async for entries in awatch(
settings.project_path,
watch_filter=DefaultFilter(
ignore_paths=ensure_list(DefaultFilter.ignore_paths) + ignore_paths,
ignore_entity_patterns=ensure_list(DefaultFilter.ignore_entity_patterns)
ignore_paths=list(DefaultFilter.ignore_paths or []) + ignore_paths,
ignore_entity_patterns=list(DefaultFilter.ignore_entity_patterns or [])
+ ignore_entity_patterns,
ignore_dirs=ensure_list(DefaultFilter.ignore_dirs) + ignore_dirs,
ignore_dirs=list(DefaultFilter.ignore_dirs or []) + ignore_dirs,
),
):
changes: t.List[models.ArtifactChange] = []
Expand Down
Loading