diff --git a/cluster/src/test/scala/org/apache/spark/sql/SnappySQLQuerySuite.scala b/cluster/src/test/scala/org/apache/spark/sql/SnappySQLQuerySuite.scala index 5d3ac507da..daea619cdf 100644 --- a/cluster/src/test/scala/org/apache/spark/sql/SnappySQLQuerySuite.scala +++ b/cluster/src/test/scala/org/apache/spark/sql/SnappySQLQuerySuite.scala @@ -253,9 +253,59 @@ class SnappySQLQuerySuite extends SnappyFunSuite { "(exists (select col1 from r2 where r2.col1=r1.col1) " + "or exists(select col1 from r3 where r3.col1=r1.col1))") - val result = df.collect() checkAnswer(df, Seq(Row(1, "1", "1", 100), Row(2, "2", "2", 2), Row(4, "4", "4", 4) )) + snc.dropTable("r1", ifExists = true) + } + + test("Delete duplicate with WITH and window function") { + val snc = new SnappySession(sc) + snc.dropTable("r1", ifExists = true) + snc.sql("create table r1(col1 INT, col2 STRING, col3 String, col4 Int)" + + " using column ") + + + snc.insert("r1", Row(1, "1", "1", 100)) + snc.insert("r1", Row(1, "1", "1", 100)) + snc.insert("r1", Row(2, "4", "4", 4)) + snc.insert("r1", Row(2, "4", "4", 4)) + snc.sql("WITH dups AS " + + "(SELECT col1, ROW_NUMBER() OVER" + + " (PARTITION BY col1 ORDER BY ( SELECT 0))" + + " RN FROM r1) DELETE from dups where rn > 1;") + + val df = snc.sql("Select * from r1") + checkAnswer(df, Seq(Row(1, "1", "1", 100), + Row(2, "4", "4", 4))) + } + + test("Update rows duplicate with WITH and window function") { + val snc = new SnappySession(sc) + snc.dropTable("r1", ifExists = true) + snc.sql("create table r1(col1 INT, col2 STRING, col3 String, col4 Int)" + + " using column ") + + + snc.insert("r1", Row(1, "1", "1", 100)) + snc.insert("r1", Row(1, "1", "1", 100)) + snc.insert("r1", Row(2, "4", "4", 4)) + snc.insert("r1", Row(2, "4", "4", 4)) + snc.sql("WITH dups AS " + + "(SELECT col1, ROW_NUMBER() OVER" + + " (PARTITION BY col1 ORDER BY ( SELECT 0))" + + " RN FROM r1) update dups set col1 = 99 where rn > 1;") + + val df = snc.sql("Select * from r1") + checkAnswer(df, Seq(Row(99, "1", "1", 100), + Row(99, "4", "4", 4), Row(1, "1", "1", 100), Row(2, "4", "4", 4))) + } + + test("netsed CTEs") { + val snc = new SnappySession(sc) + val df = snc.sql("select * from range(10) where id" + + " not in (select id from range(2) union all select id from range(2))") + + df.show } } diff --git a/core/src/main/scala/org/apache/spark/sql/SnappyParser.scala b/core/src/main/scala/org/apache/spark/sql/SnappyParser.scala index c265a47a0f..f47903c1f7 100644 --- a/core/src/main/scala/org/apache/spark/sql/SnappyParser.scala +++ b/core/src/main/scala/org/apache/spark/sql/SnappyParser.scala @@ -1060,7 +1060,7 @@ class SnappyParser(session: SnappySession) protected final def ctes: Rule1[LogicalPlan] = rule { WITH ~ ((identifier ~ AS.? ~ '(' ~ ws ~ query ~ ')' ~ ws ~> ((id: String, p: LogicalPlan) => (id, p))) + commaSep) ~ - (query | insert) ~> ((r: Seq[(String, LogicalPlan)], s: LogicalPlan) => + (query | insert | delete | update) ~> ((r: Seq[(String, LogicalPlan)], s: LogicalPlan) => With(s, r.map(ns => (ns._1, SubqueryAlias(ns._1, ns._2, None))))) } diff --git a/core/src/main/scala/org/apache/spark/sql/internal/SnappySessionState.scala b/core/src/main/scala/org/apache/spark/sql/internal/SnappySessionState.scala index 40a19751bb..e1e723a488 100644 --- a/core/src/main/scala/org/apache/spark/sql/internal/SnappySessionState.scala +++ b/core/src/main/scala/org/apache/spark/sql/internal/SnappySessionState.scala @@ -424,6 +424,25 @@ class SnappySessionState(snappySession: SnappySession) } } + def projectKeyAttributes(table: LogicalPlan, + newChild: LogicalPlan, + keyAttrs: Seq[NamedExpression]): (LogicalPlan, LogicalPlan) = { + val transformedChild = newChild.transformUp { + case Project(attr, ch) + if keyAttrs.forall(k => ch.output.map(_.name).contains(k.name)) => + Project(attr ++ keyAttrs, ch) + } + val physicalTables = table.collect { + case lr@LogicalRelation(mutable: MutableRelation, _, _) => lr + } + if (physicalTables.size > 1 || physicalTables.isEmpty) { + throw new AnalysisException("You need to update/delete on one and only one mutable table." + + " If you are using a subquery/CTE in the FROM clause ensure" + + " it is only on one mutable relation") + } + (physicalTables.head, transformedChild) + } + def apply(plan: LogicalPlan): LogicalPlan = plan transform { case c: DMLExternalTable if !c.query.resolved => c.copy(query = analyzeQuery(c.query)) @@ -469,9 +488,15 @@ class SnappySessionState(snappySession: SnappySession) // any extra columns val allReferences = newChild.references ++ AttributeSet(newUpdateExprs.flatMap(_.references)) ++ AttributeSet(keyAttrs) - u.copy(child = Project(newChild.output.filter(allReferences.contains), newChild), + + val (physicalTable, transformedChild) = projectKeyAttributes(table, newChild, keyAttrs) + + u.copy(child = Project(transformedChild.output.filter(allReferences.contains), + transformedChild), keyColumns = keyAttrs.map(_.toAttribute), - updateColumns = updateAttrs.map(_.toAttribute), updateExpressions = newUpdateExprs) + updateColumns = updateAttrs.map(_.toAttribute), + updateExpressions = newUpdateExprs, + table = physicalTable) } case d@Delete(table, child, keyColumns) if keyColumns.isEmpty && child.resolved => @@ -480,7 +505,8 @@ class SnappySessionState(snappySession: SnappySession) // if this is a row table with no PK, then fallback to direct execution if (keyAttrs.isEmpty) newChild else { - d.copy(child = Project(keyAttrs, newChild), + val (physicalTable, transformedChild) = projectKeyAttributes(table, newChild, keyAttrs) + d.copy(table = physicalTable, child = Project(keyAttrs, transformedChild), keyColumns = keyAttrs.map(_.toAttribute)) } case d@DeleteFromTable(_, child) if child.resolved =>