From 66a79c1cf6a4fc90b5e1d3f60ff9f58c172fc8bc Mon Sep 17 00:00:00 2001 From: Xin Zhang Date: Wed, 19 Mar 2025 11:58:23 +0000 Subject: [PATCH 01/12] Fix formatting of table function table arguments in SqlFormatter Changes adapted from trino/PR#14175 Original commit: 5c125b5ef0e355b7f89d4927171dc7dd029d0b18 Author: kasiafi Co-authored-by: kasiafi <30203062+kasiafi@users.noreply.github.com> --- .../src/main/java/com/facebook/presto/sql/SqlFormatter.java | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/presto-parser/src/main/java/com/facebook/presto/sql/SqlFormatter.java b/presto-parser/src/main/java/com/facebook/presto/sql/SqlFormatter.java index 49a76d28b9c80..0013c2a0aa2d8 100644 --- a/presto-parser/src/main/java/com/facebook/presto/sql/SqlFormatter.java +++ b/presto-parser/src/main/java/com/facebook/presto/sql/SqlFormatter.java @@ -266,7 +266,11 @@ private void appendTableFunctionArguments(List arguments, protected Void visitTableArgument(TableFunctionTableArgument node, Integer indent) { Relation relation = node.getTable(); - Relation unaliased = relation instanceof AliasedRelation ? ((AliasedRelation) relation).getRelation() : relation; + Node unaliased = relation instanceof AliasedRelation ? ((AliasedRelation) relation).getRelation() : relation; + if (unaliased instanceof TableSubquery) { + // unpack the relation from TableSubquery to avoid adding another pair of parentheses + unaliased = ((TableSubquery) unaliased).getQuery(); + } builder.append("TABLE("); process(unaliased, indent); builder.append(")"); From 91f7107bdc61f1462a8f9cb439af503eea7bd2bc Mon Sep 17 00:00:00 2001 From: Xin Zhang Date: Wed, 19 Mar 2025 13:07:25 +0000 Subject: [PATCH 02/12] Pass plan node tag in the context of PlanPrinter Changes adapted from trino/PR#14175 Original commit: a6f537d5519e34a4a46a411e6967d585b382c56f Author: kasiafi Co-authored-by: kasiafi <30203062+kasiafi@users.noreply.github.com> --- .../sql/planner/planPrinter/PlanPrinter.java | 337 ++++++++++-------- 1 file changed, 182 insertions(+), 155 deletions(-) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java index ec0868f02531f..4153c20c32f3d 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java @@ -203,7 +203,7 @@ private PlanPrinter( this.formatter = rowExpression -> rowExpressionFormatter.formatRowExpression(connectorSession, rowExpression); Visitor visitor = new Visitor(stageExecutionStrategy, types, estimatedStatsAndCosts, session, stats); - planRoot.accept(visitor, null); + planRoot.accept(visitor, new Context()); } public String toText(boolean verbose, int level, boolean verboseOptimizerInfo) @@ -483,7 +483,7 @@ public static String graphvizDistributedPlan(StageInfo stageInfo, FunctionAndTyp } private class Visitor - extends InternalPlanVisitor + extends InternalPlanVisitor { private final Optional stageExecutionStrategy; private final TypeProvider types; @@ -501,14 +501,14 @@ public Visitor(Optional stageExecutionStrategy, TypePr } @Override - public Void visitExplainAnalyze(ExplainAnalyzeNode node, Void context) + public Void visitExplainAnalyze(ExplainAnalyzeNode node, Context context) { - addNode(node, "ExplainAnalyze"); - return processChildren(node, context); + addNode(node, "ExplainAnalyze", context.getTag()); + return processChildren(node, new Context()); } @Override - public Void visitJoin(JoinNode node, Void context) + public Void visitJoin(JoinNode node, Context context) { List joinExpressions = new ArrayList<>(); for (EquiJoinClause clause : node.getCriteria()) { @@ -519,12 +519,12 @@ public Void visitJoin(JoinNode node, Void context) NodeRepresentation nodeOutput; if (node.isCrossJoin()) { checkState(joinExpressions.isEmpty()); - nodeOutput = addNode(node, "CrossJoin"); + nodeOutput = addNode(node, "CrossJoin", context.getTag()); } else { nodeOutput = addNode(node, node.getType().getJoinLabel(), - format("[%s]%s", Joiner.on(" AND ").join(joinExpressions), formatHash(node.getLeftHashVariable(), node.getRightHashVariable()))); + format("[%s]%s", Joiner.on(" AND ").join(joinExpressions), formatHash(node.getLeftHashVariable(), node.getRightHashVariable())), context.getTag()); } node.getDistributionType().ifPresent(distributionType -> nodeOutput.appendDetailsLine("Distribution: %s", distributionType)); @@ -534,51 +534,51 @@ public Void visitJoin(JoinNode node, Void context) getSortExpressionContext(node, functionAndTypeManager) .ifPresent(sortContext -> nodeOutput.appendDetails("SortExpression[%s]", formatter.apply(sortContext.getSortExpression()))); - node.getLeft().accept(this, context); - node.getRight().accept(this, context); + node.getLeft().accept(this, new Context()); + node.getRight().accept(this, new Context()); return null; } @Override - public Void visitSpatialJoin(SpatialJoinNode node, Void context) + public Void visitSpatialJoin(SpatialJoinNode node, Context context) { NodeRepresentation nodeOutput = addNode(node, node.getType().getJoinLabel(), - format("[%s]", formatter.apply(node.getFilter()))); + format("[%s]", formatter.apply(node.getFilter())), context.getTag()); nodeOutput.appendDetailsLine("Distribution: %s", node.getDistributionType()); - node.getLeft().accept(this, context); - node.getRight().accept(this, context); + node.getLeft().accept(this, new Context()); + node.getRight().accept(this, new Context()); return null; } @Override - public Void visitSemiJoin(SemiJoinNode node, Void context) + public Void visitSemiJoin(SemiJoinNode node, Context context) { NodeRepresentation nodeOutput = addNode(node, "SemiJoin", format("[%s = %s]%s", node.getSourceJoinVariable(), node.getFilteringSourceJoinVariable(), - formatHash(node.getSourceHashVariable(), node.getFilteringSourceHashVariable()))); + formatHash(node.getSourceHashVariable(), node.getFilteringSourceHashVariable())), context.getTag()); node.getDistributionType().ifPresent(distributionType -> nodeOutput.appendDetailsLine("Distribution: %s", distributionType)); if (!node.getDynamicFilters().isEmpty()) { nodeOutput.appendDetails(getDynamicFilterAssignments(node)); } - node.getSource().accept(this, context); - node.getFilteringSource().accept(this, context); + node.getSource().accept(this, new Context()); + node.getFilteringSource().accept(this, new Context()); return null; } @Override - public Void visitIndexSource(IndexSourceNode node, Void context) + public Void visitIndexSource(IndexSourceNode node, Context context) { NodeRepresentation nodeOutput = addNode(node, "IndexSource", - format("[%s, lookup = %s]", node.getIndexHandle(), node.getLookupVariables())); + format("[%s, lookup = %s]", node.getIndexHandle(), node.getLookupVariables()), context.getTag()); for (Map.Entry entry : node.getAssignments().entrySet()) { if (node.getOutputVariables().contains(entry.getKey())) { @@ -589,7 +589,7 @@ public Void visitIndexSource(IndexSourceNode node, Void context) } @Override - public Void visitIndexJoin(IndexJoinNode node, Void context) + public Void visitIndexJoin(IndexJoinNode node, Context context) { List joinExpressions = new ArrayList<>(); for (IndexJoinNode.EquiJoinClause clause : node.getCriteria()) { @@ -600,15 +600,15 @@ public Void visitIndexJoin(IndexJoinNode node, Void context) addNode(node, format("%sIndexJoin", node.getType().getJoinLabel()), - format("[%s]%s", Joiner.on(" AND ").join(joinExpressions), formatHash(node.getProbeHashVariable(), node.getIndexHashVariable()))); - node.getProbeSource().accept(this, context); - node.getIndexSource().accept(this, context); + format("[%s]%s", Joiner.on(" AND ").join(joinExpressions), formatHash(node.getProbeHashVariable(), node.getIndexHashVariable())), context.getTag()); + node.getProbeSource().accept(this, new Context()); + node.getIndexSource().accept(this, new Context()); return null; } @Override - public Void visitMergeJoin(MergeJoinNode node, Void context) + public Void visitMergeJoin(MergeJoinNode node, Context context) { List joinExpressions = new ArrayList<>(); for (EquiJoinClause clause : node.getCriteria()) { @@ -618,32 +618,32 @@ public Void visitMergeJoin(MergeJoinNode node, Void context) addNode(node, "MergeJoin", - format("[type: %s], [%s]%s", node.getType().getJoinLabel(), Joiner.on(" AND ").join(joinExpressions), formatHash(node.getLeftHashVariable(), node.getRightHashVariable()))); - node.getLeft().accept(this, context); - node.getRight().accept(this, context); + format("[type: %s], [%s]%s", node.getType().getJoinLabel(), Joiner.on(" AND ").join(joinExpressions), formatHash(node.getLeftHashVariable(), node.getRightHashVariable())), context.getTag()); + node.getLeft().accept(this, new Context()); + node.getRight().accept(this, new Context()); return null; } @Override - public Void visitLimit(LimitNode node, Void context) + public Void visitLimit(LimitNode node, Context context) { addNode(node, format("Limit%s", node.isPartial() ? "Partial" : ""), - format("[%s]", node.getCount())); - return processChildren(node, context); + format("[%s]", node.getCount()), context.getTag()); + return processChildren(node, new Context()); } @Override - public Void visitDistinctLimit(DistinctLimitNode node, Void context) + public Void visitDistinctLimit(DistinctLimitNode node, Context context) { addNode(node, format("DistinctLimit%s", node.isPartial() ? "Partial" : ""), - format("[%s]%s", node.getLimit(), formatHash(node.getHashVariable()))); - return processChildren(node, context); + format("[%s]%s", node.getLimit(), formatHash(node.getHashVariable())), context.getTag()); + return processChildren(node, new Context()); } @Override - public Void visitAggregation(AggregationNode node, Void context) + public Void visitAggregation(AggregationNode node, Context context) { String type = ""; if (node.getStep() != AggregationNode.Step.SINGLE) { @@ -661,13 +661,13 @@ public Void visitAggregation(AggregationNode node, Void context) } NodeRepresentation nodeOutput = addNode(node, - format("Aggregate%s%s%s", type, key, formatHash(node.getHashVariable()))); + format("Aggregate%s%s%s", type, key, formatHash(node.getHashVariable())), context.getTag()); for (Map.Entry entry : node.getAggregations().entrySet()) { nodeOutput.appendDetailsLine("%s := %s%s", entry.getKey(), formatAggregation(entry.getValue()), formatSourceLocation(entry.getValue().getCall().getSourceLocation(), entry.getKey().getSourceLocation())); } - return processChildren(node, context); + return processChildren(node, new Context()); } private String formatAggregation(AggregationNode.Aggregation aggregation) @@ -694,7 +694,7 @@ private String formatAggregation(AggregationNode.Aggregation aggregation) } @Override - public Void visitGroupId(GroupIdNode node, Void context) + public Void visitGroupId(GroupIdNode node, Context context) { // grouping sets are easier to understand in terms of inputs List> inputGroupingSetSymbols = node.getGroupingSets().stream() @@ -703,27 +703,27 @@ public Void visitGroupId(GroupIdNode node, Void context) .collect(Collectors.toList())) .collect(Collectors.toList()); - NodeRepresentation nodeOutput = addNode(node, "GroupId", format("%s", inputGroupingSetSymbols)); + NodeRepresentation nodeOutput = addNode(node, "GroupId", format("%s", inputGroupingSetSymbols), context.getTag()); for (Map.Entry mapping : node.getGroupingColumns().entrySet()) { nodeOutput.appendDetailsLine("%s := %s%s", mapping.getKey(), mapping.getValue(), formatSourceLocation(mapping.getValue().getSourceLocation(), mapping.getKey().getSourceLocation())); } - return processChildren(node, context); + return processChildren(node, new Context()); } @Override - public Void visitMarkDistinct(MarkDistinctNode node, Void context) + public Void visitMarkDistinct(MarkDistinctNode node, Context context) { addNode(node, "MarkDistinct", - format("[distinct=%s marker=%s]%s", formatOutputs(node.getDistinctVariables()), node.getMarkerVariable(), formatHash(node.getHashVariable()))); + format("[distinct=%s marker=%s]%s", formatOutputs(node.getDistinctVariables()), node.getMarkerVariable(), formatHash(node.getHashVariable())), context.getTag()); - return processChildren(node, context); + return processChildren(node, new Context()); } @Override - public Void visitWindow(WindowNode node, Void context) + public Void visitWindow(WindowNode node, Context context) { List partitionBy = Lists.transform(node.getPartitionBy(), Functions.toStringFunction()); @@ -763,7 +763,7 @@ public Void visitWindow(WindowNode node, Void context) .collect(Collectors.joining(", ")))); } - NodeRepresentation nodeOutput = addNode(node, "Window", format("[%s]%s", Joiner.on(", ").join(args), formatHash(node.getHashVariable()))); + NodeRepresentation nodeOutput = addNode(node, "Window", format("[%s]%s", Joiner.on(", ").join(args), formatHash(node.getHashVariable())), context.getTag()); for (Map.Entry entry : node.getWindowFunctions().entrySet()) { CallExpression call = entry.getValue().getFunctionCall(); @@ -777,11 +777,11 @@ public Void visitWindow(WindowNode node, Void context) frameInfo, formatSourceLocation(entry.getValue().getFunctionCall().getSourceLocation(), entry.getKey().getSourceLocation())); } - return processChildren(node, context); + return processChildren(node, new Context()); } @Override - public Void visitTopNRowNumber(TopNRowNumberNode node, Void context) + public Void visitTopNRowNumber(TopNRowNumberNode node, Context context) { List partitionBy = node.getPartitionBy().stream() .map(Functions.toStringFunction()) @@ -797,15 +797,15 @@ public Void visitTopNRowNumber(TopNRowNumberNode node, Void context) NodeRepresentation nodeOutput = addNode(node, format("TopNRowNumber%s", node.isPartial() ? "Partial" : ""), - format("[%s limit %s]%s", Joiner.on(", ").join(args), node.getMaxRowCountPerPartition(), formatHash(node.getHashVariable()))); + format("[%s limit %s]%s", Joiner.on(", ").join(args), node.getMaxRowCountPerPartition(), formatHash(node.getHashVariable())), context.getTag()); nodeOutput.appendDetailsLine("%s := %s%s", node.getRowNumberVariable(), "row_number()", formatSourceLocation(node.getRowNumberVariable().getSourceLocation())); - return processChildren(node, context); + return processChildren(node, new Context()); } @Override - public Void visitRowNumber(RowNumberNode node, Void context) + public Void visitRowNumber(RowNumberNode node, Context context) { List partitionBy = Lists.transform(node.getPartitionBy(), Functions.toStringFunction()); List args = new ArrayList<>(); @@ -819,24 +819,24 @@ public Void visitRowNumber(RowNumberNode node, Void context) NodeRepresentation nodeOutput = addNode(node, format("RowNumber%s", node.isPartial() ? "Partial" : ""), - format("[%s]%s", Joiner.on(", ").join(args), formatHash(node.getHashVariable()))); + format("[%s]%s", Joiner.on(", ").join(args), formatHash(node.getHashVariable())), context.getTag()); nodeOutput.appendDetailsLine("%s := %s%s", node.getRowNumberVariable(), "row_number()", formatSourceLocation(node.getRowNumberVariable().getSourceLocation())); - return processChildren(node, context); + return processChildren(node, new Context()); } @Override - public Void visitTableScan(TableScanNode node, Void context) + public Void visitTableScan(TableScanNode node, Context context) { TableHandle table = node.getTable(); NodeRepresentation nodeOutput; if (stageExecutionStrategy.isPresent()) { nodeOutput = addNode(node, "TableScan", - format("[%s, grouped = %s]", table, stageExecutionStrategy.get().isScanGroupedExecution(node.getId()))); + format("[%s, grouped = %s]", table, stageExecutionStrategy.get().isScanGroupedExecution(node.getId())), context.getTag()); } else { - nodeOutput = addNode(node, "TableScan", format("[%s]", table)); + nodeOutput = addNode(node, "TableScan", format("[%s]", table), context.getTag()); } PlanNodeStats nodeStats = stats.map(s -> s.get(node.getId())).orElse(null); printTableScanInfo(nodeOutput, node, nodeStats); @@ -844,49 +844,49 @@ public Void visitTableScan(TableScanNode node, Void context) } @Override - public Void visitSequence(SequenceNode node, Void context) + public Void visitSequence(SequenceNode node, Context context) { NodeRepresentation nodeOutput; - nodeOutput = addNode(node, "Sequence"); + nodeOutput = addNode(node, "Sequence", context.getTag()); nodeOutput.appendDetails(getCteExecutionOrder(node)); - return processChildren(node, context); + return processChildren(node, new Context()); } @Override - public Void visitCteConsumer(CteConsumerNode node, Void context) + public Void visitCteConsumer(CteConsumerNode node, Context context) { NodeRepresentation nodeOutput; - nodeOutput = addNode(node, "CteConsumer"); + nodeOutput = addNode(node, "CteConsumer", context.getTag()); nodeOutput.appendDetailsLine("CTE_NAME: %s", node.getCteId()); - return processChildren(node, context); + return processChildren(node, new Context()); } @Override - public Void visitCteProducer(CteProducerNode node, Void context) + public Void visitCteProducer(CteProducerNode node, Context context) { NodeRepresentation nodeOutput; - nodeOutput = addNode(node, "CteProducer"); + nodeOutput = addNode(node, "CteProducer", context.getTag()); nodeOutput.appendDetailsLine("CTE_NAME: %s", node.getCteId()); - return processChildren(node, context); + return processChildren(node, new Context()); } @Override - public Void visitCteReference(CteReferenceNode node, Void context) + public Void visitCteReference(CteReferenceNode node, Context context) { - addNode(node, "CteReference"); - return processChildren(node, context); + addNode(node, "CteReference", context.getTag()); + return processChildren(node, new Context()); } @Override - public Void visitValues(ValuesNode node, Void context) + public Void visitValues(ValuesNode node, Context context) { NodeRepresentation nodeOutput; if (node.getValuesNodeLabel().isPresent()) { - nodeOutput = addNode(node, format("Values converted from TableScan[%s]", node.getValuesNodeLabel().get())); + nodeOutput = addNode(node, format("Values converted from TableScan[%s]", node.getValuesNodeLabel().get()), context.getTag()); } else { - nodeOutput = addNode(node, "Values"); + nodeOutput = addNode(node, "Values", context.getTag()); } for (List row : node.getRows()) { nodeOutput.appendDetailsLine("(" + row.stream().map(formatter::apply).collect(Collectors.joining(", ")) + ")"); @@ -895,13 +895,13 @@ public Void visitValues(ValuesNode node, Void context) } @Override - public Void visitFilter(FilterNode node, Void context) + public Void visitFilter(FilterNode node, Context context) { return visitScanFilterAndProjectInfo(node, Optional.of(node), Optional.empty(), context); } @Override - public Void visitProject(ProjectNode node, Void context) + public Void visitProject(ProjectNode node, Context context) { if (node.getSource() instanceof FilterNode) { return visitScanFilterAndProjectInfo(node, Optional.of((FilterNode) node.getSource()), Optional.of(node), context); @@ -914,7 +914,7 @@ private Void visitScanFilterAndProjectInfo( PlanNode node, Optional filterNode, Optional projectNode, - Void context) + Context context) { checkState(projectNode.isPresent() || filterNode.isPresent()); @@ -988,7 +988,8 @@ private Void visitScanFilterAndProjectInfo( format(formatString, arguments.toArray(new Object[0])), allNodes, ImmutableList.of(sourceNode), - ImmutableList.of()); + ImmutableList.of(), + context.getTag()); if (projectNode.isPresent()) { printAssignments(nodeOutput, projectNode.get().getAssignments()); @@ -1000,7 +1001,7 @@ private Void visitScanFilterAndProjectInfo( return null; } - sourceNode.accept(this, context); + sourceNode.accept(this, new Context()); return null; } @@ -1068,18 +1069,18 @@ else if (predicate.isNone()) { } @Override - public Void visitUnnest(UnnestNode node, Void context) + public Void visitUnnest(UnnestNode node, Context context) { addNode(node, "Unnest", - format("[replicate=%s, unnest=%s]", formatOutputs(node.getReplicateVariables()), formatOutputs(node.getUnnestVariables().keySet()))); - return processChildren(node, context); + format("[replicate=%s, unnest=%s]", formatOutputs(node.getReplicateVariables()), formatOutputs(node.getUnnestVariables().keySet())), context.getTag()); + return processChildren(node, new Context()); } @Override - public Void visitOutput(OutputNode node, Void context) + public Void visitOutput(OutputNode node, Context context) { - NodeRepresentation nodeOutput = addNode(node, "Output", format("[%s]", Joiner.on(", ").join(node.getColumnNames()))); + NodeRepresentation nodeOutput = addNode(node, "Output", format("[%s]", Joiner.on(", ").join(node.getColumnNames())), context.getTag()); for (int i = 0; i < node.getColumnNames().size(); i++) { String name = node.getColumnNames().get(i); VariableReferenceExpression variable = node.getOutputVariables().get(i); @@ -1087,22 +1088,22 @@ public Void visitOutput(OutputNode node, Void context) nodeOutput.appendDetailsLine("%s := %s%s", name, variable, formatSourceLocation(variable.getSourceLocation())); } } - return processChildren(node, context); + return processChildren(node, new Context()); } @Override - public Void visitTopN(TopNNode node, Void context) + public Void visitTopN(TopNNode node, Context context) { Iterable keys = Iterables.transform(node.getOrderingScheme().getOrderByVariables(), input -> input + " " + node.getOrderingScheme().getOrdering(input)); addNode(node, format("TopN%s", node.getStep() == TopNNode.Step.PARTIAL ? "Partial" : ""), - format("[%s by (%s)]", node.getCount(), Joiner.on(", ").join(keys))); - return processChildren(node, context); + format("[%s by (%s)]", node.getCount(), Joiner.on(", ").join(keys)), context.getTag()); + return processChildren(node, new Context()); } @Override - public Void visitSort(SortNode node, Void context) + public Void visitSort(SortNode node, Context context) { Iterable keys = Iterables.transform(node.getOrderingScheme().getOrderByVariables(), input -> input + " " + node.getOrderingScheme().getOrdering(input)); @@ -1110,52 +1111,53 @@ public Void visitSort(SortNode node, Void context) if (!node.getPartitionBy().isEmpty()) { detail = format("%s[Partition by %s]", detail, Joiner.on(", ").join(node.getPartitionBy())); } - addNode(node, format("%sSort", node.isPartial() ? "Partial" : ""), detail); + addNode(node, format("%sSort", node.isPartial() ? "Partial" : ""), detail, context.getTag()); - return processChildren(node, context); + return processChildren(node, new Context()); } @Override - public Void visitRemoteSource(RemoteSourceNode node, Void context) + public Void visitRemoteSource(RemoteSourceNode node, Context context) { addNode(node, format("Remote%s", node.getOrderingScheme().isPresent() ? "Merge" : "Source"), format("[%s]", Joiner.on(',').join(node.getSourceFragmentIds())), ImmutableList.of(), ImmutableList.of(), - node.getSourceFragmentIds()); + node.getSourceFragmentIds(), + context.getTag()); return null; } @Override - public Void visitUnion(UnionNode node, Void context) + public Void visitUnion(UnionNode node, Context context) { - addNode(node, "Union"); + addNode(node, "Union", context.getTag()); - return processChildren(node, context); + return processChildren(node, new Context()); } @Override - public Void visitIntersect(IntersectNode node, Void context) + public Void visitIntersect(IntersectNode node, Context context) { - addNode(node, "Intersect"); + addNode(node, "Intersect", context.getTag()); - return processChildren(node, context); + return processChildren(node, new Context()); } @Override - public Void visitExcept(ExceptNode node, Void context) + public Void visitExcept(ExceptNode node, Context context) { - addNode(node, "Except"); + addNode(node, "Except", context.getTag()); - return processChildren(node, context); + return processChildren(node, new Context()); } @Override - public Void visitTableWriter(TableWriterNode node, Void context) + public Void visitTableWriter(TableWriterNode node, Context context) { - NodeRepresentation nodeOutput = addNode(node, "TableWriter"); + NodeRepresentation nodeOutput = addNode(node, "TableWriter", context.getTag()); for (int i = 0; i < node.getColumnNames().size(); i++) { String name = node.getColumnNames().get(i); VariableReferenceExpression variable = node.getColumns().get(i); @@ -1168,40 +1170,40 @@ public Void visitTableWriter(TableWriterNode node, Void context) .orElse(0); nodeOutput.appendDetailsLine("Statistics collected: %s", statisticsCollected); - return processChildren(node, context); + return processChildren(node, new Context()); } @Override - public Void visitTableWriteMerge(TableWriterMergeNode node, Void context) + public Void visitTableWriteMerge(TableWriterMergeNode node, Context context) { - addNode(node, "TableWriterMerge"); - return processChildren(node, context); + addNode(node, "TableWriterMerge", context.getTag()); + return processChildren(node, new Context()); } @Override - public Void visitStatisticsWriterNode(StatisticsWriterNode node, Void context) + public Void visitStatisticsWriterNode(StatisticsWriterNode node, Context context) { - addNode(node, "StatisticsWriter", format("[%s]", node.getTableHandle())); - return processChildren(node, context); + addNode(node, "StatisticsWriter", format("[%s]", node.getTableHandle()), context.getTag()); + return processChildren(node, new Context()); } @Override - public Void visitTableFinish(TableFinishNode node, Void context) + public Void visitTableFinish(TableFinishNode node, Context context) { - addNode(node, "TableCommit", format("[%s]", node.getTarget())); - return processChildren(node, context); + addNode(node, "TableCommit", format("[%s]", node.getTarget()), context.getTag()); + return processChildren(node, new Context()); } @Override - public Void visitSample(SampleNode node, Void context) + public Void visitSample(SampleNode node, Context context) { - addNode(node, "Sample", format("[%s: %s]", node.getSampleType(), node.getSampleRatio())); + addNode(node, "Sample", format("[%s: %s]", node.getSampleType(), node.getSampleRatio()), context.getTag()); - return processChildren(node, context); + return processChildren(node, new Context()); } @Override - public Void visitExchange(ExchangeNode node, Void context) + public Void visitExchange(ExchangeNode node, Context context) { if (node.getOrderingScheme().isPresent()) { OrderingScheme orderingScheme = node.getOrderingScheme().get(); @@ -1212,7 +1214,7 @@ public Void visitExchange(ExchangeNode node, Void context) addNode(node, format("%sMerge", UPPER_UNDERSCORE.to(CaseFormat.UPPER_CAMEL, node.getScope().toString())), - format("[%s]", Joiner.on(", ").join(orderBy))); + format("[%s]", Joiner.on(", ").join(orderBy)), context.getTag()); } else if (node.getScope().isLocal()) { addNode(node, @@ -1221,7 +1223,7 @@ else if (node.getScope().isLocal()) { node.getPartitioningScheme().getPartitioning().getHandle(), node.getPartitioningScheme().isReplicateNullsAndAny() ? " - REPLICATE NULLS AND ANY" : "", formatHash(node.getPartitioningScheme().getHashColumn()), - Joiner.on(", ").join(node.getPartitioningScheme().getPartitioning().getArguments()))); + Joiner.on(", ").join(node.getPartitioningScheme().getPartitioning().getArguments())), context.getTag()); } else { addNode(node, @@ -1230,83 +1232,84 @@ else if (node.getScope().isLocal()) { node.getType(), node.getPartitioningScheme().getEncoding(), node.getPartitioningScheme().isReplicateNullsAndAny() ? " - REPLICATE NULLS AND ANY" : "", - formatHash(node.getPartitioningScheme().getHashColumn()))); + formatHash(node.getPartitioningScheme().getHashColumn())), context.getTag()); } - return processChildren(node, context); + return processChildren(node, new Context()); } @Override - public Void visitDelete(DeleteNode node, Void context) + public Void visitDelete(DeleteNode node, Context context) { - addNode(node, "Delete"); + addNode(node, "Delete", context.getTag()); - return processChildren(node, context); + return processChildren(node, new Context()); } @Override - public Void visitUpdate(UpdateNode node, Void context) + public Void visitUpdate(UpdateNode node, Context context) { - addNode(node, "Update"); + addNode(node, "Update", context.getTag()); - return processChildren(node, context); + return processChildren(node, new Context()); } @Override - public Void visitMetadataDelete(MetadataDeleteNode node, Void context) + public Void visitMetadataDelete(MetadataDeleteNode node, Context context) { - addNode(node, "MetadataDelete", format("[%s]", node.getTableHandle())); + addNode(node, "MetadataDelete", format("[%s]", node.getTableHandle()), context.getTag()); - return processChildren(node, context); + return processChildren(node, new Context()); } @Override - public Void visitEnforceSingleRow(EnforceSingleRowNode node, Void context) + public Void visitEnforceSingleRow(EnforceSingleRowNode node, Context context) { - addNode(node, "EnforceSingleRow"); + addNode(node, "EnforceSingleRow", context.getTag()); - return processChildren(node, context); + return processChildren(node, new Context()); } @Override - public Void visitAssignUniqueId(AssignUniqueId node, Void context) + public Void visitAssignUniqueId(AssignUniqueId node, Context context) { - addNode(node, "AssignUniqueId"); + addNode(node, "AssignUniqueId", context.getTag()); - return processChildren(node, context); + return processChildren(node, new Context()); } @Override - public Void visitGroupReference(GroupReference node, Void context) + public Void visitGroupReference(GroupReference node, Context context) { - addNode(node, "GroupReference", format("[%s]", node.getGroupId()), ImmutableList.of()); + addNode(node, "GroupReference", format("[%s]", node.getGroupId()), ImmutableList.of(), context.getTag()); return null; } @Override - public Void visitApply(ApplyNode node, Void context) + public Void visitApply(ApplyNode node, Context context) { - NodeRepresentation nodeOutput = addNode(node, "Apply", format("[%s]", node.getCorrelation())); + NodeRepresentation nodeOutput = addNode(node, "Apply", format("[%s]", node.getCorrelation()), context.getTag()); printAssignments(nodeOutput, node.getSubqueryAssignments()); - return processChildren(node, context); + return processChildren(node, new Context()); } @Override - public Void visitLateralJoin(LateralJoinNode node, Void context) + public Void visitLateralJoin(LateralJoinNode node, Context context) { - addNode(node, "Lateral", format("[%s]", node.getCorrelation())); + addNode(node, "Lateral", format("[%s]", node.getCorrelation()), context.getTag()); - return processChildren(node, context); + return processChildren(node, new Context()); } @Override - public Void visitTableFunction(TableFunctionNode node, Void context) + public Void visitTableFunction(TableFunctionNode node, Context context) { NodeRepresentation nodeOutput = addNode( node, "TableFunction", - "name"); + "name", + context.getTag()); checkArgument( node.getSources().isEmpty() && node.getTableArgumentProperties().isEmpty(), @@ -1316,7 +1319,7 @@ public Void visitTableFunction(TableFunctionNode node, Void context) // node.getArguments().entrySet().stream() // .forEach(entry -> nodeOutput.appendDetails(entry.getKey() + " => " + formatArgument((ScalarArgument) entry.getValue()))); - return processChildren(node, context); + return processChildren(node, new Context()); } /* @@ -1327,12 +1330,12 @@ private String formatArgument(ScalarArgument argument) */ @Override - public Void visitPlan(PlanNode node, Void context) + public Void visitPlan(PlanNode node, Context context) { throw new UnsupportedOperationException("not yet implemented: " + node.getClass().getName()); } - private Void processChildren(PlanNode node, Void context) + private Void processChildren(PlanNode node, Context context) { for (PlanNode child : node.getSources()) { child.accept(this, context); @@ -1391,27 +1394,27 @@ private String formatDomain(Domain domain) return "[" + Joiner.on(", ").join(parts.build()) + "]"; } - public NodeRepresentation addNode(PlanNode node, String name) + public NodeRepresentation addNode(PlanNode node, String name, Optional tag) { - return addNode(node, name, ""); + return addNode(node, name, "", tag); } - public NodeRepresentation addNode(PlanNode node, String name, String identifier) + public NodeRepresentation addNode(PlanNode node, String name, String identifier, Optional tag) { - return addNode(node, name, identifier, node.getSources()); + return addNode(node, name, identifier, node.getSources(), tag); } - public NodeRepresentation addNode(PlanNode node, String name, String identifier, List children) + public NodeRepresentation addNode(PlanNode node, String name, String identifier, List children, Optional tag) { - return addNode(node, name, identifier, ImmutableList.of(node.getId()), children, ImmutableList.of()); + return addNode(node, name, identifier, ImmutableList.of(node.getId()), children, ImmutableList.of(), tag); } - public NodeRepresentation addNode(PlanNode node, String name, List children) + public NodeRepresentation addNode(PlanNode node, String name, List children, Optional tag) { - return addNode(node, name, "", ImmutableList.of(node.getId()), children, ImmutableList.of()); + return addNode(node, name, "", ImmutableList.of(node.getId()), children, ImmutableList.of(), tag); } - public NodeRepresentation addNode(PlanNode rootNode, String name, String identifier, List allNodes, List children, List remoteSources) + public NodeRepresentation addNode(PlanNode rootNode, String name, String identifier, List allNodes, List children, List remoteSources, Optional tag) { List childrenIds = children.stream().map(PlanNode::getId).collect(toImmutableList()); List estimatedStats = allNodes.stream() @@ -1420,6 +1423,9 @@ public NodeRepresentation addNode(PlanNode rootNode, String name, String identif List estimatedCosts = allNodes.stream() .map(nodeId -> estimatedStatsAndCosts.getCosts().getOrDefault(nodeId, PlanCostEstimate.unknown())) .collect(toList()); + name = tag + .map(tagName -> format("[%s] ", tagName)) + .orElse("") + name; NodeRepresentation nodeOutput = new NodeRepresentation( Optional.empty(), @@ -1516,4 +1522,25 @@ private static String formatOutputs(Iterable output .map(input -> input + ":" + input.getType().getDisplayName()) .collect(Collectors.joining(", ")); } + + + public class Context { + private final Optional tag; + + public Context() { + this(Optional.empty()); + } + + public Context(String tag) { + this(Optional.of(tag)); + } + + public Context(Optional tag) { + this.tag = requireNonNull(tag, "tag is null"); + } + + public Optional getTag() { + return tag; + } + } } From 989fd04347e98424509b6e1fa5d17728b495c296 Mon Sep 17 00:00:00 2001 From: Xin Zhang Date: Wed, 19 Mar 2025 13:12:43 +0000 Subject: [PATCH 03/12] Copy arguments in the constructor Changes adapted from trino/PR#14175 Original commit: 4666472b0188aa26087840cdb587cc6e4495edef Author: kasiafi Co-authored-by: kasiafi <30203062+kasiafi@users.noreply.github.com> --- .../presto/sql/planner/plan/TableFunctionNode.java | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/TableFunctionNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/TableFunctionNode.java index 342213274154c..3ded5b219a412 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/TableFunctionNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/TableFunctionNode.java @@ -22,6 +22,8 @@ import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import javax.annotation.concurrent.Immutable; @@ -69,10 +71,10 @@ public TableFunctionNode( { super(sourceLocation, id, statsEquivalentPlanNode); this.name = requireNonNull(name, "name is null"); - this.arguments = requireNonNull(arguments, "arguments is null"); - this.outputVariables = requireNonNull(outputVariables, "properOutputs is null"); - this.sources = requireNonNull(sources, "sources is null"); - this.tableArgumentProperties = requireNonNull(tableArgumentProperties, "tableArgumentProperties is null"); + this.arguments = ImmutableMap.copyOf(arguments); + this.outputVariables = ImmutableList.copyOf(outputVariables); + this.sources = ImmutableList.copyOf(sources); + this.tableArgumentProperties = ImmutableList.copyOf(tableArgumentProperties); this.handle = requireNonNull(handle, "handle is null"); } From 235bf2761c4dce5aa84eb91c1fc319fabf0036a9 Mon Sep 17 00:00:00 2001 From: Xin Zhang Date: Wed, 19 Mar 2025 13:17:46 +0000 Subject: [PATCH 04/12] Refactor TestingTableFunctions Changes adapted from trino/PR#14175 Original commit: 1aea489884346822c812b1a242acc286e3e1248e Author: kasiafi Co-authored-by: kasiafi <30203062+kasiafi@users.noreply.github.com> --- .../connector/tvf/TestingTableFunctions.java | 44 +++++++++++++------ 1 file changed, 30 insertions(+), 14 deletions(-) diff --git a/presto-main/src/test/java/com/facebook/presto/connector/tvf/TestingTableFunctions.java b/presto-main/src/test/java/com/facebook/presto/connector/tvf/TestingTableFunctions.java index 391dc44e29029..20eb32585060e 100644 --- a/presto-main/src/test/java/com/facebook/presto/connector/tvf/TestingTableFunctions.java +++ b/presto-main/src/test/java/com/facebook/presto/connector/tvf/TestingTableFunctions.java @@ -45,10 +45,15 @@ public class TestingTableFunctions { private static final String SCHEMA_NAME = "system"; - private static final ConnectorTableFunctionHandle HANDLE = new ConnectorTableFunctionHandle() {}; + private static final String TABLE_NAME = "table"; + private static final String COLUMN_NAME = "column"; + private static final ConnectorTableFunctionHandle HANDLE = new TestingTableFunctionHandle(); private static final TableFunctionAnalysis ANALYSIS = TableFunctionAnalysis.builder() .handle(HANDLE) - .returnedType(new Descriptor(ImmutableList.of(new Descriptor.Field("column", Optional.of(BOOLEAN))))) + .returnedType(new Descriptor(ImmutableList.of(new Descriptor.Field(COLUMN_NAME, Optional.of(BOOLEAN))))) + .build(); + private static final TableFunctionAnalysis NO_DESCRIPTOR_ANALYSIS = TableFunctionAnalysis.builder() + .handle(HANDLE) .build(); /** @@ -254,9 +259,7 @@ public OnlyPassThroughFunction() @Override public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) { - return TableFunctionAnalysis.builder() - .handle(HANDLE) - .build(); + return NO_DESCRIPTOR_ANALYSIS; } } @@ -277,9 +280,7 @@ public MonomorphicStaticReturnTypeFunction() @Override public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) { - return TableFunctionAnalysis.builder() - .handle(HANDLE) - .build(); + return NO_DESCRIPTOR_ANALYSIS; } } @@ -302,9 +303,7 @@ public PolymorphicStaticReturnTypeFunction() @Override public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) { - return TableFunctionAnalysis.builder() - .handle(HANDLE) - .build(); + return NO_DESCRIPTOR_ANALYSIS; } } @@ -328,9 +327,26 @@ public PassThroughFunction() @Override public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) { - return TableFunctionAnalysis.builder() - .handle(HANDLE) - .build(); + return NO_DESCRIPTOR_ANALYSIS; + } + } + + public static class TestingTableFunctionHandle + implements ConnectorTableFunctionHandle + { + private final MockConnectorTableHandle tableHandle; + + public TestingTableFunctionHandle() + { + this.tableHandle = new MockConnectorTableHandle( + new SchemaTableName(SCHEMA_NAME, TABLE_NAME), + TupleDomain.all(), + Optional.of(ImmutableList.of(new MockConnectorColumnHandle(COLUMN_NAME, BOOLEAN)))); + } + + public MockConnectorTableHandle getTableHandle() + { + return tableHandle; } } } From e97a16e15a8d0eb9c4aeffa404c1ae70377da2b1 Mon Sep 17 00:00:00 2001 From: Xin Zhang Date: Wed, 19 Mar 2025 14:12:35 +0000 Subject: [PATCH 05/12] Extract WindowNode.Specification as a separate class Changes adapted from trino/PR#14175 Original commit: 80c7fa0519eea07d8417d23908e8d1f8774dc3cd Author: kasiafi Co-authored-by: kasiafi <30203062+kasiafi@users.noreply.github.com> --- .../sql/planner/CanonicalPlanGenerator.java | 5 +- .../presto/sql/planner/QueryPlanner.java | 3 +- .../optimizations/PlanNodeDecorrelator.java | 4 +- .../UnaliasSymbolReferences.java | 5 +- .../sql/planner/plan/TableFunctionNode.java | 8 +- .../sql/planner/plan/TopNRowNumberNode.java | 10 +-- .../presto/sql/planner/TestCanonicalize.java | 3 +- .../TestEffectivePredicateExtractor.java | 3 +- .../presto/sql/planner/TestTypeValidator.java | 7 +- .../planner/assertions/PlanMatchPattern.java | 3 +- .../assertions/SpecificationProvider.java | 11 +-- .../assertions/TopNRowNumberMatcher.java | 8 +- .../sql/planner/assertions/WindowMatcher.java | 9 +- .../rule/TestMergeAdjacentWindows.java | 11 +-- .../rule/TestPruneWindowColumns.java | 3 +- ...stSwapAdjacentWindowsBySpecifications.java | 23 +++--- .../iterative/rule/test/PlanBuilder.java | 5 +- .../optimizations/TestEliminateSorts.java | 3 +- .../optimizations/TestMergeWindows.java | 27 +++--- .../TestPruneUnreferencedOutputs.java | 3 +- .../optimizations/TestReorderWindows.java | 15 ++-- .../sql/planner/plan/TestWindowNode.java | 3 +- .../plan/DataOrganizationSpecification.java | 82 +++++++++++++++++++ .../facebook/presto/spi/plan/WindowNode.java | 64 ++------------- 24 files changed, 182 insertions(+), 136 deletions(-) create mode 100644 presto-spi/src/main/java/com/facebook/presto/spi/plan/DataOrganizationSpecification.java diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/CanonicalPlanGenerator.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/CanonicalPlanGenerator.java index 5825cbeca2dd6..068b4459ea02d 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/CanonicalPlanGenerator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/CanonicalPlanGenerator.java @@ -27,6 +27,7 @@ import com.facebook.presto.spi.plan.CteConsumerNode; import com.facebook.presto.spi.plan.CteProducerNode; import com.facebook.presto.spi.plan.CteReferenceNode; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; import com.facebook.presto.spi.plan.DistinctLimitNode; import com.facebook.presto.spi.plan.EquiJoinClause; import com.facebook.presto.spi.plan.FilterNode; @@ -480,7 +481,7 @@ public Optional visitWindow(WindowNode node, Context context) .sorted(comparing(this::writeValueAsString)) .collect(toImmutableSet()); - WindowNode.Specification specification = new WindowNode.Specification( + DataOrganizationSpecification specification = new DataOrganizationSpecification( node.getSpecification().getPartitionBy().stream() .map(variable -> inlineAndCanonicalize(context.getExpressions(), variable)) .sorted(comparing(this::writeValueAsString)) @@ -694,7 +695,7 @@ public Optional visitTopNRowNumber(TopNRowNumberNode node, Context con Optional.empty(), planNodeidAllocator.getNextId(), source.get(), - new WindowNode.Specification( + new DataOrganizationSpecification( partitionBy, node.getSpecification().getOrderingScheme().map(scheme -> getCanonicalOrderingScheme(scheme, context.getExpressions()))), rowNumberVariable, diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java index 0ef7e7d7d663f..71149793bd3d8 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java @@ -29,6 +29,7 @@ import com.facebook.presto.spi.plan.AggregationNode; import com.facebook.presto.spi.plan.AggregationNode.Aggregation; import com.facebook.presto.spi.plan.Assignments; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; import com.facebook.presto.spi.plan.DeleteNode; import com.facebook.presto.spi.plan.FilterNode; import com.facebook.presto.spi.plan.LimitNode; @@ -1069,7 +1070,7 @@ else if (window.getFrame().isPresent()) { subPlan.getRoot().getSourceLocation(), idAllocator.getNextId(), subPlan.getRoot(), - new WindowNode.Specification( + new DataOrganizationSpecification( partitionByVariables.build(), orderingScheme), ImmutableMap.of(newVariable, function), diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PlanNodeDecorrelator.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PlanNodeDecorrelator.java index 9110ec4ceb128..9915c495762be 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PlanNodeDecorrelator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PlanNodeDecorrelator.java @@ -18,6 +18,7 @@ import com.facebook.presto.spi.VariableAllocator; import com.facebook.presto.spi.plan.AggregationNode; import com.facebook.presto.spi.plan.Assignments; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; import com.facebook.presto.spi.plan.FilterNode; import com.facebook.presto.spi.plan.LimitNode; import com.facebook.presto.spi.plan.Ordering; @@ -27,7 +28,6 @@ import com.facebook.presto.spi.plan.PlanNodeIdAllocator; import com.facebook.presto.spi.plan.ProjectNode; import com.facebook.presto.spi.plan.TopNNode; -import com.facebook.presto.spi.plan.WindowNode.Specification; import com.facebook.presto.spi.relation.CallExpression; import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; @@ -307,7 +307,7 @@ public Optional visitTopN(TopNNode node, Void context) decorrelatedChildNode.getSourceLocation(), node.getId(), decorrelatedChildNode, - new Specification( + new DataOrganizationSpecification( ImmutableList.copyOf(childDecorrelationResult.variablesToPropagate), Optional.of(orderingScheme)), variableAllocator.newVariable("row_number", BIGINT), diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java index 463836291847b..11b1cd5f48ed0 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java @@ -23,6 +23,7 @@ import com.facebook.presto.spi.plan.CteConsumerNode; import com.facebook.presto.spi.plan.CteProducerNode; import com.facebook.presto.spi.plan.CteReferenceNode; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; import com.facebook.presto.spi.plan.DeleteNode; import com.facebook.presto.spi.plan.DistinctLimitNode; import com.facebook.presto.spi.plan.EquiJoinClause; @@ -813,9 +814,9 @@ private Map canonicalizeAndDistinct(Map sourceLocation, @JsonProperty("id") PlanNodeId id, @JsonProperty("source") PlanNode source, - @JsonProperty("specification") Specification specification, + @JsonProperty("specification") DataOrganizationSpecification specification, @JsonProperty("rowNumberVariable") VariableReferenceExpression rowNumberVariable, @JsonProperty("maxRowCountPerPartition") int maxRowCountPerPartition, @JsonProperty("partial") boolean partial, @@ -62,7 +62,7 @@ public TopNRowNumberNode( PlanNodeId id, Optional statsEquivalentPlanNode, PlanNode source, - Specification specification, + DataOrganizationSpecification specification, VariableReferenceExpression rowNumberVariable, int maxRowCountPerPartition, boolean partial, @@ -109,7 +109,7 @@ public PlanNode getSource() } @JsonProperty - public Specification getSpecification() + public DataOrganizationSpecification getSpecification() { return specification; } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestCanonicalize.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestCanonicalize.java index 1fd0121232144..744d7a4fdaf38 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestCanonicalize.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestCanonicalize.java @@ -15,6 +15,7 @@ import com.facebook.presto.Session; import com.facebook.presto.common.block.SortOrder; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; import com.facebook.presto.spi.plan.WindowNode; import com.facebook.presto.sql.planner.assertions.BasePlanTest; import com.facebook.presto.sql.planner.assertions.ExpectedValueProvider; @@ -81,7 +82,7 @@ public void testJoin() @Test public void testDuplicatesInWindowOrderBy() { - ExpectedValueProvider specification = specification( + ExpectedValueProvider specification = specification( ImmutableList.of(), ImmutableList.of("A"), ImmutableMap.of("A", SortOrder.ASC_NULLS_LAST)); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestEffectivePredicateExtractor.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestEffectivePredicateExtractor.java index 46676111a82aa..ae2a450ef2b05 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestEffectivePredicateExtractor.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestEffectivePredicateExtractor.java @@ -24,6 +24,7 @@ import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.TableHandle; import com.facebook.presto.spi.plan.AggregationNode; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; import com.facebook.presto.spi.plan.EquiJoinClause; import com.facebook.presto.spi.plan.FilterNode; import com.facebook.presto.spi.plan.JoinNode; @@ -342,7 +343,7 @@ public void testWindow() equals(AV, BV), equals(BV, CV), lessThan(CV, bigintLiteral(10)))), - new WindowNode.Specification( + new DataOrganizationSpecification( ImmutableList.of(AV), Optional.of(new OrderingScheme( ImmutableList.of(new Ordering(AV, SortOrder.ASC_NULLS_LAST))))), diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestTypeValidator.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestTypeValidator.java index 07a0529289495..8c7995568caec 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestTypeValidator.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestTypeValidator.java @@ -26,6 +26,7 @@ import com.facebook.presto.spi.plan.AggregationNode; import com.facebook.presto.spi.plan.AggregationNode.Aggregation; import com.facebook.presto.spi.plan.Assignments; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.PlanNodeId; import com.facebook.presto.spi.plan.ProjectNode; @@ -175,7 +176,7 @@ public void testValidWindow() WindowNode.Function function = new WindowNode.Function(call("sum", functionHandle, DOUBLE, variableC), frame, false); - WindowNode.Specification specification = new WindowNode.Specification(ImmutableList.of(), Optional.empty()); + DataOrganizationSpecification specification = new DataOrganizationSpecification(ImmutableList.of(), Optional.empty()); PlanNode node = new WindowNode( Optional.empty(), @@ -437,7 +438,7 @@ public void testInvalidWindowFunctionCall() WindowNode.Function function = new WindowNode.Function(call("sum", functionHandle, BIGINT, variableA), frame, false); - WindowNode.Specification specification = new WindowNode.Specification(ImmutableList.of(), Optional.empty()); + DataOrganizationSpecification specification = new DataOrganizationSpecification(ImmutableList.of(), Optional.empty()); PlanNode node = new WindowNode( Optional.empty(), @@ -471,7 +472,7 @@ public void testInvalidWindowFunctionSignature() WindowNode.Function function = new WindowNode.Function(call("sum", functionHandle, BIGINT, variableC), frame, false); - WindowNode.Specification specification = new WindowNode.Specification(ImmutableList.of(), Optional.empty()); + DataOrganizationSpecification specification = new DataOrganizationSpecification(ImmutableList.of(), Optional.empty()); PlanNode node = new WindowNode( Optional.empty(), diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java index 8702e3c0f7a1b..af888762f2021 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java @@ -23,6 +23,7 @@ import com.facebook.presto.spi.plan.AggregationNode.Step; import com.facebook.presto.spi.plan.CteConsumerNode; import com.facebook.presto.spi.plan.CteProducerNode; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; import com.facebook.presto.spi.plan.EquiJoinClause; import com.facebook.presto.spi.plan.ExceptNode; import com.facebook.presto.spi.plan.FilterNode; @@ -927,7 +928,7 @@ private static List toSymbolAliases(List aliases) .collect(toImmutableList()); } - public static ExpectedValueProvider specification( + public static ExpectedValueProvider specification( List partitionBy, List orderBy, Map orderings) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/SpecificationProvider.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/SpecificationProvider.java index 1caf681620280..dc621f809a8d9 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/SpecificationProvider.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/SpecificationProvider.java @@ -14,6 +14,7 @@ package com.facebook.presto.sql.planner.assertions; import com.facebook.presto.common.block.SortOrder; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; import com.facebook.presto.spi.plan.Ordering; import com.facebook.presto.spi.plan.OrderingScheme; import com.facebook.presto.spi.plan.WindowNode; @@ -33,7 +34,7 @@ import static java.util.Objects.requireNonNull; public class SpecificationProvider - implements ExpectedValueProvider + implements ExpectedValueProvider { private final List partitionBy; private final List orderBy; @@ -50,7 +51,7 @@ public class SpecificationProvider } @Override - public WindowNode.Specification getExpectedValue(SymbolAliases aliases) + public DataOrganizationSpecification getExpectedValue(SymbolAliases aliases) { Optional orderingScheme = Optional.empty(); if (!orderBy.isEmpty()) { @@ -64,7 +65,7 @@ public WindowNode.Specification getExpectedValue(SymbolAliases aliases) .collect(toImmutableList()))); } - return new WindowNode.Specification( + return new DataOrganizationSpecification( partitionBy .stream() .map(alias -> new VariableReferenceExpression(Optional.empty(), alias.toSymbol(aliases).getName(), UNKNOWN)) @@ -87,7 +88,7 @@ public String toString() * VariableReferenceExpression::equals to check whether two specification are equivalent once they include VariableReferenceExpression. * TODO Directly use equals once SymbolAlias is converted to something with type information. */ - public static boolean matchSpecification(WindowNode.Specification actual, WindowNode.Specification expected) + public static boolean matchSpecification(DataOrganizationSpecification actual, DataOrganizationSpecification expected) { return actual.getPartitionBy().stream().map(VariableReferenceExpression::getName).collect(toImmutableList()) .equals(expected.getPartitionBy().stream().map(VariableReferenceExpression::getName).collect(toImmutableList())) && @@ -104,7 +105,7 @@ public static boolean matchSpecification(WindowNode.Specification actual, Window .orElse(true); } - public static boolean matchSpecification(WindowNode.Specification actual, SpecificationProvider expected) + public static boolean matchSpecification(DataOrganizationSpecification actual, SpecificationProvider expected) { return actual.getPartitionBy().stream().map(VariableReferenceExpression::getName).collect(toImmutableList()) .equals(expected.partitionBy.stream().map(SymbolAlias::toString).collect(toImmutableList())) && diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/TopNRowNumberMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/TopNRowNumberMatcher.java index c4d66eb263263..590625570cbd7 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/TopNRowNumberMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/TopNRowNumberMatcher.java @@ -17,8 +17,8 @@ import com.facebook.presto.common.block.SortOrder; import com.facebook.presto.cost.StatsProvider; import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; import com.facebook.presto.spi.plan.PlanNode; -import com.facebook.presto.spi.plan.WindowNode; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; @@ -38,14 +38,14 @@ public class TopNRowNumberMatcher implements Matcher { - private final Optional> specification; + private final Optional> specification; private final Optional rowNumberSymbol; private final Optional maxRowCountPerPartition; private final Optional partial; private final Optional> hashSymbol; private TopNRowNumberMatcher( - Optional> specification, + Optional> specification, Optional rowNumberSymbol, Optional maxRowCountPerPartition, Optional partial, @@ -125,7 +125,7 @@ public String toString() public static class Builder { private final PlanMatchPattern source; - private Optional> specification = Optional.empty(); + private Optional> specification = Optional.empty(); private Optional rowNumberSymbol = Optional.empty(); private Optional maxRowCountPerPartition = Optional.empty(); private Optional partial = Optional.empty(); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/WindowMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/WindowMatcher.java index 02aa69afdd716..ffbe3c37d3e26 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/WindowMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/WindowMatcher.java @@ -18,6 +18,7 @@ import com.facebook.presto.cost.StatsProvider; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.function.FunctionHandle; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.WindowNode; import com.facebook.presto.spi.relation.VariableReferenceExpression; @@ -46,13 +47,13 @@ public final class WindowMatcher implements Matcher { private final Optional> prePartitionedInputs; - private final Optional> specification; + private final Optional> specification; private final Optional preSortedOrderPrefix; private final Optional> hashSymbol; private WindowMatcher( Optional> prePartitionedInputs, - Optional> specification, + Optional> specification, Optional preSortedOrderPrefix, Optional> hashSymbol) { @@ -136,7 +137,7 @@ public static class Builder { private final PlanMatchPattern source; private Optional> prePartitionedInputs = Optional.empty(); - private Optional> specification = Optional.empty(); + private Optional> specification = Optional.empty(); private Optional preSortedOrderPrefix = Optional.empty(); private List windowFunctionMatchers = new LinkedList<>(); private Optional> hashSymbol = Optional.empty(); @@ -164,7 +165,7 @@ public Builder specification( return specification(PlanMatchPattern.specification(partitionBy, orderBy, orderings)); } - public Builder specification(ExpectedValueProvider specification) + public Builder specification(ExpectedValueProvider specification) { requireNonNull(specification, "specification is null"); this.specification = Optional.of(specification); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMergeAdjacentWindows.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMergeAdjacentWindows.java index f0e9d391ec889..55c64c8c946a3 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMergeAdjacentWindows.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMergeAdjacentWindows.java @@ -16,6 +16,7 @@ import com.facebook.presto.common.block.SortOrder; import com.facebook.presto.spi.function.FunctionHandle; import com.facebook.presto.spi.plan.Assignments; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; import com.facebook.presto.spi.plan.Ordering; import com.facebook.presto.spi.plan.OrderingScheme; import com.facebook.presto.spi.plan.WindowNode; @@ -93,7 +94,7 @@ public class TestMergeAdjacentWindows private static final FunctionHandle LAG_FUNCTION_HANDLE = createTestMetadataManager().getFunctionAndTypeManager().lookupFunction("lag", fromTypes(DOUBLE)); private static final FunctionHandle RANK_FUNCTION_HANDLE = createTestMetadataManager().getFunctionAndTypeManager().lookupFunction("rank", ImmutableList.of()); private static final String columnAAlias = "ALIAS_A"; - private static final ExpectedValueProvider specificationA = + private static final ExpectedValueProvider specificationA = specification(ImmutableList.of(columnAAlias), ImmutableList.of(), ImmutableMap.of()); @Test @@ -275,14 +276,14 @@ public void testIntermediateProjectNodes() values(columnAAlias, unusedAlias)))))); } - private static WindowNode.Specification newWindowNodeSpecification(PlanBuilder planBuilder, String symbolName) + private static DataOrganizationSpecification newWindowNodeSpecification(PlanBuilder planBuilder, String symbolName) { - return new WindowNode.Specification(ImmutableList.of(planBuilder.variable(symbolName, BIGINT)), Optional.empty()); + return new DataOrganizationSpecification(ImmutableList.of(planBuilder.variable(symbolName, BIGINT)), Optional.empty()); } - private static WindowNode.Specification newWindowNodeSpecification(PlanBuilder planBuilder, String symbolName, String sortkey) + private static DataOrganizationSpecification newWindowNodeSpecification(PlanBuilder planBuilder, String symbolName, String sortkey) { - return new WindowNode.Specification(ImmutableList.of(planBuilder.variable(symbolName, BIGINT)), + return new DataOrganizationSpecification(ImmutableList.of(planBuilder.variable(symbolName, BIGINT)), Optional.of(new OrderingScheme( ImmutableList.of(new Ordering(planBuilder.variable(sortkey, BIGINT), SortOrder.ASC_NULLS_FIRST))))); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneWindowColumns.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneWindowColumns.java index c4f68ef174e8f..ccbf278341a55 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneWindowColumns.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneWindowColumns.java @@ -15,6 +15,7 @@ import com.facebook.presto.common.block.SortOrder; import com.facebook.presto.spi.function.FunctionHandle; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; import com.facebook.presto.spi.plan.Ordering; import com.facebook.presto.spi.plan.OrderingScheme; import com.facebook.presto.spi.plan.PlanNode; @@ -280,7 +281,7 @@ private static PlanNode buildProjectedWindow( .filter(projectionFilter) .collect(toImmutableList())), p.window( - new WindowNode.Specification( + new DataOrganizationSpecification( ImmutableList.of(partitionKey), Optional.of(new OrderingScheme( ImmutableList.of(new Ordering(orderKey, SortOrder.ASC_NULLS_FIRST))))), diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSwapAdjacentWindowsBySpecifications.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSwapAdjacentWindowsBySpecifications.java index 9209e2df9cdc4..b3fa34f67a8d1 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSwapAdjacentWindowsBySpecifications.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSwapAdjacentWindowsBySpecifications.java @@ -15,6 +15,7 @@ import com.facebook.presto.common.block.SortOrder; import com.facebook.presto.spi.function.FunctionHandle; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; import com.facebook.presto.spi.plan.Ordering; import com.facebook.presto.spi.plan.OrderingScheme; import com.facebook.presto.spi.plan.WindowNode; @@ -79,7 +80,7 @@ public void doesNotFireOnPlanWithoutWindowFunctions() public void doesNotFireOnPlanWithSingleWindowNode() { tester().assertThat(new GatherAndMergeWindows.SwapAdjacentWindowsBySpecifications(0)) - .on(p -> p.window(new WindowNode.Specification( + .on(p -> p.window(new DataOrganizationSpecification( ImmutableList.of(p.variable("a")), Optional.empty()), ImmutableMap.of(p.variable("avg_1"), @@ -94,16 +95,16 @@ public void subsetComesFirst() String columnAAlias = "ALIAS_A"; String columnBAlias = "ALIAS_B"; - ExpectedValueProvider specificationA = specification(ImmutableList.of(columnAAlias), ImmutableList.of(), ImmutableMap.of()); - ExpectedValueProvider specificationAB = specification(ImmutableList.of(columnAAlias, columnBAlias), ImmutableList.of(), ImmutableMap.of()); + ExpectedValueProvider specificationA = specification(ImmutableList.of(columnAAlias), ImmutableList.of(), ImmutableMap.of()); + ExpectedValueProvider specificationAB = specification(ImmutableList.of(columnAAlias, columnBAlias), ImmutableList.of(), ImmutableMap.of()); tester().assertThat(new GatherAndMergeWindows.SwapAdjacentWindowsBySpecifications(0)) .on(p -> - p.window(new WindowNode.Specification( + p.window(new DataOrganizationSpecification( ImmutableList.of(p.variable("a")), Optional.empty()), ImmutableMap.of(p.variable("avg_1", DOUBLE), newWindowNodeFunction(ImmutableList.of(new Symbol("a")))), - p.window(new WindowNode.Specification( + p.window(new DataOrganizationSpecification( ImmutableList.of(p.variable("a"), p.variable("b")), Optional.empty()), ImmutableMap.of(p.variable("avg_2", DOUBLE), newWindowNodeFunction(ImmutableList.of(new Symbol("b")))), @@ -123,11 +124,11 @@ public void dependentWindowsAreNotReordered() { tester().assertThat(new GatherAndMergeWindows.SwapAdjacentWindowsBySpecifications(0)) .on(p -> - p.window(new WindowNode.Specification( + p.window(new DataOrganizationSpecification( ImmutableList.of(p.variable("a")), Optional.empty()), ImmutableMap.of(p.variable("avg_1"), newWindowNodeFunction(ImmutableList.of(new Symbol("avg_2")))), - p.window(new WindowNode.Specification( + p.window(new DataOrganizationSpecification( ImmutableList.of(p.variable("a"), p.variable("b")), Optional.empty()), ImmutableMap.of(p.variable("avg_2"), newWindowNodeFunction(ImmutableList.of(new Symbol("a")))), @@ -168,11 +169,11 @@ public void dependentWindowsAreNotReorderedWithOffset() tester().assertThat(new GatherAndMergeWindows.SwapAdjacentWindowsBySpecifications(0)) .on(p -> - p.window(new WindowNode.Specification( + p.window(new DataOrganizationSpecification( ImmutableList.of(p.variable("a")), Optional.of(new OrderingScheme(ImmutableList.of(new Ordering(p.variable("sortkey", BIGINT), SortOrder.ASC_NULLS_FIRST))))), ImmutableMap.of(p.variable("avg_1"), functionWithOffset), - p.window(new WindowNode.Specification( + p.window(new DataOrganizationSpecification( ImmutableList.of(p.variable("a"), p.variable("b")), Optional.of(new OrderingScheme(ImmutableList.of(new Ordering(p.variable("sortkey", BIGINT), SortOrder.ASC_NULLS_FIRST))))), ImmutableMap.of(p.variable("startValue"), windowFunction), @@ -213,11 +214,11 @@ public void dependentWindowsWithRangeAreNotReordered() tester().assertThat(new GatherAndMergeWindows.SwapAdjacentWindowsBySpecifications(0)) .on(p -> - p.window(new WindowNode.Specification( + p.window(new DataOrganizationSpecification( ImmutableList.of(p.variable("a")), Optional.of(new OrderingScheme(ImmutableList.of(new Ordering(p.variable("sortkey", BIGINT), SortOrder.ASC_NULLS_FIRST))))), ImmutableMap.of(p.variable("avg_1"), functionWithOffset), - p.window(new WindowNode.Specification( + p.window(new DataOrganizationSpecification( ImmutableList.of(p.variable("a"), p.variable("b")), Optional.of(new OrderingScheme(ImmutableList.of(new Ordering(p.variable("sortkey", BIGINT), SortOrder.ASC_NULLS_FIRST))))), ImmutableMap.of(p.variable("startValue"), windowFunction), diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java index 6e541d36606aa..0ff91e3ea8fc3 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java @@ -32,6 +32,7 @@ import com.facebook.presto.spi.plan.AggregationNode.Step; import com.facebook.presto.spi.plan.Assignments; import com.facebook.presto.spi.plan.CteProducerNode; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; import com.facebook.presto.spi.plan.DeleteNode; import com.facebook.presto.spi.plan.DistinctLimitNode; import com.facebook.presto.spi.plan.EquiJoinClause; @@ -933,7 +934,7 @@ public PlanBuilder registerVariable(VariableReferenceExpression expression) return this; } - public WindowNode window(WindowNode.Specification specification, Map functions, PlanNode source) + public WindowNode window(DataOrganizationSpecification specification, Map functions, PlanNode source) { return new WindowNode( Optional.empty(), @@ -946,7 +947,7 @@ public WindowNode window(WindowNode.Specification specification, Map functions, VariableReferenceExpression hashVariable, PlanNode source) + public WindowNode window(DataOrganizationSpecification specification, Map functions, VariableReferenceExpression hashVariable, PlanNode source) { return new WindowNode( Optional.empty(), diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestEliminateSorts.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestEliminateSorts.java index aa0456e3bcfb3..1b3e0cc4f3625 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestEliminateSorts.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestEliminateSorts.java @@ -15,6 +15,7 @@ import com.facebook.presto.Session; import com.facebook.presto.common.block.SortOrder; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; import com.facebook.presto.spi.plan.WindowNode; import com.facebook.presto.sql.planner.PartitioningProviderManager; import com.facebook.presto.sql.planner.RuleStatsRecorder; @@ -53,7 +54,7 @@ public class TestEliminateSorts private static final String QUANTITY_ALIAS = "QUANTITY"; private static final String TAX_ALIAS = "TAX"; - private static final ExpectedValueProvider windowSpec = specification( + private static final ExpectedValueProvider windowSpec = specification( ImmutableList.of(), ImmutableList.of(QUANTITY_ALIAS), ImmutableMap.of(QUANTITY_ALIAS, SortOrder.ASC_NULLS_LAST)); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestMergeWindows.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestMergeWindows.java index 379bb4cfd2780..ac48f5531d741 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestMergeWindows.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestMergeWindows.java @@ -14,6 +14,7 @@ package com.facebook.presto.sql.planner.optimizations; import com.facebook.presto.common.block.SortOrder; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; import com.facebook.presto.spi.plan.JoinType; import com.facebook.presto.spi.plan.WindowNode; import com.facebook.presto.sql.planner.RuleStatsRecorder; @@ -88,8 +89,8 @@ public class TestMergeWindows private static final Optional UNSPECIFIED_FRAME = Optional.empty(); - private final ExpectedValueProvider specificationA; - private final ExpectedValueProvider specificationB; + private final ExpectedValueProvider specificationA; + private final ExpectedValueProvider specificationB; public TestMergeWindows() { @@ -292,12 +293,12 @@ public void testIdenticalWindowSpecificationsAAfilterA() @Test public void testIdenticalWindowSpecificationsDefaultFrame() { - ExpectedValueProvider specificationC = specification( + ExpectedValueProvider specificationC = specification( ImmutableList.of(SUPPKEY_ALIAS), ImmutableList.of(ORDERKEY_ALIAS), ImmutableMap.of(ORDERKEY_ALIAS, SortOrder.ASC_NULLS_LAST)); - ExpectedValueProvider specificationD = specification( + ExpectedValueProvider specificationD = specification( ImmutableList.of(ORDERKEY_ALIAS), ImmutableList.of(SHIPDATE_ALIAS), ImmutableMap.of(SHIPDATE_ALIAS, SortOrder.ASC_NULLS_LAST)); @@ -328,7 +329,7 @@ public void testMergeDifferentFrames() new FrameBound(FrameBound.Type.UNBOUNDED_PRECEDING), Optional.of(new FrameBound(FrameBound.Type.CURRENT_ROW)))); - ExpectedValueProvider specificationC = specification( + ExpectedValueProvider specificationC = specification( ImmutableList.of(SUPPKEY_ALIAS), ImmutableList.of(ORDERKEY_ALIAS), ImmutableMap.of(ORDERKEY_ALIAS, SortOrder.ASC_NULLS_LAST)); @@ -362,7 +363,7 @@ public void testMergeDifferentFramesWithDefault() new FrameBound(FrameBound.Type.CURRENT_ROW), Optional.of(new FrameBound(FrameBound.Type.UNBOUNDED_FOLLOWING)))); - ExpectedValueProvider specificationD = specification( + ExpectedValueProvider specificationD = specification( ImmutableList.of(SUPPKEY_ALIAS), ImmutableList.of(ORDERKEY_ALIAS), ImmutableMap.of(ORDERKEY_ALIAS, SortOrder.ASC_NULLS_LAST)); @@ -391,7 +392,7 @@ public void testMergeRangeFramesWithDefault() new FrameBound(FrameBound.Type.CURRENT_ROW), Optional.of(new FrameBound(FrameBound.Type.UNBOUNDED_FOLLOWING)))); - ExpectedValueProvider specificationD = specification( + ExpectedValueProvider specificationD = specification( ImmutableList.of(SUPPKEY_ALIAS), ImmutableList.of(ORDERKEY_ALIAS), ImmutableMap.of(ORDERKEY_ALIAS, SortOrder.ASC_NULLS_LAST)); @@ -433,12 +434,12 @@ public void testNotMergeAcrossJoinBranches() ")" + "SELECT * FROM foo, bar WHERE foo.a = bar.b"; - ExpectedValueProvider leftSpecification = specification( + ExpectedValueProvider leftSpecification = specification( ImmutableList.of(ORDERKEY_ALIAS), ImmutableList.of(SHIPDATE_ALIAS, QUANTITY_ALIAS), ImmutableMap.of(SHIPDATE_ALIAS, SortOrder.ASC_NULLS_LAST, QUANTITY_ALIAS, SortOrder.DESC_NULLS_LAST)); - ExpectedValueProvider rightSpecification = specification( + ExpectedValueProvider rightSpecification = specification( ImmutableList.of(rOrderkeyAlias), ImmutableList.of(rShipdateAlias, rQuantityAlias), ImmutableMap.of(rShipdateAlias, SortOrder.ASC_NULLS_LAST, rQuantityAlias, SortOrder.DESC_NULLS_LAST)); @@ -489,7 +490,7 @@ public void testNotMergeDifferentPartition() "SUM(quantity) over (PARTITION BY quantity ORDER BY orderkey ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) sum_quantity_C " + "FROM lineitem"; - ExpectedValueProvider specificationC = specification( + ExpectedValueProvider specificationC = specification( ImmutableList.of(QUANTITY_ALIAS), ImmutableList.of(ORDERKEY_ALIAS), ImmutableMap.of(ORDERKEY_ALIAS, SortOrder.ASC_NULLS_LAST)); @@ -513,7 +514,7 @@ public void testNotMergeDifferentOrderBy() "SUM(quantity) OVER (PARTITION BY suppkey ORDER BY quantity ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) sum_quantity_C " + "FROM lineitem"; - ExpectedValueProvider specificationC = specification( + ExpectedValueProvider specificationC = specification( ImmutableList.of(SUPPKEY_ALIAS), ImmutableList.of(QUANTITY_ALIAS), ImmutableMap.of(QUANTITY_ALIAS, SortOrder.ASC_NULLS_LAST)); @@ -538,7 +539,7 @@ public void testNotMergeDifferentOrdering() "SUM(discount) over (PARTITION BY suppkey ORDER BY orderkey ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) sum_discount_A " + "FROM lineitem"; - ExpectedValueProvider specificationC = specification( + ExpectedValueProvider specificationC = specification( ImmutableList.of(SUPPKEY_ALIAS), ImmutableList.of(ORDERKEY_ALIAS), ImmutableMap.of(ORDERKEY_ALIAS, SortOrder.DESC_NULLS_LAST)); @@ -564,7 +565,7 @@ public void testNotMergeDifferentNullOrdering() "SUM(discount) OVER (PARTITION BY suppkey ORDER BY orderkey ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) sum_discount_A " + "FROM lineitem"; - ExpectedValueProvider specificationC = specification( + ExpectedValueProvider specificationC = specification( ImmutableList.of(SUPPKEY_ALIAS), ImmutableList.of(ORDERKEY_ALIAS), ImmutableMap.of(ORDERKEY_ALIAS, SortOrder.ASC_NULLS_FIRST)); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestPruneUnreferencedOutputs.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestPruneUnreferencedOutputs.java index af11a87d77e33..6a1eff9da903d 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestPruneUnreferencedOutputs.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestPruneUnreferencedOutputs.java @@ -16,6 +16,7 @@ import com.facebook.presto.common.block.SortOrder; import com.facebook.presto.spi.function.FunctionHandle; import com.facebook.presto.spi.plan.Assignments; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; import com.facebook.presto.spi.plan.Ordering; import com.facebook.presto.spi.plan.OrderingScheme; import com.facebook.presto.spi.plan.WindowNode; @@ -71,7 +72,7 @@ public void windowNodePruning() p.output(ImmutableList.of("user_uuid"), ImmutableList.of(p.variable("user_uuid", VARCHAR)), p.project(Assignments.of(p.variable("user_uuid", VARCHAR), p.variable("user_uuid", VARCHAR)), p.window( - new WindowNode.Specification( + new DataOrganizationSpecification( ImmutableList.of(p.variable("user_uuid", VARCHAR)), Optional.of(new OrderingScheme( ImmutableList.of( diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestReorderWindows.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestReorderWindows.java index 1ea58b5627d57..aa6a80c5a883a 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestReorderWindows.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestReorderWindows.java @@ -14,6 +14,7 @@ package com.facebook.presto.sql.planner.optimizations; import com.facebook.presto.common.block.SortOrder; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; import com.facebook.presto.spi.plan.WindowNode; import com.facebook.presto.sql.InMemoryExpressionOptimizerProvider; import com.facebook.presto.sql.planner.RuleStatsRecorder; @@ -59,13 +60,13 @@ public class TestReorderWindows private static final Optional commonFrame; - private static final ExpectedValueProvider windowA; - private static final ExpectedValueProvider windowAp; - private static final ExpectedValueProvider windowApp; - private static final ExpectedValueProvider windowB; - private static final ExpectedValueProvider windowC; - private static final ExpectedValueProvider windowD; - private static final ExpectedValueProvider windowE; + private static final ExpectedValueProvider windowA; + private static final ExpectedValueProvider windowAp; + private static final ExpectedValueProvider windowApp; + private static final ExpectedValueProvider windowB; + private static final ExpectedValueProvider windowC; + private static final ExpectedValueProvider windowD; + private static final ExpectedValueProvider windowE; static { ImmutableMap.Builder columns = ImmutableMap.builder(); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/plan/TestWindowNode.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/plan/TestWindowNode.java index 767ce955bb423..fa5a8c6db51dc 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/plan/TestWindowNode.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/plan/TestWindowNode.java @@ -25,6 +25,7 @@ import com.facebook.presto.server.SliceSerializer; import com.facebook.presto.spi.VariableAllocator; import com.facebook.presto.spi.function.FunctionHandle; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; import com.facebook.presto.spi.plan.Ordering; import com.facebook.presto.spi.plan.OrderingScheme; import com.facebook.presto.spi.plan.PlanNodeId; @@ -116,7 +117,7 @@ public void testSerializationRoundtrip() Optional.empty()); PlanNodeId id = newId(); - WindowNode.Specification specification = new WindowNode.Specification( + DataOrganizationSpecification specification = new DataOrganizationSpecification( ImmutableList.of(columnA), Optional.of(new OrderingScheme(ImmutableList.of(new Ordering(columnB, SortOrder.ASC_NULLS_FIRST))))); CallExpression call = call("sum", functionHandle, BIGINT, new VariableReferenceExpression(Optional.empty(), columnC.getName(), BIGINT)); diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/plan/DataOrganizationSpecification.java b/presto-spi/src/main/java/com/facebook/presto/spi/plan/DataOrganizationSpecification.java new file mode 100644 index 0000000000000..ad9e1856e48d4 --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/plan/DataOrganizationSpecification.java @@ -0,0 +1,82 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.plan; + +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import javax.annotation.concurrent.Immutable; + +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +import static java.util.Collections.unmodifiableList; +import static java.util.Objects.requireNonNull; + +@Immutable +public class DataOrganizationSpecification +{ + private final List partitionBy; + private final Optional orderingScheme; + + @JsonCreator + public DataOrganizationSpecification( + @JsonProperty("partitionBy") List partitionBy, + @JsonProperty("orderingScheme") Optional orderingScheme) + { + requireNonNull(partitionBy, "partitionBy is null"); + requireNonNull(orderingScheme, "orderingScheme is null"); + + this.partitionBy = unmodifiableList(new ArrayList<>(partitionBy)); + this.orderingScheme = requireNonNull(orderingScheme, "orderingScheme is null"); + } + + @JsonProperty + public List getPartitionBy() + { + return partitionBy; + } + + @JsonProperty + public Optional getOrderingScheme() + { + return orderingScheme; + } + + @Override + public int hashCode() + { + return Objects.hash(partitionBy, orderingScheme); + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + + if (obj == null || getClass() != obj.getClass()) { + return false; + } + + DataOrganizationSpecification other = (DataOrganizationSpecification) obj; + + return Objects.equals(this.partitionBy, other.partitionBy) && + Objects.equals(this.orderingScheme, other.orderingScheme); + } +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/plan/WindowNode.java b/presto-spi/src/main/java/com/facebook/presto/spi/plan/WindowNode.java index d24da2ad567c6..a9d958b34ca14 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/plan/WindowNode.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/plan/WindowNode.java @@ -47,7 +47,7 @@ public class WindowNode { private final PlanNode source; private final Set prePartitionedInputs; - private final Specification specification; + private final DataOrganizationSpecification specification; private final int preSortedOrderPrefix; private final Map windowFunctions; private final Optional hashVariable; @@ -57,7 +57,7 @@ public WindowNode( @JsonProperty("sourceLocation") Optional sourceLocation, @JsonProperty("id") PlanNodeId id, @JsonProperty("source") PlanNode source, - @JsonProperty("specification") Specification specification, + @JsonProperty("specification") DataOrganizationSpecification specification, @JsonProperty("windowFunctions") Map windowFunctions, @JsonProperty("hashVariable") Optional hashVariable, @JsonProperty("prePartitionedInputs") Set prePartitionedInputs, @@ -71,7 +71,7 @@ public WindowNode( PlanNodeId id, Optional statsEquivalentPlanNode, PlanNode source, - Specification specification, + DataOrganizationSpecification specification, Map windowFunctions, Optional hashVariable, Set prePartitionedInputs, @@ -122,7 +122,7 @@ public PlanNode getSource() } @JsonProperty - public Specification getSpecification() + public DataOrganizationSpecification getSpecification() { return specification; } @@ -134,7 +134,7 @@ public List getPartitionBy() public Optional getOrderingScheme() { - return specification.orderingScheme; + return specification.getOrderingScheme(); } @JsonProperty @@ -188,60 +188,6 @@ public PlanNode assignStatsEquivalentPlanNode(Optional statsEquivalent return new WindowNode(getSourceLocation(), getId(), statsEquivalentPlanNode, source, specification, windowFunctions, hashVariable, prePartitionedInputs, preSortedOrderPrefix); } - @Immutable - public static class Specification - { - private final List partitionBy; - private final Optional orderingScheme; - - @JsonCreator - public Specification( - @JsonProperty("partitionBy") List partitionBy, - @JsonProperty("orderingScheme") Optional orderingScheme) - { - requireNonNull(partitionBy, "partitionBy is null"); - requireNonNull(orderingScheme, "orderingScheme is null"); - - this.partitionBy = unmodifiableList(new ArrayList<>(partitionBy)); - this.orderingScheme = requireNonNull(orderingScheme, "orderingScheme is null"); - } - - @JsonProperty - public List getPartitionBy() - { - return partitionBy; - } - - @JsonProperty - public Optional getOrderingScheme() - { - return orderingScheme; - } - - @Override - public int hashCode() - { - return Objects.hash(partitionBy, orderingScheme); - } - - @Override - public boolean equals(Object obj) - { - if (this == obj) { - return true; - } - - if (obj == null || getClass() != obj.getClass()) { - return false; - } - - Specification other = (Specification) obj; - - return Objects.equals(this.partitionBy, other.partitionBy) && - Objects.equals(this.orderingScheme, other.orderingScheme); - } - } - @Immutable public static class Frame { From c7e1b6fb81a786c54bcad2264a23949e1edca451 Mon Sep 17 00:00:00 2001 From: Xin Zhang Date: Wed, 19 Mar 2025 14:43:20 +0000 Subject: [PATCH 06/12] Fix format issues --- .../sql/planner/planPrinter/PlanPrinter.java | 16 ++++++++++------ .../presto/sql/planner/TestCanonicalize.java | 1 - .../assertions/SpecificationProvider.java | 1 - .../optimizations/TestEliminateSorts.java | 1 - .../optimizations/TestReorderWindows.java | 1 - .../spi/plan/DataOrganizationSpecification.java | 2 +- 6 files changed, 11 insertions(+), 11 deletions(-) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java index 4153c20c32f3d..c027ceaa32364 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java @@ -1523,23 +1523,27 @@ private static String formatOutputs(Iterable output .collect(Collectors.joining(", ")); } - - public class Context { + public class Context + { private final Optional tag; - public Context() { + public Context() + { this(Optional.empty()); } - public Context(String tag) { + public Context(String tag) + { this(Optional.of(tag)); } - public Context(Optional tag) { + public Context(Optional tag) + { this.tag = requireNonNull(tag, "tag is null"); } - public Optional getTag() { + public Optional getTag() + { return tag; } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestCanonicalize.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestCanonicalize.java index 744d7a4fdaf38..3880e9c72aaac 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestCanonicalize.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestCanonicalize.java @@ -16,7 +16,6 @@ import com.facebook.presto.Session; import com.facebook.presto.common.block.SortOrder; import com.facebook.presto.spi.plan.DataOrganizationSpecification; -import com.facebook.presto.spi.plan.WindowNode; import com.facebook.presto.sql.planner.assertions.BasePlanTest; import com.facebook.presto.sql.planner.assertions.ExpectedValueProvider; import com.facebook.presto.sql.planner.iterative.IterativeOptimizer; diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/SpecificationProvider.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/SpecificationProvider.java index dc621f809a8d9..b8cc9b3b4c55b 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/SpecificationProvider.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/SpecificationProvider.java @@ -17,7 +17,6 @@ import com.facebook.presto.spi.plan.DataOrganizationSpecification; import com.facebook.presto.spi.plan.Ordering; import com.facebook.presto.spi.plan.OrderingScheme; -import com.facebook.presto.spi.plan.WindowNode; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestEliminateSorts.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestEliminateSorts.java index 1b3e0cc4f3625..57550cc78755b 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestEliminateSorts.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestEliminateSorts.java @@ -16,7 +16,6 @@ import com.facebook.presto.Session; import com.facebook.presto.common.block.SortOrder; import com.facebook.presto.spi.plan.DataOrganizationSpecification; -import com.facebook.presto.spi.plan.WindowNode; import com.facebook.presto.sql.planner.PartitioningProviderManager; import com.facebook.presto.sql.planner.RuleStatsRecorder; import com.facebook.presto.sql.planner.assertions.BasePlanTest; diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestReorderWindows.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestReorderWindows.java index aa6a80c5a883a..a62a2648040de 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestReorderWindows.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestReorderWindows.java @@ -15,7 +15,6 @@ import com.facebook.presto.common.block.SortOrder; import com.facebook.presto.spi.plan.DataOrganizationSpecification; -import com.facebook.presto.spi.plan.WindowNode; import com.facebook.presto.sql.InMemoryExpressionOptimizerProvider; import com.facebook.presto.sql.planner.RuleStatsRecorder; import com.facebook.presto.sql.planner.assertions.BasePlanTest; diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/plan/DataOrganizationSpecification.java b/presto-spi/src/main/java/com/facebook/presto/spi/plan/DataOrganizationSpecification.java index ad9e1856e48d4..30588227f76cf 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/plan/DataOrganizationSpecification.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/plan/DataOrganizationSpecification.java @@ -28,7 +28,7 @@ import static java.util.Objects.requireNonNull; @Immutable -public class DataOrganizationSpecification +public class DataOrganizationSpecification { private final List partitionBy; private final Optional orderingScheme; From c6225ef6d89cf914af8c0d202189b77e3bab2f8f Mon Sep 17 00:00:00 2001 From: mohsaka <135669458+mohsaka@users.noreply.github.com> Date: Wed, 19 Mar 2025 14:09:30 -0700 Subject: [PATCH 07/12] Plan table function invocation with table arguments Changes adapted from trino/PR#14175 Original commit: 8bd17171a8469b9351e2fd7d9f2f49f4af9ea209 Author: kasiafi Modifications were made to adapt to Presto including: Rewritting the UnaliasSymbolReferences off of Unnest Example Add a static coerce function passing in required values based off of current coerce function Co-authored-by: kasiafi <30203062+kasiafi@users.noreply.github.com> --- .../sql/planner/OrderingTranslator.java | 37 ++ .../presto/sql/planner/QueryPlanner.java | 73 +++- .../presto/sql/planner/RelationPlanner.java | 122 ++++-- .../UnaliasSymbolReferences.java | 49 ++- .../sql/planner/plan/TableFunctionNode.java | 71 +++- .../sql/planner/planPrinter/PlanPrinter.java | 136 ++++++- .../sanity/ValidateDependenciesChecker.java | 32 ++ .../connector/tvf/TestingTableFunctions.java | 39 ++ .../planner/TestTableFunctionInvocation.java | 149 ++++++++ .../planner/assertions/PlanMatchPattern.java | 7 + .../assertions/TableFunctionMatcher.java | 361 ++++++++++++++++++ 11 files changed, 1005 insertions(+), 71 deletions(-) create mode 100644 presto-main/src/main/java/com/facebook/presto/sql/planner/OrderingTranslator.java create mode 100644 presto-main/src/test/java/com/facebook/presto/sql/planner/TestTableFunctionInvocation.java create mode 100644 presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/TableFunctionMatcher.java diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/OrderingTranslator.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/OrderingTranslator.java new file mode 100644 index 0000000000000..c487a10415472 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/OrderingTranslator.java @@ -0,0 +1,37 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner; + +import com.facebook.presto.common.block.SortOrder; +import com.facebook.presto.sql.tree.SortItem; + +public class OrderingTranslator +{ + private OrderingTranslator() {} + + public static SortOrder sortItemToSortOrder(SortItem sortItem) + { + if (sortItem.getOrdering() == SortItem.Ordering.ASCENDING) { + if (sortItem.getNullOrdering() == SortItem.NullOrdering.FIRST) { + return SortOrder.ASC_NULLS_FIRST; + } + return SortOrder.ASC_NULLS_LAST; + } + + if (sortItem.getNullOrdering() == SortItem.NullOrdering.FIRST) { + return SortOrder.DESC_NULLS_FIRST; + } + return SortOrder.DESC_NULLS_LAST; + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java index 71149793bd3d8..db2db488a0ca5 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java @@ -115,6 +115,7 @@ import static com.facebook.presto.sql.analyzer.ExpressionAnalyzer.isNumericType; import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.getSourceLocation; import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; +import static com.facebook.presto.sql.planner.OrderingTranslator.sortItemToSortOrder; import static com.facebook.presto.sql.planner.PlannerUtils.newVariable; import static com.facebook.presto.sql.planner.PlannerUtils.toOrderingScheme; import static com.facebook.presto.sql.planner.PlannerUtils.toSortOrder; @@ -513,7 +514,7 @@ private PlanBuilder project(PlanBuilder subPlan, Iterable expression * * @return the new subplan and a mapping of each expression to the symbol representing the coercion or an existing symbol if a coercion wasn't needed */ - private PlanAndMappings coerce(PlanBuilder subPlan, List expressions, Analysis analysis, PlanNodeIdAllocator idAllocator, VariableAllocator variableAllocator, Metadata metadata) + public PlanAndMappings coerce(PlanBuilder subPlan, List expressions, Analysis analysis, PlanNodeIdAllocator idAllocator, VariableAllocator variableAllocator, Metadata metadata) { Assignments.Builder assignments = Assignments.builder(); assignments.putAll(subPlan.getRoot().getOutputVariables().stream().collect(toImmutableMap(Function.identity(), Function.identity()))); @@ -544,6 +545,64 @@ private PlanAndMappings coerce(PlanBuilder subPlan, List expressions return new PlanAndMappings(subPlan, mappings.build()); } + /** + * Creates a projection with any additional coercions by identity of the provided expressions. + * + * @return the new subplan and a mapping of each expression to the symbol representing the coercion or an existing symbol if a coercion wasn't needed + */ + public static PlanAndMappings coerce(PlanBuilder subPlan, List expressions, Analysis analysis, PlanNodeIdAllocator idAllocator, VariableAllocator variableAllocator, Metadata metadata, SqlPlannerContext sqlPlannerContextStatic, Session session, SqlParser sqlParserStatic) + { + Assignments.Builder assignments = Assignments.builder(); + assignments.putAll(subPlan.getRoot().getOutputVariables().stream().collect(toImmutableMap(Function.identity(), Function.identity()))); + ImmutableMap.Builder, VariableReferenceExpression> mappings = ImmutableMap.builder(); + for (Expression expression : expressions) { + Type coercion = analysis.getCoercion(expression); + if (coercion != null) { + Type type = analysis.getType(expression); + VariableReferenceExpression variable = newVariable(variableAllocator, expression, coercion); + assignments.put(variable, rowExpression( + new Cast( + subPlan.rewrite(expression), + coercion.getTypeSignature().toString(), + false, + metadata.getFunctionAndTypeManager().isTypeOnlyCoercion(type, coercion)), + sqlPlannerContextStatic, metadata, session, sqlParserStatic, variableAllocator, analysis)); + mappings.put(NodeRef.of(expression), variable); + } + else { + mappings.put(NodeRef.of(expression), subPlan.translate(expression)); + } + } + subPlan = subPlan.withNewRoot( + new ProjectNode( + idAllocator.getNextId(), + subPlan.getRoot(), + assignments.build())); + return new PlanAndMappings(subPlan, mappings.build()); + } + + public static OrderingScheme translateOrderingScheme(List items, Function coercions) + { + List coerced = items.stream() + .map(SortItem::getSortKey) + .map(coercions) + .collect(toImmutableList()); + + ImmutableList.Builder variables = ImmutableList.builder(); + Map orders = new HashMap<>(); + for (int i = 0; i < coerced.size(); i++) { + VariableReferenceExpression variable = coerced.get(i); + // for multiple sort items based on the same expression, retain the first one: + // ORDER BY x DESC, x ASC, y --> ORDER BY x DESC, y + if (!orders.containsKey(variable)) { + variables.add(variable); + orders.put(variable, new Ordering(variable, sortItemToSortOrder(items.get(i)))); + } + } + + return new OrderingScheme(new ArrayList(orders.values())); + } + private Map coerce(Iterable expressions, PlanBuilder subPlan, TranslationMap translations) { ImmutableMap.Builder projections = ImmutableMap.builder(); @@ -1335,6 +1394,18 @@ private RowExpression rowExpression(Expression expression, SqlPlannerContext con context.getTranslatorContext()); } + public static RowExpression rowExpression(Expression expression, SqlPlannerContext context, Metadata metadata, Session session, SqlParser sqlParser, VariableAllocator variableAllocator, Analysis analysis) + { + return toRowExpression( + expression, + metadata, + session, + sqlParser, + variableAllocator, + analysis, + context.getTranslatorContext()); + } + private static List toSymbolReferences(List variables) { return variables.stream() diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java index 0ac04d51669b1..b050600ea072d 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java @@ -30,11 +30,13 @@ import com.facebook.presto.spi.plan.AggregationNode; import com.facebook.presto.spi.plan.Assignments; import com.facebook.presto.spi.plan.CteReferenceNode; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; import com.facebook.presto.spi.plan.EquiJoinClause; import com.facebook.presto.spi.plan.ExceptNode; import com.facebook.presto.spi.plan.FilterNode; import com.facebook.presto.spi.plan.IntersectNode; import com.facebook.presto.spi.plan.JoinNode; +import com.facebook.presto.spi.plan.OrderingScheme; import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; import com.facebook.presto.spi.plan.ProjectNode; @@ -50,13 +52,13 @@ import com.facebook.presto.sql.analyzer.RelationId; import com.facebook.presto.sql.analyzer.RelationType; import com.facebook.presto.sql.analyzer.Scope; -import com.facebook.presto.sql.analyzer.SemanticException; import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.planner.optimizations.JoinNodeUtils; import com.facebook.presto.sql.planner.optimizations.SampleNodeUtil; import com.facebook.presto.sql.planner.plan.LateralJoinNode; import com.facebook.presto.sql.planner.plan.SampleNode; import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.TableArgumentProperties; import com.facebook.presto.sql.planner.plan.UnnestNode; import com.facebook.presto.sql.tree.AliasedRelation; import com.facebook.presto.sql.tree.Cast; @@ -87,9 +89,7 @@ import com.facebook.presto.sql.tree.SetOperation; import com.facebook.presto.sql.tree.SymbolReference; import com.facebook.presto.sql.tree.Table; -import com.facebook.presto.sql.tree.TableFunctionDescriptorArgument; import com.facebook.presto.sql.tree.TableFunctionInvocation; -import com.facebook.presto.sql.tree.TableFunctionTableArgument; import com.facebook.presto.sql.tree.TableSubquery; import com.facebook.presto.sql.tree.Union; import com.facebook.presto.sql.tree.Unnest; @@ -97,6 +97,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableListMultimap; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableMultimap; import com.google.common.collect.Iterables; import com.google.common.collect.ListMultimap; import com.google.common.collect.UnmodifiableIterator; @@ -126,9 +127,10 @@ import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.isEqualComparisonExpression; import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.resolveEnumLiteral; import static com.facebook.presto.sql.analyzer.FeaturesConfig.CteMaterializationStrategy.NONE; -import static com.facebook.presto.sql.analyzer.SemanticErrorCode.NOT_SUPPORTED; import static com.facebook.presto.sql.analyzer.SemanticExceptions.notSupportedException; import static com.facebook.presto.sql.planner.PlannerUtils.newVariable; +import static com.facebook.presto.sql.planner.QueryPlanner.coerce; +import static com.facebook.presto.sql.planner.QueryPlanner.translateOrderingScheme; import static com.facebook.presto.sql.planner.TranslateExpressionsUtil.toRowExpression; import static com.facebook.presto.sql.tree.Join.Type.INNER; import static com.facebook.presto.sql.tree.Join.Type.LEFT; @@ -235,48 +237,100 @@ protected RelationPlan visitTable(Table node, SqlPlannerContext context) @Override protected RelationPlan visitTableFunctionInvocation(TableFunctionInvocation node, SqlPlannerContext context) { - node.getArguments().stream() - .forEach(argument -> { - if (argument.getValue() instanceof TableFunctionTableArgument) { - throw new SemanticException(NOT_SUPPORTED, argument, "Table arguments are not yet supported for table functions"); - } - if (argument.getValue() instanceof TableFunctionDescriptorArgument) { - throw new SemanticException(NOT_SUPPORTED, argument, "Descriptor arguments are not yet supported for table functions"); - } - }); Analysis.TableFunctionInvocationAnalysis functionAnalysis = analysis.getTableFunctionAnalysis(node); + ImmutableList.Builder sources = ImmutableList.builder(); + ImmutableList.Builder sourceProperties = ImmutableList.builder(); + ImmutableList.Builder outputVariables = ImmutableList.builder(); + + // create new symbols for table function's proper columns + RelationType relationType = analysis.getScope(node).getRelationType(); + List properOutputs = IntStream.range(0, functionAnalysis.getProperColumnsCount()) + .mapToObj(relationType::getFieldByIndex) + .map(field -> variableAllocator.newVariable(getSourceLocation(node), field.getName().get(), field.getType())) + .collect(toImmutableList()); - // TODO handle input relations: - // 1. extract the input relations from node.getArguments() and plan them. Apply relation coercions if requested. - // 2. for each input relation, prepare the TableArgumentProperties record, consisting of: - // - row or set semantics (from the actualArgument) - // - prune when empty property (from the actualArgument) - // - pass through columns property (from the actualArgument) - // - optional Specification: ordering scheme and partitioning (from the node's argument) <- planned upon the source's RelationPlan (or combined RelationPlan from all sources) - // TODO add - argument name - // TODO add - mapping column name => Symbol // TODO mind the fields without names and duplicate field names in RelationType - List sources = ImmutableList.of(); - List inputRelationsProperties = ImmutableList.of(); + outputVariables.addAll(properOutputs); - Scope scope = analysis.getScope(node); + // process sources in order of argument declarations + for (Analysis.TableArgumentAnalysis tableArgument : functionAnalysis.getTableArgumentAnalyses()) { + RelationPlan sourcePlan = process(tableArgument.getRelation(), context); - ImmutableList.Builder outputVariablesBuilder = ImmutableList.builder(); - for (Field field : scope.getRelationType().getAllFields()) { - VariableReferenceExpression variable = variableAllocator.newVariable(getSourceLocation(node), field.getName().get(), field.getType()); - outputVariablesBuilder.add(variable); + PlanBuilder sourcePlanBuilder = initializePlanBuilder(sourcePlan); + + // map column names to symbols + // note: hidden columns are included in the mapping. They are present both in sourceDescriptor.allFields, and in sourcePlan.fieldMappings + // note: for an aliased relation or a CTE, the field names in the relation type are in the same case as specified in the alias. + // quotes and canonicalization rules are not applied. + ImmutableMultimap.Builder columnMapping = ImmutableMultimap.builder(); + RelationType sourceDescriptor = sourcePlan.getDescriptor(); + for (int i = 0; i < sourceDescriptor.getAllFieldCount(); i++) { + Optional name = sourceDescriptor.getFieldByIndex(i).getName(); + if (name.isPresent()) { + columnMapping.put(name.get(), sourcePlan.getVariable(i)); + } + } + + Optional specification = Optional.empty(); + + // if the table argument has set semantics, create Specification + if (!tableArgument.isRowSemantics()) { + // partition by + List partitionBy = ImmutableList.of(); + // if there are partitioning columns, they might have to be coerced for copartitioning + if (tableArgument.getPartitionBy().isPresent() && !tableArgument.getPartitionBy().get().isEmpty()) { + // This is from unnest and may not be correct + List partitioningColumns = tableArgument.getPartitionBy().get(); + QueryPlanner.PlanAndMappings copartitionCoercions = coerce(sourcePlanBuilder, partitioningColumns, analysis, idAllocator, variableAllocator, metadata, context, session, sqlParser); + sourcePlanBuilder = copartitionCoercions.getSubPlan(); + partitionBy = partitioningColumns.stream() + .map(copartitionCoercions::get) + .collect(toImmutableList()); + } + + // order by + Optional orderBy = Optional.empty(); + if (tableArgument.getOrderBy().isPresent()) { + // the ordering symbols are not coerced + orderBy = Optional.of(translateOrderingScheme(tableArgument.getOrderBy().get().getSortItems(), sourcePlanBuilder::translate)); + } + + specification = Optional.of(new DataOrganizationSpecification(partitionBy, orderBy)); + } + + sources.add(sourcePlanBuilder.getRoot()); + sourceProperties.add(new TableArgumentProperties( + tableArgument.getArgumentName(), + columnMapping.build(), + tableArgument.isRowSemantics(), + tableArgument.isPruneWhenEmpty(), + tableArgument.isPassThroughColumns(), + specification)); + + // add output symbols passed from the table argument + if (tableArgument.isPassThroughColumns()) { + // the original output symbols from the source node, not coerced + // note: hidden columns are included. They are present in sourcePlan.fieldMappings + outputVariables.addAll(sourcePlan.getFieldMappings()); + } + else if (tableArgument.getPartitionBy().isPresent()) { + tableArgument.getPartitionBy().get().stream() + // the original symbols for partitioning columns, not coerced + .map(sourcePlanBuilder::translate) + .forEach(outputVariables::add); + } } - List outputVariables = outputVariablesBuilder.build(); PlanNode root = new TableFunctionNode( idAllocator.getNextId(), functionAnalysis.getFunctionName(), functionAnalysis.getArguments(), - outputVariablesBuilder.build(), - sources.stream().map(RelationPlan::getRoot).collect(toImmutableList()), - inputRelationsProperties, + properOutputs, + sources.build(), + sourceProperties.build(), + functionAnalysis.getCopartitioningLists(), new TableFunctionHandle(functionAnalysis.getConnectorId(), functionAnalysis.getConnectorTableFunctionHandle(), functionAnalysis.getTransactionHandle())); - return new RelationPlan(root, scope, outputVariables); + return new RelationPlan(root, analysis.getScope(node), outputVariables.build()); } @Override diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java index 11b1cd5f48ed0..372c0111182fc 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java @@ -478,25 +478,46 @@ public PlanNode visitTableFinish(TableFinishNode node, RewriteContext cont } @Override - public PlanNode visitTableFunction(TableFunctionNode node, RewriteContext context) + public PlanAndMappings visitTableFunction(TableFunctionNode node, UnaliasContext context) { - // TODO rewrite sources, and tableArgumentProperties when we add support for input tables - /* Map mapping = new HashMap<>(context.getCorrelationMapping()); SymbolMapper mapper = symbolMapper(mapping); - List newProperOutputs = mapper.map(node.getProperOutputs());*/ + List newProperOutputs = mapper.map(node.getProperOutputs()); + + ImmutableList.Builder newSources = ImmutableList.builder(); + ImmutableList.Builder newTableArgumentProperties = ImmutableList.builder(); + + for (int i = 0; i < node.getSources().size(); i++) { + PlanAndMappings newSource = node.getSources().get(i).accept(this, context); + newSources.add(newSource.getRoot()); + + SymbolMapper inputMapper = symbolMapper(new HashMap<>(newSource.getMappings())); + TableArgumentProperties properties = node.getTableArgumentProperties().get(i); + ImmutableMultimap.Builder newColumnMapping = ImmutableMultimap.builder(); + properties.getColumnMapping().entries().stream() + .forEach(entry -> newColumnMapping.put(entry.getKey(), inputMapper.map(entry.getValue()))); + Optional newSpecification = properties.getSpecification().map(inputMapper::mapAndDistinct); + newTableArgumentProperties.add(new TableArgumentProperties( + properties.getArgumentName(), + newColumnMapping.build(), + properties.isRowSemantics(), + properties.isPruneWhenEmpty(), + properties.isPassThroughColumns(), + newSpecification)); + } - return new TableFunctionNode( - node.getSourceLocation(), - node.getId(), - Optional.empty(), - node.getName(), - node.getArguments(), - node.getOutputVariables(), - node.getSources(), - node.getTableArgumentProperties(), - node.getHandle()); + return new PlanAndMappings( + new TableFunctionNode( + node.getId(), + node.getName(), + node.getArguments(), + newProperOutputs, + newSources.build(), + newTableArgumentProperties.build(), + node.getCopartitioningLists(), + node.getHandle()), + mapping); } @Override diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/TableFunctionNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/TableFunctionNode.java index 42686890fdb9f..061d09366f093 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/TableFunctionNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/TableFunctionNode.java @@ -24,6 +24,8 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableMultimap; +import com.google.common.collect.Multimap; import javax.annotation.concurrent.Immutable; @@ -32,6 +34,7 @@ import java.util.Optional; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableList.toImmutableList; import static java.util.Objects.requireNonNull; @Immutable @@ -43,6 +46,7 @@ public class TableFunctionNode private final List outputVariables; private final List sources; private final List tableArgumentProperties; + private final List> copartitioningLists; private final TableFunctionHandle handle; @JsonCreator @@ -53,9 +57,10 @@ public TableFunctionNode( @JsonProperty("outputVariables") List outputVariables, @JsonProperty("sources") List sources, @JsonProperty("tableArgumentProperties") List tableArgumentProperties, + @JsonProperty("copartitioningLists") List> copartitioningLists, @JsonProperty("handle") TableFunctionHandle handle) { - this(Optional.empty(), id, Optional.empty(), name, arguments, outputVariables, sources, tableArgumentProperties, handle); + this(Optional.empty(), id, Optional.empty(), name, arguments, outputVariables, sources, tableArgumentProperties, copartitioningLists, handle); } public TableFunctionNode( @@ -67,6 +72,7 @@ public TableFunctionNode( List outputVariables, List sources, List tableArgumentProperties, + List> copartitioningLists, TableFunctionHandle handle) { super(sourceLocation, id, statsEquivalentPlanNode); @@ -75,6 +81,9 @@ public TableFunctionNode( this.outputVariables = ImmutableList.copyOf(outputVariables); this.sources = ImmutableList.copyOf(sources); this.tableArgumentProperties = ImmutableList.copyOf(tableArgumentProperties); + this.copartitioningLists = copartitioningLists.stream() + .map(ImmutableList::copyOf) + .collect(toImmutableList()); this.handle = requireNonNull(handle, "handle is null"); } @@ -90,10 +99,26 @@ public Map getArguments() return arguments; } - @JsonProperty + @Override public List getOutputVariables() { - return outputVariables; + ImmutableList.Builder variables = ImmutableList.builder(); + + variables.addAll(outputVariables); + + for (int i = 0; i < sources.size(); i++) { + TableArgumentProperties sourceProperties = tableArgumentProperties.get(i); + if (sourceProperties.passThroughColumns()) { + variables.addAll(sources.get(i).getOutputVariables()); + } + else { + sourceProperties.specification() + .map(DataOrganizationSpecification::getPartitionBy) + .ifPresent(outputVariables::addAll); + } + } + + return variables.build(); } @JsonProperty @@ -102,6 +127,12 @@ public List getTableArgumentProperties() return tableArgumentProperties; } + @JsonProperty + public List> getCopartitioningLists() + { + return copartitioningLists; + } + @JsonProperty public TableFunctionHandle getHandle() { @@ -125,29 +156,35 @@ public R accept(InternalPlanVisitor visitor, C context) public PlanNode replaceChildren(List newSources) { checkArgument(sources.size() == newSources.size(), "wrong number of new children"); - return new TableFunctionNode(getId(), name, arguments, outputVariables, newSources, tableArgumentProperties, handle); + return new TableFunctionNode(getId(), name, arguments, outputVariables, newSources, tableArgumentProperties, copartitioningLists, handle); } @Override public PlanNode assignStatsEquivalentPlanNode(Optional statsEquivalentPlanNode) { - return new TableFunctionNode(getSourceLocation(), getId(), statsEquivalentPlanNode, name, arguments, outputVariables, sources, tableArgumentProperties, handle); + return new TableFunctionNode(getSourceLocation(), getId(), statsEquivalentPlanNode, name, arguments, outputVariables, sources, tableArgumentProperties, copartitioningLists, handle); } public static class TableArgumentProperties { + private final String argumentName; + private final Multimap columnMapping; private final boolean rowSemantics; private final boolean pruneWhenEmpty; private final boolean passThroughColumns; - private final DataOrganizationSpecification specification; + private final Optional specification; @JsonCreator public TableArgumentProperties( + @JsonProperty("argumentName") String argumentName, + @JsonProperty("columnMapping") Multimap columnMapping, @JsonProperty("rowSemantics") boolean rowSemantics, @JsonProperty("pruneWhenEmpty") boolean pruneWhenEmpty, @JsonProperty("passThroughColumns") boolean passThroughColumns, - @JsonProperty("specification") DataOrganizationSpecification specification) + @JsonProperty("specification") Optional specification) { + this.argumentName = requireNonNull(argumentName, "argumentName is null"); + this.columnMapping = ImmutableMultimap.copyOf(columnMapping); this.rowSemantics = rowSemantics; this.pruneWhenEmpty = pruneWhenEmpty; this.passThroughColumns = passThroughColumns; @@ -155,25 +192,37 @@ public TableArgumentProperties( } @JsonProperty - public boolean isRowSemantics() + public String getArgumentName() + { + return argumentName; + } + + @JsonProperty + public Multimap getColumnMapping() + { + return columnMapping; + } + + @JsonProperty + public boolean rowSemantics() { return rowSemantics; } @JsonProperty - public boolean isPruneWhenEmpty() + public boolean pruneWhenEmpty() { return pruneWhenEmpty; } @JsonProperty - public boolean isPassThroughColumns() + public boolean passThroughColumns() { return passThroughColumns; } @JsonProperty - public DataOrganizationSpecification getSpecification() + public Optional specification() { return specification; } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java index c027ceaa32364..46830bac75f3d 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java @@ -34,6 +34,9 @@ import com.facebook.presto.spi.SourceLocation; import com.facebook.presto.spi.TableHandle; import com.facebook.presto.spi.function.FunctionHandle; +import com.facebook.presto.spi.function.table.Argument; +import com.facebook.presto.spi.function.table.DescriptorArgument; +import com.facebook.presto.spi.function.table.ScalarArgument; import com.facebook.presto.spi.plan.AbstractJoinNode; import com.facebook.presto.spi.plan.AggregationNode; import com.facebook.presto.spi.plan.Assignments; @@ -96,6 +99,7 @@ import com.facebook.presto.sql.planner.plan.SequenceNode; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.TableArgumentProperties; import com.facebook.presto.sql.planner.plan.TableWriterMergeNode; import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; import com.facebook.presto.sql.planner.plan.UnnestNode; @@ -118,11 +122,13 @@ import io.airlift.units.Duration; import java.util.ArrayList; +import java.util.Collection; import java.util.Collections; import java.util.Comparator; import java.util.LinkedList; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Optional; import java.util.Set; import java.util.function.Function; @@ -134,6 +140,7 @@ import static com.facebook.presto.execution.StageInfo.getAllStages; import static com.facebook.presto.expressions.DynamicFilters.extractDynamicFilters; import static com.facebook.presto.metadata.CastType.CAST; +import static com.facebook.presto.spi.function.table.DescriptorArgument.NULL_DESCRIPTOR; import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.createSymbolReference; import static com.facebook.presto.sql.planner.SortExpressionExtractor.getSortExpressionContext; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; @@ -147,10 +154,13 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableMap.toImmutableMap; import static java.lang.String.format; import static java.util.Arrays.stream; import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static java.util.function.Function.identity; +import static java.util.stream.Collectors.joining; import static java.util.stream.Collectors.toList; public class PlanPrinter @@ -1311,23 +1321,127 @@ public Void visitTableFunction(TableFunctionNode node, Context context) "name", context.getTag()); - checkArgument( - node.getSources().isEmpty() && node.getTableArgumentProperties().isEmpty(), - "Table or descriptor arguments are not yet supported in PlanPrinter"); + if (!node.getArguments().isEmpty()) { + nodeOutput.appendDetails("Arguments:"); - // TODO: Add details here for plan printer - // node.getArguments().entrySet().stream() - // .forEach(entry -> nodeOutput.appendDetails(entry.getKey() + " => " + formatArgument((ScalarArgument) entry.getValue()))); + Map tableArguments = node.getTableArgumentProperties().stream() + .collect(toImmutableMap(TableArgumentProperties::getArgumentName, identity())); - return processChildren(node, new Context()); + node.getArguments().entrySet().stream() + .forEach(entry -> nodeOutput.appendDetails(formatArgument(entry.getKey(), entry.getValue(), tableArguments))); + + if (!node.getCopartitioningLists().isEmpty()) { + nodeOutput.appendDetails(node.getCopartitioningLists().stream() + .map(list -> list.stream().collect(Collectors.joining(", ", "(", ")"))) + .collect(joining(", ", "Co-partition: [", "]"))); + } + } + + for (int i = 0; i < node.getSources().size(); i++) { + node.getSources().get(i).accept(this, new Context(node.getTableArgumentProperties().get(i).getArgumentName())); + } + + return null; + } + + private String formatArgument(String argumentName, Argument argument, Map tableArguments) + { + if (argument instanceof ScalarArgument) { + ScalarArgument scalarArgument = (ScalarArgument) argument; + return formatScalarArgument(argumentName, scalarArgument); + } + if (argument instanceof DescriptorArgument) { + DescriptorArgument descriptorArgument = (DescriptorArgument) argument; + return formatDescriptorArgument(argumentName, descriptorArgument); + } + else { + TableArgumentProperties argumentProperties = tableArguments.get(argumentName); + return formatTableArgument(argumentName, argumentProperties); + } + } + + private String formatScalarArgument(String argumentName, ScalarArgument argument) + { + return format( + "%s => ScalarArgument{type=%s, value=%s}", + argumentName, + argument.getType().getDisplayName(), + argument.getValue()); + } + + private String formatDescriptorArgument(String argumentName, DescriptorArgument argument) + { + String descriptor; + if (argument.equals(NULL_DESCRIPTOR)) { + descriptor = "NULL"; + } + else { + descriptor = argument.getDescriptor().orElseThrow(() -> new IllegalStateException("Missing descriptor")).getFields().stream() + .map(field -> field.getName() + field.getType().map(type -> " " + type.getDisplayName()).orElse("")) + .collect(joining(", ", "(", ")")); + } + return format("%s => DescriptorArgument{%s}", argumentName, descriptor); + } + + private String formatTableArgument(String argumentName, TableArgumentProperties argumentProperties) + { + StringBuilder properties = new StringBuilder(); + if (argumentProperties.rowSemantics()) { + properties.append("row semantics"); + } + argumentProperties.specification().ifPresent(specification -> { + properties + .append("partition by: [") + .append(Joiner.on(", ").join(specification.getPartitionBy())) + .append("]"); + specification.getOrderingScheme().ifPresent(orderingScheme -> { + properties + .append(", order by: ") + .append(formatOrderingScheme(orderingScheme)); + }); + }); + + /* TODO: Come back here eventually + properties.append("required columns: [") + .append(Joiner.on(", ").join(argumentProperties.requiredColumns())) + .append("]"); + */ + if (argumentProperties.pruneWhenEmpty()) { + properties.append(", prune when empty"); + } + + /* TODO: Come back here as well + if (argumentProperties.getPassThroughSpecification().declaredAsPassThrough()) { + properties.append(", pass through columns"); + } + */ + return format("%s => TableArgument{%s}", argumentName, properties); + } + + private String formatOrderingScheme(OrderingScheme orderingScheme) + { + return formatCollection(orderingScheme.getOrderByVariables(), variable -> variable + " " + orderingScheme.getOrdering(variable)); + } + + private String formatOrderingScheme(OrderingScheme orderingScheme, int preSortedOrderPrefix) + { + List orderBy = Stream.concat( + orderingScheme.getOrderByVariables().stream() + .limit(preSortedOrderPrefix) + .map(variable -> "<" + variable + " " + orderingScheme.getOrdering(variable) + ">"), + orderingScheme.getOrderByVariables().stream() + .skip(preSortedOrderPrefix) + .map(variable -> variable + " " + orderingScheme.getOrdering(variable))) + .collect(toImmutableList()); + return formatCollection(orderBy, Objects::toString); } - /* - private String formatArgument(ScalarArgument argument) + public String formatCollection(Collection collection, Function formatter) { - return format("ScalarArgument{type=%s, value=%s}", argument.getType(), valuePrinter.castToVarchar(argument.getType(), argument.getValue())); + return collection.stream() + .map(formatter) + .collect(joining(", ", "[", "]")); } - */ @Override public Void visitPlan(PlanNode node, Context context) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java index db94541ed5bb5..3bf5924664d99 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java @@ -129,6 +129,38 @@ public Void visitPlan(PlanNode node, Set boundVaria @Override public Void visitTableFunction(TableFunctionNode node, Set boundSymbols) { + for (int i = 0; i < node.getSources().size(); i++) { + PlanNode source = node.getSources().get(i); + source.accept(this, boundSymbols); + Set inputs = createInputs(source, boundSymbols); + TableFunctionNode.TableArgumentProperties argumentProperties = node.getTableArgumentProperties().get(i); + + checkDependencies( + inputs, + argumentProperties.getColumnMapping().values(), + "Invalid node. Input symbols from source %s (%s) not in source plan output (%s)", + argumentProperties.getArgumentName(), + argumentProperties.getColumnMapping().values(), + source.getOutputVariables()); + argumentProperties.specification().ifPresent(specification -> { + checkDependencies( + inputs, + specification.getPartitionBy(), + "Invalid node. Partition by symbols for source %s (%s) not in source plan output (%s)", + argumentProperties.getArgumentName(), + specification.getPartitionBy(), + source.getOutputVariables()); + specification.getOrderingScheme().ifPresent(orderingScheme -> { + checkDependencies( + inputs, + orderingScheme.getOrderByVariables(), + "Invalid node. Order by symbols for source %s (%s) not in source plan output (%s)", + argumentProperties.getArgumentName(), + orderingScheme.getOrderBy(), + source.getOutputVariables()); + }); + }); + } return null; } diff --git a/presto-main/src/test/java/com/facebook/presto/connector/tvf/TestingTableFunctions.java b/presto-main/src/test/java/com/facebook/presto/connector/tvf/TestingTableFunctions.java index 20eb32585060e..fe3d822d4cdd6 100644 --- a/presto-main/src/test/java/com/facebook/presto/connector/tvf/TestingTableFunctions.java +++ b/presto-main/src/test/java/com/facebook/presto/connector/tvf/TestingTableFunctions.java @@ -331,6 +331,45 @@ public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransact } } + public static class DifferentArgumentTypesFunction + extends AbstractConnectorTableFunction + { + public DifferentArgumentTypesFunction() + { + super( + SCHEMA_NAME, + "different_arguments_function", + ImmutableList.of( + TableArgumentSpecification.builder() + .name("INPUT_1") + .passThroughColumns() + .build(), + DescriptorArgumentSpecification.builder() + .name("LAYOUT") + .build(), + TableArgumentSpecification.builder() + .name("INPUT_2") + .rowSemantics() + .passThroughColumns() + .build(), + ScalarArgumentSpecification.builder() + .name("ID") + .type(BIGINT) + .build(), + TableArgumentSpecification.builder() + .name("INPUT_3") + .pruneWhenEmpty() + .build()), + GENERIC_TABLE); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + return ANALYSIS; + } + } + public static class TestingTableFunctionHandle implements ConnectorTableFunctionHandle { diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestTableFunctionInvocation.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestTableFunctionInvocation.java new file mode 100644 index 0000000000000..14aab03553a22 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestTableFunctionInvocation.java @@ -0,0 +1,149 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner; + +import com.facebook.presto.connector.tvf.MockConnectorFactory; +import com.facebook.presto.connector.tvf.MockConnectorPlugin; +import com.facebook.presto.connector.tvf.TestingTableFunctions.DescriptorArgumentFunction; +import com.facebook.presto.connector.tvf.TestingTableFunctions.DifferentArgumentTypesFunction; +import com.facebook.presto.connector.tvf.TestingTableFunctions.TestingTableFunctionHandle; +import com.facebook.presto.connector.tvf.TestingTableFunctions.TwoScalarArgumentsFunction; +import com.facebook.presto.spi.connector.TableFunctionApplicationResult; +import com.facebook.presto.spi.function.table.Descriptor; +import com.facebook.presto.spi.function.table.Descriptor.Field; +import com.facebook.presto.sql.planner.assertions.BasePlanTest; +import com.facebook.presto.sql.tree.LongLiteral; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.util.Optional; + +import static com.facebook.presto.common.block.SortOrder.ASC_NULLS_LAST; +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.BooleanType.BOOLEAN; +import static com.facebook.presto.sql.Optimizer.PlanStage.CREATED; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyTree; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.specification; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableFunction; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static com.facebook.presto.sql.planner.assertions.TableFunctionMatcher.DescriptorArgumentValue.descriptorArgument; +import static com.facebook.presto.sql.planner.assertions.TableFunctionMatcher.DescriptorArgumentValue.nullDescriptor; +import static com.facebook.presto.sql.planner.assertions.TableFunctionMatcher.TableArgumentValue.Builder.tableArgument; + +public class TestTableFunctionInvocation + extends BasePlanTest +{ + private static final String TESTING_CATALOG = "mock"; + + @BeforeClass + public final void setup() + { + getQueryRunner().installPlugin(new MockConnectorPlugin(MockConnectorFactory.builder() + .withTableFunctions(ImmutableSet.of( + new DifferentArgumentTypesFunction(), + new TwoScalarArgumentsFunction(), + new DescriptorArgumentFunction())) + .withApplyTableFunction((session, handle) -> { + if (handle instanceof TestingTableFunctionHandle) { + TestingTableFunctionHandle functionHandle = (TestingTableFunctionHandle) handle; + return Optional.of(new TableFunctionApplicationResult<>(functionHandle.getTableHandle(), functionHandle.getTableHandle().getColumns().orElseThrow(() -> new IllegalStateException("Missing columns")))); + } + throw new IllegalStateException("Unsupported table function handle: " + handle.getClass().getSimpleName()); + }) + .build())); + getQueryRunner().createCatalog(TESTING_CATALOG, "mock", ImmutableMap.of()); + } + + @Test + public void testTableFunctionInitialPlan() + { + assertPlan( + "SELECT * FROM TABLE(mock.system.different_arguments_function(" + + "INPUT_1 => TABLE(SELECT 'a') t1(c1) PARTITION BY c1 ORDER BY c1," + + "INPUT_3 => TABLE(SELECT 'b') t3(c3) PARTITION BY c3," + + "INPUT_2 => TABLE(VALUES 1) t2(c2)," + + "ID => BIGINT '2001'," + + "LAYOUT => DESCRIPTOR (x boolean, y bigint)" + + "COPARTITION (t1, t3))) t", + CREATED, + anyTree(tableFunction(builder -> builder + .name("different_arguments_function") + .addTableArgument( + "INPUT_1", + tableArgument(0) + .specification(specification(ImmutableList.of("c1"), ImmutableList.of("c1"), ImmutableMap.of("c1", ASC_NULLS_LAST))) + .passThroughColumns()) + .addTableArgument( + "INPUT_3", + tableArgument(2) + .specification(specification(ImmutableList.of("c3"), ImmutableList.of(), ImmutableMap.of())) + .pruneWhenEmpty()) + .addTableArgument( + "INPUT_2", + tableArgument(1) + .rowSemantics() + .passThroughColumns()) + .addScalarArgument("ID", 2001L) + .addDescriptorArgument( + "LAYOUT", + descriptorArgument(new Descriptor(ImmutableList.of( + new Field("X", Optional.of(BOOLEAN)), + new Field("Y", Optional.of(BIGINT)))))) + .addCopartitioning(ImmutableList.of("INPUT_1", "INPUT_3")) + .properOutputs(ImmutableList.of("OUTPUT")), + anyTree(project(ImmutableMap.of("c1", expression("'a'")), values("1"))), + anyTree(values(ImmutableList.of("c2"), ImmutableList.of(ImmutableList.of(new LongLiteral("1"))))), + anyTree(project(ImmutableMap.of("c3", expression("'b'")), values("1")))))); + } + + @Test + public void testNullScalarArgument() + { + // the argument NUMBER has null default value + assertPlan( + " SELECT * FROM TABLE(mock.system.two_arguments_function(TEXT => null))", + CREATED, + anyTree(tableFunction(builder -> builder + .name("two_arguments_function") + .addScalarArgument("TEXT", null) + .addScalarArgument("NUMBER", null) + .properOutputs(ImmutableList.of("OUTPUT"))))); + } + + @Test + public void testNullDescriptorArgument() + { + assertPlan( + " SELECT * FROM TABLE(mock.system.descriptor_argument_function(SCHEMA => CAST(null AS DESCRIPTOR)))", + CREATED, + anyTree(tableFunction(builder -> builder + .name("descriptor_argument_function") + .addDescriptorArgument("SCHEMA", nullDescriptor()) + .properOutputs(ImmutableList.of("OUTPUT"))))); + + // the argument SCHEMA has null default value + assertPlan( + " SELECT * FROM TABLE(mock.system.descriptor_argument_function())", + CREATED, + anyTree(tableFunction(builder -> builder + .name("descriptor_argument_function") + .addDescriptorArgument("SCHEMA", nullDescriptor()) + .properOutputs(ImmutableList.of("OUTPUT"))))); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java index af888762f2021..03b73673abf42 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java @@ -668,6 +668,13 @@ public static PlanMatchPattern remoteSource(List sourceFragmentI return node(RemoteSourceNode.class).with(new RemoteSourceMatcher(sourceFragmentIds, outputSymbolAliases)); } + public static PlanMatchPattern tableFunction(Consumer handler, PlanMatchPattern... sources) + { + TableFunctionMatcher.Builder builder = new TableFunctionMatcher.Builder(sources); + handler.accept(builder); + return builder.build(); + } + public PlanMatchPattern(List sourcePatterns) { requireNonNull(sourcePatterns, "sourcePatterns are null"); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/TableFunctionMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/TableFunctionMatcher.java new file mode 100644 index 0000000000000..9eb971c453ed3 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/TableFunctionMatcher.java @@ -0,0 +1,361 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.assertions; + +import com.facebook.presto.Session; +import com.facebook.presto.cost.StatsProvider; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.function.table.Argument; +import com.facebook.presto.spi.function.table.Descriptor; +import com.facebook.presto.spi.function.table.DescriptorArgument; +import com.facebook.presto.spi.function.table.ScalarArgument; +import com.facebook.presto.spi.function.table.TableArgument; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; +import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.TableArgumentProperties; +import com.facebook.presto.sql.tree.SymbolReference; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; + +import static com.facebook.presto.sql.planner.assertions.MatchResult.NO_MATCH; +import static com.facebook.presto.sql.planner.assertions.MatchResult.match; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.node; +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.Objects.requireNonNull; + +public class TableFunctionMatcher + implements Matcher +{ + private final String name; + private final Map arguments; + private final List properOutputs; + private final List> copartitioningLists; + + private TableFunctionMatcher( + String name, + Map arguments, + List properOutputs, + List> copartitioningLists) + { + this.name = requireNonNull(name, "name is null"); + this.arguments = ImmutableMap.copyOf(requireNonNull(arguments, "arguments is null")); + this.properOutputs = ImmutableList.copyOf(requireNonNull(properOutputs, "properOutputs is null")); + requireNonNull(copartitioningLists, "copartitioningLists is null"); + this.copartitioningLists = copartitioningLists.stream() + .map(ImmutableList::copyOf) + .collect(toImmutableList()); + } + + @Override + public boolean shapeMatches(PlanNode node) + { + return node instanceof TableFunctionNode; + } + + @Override + public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session session, Metadata metadata, SymbolAliases symbolAliases) + { + checkState(shapeMatches(node), "Plan testing framework error: shapeMatches returned false in detailMatches in %s", this.getClass().getName()); + + TableFunctionNode tableFunctionNode = (TableFunctionNode) node; + + if (!name.equals(tableFunctionNode.getName())) { + return NO_MATCH; + } + + if (arguments.size() != tableFunctionNode.getArguments().size()) { + return NO_MATCH; + } + for (Map.Entry entry : arguments.entrySet()) { + String name = entry.getKey(); + Argument actual = tableFunctionNode.getArguments().get(name); + if (actual == null) { + return NO_MATCH; + } + ArgumentValue expected = entry.getValue(); + if (expected instanceof DescriptorArgumentValue) { + DescriptorArgumentValue expectedDescriptor = (DescriptorArgumentValue) expected; + if (!(actual instanceof DescriptorArgument) || !expectedDescriptor.getDescriptor().equals(((DescriptorArgument) actual).getDescriptor())) { + return NO_MATCH; + } + } + else if (expected instanceof ScalarArgumentValue) { + ScalarArgumentValue expectedScalar = (ScalarArgumentValue) expected; + if (!(actual instanceof ScalarArgument) || !Objects.equals(expectedScalar.getValue(), ((ScalarArgument) actual).getValue())) { + return NO_MATCH; + } + } + else { + if (!(actual instanceof TableArgument)) { + return NO_MATCH; + } + TableArgumentValue expectedTableArgument = (TableArgumentValue) expected; + TableArgumentProperties argumentProperties = tableFunctionNode.getTableArgumentProperties().get(expectedTableArgument.sourceIndex()); + if (!name.equals(argumentProperties.getArgumentName())) { + return NO_MATCH; + } + if (expectedTableArgument.rowSemantics() != argumentProperties.rowSemantics() || + expectedTableArgument.pruneWhenEmpty() != argumentProperties.pruneWhenEmpty() || + expectedTableArgument.passThroughColumns() != argumentProperties.passThroughColumns()) { + return NO_MATCH; + } + boolean specificationMatches = expectedTableArgument.specification() + .map(specification -> specification.getExpectedValue(symbolAliases)) + .equals(argumentProperties.specification()); + if (!specificationMatches) { + return NO_MATCH; + } + } + } + + if (properOutputs.size() != tableFunctionNode.getOutputVariables().size()) { + return NO_MATCH; + } + + if (!ImmutableSet.copyOf(copartitioningLists).equals(ImmutableSet.copyOf(tableFunctionNode.getCopartitioningLists()))) { + return NO_MATCH; + } + + ImmutableMap.Builder properOutputsMapping = ImmutableMap.builder(); + + // TODO: Do we need these symbol references or should it be something like VariableReferenceExpression + /* + for (int i = 0; i < properOutputs.size(); i++) { + properOutputsMapping.put(properOutputs.get(i), tableFunctionNode.getOutputVariables().get(i).toSymbolReference()); + } + */ + + return match(SymbolAliases.builder() + .putAll(symbolAliases) + .putAll(properOutputsMapping.buildOrThrow()) + .build()); + } + + @Override + public String toString() + { + return toStringHelper(this) + .omitNullValues() + .add("name", name) + .add("arguments", arguments) + .add("properOutputs", properOutputs) + .add("copartitioningLists", copartitioningLists) + .toString(); + } + + public static class Builder + { + private final PlanMatchPattern[] sources; + private String name; + private final ImmutableMap.Builder arguments = ImmutableMap.builder(); + private List properOutputs = ImmutableList.of(); + private final ImmutableList.Builder> copartitioningLists = ImmutableList.builder(); + + Builder(PlanMatchPattern... sources) + { + this.sources = Arrays.copyOf(sources, sources.length); + } + + public Builder name(String name) + { + this.name = name; + return this; + } + + public Builder addDescriptorArgument(String name, DescriptorArgumentValue descriptor) + { + this.arguments.put(name, descriptor); + return this; + } + + public Builder addScalarArgument(String name, Object value) + { + this.arguments.put(name, new ScalarArgumentValue(value)); + return this; + } + + public Builder addTableArgument(String name, TableArgumentValue.Builder tableArgument) + { + this.arguments.put(name, tableArgument.build()); + return this; + } + + public Builder properOutputs(List properOutputs) + { + this.properOutputs = properOutputs; + return this; + } + + public Builder addCopartitioning(List copartitioning) + { + this.copartitioningLists.add(copartitioning); + return this; + } + + public PlanMatchPattern build() + { + return node(TableFunctionNode.class, sources) + .with(new TableFunctionMatcher(name, arguments.buildOrThrow(), properOutputs, copartitioningLists.build())); + } + } + + interface ArgumentValue + { + } + + public static class DescriptorArgumentValue + implements ArgumentValue + { + private final Optional descriptor; + + public DescriptorArgumentValue(Optional descriptor) + { + this.descriptor = requireNonNull(descriptor, "descriptor is null"); + } + + public static DescriptorArgumentValue descriptorArgument(Descriptor descriptor) + { + return new DescriptorArgumentValue(Optional.of(requireNonNull(descriptor, "descriptor is null"))); + } + + public static DescriptorArgumentValue nullDescriptor() + { + return new DescriptorArgumentValue(Optional.empty()); + } + + public Optional getDescriptor() + { + return descriptor; + } + } + + public static class ScalarArgumentValue + implements ArgumentValue + { + private final Object value; + + public ScalarArgumentValue(Object value) + { + this.value = value; + } + + public Object getValue() + { + return value; + } + } + + public static class TableArgumentValue + implements ArgumentValue + { + private final int sourceIndex; + private final boolean rowSemantics; + private final boolean pruneWhenEmpty; + private final boolean passThroughColumns; + private final Optional> specification; + + public TableArgumentValue(int sourceIndex, boolean rowSemantics, boolean pruneWhenEmpty, boolean passThroughColumns, Optional> specification) + { + this.sourceIndex = sourceIndex; + this.rowSemantics = rowSemantics; + this.pruneWhenEmpty = pruneWhenEmpty; + this.passThroughColumns = passThroughColumns; + this.specification = requireNonNull(specification, "specification is null"); + } + + public int sourceIndex() + { + return sourceIndex; + } + + public boolean rowSemantics() + { + return rowSemantics; + } + + public boolean pruneWhenEmpty() + { + return pruneWhenEmpty; + } + + public boolean passThroughColumns() + { + return passThroughColumns; + } + + public Optional> specification() + { + return specification; + } + + public static class Builder + { + private final int sourceIndex; + private boolean rowSemantics; + private boolean pruneWhenEmpty; + private boolean passThroughColumns; + private Optional> specification = Optional.empty(); + + private Builder(int sourceIndex) + { + this.sourceIndex = sourceIndex; + } + + public static Builder tableArgument(int sourceIndex) + { + return new Builder(sourceIndex); + } + + public Builder rowSemantics() + { + this.rowSemantics = true; + this.pruneWhenEmpty = true; + return this; + } + + public Builder pruneWhenEmpty() + { + this.pruneWhenEmpty = true; + return this; + } + + public Builder passThroughColumns() + { + this.passThroughColumns = true; + return this; + } + + public Builder specification(ExpectedValueProvider specification) + { + this.specification = Optional.of(specification); + return this; + } + + private TableArgumentValue build() + { + return new TableArgumentValue(sourceIndex, rowSemantics, pruneWhenEmpty, passThroughColumns, specification); + } + } + } +} From a9603675fa0a2a39147dcd493c9e33705e9d4d1d Mon Sep 17 00:00:00 2001 From: Xin Zhang Date: Thu, 20 Mar 2025 15:07:32 +0000 Subject: [PATCH 08/12] Temporary change for unalias. --- .../planner/optimizations/SymbolMapper.java | 2 +- .../UnaliasSymbolReferences.java | 217 ++++++++++++------ .../presto/testing/LocalQueryRunner.java | 4 +- 3 files changed, 150 insertions(+), 73 deletions(-) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/SymbolMapper.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/SymbolMapper.java index 9805efad17939..32191c7e6e44c 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/SymbolMapper.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/SymbolMapper.java @@ -335,7 +335,7 @@ private List mapAndDistinctSymbol(List outputs) return builder.build(); } - private List mapAndDistinctVariable(List outputs) + List mapAndDistinctVariable(List outputs) { Set added = new HashSet<>(); ImmutableList.Builder builder = ImmutableList.builder(); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java index 372c0111182fc..db6847b7c4ceb 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java @@ -82,10 +82,12 @@ import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableMultimap; import com.google.common.collect.Iterables; import com.google.common.collect.Lists; import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.LinkedHashMap; @@ -140,7 +142,7 @@ public PlanOptimizerResult optimize(PlanNode plan, Session session, TypeProvider } private static class Rewriter - extends SimplePlanRewriter + extends SimplePlanRewriter { private final Map mapping = new HashMap<>(); private final TypeProvider types; @@ -158,7 +160,7 @@ private Rewriter(TypeProvider types, FunctionAndTypeManager functionAndTypeManag } @Override - public PlanNode visitAggregation(AggregationNode node, RewriteContext context) + public PlanNode visitAggregation(AggregationNode node, RewriteContext context) { PlanNode source = context.rewrite(node.getSource()); //TODO: use mapper in other methods @@ -167,26 +169,26 @@ public PlanNode visitAggregation(AggregationNode node, RewriteContext cont } @Override - public PlanNode visitCteReference(CteReferenceNode node, RewriteContext context) + public PlanNode visitCteReference(CteReferenceNode node, RewriteContext context) { PlanNode source = context.rewrite(node.getSource()); return new CteReferenceNode(node.getSourceLocation(), node.getId(), source, node.getCteId()); } - public PlanNode visitCteProducer(CteProducerNode node, RewriteContext context) + public PlanNode visitCteProducer(CteProducerNode node, RewriteContext context) { PlanNode source = context.rewrite(node.getSource()); List canonical = Lists.transform(node.getOutputVariables(), this::canonicalize); return new CteProducerNode(node.getSourceLocation(), node.getId(), source, node.getCteId(), node.getRowCountVariable(), canonical); } - public PlanNode visitCteConsumer(CteConsumerNode node, RewriteContext context) + public PlanNode visitCteConsumer(CteConsumerNode node, RewriteContext context) { // No rewrite on source by cte consumer return node; } - public PlanNode visitSequence(SequenceNode node, RewriteContext context) + public PlanNode visitSequence(SequenceNode node, RewriteContext context) { List cteProducers = node.getCteProducers().stream().map(c -> SimplePlanRewriter.rewriteWith(new Rewriter(types, functionAndTypeManager, warningCollector), c)) @@ -196,7 +198,7 @@ public PlanNode visitSequence(SequenceNode node, RewriteContext context) } @Override - public PlanNode visitGroupId(GroupIdNode node, RewriteContext context) + public PlanNode visitGroupId(GroupIdNode node, RewriteContext context) { PlanNode source = context.rewrite(node.getSource()); @@ -216,21 +218,21 @@ public PlanNode visitGroupId(GroupIdNode node, RewriteContext context) } @Override - public PlanNode visitExplainAnalyze(ExplainAnalyzeNode node, RewriteContext context) + public PlanNode visitExplainAnalyze(ExplainAnalyzeNode node, RewriteContext context) { PlanNode source = context.rewrite(node.getSource()); return new ExplainAnalyzeNode(node.getSourceLocation(), node.getId(), source, canonicalize(node.getOutputVariable()), node.isVerbose(), node.getFormat()); } @Override - public PlanNode visitMarkDistinct(MarkDistinctNode node, RewriteContext context) + public PlanNode visitMarkDistinct(MarkDistinctNode node, RewriteContext context) { PlanNode source = context.rewrite(node.getSource()); return new MarkDistinctNode(node.getSourceLocation(), node.getId(), source, canonicalize(node.getMarkerVariable()), canonicalizeAndDistinct(node.getDistinctVariables()), canonicalize(node.getHashVariable())); } @Override - public PlanNode visitUnnest(UnnestNode node, RewriteContext context) + public PlanNode visitUnnest(UnnestNode node, RewriteContext context) { PlanNode source = context.rewrite(node.getSource()); ImmutableMap.Builder> builder = ImmutableMap.builder(); @@ -241,7 +243,7 @@ public PlanNode visitUnnest(UnnestNode node, RewriteContext context) } @Override - public PlanNode visitWindow(WindowNode node, RewriteContext context) + public PlanNode visitWindow(WindowNode node, RewriteContext context) { PlanNode source = context.rewrite(node.getSource()); @@ -301,13 +303,13 @@ private WindowNode.Frame canonicalize(WindowNode.Frame frame) } @Override - public PlanNode visitTableScan(TableScanNode node, RewriteContext context) + public PlanNode visitTableScan(TableScanNode node, RewriteContext context) { return node; } @Override - public PlanNode visitExchange(ExchangeNode node, RewriteContext context) + public PlanNode visitExchange(ExchangeNode node, RewriteContext context) { List sources = node.getSources().stream() .map(context::rewrite) @@ -393,7 +395,7 @@ private List canonicalizeExchangeNodeInputs(Exchang } @Override - public PlanNode visitRemoteSource(RemoteSourceNode node, RewriteContext context) + public PlanNode visitRemoteSource(RemoteSourceNode node, RewriteContext context) { return new RemoteSourceNode( node.getSourceLocation(), @@ -408,31 +410,31 @@ public PlanNode visitRemoteSource(RemoteSourceNode node, RewriteContext co } @Override - public PlanNode visitOffset(OffsetNode node, RewriteContext context) + public PlanNode visitOffset(OffsetNode node, RewriteContext context) { return context.defaultRewrite(node); } @Override - public PlanNode visitLimit(LimitNode node, RewriteContext context) + public PlanNode visitLimit(LimitNode node, RewriteContext context) { return context.defaultRewrite(node); } @Override - public PlanNode visitDistinctLimit(DistinctLimitNode node, RewriteContext context) + public PlanNode visitDistinctLimit(DistinctLimitNode node, RewriteContext context) { return new DistinctLimitNode(node.getSourceLocation(), node.getId(), context.rewrite(node.getSource()), node.getLimit(), node.isPartial(), canonicalizeAndDistinct(node.getDistinctVariables()), canonicalize(node.getHashVariable()), node.getTimeoutMillis()); } @Override - public PlanNode visitSample(SampleNode node, RewriteContext context) + public PlanNode visitSample(SampleNode node, RewriteContext context) { return new SampleNode(node.getSourceLocation(), node.getId(), context.rewrite(node.getSource()), node.getSampleRatio(), node.getSampleType()); } @Override - public PlanNode visitValues(ValuesNode node, RewriteContext context) + public PlanNode visitValues(ValuesNode node, RewriteContext context) { List> canonicalizedRows = node.getRows().stream() .map(rowExpressions -> rowExpressions.stream() @@ -450,19 +452,19 @@ public PlanNode visitValues(ValuesNode node, RewriteContext context) } @Override - public PlanNode visitDelete(DeleteNode node, RewriteContext context) + public PlanNode visitDelete(DeleteNode node, RewriteContext context) { return new DeleteNode(node.getSourceLocation(), node.getId(), context.rewrite(node.getSource()), canonicalize(node.getRowId()), node.getOutputVariables(), node.getInputDistribution()); } @Override - public PlanNode visitUpdate(UpdateNode node, RewriteContext context) + public PlanNode visitUpdate(UpdateNode node, RewriteContext context) { return new UpdateNode(node.getSourceLocation(), node.getId(), node.getSource(), canonicalize(node.getRowId()), node.getColumnValueAndRowIdSymbols(), node.getOutputVariables()); } @Override - public PlanNode visitStatisticsWriterNode(StatisticsWriterNode node, RewriteContext context) + public PlanNode visitStatisticsWriterNode(StatisticsWriterNode node, RewriteContext context) { PlanNode source = context.rewrite(node.getSource()); SymbolMapper mapper = new SymbolMapper(mapping, types, warningCollector); @@ -470,64 +472,98 @@ public PlanNode visitStatisticsWriterNode(StatisticsWriterNode node, RewriteCont } @Override - public PlanNode visitTableFinish(TableFinishNode node, RewriteContext context) + public PlanNode visitTableFinish(TableFinishNode node, RewriteContext context) { PlanNode source = context.rewrite(node.getSource()); SymbolMapper mapper = new SymbolMapper(mapping, types, warningCollector); return mapper.map(node, source); } + @Override - public PlanAndMappings visitTableFunction(TableFunctionNode node, UnaliasContext context) + public PlanNode visitTableFunction(TableFunctionNode node, RewriteContext context) { - Map mapping = new HashMap<>(context.getCorrelationMapping()); - SymbolMapper mapper = symbolMapper(mapping); + Map mappings = new HashMap<>(context.get().getCorrelationMapping()); + SymbolMapper mapper = new SymbolMapper(mappings, warningCollector); - List newProperOutputs = mapper.map(node.getProperOutputs()); + List newProperOutputs = node.getOutputVariables().stream() + .map(mapper::map). + collect(toImmutableList()); ImmutableList.Builder newSources = ImmutableList.builder(); - ImmutableList.Builder newTableArgumentProperties = ImmutableList.builder(); + ImmutableList.Builder newTableArgumentProperties = ImmutableList.builder(); for (int i = 0; i < node.getSources().size(); i++) { - PlanAndMappings newSource = node.getSources().get(i).accept(this, context); - newSources.add(newSource.getRoot()); + PlanNode newSource = node.getSources().get(i).accept(this, context); + newSources.add(newSource); + - SymbolMapper inputMapper = symbolMapper(new HashMap<>(newSource.getMappings())); - TableArgumentProperties properties = node.getTableArgumentProperties().get(i); - ImmutableMultimap.Builder newColumnMapping = ImmutableMultimap.builder(); + SymbolMapper inputMapper = new SymbolMapper(((PlanAndMappings)newSource).getMappings(), warningCollector); + TableFunctionNode.TableArgumentProperties properties = node.getTableArgumentProperties().get(i); + ImmutableMultimap.Builder newColumnMapping = ImmutableMultimap.builder(); properties.getColumnMapping().entries().stream() .forEach(entry -> newColumnMapping.put(entry.getKey(), inputMapper.map(entry.getValue()))); - Optional newSpecification = properties.getSpecification().map(inputMapper::mapAndDistinct); - newTableArgumentProperties.add(new TableArgumentProperties( + Optional newSpecification = Optional.of(new DataOrganizationSpecification( + inputMapper.mapAndDistinctVariable(properties.specification().get().getPartitionBy()), + Optional.of(inputMapper.map(properties.specification().get().getOrderingScheme().get())) + )); + newTableArgumentProperties.add(new TableFunctionNode.TableArgumentProperties( properties.getArgumentName(), newColumnMapping.build(), - properties.isRowSemantics(), - properties.isPruneWhenEmpty(), - properties.isPassThroughColumns(), + properties.rowSemantics(), + properties.pruneWhenEmpty(), + properties.passThroughColumns(), newSpecification)); } + TableFunctionNode tableFunctionNode = new TableFunctionNode( + node.getId(), + node.getName(), + node.getArguments(), + newProperOutputs, + newSources.build(), + newTableArgumentProperties.build(), + node.getCopartitioningLists(), + node.getHandle()); + return new PlanAndMappings( - new TableFunctionNode( - node.getId(), - node.getName(), - node.getArguments(), - newProperOutputs, - newSources.build(), - newTableArgumentProperties.build(), - node.getCopartitioningLists(), - node.getHandle()), - mapping); + tableFunctionNode, + mappings) + { + @Override + public List getSources() + { + return tableFunctionNode.getSources(); + } + + @Override + public List getOutputVariables() + { + return tableFunctionNode.getOutputVariables(); + } + + @Override + public PlanNode replaceChildren(List newChildren) + { + return tableFunctionNode.replaceChildren(newChildren); + } + + @Override + public PlanNode assignStatsEquivalentPlanNode(Optional statsEquivalentPlanNode) + { + return tableFunctionNode.assignStatsEquivalentPlanNode(statsEquivalentPlanNode); + } + }; } @Override - public PlanNode visitRowNumber(RowNumberNode node, RewriteContext context) + public PlanNode visitRowNumber(RowNumberNode node, RewriteContext context) { return new RowNumberNode(node.getSourceLocation(), node.getId(), context.rewrite(node.getSource()), canonicalizeAndDistinct(node.getPartitionBy()), canonicalize(node.getRowNumberVariable()), node.getMaxRowCountPerPartition(), node.isPartial(), canonicalize(node.getHashVariable())); } @Override - public PlanNode visitTopNRowNumber(TopNRowNumberNode node, RewriteContext context) + public PlanNode visitTopNRowNumber(TopNRowNumberNode node, RewriteContext context) { return new TopNRowNumberNode( node.getSourceLocation(), @@ -541,7 +577,7 @@ public PlanNode visitTopNRowNumber(TopNRowNumberNode node, RewriteContext } @Override - public PlanNode visitFilter(FilterNode node, RewriteContext context) + public PlanNode visitFilter(FilterNode node, RewriteContext context) { PlanNode source = context.rewrite(node.getSource()); @@ -549,14 +585,14 @@ public PlanNode visitFilter(FilterNode node, RewriteContext context) } @Override - public PlanNode visitProject(ProjectNode node, RewriteContext context) + public PlanNode visitProject(ProjectNode node, RewriteContext context) { PlanNode source = context.rewrite(node.getSource()); return new ProjectNode(node.getSourceLocation(), node.getId(), source, canonicalize(node.getAssignments()), node.getLocality()); } @Override - public PlanNode visitOutput(OutputNode node, RewriteContext context) + public PlanNode visitOutput(OutputNode node, RewriteContext context) { PlanNode source = context.rewrite(node.getSource()); @@ -565,7 +601,7 @@ public PlanNode visitOutput(OutputNode node, RewriteContext context) } @Override - public PlanNode visitEnforceSingleRow(EnforceSingleRowNode node, RewriteContext context) + public PlanNode visitEnforceSingleRow(EnforceSingleRowNode node, RewriteContext context) { PlanNode source = context.rewrite(node.getSource()); @@ -573,7 +609,7 @@ public PlanNode visitEnforceSingleRow(EnforceSingleRowNode node, RewriteContext< } @Override - public PlanNode visitAssignUniqueId(AssignUniqueId node, RewriteContext context) + public PlanNode visitAssignUniqueId(AssignUniqueId node, RewriteContext context) { PlanNode source = context.rewrite(node.getSource()); @@ -581,7 +617,7 @@ public PlanNode visitAssignUniqueId(AssignUniqueId node, RewriteContext co } @Override - public PlanNode visitApply(ApplyNode node, RewriteContext context) + public PlanNode visitApply(ApplyNode node, RewriteContext context) { PlanNode source = context.rewrite(node.getInput()); PlanNode subquery = context.rewrite(node.getSubquery()); @@ -593,7 +629,7 @@ public PlanNode visitApply(ApplyNode node, RewriteContext context) } @Override - public PlanNode visitLateralJoin(LateralJoinNode node, RewriteContext context) + public PlanNode visitLateralJoin(LateralJoinNode node, RewriteContext context) { PlanNode source = context.rewrite(node.getInput()); PlanNode subquery = context.rewrite(node.getSubquery()); @@ -603,7 +639,7 @@ public PlanNode visitLateralJoin(LateralJoinNode node, RewriteContext cont } @Override - public PlanNode visitTopN(TopNNode node, RewriteContext context) + public PlanNode visitTopN(TopNNode node, RewriteContext context) { PlanNode source = context.rewrite(node.getSource()); @@ -612,7 +648,7 @@ public PlanNode visitTopN(TopNNode node, RewriteContext context) } @Override - public PlanNode visitSort(SortNode node, RewriteContext context) + public PlanNode visitSort(SortNode node, RewriteContext context) { PlanNode source = context.rewrite(node.getSource()); @@ -620,7 +656,7 @@ public PlanNode visitSort(SortNode node, RewriteContext context) } @Override - public PlanNode visitJoin(JoinNode node, RewriteContext context) + public PlanNode visitJoin(JoinNode node, RewriteContext context) { PlanNode left = context.rewrite(node.getLeft()); PlanNode right = context.rewrite(node.getRight()); @@ -655,7 +691,7 @@ public PlanNode visitJoin(JoinNode node, RewriteContext context) } @Override - public PlanNode visitSemiJoin(SemiJoinNode node, RewriteContext context) + public PlanNode visitSemiJoin(SemiJoinNode node, RewriteContext context) { PlanNode source = context.rewrite(node.getSource()); PlanNode filteringSource = context.rewrite(node.getFilteringSource()); @@ -675,7 +711,7 @@ public PlanNode visitSemiJoin(SemiJoinNode node, RewriteContext context) } @Override - public PlanNode visitSpatialJoin(SpatialJoinNode node, RewriteContext context) + public PlanNode visitSpatialJoin(SpatialJoinNode node, RewriteContext context) { PlanNode left = context.rewrite(node.getLeft()); PlanNode right = context.rewrite(node.getRight()); @@ -684,13 +720,13 @@ public PlanNode visitSpatialJoin(SpatialJoinNode node, RewriteContext cont } @Override - public PlanNode visitIndexSource(IndexSourceNode node, RewriteContext context) + public PlanNode visitIndexSource(IndexSourceNode node, RewriteContext context) { return new IndexSourceNode(node.getSourceLocation(), node.getId(), node.getIndexHandle(), node.getTableHandle(), canonicalize(node.getLookupVariables()), node.getOutputVariables(), node.getAssignments(), node.getCurrentConstraint()); } @Override - public PlanNode visitIndexJoin(IndexJoinNode node, RewriteContext context) + public PlanNode visitIndexJoin(IndexJoinNode node, RewriteContext context) { PlanNode probeSource = context.rewrite(node.getProbeSource()); PlanNode indexSource = context.rewrite(node.getIndexSource()); @@ -699,24 +735,24 @@ public PlanNode visitIndexJoin(IndexJoinNode node, RewriteContext context) } @Override - public PlanNode visitUnion(UnionNode node, RewriteContext context) + public PlanNode visitUnion(UnionNode node, RewriteContext context) { return new UnionNode(node.getSourceLocation(), node.getId(), rewriteSources(node, context).build(), canonicalizeSetOperationOutputVariables(node.getOutputVariables()), canonicalizeSetOperationVariableMap(node.getVariableMapping())); } @Override - public PlanNode visitIntersect(IntersectNode node, RewriteContext context) + public PlanNode visitIntersect(IntersectNode node, RewriteContext context) { return new IntersectNode(node.getSourceLocation(), node.getId(), rewriteSources(node, context).build(), canonicalizeSetOperationOutputVariables(node.getOutputVariables()), canonicalizeSetOperationVariableMap(node.getVariableMapping())); } @Override - public PlanNode visitExcept(ExceptNode node, RewriteContext context) + public PlanNode visitExcept(ExceptNode node, RewriteContext context) { return new ExceptNode(node.getSourceLocation(), node.getId(), rewriteSources(node, context).build(), canonicalizeSetOperationOutputVariables(node.getOutputVariables()), canonicalizeSetOperationVariableMap(node.getVariableMapping())); } - private static ImmutableList.Builder rewriteSources(SetOperationNode node, RewriteContext context) + private static ImmutableList.Builder rewriteSources(SetOperationNode node, RewriteContext context) { ImmutableList.Builder rewrittenSources = ImmutableList.builder(); for (PlanNode source : node.getSources()) { @@ -726,7 +762,7 @@ private static ImmutableList.Builder rewriteSources(SetOperationNode n } @Override - public PlanNode visitTableWriter(TableWriterNode node, RewriteContext context) + public PlanNode visitTableWriter(TableWriterNode node, RewriteContext context) { PlanNode source = context.rewrite(node.getSource()); SymbolMapper mapper = new SymbolMapper(mapping, types, warningCollector); @@ -734,7 +770,7 @@ public PlanNode visitTableWriter(TableWriterNode node, RewriteContext cont } @Override - public PlanNode visitTableWriteMerge(TableWriterMergeNode node, RewriteContext context) + public PlanNode visitTableWriteMerge(TableWriterMergeNode node, RewriteContext context) { PlanNode source = context.rewrite(node.getSource()); SymbolMapper mapper = new SymbolMapper(mapping, types, warningCollector); @@ -742,7 +778,7 @@ public PlanNode visitTableWriteMerge(TableWriterMergeNode node, RewriteContext context) + public PlanNode visitPlan(PlanNode node, RewriteContext context) { throw new UnsupportedOperationException("Unsupported plan node " + node.getClass().getSimpleName()); } @@ -908,4 +944,43 @@ private List canonicalizeSetOperationOutputVariable return builder.build(); } } + + private static class UnaliasContext + { + // Correlation mapping is a record of how correlation symbols have been mapped in the subplan which provides them. + // All occurrences of correlation symbols within the correlated subquery must be remapped accordingly. + // In case of nested correlation, correlationMappings has required mappings for correlation symbols from all levels of nesting. + private final Map correlationMapping; + + public UnaliasContext(Map correlationMapping) + { + this.correlationMapping = requireNonNull(correlationMapping, "correlationMapping is null"); + } + + public static UnaliasContext empty() + { + return new UnaliasContext(ImmutableMap.of()); + } + + public Map getCorrelationMapping() + { + return correlationMapping; + } + } + + private abstract static class PlanAndMappings extends PlanNode + { + private final Map mappings; + + public PlanAndMappings(PlanNode root, Map mappings) + { + super(root.getSourceLocation(), root.getId(), root.getStatsEquivalentPlanNode()); + this.mappings = ImmutableMap.copyOf(requireNonNull(mappings, "mappings is null")); + } + + public Map getMappings() + { + return mappings; + } + } } diff --git a/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java b/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java index e9f21fb40ccec..fa6e794ff754f 100644 --- a/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java +++ b/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java @@ -777,7 +777,9 @@ public void installPlugin(Plugin plugin) @Override public void createCatalog(String catalogName, String connectorName, Map properties) { - throw new UnsupportedOperationException(); +// throw new UnsupportedOperationException(); + nodeManager.addCurrentNodeConnector(new ConnectorId(catalogName)); + connectorManager.createConnection(catalogName, connectorName, properties); } @Override From 0a645ce57b7fe0f378c49e5f5a39087ed68a0937 Mon Sep 17 00:00:00 2001 From: mohsaka <135669458+mohsaka@users.noreply.github.com> Date: Thu, 20 Mar 2025 08:09:51 -0700 Subject: [PATCH 09/12] Remove static functions --- .../presto/sql/planner/QueryPlanner.java | 48 ------------------- .../presto/sql/planner/RelationPlanner.java | 4 +- 2 files changed, 2 insertions(+), 50 deletions(-) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java index db2db488a0ca5..be6f4ad57cc57 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java @@ -545,42 +545,6 @@ public PlanAndMappings coerce(PlanBuilder subPlan, List expressions, return new PlanAndMappings(subPlan, mappings.build()); } - /** - * Creates a projection with any additional coercions by identity of the provided expressions. - * - * @return the new subplan and a mapping of each expression to the symbol representing the coercion or an existing symbol if a coercion wasn't needed - */ - public static PlanAndMappings coerce(PlanBuilder subPlan, List expressions, Analysis analysis, PlanNodeIdAllocator idAllocator, VariableAllocator variableAllocator, Metadata metadata, SqlPlannerContext sqlPlannerContextStatic, Session session, SqlParser sqlParserStatic) - { - Assignments.Builder assignments = Assignments.builder(); - assignments.putAll(subPlan.getRoot().getOutputVariables().stream().collect(toImmutableMap(Function.identity(), Function.identity()))); - ImmutableMap.Builder, VariableReferenceExpression> mappings = ImmutableMap.builder(); - for (Expression expression : expressions) { - Type coercion = analysis.getCoercion(expression); - if (coercion != null) { - Type type = analysis.getType(expression); - VariableReferenceExpression variable = newVariable(variableAllocator, expression, coercion); - assignments.put(variable, rowExpression( - new Cast( - subPlan.rewrite(expression), - coercion.getTypeSignature().toString(), - false, - metadata.getFunctionAndTypeManager().isTypeOnlyCoercion(type, coercion)), - sqlPlannerContextStatic, metadata, session, sqlParserStatic, variableAllocator, analysis)); - mappings.put(NodeRef.of(expression), variable); - } - else { - mappings.put(NodeRef.of(expression), subPlan.translate(expression)); - } - } - subPlan = subPlan.withNewRoot( - new ProjectNode( - idAllocator.getNextId(), - subPlan.getRoot(), - assignments.build())); - return new PlanAndMappings(subPlan, mappings.build()); - } - public static OrderingScheme translateOrderingScheme(List items, Function coercions) { List coerced = items.stream() @@ -1394,18 +1358,6 @@ private RowExpression rowExpression(Expression expression, SqlPlannerContext con context.getTranslatorContext()); } - public static RowExpression rowExpression(Expression expression, SqlPlannerContext context, Metadata metadata, Session session, SqlParser sqlParser, VariableAllocator variableAllocator, Analysis analysis) - { - return toRowExpression( - expression, - metadata, - session, - sqlParser, - variableAllocator, - analysis, - context.getTranslatorContext()); - } - private static List toSymbolReferences(List variables) { return variables.stream() diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java index b050600ea072d..01a894b538236 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java @@ -129,7 +129,6 @@ import static com.facebook.presto.sql.analyzer.FeaturesConfig.CteMaterializationStrategy.NONE; import static com.facebook.presto.sql.analyzer.SemanticExceptions.notSupportedException; import static com.facebook.presto.sql.planner.PlannerUtils.newVariable; -import static com.facebook.presto.sql.planner.QueryPlanner.coerce; import static com.facebook.presto.sql.planner.QueryPlanner.translateOrderingScheme; import static com.facebook.presto.sql.planner.TranslateExpressionsUtil.toRowExpression; import static com.facebook.presto.sql.tree.Join.Type.INNER; @@ -280,7 +279,8 @@ protected RelationPlan visitTableFunctionInvocation(TableFunctionInvocation node if (tableArgument.getPartitionBy().isPresent() && !tableArgument.getPartitionBy().get().isEmpty()) { // This is from unnest and may not be correct List partitioningColumns = tableArgument.getPartitionBy().get(); - QueryPlanner.PlanAndMappings copartitionCoercions = coerce(sourcePlanBuilder, partitioningColumns, analysis, idAllocator, variableAllocator, metadata, context, session, sqlParser); + QueryPlanner partitionQueryPlanner = new QueryPlanner(analysis, variableAllocator, idAllocator, lambdaDeclarationToVariableMap, metadata, session, context, sqlParser); + QueryPlanner.PlanAndMappings copartitionCoercions = partitionQueryPlanner.coerce(sourcePlanBuilder, partitioningColumns, analysis, idAllocator, variableAllocator, metadata); sourcePlanBuilder = copartitionCoercions.getSubPlan(); partitionBy = partitioningColumns.stream() .map(copartitionCoercions::get) From c5a3ba48059c8f90e3e90b449f8c595333fa6e0f Mon Sep 17 00:00:00 2001 From: mohsaka <135669458+mohsaka@users.noreply.github.com> Date: Thu, 20 Mar 2025 18:08:56 -0700 Subject: [PATCH 10/12] Formatting --- .../presto/sql/planner/RelationPlanner.java | 3 +- .../UnaliasSymbolReferences.java | 39 +++++++++---------- 2 files changed, 19 insertions(+), 23 deletions(-) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java index 01a894b538236..9b71c5bee784b 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java @@ -253,7 +253,6 @@ protected RelationPlan visitTableFunctionInvocation(TableFunctionInvocation node // process sources in order of argument declarations for (Analysis.TableArgumentAnalysis tableArgument : functionAnalysis.getTableArgumentAnalyses()) { RelationPlan sourcePlan = process(tableArgument.getRelation(), context); - PlanBuilder sourcePlanBuilder = initializePlanBuilder(sourcePlan); // map column names to symbols @@ -266,6 +265,7 @@ protected RelationPlan visitTableFunctionInvocation(TableFunctionInvocation node Optional name = sourceDescriptor.getFieldByIndex(i).getName(); if (name.isPresent()) { columnMapping.put(name.get(), sourcePlan.getVariable(i)); + sourcePlanBuilder.getTranslations().put(new Identifier(name.get()), sourcePlan.getVariable(i)); } } @@ -277,7 +277,6 @@ protected RelationPlan visitTableFunctionInvocation(TableFunctionInvocation node List partitionBy = ImmutableList.of(); // if there are partitioning columns, they might have to be coerced for copartitioning if (tableArgument.getPartitionBy().isPresent() && !tableArgument.getPartitionBy().get().isEmpty()) { - // This is from unnest and may not be correct List partitioningColumns = tableArgument.getPartitionBy().get(); QueryPlanner partitionQueryPlanner = new QueryPlanner(analysis, variableAllocator, idAllocator, lambdaDeclarationToVariableMap, metadata, session, context, sqlParser); QueryPlanner.PlanAndMappings copartitionCoercions = partitionQueryPlanner.coerce(sourcePlanBuilder, partitioningColumns, analysis, idAllocator, variableAllocator, metadata); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java index db6847b7c4ceb..770a1be8d0a85 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java @@ -87,7 +87,6 @@ import com.google.common.collect.Lists; import java.util.ArrayList; -import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.LinkedHashMap; @@ -479,7 +478,6 @@ public PlanNode visitTableFinish(TableFinishNode node, RewriteContext context) { @@ -487,8 +485,8 @@ public PlanNode visitTableFunction(TableFunctionNode node, RewriteContext newProperOutputs = node.getOutputVariables().stream() - .map(mapper::map). - collect(toImmutableList()); + .map(mapper::map) + .collect(toImmutableList()); ImmutableList.Builder newSources = ImmutableList.builder(); ImmutableList.Builder newTableArgumentProperties = ImmutableList.builder(); @@ -497,16 +495,14 @@ public PlanNode visitTableFunction(TableFunctionNode node, RewriteContext newColumnMapping = ImmutableMultimap.builder(); properties.getColumnMapping().entries().stream() .forEach(entry -> newColumnMapping.put(entry.getKey(), inputMapper.map(entry.getValue()))); Optional newSpecification = Optional.of(new DataOrganizationSpecification( inputMapper.mapAndDistinctVariable(properties.specification().get().getPartitionBy()), - Optional.of(inputMapper.map(properties.specification().get().getOrderingScheme().get())) - )); + Optional.of(inputMapper.map(properties.specification().get().getOrderingScheme().get())))); newTableArgumentProperties.add(new TableFunctionNode.TableArgumentProperties( properties.getArgumentName(), newColumnMapping.build(), @@ -526,7 +522,7 @@ public PlanNode visitTableFunction(TableFunctionNode node, RewriteContext getCorrelat { return correlationMapping; } - } - private abstract static class PlanAndMappings extends PlanNode - { - private final Map mappings; - - public PlanAndMappings(PlanNode root, Map mappings) + private abstract static class PlanAndMappings + extends PlanNode { - super(root.getSourceLocation(), root.getId(), root.getStatsEquivalentPlanNode()); - this.mappings = ImmutableMap.copyOf(requireNonNull(mappings, "mappings is null")); - } + private final Map mappings; - public Map getMappings() - { - return mappings; + public PlanAndMappings(PlanNode root, Map mappings) + { + super(root.getSourceLocation(), root.getId(), root.getStatsEquivalentPlanNode()); + this.mappings = ImmutableMap.copyOf(requireNonNull(mappings, "mappings is null")); + } + + public Map getMappings() + { + return mappings; + } } } } From c9fe3ebeb3bfb442a8d11271ceb802b26ac58759 Mon Sep 17 00:00:00 2001 From: mohsaka <135669458+mohsaka@users.noreply.github.com> Date: Thu, 20 Mar 2025 19:00:29 -0700 Subject: [PATCH 11/12] Add Translation via Analyzed table argument's getPartitionBy and the sourceDescriptor --- .../java/com/facebook/presto/sql/planner/RelationPlanner.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java index 9b71c5bee784b..4da635e558ea6 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java @@ -265,7 +265,7 @@ protected RelationPlan visitTableFunctionInvocation(TableFunctionInvocation node Optional name = sourceDescriptor.getFieldByIndex(i).getName(); if (name.isPresent()) { columnMapping.put(name.get(), sourcePlan.getVariable(i)); - sourcePlanBuilder.getTranslations().put(new Identifier(name.get()), sourcePlan.getVariable(i)); + sourcePlanBuilder.getTranslations().put(tableArgument.getPartitionBy().get().get(i), sourcePlan.getVariable(i)); } } From 4fec942c13e296685ca84f2ff03f6258a89ef6f8 Mon Sep 17 00:00:00 2001 From: Xin Zhang Date: Fri, 21 Mar 2025 11:47:57 +0000 Subject: [PATCH 12/12] Fix debugging test issues --- .../presto/sql/planner/RelationPlanner.java | 5 +- .../planner/optimizations/SymbolMapper.java | 10 ++- .../UnaliasSymbolReferences.java | 19 +++-- .../sql/planner/plan/TableFunctionNode.java | 7 +- .../planner/TestTableFunctionInvocation.java | 4 +- .../planner/assertions/PlanMatchPattern.java | 5 ++ .../assertions/TableFunctionMatcher.java | 69 +++++++++++++++++-- 7 files changed, 105 insertions(+), 14 deletions(-) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java index 4da635e558ea6..3b266c8853167 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java @@ -265,7 +265,10 @@ protected RelationPlan visitTableFunctionInvocation(TableFunctionInvocation node Optional name = sourceDescriptor.getFieldByIndex(i).getName(); if (name.isPresent()) { columnMapping.put(name.get(), sourcePlan.getVariable(i)); - sourcePlanBuilder.getTranslations().put(tableArgument.getPartitionBy().get().get(i), sourcePlan.getVariable(i)); + Optional> partitionBy = tableArgument.getPartitionBy(); + if (partitionBy.isPresent()) { + sourcePlanBuilder.getTranslations().put(partitionBy.get().get(i), sourcePlan.getVariable(i)); + } } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/SymbolMapper.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/SymbolMapper.java index 32191c7e6e44c..1e56f3eeaa72b 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/SymbolMapper.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/SymbolMapper.java @@ -20,6 +20,7 @@ import com.facebook.presto.spi.WarningCollector; import com.facebook.presto.spi.plan.AggregationNode; import com.facebook.presto.spi.plan.AggregationNode.Aggregation; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; import com.facebook.presto.spi.plan.Ordering; import com.facebook.presto.spi.plan.OrderingScheme; import com.facebook.presto.spi.plan.PartitioningScheme; @@ -335,7 +336,7 @@ private List mapAndDistinctSymbol(List outputs) return builder.build(); } - List mapAndDistinctVariable(List outputs) + private List mapAndDistinctVariable(List outputs) { Set added = new HashSet<>(); ImmutableList.Builder builder = ImmutableList.builder(); @@ -348,6 +349,13 @@ List mapAndDistinctVariable(List context) { - Map mappings = new HashMap<>(context.get().getCorrelationMapping()); + Map mappings = + Optional.ofNullable(context.get()) + .map(c -> new HashMap<>(c.getCorrelationMapping())) + .orElseGet(HashMap::new); + SymbolMapper mapper = new SymbolMapper(mappings, warningCollector); List newProperOutputs = node.getOutputVariables().stream() @@ -495,14 +499,19 @@ public PlanNode visitTableFunction(TableFunctionNode node, RewriteContext(), + warningCollector); + TableFunctionNode.TableArgumentProperties properties = node.getTableArgumentProperties().get(i); ImmutableMultimap.Builder newColumnMapping = ImmutableMultimap.builder(); properties.getColumnMapping().entries().stream() .forEach(entry -> newColumnMapping.put(entry.getKey(), inputMapper.map(entry.getValue()))); - Optional newSpecification = Optional.of(new DataOrganizationSpecification( - inputMapper.mapAndDistinctVariable(properties.specification().get().getPartitionBy()), - Optional.of(inputMapper.map(properties.specification().get().getOrderingScheme().get())))); + + Optional newSpecification = properties.specification().map(inputMapper::mapAndDistinct); + newTableArgumentProperties.add(new TableFunctionNode.TableArgumentProperties( properties.getArgumentName(), newColumnMapping.build(), diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/TableFunctionNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/TableFunctionNode.java index 061d09366f093..da5155664eed2 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/TableFunctionNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/TableFunctionNode.java @@ -114,13 +114,18 @@ public List getOutputVariables() else { sourceProperties.specification() .map(DataOrganizationSpecification::getPartitionBy) - .ifPresent(outputVariables::addAll); + .ifPresent(variables::addAll); } } return variables.build(); } + public List getProperOutput() + { + return outputVariables; + } + @JsonProperty public List getTableArgumentProperties() { diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestTableFunctionInvocation.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestTableFunctionInvocation.java index 14aab03553a22..fd755e05bb06b 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestTableFunctionInvocation.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestTableFunctionInvocation.java @@ -107,9 +107,9 @@ public void testTableFunctionInitialPlan() new Field("Y", Optional.of(BIGINT)))))) .addCopartitioning(ImmutableList.of("INPUT_1", "INPUT_3")) .properOutputs(ImmutableList.of("OUTPUT")), - anyTree(project(ImmutableMap.of("c1", expression("'a'")), values("1"))), + anyTree(project(ImmutableMap.of("c1", expression("'a'")), values(1))), anyTree(values(ImmutableList.of("c2"), ImmutableList.of(ImmutableList.of(new LongLiteral("1"))))), - anyTree(project(ImmutableMap.of("c3", expression("'b'")), values("1")))))); + anyTree(project(ImmutableMap.of("c3", expression("'b'")), values(1)))))); } @Test diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java index 03b73673abf42..27be55d920056 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java @@ -628,6 +628,11 @@ public static PlanMatchPattern values(String... aliases) return values(ImmutableList.copyOf(aliases)); } + public static PlanMatchPattern values(int rowCount) + { + return values(ImmutableList.of(), nCopies(rowCount, ImmutableList.of())); + } + public static PlanMatchPattern values(List aliases, List> expectedRows) { return values(aliases, Optional.of(expectedRows)); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/TableFunctionMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/TableFunctionMatcher.java index 9eb971c453ed3..9683099ed220e 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/TableFunctionMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/TableFunctionMatcher.java @@ -14,6 +14,7 @@ package com.facebook.presto.sql.planner.assertions; import com.facebook.presto.Session; +import com.facebook.presto.common.block.SortOrder; import com.facebook.presto.cost.StatsProvider; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.function.table.Argument; @@ -22,7 +23,10 @@ import com.facebook.presto.spi.function.table.ScalarArgument; import com.facebook.presto.spi.function.table.TableArgument; import com.facebook.presto.spi.plan.DataOrganizationSpecification; +import com.facebook.presto.spi.plan.Ordering; +import com.facebook.presto.spi.plan.OrderingScheme; import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.plan.TableFunctionNode; import com.facebook.presto.sql.planner.plan.TableFunctionNode.TableArgumentProperties; import com.facebook.presto.sql.tree.SymbolReference; @@ -35,6 +39,8 @@ import java.util.Map; import java.util.Objects; import java.util.Optional; +import java.util.function.BiPredicate; +import java.util.stream.IntStream; import static com.facebook.presto.sql.planner.assertions.MatchResult.NO_MATCH; import static com.facebook.presto.sql.planner.assertions.MatchResult.match; @@ -120,16 +126,17 @@ else if (expected instanceof ScalarArgumentValue) { expectedTableArgument.passThroughColumns() != argumentProperties.passThroughColumns()) { return NO_MATCH; } - boolean specificationMatches = expectedTableArgument.specification() - .map(specification -> specification.getExpectedValue(symbolAliases)) - .equals(argumentProperties.specification()); + boolean specificationMatches = customDataOrganizationSpecificationEquals( + expectedTableArgument.specification().map(specification -> specification.getExpectedValue(symbolAliases)), + argumentProperties.specification(), + (v1, v2) -> v1.getName().equals(v2.getName())); if (!specificationMatches) { return NO_MATCH; } } } - if (properOutputs.size() != tableFunctionNode.getOutputVariables().size()) { + if (properOutputs.size() != tableFunctionNode.getProperOutput().size()) { return NO_MATCH; } @@ -358,4 +365,58 @@ private TableArgumentValue build() } } } + + private static boolean customDataOrganizationSpecificationEquals( + Optional left, + Optional right, + BiPredicate comparator) + { + if (!left.isPresent() && !right.isPresent()) { + return true; + } + if (!left.isPresent() || !right.isPresent()) { + return false; + } + + DataOrganizationSpecification leftSpecification = left.get(); + DataOrganizationSpecification rightSpecification = right.get(); + + List leftPartitionBy = leftSpecification.getPartitionBy(); + List rightPartitionBy = rightSpecification.getPartitionBy(); + if (leftPartitionBy.size() != rightPartitionBy.size() + || IntStream.range(0, leftPartitionBy.size()) + .anyMatch(i -> !comparator.test(leftPartitionBy.get(i), rightPartitionBy.get(i)))) { + return false; + } + + Optional leftOrderingScheme = leftSpecification.getOrderingScheme(); + Optional rightOrderingScheme = rightSpecification.getOrderingScheme(); + if (!leftOrderingScheme.isPresent() && !rightOrderingScheme.isPresent()) { + return true; + } + if (!leftOrderingScheme.isPresent() || !rightOrderingScheme.isPresent()) { + return false; + } + + List leftOrderBy = leftOrderingScheme.get().getOrderBy(); + List rightOrderBy = rightOrderingScheme.get().getOrderBy(); + + if (leftOrderBy.size() != rightOrderBy.size() + || IntStream.range(0, leftOrderBy.size()) + .anyMatch(i -> !comparator.test(leftOrderBy.get(i).getVariable(), rightOrderBy.get(i).getVariable()) + || !leftOrderBy.get(i).getSortOrder().equals(rightOrderBy.get(i).getSortOrder()))) { + return false; + } + + Map leftOrdering = leftOrderingScheme.get().getOrderingsMap(); + Map rightOrdering = rightOrderingScheme.get().getOrderingsMap(); + if (leftOrdering.size() != rightOrdering.size()) { + return false; + } + + return leftOrdering.entrySet().stream() + .allMatch(entry -> + rightOrdering.entrySet().stream() + .anyMatch(e -> comparator.test(entry.getKey(), e.getKey()) && entry.getValue().equals(e.getValue()))); + } }