From 7c4f95b75b0758f147fe4a2225530720f341a03b Mon Sep 17 00:00:00 2001 From: gvozdvmozgu Date: Mon, 23 Mar 2026 15:15:38 -0400 Subject: [PATCH 01/10] fix(st05): repair per-result clone flow and nested CTE regressions --- crates/lib/src/rules/structure/st05.rs | 630 ++++++++++++++++++++----- 1 file changed, 517 insertions(+), 113 deletions(-) diff --git a/crates/lib/src/rules/structure/st05.rs b/crates/lib/src/rules/structure/st05.rs index 50c790fa1..0c5af130e 100644 --- a/crates/lib/src/rules/structure/st05.rs +++ b/crates/lib/src/rules/structure/st05.rs @@ -1,4 +1,3 @@ -use std::iter::zip; use std::ops::{Index, IndexMut}; use hashbrown::{HashMap, HashSet}; @@ -45,6 +44,15 @@ struct NestedSubQuerySummary<'a> { select_source_names: HashSet, } +struct LintQueryResult { + lint_result: LintResult, + from_expression: ErasedSegment, + alias_name: SmolStr, + subquery_parent: ErasedSegment, + cte_source: Option, + is_fixable: bool, +} + #[derive(Clone, Debug, Default)] pub(crate) struct RuleST05 { forbid_subquery_in: String, @@ -130,13 +138,27 @@ join c using(x) .is_empty(); let case_preference = get_case_preference(&segment); + let output_select = if is_with { + segment.children_where(|it: &ErasedSegment| { + matches!( + it.get_type(), + SyntaxKind::SetExpression | SyntaxKind::SelectStatement + ) + }) + } else { + segment.clone() + }; + let bracketed_ctas = parent_stack + .base + .iter() + .rev() + .take(2) + .map(|it| it.get_type()) + .eq([SyntaxKind::CreateTableStatement, SyntaxKind::Bracketed]); - let clone_map = SegmentCloneMap::new( - segment.first().unwrap().clone(), - segment.first().unwrap().deep_clone(), - ); + let clone_map = SegmentCloneMap::new(segment.first().unwrap().deep_clone()); - let results = self.lint_query( + let mut results_list = self.lint_query( context.tables, context.dialect, query, @@ -145,19 +167,16 @@ join c using(x) &clone_map, ); - let mut lint_results = Vec::with_capacity(results.len()); - let mut is_fixable = true; - - let mut subquery_parent = None; - let mut local_fixes = Vec::new(); - let mut q = Vec::new(); - for result in results { - let (lint_result, from_expression, alias_name, subquery_parent_slot) = result; - subquery_parent = Some(subquery_parent_slot.clone()); - let this_seg_clone = clone_map[&from_expression].clone(); - let new_table_ref = create_table_ref(context.tables, &alias_name, context.dialect); + for result in &results_list { + if !result.is_fixable { + continue; + } + + let this_seg_clone = clone_map[&result.from_expression].clone(); + let new_table_ref = + create_table_ref(context.tables, &result.alias_name, context.dialect); local_fixes.push(LintFix::replace( this_seg_clone.clone(), @@ -172,64 +191,99 @@ join c using(x) ], None, )); - - q.push(subquery_parent_slot); - - let bracketed_ctas = parent_stack - .base - .iter() - .rev() - .take(2) - .map(|it| it.get_type()) - .eq([SyntaxKind::CreateTableStatement, SyntaxKind::Bracketed]); - - if bracketed_ctas || ctes.has_duplicate_aliases() || is_recursive { - is_fixable = false; - } - - lint_results.push(lint_result); } - if !is_fixable { - return lint_results; + if bracketed_ctas || is_recursive || local_fixes.is_empty() { + return results_list + .into_iter() + .map(|result| result.lint_result) + .collect(); } let mut fixes = HashMap::default(); compute_anchor_edit_info(&mut fixes, local_fixes); let (new_root, _, _) = clone_map.root.apply_fixes(&mut fixes); - let clone_map = SegmentCloneMap::new(segment.first().unwrap().clone(), new_root.clone()); - for subquery_parent_slot in q { - ctes.replace_with_clone(subquery_parent_slot, &clone_map); + let clone_map = SegmentCloneMap::new(new_root.clone()); + for result in &results_list { + if result.is_fixable { + ctes.replace_with_clone(result.subquery_parent.clone(), &clone_map); + } } + for result in &results_list { + let Some(cte_source) = &result.cte_source else { + continue; + }; - let _segment = Segments::new(new_root, None); - let output_select = if is_with { - _segment.children_where(|it: &ErasedSegment| { - matches!( - it.get_type(), - SyntaxKind::SetExpression | SyntaxKind::SelectStatement - ) - }) - } else { - _segment.clone() - }; + let cte_source_clone = cte_source.deep_clone(); + let cte_clone_map = SegmentCloneMap::new(cte_source_clone.clone()); + let mut cte_fixes = Vec::new(); + + for nested_result in &results_list { + if !nested_result.is_fixable + || cte_clone_map.get(&nested_result.from_expression).is_none() + { + continue; + } + + let from_expression_clone = cte_clone_map[&nested_result.from_expression].clone(); + let new_table_ref = + create_table_ref(context.tables, &nested_result.alias_name, context.dialect); + + cte_fixes.push(LintFix::replace( + from_expression_clone.clone(), + vec![ + SegmentBuilder::node( + context.tables.next_id(), + from_expression_clone.get_type(), + context.dialect.name, + vec![new_table_ref], + ) + .finish(), + ], + None, + )); + } + + let cte_source = if cte_fixes.is_empty() { + cte_source_clone + } else { + let mut fixes = HashMap::default(); + compute_anchor_edit_info(&mut fixes, cte_fixes); + cte_source_clone.apply_fixes(&mut fixes).0 + }; + + let new_cte = create_cte_seg( + context.tables, + result.alias_name.clone(), + cte_source, + case_preference, + context.dialect, + ); + ctes.replace_cte(&result.alias_name, new_cte); + } // If there's no SELECT statement (e.g., WITH ... INSERT/UPDATE/DELETE), // we can't safely create fixes, so return lint results without fixes. if output_select.is_empty() { - return lint_results; + return results_list + .into_iter() + .map(|result| result.lint_result) + .collect(); } - for result in &mut lint_results { - let subquery_parent = subquery_parent.clone().unwrap(); - let output_select_clone = output_select[0].clone(); + let output_select_clone = clone_map[&output_select[0]].clone(); + + for result in &mut results_list { + if !result.is_fixable { + continue; + } let mut fixes = ctes.ensure_space_after_from( context.tables, output_select[0].clone(), &output_select_clone, - subquery_parent, + result.subquery_parent.clone(), ); let new_select = ctes.compose_select( @@ -239,16 +293,19 @@ join c using(x) case_preference, ); - result.fixes = vec![LintFix::replace( + result.lint_result.fixes = vec![LintFix::replace( segment.first().unwrap().clone(), vec![new_select], None, )]; - result.fixes.append(&mut fixes); + result.lint_result.fixes.append(&mut fixes); } - lint_results + results_list + .into_iter() + .map(|result| result.lint_result) + .collect() } fn is_fix_compatible(&self) -> bool { @@ -269,7 +326,7 @@ impl RuleST05 { ctes: &mut CTEBuilder, case_preference: Case, segment_clone_map: &SegmentCloneMap, - ) -> Vec<(LintResult, ErasedSegment, SmolStr, ErasedSegment)> { + ) -> Vec { let mut acc = Vec::new(); for nsq in self.nested_subqueries(query, dialect) { @@ -282,21 +339,39 @@ impl RuleST05 { .cloned() .unwrap(); - let new_cte = create_cte_seg( - tables, - alias_name.clone(), - segment_clone_map[&anchor].clone(), - case_preference, - dialect, - ); + let mut is_fixable = !ctes.list_used_names().contains(&alias_name); + let bracket_anchor = if anchor.is_type(SyntaxKind::TableExpression) { + let Some(bracket_anchor) = + anchor.child(const { &SyntaxSet::single(SyntaxKind::Bracketed) }) + else { + continue; + }; + + bracket_anchor + } else { + anchor.clone() + }; + + if !bracket_anchor.is_type(SyntaxKind::Bracketed) + || bracket_anchor + .child(const { &SyntaxSet::single(SyntaxKind::TableExpression) }) + .is_some() + { + is_fixable = false; + } - ctes.insert_cte(new_cte); + if is_fixable { + let new_cte = create_cte_seg( + tables, + alias_name.clone(), + segment_clone_map[&bracket_anchor].clone(), + case_preference, + dialect, + ); - if nsq.query.inner.borrow().selectables.len() != 1 { - continue; + ctes.insert_cte(new_cte); } - let select = nsq.query.inner.borrow().selectables[0].clone().selectable; let anchor = anchor.recursive_crawl( const { &SyntaxSet::new(&[ @@ -317,18 +392,20 @@ impl RuleST05 { Vec::new(), format!( "{} clauses should not contain subqueries. Use CTEs instead", - select.get_type().as_str() + nsq.selectable.selectable.get_type().as_str() ) .into(), None, ); - acc.push(( - res, - nsq.table_alias.from_expression_element, - alias_name.clone(), - nsq.query.inner.borrow().selectables[0].clone().selectable, - )); + acc.push(LintQueryResult { + lint_result: res, + from_expression: nsq.table_alias.from_expression_element, + alias_name: alias_name.clone(), + subquery_parent: nsq.selectable.selectable, + cte_source: is_fixable.then_some(bracket_anchor), + is_fixable, + }); } acc @@ -584,7 +661,9 @@ impl CTEBuilder { .into_iter() .any(|seg| segment.is(&seg)) { - self.ctes[idx] = clone_map[&self.ctes[idx]].clone(); + if let Some(cte_clone) = clone_map.get(&self.ctes[idx]) { + self.ctes[idx] = cte_clone.clone(); + } return; } } @@ -592,38 +671,39 @@ impl CTEBuilder { } impl CTEBuilder { - fn list_used_names(&self) -> Vec { - let mut used_names = Vec::new(); - - for cte in &self.ctes { - let id_seg = cte - .child( - const { - &SyntaxSet::new(&[ - SyntaxKind::Identifier, - SyntaxKind::NakedIdentifier, - SyntaxKind::QuotedIdentifier, - ]) - }, - ) - .unwrap(); - - let cte_name = if id_seg.is_type(SyntaxKind::QuotedIdentifier) { - let raw = id_seg.raw(); - raw[1..raw.len() - 1].to_smolstr() - } else { - id_seg.raw().to_smolstr() - }; + fn cte_name(cte: &ErasedSegment) -> SmolStr { + let id_seg = cte + .child( + const { + &SyntaxSet::new(&[ + SyntaxKind::Identifier, + SyntaxKind::NakedIdentifier, + SyntaxKind::QuotedIdentifier, + ]) + }, + ) + .unwrap(); - used_names.push(cte_name); + if id_seg.is_type(SyntaxKind::QuotedIdentifier) { + let raw = id_seg.raw(); + raw[1..raw.len() - 1].to_smolstr() + } else { + id_seg.raw().to_smolstr() } + } - used_names + fn list_used_names(&self) -> Vec { + self.ctes.iter().map(Self::cte_name).collect() } - fn has_duplicate_aliases(&self) -> bool { - let used_names = self.list_used_names(); - !used_names.into_iter().all_unique() + fn replace_cte(&mut self, cte_name: &str, new_cte: ErasedSegment) { + if let Some(idx) = self + .ctes + .iter() + .position(|cte| Self::cte_name(cte) == cte_name) + { + self.ctes[idx] = new_cte; + } } fn create_cte_alias(&mut self, alias: Option<&AliasInfo>) -> (SmolStr, bool) { @@ -766,32 +846,33 @@ fn create_table_ref(tables: &Tables, table_name: &str, dialect: &Dialect) -> Era pub(crate) struct SegmentCloneMap { root: ErasedSegment, - segment_map: HashMap, + segment_map: HashMap, } impl Index<&ErasedSegment> for SegmentCloneMap { type Output = ErasedSegment; fn index(&self, index: &ErasedSegment) -> &Self::Output { - &self.segment_map[&index.addr()] + &self.segment_map[&index.id()] } } impl IndexMut<&ErasedSegment> for SegmentCloneMap { fn index_mut(&mut self, index: &ErasedSegment) -> &mut Self::Output { - self.segment_map.get_mut(&index.addr()).unwrap() + self.segment_map.get_mut(&index.id()).unwrap() } } impl SegmentCloneMap { - fn new(segment: ErasedSegment, segment_copy: ErasedSegment) -> Self { + fn get(&self, old_segment: &ErasedSegment) -> Option<&ErasedSegment> { + self.segment_map.get(&old_segment.id()) + } + + fn new(segment_copy: ErasedSegment) -> Self { let mut segment_map = HashMap::new(); - for (old_segment, new_segment) in zip( - segment.recursive_crawl_all(false), - segment_copy.recursive_crawl_all(false), - ) { - segment_map.insert(old_segment.addr(), new_segment); + for segment in segment_copy.recursive_crawl_all(false) { + segment_map.insert(segment.id(), segment.clone()); } Self { @@ -800,3 +881,326 @@ impl SegmentCloneMap { } } } + +#[cfg(test)] +mod tests { + use crate::core::config::FluffConfig; + use crate::core::linter::core::Linter; + + fn st05_linter(dialect: &str) -> Linter { + let config = FluffConfig::from_source( + &format!( + r#" +[sqruff] +dialect = {dialect} +rules = ST05 + +[sqruff:rules:structure.subquery] +forbid_subquery_in = both +"# + ), + None, + ); + + Linter::new(config, None, None, false).unwrap() + } + + fn assert_fix(dialect: &str, source: &str, expected: &str) { + let mut linter = st05_linter(dialect); + let actual = linter + .lint_string_wrapped(source, true) + .unwrap() + .fix_string(); + + pretty_assertions::assert_eq!(actual, expected); + } + + #[test] + fn st05_fixes_double_nested_subquery_without_panicking() { + let source = r#"WITH q AS ( + SELECT + t1.a + FROM + table_1 AS t1 + INNER JOIN + table_2 AS t2 USING (a) + LEFT JOIN ( + SELECT DISTINCT a FROM table_3 + WHERE c = 'v1' + ) AS dns USING (a) + LEFT JOIN ( + SELECT DISTINCT a FROM table_5 + LEFT JOIN ( + SELECT DISTINCT + a, + b + FROM table_6 + WHERE c < 5 + ) AS t4 + USING (a) + WHERE table_5.b = 'v2' + ) AS dcod USING (a) +) +SELECT + a +FROM + q; +"#; + let expected = r#"WITH dns AS ( + SELECT DISTINCT a FROM table_3 + WHERE c = 'v1' + ), +t4 AS ( + SELECT DISTINCT + a, + b + FROM table_6 + WHERE c < 5 + ), +dcod AS ( + SELECT DISTINCT a FROM table_5 + LEFT JOIN t4 + USING (a) + WHERE table_5.b = 'v2' + ), +q AS ( + SELECT + t1.a + FROM + table_1 AS t1 + INNER JOIN + table_2 AS t2 USING (a) + LEFT JOIN dns USING (a) + LEFT JOIN dcod USING (a) +) +SELECT + a +FROM + q; +"#; + + assert_fix("ansi", source, expected); + } + + #[test] + fn st05_fixes_order_4782_without_panicking() { + let source = r#"WITH +cte_1 AS ( + SELECT + subquery_a.field_a, + subquery_a.field_b + FROM ( + SELECT + subquery_b.field_a, + alias_a.field_d, + alias_a.field_b, + alias_b.field_c + FROM table_b AS alias_a + INNER JOIN + (SELECT * FROM table_a) AS subquery_b + ON subquery_b.field_a >= alias_a.field_d + LEFT OUTER JOIN table_b AS alias_b ON alias_a.field_b = alias_b.field_c + ) AS subquery_a +), + +cte_2 AS ( + SELECT * + FROM table_c + WHERE field_a > 0 + ORDER BY field_b DESC +), + +join_ctes AS ( + SELECT * FROM cte_1 LEFT OUTER JOIN cte_2 ON cte_1.field_a = cte_2.field_a +) + +SELECT * +FROM join_ctes; +"#; + let expected = r#"WITH subquery_b AS (SELECT * FROM table_a), +subquery_a AS ( + SELECT + subquery_b.field_a, + alias_a.field_d, + alias_a.field_b, + alias_b.field_c + FROM table_b AS alias_a + INNER JOIN + subquery_b + ON subquery_b.field_a >= alias_a.field_d + LEFT OUTER JOIN table_b AS alias_b ON alias_a.field_b = alias_b.field_c + ), +cte_1 AS ( + SELECT + subquery_a.field_a, + subquery_a.field_b + FROM subquery_a +), +cte_2 AS ( + SELECT * + FROM table_c + WHERE field_a > 0 + ORDER BY field_b DESC +), +join_ctes AS ( + SELECT * FROM cte_1 LEFT OUTER JOIN cte_2 ON cte_1.field_a = cte_2.field_a +) +SELECT * +FROM join_ctes; +"#; + + assert_fix("ansi", source, expected); + } + + #[test] + fn st05_fixes_same_named_nested_subqueries_across_ctes() { + let source = r#"with purchases_in_the_last_year as ( + select + customer_id + , arrayagg(distinct attr) within group (order by attr asc) as attrlist + from ( + select + o.customer_id + , p.attr + from + order_line_item as o + inner join product as p + on o.product_id = p.product_id + and o.time_placed >= dateadd(year, -1, current_date()) + ) group by customer_id +) + +, purchases_in_the_last_three_years as ( + select + customer_id + , arrayagg(distinct attr) within group (order by attr asc) as attrlist + from ( + select + o.customer_id + , p.attr + from + order_line_item as o + inner join product as p + on o.product_id = p.product_id + and o.time_placed >= dateadd(year, -3, current_date()) + ) group by customer_id +) + + +select distinct + c.customer_id + , ly.attrlist as attrlist_last_year + , l3y.attrlist as attrlist_last_three_years +from + customers as c +left outer join + purchases_in_the_last_year as ly + on c.customer_id = ly.customer_id +left outer join + purchases_in_the_last_three_years as l3y + on c.customer_id = l3y.customer_id +; +"#; + let expected = r#"with prep_1 as ( + select + o.customer_id + , p.attr + from + order_line_item as o + inner join product as p + on o.product_id = p.product_id + and o.time_placed >= dateadd(year, -1, current_date()) + ), +purchases_in_the_last_year as ( + select + customer_id + , arrayagg(distinct attr) within group (order by attr asc) as attrlist + from prep_1 group by customer_id +), +prep_2 as ( + select + o.customer_id + , p.attr + from + order_line_item as o + inner join product as p + on o.product_id = p.product_id + and o.time_placed >= dateadd(year, -3, current_date()) + ), +purchases_in_the_last_three_years as ( + select + customer_id + , arrayagg(distinct attr) within group (order by attr asc) as attrlist + from prep_2 group by customer_id +) +select distinct + c.customer_id + , ly.attrlist as attrlist_last_year + , l3y.attrlist as attrlist_last_three_years +from + customers as c +left outer join + purchases_in_the_last_year as ly + on c.customer_id = ly.customer_id +left outer join + purchases_in_the_last_three_years as l3y + on c.customer_id = l3y.customer_id +; +"#; + + assert_fix("snowflake", source, expected); + } + + #[test] + fn st05_partially_fixes_duplicate_aliases_in_order_5265_case() { + let source = r#"WITH +cte1 AS ( + SELECT COUNT(*) AS qty + FROM some_table AS st + LEFT JOIN ( + SELECT 'first' AS id + ) AS oops + ON st.id = oops.id +), +cte2 AS ( + SELECT COUNT(*) AS other_qty + FROM other_table AS sot + LEFT JOIN ( + SELECT 'middle' AS id + ) AS another + ON sot.id = another.id + LEFT JOIN ( + SELECT 'last' AS id + ) AS oops + ON sot.id = oops.id +) +SELECT CURRENT_DATE(); +"#; + let expected = r#"WITH oops AS ( + SELECT 'first' AS id + ), +cte1 AS ( + SELECT COUNT(*) AS qty + FROM some_table AS st + LEFT JOIN oops + ON st.id = oops.id +), +another AS ( + SELECT 'middle' AS id + ), +cte2 AS ( + SELECT COUNT(*) AS other_qty + FROM other_table AS sot + LEFT JOIN another + ON sot.id = another.id + LEFT JOIN ( + SELECT 'last' AS id + ) AS oops + ON sot.id = oops.id +) +SELECT CURRENT_DATE(); +"#; + + assert_fix("ansi", source, expected); + } +} From a2a9047ae3226640b20d16231defcb645f211a9d Mon Sep 17 00:00:00 2001 From: gvozdvmozgu Date: Mon, 23 Mar 2026 15:24:39 -0400 Subject: [PATCH 02/10] fix(st05): preserve whitespace after FROM in root rewrites --- crates/lib/src/rules/structure/st05.rs | 42 +++++++++++++++++++++++--- 1 file changed, 37 insertions(+), 5 deletions(-) diff --git a/crates/lib/src/rules/structure/st05.rs b/crates/lib/src/rules/structure/st05.rs index 0c5af130e..8eecdbf8a 100644 --- a/crates/lib/src/rules/structure/st05.rs +++ b/crates/lib/src/rules/structure/st05.rs @@ -272,7 +272,7 @@ join c using(x) .collect(); } - let output_select_clone = clone_map[&output_select[0]].clone(); + let mut output_select_clone = clone_map[&output_select[0]].clone(); for result in &mut results_list { if !result.is_fixable { @@ -282,7 +282,7 @@ join c using(x) let mut fixes = ctes.ensure_space_after_from( context.tables, output_select[0].clone(), - &output_select_clone, + &mut output_select_clone, result.subquery_parent.clone(), ); @@ -584,17 +584,26 @@ impl CTEBuilder { &self, tables: &Tables, output_select: ErasedSegment, - output_select_clone: &ErasedSegment, + output_select_clone: &mut ErasedSegment, subquery_parent: ErasedSegment, ) -> Vec { let mut fixes = Vec::new(); if subquery_parent.is(&output_select) { - let (missing_space_after_from, _from_clause, _from_clause_children, _from_segment) = + let (missing_space_after_from, _from_clause, _from_clause_children, from_segment) = Self::missing_space_after_from(output_select_clone.clone()); if missing_space_after_from { - todo!() + let mut anchor_fixes = HashMap::default(); + compute_anchor_edit_info( + &mut anchor_fixes, + vec![LintFix::create_after( + from_segment.unwrap().base[0].clone(), + vec![SegmentBuilder::whitespace(tables.next_id(), " ")], + None, + )], + ); + *output_select_clone = output_select_clone.apply_fixes(&mut anchor_fixes).0; } } else { let (missing_space_after_from, _from_clause, _from_clause_children, from_segment) = @@ -1199,6 +1208,29 @@ cte2 AS ( ON sot.id = oops.id ) SELECT CURRENT_DATE(); +"#; + + assert_fix("ansi", source, expected); + } + + #[test] + fn st05_inserts_space_after_from_when_rewriting_root_subquery() { + let source = r#"CREATE TABLE t +AS +SELECT + col1 +FROM( + SELECT 'x' AS col1 +) x +"#; + let expected = r#"CREATE TABLE t +AS +WITH x AS ( + SELECT 'x' AS col1 +) +SELECT + col1 +FROM x "#; assert_fix("ansi", source, expected); From af1f87d78d8f29de971843d37bdfd18877d05639 Mon Sep 17 00:00:00 2001 From: gvozdvmozgu Date: Mon, 23 Mar 2026 15:27:46 -0400 Subject: [PATCH 03/10] test(st05): add regression coverage for UNION branch subqueries --- crates/lib/src/rules/structure/st05.rs | 37 ++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/crates/lib/src/rules/structure/st05.rs b/crates/lib/src/rules/structure/st05.rs index 8eecdbf8a..beb94f6e3 100644 --- a/crates/lib/src/rules/structure/st05.rs +++ b/crates/lib/src/rules/structure/st05.rs @@ -1231,6 +1231,43 @@ WITH x AS ( SELECT col1 FROM x +"#; + + assert_fix("ansi", source, expected); + } + + #[test] + fn st05_fixes_set_subquery_in_second_query() { + let source = r#"SELECT 1 AS value_name +UNION +SELECT value +FROM (SELECT 2 AS value_name); +"#; + let expected = r#"WITH prep_1 AS (SELECT 2 AS value_name) +SELECT 1 AS value_name +UNION +SELECT value +FROM prep_1; +"#; + + assert_fix("ansi", source, expected); + } + + #[test] + fn st05_fixes_multiple_set_subqueries_in_second_query() { + let source = r#"SELECT 1 AS value_name +UNION +SELECT value +FROM (SELECT 2 AS value_name) +CROSS JOIN (SELECT 1 as v2); +"#; + let expected = r#"WITH prep_1 AS (SELECT 2 AS value_name), +prep_2 AS (SELECT 1 as v2) +SELECT 1 AS value_name +UNION +SELECT value +FROM prep_1 +CROSS JOIN prep_2; "#; assert_fix("ansi", source, expected); From 4f96c6dcd58a050161877d797ffec180eabe16dc Mon Sep 17 00:00:00 2001 From: gvozdvmozgu Date: Mon, 23 Mar 2026 15:36:18 -0400 Subject: [PATCH 04/10] fix(st05): support T-SQL WITH-before-INSERT rewrites --- crates/lib/src/rules/structure/st05.rs | 89 +++++++++++++++++++++++++- 1 file changed, 86 insertions(+), 3 deletions(-) diff --git a/crates/lib/src/rules/structure/st05.rs b/crates/lib/src/rules/structure/st05.rs index beb94f6e3..13bf82d2e 100644 --- a/crates/lib/src/rules/structure/st05.rs +++ b/crates/lib/src/rules/structure/st05.rs @@ -113,6 +113,8 @@ join c using(x) let functional_context = FunctionalContext::new(context); let segment = functional_context.segment(); let parent_stack = functional_context.parent_stack(); + let insert_parent = parent_stack + .find_last_where(|it: &ErasedSegment| it.is_type(SyntaxKind::InsertStatement)); let is_select = segment.all_match(|it: &ErasedSegment| SELECT_TYPES.contains(it.get_type())); @@ -138,13 +140,19 @@ join c using(x) .is_empty(); let case_preference = get_case_preference(&segment); + let mut root_segment = segment.first().unwrap().clone(); let output_select = if is_with { segment.children_where(|it: &ErasedSegment| { matches!( it.get_type(), - SyntaxKind::SetExpression | SyntaxKind::SelectStatement + SyntaxKind::InsertStatement + | SyntaxKind::SetExpression + | SyntaxKind::SelectStatement ) }) + } else if context.dialect.name == DialectKind::Tsql && !insert_parent.is_empty() { + root_segment = insert_parent.first().unwrap().clone(); + insert_parent } else { segment.clone() }; @@ -156,7 +164,7 @@ join c using(x) .map(|it| it.get_type()) .eq([SyntaxKind::CreateTableStatement, SyntaxKind::Bracketed]); - let clone_map = SegmentCloneMap::new(segment.first().unwrap().deep_clone()); + let clone_map = SegmentCloneMap::new(root_segment.deep_clone()); let mut results_list = self.lint_query( context.tables, @@ -294,7 +302,7 @@ join c using(x) ); result.lint_result.fixes = vec![LintFix::replace( - segment.first().unwrap().clone(), + root_segment.clone(), vec![new_select], None, )]; @@ -1272,4 +1280,79 @@ CROSS JOIN prep_2; assert_fix("ansi", source, expected); } + + #[test] + fn st05_rewrites_tsql_insert_to_with_before_insert() { + let source = r#"INSERT INTO Table1 (Id,Name,Attribute) +SELECT +Main.Id +,Main.Name +,Subq.Attribute +FROM MainTable AS Main +LEFT JOIN +(SELECT +Id +,Attribute +FROM Table2) Subq +ON Main.Id = Subq.Id +"#; + let expected = r#"WITH Subq AS (SELECT +Id +,Attribute +FROM Table2) +INSERT INTO Table1 (Id,Name,Attribute) +SELECT +Main.Id +,Main.Name +,Subq.Attribute +FROM MainTable AS Main +LEFT JOIN +Subq +ON Main.Id = Subq.Id +"#; + + assert_fix("tsql", source, expected); + } + + #[test] + fn st05_rewrites_existing_tsql_with_insert_statement() { + let source = r#"WITH MainTable as ( + select * + from sales +) + +INSERT INTO Table1 (Id,Name,Attribute) +SELECT +Main.Id +,Main.Name +,Subq.Attribute +FROM MainTable AS Main +LEFT JOIN +(SELECT +Id +,Attribute +FROM Table2) Subq +ON Main.Id = Subq.Id +"#; + let expected = r#"WITH MainTable as ( + select * + from sales +), +Subq AS (SELECT +Id +,Attribute +FROM Table2) +INSERT INTO Table1 (Id,Name,Attribute) +SELECT +Main.Id +,Main.Name +,Subq.Attribute +FROM MainTable AS Main +LEFT JOIN +Subq +ON Main.Id = Subq.Id +"#; + + assert_fix("tsql", source, expected); + } } From 0c58b57eb189fa97620273cd9a493406a8b49d1f Mon Sep 17 00:00:00 2001 From: gvozdvmozgu Date: Mon, 23 Mar 2026 15:46:21 -0400 Subject: [PATCH 05/10] fix(st05): handle nested joined subqueries without unsafe fixes --- crates/lib-core/src/parser/segments/from.rs | 9 ++- crates/lib-dialects/src/ansi.rs | 64 +++++++++++++-------- crates/lib/src/rules/structure/st05.rs | 53 +++++++++++++++++ 3 files changed, 102 insertions(+), 24 deletions(-) diff --git a/crates/lib-core/src/parser/segments/from.rs b/crates/lib-core/src/parser/segments/from.rs index 26fcae7a8..d5dce0319 100644 --- a/crates/lib-core/src/parser/segments/from.rs +++ b/crates/lib-core/src/parser/segments/from.rs @@ -85,7 +85,14 @@ impl FromExpressionElementSegment { let alias_expression = self .0 - .child(const { &SyntaxSet::new(&[SyntaxKind::AliasExpression]) }); + .child(const { &SyntaxSet::new(&[SyntaxKind::AliasExpression]) }) + .or_else(|| { + self.0 + .child(const { &SyntaxSet::new(&[SyntaxKind::Bracketed]) }) + .and_then(|bracketed| { + bracketed.child(const { &SyntaxSet::new(&[SyntaxKind::AliasExpression]) }) + }) + }); if let Some(alias_expression) = alias_expression { let segment = alias_expression.child( const { diff --git a/crates/lib-dialects/src/ansi.rs b/crates/lib-dialects/src/ansi.rs index 6e9044029..9877efc94 100644 --- a/crates/lib-dialects/src/ansi.rs +++ b/crates/lib-dialects/src/ansi.rs @@ -4707,34 +4707,52 @@ pub fn raw_dialect() -> Dialect { ( "FromExpressionElementSegment".into(), NodeMatcher::new(SyntaxKind::FromExpressionElement, |_| { - Sequence::new(vec![ - Ref::new("PreTableFunctionKeywordsGrammar") - .optional() + let base_from_expression_element = || { + Sequence::new(vec![ + Ref::new("PreTableFunctionKeywordsGrammar") + .optional() + .to_matchable(), + optionally_bracketed(vec![ + Ref::new("TableExpressionSegment").to_matchable(), + ]) .to_matchable(), - optionally_bracketed(vec![Ref::new("TableExpressionSegment").to_matchable()]) + Ref::new("AliasExpressionSegment") + .exclude(one_of(vec![ + Ref::new("FromClauseTerminatorGrammar").to_matchable(), + Ref::new("SamplingExpressionSegment").to_matchable(), + Ref::new("JoinLikeClauseGrammar").to_matchable(), + LookaheadExclude::new("WITH", "(").to_matchable(), + ])) + .optional() + .to_matchable(), + Sequence::new(vec![ + Ref::keyword("WITH").to_matchable(), + Ref::keyword("OFFSET").to_matchable(), + Ref::new("AliasExpressionSegment").to_matchable(), + ]) + .config(|this| this.optional()) .to_matchable(), - Ref::new("AliasExpressionSegment") - .exclude(one_of(vec![ - Ref::new("FromClauseTerminatorGrammar").to_matchable(), - Ref::new("SamplingExpressionSegment").to_matchable(), - Ref::new("JoinLikeClauseGrammar").to_matchable(), - LookaheadExclude::new("WITH", "(").to_matchable(), - ])) - .optional() + Ref::new("SamplingExpressionSegment") + .optional() + .to_matchable(), + Ref::new("PostTableExpressionGrammar") + .optional() + .to_matchable(), + ]) + .to_matchable() + }; + + one_of(vec![ + base_from_expression_element(), + Bracketed::new(vec![ + Sequence::new(vec![ + base_from_expression_element(), + AnyNumberOf::new(vec![Ref::new("JoinClauseSegment").to_matchable()]) + .to_matchable(), + ]) .to_matchable(), - Sequence::new(vec![ - Ref::keyword("WITH").to_matchable(), - Ref::keyword("OFFSET").to_matchable(), - Ref::new("AliasExpressionSegment").to_matchable(), ]) - .config(|this| this.optional()) .to_matchable(), - Ref::new("SamplingExpressionSegment") - .optional() - .to_matchable(), - Ref::new("PostTableExpressionGrammar") - .optional() - .to_matchable(), ]) .to_matchable() }) diff --git a/crates/lib/src/rules/structure/st05.rs b/crates/lib/src/rules/structure/st05.rs index 13bf82d2e..e3199a803 100644 --- a/crates/lib/src/rules/structure/st05.rs +++ b/crates/lib/src/rules/structure/st05.rs @@ -932,6 +932,25 @@ forbid_subquery_in = both pretty_assertions::assert_eq!(actual, expected); } + fn assert_lint_without_fix(dialect: &str, source: &str, expected_violations: usize) { + let mut linter = st05_linter(dialect); + let linted = linter.lint_string_wrapped(source, false).unwrap(); + assert_eq!(linted.violations().len(), expected_violations); + + let mut linter = st05_linter(dialect); + let fixed = linter + .lint_string_wrapped(source, true) + .unwrap() + .fix_string(); + pretty_assertions::assert_eq!(fixed, source); + } + + fn assert_pass(dialect: &str, source: &str) { + let mut linter = st05_linter(dialect); + let linted = linter.lint_string_wrapped(source, false).unwrap(); + assert!(linted.violations().is_empty(), "{:?}", linted.violations()); + } + #[test] fn st05_fixes_double_nested_subquery_without_panicking() { let source = r#"WITH q AS ( @@ -1355,4 +1374,38 @@ ON Main.Id = Subq.Id assert_fix("tsql", source, expected); } + + #[test] + fn st05_lints_nested_joined_subquery_without_fixing() { + let source = r#"SELECT + x.a, + w2.b +FROM x +LEFT JOIN ( + ( + SELECT + w.a, + w.b, + w.c + FROM w + ) AS w2 + LEFT JOIN y + ON w2.a = y.a +) + ON x.a = w2.a; +"#; + + assert_lint_without_fix("ansi", source, 1); + } + + #[test] + fn st05_ignores_nested_table_function_subqueries() { + let source = r#"SELECT * +FROM `func`(( + SELECT 1 +)); +"#; + + assert_pass("bigquery", source); + } } From 0e97458b231712e1732b3580c3f1a5edf78ea3c6 Mon Sep 17 00:00:00 2001 From: gvozdvmozgu Date: Mon, 23 Mar 2026 15:47:55 -0400 Subject: [PATCH 06/10] fix(st05): validate forbid_subquery_in config values --- crates/lib/src/rules/structure/st05.rs | 41 ++++++++++++++++++++++++-- 1 file changed, 38 insertions(+), 3 deletions(-) diff --git a/crates/lib/src/rules/structure/st05.rs b/crates/lib/src/rules/structure/st05.rs index e3199a803..9ddc4edef 100644 --- a/crates/lib/src/rules/structure/st05.rs +++ b/crates/lib/src/rules/structure/st05.rs @@ -60,10 +60,17 @@ pub(crate) struct RuleST05 { impl Rule for RuleST05 { fn load_from_config(&self, config: &HashMap) -> Result { - Ok(RuleST05 { - forbid_subquery_in: config["forbid_subquery_in"].as_string().unwrap().into(), + match config["forbid_subquery_in"].as_string() { + Some("join" | "from" | "both") => Ok(RuleST05 { + forbid_subquery_in: config["forbid_subquery_in"].as_string().unwrap().into(), + } + .erased()), + Some(value) => Err(format!( + "Invalid value for forbid_subquery_in: {value}. Must be one of [join, from, \ + both]" + )), + None => Err("Rule ST05 expects a string for `forbid_subquery_in`".into()), } - .erased()) } fn name(&self) -> &'static str { @@ -901,8 +908,14 @@ impl SegmentCloneMap { #[cfg(test)] mod tests { + use hashbrown::HashMap; + use crate::core::config::FluffConfig; + use crate::core::config::Value; use crate::core::linter::core::Linter; + use crate::core::rules::Rule; + + use super::RuleST05; fn st05_linter(dialect: &str) -> Linter { let config = FluffConfig::from_source( @@ -1408,4 +1421,26 @@ FROM `func`(( assert_pass("bigquery", source); } + + #[test] + fn st05_load_from_config_rejects_invalid_value() { + let config = HashMap::from_iter([( + "forbid_subquery_in".into(), + Value::String("sideways".into()), + )]); + + let err = RuleST05::default().load_from_config(&config).unwrap_err(); + assert_eq!( + err, + "Invalid value for forbid_subquery_in: sideways. Must be one of [join, from, both]" + ); + } + + #[test] + fn st05_load_from_config_accepts_valid_value() { + let config = + HashMap::from_iter([("forbid_subquery_in".into(), Value::String("from".into()))]); + + assert!(RuleST05::default().load_from_config(&config).is_ok()); + } } From d53c6d60484f91e933254e82e01db2dd5798d75b Mon Sep 17 00:00:00 2001 From: gvozdvmozgu Date: Mon, 23 Mar 2026 16:03:50 -0400 Subject: [PATCH 07/10] fix(st05): preserve ClickHouse expression CTEs during subquery rewrites --- crates/lib/src/rules/structure/st05.rs | 68 ++++++++++++++++++++++++-- 1 file changed, 64 insertions(+), 4 deletions(-) diff --git a/crates/lib/src/rules/structure/st05.rs b/crates/lib/src/rules/structure/st05.rs index 9ddc4edef..3abffc8e4 100644 --- a/crates/lib/src/rules/structure/st05.rs +++ b/crates/lib/src/rules/structure/st05.rs @@ -132,15 +132,25 @@ join c using(x) return Vec::new(); } + let is_with = + segment.all_match(|it: &ErasedSegment| it.is_type(SyntaxKind::WithCompoundStatement)); let query: Query<'_> = Query::from_segment(&context.segment, context.dialect, None); let mut ctes = CTEBuilder::default(); - for cte in query.inner.borrow().ctes.values() { - ctes.insert_cte(cte.inner.borrow().cte_definition_segment.clone().unwrap()); + if is_with { + // Mirror SQLFluff by preserving all CTEs in document order, including + // non-select expression CTEs (e.g. ClickHouse `expr AS alias`) which + // Query::from_segment intentionally skips in `query.ctes`. + for cte in context.segment.recursive_crawl( + const { &SyntaxSet::single(SyntaxKind::CommonTableExpression) }, + false, + const { &SyntaxSet::single(SyntaxKind::WithCompoundStatement) }, + false, + ) { + ctes.insert_cte(cte); + } } - let is_with = - segment.all_match(|it: &ErasedSegment| it.is_type(SyntaxKind::WithCompoundStatement)); let is_recursive = is_with && !segment .children_where(|it: &ErasedSegment| it.is_keyword("recursive")) @@ -1422,6 +1432,56 @@ FROM `func`(( assert_pass("bigquery", source); } + #[test] + fn st05_preserves_clickhouse_expression_ctes_when_fixing_join_subquery() { + let config = FluffConfig::from_source( + r#" +[sqruff] +dialect = clickhouse +rules = ST05 +"#, + None, + ); + let source = r#"insert into demo_table +with +array('A', 'B') as keep_list +select + orders.id, + default_tags.tag +from ( + select 1 as id, 'A' as tag +) as orders +left join ( + select 1 as id, 'x' as tag +) as default_tags + on orders.id = default_tags.id +where orders.tag in keep_list +"#; + let expected = r#"insert into demo_table +with array('A', 'B') as keep_list, +default_tags as ( + select 1 as id, 'x' as tag +) +select + orders.id, + default_tags.tag +from ( + select 1 as id, 'A' as tag +) as orders +left join default_tags + on orders.id = default_tags.id +where orders.tag in keep_list +"#; + + let mut linter = Linter::new(config, None, None, false).unwrap(); + let actual = linter + .lint_string_wrapped(source, true) + .unwrap() + .fix_string(); + + pretty_assertions::assert_eq!(actual, expected); + } + #[test] fn st05_load_from_config_rejects_invalid_value() { let config = HashMap::from_iter([( From cad2e5e277e4e9239a312c68fe38eb3dc4106f8b Mon Sep 17 00:00:00 2001 From: gvozdvmozgu Date: Mon, 23 Mar 2026 16:10:01 -0400 Subject: [PATCH 08/10] fix(st05): preserve generated CTE order across LT09 rewrites --- crates/lib/src/rules/structure/st05.rs | 78 +++++++++++++++++++++++++- 1 file changed, 76 insertions(+), 2 deletions(-) diff --git a/crates/lib/src/rules/structure/st05.rs b/crates/lib/src/rules/structure/st05.rs index 3abffc8e4..b9cc77bc4 100644 --- a/crates/lib/src/rules/structure/st05.rs +++ b/crates/lib/src/rules/structure/st05.rs @@ -788,11 +788,15 @@ fn is_child(maybe_parent: Segments, maybe_child: Segments) -> bool { let child_markers = maybe_child[0].get_position_marker().unwrap(); let parent_pos = maybe_parent[0].get_position_marker().unwrap(); - if child_markers < &parent_pos.start_point_marker() { + // Preserve the original source nesting relationship even after earlier fix + // passes have changed working positions within the tree. Structural rewrites + // like ST05 rely on this to keep generated CTEs ahead of the CTEs/selects + // they were extracted from. + if child_markers.source_slice.start < parent_pos.source_slice.start { return false; } - if child_markers > &parent_pos.end_point_marker() { + if child_markers.source_slice.end > parent_pos.source_slice.end { return false; } @@ -1503,4 +1507,74 @@ where orders.tag in keep_list assert!(RuleST05::default().load_from_config(&config).is_ok()); } + + #[test] + fn st05_with_lt09_keeps_generated_cte_before_first_use() { + let config = FluffConfig::from_source( + r#" +[sqruff] +dialect = ansi +rules = ST05,LT09 + +[sqruff:rules:structure.subquery] +forbid_subquery_in = both +"#, + None, + ); + let source = r#"with + +cte1 as ( + select t1.x, t2.y + from tbl1 t1 + join (select x, y from tbl2) t2 + on t1.x = t2.x +) + +, cte2 as ( + select x, y from tbl2 t2 +) + +select x, y from cte1 +union all +select x, y from cte2 +; +"#; + let expected = r#"with t2 as (select +x, +y +from tbl2), +cte1 as ( + select +t1.x, +t2.y + from tbl1 t1 + join t2 + on t1.x = t2.x +), +cte2 as ( + select +x, +y +from tbl2 t2 +) +select +x, +y +from cte1 +union all +select +x, +y +from cte2 +; +"#; + + let mut linter = Linter::new(config, None, None, false).unwrap(); + let fixed = linter + .lint_string_wrapped(source, true) + .unwrap() + .fix_string(); + + pretty_assertions::assert_eq!(fixed, expected); + } } From 0a938ee81b2efbfe350eb387d8fa1211c212e540 Mon Sep 17 00:00:00 2001 From: gvozdvmozgu Date: Mon, 23 Mar 2026 16:14:28 -0400 Subject: [PATCH 09/10] test(st05): move regression coverage to YAML fixtures --- crates/lib/src/rules/structure/st05.rs | 119 ------------------ .../fixtures/rules/std_rule_cases/ST05.yml | 35 ++++++ .../rules/std_rule_cases/ST05_LT09.yml | 58 +++++++++ 3 files changed, 93 insertions(+), 119 deletions(-) create mode 100644 crates/lib/test/fixtures/rules/std_rule_cases/ST05_LT09.yml diff --git a/crates/lib/src/rules/structure/st05.rs b/crates/lib/src/rules/structure/st05.rs index b9cc77bc4..8c1099cdd 100644 --- a/crates/lib/src/rules/structure/st05.rs +++ b/crates/lib/src/rules/structure/st05.rs @@ -1436,56 +1436,6 @@ FROM `func`(( assert_pass("bigquery", source); } - #[test] - fn st05_preserves_clickhouse_expression_ctes_when_fixing_join_subquery() { - let config = FluffConfig::from_source( - r#" -[sqruff] -dialect = clickhouse -rules = ST05 -"#, - None, - ); - let source = r#"insert into demo_table -with -array('A', 'B') as keep_list -select - orders.id, - default_tags.tag -from ( - select 1 as id, 'A' as tag -) as orders -left join ( - select 1 as id, 'x' as tag -) as default_tags - on orders.id = default_tags.id -where orders.tag in keep_list -"#; - let expected = r#"insert into demo_table -with array('A', 'B') as keep_list, -default_tags as ( - select 1 as id, 'x' as tag -) -select - orders.id, - default_tags.tag -from ( - select 1 as id, 'A' as tag -) as orders -left join default_tags - on orders.id = default_tags.id -where orders.tag in keep_list -"#; - - let mut linter = Linter::new(config, None, None, false).unwrap(); - let actual = linter - .lint_string_wrapped(source, true) - .unwrap() - .fix_string(); - - pretty_assertions::assert_eq!(actual, expected); - } - #[test] fn st05_load_from_config_rejects_invalid_value() { let config = HashMap::from_iter([( @@ -1508,73 +1458,4 @@ where orders.tag in keep_list assert!(RuleST05::default().load_from_config(&config).is_ok()); } - #[test] - fn st05_with_lt09_keeps_generated_cte_before_first_use() { - let config = FluffConfig::from_source( - r#" -[sqruff] -dialect = ansi -rules = ST05,LT09 - -[sqruff:rules:structure.subquery] -forbid_subquery_in = both -"#, - None, - ); - let source = r#"with - -cte1 as ( - select t1.x, t2.y - from tbl1 t1 - join (select x, y from tbl2) t2 - on t1.x = t2.x -) - -, cte2 as ( - select x, y from tbl2 t2 -) - -select x, y from cte1 -union all -select x, y from cte2 -; -"#; - let expected = r#"with t2 as (select -x, -y -from tbl2), -cte1 as ( - select -t1.x, -t2.y - from tbl1 t1 - join t2 - on t1.x = t2.x -), -cte2 as ( - select -x, -y -from tbl2 t2 -) -select -x, -y -from cte1 -union all -select -x, -y -from cte2 -; -"#; - - let mut linter = Linter::new(config, None, None, false).unwrap(); - let fixed = linter - .lint_string_wrapped(source, true) - .unwrap() - .fix_string(); - - pretty_assertions::assert_eq!(fixed, expected); - } } diff --git a/crates/lib/test/fixtures/rules/std_rule_cases/ST05.yml b/crates/lib/test/fixtures/rules/std_rule_cases/ST05.yml index 99cc0aa33..1ef656bab 100644 --- a/crates/lib/test/fixtures/rules/std_rule_cases/ST05.yml +++ b/crates/lib/test/fixtures/rules/std_rule_cases/ST05.yml @@ -638,3 +638,38 @@ test_fail_subquery_in_cte_3: rules: structure.subquery: forbid_subquery_in: both + +test_fail_clickhouse_expression_cte_preserved: + fail_str: | + insert into demo_table + with + array('A', 'B') as keep_list + select + orders.id, + default_tags.tag + from ( + select 1 as id, 'A' as tag + ) as orders + left join ( + select 1 as id, 'x' as tag + ) as default_tags + on orders.id = default_tags.id + where orders.tag in keep_list + fix_str: | + insert into demo_table + with array('A', 'B') as keep_list, + default_tags as ( + select 1 as id, 'x' as tag + ) + select + orders.id, + default_tags.tag + from ( + select 1 as id, 'A' as tag + ) as orders + left join default_tags + on orders.id = default_tags.id + where orders.tag in keep_list + configs: + core: + dialect: clickhouse diff --git a/crates/lib/test/fixtures/rules/std_rule_cases/ST05_LT09.yml b/crates/lib/test/fixtures/rules/std_rule_cases/ST05_LT09.yml new file mode 100644 index 000000000..1e5f80f4a --- /dev/null +++ b/crates/lib/test/fixtures/rules/std_rule_cases/ST05_LT09.yml @@ -0,0 +1,58 @@ +rule: ST05,LT09 + +test_fix_generated_cte_stays_before_first_use: + # https://github.com/sqlfluff/sqlfluff/issues/4137 + # Tests multi-rule interaction: ST05 extracts the join subquery into a CTE, + # and LT09 reformats the SELECT targets. The generated CTE must stay before + # the first CTE that references it. + fail_str: | + with + + cte1 as ( + select t1.x, t2.y + from tbl1 t1 + join (select x, y from tbl2) t2 + on t1.x = t2.x + ) + + , cte2 as ( + select x, y from tbl2 t2 + ) + + select x, y from cte1 + union all + select x, y from cte2 + ; + fix_str: | + with t2 as (select + x, + y + from tbl2), + cte1 as ( + select + t1.x, + t2.y + from tbl1 t1 + join t2 + on t1.x = t2.x + ), + cte2 as ( + select + x, + y + from tbl2 t2 + ) + select + x, + y + from cte1 + union all + select + x, + y + from cte2 + ; + configs: + rules: + structure.subquery: + forbid_subquery_in: both From 823e0531bd921c224740c04a41c3b7971d73a978 Mon Sep 17 00:00:00 2001 From: gvozdvmozgu Date: Mon, 23 Mar 2026 17:40:43 -0400 Subject: [PATCH 10/10] refactor(st05): harden traversal and identifier handling --- crates/lib/src/rules/structure/st05.rs | 396 ++++++++++-------- .../fixtures/rules/std_rule_cases/ST05.yml | 52 +++ 2 files changed, 279 insertions(+), 169 deletions(-) diff --git a/crates/lib/src/rules/structure/st05.rs b/crates/lib/src/rules/structure/st05.rs index 8c1099cdd..450d599c7 100644 --- a/crates/lib/src/rules/structure/st05.rs +++ b/crates/lib/src/rules/structure/st05.rs @@ -1,7 +1,6 @@ -use std::ops::{Index, IndexMut}; +use std::ops::Index; use hashbrown::{HashMap, HashSet}; -use itertools::{Itertools, enumerate}; use smol_str::{SmolStr, StrExt, ToSmolStr, format_smolstr}; use sqruff_lib_core::dialects::Dialect; use sqruff_lib_core::dialects::common::AliasInfo; @@ -27,27 +26,107 @@ const SELECT_TYPES: SyntaxSet = SyntaxSet::new(&[ SyntaxKind::SelectStatement, ]); -fn config_mapping(key: &str) -> SyntaxSet { - match key { - "join" => SyntaxSet::single(SyntaxKind::JoinClause), - "from" => SyntaxSet::single(SyntaxKind::FromExpressionElement), - "both" => SyntaxSet::new(&[SyntaxKind::JoinClause, SyntaxKind::FromExpressionElement]), - _ => unreachable!("Invalid value for 'forbid_subquery_in': {key}"), +fn normalize_identifier_name(raw: &str) -> SmolStr { + let is_bracket_quoted = raw.starts_with('[') && raw.ends_with(']') && raw.len() >= 2; + let is_matching_quote_quoted = matches!(raw.chars().next(), Some('"') | Some('`')) + && raw.len() >= 2 + && raw.chars().next() == raw.chars().last(); + + if is_bracket_quoted || is_matching_quote_quoted { + raw[1..raw.len() - 1].into() + } else { + raw.into() + } +} + +#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)] +enum ForbidSubqueryIn { + Join, + From, + #[default] + Both, +} + +impl ForbidSubqueryIn { + fn syntax_set(self) -> SyntaxSet { + match self { + Self::Join => SyntaxSet::single(SyntaxKind::JoinClause), + Self::From => SyntaxSet::single(SyntaxKind::FromExpressionElement), + Self::Both => { + SyntaxSet::new(&[SyntaxKind::JoinClause, SyntaxKind::FromExpressionElement]) + } + } + } +} + +impl std::str::FromStr for ForbidSubqueryIn { + type Err = String; + + fn from_str(value: &str) -> Result { + match value { + "join" => Ok(Self::Join), + "from" => Ok(Self::From), + "both" => Ok(Self::Both), + _ => Err(format!( + "Invalid value for forbid_subquery_in: {value}. Must be one of [join, from, \ + both]" + )), + } } } -#[allow(dead_code)] -struct NestedSubQuerySummary<'a> { - query: Query<'a>, +#[derive(Clone, Debug, Eq, PartialEq, Hash)] +struct IdentifierSpec { + raw: SmolStr, + normalized: SmolStr, + kind: SyntaxKind, +} + +impl IdentifierSpec { + fn new(raw: SmolStr, kind: SyntaxKind) -> Self { + Self { + normalized: normalize_identifier_name(&raw), + raw, + kind: match kind { + SyntaxKind::Identifier + | SyntaxKind::NakedIdentifier + | SyntaxKind::QuotedIdentifier => kind, + _ => SyntaxKind::NakedIdentifier, + }, + } + } + + fn from_segment(segment: &ErasedSegment) -> Self { + Self::new(segment.raw().clone(), segment.get_type()) + } + + fn from_alias(alias: &AliasInfo) -> Option { + if alias.ref_str.is_empty() { + return None; + } + + alias.segment.as_ref().map(Self::from_segment).or_else(|| { + Some(Self::new( + alias.ref_str.clone(), + SyntaxKind::NakedIdentifier, + )) + }) + } + + fn generated(raw: SmolStr) -> Self { + Self::new(raw, SyntaxKind::NakedIdentifier) + } +} + +struct NestedSubqueryCandidate<'a> { selectable: Selectable<'a>, table_alias: AliasInfo, - select_source_names: HashSet, } struct LintQueryResult { lint_result: LintResult, from_expression: ErasedSegment, - alias_name: SmolStr, + alias_name: IdentifierSpec, subquery_parent: ErasedSegment, cte_source: Option, is_fixable: bool, @@ -55,22 +134,20 @@ struct LintQueryResult { #[derive(Clone, Debug, Default)] pub(crate) struct RuleST05 { - forbid_subquery_in: String, + forbid_subquery_in: ForbidSubqueryIn, } impl Rule for RuleST05 { fn load_from_config(&self, config: &HashMap) -> Result { - match config["forbid_subquery_in"].as_string() { - Some("join" | "from" | "both") => Ok(RuleST05 { - forbid_subquery_in: config["forbid_subquery_in"].as_string().unwrap().into(), - } - .erased()), - Some(value) => Err(format!( - "Invalid value for forbid_subquery_in: {value}. Must be one of [join, from, \ - both]" - )), - None => Err("Rule ST05 expects a string for `forbid_subquery_in`".into()), - } + let forbid_subquery_in = match config.get("forbid_subquery_in") { + Some(value) => value + .as_string() + .ok_or_else(|| "Rule ST05 expects a string for `forbid_subquery_in`".to_string())? + .parse()?, + None => self.forbid_subquery_in, + }; + + Ok(RuleST05 { forbid_subquery_in }.erased()) } fn name(&self) -> &'static str { @@ -285,7 +362,7 @@ join c using(x) case_preference, context.dialect, ); - ctes.replace_cte(&result.alias_name, new_cte); + ctes.replace_cte(result.alias_name.normalized.as_str(), new_cte); } // If there's no SELECT statement (e.g., WITH ... INSERT/UPDATE/DELETE), @@ -355,16 +432,16 @@ impl RuleST05 { let mut acc = Vec::new(); for nsq in self.nested_subqueries(query, dialect) { - let (alias_name, _) = ctes.create_cte_alias(Some(&nsq.table_alias)); + let alias_name = ctes.create_cte_alias(Some(&nsq.table_alias)); let anchor = nsq .table_alias .from_expression_element .segments() .first() .cloned() - .unwrap(); + .expect("from_expression_element should have at least one segment"); - let mut is_fixable = !ctes.list_used_names().contains(&alias_name); + let mut is_fixable = !ctes.is_name_used(&alias_name.normalized); let bracket_anchor = if anchor.is_type(SyntaxKind::TableExpression) { let Some(bracket_anchor) = anchor.child(const { &SyntaxSet::single(SyntaxKind::Bracketed) }) @@ -397,20 +474,23 @@ impl RuleST05 { ctes.insert_cte(new_cte); } - let anchor = anchor.recursive_crawl( - const { - &SyntaxSet::new(&[ - SyntaxKind::Keyword, - SyntaxKind::Symbol, - SyntaxKind::StartBracket, - SyntaxKind::EndBracket, - ]) - }, - true, - &SyntaxSet::EMPTY, - true, - )[0] - .clone(); + let anchor = anchor + .recursive_crawl( + const { + &SyntaxSet::new(&[ + SyntaxKind::Keyword, + SyntaxKind::Symbol, + SyntaxKind::StartBracket, + SyntaxKind::EndBracket, + ]) + }, + true, + &SyntaxSet::EMPTY, + true, + ) + .first() + .cloned() + .expect("expected a code anchor inside the offending subquery"); let res = LintResult::new( anchor.into(), @@ -440,21 +520,32 @@ impl RuleST05 { &self, query: Query<'a>, dialect: &'a Dialect, - ) -> Vec> { + ) -> Vec> { let mut acc = Vec::new(); + self.collect_nested_subqueries(&query, dialect, &mut acc); + acc + } - let parent_types = config_mapping(&self.forbid_subquery_in); + fn collect_nested_subqueries<'a>( + &self, + query: &Query<'a>, + dialect: &'a Dialect, + acc: &mut Vec>, + ) { + let parent_types = self.forbid_subquery_in.syntax_set(); let mut queries = vec![query.clone()]; queries.extend(query.inner.borrow().ctes.values().cloned()); - for (i, q) in enumerate(queries) { - for selectable in &q.inner.borrow().selectables { + for current_query in queries { + let selectables = current_query.inner.borrow().selectables.clone(); + + for selectable in selectables { let Some(select_info) = selectable.select_info() else { continue; }; let mut select_source_names = HashSet::new(); - for table_alias in select_info.table_aliases { + for table_alias in &select_info.table_aliases { if !table_alias.ref_str.is_empty() { select_source_names.insert(table_alias.ref_str.clone()); } @@ -462,8 +553,10 @@ impl RuleST05 { if let Some(object_reference) = &table_alias.object_reference { select_source_names.insert(object_reference.raw().to_smolstr()); } + } - let Some(query) = + for table_alias in select_info.table_aliases { + let Some(subquery) = Query::from_root(&table_alias.from_expression_element, dialect) else { continue; @@ -472,48 +565,35 @@ impl RuleST05 { let path_to = selectable .selectable .path_to(&table_alias.from_expression_element); - - if !(parent_types.contains(table_alias.from_expression_element.get_type()) + let is_target_parent = parent_types + .contains(table_alias.from_expression_element.get_type()) || path_to .iter() - .any(|ps| parent_types.contains(ps.segment.get_type()))) - { - continue; - } + .any(|ps| parent_types.contains(ps.segment.get_type())); - if is_correlated_subquery( - Segments::new( - query - .inner - .borrow() - .selectables - .first() - .unwrap() - .selectable - .clone(), - None, - ), - &select_source_names, - dialect, - ) { + let Some(nested_selectable) = + subquery.inner.borrow().selectables.first().cloned() + else { continue; - } - - acc.push(NestedSubQuerySummary { - query: q.clone(), - selectable: selectable.clone(), - table_alias: table_alias.clone(), - select_source_names: select_source_names.clone(), - }); + }; - if i > 0 { - acc.append(&mut self.nested_subqueries(query.clone(), dialect)); + if is_target_parent + && !is_correlated_subquery( + Segments::new(nested_selectable.selectable.clone(), None), + &select_source_names, + dialect, + ) + { + acc.push(NestedSubqueryCandidate { + selectable: selectable.clone(), + table_alias: table_alias.clone(), + }); } + + self.collect_nested_subqueries(&subquery, dialect, acc); } } } - - acc } } @@ -557,6 +637,7 @@ fn is_correlated_subquery( #[derive(Default)] struct CTEBuilder { ctes: Vec, + used_names: HashSet, name_idx: usize, } @@ -615,71 +696,42 @@ impl CTEBuilder { let mut fixes = Vec::new(); if subquery_parent.is(&output_select) { - let (missing_space_after_from, _from_clause, _from_clause_children, from_segment) = - Self::missing_space_after_from(output_select_clone.clone()); - - if missing_space_after_from { + if let Some(from_keyword) = Self::from_keyword_needing_separator(output_select_clone) { let mut anchor_fixes = HashMap::default(); compute_anchor_edit_info( &mut anchor_fixes, vec![LintFix::create_after( - from_segment.unwrap().base[0].clone(), + from_keyword, vec![SegmentBuilder::whitespace(tables.next_id(), " ")], None, )], ); *output_select_clone = output_select_clone.apply_fixes(&mut anchor_fixes).0; } - } else { - let (missing_space_after_from, _from_clause, _from_clause_children, from_segment) = - Self::missing_space_after_from(subquery_parent); - - if missing_space_after_from { - fixes.push(LintFix::create_after( - from_segment.unwrap().base[0].clone(), - vec![SegmentBuilder::whitespace(tables.next_id(), " ")], - None, - )) - } + } else if let Some(from_keyword) = Self::from_keyword_needing_separator(&subquery_parent) { + fixes.push(LintFix::create_after( + from_keyword, + vec![SegmentBuilder::whitespace(tables.next_id(), " ")], + None, + )) } fixes } - fn missing_space_after_from( - segment: ErasedSegment, - ) -> ( - bool, - Option, - Option, - Option, - ) { - let mut missing_space_after_from = false; - let from_clause_children = None; - let mut from_segment = None; - let from_clause = segment.child(const { &SyntaxSet::single(SyntaxKind::FromClause) }); - - if let Some(from_clause) = &from_clause { - let from_clause_children = Segments::from_vec(from_clause.segments().to_vec(), None); - from_segment = from_clause_children - .find_first_where(|it: &ErasedSegment| it.is_keyword("FROM")) - .into(); - if !from_segment.as_ref().unwrap().is_empty() - && from_clause_children - .after(&from_segment.as_ref().unwrap().base[0]) - .take_while(|it| it.is_whitespace()) - .is_empty() - { - missing_space_after_from = true; - } - } - - ( - missing_space_after_from, - from_clause, - from_clause_children, - from_segment, - ) + fn from_keyword_needing_separator(segment: &ErasedSegment) -> Option { + let from_clause = segment.child(const { &SyntaxSet::single(SyntaxKind::FromClause) })?; + let from_clause_children = Segments::from_vec(from_clause.segments().to_vec(), None); + let from_keyword = from_clause_children + .find_first_where(|it: &ErasedSegment| it.is_keyword("FROM")) + .first() + .cloned()?; + let has_separator = !from_clause_children + .after(&from_keyword) + .take_while(|it| it.is_whitespace() || it.is_comment() || it.is_meta()) + .is_empty(); + + (!has_separator).then_some(from_keyword) } } @@ -689,7 +741,7 @@ impl CTEBuilder { segment: ErasedSegment, clone_map: &SegmentCloneMap, ) { - for (idx, cte) in enumerate(&self.ctes) { + for (idx, cte) in self.ctes.iter().enumerate() { if cte .recursive_crawl_all(false) .into_iter() @@ -716,18 +768,13 @@ impl CTEBuilder { ]) }, ) - .unwrap(); + .expect("CTE should contain an identifier segment"); - if id_seg.is_type(SyntaxKind::QuotedIdentifier) { - let raw = id_seg.raw(); - raw[1..raw.len() - 1].to_smolstr() - } else { - id_seg.raw().to_smolstr() - } + normalize_identifier_name(id_seg.raw().as_ref()) } - fn list_used_names(&self) -> Vec { - self.ctes.iter().map(Self::cte_name).collect() + fn is_name_used(&self, name: &SmolStr) -> bool { + self.used_names.contains(name) } fn replace_cte(&mut self, cte_name: &str, new_cte: ErasedSegment) { @@ -736,25 +783,30 @@ impl CTEBuilder { .iter() .position(|cte| Self::cte_name(cte) == cte_name) { + self.used_names.remove(&Self::cte_name(&self.ctes[idx])); self.ctes[idx] = new_cte; + self.used_names.insert(Self::cte_name(&self.ctes[idx])); } } - fn create_cte_alias(&mut self, alias: Option<&AliasInfo>) -> (SmolStr, bool) { - if let Some(alias) = alias.filter(|alias| alias.aliased && !alias.ref_str.is_empty()) { - return (alias.ref_str.clone(), false); + fn create_cte_alias(&mut self, alias: Option<&AliasInfo>) -> IdentifierSpec { + if let Some(alias) = alias.filter(|alias| alias.aliased) + && let Some(alias_name) = IdentifierSpec::from_alias(alias) + { + return alias_name; } - self.name_idx += 1; - let name = format_smolstr!("prep_{}", self.name_idx); - if self.list_used_names().iter().contains(&name) { - return self.create_cte_alias(None); + loop { + self.name_idx += 1; + let name = format_smolstr!("prep_{}", self.name_idx); + if !self.is_name_used(&name) { + return IdentifierSpec::generated(name); + } } - - (name, true) } fn insert_cte(&mut self, cte: ErasedSegment) { + let cte_name = Self::cte_name(&cte); let inbound_subquery = Segments::new(cte.clone(), None) .children_all() .find_first_where(|it: &ErasedSegment| it.get_position_marker().is_some()); @@ -770,7 +822,7 @@ impl CTEBuilder { .children_all() .last() .cloned() - .unwrap(), + .expect("CTE should contain a trailing selectable segment"), None, ), inbound_subquery.clone(), @@ -781,6 +833,7 @@ impl CTEBuilder { .unwrap_or(self.ctes.len()); self.ctes.insert(insert_position, cte); + self.used_names.insert(cte_name); } } @@ -836,7 +889,7 @@ fn segmentify(tables: &Tables, input_el: &str, casing: Case) -> ErasedSegment { fn create_cte_seg( tables: &Tables, - alias_name: SmolStr, + alias_name: IdentifierSpec, subquery: ErasedSegment, case_preference: Case, dialect: &Dialect, @@ -846,8 +899,7 @@ fn create_cte_seg( SyntaxKind::CommonTableExpression, dialect.name, vec![ - SegmentBuilder::token(tables.next_id(), &alias_name, SyntaxKind::NakedIdentifier) - .finish(), + SegmentBuilder::token(tables.next_id(), &alias_name.raw, alias_name.kind).finish(), SegmentBuilder::whitespace(tables.next_id(), " "), segmentify(tables, "AS", case_preference), SegmentBuilder::whitespace(tables.next_id(), " "), @@ -857,7 +909,11 @@ fn create_cte_seg( .finish() } -fn create_table_ref(tables: &Tables, table_name: &str, dialect: &Dialect) -> ErasedSegment { +fn create_table_ref( + tables: &Tables, + table_name: &IdentifierSpec, + dialect: &Dialect, +) -> ErasedSegment { SegmentBuilder::node( tables.next_id(), SyntaxKind::TableExpression, @@ -868,12 +924,8 @@ fn create_table_ref(tables: &Tables, table_name: &str, dialect: &Dialect) -> Era SyntaxKind::TableReference, dialect.name, vec![ - SegmentBuilder::token( - tables.next_id(), - table_name, - SyntaxKind::NakedIdentifier, - ) - .finish(), + SegmentBuilder::token(tables.next_id(), &table_name.raw, table_name.kind) + .finish(), ], ) .finish(), @@ -895,12 +947,6 @@ impl Index<&ErasedSegment> for SegmentCloneMap { } } -impl IndexMut<&ErasedSegment> for SegmentCloneMap { - fn index_mut(&mut self, index: &ErasedSegment) -> &mut Self::Output { - self.segment_map.get_mut(&index.id()).unwrap() - } -} - impl SegmentCloneMap { fn get(&self, old_segment: &ErasedSegment) -> Option<&ErasedSegment> { self.segment_map.get(&old_segment.id()) @@ -929,7 +975,7 @@ mod tests { use crate::core::linter::core::Linter; use crate::core::rules::Rule; - use super::RuleST05; + use super::{ForbidSubqueryIn, RuleST05}; fn st05_linter(dialect: &str) -> Linter { let config = FluffConfig::from_source( @@ -1458,4 +1504,16 @@ FROM `func`(( assert!(RuleST05::default().load_from_config(&config).is_ok()); } + #[test] + fn st05_default_config_is_safe_and_uses_both() { + assert_eq!( + RuleST05::default().forbid_subquery_in, + ForbidSubqueryIn::Both + ); + assert!( + RuleST05::default() + .load_from_config(&HashMap::default()) + .is_ok() + ); + } } diff --git a/crates/lib/test/fixtures/rules/std_rule_cases/ST05.yml b/crates/lib/test/fixtures/rules/std_rule_cases/ST05.yml index 1ef656bab..4debdd699 100644 --- a/crates/lib/test/fixtures/rules/std_rule_cases/ST05.yml +++ b/crates/lib/test/fixtures/rules/std_rule_cases/ST05.yml @@ -673,3 +673,55 @@ test_fail_clickhouse_expression_cte_preserved: configs: core: dialect: clickhouse + +test_fail_nested_join_subquery_inside_allowed_from_subquery: + fail_str: | + select + outer_q.x, + outer_q.y + from ( + select + a.x, + b.y + from a + join ( + select x, y from b + ) as b on a.x = b.x + ) as outer_q + fix_str: | + with b as ( + select x, y from b + ) + select + outer_q.x, + outer_q.y + from ( + select + a.x, + b.y + from a + join b on a.x = b.x + ) as outer_q + configs: + rules: + structure.subquery: + forbid_subquery_in: join + +test_fail_preserves_quoted_aliases_in_generated_ctes: + fail_str: | + select + a.x, + "b".y + from a + join ( + select x, y from b + ) as "b" on a.x = "b".x + fix_str: | + with "b" as ( + select x, y from b + ) + select + a.x, + "b".y + from a + join "b" on a.x = "b".x