diff --git a/crates/lib-core/src/parser/segments/from.rs b/crates/lib-core/src/parser/segments/from.rs index 26fcae7a8..d5dce0319 100644 --- a/crates/lib-core/src/parser/segments/from.rs +++ b/crates/lib-core/src/parser/segments/from.rs @@ -85,7 +85,14 @@ impl FromExpressionElementSegment { let alias_expression = self .0 - .child(const { &SyntaxSet::new(&[SyntaxKind::AliasExpression]) }); + .child(const { &SyntaxSet::new(&[SyntaxKind::AliasExpression]) }) + .or_else(|| { + self.0 + .child(const { &SyntaxSet::new(&[SyntaxKind::Bracketed]) }) + .and_then(|bracketed| { + bracketed.child(const { &SyntaxSet::new(&[SyntaxKind::AliasExpression]) }) + }) + }); if let Some(alias_expression) = alias_expression { let segment = alias_expression.child( const { diff --git a/crates/lib-dialects/src/ansi.rs b/crates/lib-dialects/src/ansi.rs index 6e9044029..9877efc94 100644 --- a/crates/lib-dialects/src/ansi.rs +++ b/crates/lib-dialects/src/ansi.rs @@ -4707,34 +4707,52 @@ pub fn raw_dialect() -> Dialect { ( "FromExpressionElementSegment".into(), NodeMatcher::new(SyntaxKind::FromExpressionElement, |_| { - Sequence::new(vec![ - Ref::new("PreTableFunctionKeywordsGrammar") - .optional() + let base_from_expression_element = || { + Sequence::new(vec![ + Ref::new("PreTableFunctionKeywordsGrammar") + .optional() + .to_matchable(), + optionally_bracketed(vec![ + Ref::new("TableExpressionSegment").to_matchable(), + ]) .to_matchable(), - optionally_bracketed(vec![Ref::new("TableExpressionSegment").to_matchable()]) + Ref::new("AliasExpressionSegment") + .exclude(one_of(vec![ + Ref::new("FromClauseTerminatorGrammar").to_matchable(), + Ref::new("SamplingExpressionSegment").to_matchable(), + Ref::new("JoinLikeClauseGrammar").to_matchable(), + LookaheadExclude::new("WITH", "(").to_matchable(), + ])) + .optional() + .to_matchable(), + Sequence::new(vec![ + Ref::keyword("WITH").to_matchable(), + Ref::keyword("OFFSET").to_matchable(), + Ref::new("AliasExpressionSegment").to_matchable(), + ]) + .config(|this| this.optional()) .to_matchable(), - Ref::new("AliasExpressionSegment") - .exclude(one_of(vec![ - Ref::new("FromClauseTerminatorGrammar").to_matchable(), - Ref::new("SamplingExpressionSegment").to_matchable(), - Ref::new("JoinLikeClauseGrammar").to_matchable(), - LookaheadExclude::new("WITH", "(").to_matchable(), - ])) - .optional() + Ref::new("SamplingExpressionSegment") + .optional() + .to_matchable(), + Ref::new("PostTableExpressionGrammar") + .optional() + .to_matchable(), + ]) + .to_matchable() + }; + + one_of(vec![ + base_from_expression_element(), + Bracketed::new(vec![ + Sequence::new(vec![ + base_from_expression_element(), + AnyNumberOf::new(vec![Ref::new("JoinClauseSegment").to_matchable()]) + .to_matchable(), + ]) .to_matchable(), - Sequence::new(vec![ - Ref::keyword("WITH").to_matchable(), - Ref::keyword("OFFSET").to_matchable(), - Ref::new("AliasExpressionSegment").to_matchable(), ]) - .config(|this| this.optional()) .to_matchable(), - Ref::new("SamplingExpressionSegment") - .optional() - .to_matchable(), - Ref::new("PostTableExpressionGrammar") - .optional() - .to_matchable(), ]) .to_matchable() }) diff --git a/crates/lib/src/rules/structure/st05.rs b/crates/lib/src/rules/structure/st05.rs index 50c790fa1..450d599c7 100644 --- a/crates/lib/src/rules/structure/st05.rs +++ b/crates/lib/src/rules/structure/st05.rs @@ -1,8 +1,6 @@ -use std::iter::zip; -use std::ops::{Index, IndexMut}; +use std::ops::Index; use hashbrown::{HashMap, HashSet}; -use itertools::{Itertools, enumerate}; use smol_str::{SmolStr, StrExt, ToSmolStr, format_smolstr}; use sqruff_lib_core::dialects::Dialect; use sqruff_lib_core::dialects::common::AliasInfo; @@ -28,34 +26,128 @@ const SELECT_TYPES: SyntaxSet = SyntaxSet::new(&[ SyntaxKind::SelectStatement, ]); -fn config_mapping(key: &str) -> SyntaxSet { - match key { - "join" => SyntaxSet::single(SyntaxKind::JoinClause), - "from" => SyntaxSet::single(SyntaxKind::FromExpressionElement), - "both" => SyntaxSet::new(&[SyntaxKind::JoinClause, SyntaxKind::FromExpressionElement]), - _ => unreachable!("Invalid value for 'forbid_subquery_in': {key}"), +fn normalize_identifier_name(raw: &str) -> SmolStr { + let is_bracket_quoted = raw.starts_with('[') && raw.ends_with(']') && raw.len() >= 2; + let is_matching_quote_quoted = matches!(raw.chars().next(), Some('"') | Some('`')) + && raw.len() >= 2 + && raw.chars().next() == raw.chars().last(); + + if is_bracket_quoted || is_matching_quote_quoted { + raw[1..raw.len() - 1].into() + } else { + raw.into() + } +} + +#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)] +enum ForbidSubqueryIn { + Join, + From, + #[default] + Both, +} + +impl ForbidSubqueryIn { + fn syntax_set(self) -> SyntaxSet { + match self { + Self::Join => SyntaxSet::single(SyntaxKind::JoinClause), + Self::From => SyntaxSet::single(SyntaxKind::FromExpressionElement), + Self::Both => { + SyntaxSet::new(&[SyntaxKind::JoinClause, SyntaxKind::FromExpressionElement]) + } + } + } +} + +impl std::str::FromStr for ForbidSubqueryIn { + type Err = String; + + fn from_str(value: &str) -> Result { + match value { + "join" => Ok(Self::Join), + "from" => Ok(Self::From), + "both" => Ok(Self::Both), + _ => Err(format!( + "Invalid value for forbid_subquery_in: {value}. Must be one of [join, from, \ + both]" + )), + } } } -#[allow(dead_code)] -struct NestedSubQuerySummary<'a> { - query: Query<'a>, +#[derive(Clone, Debug, Eq, PartialEq, Hash)] +struct IdentifierSpec { + raw: SmolStr, + normalized: SmolStr, + kind: SyntaxKind, +} + +impl IdentifierSpec { + fn new(raw: SmolStr, kind: SyntaxKind) -> Self { + Self { + normalized: normalize_identifier_name(&raw), + raw, + kind: match kind { + SyntaxKind::Identifier + | SyntaxKind::NakedIdentifier + | SyntaxKind::QuotedIdentifier => kind, + _ => SyntaxKind::NakedIdentifier, + }, + } + } + + fn from_segment(segment: &ErasedSegment) -> Self { + Self::new(segment.raw().clone(), segment.get_type()) + } + + fn from_alias(alias: &AliasInfo) -> Option { + if alias.ref_str.is_empty() { + return None; + } + + alias.segment.as_ref().map(Self::from_segment).or_else(|| { + Some(Self::new( + alias.ref_str.clone(), + SyntaxKind::NakedIdentifier, + )) + }) + } + + fn generated(raw: SmolStr) -> Self { + Self::new(raw, SyntaxKind::NakedIdentifier) + } +} + +struct NestedSubqueryCandidate<'a> { selectable: Selectable<'a>, table_alias: AliasInfo, - select_source_names: HashSet, +} + +struct LintQueryResult { + lint_result: LintResult, + from_expression: ErasedSegment, + alias_name: IdentifierSpec, + subquery_parent: ErasedSegment, + cte_source: Option, + is_fixable: bool, } #[derive(Clone, Debug, Default)] pub(crate) struct RuleST05 { - forbid_subquery_in: String, + forbid_subquery_in: ForbidSubqueryIn, } impl Rule for RuleST05 { fn load_from_config(&self, config: &HashMap) -> Result { - Ok(RuleST05 { - forbid_subquery_in: config["forbid_subquery_in"].as_string().unwrap().into(), - } - .erased()) + let forbid_subquery_in = match config.get("forbid_subquery_in") { + Some(value) => value + .as_string() + .ok_or_else(|| "Rule ST05 expects a string for `forbid_subquery_in`".to_string())? + .parse()?, + None => self.forbid_subquery_in, + }; + + Ok(RuleST05 { forbid_subquery_in }.erased()) } fn name(&self) -> &'static str { @@ -105,6 +197,8 @@ join c using(x) let functional_context = FunctionalContext::new(context); let segment = functional_context.segment(); let parent_stack = functional_context.parent_stack(); + let insert_parent = parent_stack + .find_last_where(|it: &ErasedSegment| it.is_type(SyntaxKind::InsertStatement)); let is_select = segment.all_match(|it: &ErasedSegment| SELECT_TYPES.contains(it.get_type())); @@ -115,28 +209,58 @@ join c using(x) return Vec::new(); } + let is_with = + segment.all_match(|it: &ErasedSegment| it.is_type(SyntaxKind::WithCompoundStatement)); let query: Query<'_> = Query::from_segment(&context.segment, context.dialect, None); let mut ctes = CTEBuilder::default(); - for cte in query.inner.borrow().ctes.values() { - ctes.insert_cte(cte.inner.borrow().cte_definition_segment.clone().unwrap()); + if is_with { + // Mirror SQLFluff by preserving all CTEs in document order, including + // non-select expression CTEs (e.g. ClickHouse `expr AS alias`) which + // Query::from_segment intentionally skips in `query.ctes`. + for cte in context.segment.recursive_crawl( + const { &SyntaxSet::single(SyntaxKind::CommonTableExpression) }, + false, + const { &SyntaxSet::single(SyntaxKind::WithCompoundStatement) }, + false, + ) { + ctes.insert_cte(cte); + } } - let is_with = - segment.all_match(|it: &ErasedSegment| it.is_type(SyntaxKind::WithCompoundStatement)); let is_recursive = is_with && !segment .children_where(|it: &ErasedSegment| it.is_keyword("recursive")) .is_empty(); let case_preference = get_case_preference(&segment); + let mut root_segment = segment.first().unwrap().clone(); + let output_select = if is_with { + segment.children_where(|it: &ErasedSegment| { + matches!( + it.get_type(), + SyntaxKind::InsertStatement + | SyntaxKind::SetExpression + | SyntaxKind::SelectStatement + ) + }) + } else if context.dialect.name == DialectKind::Tsql && !insert_parent.is_empty() { + root_segment = insert_parent.first().unwrap().clone(); + insert_parent + } else { + segment.clone() + }; + let bracketed_ctas = parent_stack + .base + .iter() + .rev() + .take(2) + .map(|it| it.get_type()) + .eq([SyntaxKind::CreateTableStatement, SyntaxKind::Bracketed]); - let clone_map = SegmentCloneMap::new( - segment.first().unwrap().clone(), - segment.first().unwrap().deep_clone(), - ); + let clone_map = SegmentCloneMap::new(root_segment.deep_clone()); - let results = self.lint_query( + let mut results_list = self.lint_query( context.tables, context.dialect, query, @@ -145,19 +269,16 @@ join c using(x) &clone_map, ); - let mut lint_results = Vec::with_capacity(results.len()); - let mut is_fixable = true; - - let mut subquery_parent = None; - let mut local_fixes = Vec::new(); - let mut q = Vec::new(); - for result in results { - let (lint_result, from_expression, alias_name, subquery_parent_slot) = result; - subquery_parent = Some(subquery_parent_slot.clone()); - let this_seg_clone = clone_map[&from_expression].clone(); - let new_table_ref = create_table_ref(context.tables, &alias_name, context.dialect); + for result in &results_list { + if !result.is_fixable { + continue; + } + + let this_seg_clone = clone_map[&result.from_expression].clone(); + let new_table_ref = + create_table_ref(context.tables, &result.alias_name, context.dialect); local_fixes.push(LintFix::replace( this_seg_clone.clone(), @@ -172,64 +293,99 @@ join c using(x) ], None, )); - - q.push(subquery_parent_slot); - - let bracketed_ctas = parent_stack - .base - .iter() - .rev() - .take(2) - .map(|it| it.get_type()) - .eq([SyntaxKind::CreateTableStatement, SyntaxKind::Bracketed]); - - if bracketed_ctas || ctes.has_duplicate_aliases() || is_recursive { - is_fixable = false; - } - - lint_results.push(lint_result); } - if !is_fixable { - return lint_results; + if bracketed_ctas || is_recursive || local_fixes.is_empty() { + return results_list + .into_iter() + .map(|result| result.lint_result) + .collect(); } let mut fixes = HashMap::default(); compute_anchor_edit_info(&mut fixes, local_fixes); let (new_root, _, _) = clone_map.root.apply_fixes(&mut fixes); - let clone_map = SegmentCloneMap::new(segment.first().unwrap().clone(), new_root.clone()); - for subquery_parent_slot in q { - ctes.replace_with_clone(subquery_parent_slot, &clone_map); + let clone_map = SegmentCloneMap::new(new_root.clone()); + for result in &results_list { + if result.is_fixable { + ctes.replace_with_clone(result.subquery_parent.clone(), &clone_map); + } } + for result in &results_list { + let Some(cte_source) = &result.cte_source else { + continue; + }; - let _segment = Segments::new(new_root, None); - let output_select = if is_with { - _segment.children_where(|it: &ErasedSegment| { - matches!( - it.get_type(), - SyntaxKind::SetExpression | SyntaxKind::SelectStatement - ) - }) - } else { - _segment.clone() - }; + let cte_source_clone = cte_source.deep_clone(); + let cte_clone_map = SegmentCloneMap::new(cte_source_clone.clone()); + let mut cte_fixes = Vec::new(); + + for nested_result in &results_list { + if !nested_result.is_fixable + || cte_clone_map.get(&nested_result.from_expression).is_none() + { + continue; + } + + let from_expression_clone = cte_clone_map[&nested_result.from_expression].clone(); + let new_table_ref = + create_table_ref(context.tables, &nested_result.alias_name, context.dialect); + + cte_fixes.push(LintFix::replace( + from_expression_clone.clone(), + vec![ + SegmentBuilder::node( + context.tables.next_id(), + from_expression_clone.get_type(), + context.dialect.name, + vec![new_table_ref], + ) + .finish(), + ], + None, + )); + } + + let cte_source = if cte_fixes.is_empty() { + cte_source_clone + } else { + let mut fixes = HashMap::default(); + compute_anchor_edit_info(&mut fixes, cte_fixes); + cte_source_clone.apply_fixes(&mut fixes).0 + }; + + let new_cte = create_cte_seg( + context.tables, + result.alias_name.clone(), + cte_source, + case_preference, + context.dialect, + ); + ctes.replace_cte(result.alias_name.normalized.as_str(), new_cte); + } // If there's no SELECT statement (e.g., WITH ... INSERT/UPDATE/DELETE), // we can't safely create fixes, so return lint results without fixes. if output_select.is_empty() { - return lint_results; + return results_list + .into_iter() + .map(|result| result.lint_result) + .collect(); } - for result in &mut lint_results { - let subquery_parent = subquery_parent.clone().unwrap(); - let output_select_clone = output_select[0].clone(); + let mut output_select_clone = clone_map[&output_select[0]].clone(); + + for result in &mut results_list { + if !result.is_fixable { + continue; + } let mut fixes = ctes.ensure_space_after_from( context.tables, output_select[0].clone(), - &output_select_clone, - subquery_parent, + &mut output_select_clone, + result.subquery_parent.clone(), ); let new_select = ctes.compose_select( @@ -239,16 +395,19 @@ join c using(x) case_preference, ); - result.fixes = vec![LintFix::replace( - segment.first().unwrap().clone(), + result.lint_result.fixes = vec![LintFix::replace( + root_segment.clone(), vec![new_select], None, )]; - result.fixes.append(&mut fixes); + result.lint_result.fixes.append(&mut fixes); } - lint_results + results_list + .into_iter() + .map(|result| result.lint_result) + .collect() } fn is_fix_compatible(&self) -> bool { @@ -269,66 +428,89 @@ impl RuleST05 { ctes: &mut CTEBuilder, case_preference: Case, segment_clone_map: &SegmentCloneMap, - ) -> Vec<(LintResult, ErasedSegment, SmolStr, ErasedSegment)> { + ) -> Vec { let mut acc = Vec::new(); for nsq in self.nested_subqueries(query, dialect) { - let (alias_name, _) = ctes.create_cte_alias(Some(&nsq.table_alias)); + let alias_name = ctes.create_cte_alias(Some(&nsq.table_alias)); let anchor = nsq .table_alias .from_expression_element .segments() .first() .cloned() - .unwrap(); + .expect("from_expression_element should have at least one segment"); - let new_cte = create_cte_seg( - tables, - alias_name.clone(), - segment_clone_map[&anchor].clone(), - case_preference, - dialect, - ); + let mut is_fixable = !ctes.is_name_used(&alias_name.normalized); + let bracket_anchor = if anchor.is_type(SyntaxKind::TableExpression) { + let Some(bracket_anchor) = + anchor.child(const { &SyntaxSet::single(SyntaxKind::Bracketed) }) + else { + continue; + }; - ctes.insert_cte(new_cte); + bracket_anchor + } else { + anchor.clone() + }; - if nsq.query.inner.borrow().selectables.len() != 1 { - continue; + if !bracket_anchor.is_type(SyntaxKind::Bracketed) + || bracket_anchor + .child(const { &SyntaxSet::single(SyntaxKind::TableExpression) }) + .is_some() + { + is_fixable = false; } - let select = nsq.query.inner.borrow().selectables[0].clone().selectable; - let anchor = anchor.recursive_crawl( - const { - &SyntaxSet::new(&[ - SyntaxKind::Keyword, - SyntaxKind::Symbol, - SyntaxKind::StartBracket, - SyntaxKind::EndBracket, - ]) - }, - true, - &SyntaxSet::EMPTY, - true, - )[0] - .clone(); + if is_fixable { + let new_cte = create_cte_seg( + tables, + alias_name.clone(), + segment_clone_map[&bracket_anchor].clone(), + case_preference, + dialect, + ); + + ctes.insert_cte(new_cte); + } + + let anchor = anchor + .recursive_crawl( + const { + &SyntaxSet::new(&[ + SyntaxKind::Keyword, + SyntaxKind::Symbol, + SyntaxKind::StartBracket, + SyntaxKind::EndBracket, + ]) + }, + true, + &SyntaxSet::EMPTY, + true, + ) + .first() + .cloned() + .expect("expected a code anchor inside the offending subquery"); let res = LintResult::new( anchor.into(), Vec::new(), format!( "{} clauses should not contain subqueries. Use CTEs instead", - select.get_type().as_str() + nsq.selectable.selectable.get_type().as_str() ) .into(), None, ); - acc.push(( - res, - nsq.table_alias.from_expression_element, - alias_name.clone(), - nsq.query.inner.borrow().selectables[0].clone().selectable, - )); + acc.push(LintQueryResult { + lint_result: res, + from_expression: nsq.table_alias.from_expression_element, + alias_name: alias_name.clone(), + subquery_parent: nsq.selectable.selectable, + cte_source: is_fixable.then_some(bracket_anchor), + is_fixable, + }); } acc @@ -338,21 +520,32 @@ impl RuleST05 { &self, query: Query<'a>, dialect: &'a Dialect, - ) -> Vec> { + ) -> Vec> { let mut acc = Vec::new(); + self.collect_nested_subqueries(&query, dialect, &mut acc); + acc + } - let parent_types = config_mapping(&self.forbid_subquery_in); + fn collect_nested_subqueries<'a>( + &self, + query: &Query<'a>, + dialect: &'a Dialect, + acc: &mut Vec>, + ) { + let parent_types = self.forbid_subquery_in.syntax_set(); let mut queries = vec![query.clone()]; queries.extend(query.inner.borrow().ctes.values().cloned()); - for (i, q) in enumerate(queries) { - for selectable in &q.inner.borrow().selectables { + for current_query in queries { + let selectables = current_query.inner.borrow().selectables.clone(); + + for selectable in selectables { let Some(select_info) = selectable.select_info() else { continue; }; let mut select_source_names = HashSet::new(); - for table_alias in select_info.table_aliases { + for table_alias in &select_info.table_aliases { if !table_alias.ref_str.is_empty() { select_source_names.insert(table_alias.ref_str.clone()); } @@ -360,8 +553,10 @@ impl RuleST05 { if let Some(object_reference) = &table_alias.object_reference { select_source_names.insert(object_reference.raw().to_smolstr()); } + } - let Some(query) = + for table_alias in select_info.table_aliases { + let Some(subquery) = Query::from_root(&table_alias.from_expression_element, dialect) else { continue; @@ -370,48 +565,35 @@ impl RuleST05 { let path_to = selectable .selectable .path_to(&table_alias.from_expression_element); - - if !(parent_types.contains(table_alias.from_expression_element.get_type()) + let is_target_parent = parent_types + .contains(table_alias.from_expression_element.get_type()) || path_to .iter() - .any(|ps| parent_types.contains(ps.segment.get_type()))) - { - continue; - } + .any(|ps| parent_types.contains(ps.segment.get_type())); - if is_correlated_subquery( - Segments::new( - query - .inner - .borrow() - .selectables - .first() - .unwrap() - .selectable - .clone(), - None, - ), - &select_source_names, - dialect, - ) { + let Some(nested_selectable) = + subquery.inner.borrow().selectables.first().cloned() + else { continue; - } - - acc.push(NestedSubQuerySummary { - query: q.clone(), - selectable: selectable.clone(), - table_alias: table_alias.clone(), - select_source_names: select_source_names.clone(), - }); + }; - if i > 0 { - acc.append(&mut self.nested_subqueries(query.clone(), dialect)); + if is_target_parent + && !is_correlated_subquery( + Segments::new(nested_selectable.selectable.clone(), None), + &select_source_names, + dialect, + ) + { + acc.push(NestedSubqueryCandidate { + selectable: selectable.clone(), + table_alias: table_alias.clone(), + }); } + + self.collect_nested_subqueries(&subquery, dialect, acc); } } } - - acc } } @@ -455,6 +637,7 @@ fn is_correlated_subquery( #[derive(Default)] struct CTEBuilder { ctes: Vec, + used_names: HashSet, name_idx: usize, } @@ -507,68 +690,48 @@ impl CTEBuilder { &self, tables: &Tables, output_select: ErasedSegment, - output_select_clone: &ErasedSegment, + output_select_clone: &mut ErasedSegment, subquery_parent: ErasedSegment, ) -> Vec { let mut fixes = Vec::new(); if subquery_parent.is(&output_select) { - let (missing_space_after_from, _from_clause, _from_clause_children, _from_segment) = - Self::missing_space_after_from(output_select_clone.clone()); - - if missing_space_after_from { - todo!() - } - } else { - let (missing_space_after_from, _from_clause, _from_clause_children, from_segment) = - Self::missing_space_after_from(subquery_parent); - - if missing_space_after_from { - fixes.push(LintFix::create_after( - from_segment.unwrap().base[0].clone(), - vec![SegmentBuilder::whitespace(tables.next_id(), " ")], - None, - )) + if let Some(from_keyword) = Self::from_keyword_needing_separator(output_select_clone) { + let mut anchor_fixes = HashMap::default(); + compute_anchor_edit_info( + &mut anchor_fixes, + vec![LintFix::create_after( + from_keyword, + vec![SegmentBuilder::whitespace(tables.next_id(), " ")], + None, + )], + ); + *output_select_clone = output_select_clone.apply_fixes(&mut anchor_fixes).0; } + } else if let Some(from_keyword) = Self::from_keyword_needing_separator(&subquery_parent) { + fixes.push(LintFix::create_after( + from_keyword, + vec![SegmentBuilder::whitespace(tables.next_id(), " ")], + None, + )) } fixes } - fn missing_space_after_from( - segment: ErasedSegment, - ) -> ( - bool, - Option, - Option, - Option, - ) { - let mut missing_space_after_from = false; - let from_clause_children = None; - let mut from_segment = None; - let from_clause = segment.child(const { &SyntaxSet::single(SyntaxKind::FromClause) }); - - if let Some(from_clause) = &from_clause { - let from_clause_children = Segments::from_vec(from_clause.segments().to_vec(), None); - from_segment = from_clause_children - .find_first_where(|it: &ErasedSegment| it.is_keyword("FROM")) - .into(); - if !from_segment.as_ref().unwrap().is_empty() - && from_clause_children - .after(&from_segment.as_ref().unwrap().base[0]) - .take_while(|it| it.is_whitespace()) - .is_empty() - { - missing_space_after_from = true; - } - } - - ( - missing_space_after_from, - from_clause, - from_clause_children, - from_segment, - ) + fn from_keyword_needing_separator(segment: &ErasedSegment) -> Option { + let from_clause = segment.child(const { &SyntaxSet::single(SyntaxKind::FromClause) })?; + let from_clause_children = Segments::from_vec(from_clause.segments().to_vec(), None); + let from_keyword = from_clause_children + .find_first_where(|it: &ErasedSegment| it.is_keyword("FROM")) + .first() + .cloned()?; + let has_separator = !from_clause_children + .after(&from_keyword) + .take_while(|it| it.is_whitespace() || it.is_comment() || it.is_meta()) + .is_empty(); + + (!has_separator).then_some(from_keyword) } } @@ -578,13 +741,15 @@ impl CTEBuilder { segment: ErasedSegment, clone_map: &SegmentCloneMap, ) { - for (idx, cte) in enumerate(&self.ctes) { + for (idx, cte) in self.ctes.iter().enumerate() { if cte .recursive_crawl_all(false) .into_iter() .any(|seg| segment.is(&seg)) { - self.ctes[idx] = clone_map[&self.ctes[idx]].clone(); + if let Some(cte_clone) = clone_map.get(&self.ctes[idx]) { + self.ctes[idx] = cte_clone.clone(); + } return; } } @@ -592,55 +757,56 @@ impl CTEBuilder { } impl CTEBuilder { - fn list_used_names(&self) -> Vec { - let mut used_names = Vec::new(); - - for cte in &self.ctes { - let id_seg = cte - .child( - const { - &SyntaxSet::new(&[ - SyntaxKind::Identifier, - SyntaxKind::NakedIdentifier, - SyntaxKind::QuotedIdentifier, - ]) - }, - ) - .unwrap(); - - let cte_name = if id_seg.is_type(SyntaxKind::QuotedIdentifier) { - let raw = id_seg.raw(); - raw[1..raw.len() - 1].to_smolstr() - } else { - id_seg.raw().to_smolstr() - }; - - used_names.push(cte_name); - } + fn cte_name(cte: &ErasedSegment) -> SmolStr { + let id_seg = cte + .child( + const { + &SyntaxSet::new(&[ + SyntaxKind::Identifier, + SyntaxKind::NakedIdentifier, + SyntaxKind::QuotedIdentifier, + ]) + }, + ) + .expect("CTE should contain an identifier segment"); - used_names + normalize_identifier_name(id_seg.raw().as_ref()) } - fn has_duplicate_aliases(&self) -> bool { - let used_names = self.list_used_names(); - !used_names.into_iter().all_unique() + fn is_name_used(&self, name: &SmolStr) -> bool { + self.used_names.contains(name) } - fn create_cte_alias(&mut self, alias: Option<&AliasInfo>) -> (SmolStr, bool) { - if let Some(alias) = alias.filter(|alias| alias.aliased && !alias.ref_str.is_empty()) { - return (alias.ref_str.clone(), false); + fn replace_cte(&mut self, cte_name: &str, new_cte: ErasedSegment) { + if let Some(idx) = self + .ctes + .iter() + .position(|cte| Self::cte_name(cte) == cte_name) + { + self.used_names.remove(&Self::cte_name(&self.ctes[idx])); + self.ctes[idx] = new_cte; + self.used_names.insert(Self::cte_name(&self.ctes[idx])); } + } - self.name_idx += 1; - let name = format_smolstr!("prep_{}", self.name_idx); - if self.list_used_names().iter().contains(&name) { - return self.create_cte_alias(None); + fn create_cte_alias(&mut self, alias: Option<&AliasInfo>) -> IdentifierSpec { + if let Some(alias) = alias.filter(|alias| alias.aliased) + && let Some(alias_name) = IdentifierSpec::from_alias(alias) + { + return alias_name; } - (name, true) + loop { + self.name_idx += 1; + let name = format_smolstr!("prep_{}", self.name_idx); + if !self.is_name_used(&name) { + return IdentifierSpec::generated(name); + } + } } fn insert_cte(&mut self, cte: ErasedSegment) { + let cte_name = Self::cte_name(&cte); let inbound_subquery = Segments::new(cte.clone(), None) .children_all() .find_first_where(|it: &ErasedSegment| it.get_position_marker().is_some()); @@ -656,7 +822,7 @@ impl CTEBuilder { .children_all() .last() .cloned() - .unwrap(), + .expect("CTE should contain a trailing selectable segment"), None, ), inbound_subquery.clone(), @@ -667,6 +833,7 @@ impl CTEBuilder { .unwrap_or(self.ctes.len()); self.ctes.insert(insert_position, cte); + self.used_names.insert(cte_name); } } @@ -674,11 +841,15 @@ fn is_child(maybe_parent: Segments, maybe_child: Segments) -> bool { let child_markers = maybe_child[0].get_position_marker().unwrap(); let parent_pos = maybe_parent[0].get_position_marker().unwrap(); - if child_markers < &parent_pos.start_point_marker() { + // Preserve the original source nesting relationship even after earlier fix + // passes have changed working positions within the tree. Structural rewrites + // like ST05 rely on this to keep generated CTEs ahead of the CTEs/selects + // they were extracted from. + if child_markers.source_slice.start < parent_pos.source_slice.start { return false; } - if child_markers > &parent_pos.end_point_marker() { + if child_markers.source_slice.end > parent_pos.source_slice.end { return false; } @@ -718,7 +889,7 @@ fn segmentify(tables: &Tables, input_el: &str, casing: Case) -> ErasedSegment { fn create_cte_seg( tables: &Tables, - alias_name: SmolStr, + alias_name: IdentifierSpec, subquery: ErasedSegment, case_preference: Case, dialect: &Dialect, @@ -728,8 +899,7 @@ fn create_cte_seg( SyntaxKind::CommonTableExpression, dialect.name, vec![ - SegmentBuilder::token(tables.next_id(), &alias_name, SyntaxKind::NakedIdentifier) - .finish(), + SegmentBuilder::token(tables.next_id(), &alias_name.raw, alias_name.kind).finish(), SegmentBuilder::whitespace(tables.next_id(), " "), segmentify(tables, "AS", case_preference), SegmentBuilder::whitespace(tables.next_id(), " "), @@ -739,7 +909,11 @@ fn create_cte_seg( .finish() } -fn create_table_ref(tables: &Tables, table_name: &str, dialect: &Dialect) -> ErasedSegment { +fn create_table_ref( + tables: &Tables, + table_name: &IdentifierSpec, + dialect: &Dialect, +) -> ErasedSegment { SegmentBuilder::node( tables.next_id(), SyntaxKind::TableExpression, @@ -750,12 +924,8 @@ fn create_table_ref(tables: &Tables, table_name: &str, dialect: &Dialect) -> Era SyntaxKind::TableReference, dialect.name, vec![ - SegmentBuilder::token( - tables.next_id(), - table_name, - SyntaxKind::NakedIdentifier, - ) - .finish(), + SegmentBuilder::token(tables.next_id(), &table_name.raw, table_name.kind) + .finish(), ], ) .finish(), @@ -766,32 +936,27 @@ fn create_table_ref(tables: &Tables, table_name: &str, dialect: &Dialect) -> Era pub(crate) struct SegmentCloneMap { root: ErasedSegment, - segment_map: HashMap, + segment_map: HashMap, } impl Index<&ErasedSegment> for SegmentCloneMap { type Output = ErasedSegment; fn index(&self, index: &ErasedSegment) -> &Self::Output { - &self.segment_map[&index.addr()] + &self.segment_map[&index.id()] } } -impl IndexMut<&ErasedSegment> for SegmentCloneMap { - fn index_mut(&mut self, index: &ErasedSegment) -> &mut Self::Output { - self.segment_map.get_mut(&index.addr()).unwrap() +impl SegmentCloneMap { + fn get(&self, old_segment: &ErasedSegment) -> Option<&ErasedSegment> { + self.segment_map.get(&old_segment.id()) } -} -impl SegmentCloneMap { - fn new(segment: ErasedSegment, segment_copy: ErasedSegment) -> Self { + fn new(segment_copy: ErasedSegment) -> Self { let mut segment_map = HashMap::new(); - for (old_segment, new_segment) in zip( - segment.recursive_crawl_all(false), - segment_copy.recursive_crawl_all(false), - ) { - segment_map.insert(old_segment.addr(), new_segment); + for segment in segment_copy.recursive_crawl_all(false) { + segment_map.insert(segment.id(), segment.clone()); } Self { @@ -800,3 +965,555 @@ impl SegmentCloneMap { } } } + +#[cfg(test)] +mod tests { + use hashbrown::HashMap; + + use crate::core::config::FluffConfig; + use crate::core::config::Value; + use crate::core::linter::core::Linter; + use crate::core::rules::Rule; + + use super::{ForbidSubqueryIn, RuleST05}; + + fn st05_linter(dialect: &str) -> Linter { + let config = FluffConfig::from_source( + &format!( + r#" +[sqruff] +dialect = {dialect} +rules = ST05 + +[sqruff:rules:structure.subquery] +forbid_subquery_in = both +"# + ), + None, + ); + + Linter::new(config, None, None, false).unwrap() + } + + fn assert_fix(dialect: &str, source: &str, expected: &str) { + let mut linter = st05_linter(dialect); + let actual = linter + .lint_string_wrapped(source, true) + .unwrap() + .fix_string(); + + pretty_assertions::assert_eq!(actual, expected); + } + + fn assert_lint_without_fix(dialect: &str, source: &str, expected_violations: usize) { + let mut linter = st05_linter(dialect); + let linted = linter.lint_string_wrapped(source, false).unwrap(); + assert_eq!(linted.violations().len(), expected_violations); + + let mut linter = st05_linter(dialect); + let fixed = linter + .lint_string_wrapped(source, true) + .unwrap() + .fix_string(); + pretty_assertions::assert_eq!(fixed, source); + } + + fn assert_pass(dialect: &str, source: &str) { + let mut linter = st05_linter(dialect); + let linted = linter.lint_string_wrapped(source, false).unwrap(); + assert!(linted.violations().is_empty(), "{:?}", linted.violations()); + } + + #[test] + fn st05_fixes_double_nested_subquery_without_panicking() { + let source = r#"WITH q AS ( + SELECT + t1.a + FROM + table_1 AS t1 + INNER JOIN + table_2 AS t2 USING (a) + LEFT JOIN ( + SELECT DISTINCT a FROM table_3 + WHERE c = 'v1' + ) AS dns USING (a) + LEFT JOIN ( + SELECT DISTINCT a FROM table_5 + LEFT JOIN ( + SELECT DISTINCT + a, + b + FROM table_6 + WHERE c < 5 + ) AS t4 + USING (a) + WHERE table_5.b = 'v2' + ) AS dcod USING (a) +) +SELECT + a +FROM + q; +"#; + let expected = r#"WITH dns AS ( + SELECT DISTINCT a FROM table_3 + WHERE c = 'v1' + ), +t4 AS ( + SELECT DISTINCT + a, + b + FROM table_6 + WHERE c < 5 + ), +dcod AS ( + SELECT DISTINCT a FROM table_5 + LEFT JOIN t4 + USING (a) + WHERE table_5.b = 'v2' + ), +q AS ( + SELECT + t1.a + FROM + table_1 AS t1 + INNER JOIN + table_2 AS t2 USING (a) + LEFT JOIN dns USING (a) + LEFT JOIN dcod USING (a) +) +SELECT + a +FROM + q; +"#; + + assert_fix("ansi", source, expected); + } + + #[test] + fn st05_fixes_order_4782_without_panicking() { + let source = r#"WITH +cte_1 AS ( + SELECT + subquery_a.field_a, + subquery_a.field_b + FROM ( + SELECT + subquery_b.field_a, + alias_a.field_d, + alias_a.field_b, + alias_b.field_c + FROM table_b AS alias_a + INNER JOIN + (SELECT * FROM table_a) AS subquery_b + ON subquery_b.field_a >= alias_a.field_d + LEFT OUTER JOIN table_b AS alias_b ON alias_a.field_b = alias_b.field_c + ) AS subquery_a +), + +cte_2 AS ( + SELECT * + FROM table_c + WHERE field_a > 0 + ORDER BY field_b DESC +), + +join_ctes AS ( + SELECT * FROM cte_1 LEFT OUTER JOIN cte_2 ON cte_1.field_a = cte_2.field_a +) + +SELECT * +FROM join_ctes; +"#; + let expected = r#"WITH subquery_b AS (SELECT * FROM table_a), +subquery_a AS ( + SELECT + subquery_b.field_a, + alias_a.field_d, + alias_a.field_b, + alias_b.field_c + FROM table_b AS alias_a + INNER JOIN + subquery_b + ON subquery_b.field_a >= alias_a.field_d + LEFT OUTER JOIN table_b AS alias_b ON alias_a.field_b = alias_b.field_c + ), +cte_1 AS ( + SELECT + subquery_a.field_a, + subquery_a.field_b + FROM subquery_a +), +cte_2 AS ( + SELECT * + FROM table_c + WHERE field_a > 0 + ORDER BY field_b DESC +), +join_ctes AS ( + SELECT * FROM cte_1 LEFT OUTER JOIN cte_2 ON cte_1.field_a = cte_2.field_a +) +SELECT * +FROM join_ctes; +"#; + + assert_fix("ansi", source, expected); + } + + #[test] + fn st05_fixes_same_named_nested_subqueries_across_ctes() { + let source = r#"with purchases_in_the_last_year as ( + select + customer_id + , arrayagg(distinct attr) within group (order by attr asc) as attrlist + from ( + select + o.customer_id + , p.attr + from + order_line_item as o + inner join product as p + on o.product_id = p.product_id + and o.time_placed >= dateadd(year, -1, current_date()) + ) group by customer_id +) + +, purchases_in_the_last_three_years as ( + select + customer_id + , arrayagg(distinct attr) within group (order by attr asc) as attrlist + from ( + select + o.customer_id + , p.attr + from + order_line_item as o + inner join product as p + on o.product_id = p.product_id + and o.time_placed >= dateadd(year, -3, current_date()) + ) group by customer_id +) + + +select distinct + c.customer_id + , ly.attrlist as attrlist_last_year + , l3y.attrlist as attrlist_last_three_years +from + customers as c +left outer join + purchases_in_the_last_year as ly + on c.customer_id = ly.customer_id +left outer join + purchases_in_the_last_three_years as l3y + on c.customer_id = l3y.customer_id +; +"#; + let expected = r#"with prep_1 as ( + select + o.customer_id + , p.attr + from + order_line_item as o + inner join product as p + on o.product_id = p.product_id + and o.time_placed >= dateadd(year, -1, current_date()) + ), +purchases_in_the_last_year as ( + select + customer_id + , arrayagg(distinct attr) within group (order by attr asc) as attrlist + from prep_1 group by customer_id +), +prep_2 as ( + select + o.customer_id + , p.attr + from + order_line_item as o + inner join product as p + on o.product_id = p.product_id + and o.time_placed >= dateadd(year, -3, current_date()) + ), +purchases_in_the_last_three_years as ( + select + customer_id + , arrayagg(distinct attr) within group (order by attr asc) as attrlist + from prep_2 group by customer_id +) +select distinct + c.customer_id + , ly.attrlist as attrlist_last_year + , l3y.attrlist as attrlist_last_three_years +from + customers as c +left outer join + purchases_in_the_last_year as ly + on c.customer_id = ly.customer_id +left outer join + purchases_in_the_last_three_years as l3y + on c.customer_id = l3y.customer_id +; +"#; + + assert_fix("snowflake", source, expected); + } + + #[test] + fn st05_partially_fixes_duplicate_aliases_in_order_5265_case() { + let source = r#"WITH +cte1 AS ( + SELECT COUNT(*) AS qty + FROM some_table AS st + LEFT JOIN ( + SELECT 'first' AS id + ) AS oops + ON st.id = oops.id +), +cte2 AS ( + SELECT COUNT(*) AS other_qty + FROM other_table AS sot + LEFT JOIN ( + SELECT 'middle' AS id + ) AS another + ON sot.id = another.id + LEFT JOIN ( + SELECT 'last' AS id + ) AS oops + ON sot.id = oops.id +) +SELECT CURRENT_DATE(); +"#; + let expected = r#"WITH oops AS ( + SELECT 'first' AS id + ), +cte1 AS ( + SELECT COUNT(*) AS qty + FROM some_table AS st + LEFT JOIN oops + ON st.id = oops.id +), +another AS ( + SELECT 'middle' AS id + ), +cte2 AS ( + SELECT COUNT(*) AS other_qty + FROM other_table AS sot + LEFT JOIN another + ON sot.id = another.id + LEFT JOIN ( + SELECT 'last' AS id + ) AS oops + ON sot.id = oops.id +) +SELECT CURRENT_DATE(); +"#; + + assert_fix("ansi", source, expected); + } + + #[test] + fn st05_inserts_space_after_from_when_rewriting_root_subquery() { + let source = r#"CREATE TABLE t +AS +SELECT + col1 +FROM( + SELECT 'x' AS col1 +) x +"#; + let expected = r#"CREATE TABLE t +AS +WITH x AS ( + SELECT 'x' AS col1 +) +SELECT + col1 +FROM x +"#; + + assert_fix("ansi", source, expected); + } + + #[test] + fn st05_fixes_set_subquery_in_second_query() { + let source = r#"SELECT 1 AS value_name +UNION +SELECT value +FROM (SELECT 2 AS value_name); +"#; + let expected = r#"WITH prep_1 AS (SELECT 2 AS value_name) +SELECT 1 AS value_name +UNION +SELECT value +FROM prep_1; +"#; + + assert_fix("ansi", source, expected); + } + + #[test] + fn st05_fixes_multiple_set_subqueries_in_second_query() { + let source = r#"SELECT 1 AS value_name +UNION +SELECT value +FROM (SELECT 2 AS value_name) +CROSS JOIN (SELECT 1 as v2); +"#; + let expected = r#"WITH prep_1 AS (SELECT 2 AS value_name), +prep_2 AS (SELECT 1 as v2) +SELECT 1 AS value_name +UNION +SELECT value +FROM prep_1 +CROSS JOIN prep_2; +"#; + + assert_fix("ansi", source, expected); + } + + #[test] + fn st05_rewrites_tsql_insert_to_with_before_insert() { + let source = r#"INSERT INTO Table1 (Id,Name,Attribute) +SELECT +Main.Id +,Main.Name +,Subq.Attribute +FROM MainTable AS Main +LEFT JOIN +(SELECT +Id +,Attribute +FROM Table2) Subq +ON Main.Id = Subq.Id +"#; + let expected = r#"WITH Subq AS (SELECT +Id +,Attribute +FROM Table2) +INSERT INTO Table1 (Id,Name,Attribute) +SELECT +Main.Id +,Main.Name +,Subq.Attribute +FROM MainTable AS Main +LEFT JOIN +Subq +ON Main.Id = Subq.Id +"#; + + assert_fix("tsql", source, expected); + } + + #[test] + fn st05_rewrites_existing_tsql_with_insert_statement() { + let source = r#"WITH MainTable as ( + select * + from sales +) + +INSERT INTO Table1 (Id,Name,Attribute) +SELECT +Main.Id +,Main.Name +,Subq.Attribute +FROM MainTable AS Main +LEFT JOIN +(SELECT +Id +,Attribute +FROM Table2) Subq +ON Main.Id = Subq.Id +"#; + let expected = r#"WITH MainTable as ( + select * + from sales +), +Subq AS (SELECT +Id +,Attribute +FROM Table2) +INSERT INTO Table1 (Id,Name,Attribute) +SELECT +Main.Id +,Main.Name +,Subq.Attribute +FROM MainTable AS Main +LEFT JOIN +Subq +ON Main.Id = Subq.Id +"#; + + assert_fix("tsql", source, expected); + } + + #[test] + fn st05_lints_nested_joined_subquery_without_fixing() { + let source = r#"SELECT + x.a, + w2.b +FROM x +LEFT JOIN ( + ( + SELECT + w.a, + w.b, + w.c + FROM w + ) AS w2 + LEFT JOIN y + ON w2.a = y.a +) + ON x.a = w2.a; +"#; + + assert_lint_without_fix("ansi", source, 1); + } + + #[test] + fn st05_ignores_nested_table_function_subqueries() { + let source = r#"SELECT * +FROM `func`(( + SELECT 1 +)); +"#; + + assert_pass("bigquery", source); + } + + #[test] + fn st05_load_from_config_rejects_invalid_value() { + let config = HashMap::from_iter([( + "forbid_subquery_in".into(), + Value::String("sideways".into()), + )]); + + let err = RuleST05::default().load_from_config(&config).unwrap_err(); + assert_eq!( + err, + "Invalid value for forbid_subquery_in: sideways. Must be one of [join, from, both]" + ); + } + + #[test] + fn st05_load_from_config_accepts_valid_value() { + let config = + HashMap::from_iter([("forbid_subquery_in".into(), Value::String("from".into()))]); + + assert!(RuleST05::default().load_from_config(&config).is_ok()); + } + + #[test] + fn st05_default_config_is_safe_and_uses_both() { + assert_eq!( + RuleST05::default().forbid_subquery_in, + ForbidSubqueryIn::Both + ); + assert!( + RuleST05::default() + .load_from_config(&HashMap::default()) + .is_ok() + ); + } +} diff --git a/crates/lib/test/fixtures/rules/std_rule_cases/ST05.yml b/crates/lib/test/fixtures/rules/std_rule_cases/ST05.yml index 99cc0aa33..4debdd699 100644 --- a/crates/lib/test/fixtures/rules/std_rule_cases/ST05.yml +++ b/crates/lib/test/fixtures/rules/std_rule_cases/ST05.yml @@ -638,3 +638,90 @@ test_fail_subquery_in_cte_3: rules: structure.subquery: forbid_subquery_in: both + +test_fail_clickhouse_expression_cte_preserved: + fail_str: | + insert into demo_table + with + array('A', 'B') as keep_list + select + orders.id, + default_tags.tag + from ( + select 1 as id, 'A' as tag + ) as orders + left join ( + select 1 as id, 'x' as tag + ) as default_tags + on orders.id = default_tags.id + where orders.tag in keep_list + fix_str: | + insert into demo_table + with array('A', 'B') as keep_list, + default_tags as ( + select 1 as id, 'x' as tag + ) + select + orders.id, + default_tags.tag + from ( + select 1 as id, 'A' as tag + ) as orders + left join default_tags + on orders.id = default_tags.id + where orders.tag in keep_list + configs: + core: + dialect: clickhouse + +test_fail_nested_join_subquery_inside_allowed_from_subquery: + fail_str: | + select + outer_q.x, + outer_q.y + from ( + select + a.x, + b.y + from a + join ( + select x, y from b + ) as b on a.x = b.x + ) as outer_q + fix_str: | + with b as ( + select x, y from b + ) + select + outer_q.x, + outer_q.y + from ( + select + a.x, + b.y + from a + join b on a.x = b.x + ) as outer_q + configs: + rules: + structure.subquery: + forbid_subquery_in: join + +test_fail_preserves_quoted_aliases_in_generated_ctes: + fail_str: | + select + a.x, + "b".y + from a + join ( + select x, y from b + ) as "b" on a.x = "b".x + fix_str: | + with "b" as ( + select x, y from b + ) + select + a.x, + "b".y + from a + join "b" on a.x = "b".x diff --git a/crates/lib/test/fixtures/rules/std_rule_cases/ST05_LT09.yml b/crates/lib/test/fixtures/rules/std_rule_cases/ST05_LT09.yml new file mode 100644 index 000000000..1e5f80f4a --- /dev/null +++ b/crates/lib/test/fixtures/rules/std_rule_cases/ST05_LT09.yml @@ -0,0 +1,58 @@ +rule: ST05,LT09 + +test_fix_generated_cte_stays_before_first_use: + # https://github.com/sqlfluff/sqlfluff/issues/4137 + # Tests multi-rule interaction: ST05 extracts the join subquery into a CTE, + # and LT09 reformats the SELECT targets. The generated CTE must stay before + # the first CTE that references it. + fail_str: | + with + + cte1 as ( + select t1.x, t2.y + from tbl1 t1 + join (select x, y from tbl2) t2 + on t1.x = t2.x + ) + + , cte2 as ( + select x, y from tbl2 t2 + ) + + select x, y from cte1 + union all + select x, y from cte2 + ; + fix_str: | + with t2 as (select + x, + y + from tbl2), + cte1 as ( + select + t1.x, + t2.y + from tbl1 t1 + join t2 + on t1.x = t2.x + ), + cte2 as ( + select + x, + y + from tbl2 t2 + ) + select + x, + y + from cte1 + union all + select + x, + y + from cte2 + ; + configs: + rules: + structure.subquery: + forbid_subquery_in: both