diff --git a/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/Analysis.java b/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/Analysis.java index ade11574c3974..83239d6cfc70b 100644 --- a/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/Analysis.java +++ b/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/Analysis.java @@ -19,13 +19,17 @@ import com.facebook.presto.common.type.Type; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ColumnMetadata; +import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.TableHandle; import com.facebook.presto.spi.analyzer.AccessControlInfo; import com.facebook.presto.spi.analyzer.AccessControlInfoForTable; import com.facebook.presto.spi.analyzer.AccessControlReferences; import com.facebook.presto.spi.analyzer.AccessControlRole; +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; import com.facebook.presto.spi.function.FunctionHandle; import com.facebook.presto.spi.function.FunctionKind; +import com.facebook.presto.spi.function.table.Argument; +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; import com.facebook.presto.spi.security.AccessControl; import com.facebook.presto.spi.security.AccessControlContext; import com.facebook.presto.spi.security.AllowAllAccessControl; @@ -43,6 +47,7 @@ import com.facebook.presto.sql.tree.Offset; import com.facebook.presto.sql.tree.OrderBy; import com.facebook.presto.sql.tree.Parameter; +import com.facebook.presto.sql.tree.QualifiedName; import com.facebook.presto.sql.tree.QuantifiedComparisonExpression; import com.facebook.presto.sql.tree.Query; import com.facebook.presto.sql.tree.QuerySpecification; @@ -51,6 +56,7 @@ import com.facebook.presto.sql.tree.Statement; import com.facebook.presto.sql.tree.SubqueryExpression; import com.facebook.presto.sql.tree.Table; +import com.facebook.presto.sql.tree.TableFunctionInvocation; import com.google.common.collect.ArrayListMultimap; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -190,6 +196,12 @@ public class Analysis // Keeps track of the subquery we are visiting, so we have access to base query information when processing materialized view status private Optional currentQuerySpecification = Optional.empty(); + // names of tables and aliased relations. All names are resolved case-insensitive. + private final Map, QualifiedName> relationNames = new LinkedHashMap<>(); + private final Map, TableFunctionInvocationAnalysis> tableFunctionAnalyses = new LinkedHashMap<>(); + private final Set> aliasedRelations = new LinkedHashSet<>(); + private final Set> polymorphicTableFunctions = new LinkedHashSet<>(); + public Analysis(@Nullable Statement root, Map, Expression> parameters, boolean isDescribe) { this.root = root; @@ -994,6 +1006,46 @@ public Map> getInvokedFunctions() return functionMap.entrySet().stream().collect(toImmutableMap(Map.Entry::getKey, entry -> ImmutableSet.copyOf(entry.getValue()))); } + public void setTableFunctionAnalysis(TableFunctionInvocation node, TableFunctionInvocationAnalysis analysis) + { + tableFunctionAnalyses.put(NodeRef.of(node), analysis); + } + + public TableFunctionInvocationAnalysis getTableFunctionAnalysis(TableFunctionInvocation node) + { + return tableFunctionAnalyses.get(NodeRef.of(node)); + } + + public void setRelationName(Relation relation, QualifiedName name) + { + relationNames.put(NodeRef.of(relation), name); + } + + public QualifiedName getRelationName(Relation relation) + { + return relationNames.get(NodeRef.of(relation)); + } + + public void addAliased(Relation relation) + { + aliasedRelations.add(NodeRef.of(relation)); + } + + public boolean isAliased(Relation relation) + { + return aliasedRelations.contains(NodeRef.of(relation)); + } + + public void addPolymorphicTableFunction(TableFunctionInvocation invocation) + { + polymorphicTableFunctions.add(NodeRef.of(invocation)); + } + + public boolean isPolymorphicTableFunction(TableFunctionInvocation invocation) + { + return polymorphicTableFunctions.contains(NodeRef.of(invocation)); + } + @Immutable public static final class Insert { @@ -1177,4 +1229,242 @@ public boolean isFromView() return isFromView; } } + + public static class TableArgumentAnalysis + { + private final String argumentName; + private final Optional name; + private final Relation relation; + private final Optional> partitionBy; // it is allowed to partition by empty list + private final Optional orderBy; + private final boolean pruneWhenEmpty; + private final boolean rowSemantics; + private final boolean passThroughColumns; + + private TableArgumentAnalysis( + String argumentName, + Optional name, + Relation relation, + Optional> partitionBy, + Optional orderBy, + boolean pruneWhenEmpty, + boolean rowSemantics, + boolean passThroughColumns) + { + this.argumentName = requireNonNull(argumentName, "argumentName is null"); + this.name = requireNonNull(name, "name is null"); + this.relation = requireNonNull(relation, "relation is null"); + this.partitionBy = requireNonNull(partitionBy, "partitionBy is null").map(ImmutableList::copyOf); + this.orderBy = requireNonNull(orderBy, "orderBy is null"); + this.pruneWhenEmpty = pruneWhenEmpty; + this.rowSemantics = rowSemantics; + this.passThroughColumns = passThroughColumns; + } + + public String getArgumentName() + { + return argumentName; + } + + public Optional getName() + { + return name; + } + + public Relation getRelation() + { + return relation; + } + + public Optional> getPartitionBy() + { + return partitionBy; + } + + public Optional getOrderBy() + { + return orderBy; + } + + public boolean isPruneWhenEmpty() + { + return pruneWhenEmpty; + } + + public boolean isRowSemantics() + { + return rowSemantics; + } + + public boolean isPassThroughColumns() + { + return passThroughColumns; + } + + public static Builder builder() + { + return new Builder(); + } + + public static final class Builder + { + private String argumentName; + private Optional name = Optional.empty(); + private Relation relation; + private Optional> partitionBy = Optional.empty(); + private Optional orderBy = Optional.empty(); + private boolean pruneWhenEmpty; + private boolean rowSemantics; + private boolean passThroughColumns; + + private Builder() {} + + public Builder withArgumentName(String argumentName) + { + this.argumentName = argumentName; + return this; + } + + public Builder withName(QualifiedName name) + { + this.name = Optional.of(name); + return this; + } + + public Builder withRelation(Relation relation) + { + this.relation = relation; + return this; + } + + public Builder withPartitionBy(List partitionBy) + { + this.partitionBy = Optional.of(partitionBy); + return this; + } + + public Builder withOrderBy(OrderBy orderBy) + { + this.orderBy = Optional.of(orderBy); + return this; + } + + public Builder withPruneWhenEmpty(boolean pruneWhenEmpty) + { + this.pruneWhenEmpty = pruneWhenEmpty; + return this; + } + + public Builder withRowSemantics(boolean rowSemantics) + { + this.rowSemantics = rowSemantics; + return this; + } + + public Builder withPassThroughColumns(boolean passThroughColumns) + { + this.passThroughColumns = passThroughColumns; + return this; + } + + public TableArgumentAnalysis build() + { + return new TableArgumentAnalysis(argumentName, name, relation, partitionBy, orderBy, pruneWhenEmpty, rowSemantics, passThroughColumns); + } + } + } + + public static class TableFunctionInvocationAnalysis + { + private final ConnectorId connectorId; + private final String schemaName; + private final String functionName; + private final Map arguments; + private final List tableArgumentAnalyses; + private final List> copartitioningLists; + private final Map> requiredColumns; + private final int properColumnsCount; + private final ConnectorTableFunctionHandle connectorTableFunctionHandle; + private final ConnectorTransactionHandle transactionHandle; + + public TableFunctionInvocationAnalysis( + ConnectorId connectorId, + String schemaName, + String functionName, + Map arguments, + List tableArgumentAnalyses, + Map> requiredColumns, + List> copartitioningLists, + int properColumnsCount, + ConnectorTableFunctionHandle connectorTableFunctionHandle, + ConnectorTransactionHandle transactionHandle) + { + this.connectorId = requireNonNull(connectorId, "connectorId is null"); + this.schemaName = requireNonNull(schemaName, "schemaName is null"); + this.functionName = requireNonNull(functionName, "functionName is null"); + this.arguments = ImmutableMap.copyOf(arguments); + this.connectorTableFunctionHandle = requireNonNull(connectorTableFunctionHandle, "connectorTableFunctionHandle is null"); + this.transactionHandle = requireNonNull(transactionHandle, "transactionHandle is null"); + this.tableArgumentAnalyses = ImmutableList.copyOf(tableArgumentAnalyses); + this.requiredColumns = requiredColumns.entrySet().stream() + .collect(toImmutableMap(Map.Entry::getKey, entry -> ImmutableList.copyOf(entry.getValue()))); + this.copartitioningLists = ImmutableList.copyOf(copartitioningLists); + this.properColumnsCount = properColumnsCount; + } + + public ConnectorId getConnectorId() + { + return connectorId; + } + + public String getSchemaName() + { + return schemaName; + } + + public String getFunctionName() + { + return functionName; + } + + public Map getArguments() + { + return arguments; + } + + public List getTableArgumentAnalyses() + { + return tableArgumentAnalyses; + } + + public Map> getRequiredColumns() + { + return requiredColumns; + } + + public List> getCopartitioningLists() + { + return copartitioningLists; + } + + /** + * Proper columns are the columns produced by the table function, as opposed to pass-through columns from input tables. + * Proper columns should be considered the actual result of the table function. + * @return the number of table function's proper columns + */ + public int getProperColumnsCount() + { + return properColumnsCount; + } + + public ConnectorTableFunctionHandle getConnectorTableFunctionHandle() + { + return connectorTableFunctionHandle; + } + + public ConnectorTransactionHandle getTransactionHandle() + { + return transactionHandle; + } + } } diff --git a/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/Field.java b/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/Field.java index 630f4670f6cc2..15c33950ce14b 100644 --- a/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/Field.java +++ b/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/Field.java @@ -86,6 +86,14 @@ public Field(Optional nodeLocation, Optional relati this.aliased = aliased; } + public static Field newUnqualified(Optional name, Type type) + { + requireNonNull(name, "name is null"); + requireNonNull(type, "type is null"); + + return new Field(Optional.empty(), Optional.empty(), name, type, false, Optional.empty(), Optional.empty(), false); + } + public Optional getNodeLocation() { return nodeLocation; diff --git a/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/SemanticErrorCode.java b/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/SemanticErrorCode.java index d5492e6bb6932..34966b73938bf 100644 --- a/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/SemanticErrorCode.java +++ b/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/SemanticErrorCode.java @@ -56,6 +56,13 @@ public enum SemanticErrorCode FUNCTION_NOT_FOUND, INVALID_FUNCTION_NAME, DUPLICATE_PARAMETER_NAME, + FUNCTION_IMPLEMENTATION_ERROR, + + MISSING_RETURN_TYPE, + AMBIGUOUS_RETURN_TYPE, + MISSING_ARGUMENT, + INVALID_FUNCTION_ARGUMENT, + INVALID_ARGUMENTS, ORDER_BY_MUST_BE_IN_SELECT, ORDER_BY_MUST_BE_IN_AGGREGATE, @@ -111,4 +118,9 @@ public enum SemanticErrorCode TOO_MANY_GROUPING_SETS, INVALID_OFFSET_ROW_COUNT, + INVALID_COPARTITIONING, + INVALID_TABLE_FUNCTION_INVOCATION, + DUPLICATE_RANGE_VARIABLE, + INVALID_COLUMN_REFERENCE, + COLUMN_NOT_FOUND } diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcConnector.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcConnector.java index 511899a2e03f9..60058dd68ff86 100644 --- a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcConnector.java +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcConnector.java @@ -46,9 +46,6 @@ import static com.facebook.presto.spi.connector.ConnectorCapabilities.NOT_NULL_COLUMN_CONSTRAINT; import static com.facebook.presto.spi.connector.EmptyConnectorCommitHandle.INSTANCE; -import static com.facebook.presto.spi.transaction.IsolationLevel.READ_COMMITTED; -import static com.facebook.presto.spi.transaction.IsolationLevel.checkConnectorSupports; -import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.Sets.immutableEnumSet; import static java.util.Objects.requireNonNull; @@ -71,6 +68,7 @@ public class JdbcConnector private final RowExpressionService rowExpressionService; private final JdbcClient jdbcClient; private final List> sessionProperties; + private final JdbcTransactionManager transactionManager; @Inject public JdbcConnector( @@ -85,7 +83,8 @@ public JdbcConnector( StandardFunctionResolution functionResolution, RowExpressionService rowExpressionService, JdbcClient jdbcClient, - Optional sessionPropertiesProvider) + Optional sessionPropertiesProvider, + JdbcTransactionManager transactionManager) { this.lifeCycleManager = requireNonNull(lifeCycleManager, "lifeCycleManager is null"); this.jdbcMetadataFactory = requireNonNull(jdbcMetadataFactory, "jdbcMetadataFactory is null"); @@ -99,6 +98,7 @@ public JdbcConnector( this.rowExpressionService = requireNonNull(rowExpressionService, "rowExpressionService is null"); this.jdbcClient = requireNonNull(jdbcClient, "jdbcClient is null"); this.sessionProperties = requireNonNull(sessionPropertiesProvider, "sessionPropertiesProvider is null").map(JdbcSessionPropertiesProvider::getSessionProperties).orElse(ImmutableList.of()); + this.transactionManager = requireNonNull(transactionManager, "transactionManager is null"); } @Override @@ -121,33 +121,26 @@ public boolean isSingleStatementWritesOnly() @Override public ConnectorTransactionHandle beginTransaction(IsolationLevel isolationLevel, boolean readOnly) { - checkConnectorSupports(READ_COMMITTED, isolationLevel); - JdbcTransactionHandle transaction = new JdbcTransactionHandle(); - transactions.put(transaction, jdbcMetadataFactory.create()); - return transaction; + return transactionManager.beginTransaction(isolationLevel, readOnly); } @Override public ConnectorMetadata getMetadata(ConnectorTransactionHandle transaction) { - JdbcMetadata metadata = transactions.get(transaction); - checkArgument(metadata != null, "no such transaction: %s", transaction); - return new ClassLoaderSafeConnectorMetadata(metadata, getClass().getClassLoader()); + return new ClassLoaderSafeConnectorMetadata(transactionManager.getMetadata(transaction), getClass().getClassLoader()); } @Override public ConnectorCommitHandle commit(ConnectorTransactionHandle transaction) { - checkArgument(transactions.remove(transaction) != null, "no such transaction: %s", transaction); + transactionManager.commit(transaction); return INSTANCE; } @Override public void rollback(ConnectorTransactionHandle transaction) { - JdbcMetadata metadata = transactions.remove(transaction); - checkArgument(metadata != null, "no such transaction: %s", transaction); - metadata.rollback(); + transactionManager.rollback(transaction); } @Override diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcModule.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcModule.java index b3b67ec041cd1..ac05905ea081f 100644 --- a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcModule.java +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcModule.java @@ -52,6 +52,7 @@ public void configure(Binder binder) binder.bind(JdbcRecordSetProvider.class).in(Scopes.SINGLETON); binder.bind(JdbcPageSinkProvider.class).in(Scopes.SINGLETON); newOptionalBinder(binder, JdbcSessionPropertiesProvider.class); + binder.bind(JdbcTransactionManager.class).in(Scopes.SINGLETON); binder.bind(JdbcConnector.class).in(Scopes.SINGLETON); configBinder(binder).bindConfig(JdbcMetadataConfig.class); } diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcTransactionManager.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcTransactionManager.java new file mode 100644 index 0000000000000..dbcb81f72db7a --- /dev/null +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcTransactionManager.java @@ -0,0 +1,66 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.plugin.jdbc; + +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import com.facebook.presto.spi.transaction.IsolationLevel; + +import javax.inject.Inject; + +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; + +import static com.facebook.presto.spi.transaction.IsolationLevel.READ_COMMITTED; +import static com.facebook.presto.spi.transaction.IsolationLevel.checkConnectorSupports; +import static com.google.common.base.Preconditions.checkArgument; +import static java.util.Objects.requireNonNull; + +public class JdbcTransactionManager +{ + private final ConcurrentMap transactions = new ConcurrentHashMap<>(); + private final JdbcMetadataFactory metadataFactory; + + @Inject + public JdbcTransactionManager(JdbcMetadataFactory metadataFactory) + { + this.metadataFactory = requireNonNull(metadataFactory, "metadataFactory is null"); + } + + public ConnectorTransactionHandle beginTransaction(IsolationLevel isolationLevel, boolean readOnly) + { + checkConnectorSupports(READ_COMMITTED, isolationLevel); + JdbcTransactionHandle transaction = new JdbcTransactionHandle(); + transactions.put(transaction, metadataFactory.create()); + return transaction; + } + + public JdbcMetadata getMetadata(ConnectorTransactionHandle transaction) + { + JdbcMetadata metadata = transactions.get(transaction); + checkArgument(metadata != null, "no such transaction: %s", transaction); + return metadata; + } + + public void commit(ConnectorTransactionHandle transaction) + { + checkArgument(transactions.remove(transaction) != null, "no such transaction: %s", transaction); + } + + public void rollback(ConnectorTransactionHandle transaction) + { + JdbcMetadata metadata = transactions.remove(transaction); + checkArgument(metadata != null, "no such transaction: %s", transaction); + metadata.rollback(); + } +} diff --git a/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/TestJdbcDistributedQueries.java b/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/TestJdbcDistributedQueries.java index 2d2962742a6e3..8b73ea66dc137 100644 --- a/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/TestJdbcDistributedQueries.java +++ b/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/TestJdbcDistributedQueries.java @@ -13,11 +13,14 @@ */ package com.facebook.presto.plugin.jdbc; +import com.facebook.presto.Session; import com.facebook.presto.testing.QueryRunner; import com.facebook.presto.tests.AbstractTestQueries; import io.airlift.tpch.TpchTable; +import org.testng.annotations.Test; import static com.facebook.presto.plugin.jdbc.JdbcQueryRunner.createJdbcQueryRunner; +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; public class TestJdbcDistributedQueries extends AbstractTestQueries @@ -33,4 +36,15 @@ protected QueryRunner createQueryRunner() public void testLargeIn() { } + + @Test + public void testNativeQueryParameters() + { + Session session = testSessionBuilder() + .addPreparedStatement("my_query_simple", "SELECT * FROM TABLE(system.query(query => ?))") + .addPreparedStatement("my_query", "SELECT * FROM TABLE(system.query(query => format('SELECT %s FROM %s', ?, ?)))") + .build(); + assertQueryFails(session, "EXECUTE my_query_simple USING 'SELECT 1 a'", "line 1:21: Table function system.query not registered"); + assertQueryFails(session, "EXECUTE my_query USING 'a', '(SELECT 2 a) t'", "line 1:21: Table function system.query not registered"); + } } diff --git a/presto-common/src/main/java/com/facebook/presto/common/Page.java b/presto-common/src/main/java/com/facebook/presto/common/Page.java index 2e941461dca9f..3ae3b8b49de01 100644 --- a/presto-common/src/main/java/com/facebook/presto/common/Page.java +++ b/presto-common/src/main/java/com/facebook/presto/common/Page.java @@ -416,6 +416,22 @@ public Page extractChannels(int[] channels) return wrapBlocksWithoutCopy(positionCount, blocks); } + public Page getColumns(int column) + { + return wrapBlocksWithoutCopy(positionCount, new Block[] {this.blocks[column]}); + } + + public Page getColumns(int... columns) + { + requireNonNull(columns, "columns is null"); + + Block[] blocks = new Block[columns.length]; + for (int i = 0; i < columns.length; i++) { + blocks[i] = this.blocks[columns[i]]; + } + return wrapBlocksWithoutCopy(positionCount, blocks); + } + public Page prependColumn(Block column) { if (column.getPositionCount() != positionCount) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/connector/ConnectorManager.java b/presto-main-base/src/main/java/com/facebook/presto/connector/ConnectorManager.java index 56f3e15d72aa6..06e2633c23bdd 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/connector/ConnectorManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/connector/ConnectorManager.java @@ -52,6 +52,7 @@ import com.facebook.presto.spi.connector.ConnectorRecordSetProvider; import com.facebook.presto.spi.connector.ConnectorSplitManager; import com.facebook.presto.spi.connector.ConnectorTypeSerdeProvider; +import com.facebook.presto.spi.function.table.ConnectorTableFunction; import com.facebook.presto.spi.procedure.Procedure; import com.facebook.presto.spi.relation.DeterminismEvaluator; import com.facebook.presto.spi.relation.DomainTranslator; @@ -107,7 +108,6 @@ public class ConnectorManager private final ConnectorPlanOptimizerManager connectorPlanOptimizerManager; private final ConnectorMetadataUpdaterManager connectorMetadataUpdaterManager; private final ConnectorTypeSerdeManager connectorTypeSerdeManager; - private final PageSinkManager pageSinkManager; private final HandleResolver handleResolver; private final InternalNodeManager nodeManager; @@ -211,6 +211,14 @@ public synchronized void addConnectorFactory(ConnectorFactory connectorFactory) ConnectorFactory existingConnectorFactory = connectorFactories.putIfAbsent(connectorFactory.getName(), connectorFactory); checkArgument(existingConnectorFactory == null, "Connector %s is already registered", connectorFactory.getName()); handleResolver.addConnectorName(connectorFactory.getName(), connectorFactory.getHandleResolver()); + + connectorFactory.getTableFunctionHandleResolver().ifPresent(resolver -> { + handleResolver.addTableFunctionNamespace(connectorFactory.getName(), resolver); + }); + + connectorFactory.getTableFunctionSplitResolver().ifPresent(resolver -> { + handleResolver.addTableFunctionSplitNamespace(connectorFactory.getName(), resolver); + }); } public synchronized ConnectorId createConnection(String catalogName, String connectorName, Map properties) @@ -331,6 +339,7 @@ private synchronized void addConnectorInternal(MaterializedConnector connector) metadataManager.getSchemaPropertyManager().addProperties(connectorId, connector.getSchemaProperties()); metadataManager.getAnalyzePropertyManager().addProperties(connectorId, connector.getAnalyzeProperties()); metadataManager.getSessionPropertyManager().addConnectorSessionProperties(connectorId, connector.getSessionProperties()); + metadataManager.getFunctionAndTypeManager().getTableFunctionRegistry().addTableFunctions(connectorId, connector.getTableFunctions()); } public synchronized void dropConnection(String catalogName) @@ -342,6 +351,7 @@ public synchronized void dropConnection(String catalogName) removeConnectorInternal(connectorId); removeConnectorInternal(createInformationSchemaConnectorId(connectorId)); removeConnectorInternal(createSystemTablesConnectorId(connectorId)); + metadataManager.getFunctionAndTypeManager().getTableFunctionRegistry().removeTableFunctions(connectorId); }); } @@ -405,6 +415,7 @@ private static class MaterializedConnector private final ConnectorSplitManager splitManager; private final Set systemTables; private final Set procedures; + private final Set connectorTableFunctions; private final ConnectorPageSourceProvider pageSourceProvider; private final Optional pageSinkProvider; private final Optional indexProvider; @@ -435,6 +446,10 @@ public MaterializedConnector(ConnectorId connectorId, Connector connector) requireNonNull(procedures, "Connector %s returned a null procedures set"); this.procedures = ImmutableSet.copyOf(procedures); + Set connectorTableFunctions = connector.getTableFunctions(); + requireNonNull(connectorTableFunctions, format("Connector '%s' returned a null table functions set", connectorId)); + this.connectorTableFunctions = ImmutableSet.copyOf(connectorTableFunctions); + ConnectorPageSourceProvider connectorPageSourceProvider = null; try { connectorPageSourceProvider = connector.getPageSourceProvider(); @@ -628,5 +643,10 @@ public List> getAnalyzeProperties() { return analyzeProperties; } + + public Set getTableFunctions() + { + return connectorTableFunctions; + } } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/SqlQueryExecution.java b/presto-main-base/src/main/java/com/facebook/presto/execution/SqlQueryExecution.java index 44a39427b0c84..9614e12956198 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/SqlQueryExecution.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/SqlQueryExecution.java @@ -632,7 +632,7 @@ private PlanRoot runCreateLogicalPlanAsync() private void createQueryScheduler(PlanRoot plan) { - CloseableSplitSourceProvider splitSourceProvider = new CloseableSplitSourceProvider(splitManager::getSplits); + CloseableSplitSourceProvider splitSourceProvider = new CloseableSplitSourceProvider(splitManager); // ensure split sources are closed stateMachine.addStateChangeListener(state -> { diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/SqlTaskExecution.java b/presto-main-base/src/main/java/com/facebook/presto/execution/SqlTaskExecution.java index 43885270d21a6..9216ff90e0b9d 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/SqlTaskExecution.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/SqlTaskExecution.java @@ -1082,7 +1082,10 @@ public ListenableFuture processFor(Duration duration) @Override public String getInfo() { - return (partitionedSplit == null) ? "" : partitionedSplit.getSplit().getInfo().toString(); + if (partitionedSplit != null && partitionedSplit.getSplit() != null && partitionedSplit.getSplit().getInfo() != null) { + return partitionedSplit.getSplit().getInfo().toString(); + } + return ""; } @Override diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/DelegatingMetadataManager.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/DelegatingMetadataManager.java index 751d902e46c1a..4b54732d6987a 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/DelegatingMetadataManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/DelegatingMetadataManager.java @@ -36,6 +36,7 @@ import com.facebook.presto.spi.connector.ConnectorCapabilities; import com.facebook.presto.spi.connector.ConnectorOutputMetadata; import com.facebook.presto.spi.connector.ConnectorTableVersion; +import com.facebook.presto.spi.connector.TableFunctionApplicationResult; import com.facebook.presto.spi.constraints.TableConstraint; import com.facebook.presto.spi.function.SqlFunction; import com.facebook.presto.spi.plan.PartitioningHandle; @@ -648,4 +649,10 @@ public void addConstraint(Session session, TableHandle tableHandle, TableConstra { delegate.addConstraint(session, tableHandle, tableConstraint); } + + @Override + public Optional> applyTableFunction(Session session, TableFunctionHandle handle) + { + return delegate.applyTableFunction(session, handle); + } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/FunctionAndTypeManager.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/FunctionAndTypeManager.java index 4ea19e25698cd..06270566d503d 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/FunctionAndTypeManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/FunctionAndTypeManager.java @@ -47,11 +47,16 @@ import com.facebook.presto.spi.function.JavaAggregationFunctionImplementation; import com.facebook.presto.spi.function.JavaScalarFunctionImplementation; import com.facebook.presto.spi.function.ScalarFunctionImplementation; +import com.facebook.presto.spi.function.SchemaFunctionName; import com.facebook.presto.spi.function.Signature; import com.facebook.presto.spi.function.SqlFunction; import com.facebook.presto.spi.function.SqlFunctionId; import com.facebook.presto.spi.function.SqlFunctionSupplier; import com.facebook.presto.spi.function.SqlInvokedFunction; +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; +import com.facebook.presto.spi.function.table.TableFunctionDataProcessor; +import com.facebook.presto.spi.function.table.TableFunctionProcessorProvider; +import com.facebook.presto.spi.function.table.TableFunctionProcessorState; import com.facebook.presto.spi.type.TypeManagerContext; import com.facebook.presto.spi.type.TypeManagerFactory; import com.facebook.presto.sql.analyzer.FeaturesConfig; @@ -80,12 +85,14 @@ import java.util.Collection; import java.util.List; import java.util.Map; +import java.util.NoSuchElementException; import java.util.Objects; import java.util.Optional; import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; import java.util.regex.Pattern; import static com.facebook.presto.SystemSessionProperties.isExperimentalFunctionsEnabled; @@ -102,6 +109,7 @@ import static com.facebook.presto.spi.function.FunctionKind.SCALAR; import static com.facebook.presto.spi.function.SqlFunctionVisibility.EXPERIMENTAL; import static com.facebook.presto.spi.function.SqlFunctionVisibility.PUBLIC; +import static com.facebook.presto.spi.function.table.TableFunctionProcessorState.Finished.FINISHED; import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypeSignatures; import static com.facebook.presto.sql.planner.LiteralEncoder.MAGIC_LITERAL_FUNCTION_PREFIX; import static com.facebook.presto.sql.planner.LiteralEncoder.getMagicLiteralFunctionSignature; @@ -113,6 +121,7 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static com.google.common.collect.Iterables.getOnlyElement; import static java.lang.String.format; import static java.util.Collections.emptyList; import static java.util.Collections.singletonList; @@ -129,6 +138,7 @@ public class FunctionAndTypeManager { private static final Pattern DEFAULT_NAMESPACE_PREFIX_PATTERN = Pattern.compile("[a-z]+\\.[a-z]+"); private final TransactionManager transactionManager; + private final TableFunctionRegistry tableFunctionRegistry; private final BlockEncodingSerde blockEncodingSerde; private final BuiltInTypeAndFunctionNamespaceManager builtInTypeAndFunctionNamespaceManager; private final FunctionInvokerProvider functionInvokerProvider; @@ -145,10 +155,12 @@ public class FunctionAndTypeManager private final CatalogSchemaName defaultNamespace; private final AtomicReference servingTypeManager; private final AtomicReference>> servingTypeManagerParametricTypesSupplier; + private Optional> getTableFunctionProcessorProvider; @Inject public FunctionAndTypeManager( TransactionManager transactionManager, + TableFunctionRegistry tableFunctionRegistry, BlockEncodingSerde blockEncodingSerde, FeaturesConfig featuresConfig, FunctionsConfig functionsConfig, @@ -156,6 +168,7 @@ public FunctionAndTypeManager( Set types) { this.transactionManager = requireNonNull(transactionManager, "transactionManager is null"); + this.tableFunctionRegistry = requireNonNull(tableFunctionRegistry, "tableFunctionRegistry is null"); this.blockEncodingSerde = requireNonNull(blockEncodingSerde, "blockEncodingSerde is null"); this.builtInTypeAndFunctionNamespaceManager = new BuiltInTypeAndFunctionNamespaceManager(blockEncodingSerde, functionsConfig, types, this); this.functionNamespaceManagers.put(JAVA_BUILTIN_NAMESPACE.getCatalogName(), builtInTypeAndFunctionNamespaceManager); @@ -182,6 +195,7 @@ public static FunctionAndTypeManager createTestFunctionAndTypeManager() { return new FunctionAndTypeManager( createTestTransactionManager(), + new TableFunctionRegistry(), new BlockEncodingManager(), new FeaturesConfig(), new FunctionsConfig(), @@ -403,6 +417,11 @@ public void addFunctionNamespaceFactory(FunctionNamespaceManagerFactory factory) handleResolver.addFunctionNamespace(factory.getName(), factory.getHandleResolver()); } + public TableFunctionRegistry getTableFunctionRegistry() + { + return tableFunctionRegistry; + } + public void loadTypeManager(String typeManagerName) { requireNonNull(typeManagerName, "typeManagerName is null"); @@ -432,6 +451,11 @@ public void addTypeManagerFactory(TypeManagerFactory factory) } } + public TransactionManager getTransactionManager() + { + return transactionManager; + } + public void registerBuiltInFunctions(List functions) { builtInTypeAndFunctionNamespaceManager.registerBuiltInFunctions(functions); @@ -607,6 +631,11 @@ public ScalarFunctionImplementation getScalarFunctionImplementation(FunctionHand return functionNamespaceManager.get().getScalarFunctionImplementation(functionHandle); } + public Optional> getTableFunctionProcessorProvider() + { + return getTableFunctionProcessorProvider; + } + public AggregationFunctionImplementation getAggregateFunctionImplementation(FunctionHandle functionHandle) { Optional> functionNamespaceManager = getServingFunctionNamespaceManager(functionHandle.getCatalogSchemaName()); @@ -952,4 +981,25 @@ public String toString() .toString(); } } + + public void setGetTableFunctionProcessorProvider(Optional> getTableFunctionProcessorProvider) + { + this.getTableFunctionProcessorProvider = getTableFunctionProcessorProvider; + } + + private class IdentityFunctionProcessorProvider + implements TableFunctionProcessorProvider + { + @Override + public TableFunctionDataProcessor getDataProcessor(ConnectorTableFunctionHandle handle) + { + return input -> { + if (input == null) { + return FINISHED; + } + Optional inputPage = getOnlyElement(input); + return inputPage.map(TableFunctionProcessorState.Processed::usedInputAndProduced).orElseThrow(NoSuchElementException::new); + }; + } + } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/HandleJsonModule.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/HandleJsonModule.java index 61eae56f41895..4249087ce1849 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/HandleJsonModule.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/HandleJsonModule.java @@ -38,6 +38,7 @@ public void configure(Binder binder) jsonBinder(binder).addModuleBinding().to(PartitioningHandleJacksonModule.class); jsonBinder(binder).addModuleBinding().to(FunctionHandleJacksonModule.class); jsonBinder(binder).addModuleBinding().to(MetadataUpdateJacksonModule.class); + jsonBinder(binder).addModuleBinding().to(TableFunctionJacksonHandleModule.class); binder.bind(HandleResolver.class).in(Scopes.SINGLETON); } diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/HandleResolver.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/HandleResolver.java index 09992ef314575..6c00b17d4e30b 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/HandleResolver.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/HandleResolver.java @@ -29,13 +29,18 @@ import com.facebook.presto.spi.connector.ConnectorTransactionHandle; import com.facebook.presto.spi.function.FunctionHandle; import com.facebook.presto.spi.function.FunctionHandleResolver; +import com.facebook.presto.spi.function.TableFunctionHandleResolver; +import com.facebook.presto.spi.function.TableFunctionSplitResolver; +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; import com.facebook.presto.split.EmptySplitHandleResolver; +import com.google.common.collect.ImmutableSet; import javax.inject.Inject; import java.util.Map.Entry; import java.util.Objects; import java.util.Optional; +import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.function.Function; @@ -50,6 +55,8 @@ public class HandleResolver { private final ConcurrentMap handleResolvers = new ConcurrentHashMap<>(); private final ConcurrentMap functionHandleResolvers = new ConcurrentHashMap<>(); + private final ConcurrentMap tableFunctionHandleResolvers = new ConcurrentHashMap<>(); + private final ConcurrentMap tableFunctionSplitResolvers = new ConcurrentHashMap<>(); @Inject public HandleResolver() @@ -80,6 +87,22 @@ public void addFunctionNamespace(String name, FunctionHandleResolver resolver) checkState(existingResolver == null || existingResolver.equals(resolver), "Name %s is already assigned to function resolver: %s", name, existingResolver); } + public void addTableFunctionNamespace(String name, TableFunctionHandleResolver resolver) + { + requireNonNull(name, "name is null"); + requireNonNull(resolver, "resolver is null"); + MaterializedTableFunctionHandleResolver existingResolver = tableFunctionHandleResolvers.putIfAbsent(name, new MaterializedTableFunctionHandleResolver(resolver)); + checkState(existingResolver == null || existingResolver.equals(resolver), "Name %s is already assigned to table function resolver: %s", name, existingResolver); + } + + public void addTableFunctionSplitNamespace(String name, TableFunctionSplitResolver resolver) + { + requireNonNull(name, "name is null"); + requireNonNull(resolver, "resolver is null"); + MaterializedTableFunctionSplitResolver existingResolver = tableFunctionSplitResolvers.putIfAbsent(name, new MaterializedTableFunctionSplitResolver(resolver)); + checkState(existingResolver == null || existingResolver.equals(resolver), "Name %s is already assigned to table function resolver: %s", name, existingResolver); + } + public String getId(ConnectorTableHandle tableHandle) { return getId(tableHandle, MaterializedHandleResolver::getTableHandleClass); @@ -97,7 +120,16 @@ public String getId(ColumnHandle columnHandle) public String getId(ConnectorSplit split) { - return getId(split, MaterializedHandleResolver::getSplitClass); + // TODO: Clean up all connectors and make MaterializedTableFunctionSplitResolver + // just a MaterializedSplitResolver and add Connector's split classes to it. + // This was added as different table functions in the connector can have different + // split classes, but be part of the same Connector. + try { + return getId(split, MaterializedHandleResolver::getSplitClass); + } + catch (Exception e) { + return getTableFunctionSplitId(split, MaterializedTableFunctionSplitResolver::getTableFunctionSplitClasses); + } } public String getId(ConnectorIndexHandle indexHandle) @@ -140,6 +172,11 @@ public String getId(ConnectorMetadataUpdateHandle metadataUpdateHandle) return getId(metadataUpdateHandle, MaterializedHandleResolver::getMetadataUpdateHandleClass); } + public String getId(ConnectorTableFunctionHandle tableFunctionHandle) + { + return getTableFunctionId(tableFunctionHandle, MaterializedTableFunctionHandleResolver::getTableFunctionHandleClasses); + } + public Class getTableHandleClass(String id) { return resolverFor(id).getTableHandleClass().orElseThrow(() -> new IllegalArgumentException("No resolver for " + id)); @@ -157,6 +194,16 @@ public Class getColumnHandleClass(String id) public Class getSplitClass(String id) { + Optional> tableFunctionSplit; + for (Entry entry : tableFunctionSplitResolvers.entrySet()) { + MaterializedTableFunctionSplitResolver resolver = entry.getValue(); + tableFunctionSplit = resolver.getTableFunctionSplitClasses().stream() + .filter(handle -> (entry.getKey() + ":" + handle.getName()).equals(id)) + .findFirst(); + if (tableFunctionSplit.isPresent()) { + return tableFunctionSplit.get(); + } + } return resolverFor(id).getSplitClass().orElseThrow(() -> new IllegalArgumentException("No resolver for " + id)); } @@ -200,6 +247,21 @@ public Class getMetadataUpdateHandleCla return resolverFor(id).getMetadataUpdateHandleClass().orElseThrow(() -> new IllegalArgumentException("No resolver for " + id)); } + public Class getTableFunctionHandleClass(String id) + { + Optional> tableFunctionHandle; + for (Entry entry : tableFunctionHandleResolvers.entrySet()) { + MaterializedTableFunctionHandleResolver resolver = entry.getValue(); + tableFunctionHandle = resolver.getTableFunctionHandleClasses().stream() + .filter(handle -> (entry.getKey() + ":" + handle.getName()).equals(id)) + .findFirst(); + if (tableFunctionHandle.isPresent()) { + return tableFunctionHandle.get(); + } + } + throw new IllegalArgumentException("No handle resolver for table function namespace: " + id); + } + private MaterializedHandleResolver resolverFor(String id) { MaterializedHandleResolver resolver = handleResolvers.get(id); @@ -242,6 +304,42 @@ private String getFunctionNamespaceId(T handle, Function String getTableFunctionId(ConnectorTableFunctionHandle handle, Function>> getters) + { + for (Entry entry : tableFunctionHandleResolvers.entrySet()) { + try { + Optional id = getters.apply(entry.getValue()).stream() + .filter(clazz -> clazz.isInstance(handle)) + .map(Class::getName) + .findFirst(); + if (id.isPresent()) { + return entry.getKey() + ":" + id.get(); + } + } + catch (UnsupportedOperationException ignored) { + } + } + throw new IllegalArgumentException("No function namespace for handle: " + handle); + } + + private String getTableFunctionSplitId(ConnectorSplit split, Function>> getters) + { + for (Entry entry : tableFunctionSplitResolvers.entrySet()) { + try { + Optional id = getters.apply(entry.getValue()).stream() + .filter(clazz -> clazz.isInstance(split)) + .map(Class::getName) + .findFirst(); + if (id.isPresent()) { + return entry.getKey() + ":" + id.get(); + } + } + catch (UnsupportedOperationException ignored) { + } + } + throw new IllegalArgumentException("No function namespace for handle: " + split); + } + private static class MaterializedHandleResolver { private final Optional> tableHandle; @@ -409,4 +507,92 @@ public int hashCode() return Objects.hash(functionHandle); } } + + private static class MaterializedTableFunctionHandleResolver + { + private final Set> tableFunctionHandles; + + public MaterializedTableFunctionHandleResolver(TableFunctionHandleResolver resolver) + { + tableFunctionHandles = getHandleClass(resolver::getTableFunctionHandleClasses); + } + + private static Set> getHandleClass(Supplier>> callable) + { + try { + return callable.get(); + } + catch (UnsupportedOperationException e) { + return ImmutableSet.of(); + } + } + + public Set> getTableFunctionHandleClasses() + { + return tableFunctionHandles; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + MaterializedTableFunctionHandleResolver that = (MaterializedTableFunctionHandleResolver) o; + return Objects.equals(tableFunctionHandles, that.tableFunctionHandles); + } + + @Override + public int hashCode() + { + return Objects.hash(tableFunctionHandles); + } + } + + private static class MaterializedTableFunctionSplitResolver + { + private final Set> tableFunctionSplits; + + public MaterializedTableFunctionSplitResolver(TableFunctionSplitResolver resolver) + { + tableFunctionSplits = getSplitClass(resolver::getTableFunctionSplitClasses); + } + + private static Set> getSplitClass(Supplier>> callable) + { + try { + return callable.get(); + } + catch (UnsupportedOperationException e) { + return ImmutableSet.of(); + } + } + + public Set> getTableFunctionSplitClasses() + { + return tableFunctionSplits; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + MaterializedTableFunctionSplitResolver that = (MaterializedTableFunctionSplitResolver) o; + return Objects.equals(tableFunctionSplits, that.tableFunctionSplits); + } + + @Override + public int hashCode() + { + return Objects.hash(tableFunctionSplits); + } + } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/Metadata.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/Metadata.java index 96ac818709096..5233c67e93fb6 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/Metadata.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/Metadata.java @@ -42,6 +42,7 @@ import com.facebook.presto.spi.connector.ConnectorOutputMetadata; import com.facebook.presto.spi.connector.ConnectorPartitioningHandle; import com.facebook.presto.spi.connector.ConnectorTableVersion; +import com.facebook.presto.spi.connector.TableFunctionApplicationResult; import com.facebook.presto.spi.constraints.TableConstraint; import com.facebook.presto.spi.function.SqlFunction; import com.facebook.presto.spi.plan.PartitioningHandle; @@ -523,4 +524,6 @@ default TableLayoutFilterCoverage getTableLayoutFilterCoverage(Session session, void dropConstraint(Session session, TableHandle tableHandle, Optional constraintName, Optional columnName); void addConstraint(Session session, TableHandle tableHandle, TableConstraint tableConstraint); + + Optional> applyTableFunction(Session session, TableFunctionHandle handle); } diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/MetadataManager.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/MetadataManager.java index 87a83d7060ff0..83431a86f65aa 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/MetadataManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/MetadataManager.java @@ -62,6 +62,7 @@ import com.facebook.presto.spi.connector.ConnectorPartitioningMetadata; import com.facebook.presto.spi.connector.ConnectorTableVersion; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import com.facebook.presto.spi.connector.TableFunctionApplicationResult; import com.facebook.presto.spi.constraints.TableConstraint; import com.facebook.presto.spi.function.SqlFunction; import com.facebook.presto.spi.plan.PartitioningHandle; @@ -241,7 +242,7 @@ public static MetadataManager createTestMetadataManager(TransactionManager trans { BlockEncodingManager blockEncodingManager = new BlockEncodingManager(); return new MetadataManager( - new FunctionAndTypeManager(transactionManager, blockEncodingManager, featuresConfig, functionsConfig, new HandleResolver(), ImmutableSet.of()), + new FunctionAndTypeManager(transactionManager, new TableFunctionRegistry(), blockEncodingManager, featuresConfig, functionsConfig, new HandleResolver(), ImmutableSet.of()), blockEncodingManager, createTestingSessionPropertyManager(), new SchemaPropertyManager(), @@ -1477,6 +1478,18 @@ public void addConstraint(Session session, TableHandle tableHandle, TableConstra metadata.addConstraint(session.toConnectorSession(connectorId), tableHandle.getConnectorHandle(), tableConstraint); } + @Override + public Optional> applyTableFunction(Session session, TableFunctionHandle handle) + { + ConnectorId connectorId = handle.getConnectorId(); + ConnectorMetadata metadata = getMetadata(session, connectorId); + + return metadata.applyTableFunction(session.toConnectorSession(connectorId), handle.getFunctionHandle()) + .map(result -> new TableFunctionApplicationResult<>( + new TableHandle(connectorId, result.getTableHandle(), handle.getTransactionHandle(), Optional.empty()), + result.getColumnHandles())); + } + private ViewDefinition deserializeView(String data) { try { diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/TableFunctionHandle.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/TableFunctionHandle.java new file mode 100644 index 0000000000000..10e34e48fab10 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/TableFunctionHandle.java @@ -0,0 +1,68 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.metadata; + +import com.facebook.presto.spi.ConnectorId; +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import com.facebook.presto.spi.function.SchemaFunctionName; +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import static java.util.Objects.requireNonNull; + +public class TableFunctionHandle +{ + private final ConnectorId connectorId; + private final SchemaFunctionName schemaFunctionName; + private final ConnectorTableFunctionHandle functionHandle; + private final ConnectorTransactionHandle transactionHandle; + + @JsonCreator + public TableFunctionHandle( + @JsonProperty("connectorId") ConnectorId connectorId, + @JsonProperty("schemaFunctionName") SchemaFunctionName schemaFunctionName, + @JsonProperty("functionHandle") ConnectorTableFunctionHandle functionHandle, + @JsonProperty("transactionHandle") ConnectorTransactionHandle transactionHandle) + { + this.connectorId = requireNonNull(connectorId, "connectorId is null"); + this.schemaFunctionName = requireNonNull(schemaFunctionName, "schemaFunctionName is null"); + this.functionHandle = requireNonNull(functionHandle, "functionHandle is null"); + this.transactionHandle = requireNonNull(transactionHandle, "transactionHandle is null"); + } + + @JsonProperty + public ConnectorId getConnectorId() + { + return connectorId; + } + + @JsonProperty + public SchemaFunctionName getSchemaFunctionName() + { + return schemaFunctionName; + } + + @JsonProperty + public ConnectorTableFunctionHandle getFunctionHandle() + { + return functionHandle; + } + + @JsonProperty + public ConnectorTransactionHandle getTransactionHandle() + { + return transactionHandle; + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/TableFunctionJacksonHandleModule.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/TableFunctionJacksonHandleModule.java new file mode 100644 index 0000000000000..525a94b4e07cd --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/TableFunctionJacksonHandleModule.java @@ -0,0 +1,30 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.metadata; + +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; + +import javax.inject.Inject; + +public class TableFunctionJacksonHandleModule + extends AbstractTypedJacksonModule +{ + @Inject + public TableFunctionJacksonHandleModule(HandleResolver handleResolver) + { + super(ConnectorTableFunctionHandle.class, + handleResolver::getId, + handleResolver::getTableFunctionHandleClass); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/TableFunctionMetadata.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/TableFunctionMetadata.java new file mode 100644 index 0000000000000..806215927b736 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/TableFunctionMetadata.java @@ -0,0 +1,41 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.metadata; + +import com.facebook.presto.spi.ConnectorId; +import com.facebook.presto.spi.function.table.ConnectorTableFunction; + +import static java.util.Objects.requireNonNull; + +public class TableFunctionMetadata +{ + private final ConnectorId connectorId; + private final ConnectorTableFunction function; + + public TableFunctionMetadata(ConnectorId connectorId, ConnectorTableFunction function) + { + this.connectorId = requireNonNull(connectorId, "connectorId is null"); + this.function = requireNonNull(function, "function is null"); + } + + public ConnectorId getConnectorId() + { + return connectorId; + } + + public ConnectorTableFunction getFunction() + { + return function; + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/TableFunctionRegistry.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/TableFunctionRegistry.java new file mode 100644 index 0000000000000..1da3b4c5e529b --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/TableFunctionRegistry.java @@ -0,0 +1,160 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.metadata; + +import com.facebook.presto.Session; +import com.facebook.presto.spi.ConnectorId; +import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.StandardErrorCode; +import com.facebook.presto.spi.function.CatalogSchemaFunctionName; +import com.facebook.presto.spi.function.SchemaFunctionName; +import com.facebook.presto.spi.function.table.ArgumentSpecification; +import com.facebook.presto.spi.function.table.ConnectorTableFunction; +import com.facebook.presto.spi.function.table.ReturnTypeSpecification.DescribedTable; +import com.facebook.presto.spi.function.table.TableArgumentSpecification; +import com.facebook.presto.sql.analyzer.SemanticException; +import com.facebook.presto.sql.tree.QualifiedName; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; + +import javax.annotation.concurrent.ThreadSafe; + +import java.util.Collection; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; + +import static com.facebook.presto.spi.StandardErrorCode.MISSING_CATALOG_NAME; +import static com.facebook.presto.spi.function.table.Preconditions.checkArgument; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.CATALOG_NOT_SPECIFIED; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.SCHEMA_NOT_SPECIFIED; +import static com.google.common.base.Preconditions.checkState; +import static java.util.Locale.ENGLISH; +import static java.util.Objects.requireNonNull; + +@ThreadSafe +public class TableFunctionRegistry +{ + // catalog name in the original case; schema and function name in lowercase + private final Map> tableFunctions = new ConcurrentHashMap<>(); + + public void addTableFunctions(ConnectorId catalogName, Collection functions) + { + requireNonNull(catalogName, "catalogName is null"); + requireNonNull(functions, "functions is null"); + + functions.stream() + .forEach(TableFunctionRegistry::validateTableFunction); + + ImmutableMap.Builder builder = ImmutableMap.builder(); + for (ConnectorTableFunction function : functions) { + builder.put( + new SchemaFunctionName( + function.getSchema().toLowerCase(ENGLISH), + function.getName().toLowerCase(ENGLISH)), + new TableFunctionMetadata(catalogName, function)); + } + checkState(tableFunctions.putIfAbsent(catalogName, builder.buildOrThrow()) == null, "Table functions already registered for catalog: " + catalogName); + } + + public void removeTableFunctions(ConnectorId catalogName) + { + tableFunctions.remove(catalogName); + } + + public static List toPath(Session session, QualifiedName name) + { + List parts = name.getParts(); + if (parts.size() > 3) { + throw new PrestoException(StandardErrorCode.FUNCTION_NOT_FOUND, "Invalid function name: " + name); + } + if (parts.size() == 3) { + return ImmutableList.of(new CatalogSchemaFunctionName(parts.get(0), parts.get(1), parts.get(2))); + } + + if (parts.size() == 2) { + String currentCatalog = session.getCatalog() + .orElseThrow(() -> new PrestoException(MISSING_CATALOG_NAME, "Session default catalog must be set to resolve a partial function name: " + name)); + return ImmutableList.of(new CatalogSchemaFunctionName(currentCatalog, parts.get(0), parts.get(1))); + } + + ImmutableList.Builder names = ImmutableList.builder(); + + String currentCatalog = session.getCatalog() + .orElseThrow(() -> new SemanticException(CATALOG_NOT_SPECIFIED, "Catalog must be specified when session catalog is not set")); + String currentSchema = session.getSchema() + .orElseThrow(() -> new SemanticException(SCHEMA_NOT_SPECIFIED, "Schema must be specified when session schema is not set")); + + // add resolved path items + names.add(new CatalogSchemaFunctionName(currentCatalog, currentSchema, parts.get(0))); + return names.build(); + } + + /** + * Resolve table function with given qualified name. + * Table functions are resolved case-insensitive for consistency with existing scalar function resolution. + */ + public TableFunctionMetadata resolve(Session session, QualifiedName qualifiedName) + { + for (CatalogSchemaFunctionName name : toPath(session, qualifiedName)) { + ConnectorId connectorId = new ConnectorId(name.getCatalogName()); + Map catalogFunctions = tableFunctions.get(connectorId); + if (catalogFunctions != null) { + String lowercasedSchemaName = name.getSchemaFunctionName().getSchemaName().toLowerCase(ENGLISH); + String lowercasedFunctionName = name.getSchemaFunctionName().getFunctionName().toLowerCase(ENGLISH); + TableFunctionMetadata function = catalogFunctions.get(new SchemaFunctionName(lowercasedSchemaName, lowercasedFunctionName)); + if (function != null) { + return function; + } + } + } + + return null; + } + + private static void validateTableFunction(ConnectorTableFunction tableFunction) + { + requireNonNull(tableFunction, "tableFunction is null"); + requireNonNull(tableFunction.getName(), "table function name is null"); + requireNonNull(tableFunction.getSchema(), "table function schema name is null"); + requireNonNull(tableFunction.getArguments(), "table function arguments is null"); + requireNonNull(tableFunction.getReturnTypeSpecification(), "table function returnTypeSpecification is null"); + + checkArgument(!tableFunction.getName().isEmpty(), "table function name is empty"); + checkArgument(!tableFunction.getSchema().isEmpty(), "table function schema name is empty"); + + Set argumentNames = new HashSet<>(); + for (ArgumentSpecification specification : tableFunction.getArguments()) { + if (!argumentNames.add(specification.getName())) { + throw new IllegalArgumentException("duplicate argument name: " + specification.getName()); + } + } + long tableArgumentsWithRowSemantics = tableFunction.getArguments().stream() + .filter(specification -> specification instanceof TableArgumentSpecification) + .map(TableArgumentSpecification.class::cast) + .filter(TableArgumentSpecification::isRowSemantics) + .count(); + checkArgument(tableArgumentsWithRowSemantics <= 1, "more than one table argument with row semantics"); + // The 'keep when empty' or 'prune when empty' property must not be explicitly specified for a table argument with row semantics. + // Such a table argument is implicitly 'prune when empty'. The TableArgumentSpecification.Builder enforces the 'prune when empty' property + // for a table argument with row semantics. + + if (tableFunction.getReturnTypeSpecification() instanceof DescribedTable) { + DescribedTable describedTable = (DescribedTable) tableFunction.getReturnTypeSpecification(); + checkArgument(describedTable.getDescriptor().isTyped(), "field types missing in returned type specification"); + } + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/EmptyTableFunctionPartition.java b/presto-main-base/src/main/java/com/facebook/presto/operator/EmptyTableFunctionPartition.java new file mode 100644 index 0000000000000..bda83ae6319d4 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/EmptyTableFunctionPartition.java @@ -0,0 +1,107 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator; + +import com.facebook.presto.common.Page; +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.block.RunLengthEncodedBlock; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.function.table.TableFunctionDataProcessor; +import com.facebook.presto.spi.function.table.TableFunctionProcessorState; + +import java.util.List; + +import static com.facebook.airlift.concurrent.MoreFutures.toListenableFuture; +import static com.facebook.presto.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_ERROR; +import static com.facebook.presto.spi.function.table.TableFunctionProcessorState.Finished.FINISHED; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +/** + * This is a class representing empty input to a table function. An EmptyTableFunctionPartition is created + * when the table function has KEEP WHEN EMPTY property, which means that the function should be executed + * even if the input is empty, and all the table arguments are empty relations. + *

+ * An EmptyTableFunctionPartition is created and processed once per node. To avoid duplicated execution, + * a table function having KEEP WHEN EMPTY property must have single distribution. + */ +public class EmptyTableFunctionPartition + implements TableFunctionPartition +{ + private final TableFunctionDataProcessor tableFunction; + private final int properChannelsCount; + private final int passThroughSourcesCount; + private final Type[] passThroughTypes; + + public EmptyTableFunctionPartition(TableFunctionDataProcessor tableFunction, int properChannelsCount, int passThroughSourcesCount, List passThroughTypes) + { + this.tableFunction = requireNonNull(tableFunction, "tableFunction is null"); + this.properChannelsCount = properChannelsCount; + this.passThroughSourcesCount = passThroughSourcesCount; + this.passThroughTypes = passThroughTypes.toArray(new Type[] {}); + } + + @Override + public WorkProcessor toOutputPages() + { + return WorkProcessor.create(() -> { + TableFunctionProcessorState state = tableFunction.process(null); + if (state == FINISHED) { + return WorkProcessor.ProcessState.finished(); + } + if (state instanceof TableFunctionProcessorState.Blocked) { + return WorkProcessor.ProcessState.blocked(toListenableFuture(((TableFunctionProcessorState.Blocked) state).getFuture())); + } + TableFunctionProcessorState.Processed processed = (TableFunctionProcessorState.Processed) state; + if (processed.getResult() != null) { + return WorkProcessor.ProcessState.ofResult(appendNullsForPassThroughColumns(processed.getResult())); + } + throw new PrestoException(FUNCTION_IMPLEMENTATION_ERROR, "When function got no input, it should either produce output or return Blocked state"); + }); + } + + private Page appendNullsForPassThroughColumns(Page page) + { + if (page.getChannelCount() != properChannelsCount + passThroughSourcesCount) { + throw new PrestoException( + FUNCTION_IMPLEMENTATION_ERROR, + format( + "Table function returned a page containing %s channels. Expected channel number: %s (%s proper columns, %s pass-through index columns)", + page.getChannelCount(), + properChannelsCount + passThroughSourcesCount, + properChannelsCount, + passThroughSourcesCount)); + } + + Block[] resultBlocks = new Block[properChannelsCount + passThroughTypes.length]; + + // proper outputs first + for (int channel = 0; channel < properChannelsCount; channel++) { + resultBlocks[channel] = page.getBlock(channel); + } + + // pass-through columns next + // because no input was processed, all pass-through indexes in the result page must be null (there are no input rows they could refer to). + // for performance reasons this is not checked. All pass-through columns are filled with nulls. + int channel = properChannelsCount; + for (Type type : passThroughTypes) { + resultBlocks[channel] = RunLengthEncodedBlock.create(type, null, page.getPositionCount()); + channel++; + } + + // pass the position count so that the Page can be successfully created in the case when there are no output channels (resultBlocks is empty) + return new Page(page.getPositionCount(), resultBlocks); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/LeafTableFunctionOperator.java b/presto-main-base/src/main/java/com/facebook/presto/operator/LeafTableFunctionOperator.java new file mode 100644 index 0000000000000..ca66af5e97073 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/LeafTableFunctionOperator.java @@ -0,0 +1,212 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator; + +import com.facebook.presto.common.Page; +import com.facebook.presto.execution.ScheduledSplit; +import com.facebook.presto.metadata.Split; +import com.facebook.presto.spi.ConnectorSplit; +import com.facebook.presto.spi.UpdatablePageSource; +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; +import com.facebook.presto.spi.function.table.TableFunctionProcessorProvider; +import com.facebook.presto.spi.function.table.TableFunctionProcessorState; +import com.facebook.presto.spi.function.table.TableFunctionProcessorState.Blocked; +import com.facebook.presto.spi.function.table.TableFunctionProcessorState.Processed; +import com.facebook.presto.spi.function.table.TableFunctionSplitProcessor; +import com.facebook.presto.spi.plan.PlanNodeId; +import com.google.common.util.concurrent.ListenableFuture; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.function.Supplier; + +import static com.facebook.airlift.concurrent.MoreFutures.toListenableFuture; +import static com.facebook.presto.spi.function.table.TableFunctionProcessorState.Finished.FINISHED; +import static com.google.common.base.Preconditions.checkState; +import static java.util.Objects.requireNonNull; + +public class LeafTableFunctionOperator + implements SourceOperator +{ + public static class LeafTableFunctionOperatorFactory + implements SourceOperatorFactory + { + private final int operatorId; + private final PlanNodeId sourceId; + private final TableFunctionProcessorProvider tableFunctionProvider; + private final ConnectorTableFunctionHandle functionHandle; + private boolean closed; + + public LeafTableFunctionOperatorFactory(int operatorId, PlanNodeId sourceId, TableFunctionProcessorProvider tableFunctionProvider, ConnectorTableFunctionHandle functionHandle) + { + this.operatorId = operatorId; + this.sourceId = requireNonNull(sourceId, "sourceId is null"); + this.tableFunctionProvider = requireNonNull(tableFunctionProvider, "tableFunctionProvider is null"); + this.functionHandle = requireNonNull(functionHandle, "functionHandle is null"); + } + + @Override + public PlanNodeId getSourceId() + { + return sourceId; + } + + @Override + public SourceOperator createOperator(DriverContext driverContext) + { + checkState(!closed, "Factory is already closed"); + OperatorContext operatorContext = driverContext.addOperatorContext(operatorId, sourceId, LeafTableFunctionOperator.class.getSimpleName()); + return new LeafTableFunctionOperator(operatorContext, sourceId, tableFunctionProvider, functionHandle); + } + + @Override + public void noMoreOperators() + { + closed = true; + } + } + + private final OperatorContext operatorContext; + private final PlanNodeId sourceId; + private final TableFunctionProcessorProvider tableFunctionProvider; + private final ConnectorTableFunctionHandle functionHandle; + + private ConnectorSplit currentSplit; + private final List pendingSplits = new ArrayList<>(); + private boolean noMoreSplits; + + private TableFunctionSplitProcessor processor; + private boolean processorUsedData; + private boolean processorFinishedSplit = true; + private ListenableFuture processorBlocked = NOT_BLOCKED; + + public LeafTableFunctionOperator(OperatorContext operatorContext, PlanNodeId sourceId, TableFunctionProcessorProvider tableFunctionProvider, ConnectorTableFunctionHandle functionHandle) + { + this.operatorContext = requireNonNull(operatorContext, "operatorContext is null"); + this.sourceId = requireNonNull(sourceId, "sourceId is null"); + this.tableFunctionProvider = requireNonNull(tableFunctionProvider, "tableFunctionProvider is null"); + this.functionHandle = requireNonNull(functionHandle, "functionHandle is null"); + } + + private void resetProcessor() + { + this.processor = tableFunctionProvider.getSplitProcessor(functionHandle); + this.processorUsedData = false; + this.processorFinishedSplit = false; + this.processorBlocked = NOT_BLOCKED; + } + + @Override + public OperatorContext getOperatorContext() + { + return operatorContext; + } + + @Override + public PlanNodeId getSourceId() + { + return sourceId; + } + + @Override + public boolean needsInput() + { + return false; + } + + @Override + public void addInput(Page page) + { + throw new UnsupportedOperationException(getClass().getName() + " does not take input"); + } + + @Override + public Supplier> addSplit(ScheduledSplit split) + { + Split curSplit = requireNonNull(split, "split is null").getSplit(); + checkState(!noMoreSplits, "no more splits expected"); + ConnectorSplit curConnectorSplit = curSplit.getConnectorSplit(); + pendingSplits.add(curConnectorSplit); + return Optional::empty; + } + + @Override + public void noMoreSplits() + { + noMoreSplits = true; + } + + @Override + public Page getOutput() + { + if (processorFinishedSplit) { + // start processing a new split + if (pendingSplits.isEmpty()) { + // no more splits to process at the moment + return null; + } + currentSplit = pendingSplits.remove(0); + resetProcessor(); + } + else { + // a split is being processed + requireNonNull(currentSplit, "currentSplit is null"); + } + + TableFunctionProcessorState state = processor.process(processorUsedData ? null : currentSplit); + if (state == FINISHED) { + processorFinishedSplit = true; + } + if (state instanceof Blocked) { + Blocked blocked = (Blocked) state; + processorBlocked = toListenableFuture(blocked.getFuture()); + } + if (state instanceof Processed) { + Processed processed = (Processed) state; + if (processed.isUsedInput()) { + processorUsedData = true; + } + if (processed.getResult() != null) { + return processed.getResult(); + } + } + return null; + } + + @Override + public ListenableFuture isBlocked() + { + return processorBlocked; + } + + @Override + public void finish() + { + // this method is redundant. the operator takes no input at all. noMoreSplits() should be called instead. + } + + @Override + public boolean isFinished() + { + return processorFinishedSplit && pendingSplits.isEmpty() && noMoreSplits; + } + + @Override + public void close() + throws Exception + { + // TODO + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/PageBuffer.java b/presto-main-base/src/main/java/com/facebook/presto/operator/PageBuffer.java new file mode 100644 index 0000000000000..0ad8f695b2faa --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/PageBuffer.java @@ -0,0 +1,76 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator; + +import com.facebook.presto.common.Page; + +import javax.annotation.Nullable; + +import static com.facebook.presto.operator.WorkProcessor.ProcessState.finished; +import static com.facebook.presto.operator.WorkProcessor.ProcessState.ofResult; +import static com.facebook.presto.operator.WorkProcessor.ProcessState.yield; +import static com.google.common.base.Preconditions.checkState; +import static java.util.Objects.requireNonNull; + +public class PageBuffer +{ + @Nullable + private Page page; + private boolean finished; + + public WorkProcessor pages() + { + return WorkProcessor.create(() -> { + if (isFinished() && isEmpty()) { + return finished(); + } + + if (!isEmpty()) { + Page result = page; + page = null; + return ofResult(result); + } + + return yield(); + }); + } + + public boolean isEmpty() + { + return page == null; + } + + public boolean isFinished() + { + return finished; + } + + public void add(Page page) + { + checkState(isEmpty(), "page buffer is not empty"); + checkState(!isFinished(), "page buffer is finished"); + requireNonNull(page, "page is null"); + + if (page.getPositionCount() == 0) { + return; + } + + this.page = page; + } + + public void finish() + { + finished = true; + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/PagesHashStrategy.java b/presto-main-base/src/main/java/com/facebook/presto/operator/PagesHashStrategy.java index 14203bf9f10bd..3b27de227cc7e 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/PagesHashStrategy.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/PagesHashStrategy.java @@ -59,6 +59,15 @@ public interface PagesHashStrategy */ boolean positionEqualsRow(int leftBlockIndex, int leftPosition, int rightPosition, Page rightPage); + /** + * Compares the hashed columns in this PagesHashStrategy to the values in the specified page. The + * values are compared positionally, so {@code rightPage} must have the same number of entries as + * the hashed columns and each entry is expected to be the same type. + * {@code rightPage} is used if join uses filter function and must contain all columns from probe side of join. + * The values are compared under "not distinct from" semantics. + */ + boolean positionNotDistinctFromRow(int leftBlockIndex, int leftPosition, int rightPosition, Page rightPage); + /** * Compares the hashed columns in this PagesHashStrategy to the values in the specified page. The * values are compared positionally, so {@code rightPage} must have the same number of entries as diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/PagesIndex.java b/presto-main-base/src/main/java/com/facebook/presto/operator/PagesIndex.java index bd2f5ffbc817a..b3cc42a3ce998 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/PagesIndex.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/PagesIndex.java @@ -271,9 +271,9 @@ public void swap(int a, int b) valueAddresses.swap(a, b); } - public int buildPage(int position, int[] outputChannels, PageBuilder pageBuilder) + public int buildPage(int position, int endPosition, int[] outputChannels, PageBuilder pageBuilder) { - while (!pageBuilder.isFull() && position < positionCount) { + while (!pageBuilder.isFull() && position < endPosition) { long pageAddress = valueAddresses.get(position); int blockIndex = decodeSliceIndex(pageAddress); int blockPosition = decodePosition(pageAddress); @@ -563,10 +563,29 @@ protected Page computeNext() } public Iterator getSortedPages() + { + return getSortedPagesFromRange(0, positionCount); + } + + /** + * Get sorted pages from the specified section of the PagesIndex. + * + * @param start start position of the section, inclusive + * @param end end position of the section, exclusive + * @return iterator of pages + */ + public Iterator getSortedPages(int start, int end) + { + checkArgument(start >= 0 && end <= positionCount, "position range out of bounds"); + checkArgument(start <= end, "invalid position range"); + return getSortedPagesFromRange(start, end); + } + + private Iterator getSortedPagesFromRange(int start, int end) { return new AbstractIterator() { - private int currentPosition; + private int currentPosition = start; private final PageBuilder pageBuilder = new PageBuilder(types); private final int[] outputChannels = new int[types.size()]; @@ -577,7 +596,7 @@ public Iterator getSortedPages() @Override public Page computeNext() { - currentPosition = buildPage(currentPosition, outputChannels, pageBuilder); + currentPosition = buildPage(currentPosition, end, outputChannels, pageBuilder); if (pageBuilder.isEmpty()) { return endOfData(); } diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/RegularTableFunctionPartition.java b/presto-main-base/src/main/java/com/facebook/presto/operator/RegularTableFunctionPartition.java new file mode 100644 index 0000000000000..c5920f44f21dc --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/RegularTableFunctionPartition.java @@ -0,0 +1,438 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator; + +import com.facebook.presto.common.Page; +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.block.BlockBuilder; +import com.facebook.presto.common.block.RunLengthEncodedBlock; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.function.table.TableFunctionDataProcessor; +import com.facebook.presto.spi.function.table.TableFunctionProcessorState; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.primitives.Ints; + +import java.util.Arrays; +import java.util.Collection; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Optional; +import java.util.Set; + +import static com.facebook.airlift.concurrent.MoreFutures.toListenableFuture; +import static com.facebook.presto.common.Utils.checkState; +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_ERROR; +import static com.facebook.presto.spi.function.table.TableFunctionProcessorState.Finished.FINISHED; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.util.concurrent.Futures.immediateFuture; +import static java.lang.Math.min; +import static java.lang.Math.toIntExact; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +public class RegularTableFunctionPartition + implements TableFunctionPartition +{ + private final PagesIndex pagesIndex; + private final int partitionStart; + private final int partitionEnd; + private final Iterator sortedPages; + + private final TableFunctionDataProcessor tableFunction; + private final int properChannelsCount; + private final int passThroughSourcesCount; + + // channels required by the table function, listed by source in order of argument declarations + private final int[][] requiredChannels; + + // for each input channel, the end position of actual data in that channel (exclusive) relative to partition. The remaining rows are "filler" rows, and should not be passed to table function or passed-through + private final int[] endOfData; + + // a builder for each pass-through column, in order of argument declarations + private final PassThroughColumnProvider[] passThroughProviders; + + // number of processed input positions from partition start. all sources have been processed up to this position, except the sources whose partitions ended earlier. + private int processedPositions; + + public RegularTableFunctionPartition( + PagesIndex pagesIndex, + int partitionStart, + int partitionEnd, + TableFunctionDataProcessor tableFunction, + int properChannelsCount, + int passThroughSourcesCount, + List> requiredChannels, + Optional> markerChannels, + List passThroughSpecifications) + + { + checkArgument(pagesIndex.getPositionCount() != 0, "PagesIndex is empty for regular table function partition"); + this.pagesIndex = pagesIndex; + this.partitionStart = partitionStart; + this.partitionEnd = partitionEnd; + this.sortedPages = pagesIndex.getSortedPages(partitionStart, partitionEnd); + this.tableFunction = requireNonNull(tableFunction, "tableFunction is null"); + this.properChannelsCount = properChannelsCount; + this.passThroughSourcesCount = passThroughSourcesCount; + this.requiredChannels = requiredChannels.stream() + .map(Ints::toArray) + .toArray(int[][]::new); + this.endOfData = findEndOfData(markerChannels, requiredChannels, passThroughSpecifications); + for (List channels : requiredChannels) { + checkState( + channels.stream() + .mapToInt(channel -> endOfData[channel]) + .distinct() + .count() <= 1, + "end-of-data position is inconsistent within a table function source"); + } + this.passThroughProviders = new PassThroughColumnProvider[passThroughSpecifications.size()]; + for (int i = 0; i < passThroughSpecifications.size(); i++) { + passThroughProviders[i] = createColumnProvider(passThroughSpecifications.get(i)); + } + } + + @Override + public WorkProcessor toOutputPages() + { + return WorkProcessor.create(new WorkProcessor.Process() + { + List> inputPages = prepareInputPages(); + + @Override + public WorkProcessor.ProcessState process() + { + TableFunctionProcessorState state = tableFunction.process(inputPages); + boolean functionGotNoData = inputPages == null; + if (state == FINISHED) { + return WorkProcessor.ProcessState.finished(); + } + if (state instanceof TableFunctionProcessorState.Blocked) { + return WorkProcessor.ProcessState.blocked(toListenableFuture(((TableFunctionProcessorState.Blocked) state).getFuture())); + } + TableFunctionProcessorState.Processed processed = (TableFunctionProcessorState.Processed) state; + if (processed.isUsedInput()) { + inputPages = prepareInputPages(); + } + if (processed.getResult() != null) { + return WorkProcessor.ProcessState.ofResult(appendPassThroughColumns(processed.getResult())); + } + if (functionGotNoData) { + throw new PrestoException(FUNCTION_IMPLEMENTATION_ERROR, "When function got no input, it should either produce output or return Blocked state"); + } + return WorkProcessor.ProcessState.blocked(immediateFuture(null)); + } + }); + } + + /** + * Iterate over the partition by page and extract pages for each table function source from the input page. + * For each source, project the columns required by the table function. + * If for some source all data in the partition has been consumed, Optional.empty() is returned for that source. + * It happens when the partition of this source is shorter than the partition of some other source. + * The overall length of the table function partition is equal to the length of the longest source partition. + * When all sources are fully consumed, this method returns null. + *

+ * NOTE: There are two types of table function's source semantics: set and row. The two types of sources should be handled + * by the TableFunctionDataProcessor in different ways. For a source with set semantics, the whole partition can be used for computations, + * while for a source with row semantics, each row should be processed independently from all other rows. + * To enforce that behavior, we could pass to the TableFunctionDataProcessor only one row from a table with row semantics. + * However, for performance reasons, we handle sources with row and set semantics in the same way: the TableFunctionDataProcessor + * gets a page of data from each source. The TableFunctionDataProcessor is responsible for using the provided data accordingly + * to the declared source semantics (set or rows). + * + * @return A List containing: + * - Optional Page for every source that is not fully consumed + * - Optional.empty() for every source that is fully consumed + * or null if all sources are fully consumed. + */ + private List> prepareInputPages() + { + if (!sortedPages.hasNext()) { + return null; + } + + Page inputPage = sortedPages.next(); + ImmutableList.Builder> sourcePages = ImmutableList.builder(); + + for (int[] channelsForSource : requiredChannels) { + if (channelsForSource.length == 0) { + sourcePages.add(Optional.of(new Page(inputPage.getPositionCount()))); + } + else { + int endOfDataForSource = endOfData[channelsForSource[0]]; // end-of-data position is validated to be consistent for all channels from source + if (endOfDataForSource <= processedPositions) { + // all data for this source was already processed + sourcePages.add(Optional.empty()); + } + else { + Block[] sourceBlocks = new Block[channelsForSource.length]; + if (endOfDataForSource < processedPositions + inputPage.getPositionCount()) { + // data for this source ends within the current page + for (int i = 0; i < channelsForSource.length; i++) { + int inputChannel = channelsForSource[i]; + sourceBlocks[i] = inputPage.getBlock(inputChannel).getRegion(0, endOfDataForSource - processedPositions); + } + } + else { + // data for this source does not end within the current page + for (int i = 0; i < channelsForSource.length; i++) { + int inputChannel = channelsForSource[i]; + sourceBlocks[i] = inputPage.getBlock(inputChannel); + } + } + sourcePages.add(Optional.of(new Page(sourceBlocks))); + } + } + } + + processedPositions += inputPage.getPositionCount(); + + return sourcePages.build(); + } + + /** + * There are two types of table function's source semantics: set and row. + *

+ * For a source with row semantics, the table function result depends on the whole partition, + * so it is not always possible to associate an output row with a specific input row. + * The TableFunctionDataProcessor can return null as the pass-through index to indicate that + * the output row is not associated with any row from the given source. + *

+ * For a source with row semantics, the output is determined on a row-by-row basis, so every + * output row is associated with a specific input row. In such case, the pass-through index + * should never be null. + *

+ * In our implementation, we handle sources with row and set semantics in the same way. + * For performance reasons, we do not validate the null pass-through indexes. + * The TableFunctionDataProcessor is responsible for using the pass-through capability + * accordingly to the declared source semantics (set or rows). + */ + private Page appendPassThroughColumns(Page page) + { + if (page.getChannelCount() != properChannelsCount + passThroughSourcesCount) { + throw new PrestoException( + FUNCTION_IMPLEMENTATION_ERROR, + format( + "Table function returned a page containing %s channels. Expected channel number: %s (%s proper columns, %s pass-through index columns)", + page.getChannelCount(), + properChannelsCount + passThroughSourcesCount, + properChannelsCount, + passThroughSourcesCount)); + } + // TODO is it possible to verify types of columns returned by TF? + + Block[] resultBlocks = new Block[properChannelsCount + passThroughProviders.length]; + + // proper outputs first + for (int channel = 0; channel < properChannelsCount; channel++) { + resultBlocks[channel] = page.getBlock(channel); + } + + // pass-through columns next + int channel = properChannelsCount; + for (PassThroughColumnProvider provider : passThroughProviders) { + resultBlocks[channel] = provider.getPassThroughColumn(page); + channel++; + } + + // pass the position count so that the Page can be successfully created in the case when there are no output channels (resultBlocks is empty) + return new Page(page.getPositionCount(), resultBlocks); + } + + private int[] findEndOfData(Optional> markerChannels, List> requiredChannels, List passThroughSpecifications) + { + Set referencedChannels = ImmutableSet.builder() + .addAll(requiredChannels.stream() + .flatMap(Collection::stream) + .collect(toImmutableList())) + .addAll(passThroughSpecifications.stream() + .map(PassThroughColumnSpecification::getInputChannel) + .collect(toImmutableList())) + .build(); + + if (referencedChannels.isEmpty()) { + // no required or pass-through channels + return null; + } + + int maxInputChannel = referencedChannels.stream() + .mapToInt(Integer::intValue) + .max() + .orElseThrow(NoSuchElementException::new); + + int[] result = new int[maxInputChannel + 1]; + Arrays.fill(result, -1); + + // if table function had one source, adding a marker channel was not necessary. + // end-of-data position is equal to partition end for each input channel + if (!markerChannels.isPresent()) { + referencedChannels.stream() + .forEach(channel -> result[channel] = partitionEnd - partitionStart); + return result; + } + + // if table function had more than one source, the markers map shall be present, and it shall contain mapping for each input channel + ImmutableMap.Builder endOfDataPerMarkerBuilder = ImmutableMap.builder(); + for (int markerChannel : ImmutableSet.copyOf(markerChannels.orElseThrow(NoSuchElementException::new).values())) { + endOfDataPerMarkerBuilder.put(markerChannel, findFirstNullPosition(markerChannel)); + } + Map endOfDataPerMarker = endOfDataPerMarkerBuilder.buildOrThrow(); + referencedChannels.stream() + .forEach(channel -> result[channel] = endOfDataPerMarker.get(markerChannels.orElseThrow(NoSuchElementException::new).get(channel)) - partitionStart); + + return result; + } + + private int findFirstNullPosition(int markerChannel) + { + if (pagesIndex.isNull(markerChannel, partitionStart)) { + return partitionStart; + } + if (!pagesIndex.isNull(markerChannel, partitionEnd - 1)) { + return partitionEnd; + } + + int start = partitionStart; + int end = partitionEnd; + // value at start is not null, value at end is null + while (end - start > 1) { + int mid = start + end >>> 1; + if (pagesIndex.isNull(markerChannel, mid)) { + end = mid; + } + else { + start = mid; + } + } + return end; + } + + public static class PassThroughColumnSpecification + { + private final boolean isPartitioningColumn; + private final int inputChannel; + private final int indexChannel; + + public PassThroughColumnSpecification(boolean isPartitioningColumn, int inputChannel, int indexChannel) + { + this.isPartitioningColumn = isPartitioningColumn; + this.inputChannel = inputChannel; + this.indexChannel = indexChannel; + } + + public boolean isPartitioningColumn() + { + return isPartitioningColumn; + } + + public int getInputChannel() + { + return inputChannel; + } + + public int getIndexChannel() + { + return indexChannel; + } + } + + private PassThroughColumnProvider createColumnProvider(PassThroughColumnSpecification specification) + { + if (specification.isPartitioningColumn()) { + return new PartitioningColumnProvider(pagesIndex.getSingleValueBlock(specification.getInputChannel(), partitionStart)); + } + return new NonPartitioningColumnProvider(specification.getInputChannel(), specification.getIndexChannel()); + } + + private interface PassThroughColumnProvider + { + Block getPassThroughColumn(Page page); + } + + private static class PartitioningColumnProvider + implements PassThroughColumnProvider + { + private final Block partitioningValue; + + private PartitioningColumnProvider(Block partitioningValue) + { + this.partitioningValue = requireNonNull(partitioningValue, "partitioningValue is null"); + } + + @Override + public Block getPassThroughColumn(Page page) + { + return new RunLengthEncodedBlock(partitioningValue, page.getPositionCount()); + } + + public Block getPartitioningValue() + { + return partitioningValue; + } + } + + private final class NonPartitioningColumnProvider + implements PassThroughColumnProvider + { + private final int inputChannel; + private final Type type; + private final int indexChannel; + + public NonPartitioningColumnProvider(int inputChannel, int indexChannel) + { + this.inputChannel = inputChannel; + this.type = pagesIndex.getType(inputChannel); + this.indexChannel = indexChannel; + } + + @Override + public Block getPassThroughColumn(Page page) + { + Block indexes = page.getBlock(indexChannel); + BlockBuilder builder = type.createBlockBuilder(null, page.getPositionCount()); + for (int position = 0; position < page.getPositionCount(); position++) { + if (indexes.isNull(position)) { + builder.appendNull(); + } + else { + // table function returns index from partition start + long index = BIGINT.getLong(indexes, position); + // validate index + if (index < 0 || index >= endOfData[inputChannel] || index >= processedPositions) { + int end = min(endOfData[inputChannel], processedPositions) - 1; + if (end >= 0) { + throw new PrestoException(FUNCTION_IMPLEMENTATION_ERROR, format("Index of a pass-through row: %s out of processed portion of partition [0, %s]", index, end)); + } + else { + throw new PrestoException(FUNCTION_IMPLEMENTATION_ERROR, "Index of a pass-through row must be null when no input data from the partition was processed. Actual: " + index); + } + } + // index in PagesIndex + long absoluteIndex = partitionStart + index; + pagesIndex.appendTo(inputChannel, toIntExact(absoluteIndex), builder); + } + } + + return builder.build(); + } + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/SimplePagesHashStrategy.java b/presto-main-base/src/main/java/com/facebook/presto/operator/SimplePagesHashStrategy.java index 8005ba436e667..c92884e1ce60c 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/SimplePagesHashStrategy.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/SimplePagesHashStrategy.java @@ -165,6 +165,12 @@ public boolean positionEqualsRow(int leftBlockIndex, int leftPosition, int right return true; } + @Override + public boolean positionNotDistinctFromRow(int leftBlockIndex, int leftPosition, int rightPosition, Page rightPage) + { + return false; + } + @Override public boolean positionEqualsRowIgnoreNulls(int leftBlockIndex, int leftPosition, int rightPosition, Page rightPage) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/TableFunctionOperator.java b/presto-main-base/src/main/java/com/facebook/presto/operator/TableFunctionOperator.java new file mode 100644 index 0000000000000..7586c1e9fba92 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/TableFunctionOperator.java @@ -0,0 +1,643 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator; + +import com.facebook.presto.common.Page; +import com.facebook.presto.common.block.SortOrder; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.memory.context.LocalMemoryContext; +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; +import com.facebook.presto.spi.function.table.TableFunctionProcessorProvider; +import com.facebook.presto.spi.plan.PlanNodeId; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Iterables; +import com.google.common.primitives.Ints; +import com.google.common.util.concurrent.ListenableFuture; + +import javax.annotation.Nullable; + +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.OptionalInt; + +import static com.facebook.presto.common.block.SortOrder.ASC_NULLS_LAST; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkPositionIndex; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.Iterables.concat; +import static java.util.Collections.nCopies; +import static java.util.Objects.requireNonNull; + +public class TableFunctionOperator + implements Operator +{ + public static class TableFunctionOperatorFactory + implements OperatorFactory + { + private final int operatorId; + private final PlanNodeId planNodeId; + + // a provider of table function processor to be called once per partition + private final TableFunctionProcessorProvider tableFunctionProvider; + + // all information necessary to execute the table function collected during analysis + private final ConnectorTableFunctionHandle functionHandle; + + // number of proper columns produced by the table function + private final int properChannelsCount; + + // number of input tables declared as pass-through + private final int passThroughSourcesCount; + + // columns required by the table function, in order of input tables + private final List> requiredChannels; + + // map from input channel to marker channel + // for each input table, there is a channel that marks which rows contain original data, and which are "filler" rows. + // the "filler" rows are part of the algorithm, and they should not be processed by the table function, or passed-through. + // In this map, every original column from the input table is associated with the appropriate marker. + private final Optional> markerChannels; + + // necessary information to build a pass-through column, for all pass-through columns, ordered as expected on the output + // it includes columns from sources declared as pass-through as well as partitioning columns from other sources + private final List passThroughSpecifications; + + // specifies whether the function should be pruned or executed when the input is empty + // pruneWhenEmpty is false if and only if all original input tables are KEEP WHEN EMPTY + private final boolean pruneWhenEmpty; + + // partitioning channels from all sources + private final List partitionChannels; + + // subset of partition channels that are already grouped + private final List prePartitionedChannels; + + // channels necessary to sort all sources: + // - for a single source, these are the source's sort channels + // - for multiple sources, this is a single synthesized row number channel + private final List sortChannels; + private final List sortOrders; + + // number of leading sort channels that are already sorted + private final int preSortedPrefix; + + private final List sourceTypes; + private final int expectedPositions; + private final PagesIndex.Factory pagesIndexFactory; + + private boolean closed; + + public TableFunctionOperatorFactory( + int operatorId, + PlanNodeId planNodeId, + TableFunctionProcessorProvider tableFunctionProvider, + ConnectorTableFunctionHandle functionHandle, + int properChannelsCount, + int passThroughSourcesCount, + List> requiredChannels, + Optional> markerChannels, + List passThroughSpecifications, + boolean pruneWhenEmpty, + List partitionChannels, + List prePartitionedChannels, + List sortChannels, + List sortOrders, + int preSortedPrefix, + List sourceTypes, + int expectedPositions, + PagesIndex.Factory pagesIndexFactory) + { + requireNonNull(planNodeId, "planNodeId is null"); + requireNonNull(tableFunctionProvider, "tableFunctionProvider is null"); + requireNonNull(functionHandle, "functionHandle is null"); + requireNonNull(requiredChannels, "requiredChannels is null"); + requireNonNull(markerChannels, "markerChannels is null"); + requireNonNull(passThroughSpecifications, "passThroughSpecifications is null"); + requireNonNull(partitionChannels, "partitionChannels is null"); + requireNonNull(prePartitionedChannels, "prePartitionedChannels is null"); + checkArgument(partitionChannels.containsAll(prePartitionedChannels), "prePartitionedChannels must be a subset of partitionChannels"); + requireNonNull(sortChannels, "sortChannels is null"); + requireNonNull(sortOrders, "sortOrders is null"); + checkArgument(sortChannels.size() == sortOrders.size(), "The number of sort channels must be equal to the number of sort orders"); + checkArgument(preSortedPrefix <= sortChannels.size(), "The number of pre-sorted channels must be lower or equal to the number of sort channels"); + checkArgument(preSortedPrefix == 0 || ImmutableSet.copyOf(prePartitionedChannels).equals(ImmutableSet.copyOf(partitionChannels)), "preSortedPrefix can only be greater than zero if all partition channels are pre-grouped"); + requireNonNull(sourceTypes, "sourceTypes is null"); + requireNonNull(pagesIndexFactory, "pagesIndexFactory is null"); + + this.operatorId = operatorId; + this.planNodeId = planNodeId; + this.tableFunctionProvider = tableFunctionProvider; + this.functionHandle = functionHandle; + this.properChannelsCount = properChannelsCount; + this.passThroughSourcesCount = passThroughSourcesCount; + this.requiredChannels = requiredChannels.stream() + .map(ImmutableList::copyOf) + .collect(toImmutableList()); + this.markerChannels = markerChannels.map(ImmutableMap::copyOf); + this.passThroughSpecifications = ImmutableList.copyOf(passThroughSpecifications); + this.pruneWhenEmpty = pruneWhenEmpty; + this.partitionChannels = ImmutableList.copyOf(partitionChannels); + this.prePartitionedChannels = ImmutableList.copyOf(prePartitionedChannels); + this.sortChannels = ImmutableList.copyOf(sortChannels); + this.sortOrders = ImmutableList.copyOf(sortOrders); + this.preSortedPrefix = preSortedPrefix; + this.sourceTypes = ImmutableList.copyOf(sourceTypes); + this.expectedPositions = expectedPositions; + this.pagesIndexFactory = pagesIndexFactory; + } + + @Override + public Operator createOperator(DriverContext driverContext) + { + checkState(!closed, "Factory is already closed"); + + OperatorContext operatorContext = driverContext.addOperatorContext(operatorId, planNodeId, TableFunctionOperator.class.getSimpleName()); + return new TableFunctionOperator( + operatorContext, + tableFunctionProvider, + functionHandle, + properChannelsCount, + passThroughSourcesCount, + requiredChannels, + markerChannels, + passThroughSpecifications, + pruneWhenEmpty, + partitionChannels, + prePartitionedChannels, + sortChannels, + sortOrders, + preSortedPrefix, + sourceTypes, + expectedPositions, + pagesIndexFactory); + } + + @Override + public void noMoreOperators() + { + closed = true; + } + + @Override + public OperatorFactory duplicate() + { + return new TableFunctionOperatorFactory( + operatorId, + planNodeId, + tableFunctionProvider, + functionHandle, + properChannelsCount, + passThroughSourcesCount, + requiredChannels, + markerChannels, + passThroughSpecifications, + pruneWhenEmpty, + partitionChannels, + prePartitionedChannels, + sortChannels, + sortOrders, + preSortedPrefix, + sourceTypes, + expectedPositions, + pagesIndexFactory); + } + } + + private final OperatorContext operatorContext; + + private final PageBuffer pageBuffer = new PageBuffer(); + private final WorkProcessor outputPages; + private final boolean processEmptyInput; + + @Nullable + private Page pendingInput; + private boolean operatorFinishing; + + public TableFunctionOperator( + OperatorContext operatorContext, + TableFunctionProcessorProvider tableFunctionProvider, + ConnectorTableFunctionHandle functionHandle, + int properChannelsCount, + int passThroughSourcesCount, + List> requiredChannels, + Optional> markerChannels, + List passThroughSpecifications, + boolean pruneWhenEmpty, + List partitionChannels, + List prePartitionedChannels, + List sortChannels, + List sortOrders, + int preSortedPrefix, + List sourceTypes, + int expectedPositions, + PagesIndex.Factory pagesIndexFactory) + { + requireNonNull(operatorContext, "operatorContext is null"); + requireNonNull(tableFunctionProvider, "tableFunctionProvider is null"); + requireNonNull(functionHandle, "functionHandle is null"); + requireNonNull(requiredChannels, "requiredChannels is null"); + requireNonNull(markerChannels, "markerChannels is null"); + requireNonNull(passThroughSpecifications, "passThroughSpecifications is null"); + requireNonNull(partitionChannels, "partitionChannels is null"); + requireNonNull(prePartitionedChannels, "prePartitionedChannels is null"); + checkArgument(partitionChannels.containsAll(prePartitionedChannels), "prePartitionedChannels must be a subset of partitionChannels"); + requireNonNull(sortChannels, "sortChannels is null"); + requireNonNull(sortOrders, "sortOrders is null"); + checkArgument(sortChannels.size() == sortOrders.size(), "The number of sort channels must be equal to the number of sort orders"); + checkArgument(preSortedPrefix <= sortChannels.size(), "The number of pre-sorted channels must be lower or equal to the number of sort channels"); + checkArgument(preSortedPrefix == 0 || ImmutableSet.copyOf(prePartitionedChannels).equals(ImmutableSet.copyOf(partitionChannels)), "preSortedPrefix can only be greater than zero if all partition channels are pre-grouped"); + requireNonNull(sourceTypes, "sourceTypes is null"); + requireNonNull(pagesIndexFactory, "pagesIndexFactory is null"); + + this.operatorContext = operatorContext; + + this.processEmptyInput = !pruneWhenEmpty; + + PagesIndex pagesIndex = pagesIndexFactory.newPagesIndex(sourceTypes, expectedPositions); + HashStrategies hashStrategies = new HashStrategies(pagesIndex, partitionChannels, prePartitionedChannels, sortChannels, sortOrders, preSortedPrefix); + + this.outputPages = WorkProcessor.create(new PagesSource()) + .transform(new PartitionAndSort(pagesIndex, hashStrategies, processEmptyInput)) + .flatMap(groupPagesIndex -> pagesIndexToTableFunctionPartitions( + groupPagesIndex, + hashStrategies, + tableFunctionProvider, + functionHandle, + properChannelsCount, + passThroughSourcesCount, + requiredChannels, + markerChannels, + passThroughSpecifications, + processEmptyInput)) + .flatMap(TableFunctionPartition::toOutputPages); + } + + @Override + public OperatorContext getOperatorContext() + { + return operatorContext; + } + + @Override + public void finish() + { + operatorFinishing = true; + } + + @Override + public boolean isFinished() + { + return outputPages.isFinished(); + } + + @Override + public ListenableFuture isBlocked() + { + if (outputPages.isBlocked()) { + return outputPages.getBlockedFuture(); + } + + return NOT_BLOCKED; + } + + @Override + public boolean needsInput() + { + return pendingInput == null && !operatorFinishing; + } + + @Override + public void addInput(Page page) + { + requireNonNull(page, "page is null"); + checkState(pendingInput == null, "Operator already has pending input"); + + if (page.getPositionCount() == 0) { + return; + } + + pendingInput = page; + } + + @Override + public Page getOutput() + { + if (!outputPages.process()) { + return null; + } + + if (outputPages.isFinished()) { + return null; + } + + return outputPages.getResult(); + } + + private static class HashStrategies + { + final PagesHashStrategy prePartitionedStrategy; + final PagesHashStrategy remainingPartitionStrategy; + final PagesHashStrategy preSortedStrategy; + final List remainingPartitionAndSortChannels; + final List remainingSortOrders; + final int[] prePartitionedChannelsArray; + + public HashStrategies( + PagesIndex pagesIndex, + List partitionChannels, + List prePartitionedChannels, + List sortChannels, + List sortOrders, + int preSortedPrefix) + { + this.prePartitionedStrategy = pagesIndex.createPagesHashStrategy(prePartitionedChannels, OptionalInt.empty()); + + List remainingPartitionChannels = partitionChannels.stream() + .filter(channel -> !prePartitionedChannels.contains(channel)) + .collect(toImmutableList()); + this.remainingPartitionStrategy = pagesIndex.createPagesHashStrategy(remainingPartitionChannels, OptionalInt.empty()); + + List preSortedChannels = sortChannels.stream() + .limit(preSortedPrefix) + .collect(toImmutableList()); + this.preSortedStrategy = pagesIndex.createPagesHashStrategy(preSortedChannels, OptionalInt.empty()); + + if (preSortedPrefix > 0) { + // preSortedPrefix > 0 implies that all partition channels are already pre-partitioned (enforced by check in the constructor), so we only need to do the remaining sort + this.remainingPartitionAndSortChannels = ImmutableList.copyOf(Iterables.skip(sortChannels, preSortedPrefix)); + this.remainingSortOrders = ImmutableList.copyOf(Iterables.skip(sortOrders, preSortedPrefix)); + } + else { + // we need to sort by the remaining partition channels so that the input is fully partitioned, + // and then need to we sort by all the sort channels so that the input is fully sorted + this.remainingPartitionAndSortChannels = ImmutableList.copyOf(concat(remainingPartitionChannels, sortChannels)); + this.remainingSortOrders = ImmutableList.copyOf(concat(nCopies(remainingPartitionChannels.size(), ASC_NULLS_LAST), sortOrders)); + } + + this.prePartitionedChannelsArray = Ints.toArray(prePartitionedChannels); + } + } + + private class PartitionAndSort + implements WorkProcessor.Transformation + { + private final PagesIndex pagesIndex; + private final HashStrategies hashStrategies; + private final LocalMemoryContext memoryContext; + + private boolean resetPagesIndex; + private int inputPosition; + private boolean processEmptyInput; + + public PartitionAndSort(PagesIndex pagesIndex, HashStrategies hashStrategies, boolean processEmptyInput) + { + this.pagesIndex = pagesIndex; + this.hashStrategies = hashStrategies; + this.memoryContext = operatorContext.aggregateUserMemoryContext().newLocalMemoryContext(PartitionAndSort.class.getSimpleName()); + this.processEmptyInput = processEmptyInput; + } + + @Override + public WorkProcessor.TransformationState process(Optional input) + { + if (resetPagesIndex) { + pagesIndex.clear(); + updateMemoryUsage(); + resetPagesIndex = false; + } + + if (!input.isPresent() && pagesIndex.getPositionCount() == 0) { + if (processEmptyInput) { + // it can only happen at the first call to process(), which implies that there is no input. Empty PagesIndex can be passed on only once. + processEmptyInput = false; + return WorkProcessor.TransformationState.ofResult(pagesIndex, false); + } + else { + memoryContext.close(); + return WorkProcessor.TransformationState.finished(); + } + } + + // there is input, so we are not interested in processing empty input + processEmptyInput = false; + + if (input.isPresent()) { + // append rows from input which belong to the current group wrt pre-partitioned columns + // it might be one or more partitions + inputPosition = appendCurrentGroup(pagesIndex, hashStrategies, input.get(), inputPosition); + updateMemoryUsage(); + + if (inputPosition >= input.get().getPositionCount()) { + inputPosition = 0; + return WorkProcessor.TransformationState.needsMoreData(); + } + } + + // we have unused input or the input is finished. we have buffered a full group + // the group contains one or more partitions, as it was determined by the pre-partitioned columns + // sorting serves two purposes: + // - sort by the remaining partition channels so that the input is fully partitioned, + // - sort by all the sort channels so that the input is fully sorted + sortCurrentGroup(pagesIndex, hashStrategies); + resetPagesIndex = true; + return WorkProcessor.TransformationState.ofResult(pagesIndex, false); + } + + void updateMemoryUsage() + { + memoryContext.setBytes(pagesIndex.getEstimatedSize().toBytes()); + } + } + + private static int appendCurrentGroup(PagesIndex pagesIndex, HashStrategies hashStrategies, Page page, int startPosition) + { + checkArgument(page.getPositionCount() > startPosition); + + PagesHashStrategy prePartitionedStrategy = hashStrategies.prePartitionedStrategy; + Page prePartitionedPage = page.getColumns(hashStrategies.prePartitionedChannelsArray); + + if (pagesIndex.getPositionCount() == 0 || pagesIndex.positionEqualsRow(prePartitionedStrategy, 0, startPosition, prePartitionedPage)) { + // we are within the current group. find the position where the pre-grouped columns change + int groupEnd = findGroupEnd(prePartitionedPage, prePartitionedStrategy, startPosition); + + // add the section of the page that contains values for the current group + pagesIndex.addPage(page.getRegion(startPosition, groupEnd - startPosition)); + + if (page.getPositionCount() - groupEnd > 0) { + // the remaining prt of the page contains the next group + return groupEnd; + } + // page fully consumed: it contains the current group only + return page.getPositionCount(); + } + + // we had previous results buffered, but the remaining page starts with new group values + return startPosition; + } + + private static void sortCurrentGroup(PagesIndex pagesIndex, HashStrategies hashStrategies) + { + PagesHashStrategy preSortedStrategy = hashStrategies.preSortedStrategy; + List remainingPartitionAndSortChannels = hashStrategies.remainingPartitionAndSortChannels; + List remainingSortOrders = hashStrategies.remainingSortOrders; + + if (pagesIndex.getPositionCount() > 1 && !remainingPartitionAndSortChannels.isEmpty()) { + int startPosition = 0; + while (startPosition < pagesIndex.getPositionCount()) { + int endPosition = findGroupEnd(pagesIndex, preSortedStrategy, startPosition); + pagesIndex.sort(remainingPartitionAndSortChannels, remainingSortOrders, startPosition, endPosition); + startPosition = endPosition; + } + } + } + + // Assumes input grouped on relevant pagesHashStrategy columns + private static int findGroupEnd(Page page, PagesHashStrategy pagesHashStrategy, int startPosition) + { + checkArgument(page.getPositionCount() > 0, "Must have at least one position"); + checkPositionIndex(startPosition, page.getPositionCount(), "startPosition out of bounds"); + + return findEndPosition(startPosition, page.getPositionCount(), (firstPosition, secondPosition) -> pagesHashStrategy.rowEqualsRow(firstPosition, page, secondPosition, page)); + } + + // Assumes input grouped on relevant pagesHashStrategy columns + private static int findGroupEnd(PagesIndex pagesIndex, PagesHashStrategy pagesHashStrategy, int startPosition) + { + checkArgument(pagesIndex.getPositionCount() > 0, "Must have at least one position"); + checkPositionIndex(startPosition, pagesIndex.getPositionCount(), "startPosition out of bounds"); + + return findEndPosition(startPosition, pagesIndex.getPositionCount(), (firstPosition, secondPosition) -> pagesIndex.positionEqualsPosition(pagesHashStrategy, firstPosition, secondPosition)); + } + + /** + * @param startPosition - inclusive + * @param endPosition - exclusive + * @param comparator - returns true if positions given as parameters are equal + * @return the end of the group position exclusive (the position the very next group starts) + */ + @VisibleForTesting + static int findEndPosition(int startPosition, int endPosition, PositionComparator comparator) + { + checkArgument(startPosition >= 0, "startPosition must be greater or equal than zero: %s", startPosition); + checkArgument(startPosition < endPosition, "startPosition (%s) must be less than endPosition (%s)", startPosition, endPosition); + + int left = startPosition; + int right = endPosition; + + while (right - left > 1) { + int middle = (left + right) >>> 1; + + if (comparator.test(startPosition, middle)) { + left = middle; + } + else { + right = middle; + } + } + + return right; + } + + private interface PositionComparator + { + boolean test(int first, int second); + } + + private WorkProcessor pagesIndexToTableFunctionPartitions( + PagesIndex pagesIndex, + HashStrategies hashStrategies, + TableFunctionProcessorProvider tableFunctionProvider, + ConnectorTableFunctionHandle functionHandle, + int properChannelsCount, + int passThroughSourcesCount, + List> requiredChannels, + Optional> markerChannels, + List passThroughSpecifications, + boolean processEmptyInput) + { + // pagesIndex contains the full grouped and sorted data for one or more partitions + + PagesHashStrategy remainingPartitionStrategy = hashStrategies.remainingPartitionStrategy; + + return WorkProcessor.create(new WorkProcessor.Process() + { + private int partitionStart; + private boolean processEmpty = processEmptyInput; + + @Override + public WorkProcessor.ProcessState process() + { + if (partitionStart == pagesIndex.getPositionCount()) { + if (processEmpty && pagesIndex.getPositionCount() == 0) { + // empty PagesIndex can only be passed once as the result of PartitionAndSort. Neither this nor any future instance of Process will ever get an empty PagesIndex again. + processEmpty = false; + return WorkProcessor.ProcessState.ofResult(new EmptyTableFunctionPartition( + tableFunctionProvider.getDataProcessor(functionHandle), + properChannelsCount, + passThroughSourcesCount, + passThroughSpecifications.stream() + .map(RegularTableFunctionPartition.PassThroughColumnSpecification::getInputChannel) + .map(pagesIndex::getType) + .collect(toImmutableList()))); + } + return WorkProcessor.ProcessState.finished(); + } + + // there is input, so we are not interested in processing empty input + processEmpty = false; + + int partitionEnd = findGroupEnd(pagesIndex, remainingPartitionStrategy, partitionStart); + + RegularTableFunctionPartition partition = new RegularTableFunctionPartition( + pagesIndex, + partitionStart, + partitionEnd, + tableFunctionProvider.getDataProcessor(functionHandle), + properChannelsCount, + passThroughSourcesCount, + requiredChannels, + markerChannels, + passThroughSpecifications); + + partitionStart = partitionEnd; + return WorkProcessor.ProcessState.ofResult(partition); + } + }); + } + + private class PagesSource + implements WorkProcessor.Process + { + @Override + public WorkProcessor.ProcessState process() + { + if (operatorFinishing && pendingInput == null) { + return WorkProcessor.ProcessState.finished(); + } + + if (pendingInput != null) { + Page result = pendingInput; + pendingInput = null; + return WorkProcessor.ProcessState.ofResult(result); + } + + return WorkProcessor.ProcessState.yield(); + } + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/TableFunctionPartition.java b/presto-main-base/src/main/java/com/facebook/presto/operator/TableFunctionPartition.java new file mode 100644 index 0000000000000..1876b352bd251 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/TableFunctionPartition.java @@ -0,0 +1,21 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator; + +import com.facebook.presto.common.Page; + +public interface TableFunctionPartition +{ + WorkProcessor toOutputPages(); +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/server/PluginManager.java b/presto-main-base/src/main/java/com/facebook/presto/server/PluginManager.java index af0f0f3bee089..3946f31d1335f 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/server/PluginManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/server/PluginManager.java @@ -289,6 +289,10 @@ public void installPlugin(Plugin plugin) } log.info("Registering connector %s", connectorFactory.getName()); connectorManager.addConnectorFactory(connectorFactory); + + if (connectorFactory.getTableFunctionProcessorProvider().isPresent()) { + metadata.getFunctionAndTypeManager().setGetTableFunctionProcessorProvider(connectorFactory.getTableFunctionProcessorProvider()); + } } for (Class functionClass : plugin.getFunctions()) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/split/CloseableSplitSourceProvider.java b/presto-main-base/src/main/java/com/facebook/presto/split/CloseableSplitSourceProvider.java index 69cf25fe60273..4f776743db45b 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/split/CloseableSplitSourceProvider.java +++ b/presto-main-base/src/main/java/com/facebook/presto/split/CloseableSplitSourceProvider.java @@ -15,6 +15,7 @@ import com.facebook.airlift.log.Logger; import com.facebook.presto.Session; +import com.facebook.presto.metadata.TableFunctionHandle; import com.facebook.presto.spi.TableHandle; import com.facebook.presto.spi.WarningCollector; import com.facebook.presto.spi.connector.ConnectorSplitManager.SplitSchedulingStrategy; @@ -33,14 +34,14 @@ public class CloseableSplitSourceProvider { private static final Logger log = Logger.get(CloseableSplitSourceProvider.class); - private final SplitSourceProvider delegate; + private final SplitManager delegate; @GuardedBy("this") private List splitSources = new ArrayList<>(); @GuardedBy("this") private boolean closed; - public CloseableSplitSourceProvider(SplitSourceProvider delegate) + public CloseableSplitSourceProvider(SplitManager delegate) { this.delegate = requireNonNull(delegate, "delegate is null"); } @@ -54,6 +55,15 @@ public synchronized SplitSource getSplits(Session session, TableHandle tableHand return splitSource; } + @Override + public synchronized SplitSource getSplits(Session session, TableFunctionHandle tableFunctionHandle) + { + checkState(!closed, "split source provider is closed"); + SplitSource splitSource = delegate.getSplitsForTableFunction(session, tableFunctionHandle); + splitSources.add(splitSource); + return splitSource; + } + @Override public synchronized void close() { diff --git a/presto-main-base/src/main/java/com/facebook/presto/split/SplitManager.java b/presto-main-base/src/main/java/com/facebook/presto/split/SplitManager.java index adb189379ed36..a4595387158b5 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/split/SplitManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/split/SplitManager.java @@ -18,6 +18,7 @@ import com.facebook.presto.execution.scheduler.NodeSchedulerConfig; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.metadata.MetadataManager; +import com.facebook.presto.metadata.TableFunctionHandle; import com.facebook.presto.metadata.TableLayoutResult; import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.ConnectorSession; @@ -105,4 +106,18 @@ private ConnectorSplitManager getConnectorSplitManager(ConnectorId connectorId) checkArgument(result != null, "No split manager for connector '%s'", connectorId); return result; } + + public SplitSource getSplitsForTableFunction(Session session, TableFunctionHandle function) + { + ConnectorId connectorId = function.getConnectorId(); + ConnectorSplitManager splitManager = splitManagers.get(connectorId); + + ConnectorSplitSource source = splitManager.getSplits( + function.getTransactionHandle(), + session.toConnectorSession(connectorId), + function.getSchemaFunctionName(), + function.getFunctionHandle()); + + return new ConnectorAwareSplitSource(connectorId, function.getTransactionHandle(), source); + } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/split/SplitSourceProvider.java b/presto-main-base/src/main/java/com/facebook/presto/split/SplitSourceProvider.java index 617fba7093613..f835db17eafa5 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/split/SplitSourceProvider.java +++ b/presto-main-base/src/main/java/com/facebook/presto/split/SplitSourceProvider.java @@ -14,6 +14,7 @@ package com.facebook.presto.split; import com.facebook.presto.Session; +import com.facebook.presto.metadata.TableFunctionHandle; import com.facebook.presto.spi.TableHandle; import com.facebook.presto.spi.WarningCollector; import com.facebook.presto.spi.connector.ConnectorSplitManager.SplitSchedulingStrategy; @@ -21,4 +22,6 @@ public interface SplitSourceProvider { SplitSource getSplits(Session session, TableHandle tableHandle, SplitSchedulingStrategy splitSchedulingStrategy, WarningCollector warningCollector); + + SplitSource getSplits(Session session, TableFunctionHandle tableFunctionHandle); } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/Analyzer.java b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/Analyzer.java index 35d1793e3e500..014273a0d959b 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/Analyzer.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/Analyzer.java @@ -113,7 +113,9 @@ public Analysis analyzeSemantic(Statement statement, boolean isDescribe) Analysis analysis = new Analysis(rewrittenStatement, parameterLookup, isDescribe); metadataExtractor.populateMetadataHandle(session, rewrittenStatement, analysis.getMetadataHandle()); - StatementAnalyzer analyzer = new StatementAnalyzer(analysis, metadata, sqlParser, accessControl, session, warningCollector); + + // TODO: We do not need TransactionManager in the StatementAnalyzer as metadata has it through FunctionAndTypeManager. + StatementAnalyzer analyzer = new StatementAnalyzer(analysis, metadata, sqlParser, accessControl, metadata.getFunctionAndTypeManager().getTransactionManager(), session, warningCollector); analyzer.analyze(rewrittenStatement, Optional.empty()); analyzeForUtilizedColumns(analysis, analysis.getStatement(), warningCollector); analysis.populateTableColumnAndSubfieldReferencesForAccessControl(isCheckAccessControlOnUtilizedColumnsOnly(session), isCheckAccessControlWithSubfields(session)); diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/ExpressionAnalyzer.java b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/ExpressionAnalyzer.java index 002b667eb68d9..b136da3465c6a 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/ExpressionAnalyzer.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/ExpressionAnalyzer.java @@ -111,6 +111,7 @@ import com.facebook.presto.sql.tree.WhenClause; import com.facebook.presto.sql.tree.Window; import com.facebook.presto.sql.tree.WindowFrame; +import com.facebook.presto.transaction.NoOpTransactionManager; import com.google.common.collect.HashMultimap; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -2024,7 +2025,7 @@ private static ExpressionAnalyzer create( { return new ExpressionAnalyzer( metadata.getFunctionAndTypeManager().getFunctionAndTypeResolver(), - node -> new StatementAnalyzer(analysis, metadata, sqlParser, accessControl, session, warningCollector), + node -> new StatementAnalyzer(analysis, metadata, sqlParser, accessControl, new NoOpTransactionManager(), session, warningCollector), Optional.of(session.getSessionFunctions()), session.getTransactionId(), session.getSqlFunctionProperties(), @@ -2047,7 +2048,7 @@ private static ExpressionAnalyzer create( { return new ExpressionAnalyzer( metadata.getFunctionAndTypeManager().getFunctionAndTypeResolver(), - node -> new StatementAnalyzer(analysis, metadata, sqlParser, accessControl, session, warningCollector), + node -> new StatementAnalyzer(analysis, metadata, sqlParser, accessControl, new NoOpTransactionManager(), session, warningCollector), Optional.of(session.getSessionFunctions()), session.getTransactionId(), session.getSqlFunctionProperties(), diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java index d52540a1660f0..90588ef2c7d38 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java @@ -31,24 +31,41 @@ import com.facebook.presto.common.type.TimestampWithTimeZoneType; import com.facebook.presto.common.type.Type; import com.facebook.presto.common.type.VarcharType; +import com.facebook.presto.metadata.CatalogMetadata; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.metadata.OperatorNotFoundException; +import com.facebook.presto.metadata.TableFunctionMetadata; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ColumnMetadata; +import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.MaterializedViewDefinition; import com.facebook.presto.spi.MaterializedViewStatus; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.PrestoWarning; import com.facebook.presto.spi.SchemaTableName; +import com.facebook.presto.spi.StandardErrorCode; import com.facebook.presto.spi.TableHandle; import com.facebook.presto.spi.WarningCollector; import com.facebook.presto.spi.analyzer.AccessControlInfoForTable; import com.facebook.presto.spi.analyzer.MetadataResolver; import com.facebook.presto.spi.analyzer.ViewDefinition; import com.facebook.presto.spi.connector.ConnectorTableVersion; +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; import com.facebook.presto.spi.function.FunctionKind; import com.facebook.presto.spi.function.Signature; import com.facebook.presto.spi.function.SqlFunction; +import com.facebook.presto.spi.function.table.Argument; +import com.facebook.presto.spi.function.table.ArgumentSpecification; +import com.facebook.presto.spi.function.table.ConnectorTableFunction; +import com.facebook.presto.spi.function.table.Descriptor; +import com.facebook.presto.spi.function.table.DescriptorArgument; +import com.facebook.presto.spi.function.table.DescriptorArgumentSpecification; +import com.facebook.presto.spi.function.table.ReturnTypeSpecification; +import com.facebook.presto.spi.function.table.ScalarArgument; +import com.facebook.presto.spi.function.table.ScalarArgumentSpecification; +import com.facebook.presto.spi.function.table.TableArgument; +import com.facebook.presto.spi.function.table.TableArgumentSpecification; +import com.facebook.presto.spi.function.table.TableFunctionAnalysis; import com.facebook.presto.spi.relation.DomainTranslator; import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.security.AccessControl; @@ -58,6 +75,8 @@ import com.facebook.presto.sql.ExpressionUtils; import com.facebook.presto.sql.MaterializedViewUtils; import com.facebook.presto.sql.SqlFormatterUtil; +import com.facebook.presto.sql.analyzer.Analysis.TableArgumentAnalysis; +import com.facebook.presto.sql.analyzer.Analysis.TableFunctionInvocationAnalysis; import com.facebook.presto.sql.parser.ParsingException; import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.planner.ExpressionInterpreter; @@ -92,6 +111,7 @@ import com.facebook.presto.sql.tree.DropSchema; import com.facebook.presto.sql.tree.DropTable; import com.facebook.presto.sql.tree.DropView; +import com.facebook.presto.sql.tree.EmptyTableTreatment; import com.facebook.presto.sql.tree.Except; import com.facebook.presto.sql.tree.Execute; import com.facebook.presto.sql.tree.Explain; @@ -125,6 +145,7 @@ import com.facebook.presto.sql.tree.NodeRef; import com.facebook.presto.sql.tree.Offset; import com.facebook.presto.sql.tree.OrderBy; +import com.facebook.presto.sql.tree.Parameter; import com.facebook.presto.sql.tree.Prepare; import com.facebook.presto.sql.tree.Property; import com.facebook.presto.sql.tree.QualifiedName; @@ -156,6 +177,10 @@ import com.facebook.presto.sql.tree.StartTransaction; import com.facebook.presto.sql.tree.Statement; import com.facebook.presto.sql.tree.Table; +import com.facebook.presto.sql.tree.TableFunctionArgument; +import com.facebook.presto.sql.tree.TableFunctionDescriptorArgument; +import com.facebook.presto.sql.tree.TableFunctionInvocation; +import com.facebook.presto.sql.tree.TableFunctionTableArgument; import com.facebook.presto.sql.tree.TableSubquery; import com.facebook.presto.sql.tree.TruncateTable; import com.facebook.presto.sql.tree.Union; @@ -169,6 +194,7 @@ import com.facebook.presto.sql.tree.With; import com.facebook.presto.sql.tree.WithQuery; import com.facebook.presto.sql.util.AstUtils; +import com.facebook.presto.transaction.TransactionManager; import com.google.common.base.Joiner; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -205,8 +231,6 @@ import static com.facebook.presto.metadata.MetadataUtil.createQualifiedObjectName; import static com.facebook.presto.metadata.MetadataUtil.getConnectorIdOrThrow; import static com.facebook.presto.metadata.MetadataUtil.toSchemaTableName; -import static com.facebook.presto.spi.StandardErrorCode.INVALID_ARGUMENTS; -import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; import static com.facebook.presto.spi.StandardWarningCode.PERFORMANCE_WARNING; import static com.facebook.presto.spi.StandardWarningCode.REDUNDANT_ORDER_BY; import static com.facebook.presto.spi.analyzer.AccessControlRole.TABLE_CREATE; @@ -216,6 +240,9 @@ import static com.facebook.presto.spi.connector.ConnectorTableVersion.VersionType; import static com.facebook.presto.spi.function.FunctionKind.AGGREGATE; import static com.facebook.presto.spi.function.FunctionKind.WINDOW; +import static com.facebook.presto.spi.function.table.DescriptorArgument.NULL_DESCRIPTOR; +import static com.facebook.presto.spi.function.table.ReturnTypeSpecification.GenericTable.GENERIC_TABLE; +import static com.facebook.presto.spi.function.table.ReturnTypeSpecification.OnlyPassThrough.ONLY_PASS_THROUGH; import static com.facebook.presto.sql.MaterializedViewUtils.buildOwnerSession; import static com.facebook.presto.sql.MaterializedViewUtils.generateBaseTablePredicates; import static com.facebook.presto.sql.MaterializedViewUtils.generateFalsePredicates; @@ -238,24 +265,36 @@ import static com.facebook.presto.sql.analyzer.RefreshMaterializedViewPredicateAnalyzer.extractTablePredicates; import static com.facebook.presto.sql.analyzer.ScopeReferenceExtractor.hasReferencesToScope; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.AMBIGUOUS_ATTRIBUTE; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.AMBIGUOUS_RETURN_TYPE; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.COLUMN_NAME_NOT_SPECIFIED; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.COLUMN_NOT_FOUND; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.COLUMN_TYPE_UNKNOWN; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.DUPLICATE_COLUMN_NAME; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.DUPLICATE_PARAMETER_NAME; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.DUPLICATE_PROPERTY; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.DUPLICATE_RANGE_VARIABLE; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.DUPLICATE_RELATION; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.FUNCTION_IMPLEMENTATION_ERROR; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.FUNCTION_NOT_FOUND; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.INVALID_ARGUMENTS; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.INVALID_COLUMN_REFERENCE; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.INVALID_COPARTITIONING; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.INVALID_FUNCTION_ARGUMENT; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.INVALID_FUNCTION_NAME; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.INVALID_OFFSET_ROW_COUNT; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.INVALID_ORDINAL; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.INVALID_PROCEDURE_ARGUMENTS; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.INVALID_TABLE_FUNCTION_INVOCATION; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.INVALID_WINDOW_FRAME; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.MATERIALIZED_VIEW_ALREADY_EXISTS; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.MATERIALIZED_VIEW_IS_RECURSIVE; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.MISMATCHED_COLUMN_ALIASES; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.MISMATCHED_SET_COLUMN_TYPES; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.MISSING_ARGUMENT; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.MISSING_ATTRIBUTE; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.MISSING_COLUMN; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.MISSING_MATERIALIZED_VIEW; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.MISSING_RETURN_TYPE; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.MISSING_SCHEMA; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.MISSING_TABLE; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.MUST_BE_WINDOW_FUNCTION; @@ -277,6 +316,7 @@ import static com.facebook.presto.sql.planner.ExpressionDeterminismEvaluator.isDeterministic; import static com.facebook.presto.sql.planner.ExpressionInterpreter.evaluateConstantExpression; import static com.facebook.presto.sql.planner.ExpressionInterpreter.expressionOptimizer; +import static com.facebook.presto.sql.tree.DereferenceExpression.getQualifiedName; import static com.facebook.presto.sql.tree.ExplainFormat.Type.JSON; import static com.facebook.presto.sql.tree.ExplainFormat.Type.TEXT; import static com.facebook.presto.sql.tree.ExplainType.Type.DISTRIBUTED; @@ -300,6 +340,7 @@ import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.collect.Iterables.getLast; +import static com.google.common.collect.Iterables.getOnlyElement; import static java.lang.Math.toIntExact; import static java.lang.String.format; import static java.util.Collections.emptyList; @@ -320,6 +361,7 @@ class StatementAnalyzer private final Session session; private final SqlParser sqlParser; private final AccessControl accessControl; + private final TransactionManager transactionManager; private final WarningCollector warningCollector; private final MetadataResolver metadataResolver; @@ -328,6 +370,7 @@ public StatementAnalyzer( Metadata metadata, SqlParser sqlParser, AccessControl accessControl, + TransactionManager transactionManager, Session session, WarningCollector warningCollector) { @@ -335,6 +378,7 @@ public StatementAnalyzer( this.metadata = requireNonNull(metadata, "metadata is null"); this.sqlParser = requireNonNull(sqlParser, "sqlParser is null"); this.accessControl = requireNonNull(accessControl, "accessControl is null"); + this.transactionManager = requireNonNull(transactionManager, "transactionManager is null"); this.session = requireNonNull(session, "session is null"); this.warningCollector = requireNonNull(warningCollector, "warningCollector is null"); this.metadataResolver = requireNonNull(metadata.getMetadataResolver(session), "metadataResolver is null"); @@ -590,6 +634,7 @@ protected Scope visitDelete(Delete node, Optional scope) metadata, sqlParser, new AllowAllAccessControl(), + transactionManager, session, warningCollector); @@ -693,7 +738,7 @@ protected Scope visitCreateView(CreateView node, Optional scope) QualifiedObjectName viewName = createQualifiedObjectName(session, node, node.getName()); // analyze the query that creates the view - StatementAnalyzer analyzer = new StatementAnalyzer(analysis, metadata, sqlParser, accessControl, session, warningCollector); + StatementAnalyzer analyzer = new StatementAnalyzer(analysis, metadata, sqlParser, accessControl, transactionManager, session, warningCollector); Scope queryScope = analyzer.analyze(node.getQuery(), scope); @@ -762,7 +807,7 @@ protected Scope visitRefreshMaterializedView(RefreshMaterializedView node, Optio viewName)); // Use AllowAllAccessControl; otherwise Analyzer will check SELECT permission on the materialized view, which is not necessary. - StatementAnalyzer viewAnalyzer = new StatementAnalyzer(analysis, metadata, sqlParser, new AllowAllAccessControl(), session, warningCollector); + StatementAnalyzer viewAnalyzer = new StatementAnalyzer(analysis, metadata, sqlParser, new AllowAllAccessControl(), transactionManager, session, warningCollector); Scope viewScope = viewAnalyzer.analyze(node.getTarget(), scope); Map tablePredicates = extractTablePredicates(viewName, node.getWhere(), viewScope, metadata, session); @@ -776,6 +821,7 @@ protected Scope visitRefreshMaterializedView(RefreshMaterializedView node, Optio metadata, sqlParser, accessControl, + transactionManager, buildOwnerSession(session, view.getOwner(), metadata.getSessionPropertyManager(), viewName.getCatalogName(), view.getSchema()), warningCollector); queryAnalyzer.analyze(refreshQuery, Scope.create()); @@ -805,7 +851,7 @@ private Optional analyzeBaseTableForRefreshMaterializedView(Table QualifiedObjectName viewName = createQualifiedObjectName(session, refreshMaterializedView.getTarget(), refreshMaterializedView.getTarget().getName()); // Use AllowAllAccessControl; otherwise Analyzer will check SELECT permission on the materialized view, which is not necessary. - StatementAnalyzer viewAnalyzer = new StatementAnalyzer(analysis, metadata, sqlParser, new AllowAllAccessControl(), session, warningCollector); + StatementAnalyzer viewAnalyzer = new StatementAnalyzer(analysis, metadata, sqlParser, new AllowAllAccessControl(), transactionManager, session, warningCollector); Scope viewScope = viewAnalyzer.analyze(refreshMaterializedView.getTarget(), scope); Map tablePredicates = extractTablePredicates(viewName, refreshMaterializedView.getWhere(), viewScope, metadata, session); @@ -1230,7 +1276,7 @@ else if (expressionType instanceof MapType) { outputFields.add(Field.newUnqualified(expression.getLocation(), Optional.empty(), ((MapType) expressionType).getValueType())); } else { - throw new PrestoException(INVALID_FUNCTION_ARGUMENT, "Cannot unnest type: " + expressionType); + throw new PrestoException(StandardErrorCode.INVALID_FUNCTION_ARGUMENT, "Cannot unnest type: " + expressionType); } } if (node.isWithOrdinality()) { @@ -1242,11 +1288,563 @@ else if (expressionType instanceof MapType) { @Override protected Scope visitLateral(Lateral node, Optional scope) { - StatementAnalyzer analyzer = new StatementAnalyzer(analysis, metadata, sqlParser, accessControl, session, warningCollector); + StatementAnalyzer analyzer = new StatementAnalyzer(analysis, metadata, sqlParser, accessControl, transactionManager, session, warningCollector); Scope queryScope = analyzer.analyze(node.getQuery(), scope); return createAndAssignScope(node, scope, queryScope.getRelationType()); } + @Override + protected Scope visitTableFunctionInvocation(TableFunctionInvocation node, Optional scope) + { + TableFunctionMetadata tableFunctionMetadata = metadata.getFunctionAndTypeManager().getTableFunctionRegistry().resolve(session, node.getName()); + if (tableFunctionMetadata == null) { + throw new SemanticException(FUNCTION_NOT_FOUND, node, "Table function %s not registered", node.getName()); + } + + ConnectorTableFunction function = tableFunctionMetadata.getFunction(); + ConnectorId connectorId = tableFunctionMetadata.getConnectorId(); + + QualifiedObjectName functionName = new QualifiedObjectName(connectorId.getCatalogName(), function.getSchema(), function.getName()); + //accessControl.checkCanExecuteFunction(SecurityContext.of(session), functionName); + + Node errorLocation = node; + if (!node.getArguments().isEmpty()) { + errorLocation = node.getArguments().get(0); + } + ArgumentsAnalysis argumentsAnalysis = analyzeArguments(function.getArguments(), node.getArguments(), scope, errorLocation); + + CatalogMetadata registrationCatalogMetadata = transactionManager.getOptionalCatalogMetadata(session.getRequiredTransactionId(), connectorId.getCatalogName()).orElseThrow(() -> new IllegalStateException("Missing catalog metadata")); + ConnectorTransactionHandle transactionHandle = transactionManager.getConnectorTransaction( + session.getRequiredTransactionId(), registrationCatalogMetadata.getConnectorId()); + + TableFunctionAnalysis functionAnalysis = function.analyze(session.toConnectorSession(connectorId), transactionHandle, argumentsAnalysis.getPassedArguments()); + List> copartitioningLists = analyzeCopartitioning(node.getCopartitioning(), argumentsAnalysis.getTableArgumentAnalyses()); + + // determine the result relation type per SQL standard ISO/IEC 9075-2, 4.33 SQL-invoked routines, p. 123, 413, 414 + ReturnTypeSpecification returnTypeSpecification = function.getReturnTypeSpecification(); + if (returnTypeSpecification == GENERIC_TABLE || !argumentsAnalysis.getTableArgumentAnalyses().isEmpty()) { + analysis.addPolymorphicTableFunction(node); + } + Optional analyzedProperColumnsDescriptor = functionAnalysis.getReturnedType(); + Descriptor properColumnsDescriptor; + if (returnTypeSpecification == ONLY_PASS_THROUGH) { + if (analysis.isAliased(node)) { + // According to SQL standard ISO/IEC 9075-2, 7.6 , p. 409, + // table alias is prohibited for a table function with ONLY PASS THROUGH returned type. + throw new SemanticException(INVALID_TABLE_FUNCTION_INVOCATION, node, "Alias specified for table function with ONLY PASS THROUGH return type"); + } + if (analyzedProperColumnsDescriptor.isPresent()) { + // If a table function has ONLY PASS THROUGH returned type, it does not produce any proper columns, + // so the function's analyze() method should not return the proper columns descriptor. + throw new SemanticException(AMBIGUOUS_RETURN_TYPE, node, "Returned relation type for table function %s is ambiguous", node.getName()); + } + properColumnsDescriptor = null; + } + else if (returnTypeSpecification == GENERIC_TABLE) { + // According to SQL standard ISO/IEC 9075-2, 7.6
, p. 409, + // table alias is mandatory for a polymorphic table function invocation which produces proper columns. + // We don't enforce this requirement. + properColumnsDescriptor = analyzedProperColumnsDescriptor + .orElseThrow(() -> new SemanticException(MISSING_RETURN_TYPE, node, "Cannot determine returned relation type for table function " + node.getName())); + } + else { + // returned type is statically declared at function declaration and cannot be overridden + // According to SQL standard ISO/IEC 9075-2, 7.6
, p. 409, + // table alias is mandatory for a polymorphic table function invocation which produces proper columns. + // We don't enforce this requirement. + if (analyzedProperColumnsDescriptor.isPresent()) { + // If a table function has statically declared returned type, it is returned in TableFunctionMetadata + // so the function's analyze() method should not return the proper columns descriptor. + throw new SemanticException(AMBIGUOUS_RETURN_TYPE, node, "Returned relation type for table function %s is ambiguous", node.getName()); + } + properColumnsDescriptor = ((ReturnTypeSpecification.DescribedTable) returnTypeSpecification).getDescriptor(); + } + + // validate the required input columns + Map> requiredColumns = functionAnalysis.getRequiredColumns(); + Map tableArgumentsByName = argumentsAnalysis.getTableArgumentAnalyses().stream() + .collect(toImmutableMap(TableArgumentAnalysis::getArgumentName, Function.identity())); + Set allInputs = ImmutableSet.copyOf(tableArgumentsByName.keySet()); + requiredColumns.forEach((name, columns) -> { + if (!allInputs.contains(name)) { + throw new SemanticException(FUNCTION_IMPLEMENTATION_ERROR, "Table function %s specifies required columns from table argument %s which cannot be found", node.getName(), name); + } + if (columns.isEmpty()) { + throw new SemanticException(FUNCTION_IMPLEMENTATION_ERROR, "Table function %s specifies empty list of required columns from table argument %s", node.getName(), name); + } + // the scope is recorded, because table arguments are already analyzed + Scope inputScope = analysis.getScope(tableArgumentsByName.get(name).getRelation()); + columns.stream() + .filter(column -> column < 0 || column >= inputScope.getRelationType().getAllFieldCount()) // hidden columns can be required as well as visible columns + .findFirst() + .ifPresent(column -> { + throw new SemanticException(FUNCTION_IMPLEMENTATION_ERROR, "Invalid index: %s of required column from table argument %s", column, name); + }); + }); + Set requiredInputs = ImmutableSet.copyOf(requiredColumns.keySet()); + allInputs.stream() + .filter(input -> !requiredInputs.contains(input)) + .findFirst() + .ifPresent(input -> { + throw new SemanticException(FUNCTION_IMPLEMENTATION_ERROR, "Table function %s does not specify required input columns from table argument %s", node.getName(), input); + }); + + // The result relation type of a table function consists of: + // 1. columns created by the table function, called the proper columns. + // 2. passed columns from input tables: + // - for tables with the "pass through columns" option, these are all columns of the table, + // - for tables without the "pass through columns" option, these are the partitioning columns of the table, if any. + ImmutableList.Builder fields = ImmutableList.builder(); + + // proper columns first + if (properColumnsDescriptor != null) { + properColumnsDescriptor.getFields().stream() + // per spec, field names are mandatory + .map(field -> Field.newUnqualified(Optional.empty(), field.getName(), field.getType().orElseThrow(() -> new IllegalStateException("missing returned type for proper field")))) + .forEach(fields::add); + } + + // next, columns derived from table arguments, in order of argument declarations + List tableArgumentNames = function.getArguments().stream() + .filter(argumentSpecification -> argumentSpecification instanceof TableArgumentSpecification) + .map(ArgumentSpecification::getName) + .collect(toImmutableList()); + + // table arguments in order of argument declarations + ImmutableList.Builder orderedTableArguments = ImmutableList.builder(); + + for (String name : tableArgumentNames) { + TableArgumentAnalysis argument = tableArgumentsByName.get(name); + orderedTableArguments.add(argument); + Scope argumentScope = analysis.getScope(argument.getRelation()); + if (argument.isPassThroughColumns()) { + argumentScope.getRelationType().getAllFields().stream() + .forEach(fields::add); + } + else if (argument.getPartitionBy().isPresent()) { + argument.getPartitionBy().get().stream() + .map(expression -> validateAndGetInputField(expression, argumentScope)) + .forEach(fields::add); + } + } + + analysis.setTableFunctionAnalysis(node, new TableFunctionInvocationAnalysis( + connectorId, + function.getSchema(), + function.getName(), + argumentsAnalysis.getPassedArguments(), + orderedTableArguments.build(), + functionAnalysis.getRequiredColumns(), + copartitioningLists, + properColumnsDescriptor == null ? 0 : properColumnsDescriptor.getFields().size(), + functionAnalysis.getHandle(), + transactionHandle)); + + return createAndAssignScope(node, scope, fields.build()); + } + + private ArgumentsAnalysis analyzeArguments(List argumentSpecifications, List arguments, Optional scope, Node errorLocation) + { + if (argumentSpecifications.size() < arguments.size()) { + throw new SemanticException(INVALID_ARGUMENTS, errorLocation, "Too many arguments. Expected at most %s arguments, got %s arguments", argumentSpecifications.size(), arguments.size()); + } + + if (argumentSpecifications.isEmpty()) { + return new ArgumentsAnalysis(ImmutableMap.of(), ImmutableList.of()); + } + + boolean argumentsPassedByName = !arguments.isEmpty() && arguments.stream().allMatch(argument -> argument.getName().isPresent()); + boolean argumentsPassedByPosition = arguments.stream().allMatch(argument -> !argument.getName().isPresent()); + if (!argumentsPassedByName && !argumentsPassedByPosition) { + throw new SemanticException(INVALID_ARGUMENTS, errorLocation, "All arguments must be passed by name or all must be passed positionally"); + } + + ImmutableMap.Builder passedArguments = ImmutableMap.builder(); + ImmutableList.Builder tableArgumentAnalyses = ImmutableList.builder(); + if (argumentsPassedByName) { + Map argumentSpecificationsByName = new HashMap<>(); + for (ArgumentSpecification argumentSpecification : argumentSpecifications) { + if (argumentSpecificationsByName.put(argumentSpecification.getName(), argumentSpecification) != null) { + // this should never happen, because the argument names are validated at function registration time + throw new IllegalStateException("Duplicate argument specification for name: " + argumentSpecification.getName()); + } + } + Set uniqueArgumentNames = new HashSet<>(); + for (TableFunctionArgument argument : arguments) { + String argumentName = argument.getName().orElseThrow(() -> new IllegalStateException("Missing table function argument name")).getCanonicalValue(); + if (!uniqueArgumentNames.add(argumentName)) { + throw new SemanticException(INVALID_FUNCTION_ARGUMENT, argument, "Duplicate argument name: " + argumentName); + } + ArgumentSpecification argumentSpecification = argumentSpecificationsByName.remove(argumentName); + if (argumentSpecification == null) { + throw new SemanticException(INVALID_FUNCTION_ARGUMENT, argument, "Unexpected argument name: " + argumentName); + } + ArgumentAnalysis argumentAnalysis = analyzeArgument(argumentSpecification, argument, scope); + passedArguments.put(argumentSpecification.getName(), argumentAnalysis.getArgument()); + argumentAnalysis.getTableArgumentAnalysis().ifPresent(tableArgumentAnalyses::add); + } + // apply defaults for not specified arguments + for (Map.Entry entry : argumentSpecificationsByName.entrySet()) { + ArgumentSpecification argumentSpecification = entry.getValue(); + passedArguments.put(argumentSpecification.getName(), analyzeDefault(argumentSpecification, errorLocation)); + } + } + else { + for (int i = 0; i < arguments.size(); i++) { + TableFunctionArgument argument = arguments.get(i); + ArgumentSpecification argumentSpecification = argumentSpecifications.get(i); // TODO args passed positionally - can one only pass some prefix of args? + ArgumentAnalysis argumentAnalysis = analyzeArgument(argumentSpecification, argument, scope); + passedArguments.put(argumentSpecification.getName(), argumentAnalysis.getArgument()); + argumentAnalysis.getTableArgumentAnalysis().ifPresent(tableArgumentAnalyses::add); + } + // apply defaults for not specified arguments + for (int i = arguments.size(); i < argumentSpecifications.size(); i++) { + ArgumentSpecification argumentSpecification = argumentSpecifications.get(i); + passedArguments.put(argumentSpecification.getName(), analyzeDefault(argumentSpecification, errorLocation)); + } + } + + return new ArgumentsAnalysis(passedArguments.buildOrThrow(), tableArgumentAnalyses.build()); + } + + private ArgumentAnalysis analyzeArgument(ArgumentSpecification argumentSpecification, TableFunctionArgument argument, Optional scope) + { + String actualType; + if (argument.getValue() instanceof TableFunctionTableArgument) { + actualType = "table"; + } + else if (argument.getValue() instanceof TableFunctionDescriptorArgument) { + actualType = "descriptor"; + } + else if (argument.getValue() instanceof Expression) { + actualType = "expression"; + } + else { + throw new SemanticException(INVALID_FUNCTION_ARGUMENT, argument, "Unexpected table function argument type: ", argument.getClass().getSimpleName()); + } + + if (argumentSpecification instanceof TableArgumentSpecification) { + if (!(argument.getValue() instanceof TableFunctionTableArgument)) { + if (argument.getValue() instanceof FunctionCall) { + // probably an attempt to pass a table function call, which is not supported, and was parsed as a function call + throw new SemanticException(NOT_SUPPORTED, argument, "Invalid table argument %s. Table functions are not allowed as table function arguments", argumentSpecification.getName()); + } + throw new SemanticException(INVALID_FUNCTION_ARGUMENT, argument, "Invalid argument %s. Expected table, got %s", argumentSpecification.getName(), actualType); + } + return analyzeTableArgument(argument, (TableArgumentSpecification) argumentSpecification, scope); + } + if (argumentSpecification instanceof DescriptorArgumentSpecification) { + if (!(argument.getValue() instanceof TableFunctionDescriptorArgument)) { + if (argument.getValue() instanceof FunctionCall && ((FunctionCall) argument.getValue()).getName().hasSuffix(QualifiedName.of("descriptor"))) { // function name is always compared case-insensitive + // malformed descriptor which parsed as a function call + throw new SemanticException(INVALID_FUNCTION_ARGUMENT, argument, "Invalid descriptor argument %s. Descriptors should be formatted as 'DESCRIPTOR(name [type], ...)'", (Object) argumentSpecification.getName()); + } + throw new SemanticException(INVALID_FUNCTION_ARGUMENT, argument, "Invalid argument %s. Expected descriptor, got %s", argumentSpecification.getName(), actualType); + } + return analyzeDescriptorArgument((TableFunctionDescriptorArgument) argument.getValue()); + } + if (argumentSpecification instanceof ScalarArgumentSpecification) { + if (!(argument.getValue() instanceof Expression)) { + throw new SemanticException(INVALID_FUNCTION_ARGUMENT, argument, "Invalid argument %s. Expected expression, got %s", argumentSpecification.getName(), actualType); + } + Expression expression = (Expression) argument.getValue(); + // 'descriptor' as a function name is not allowed in this context + if (expression instanceof FunctionCall && ((FunctionCall) expression).getName().hasSuffix(QualifiedName.of("descriptor"))) { // function name is always compared case-insensitive + throw new SemanticException(INVALID_FUNCTION_ARGUMENT, argument, "'descriptor' function is not allowed as a table function argument"); + } + return analyzeScalarArgument(expression, ((ScalarArgumentSpecification) argumentSpecification).getType()); + } + throw new IllegalStateException("Unexpected argument specification: " + argumentSpecification.getClass().getSimpleName()); + } + + private Argument analyzeDefault(ArgumentSpecification argumentSpecification, Node errorLocation) + { + if (argumentSpecification.isRequired()) { + throw new SemanticException(MISSING_ARGUMENT, errorLocation, "Missing argument: " + argumentSpecification.getName()); + } + + checkArgument(!(argumentSpecification instanceof TableArgumentSpecification), "invalid table argument specification: default set"); + + if (argumentSpecification instanceof DescriptorArgumentSpecification) { + return DescriptorArgument.builder() + .descriptor((Descriptor) argumentSpecification.getDefaultValue()) + .build(); + } + if (argumentSpecification instanceof ScalarArgumentSpecification) { + return ScalarArgument.builder() + .type(((ScalarArgumentSpecification) argumentSpecification).getType()) + .value(argumentSpecification.getDefaultValue()) + .build(); + } + + throw new IllegalStateException("Unexpected argument specification: " + argumentSpecification.getClass().getSimpleName()); + } + + private ArgumentAnalysis analyzeTableArgument(TableFunctionArgument argument, TableArgumentSpecification argumentSpecification, Optional scope) + { + TableFunctionTableArgument tableArgument = (TableFunctionTableArgument) argument.getValue(); + + TableArgument.Builder argumentBuilder = TableArgument.builder(); + TableArgumentAnalysis.Builder analysisBuilder = TableArgumentAnalysis.builder(); + analysisBuilder.withArgumentName(argumentSpecification.getName()); + + // process the relation + Relation relation = tableArgument.getTable(); + analysisBuilder.withRelation(relation); + Scope argumentScope = process(relation, scope); + QualifiedName relationName = analysis.getRelationName(relation); + if (relationName != null) { + analysisBuilder.withName(relationName); + } + + argumentBuilder.rowType(RowType.from(argumentScope.getRelationType().getVisibleFields().stream() + .map(field -> new RowType.Field(field.getName(), field.getType())) + .collect(toImmutableList()))); + + // analyze PARTITION BY + if (tableArgument.getPartitionBy().isPresent()) { + if (argumentSpecification.isRowSemantics()) { + throw new SemanticException(INVALID_FUNCTION_ARGUMENT, argument, "Invalid argument %s. Partitioning specified for table argument with row semantics", argumentSpecification.getName()); + } + List partitionBy = tableArgument.getPartitionBy().get(); + analysisBuilder.withPartitionBy(partitionBy); + partitionBy.stream() + .forEach(partitioningColumn -> { + validateAndGetInputField(partitioningColumn, argumentScope); + Type type = analyzeExpression(partitioningColumn, argumentScope).getType(partitioningColumn); + if (!type.isComparable()) { + throw new SemanticException(TYPE_MISMATCH, partitioningColumn, "%s is not comparable, and therefore cannot be used in PARTITION BY", type); + } + }); + argumentBuilder.partitionBy(partitionBy.stream() + // each expression is either an Identifier or a DereferenceExpression + .map(Expression::toString) + .collect(toImmutableList())); + } + + // analyze ORDER BY + if (tableArgument.getOrderBy().isPresent()) { + if (argumentSpecification.isRowSemantics()) { + throw new SemanticException(INVALID_FUNCTION_ARGUMENT, argument, "Invalid argument %s. Ordering specified for table argument with row semantics", argumentSpecification.getName()); + } + OrderBy orderBy = tableArgument.getOrderBy().get(); + analysisBuilder.withOrderBy(orderBy); + orderBy.getSortItems().stream() + .map(SortItem::getSortKey) + .forEach(orderingColumn -> { + validateAndGetInputField(orderingColumn, argumentScope); + Type type = analyzeExpression(orderingColumn, argumentScope).getType(orderingColumn); + if (!type.isOrderable()) { + throw new SemanticException(TYPE_MISMATCH, orderingColumn, "%s is not orderable, and therefore cannot be used in ORDER BY", type); + } + }); + argumentBuilder.orderBy(orderBy.getSortItems().stream() + // each sort key is either an Identifier or a DereferenceExpression + .map(sortItem -> sortItem.getSortKey().toString()) + .collect(toImmutableList())); + } + + // analyze the PRUNE/KEEP WHEN EMPTY property + boolean pruneWhenEmpty = argumentSpecification.isPruneWhenEmpty(); + if (tableArgument.getEmptyTableTreatment().isPresent()) { + if (argumentSpecification.isRowSemantics()) { + throw new SemanticException(INVALID_FUNCTION_ARGUMENT, tableArgument.getEmptyTableTreatment().get(), "Invalid argument %s. Empty behavior specified for table argument with row semantics", argumentSpecification.getName()); + } + pruneWhenEmpty = tableArgument.getEmptyTableTreatment().get().getTreatment() == EmptyTableTreatment.Treatment.PRUNE; + } + analysisBuilder.withPruneWhenEmpty(pruneWhenEmpty); + + // record remaining properties + analysisBuilder.withRowSemantics(argumentSpecification.isRowSemantics()); + analysisBuilder.withPassThroughColumns(argumentSpecification.isPassThroughColumns()); + + return new ArgumentAnalysis(argumentBuilder.build(), Optional.of(analysisBuilder.build())); + } + + private ArgumentAnalysis analyzeDescriptorArgument(TableFunctionDescriptorArgument argument) + { + return new ArgumentAnalysis( + argument.getDescriptor() + .map(descriptor -> DescriptorArgument.builder() + .descriptor(new Descriptor(descriptor.getFields().stream() + .map(field -> new Descriptor.Field( + field.getName().getCanonicalValue(), + field.getType().map(type -> { + try { + return functionAndTypeResolver.getType(parseTypeSignature(type)); + } + catch (IllegalArgumentException e) { + throw new SemanticException(TYPE_MISMATCH, field, "Unknown type: %s", type); + } + }))) + .collect(toImmutableList()))) + .build()) + .orElse(NULL_DESCRIPTOR), + Optional.empty()); + } + + private Field validateAndGetInputField(Expression expression, Scope inputScope) + { + QualifiedName qualifiedName; + if (expression instanceof Identifier) { + qualifiedName = QualifiedName.of(ImmutableList.of(((Identifier) expression).getValue())); + } + else if (expression instanceof DereferenceExpression) { + qualifiedName = getQualifiedName((DereferenceExpression) expression); + } + else { + throw new SemanticException(INVALID_COLUMN_REFERENCE, expression, "Expected column reference. Actual: %s", expression); + } + Optional field = inputScope.tryResolveField(expression, qualifiedName); + if (!field.isPresent() || !field.get().isLocal()) { + throw new SemanticException(COLUMN_NOT_FOUND, expression, "Column %s is not present in the input relation", expression); + } + + return field.get().getField(); + } + + private ArgumentAnalysis analyzeScalarArgument(Expression expression, Type type) + { + // inline parameters + Expression inlined = ExpressionTreeRewriter.rewriteWith(new ExpressionRewriter() + { + @Override + public Expression rewriteParameter(Parameter node, Void context, ExpressionTreeRewriter treeRewriter) + { + if (analysis.isDescribe()) { + // We cannot handle DESCRIBE when a table function argument involves a parameter. + // In DESCRIBE, the parameter values are not known. We cannot pass a dummy value for a parameter. + // The value of a table function argument can affect the returned relation type. The returned + // relation type can affect the assumed types for other parameters in the query. + throw new SemanticException(NOT_SUPPORTED, node, "DESCRIBE is not supported if a table function uses parameters"); + } + return analysis.getParameters().get(NodeRef.of(node)); + } + }, expression); + // currently, only constant arguments are supported + Object constantValue = ExpressionInterpreter.evaluateConstantExpression(inlined, type, metadata, session, analysis.getParameters()); + return new ArgumentAnalysis( + ScalarArgument.builder() + .type(type) + .value(constantValue) + .build(), + Optional.empty()); + } + + private List> analyzeCopartitioning(List> copartitioning, List tableArgumentAnalyses) + { + // map table arguments by relation names. usa a multimap, because multiple arguments can have the same value, e.g. input_1 => tpch.tiny.orders, input_2 => tpch.tiny.orders + ImmutableMultimap.Builder unqualifiedInputsBuilder = ImmutableMultimap.builder(); + ImmutableMultimap.Builder qualifiedInputsBuilder = ImmutableMultimap.builder(); + tableArgumentAnalyses.stream() + .filter(argument -> argument.getName().isPresent()) + .forEach(argument -> { + QualifiedName name = argument.getName().get(); + if (name.getParts().size() == 1) { + unqualifiedInputsBuilder.put(name, argument); + } + else if (name.getParts().size() == 3) { + qualifiedInputsBuilder.put(name, argument); + } + else { + throw new IllegalStateException("relation name should be unqualified or fully qualified"); + } + }); + Multimap unqualifiedInputs = unqualifiedInputsBuilder.build(); + Multimap qualifiedInputs = qualifiedInputsBuilder.build(); + + ImmutableList.Builder> copartitionBuilder = ImmutableList.builder(); + Set referencedArguments = new HashSet<>(); + for (List nameList : copartitioning) { + ImmutableList.Builder copartitionListBuilder = ImmutableList.builder(); + + // resolve copartition tables as references to table arguments + for (QualifiedName name : nameList) { + Collection candidates = emptyList(); + if (name.getParts().size() == 1) { + // try to match unqualified name. it might be a reference to a CTE or an aliased relation + candidates = unqualifiedInputs.get(name); + } + if (candidates.isEmpty()) { + // qualify the name using current schema and catalog + // Since we lost the Identifier context, create a new one here + QualifiedObjectName fullyQualifiedName = createQualifiedObjectName(session, new Identifier(name.getOriginalParts().get(0)), name); + candidates = qualifiedInputs.get(QualifiedName.of(fullyQualifiedName.getCatalogName(), fullyQualifiedName.getSchemaName(), fullyQualifiedName.getObjectName())); + } + if (candidates.isEmpty()) { + throw new SemanticException(INVALID_COPARTITIONING, "No table argument found for name: " + name); + } + if (candidates.size() > 1) { + throw new SemanticException(INVALID_COPARTITIONING, "Ambiguous reference: multiple table arguments found for name: " + name); + } + TableArgumentAnalysis argument = getOnlyElement(candidates); + if (!referencedArguments.add(argument.getArgumentName())) { + // multiple references to argument in COPARTITION clause are implicitly prohibited by + // ISO/IEC TR REPORT 19075-7, p.33, Feature B203, “More than one copartition specification” + throw new SemanticException(INVALID_COPARTITIONING, "Multiple references to table argument: %s in COPARTITION clause", name); + } + copartitionListBuilder.add(argument); + } + List copartitionList = copartitionListBuilder.build(); + + // analyze partitioning columns + copartitionList.stream() + .filter(argument -> !argument.getPartitionBy().isPresent()) + .findFirst().ifPresent(unpartitioned -> { + throw new SemanticException(INVALID_COPARTITIONING, unpartitioned.getRelation(), "Table %s referenced in COPARTITION clause is not partitioned", unpartitioned.getName().orElseThrow(() -> new IllegalStateException("Missing unpartitioned TableArgumentAnalysis name"))); + }); + // TODO make sure that copartitioned tables cannot have empty partitioning lists. + // ISO/IEC TR REPORT 19075-7, 4.5 Partitioning and ordering, p.25 is not clear: "With copartitioning, the copartitioned table arguments must have the same number of partitioning columns, + // and corresponding partitioning columns must be comparable. The DBMS effectively performs a full outer equijoin on the copartitioning columns" + copartitionList.stream() + .filter(argument -> argument.getPartitionBy().orElseThrow(() -> new IllegalStateException("PartitionBy not present in copartitionList")).isEmpty()) + .findFirst().ifPresent(partitionedOnEmpty -> { + // table is partitioned but no partitioning columns are specified (single partition) + throw new SemanticException(INVALID_COPARTITIONING, partitionedOnEmpty.getRelation(), "No partitioning columns specified for table %s referenced in COPARTITION clause", partitionedOnEmpty.getName().orElseThrow(() -> new IllegalStateException("Missing partitionedOnEmpty TableArgumentAnalysis name"))); + }); + List> partitioningColumns = copartitionList.stream() + .map(TableArgumentAnalysis::getPartitionBy) + .map(opt -> opt.orElseThrow(() -> new IllegalStateException("PartitionBy not present in partitioningColumns"))) + .collect(toImmutableList()); + if (partitioningColumns.stream() + .map(List::size) + .distinct() + .count() > 1) { + throw new SemanticException(INVALID_COPARTITIONING, "Numbers of partitioning columns in copartitioned tables do not match"); + } + + // coerce corresponding copartition columns to common supertype + for (int index = 0; index < partitioningColumns.get(0).size(); index++) { + Type commonSuperType = analysis.getType(partitioningColumns.get(0).get(index)); + // find common supertype + for (List columnList : partitioningColumns) { + Optional superType = functionAndTypeResolver.getCommonSuperType(commonSuperType, analysis.getType(columnList.get(index))); + if (!superType.isPresent()) { + throw new SemanticException(TYPE_MISMATCH, "Partitioning columns in copartitioned tables have incompatible types"); + } + commonSuperType = superType.get(); + } + for (List columnList : partitioningColumns) { + Expression column = columnList.get(index); + Type type = analysis.getType(column); + if (!type.equals(commonSuperType)) { + if (!functionAndTypeResolver.canCoerce(type, commonSuperType)) { + throw new SemanticException(TYPE_MISMATCH, column, "Cannot coerce column of type %s to common supertype: %s", type.getDisplayName(), commonSuperType.getDisplayName()); + } + analysis.addCoercion(column, commonSuperType, functionAndTypeResolver.isTypeOnlyCoercion(type, commonSuperType)); + } + } + } + + // record the resolved copartition arguments by argument names + copartitionBuilder.add(copartitionList.stream() + .map(TableArgumentAnalysis::getArgumentName) + .collect(toImmutableList())); + } + + return copartitionBuilder.build(); + } + @Override protected Scope visitTable(Table table, Optional scope) { @@ -1258,6 +1856,7 @@ protected Scope visitTable(Table table, Optional scope) if (withQuery.isPresent()) { Query query = withQuery.get().getQuery(); analysis.registerNamedQuery(table, query, false); + analysis.setRelationName(table, table.getName()); // re-alias the fields with the name assigned to the query in the WITH declaration RelationType queryDescriptor = analysis.getOutputDescriptor(query); @@ -1297,12 +1896,12 @@ protected Scope visitTable(Table table, Optional scope) field.isAliased())) .collect(toImmutableList()); } - return createAndAssignScope(table, scope, fields); } } QualifiedObjectName name = createQualifiedObjectName(session, table, table.getName()); + analysis.setRelationName(table, QualifiedName.of(name.getCatalogName(), name.getSchemaName(), name.getObjectName())); if (name.getObjectName().isEmpty()) { throw new SemanticException(MISSING_TABLE, table, "Table name is empty"); } @@ -1423,7 +2022,7 @@ private Optional processTableVersion(Table table, QualifiedObjectNa analysis.recordSubqueries(table, expressionAnalysis); Type stateExprType = expressionAnalysis.getType(stateExpr); if (stateExprType == UNKNOWN) { - throw new PrestoException(INVALID_ARGUMENTS, format("Table version AS OF/BEFORE expression cannot be NULL for %s", name.toString())); + throw new PrestoException(StandardErrorCode.INVALID_ARGUMENTS, format("Table version AS OF/BEFORE expression cannot be NULL for %s", name.toString())); } Object evalStateExpr = evaluateConstantExpression(stateExpr, stateExprType, metadata, session, analysis.getParameters()); if (tableVersionType == TIMESTAMP) { @@ -1651,10 +2250,18 @@ private MaterializedViewStatus getMaterializedViewStatus(QualifiedObjectName mat @Override protected Scope visitAliasedRelation(AliasedRelation relation, Optional scope) { + analysis.setRelationName(relation, QualifiedName.of(relation.getAlias().getValue())); + analysis.addAliased(relation.getRelation()); Scope relationScope = process(relation.getRelation(), scope); + RelationType relationType = relationScope.getRelationType(); + + // special-handle table function invocation + if (relation.getRelation() instanceof TableFunctionInvocation) { + return createAndAssignScope(relation, scope, + aliasTableFunctionInvocation(relation, relationType, (TableFunctionInvocation) relation.getRelation())); + } // todo this check should be inside of TupleDescriptor.withAlias, but the exception needs the node object - RelationType relationType = relationScope.getRelationType(); if (relation.getColumnNames() != null) { int totalColumns = relationType.getVisibleFieldCount(); if (totalColumns != relation.getColumnNames().size()) { @@ -1674,6 +2281,85 @@ protected Scope visitAliasedRelation(AliasedRelation relation, Optional s return createAndAssignScope(relation, scope, descriptor); } + // As described by the SQL standard ISO/IEC 9075-2, 7.6
, p. 409 + private RelationType aliasTableFunctionInvocation(AliasedRelation relation, RelationType relationType, TableFunctionInvocation function) + { + TableFunctionInvocationAnalysis tableFunctionAnalysis = analysis.getTableFunctionAnalysis(function); + int properColumnsCount = tableFunctionAnalysis.getProperColumnsCount(); + + // check that relation alias is different from range variables of all table arguments + tableFunctionAnalysis.getTableArgumentAnalyses().stream() + .map(TableArgumentAnalysis::getName) + .filter(Optional::isPresent) + .map(Optional::get) + .filter(name -> name.hasSuffix(QualifiedName.of(ImmutableList.of(relation.getAlias().getValue())))) + .findFirst() + .ifPresent(name -> { + throw new SemanticException(DUPLICATE_RANGE_VARIABLE, relation.getAlias(), "Relation alias: %s is a duplicate of input table name: %s", relation.getAlias(), name); + }); + + // build the new relation type. the alias must be applied to the proper columns only, + // and it must not shadow the range variables exposed by the table arguments + ImmutableList.Builder fieldsBuilder = ImmutableList.builder(); + // first, put the table function's proper columns with alias + if (relation.getColumnNames() != null) { + // check that number of column aliases matches number of table function's proper columns + if (properColumnsCount != relation.getColumnNames().size()) { + throw new SemanticException(MISMATCHED_COLUMN_ALIASES, relation, "Column alias list has %s entries but table function has %s proper columns", relation.getColumnNames().size(), properColumnsCount); + } + for (int i = 0; i < properColumnsCount; i++) { + // proper columns are not hidden, so we don't need to skip hidden fields + Field field = relationType.getFieldByIndex(i); + fieldsBuilder.add(Field.newQualified( + field.getNodeLocation(), + QualifiedName.of(ImmutableList.of(relation.getAlias().getValue())), + Optional.of(relation.getColumnNames().get(i).getCanonicalValue()), // although the canonical name is recorded, fields are resolved case-insensitive + field.getType(), + field.isHidden(), + field.getOriginTable(), + field.getOriginColumnName(), + field.isAliased())); + } + } + else { + for (int i = 0; i < properColumnsCount; i++) { + Field field = relationType.getFieldByIndex(i); + fieldsBuilder.add(Field.newQualified( + field.getNodeLocation(), + QualifiedName.of(ImmutableList.of(relation.getAlias().getValue())), + field.getName(), + field.getType(), + field.isHidden(), + field.getOriginTable(), + field.getOriginColumnName(), + field.isAliased())); + } + } + + // append remaining fields. They are not being aliased, so hidden fields are included + for (int i = properColumnsCount; i < relationType.getAllFieldCount(); i++) { + fieldsBuilder.add(relationType.getFieldByIndex(i)); + } + + List fields = fieldsBuilder.build(); + + // check that there are no duplicate names within the table function's proper columns + Set names = new HashSet<>(); + fields.subList(0, properColumnsCount).stream() + .map(Field::getName) + .filter(Optional::isPresent) + .map(Optional::get) + // field names are resolved case-insensitive + .map(name -> name.toLowerCase(ENGLISH)) + .forEach(name -> { + if (!names.add(name)) { + throw new SemanticException(DUPLICATE_COLUMN_NAME, relation.getRelation(), "Duplicate name of table function proper column: " + name); + } + }); + + return new RelationType(fields); + } + @Override protected Scope visitSampledRelation(SampledRelation relation, Optional scope) { @@ -1717,13 +2403,37 @@ protected Scope visitSampledRelation(SampledRelation relation, Optional s analysis.setSampleRatio(relation, samplePercentageValue / 100); Scope relationScope = process(relation.getRelation(), scope); + + // TABLESAMPLE cannot be applied to a polymorphic table function (SQL standard ISO/IEC 9075-2, 7.6
, p. 409) + // Note: the below method finds a table function immediately nested in SampledRelation, or aliased. + // Potentially, a table function could be also nested with intervening PatternRecognitionRelation. + // Such case is handled in visitPatternRecognitionRelation(). + validateNoNestedTableFunction(relation.getRelation(), "sample"); + return createAndAssignScope(relation, scope, relationScope.getRelationType()); } + // this method should run after the `base` relation is processed, so that it is + // determined whether the table function is polymorphic + private void validateNoNestedTableFunction(Relation base, String context) + { + TableFunctionInvocation tableFunctionInvocation = null; + if (base instanceof TableFunctionInvocation) { + tableFunctionInvocation = (TableFunctionInvocation) base; + } + else if (base instanceof AliasedRelation && + ((AliasedRelation) base).getRelation() instanceof TableFunctionInvocation) { + tableFunctionInvocation = (TableFunctionInvocation) ((AliasedRelation) base).getRelation(); + } + if (tableFunctionInvocation != null && analysis.isPolymorphicTableFunction(tableFunctionInvocation)) { + throw new SemanticException(INVALID_TABLE_FUNCTION_INVOCATION, base, "Cannot apply %s to polymorphic table function invocation", context); + } + } + @Override protected Scope visitTableSubquery(TableSubquery node, Optional scope) { - StatementAnalyzer analyzer = new StatementAnalyzer(analysis, metadata, sqlParser, accessControl, session, warningCollector); + StatementAnalyzer analyzer = new StatementAnalyzer(analysis, metadata, sqlParser, accessControl, transactionManager, session, warningCollector); Scope queryScope = analyzer.analyze(node.getQuery(), scope); return createAndAssignScope(node, scope, queryScope.getRelationType()); } @@ -2085,6 +2795,7 @@ protected Scope visitUpdate(Update update, Optional scope) metadata, sqlParser, new AllowAllAccessControl(), + transactionManager, session, warningCollector); @@ -2465,7 +3176,7 @@ public Expression rewriteIdentifier(Identifier reference, Void context, Expressi } if (expressions.size() == 1) { - return Iterables.getOnlyElement(expressions); + return getOnlyElement(expressions); } // otherwise, couldn't resolve name against output aliases, so fall through... @@ -2636,7 +3347,7 @@ else if (item instanceof SingleColumn) { name = QualifiedName.of(((Identifier) expression).getValue()); } else if (expression instanceof DereferenceExpression) { - name = DereferenceExpression.getQualifiedName((DereferenceExpression) expression); + name = getQualifiedName((DereferenceExpression) expression); } if (name != null) { @@ -2889,7 +3600,7 @@ private RelationType analyzeView(Query query, QualifiedObjectName name, Optional .setStartTime(session.getStartTime()); session.getConnectorProperties().forEach((connectorId, properties) -> properties.forEach((k, v) -> viewSessionBuilder.setConnectionProperty(connectorId, k, v))); Session viewSession = viewSessionBuilder.build(); - StatementAnalyzer analyzer = new StatementAnalyzer(analysis, metadata, sqlParser, viewAccessControl, viewSession, warningCollector); + StatementAnalyzer analyzer = new StatementAnalyzer(analysis, metadata, sqlParser, viewAccessControl, transactionManager, viewSession, warningCollector); Scope queryScope = analyzer.analyze(query, Scope.create()); return queryScope.getRelationType().withAlias(name.getObjectName(), null); } @@ -3118,4 +3829,48 @@ private static boolean hasScopeAsLocalParent(Scope root, Scope parent) return false; } + + private static final class ArgumentAnalysis + { + private final Argument argument; + private final Optional tableArgumentAnalysis; + + public ArgumentAnalysis(Argument argument, Optional tableArgumentAnalysis) + { + this.argument = requireNonNull(argument, "argument is null"); + this.tableArgumentAnalysis = requireNonNull(tableArgumentAnalysis, "tableArgumentAnalysis is null"); + } + + public Argument getArgument() + { + return argument; + } + + public Optional getTableArgumentAnalysis() + { + return tableArgumentAnalysis; + } + } + + private static final class ArgumentsAnalysis + { + private final Map passedArguments; + private final List tableArgumentAnalyses; + + public ArgumentsAnalysis(Map passedArguments, List tableArgumentAnalyses) + { + this.passedArguments = ImmutableMap.copyOf(requireNonNull(passedArguments, "passedArguments is null")); + this.tableArgumentAnalyses = ImmutableList.copyOf(requireNonNull(tableArgumentAnalyses, "tableArgumentAnalyses is null")); + } + + public Map getPassedArguments() + { + return passedArguments; + } + + public List getTableArgumentAnalyses() + { + return tableArgumentAnalyses; + } + } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/BasePlanFragmenter.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/BasePlanFragmenter.java index 348e031539142..88e95b7252a61 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/BasePlanFragmenter.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/BasePlanFragmenter.java @@ -47,6 +47,8 @@ import com.facebook.presto.sql.planner.plan.SequenceNode; import com.facebook.presto.sql.planner.plan.SimplePlanRewriter; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.facebook.presto.sql.planner.sanity.PlanChecker; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; @@ -270,6 +272,22 @@ public PlanNode visitValues(ValuesNode node, RewriteContext return context.defaultRewrite(node, context.get()); } + @Override + public PlanNode visitTableFunction(TableFunctionNode node, RewriteContext context) + { + throw new IllegalStateException(format("Unexpected node: TableFunctionNode (%s)", node.getName())); + } + + @Override + public PlanNode visitTableFunctionProcessor(TableFunctionProcessorNode node, RewriteContext context) + { + if (!node.getSource().isPresent()) { + // context is mutable. The leaf node should set the PartitioningHandle. + context.get().addSourceDistribution(node.getId(), SOURCE_DISTRIBUTION, metadata, session); + } + return context.defaultRewrite(node, context.get()); + } + @Override public PlanNode visitExchange(ExchangeNode exchange, RewriteContext context) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/CanonicalPlanGenerator.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/CanonicalPlanGenerator.java index 5825cbeca2dd6..068b4459ea02d 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/CanonicalPlanGenerator.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/CanonicalPlanGenerator.java @@ -27,6 +27,7 @@ import com.facebook.presto.spi.plan.CteConsumerNode; import com.facebook.presto.spi.plan.CteProducerNode; import com.facebook.presto.spi.plan.CteReferenceNode; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; import com.facebook.presto.spi.plan.DistinctLimitNode; import com.facebook.presto.spi.plan.EquiJoinClause; import com.facebook.presto.spi.plan.FilterNode; @@ -480,7 +481,7 @@ public Optional visitWindow(WindowNode node, Context context) .sorted(comparing(this::writeValueAsString)) .collect(toImmutableSet()); - WindowNode.Specification specification = new WindowNode.Specification( + DataOrganizationSpecification specification = new DataOrganizationSpecification( node.getSpecification().getPartitionBy().stream() .map(variable -> inlineAndCanonicalize(context.getExpressions(), variable)) .sorted(comparing(this::writeValueAsString)) @@ -694,7 +695,7 @@ public Optional visitTopNRowNumber(TopNRowNumberNode node, Context con Optional.empty(), planNodeidAllocator.getNextId(), source.get(), - new WindowNode.Specification( + new DataOrganizationSpecification( partitionBy, node.getSpecification().getOrderingScheme().map(scheme -> getCanonicalOrderingScheme(scheme, context.getExpressions()))), rowNumberVariable, diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java index 2667c097010be..889d676d534f6 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java @@ -69,6 +69,7 @@ import com.facebook.presto.operator.JoinBridgeManager; import com.facebook.presto.operator.JoinOperatorFactory; import com.facebook.presto.operator.JoinOperatorFactory.OuterOperatorFactoryResult; +import com.facebook.presto.operator.LeafTableFunctionOperator; import com.facebook.presto.operator.LimitOperator.LimitOperatorFactory; import com.facebook.presto.operator.LocalPlannerAware; import com.facebook.presto.operator.LookupJoinOperators; @@ -87,6 +88,7 @@ import com.facebook.presto.operator.PartitionFunction; import com.facebook.presto.operator.PartitionedLookupSourceFactory; import com.facebook.presto.operator.PipelineExecutionStrategy; +import com.facebook.presto.operator.RegularTableFunctionPartition; import com.facebook.presto.operator.RemoteProjectOperator.RemoteProjectOperatorFactory; import com.facebook.presto.operator.RowNumberOperator; import com.facebook.presto.operator.ScanFilterAndProjectOperator.ScanFilterAndProjectOperatorFactory; @@ -100,6 +102,7 @@ import com.facebook.presto.operator.StreamingAggregationOperator.StreamingAggregationOperatorFactory; import com.facebook.presto.operator.TableCommitContext; import com.facebook.presto.operator.TableFinishOperator.PageSinkCommitter; +import com.facebook.presto.operator.TableFunctionOperator; import com.facebook.presto.operator.TableScanOperator.TableScanOperatorFactory; import com.facebook.presto.operator.TableWriterMergeOperator.TableWriterMergeOperatorFactory; import com.facebook.presto.operator.TaskContext; @@ -138,15 +141,18 @@ import com.facebook.presto.spi.function.FunctionHandle; import com.facebook.presto.spi.function.FunctionMetadata; import com.facebook.presto.spi.function.JavaAggregationFunctionImplementation; +import com.facebook.presto.spi.function.SchemaFunctionName; import com.facebook.presto.spi.function.SqlFunctionHandle; import com.facebook.presto.spi.function.SqlFunctionId; import com.facebook.presto.spi.function.SqlInvokedFunction; import com.facebook.presto.spi.function.aggregation.LambdaProvider; +import com.facebook.presto.spi.function.table.TableFunctionProcessorProvider; import com.facebook.presto.spi.plan.AbstractJoinNode; import com.facebook.presto.spi.plan.AggregationNode; import com.facebook.presto.spi.plan.AggregationNode.Aggregation; import com.facebook.presto.spi.plan.AggregationNode.Step; import com.facebook.presto.spi.plan.Assignments; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; import com.facebook.presto.spi.plan.DeleteNode; import com.facebook.presto.spi.plan.DistinctLimitNode; import com.facebook.presto.spi.plan.EquiJoinClause; @@ -210,6 +216,8 @@ import com.facebook.presto.sql.planner.plan.RowNumberNode; import com.facebook.presto.sql.planner.plan.SampleNode; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.facebook.presto.sql.planner.plan.TableWriterMergeNode; import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; import com.facebook.presto.sql.planner.plan.UnnestNode; @@ -245,6 +253,7 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.NoSuchElementException; import java.util.Optional; import java.util.OptionalInt; import java.util.Set; @@ -350,10 +359,12 @@ import static com.google.common.base.Verify.verify; import static com.google.common.collect.DiscreteDomain.integers; import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.collect.Iterables.getOnlyElement; import static com.google.common.collect.Range.closedOpen; import static io.airlift.units.DataSize.Unit.BYTE; +import static java.lang.Math.toIntExact; import static java.lang.String.format; import static java.util.Objects.requireNonNull; import static java.util.stream.IntStream.range; @@ -1213,6 +1224,99 @@ public PhysicalOperation visitWindow(WindowNode node, LocalExecutionPlanContext return new PhysicalOperation(operatorFactory, outputMappings.build(), context, source); } + @Override + public PhysicalOperation visitTableFunction(TableFunctionNode node, LocalExecutionPlanContext context) + { + throw new UnsupportedOperationException("execution by operator is not yet implemented for table function " + node.getName()); + } + + @Override + public PhysicalOperation visitTableFunctionProcessor(TableFunctionProcessorNode node, LocalExecutionPlanContext context) + { + Function getTableFunctionProcessProvider = metadata.getFunctionAndTypeManager().getTableFunctionProcessorProvider().orElseThrow(NoSuchElementException::new); + TableFunctionProcessorProvider processorProvider = getTableFunctionProcessProvider.apply(node.getHandle().getSchemaFunctionName()); + + if (!node.getSource().isPresent()) { + OperatorFactory operatorFactory = new LeafTableFunctionOperator.LeafTableFunctionOperatorFactory(context.getNextOperatorId(), node.getId(), processorProvider, node.getHandle().getFunctionHandle()); + return new PhysicalOperation(operatorFactory, makeLayout(node), context, Optional.empty(), UNGROUPED_EXECUTION); + } + + PhysicalOperation source = node.getSource().orElseThrow(NoSuchElementException::new).accept(this, context); + + int properChannelsCount = node.getProperOutputs().size(); + + long passThroughSourcesCount = node.getPassThroughSpecifications().stream() + .filter(TableFunctionNode.PassThroughSpecification::isDeclaredAsPassThrough) + .count(); + + List> requiredChannels = node.getRequiredVariables().stream() + .map(list -> getChannelsForVariables(list, source.getLayout())) + .collect(toImmutableList()); + + Optional> markerChannels = node.getMarkerVariables() + .map(map -> map.entrySet().stream() + .collect(toImmutableMap(entry -> source.getLayout().get(entry.getKey()), entry -> source.getLayout().get(entry.getValue())))); + + int channel = properChannelsCount; + ImmutableList.Builder passThroughColumnSpecifications = ImmutableList.builder(); + for (TableFunctionNode.PassThroughSpecification specification : node.getPassThroughSpecifications()) { + // the table function produces one index channel for each source declared as pass-through. They are laid out after the proper channels. + int indexChannel = specification.isDeclaredAsPassThrough() ? channel++ : -1; + for (TableFunctionNode.PassThroughColumn column : specification.getColumns()) { + passThroughColumnSpecifications.add(new RegularTableFunctionPartition.PassThroughColumnSpecification(column.isPartitioningColumn(), source.getLayout().get(column.getOutputVariables()), indexChannel)); + } + } + + List partitionChannels = node.getSpecification() + .map(DataOrganizationSpecification::getPartitionBy) + .map(list -> getChannelsForVariables(list, source.getLayout())) + .orElse(ImmutableList.of()); + + List sortChannels = ImmutableList.of(); + List sortOrders = ImmutableList.of(); + if (node.getSpecification().flatMap(DataOrganizationSpecification::getOrderingScheme).isPresent()) { + OrderingScheme orderingScheme = node.getSpecification().flatMap(DataOrganizationSpecification::getOrderingScheme).orElseThrow(NoSuchElementException::new); + sortChannels = getChannelsForVariables(orderingScheme.getOrderByVariables(), source.getLayout()); + sortOrders = orderingScheme.getOrderingsMap().values().stream().collect(toImmutableList()); + } + + OperatorFactory operator = new TableFunctionOperator.TableFunctionOperatorFactory( + context.getNextOperatorId(), + node.getId(), + processorProvider, + node.getHandle().getFunctionHandle(), + properChannelsCount, + toIntExact(passThroughSourcesCount), + requiredChannels, + markerChannels, + passThroughColumnSpecifications.build(), + node.isPruneWhenEmpty(), + partitionChannels, + getChannelsForVariables(ImmutableList.copyOf(node.getPrePartitioned()), source.getLayout()), + sortChannels, + sortOrders, + node.getPreSorted(), + source.getTypes(), + 10_000, + pagesIndexFactory); + + ImmutableMap.Builder outputMappings = ImmutableMap.builder(); + for (int i = 0; i < node.getProperOutputs().size(); i++) { + outputMappings.put(node.getProperOutputs().get(i), i); + } + List passThroughVariables = node.getPassThroughSpecifications().stream() + .map(TableFunctionNode.PassThroughSpecification::getColumns) + .flatMap(Collection::stream) + .map(TableFunctionNode.PassThroughColumn::getOutputVariables) + .collect(toImmutableList()); + int outputChannel = properChannelsCount; + for (VariableReferenceExpression passThroughVariable : passThroughVariables) { + outputMappings.put(passThroughVariable, outputChannel++); + } + + return new PhysicalOperation(operator, outputMappings.buildOrThrow(), context, source); + } + @Override public PhysicalOperation visitTopN(TopNNode node, LocalExecutionPlanContext context) { @@ -2868,7 +2972,7 @@ public PhysicalOperation visitTableFinish(TableFinishNode node, LocalExecutionPl Map aggregationMap = aggregation.getAggregations().entrySet() .stream().collect( - ImmutableMap.toImmutableMap( + toImmutableMap( Map.Entry::getKey, entry -> createAggregation(entry.getValue()))); if (groupingVariables.isEmpty()) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/OrderingTranslator.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/OrderingTranslator.java new file mode 100644 index 0000000000000..c487a10415472 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/OrderingTranslator.java @@ -0,0 +1,37 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner; + +import com.facebook.presto.common.block.SortOrder; +import com.facebook.presto.sql.tree.SortItem; + +public class OrderingTranslator +{ + private OrderingTranslator() {} + + public static SortOrder sortItemToSortOrder(SortItem sortItem) + { + if (sortItem.getOrdering() == SortItem.Ordering.ASCENDING) { + if (sortItem.getNullOrdering() == SortItem.NullOrdering.FIRST) { + return SortOrder.ASC_NULLS_FIRST; + } + return SortOrder.ASC_NULLS_LAST; + } + + if (sortItem.getNullOrdering() == SortItem.NullOrdering.FIRST) { + return SortOrder.DESC_NULLS_FIRST; + } + return SortOrder.DESC_NULLS_LAST; + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java index 8e0f97689bcef..8f7a43581e5f9 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java @@ -48,6 +48,7 @@ import com.facebook.presto.sql.planner.iterative.rule.ImplementBernoulliSampleAsFilter; import com.facebook.presto.sql.planner.iterative.rule.ImplementFilteredAggregations; import com.facebook.presto.sql.planner.iterative.rule.ImplementOffset; +import com.facebook.presto.sql.planner.iterative.rule.ImplementTableFunctionSource; import com.facebook.presto.sql.planner.iterative.rule.InlineProjections; import com.facebook.presto.sql.planner.iterative.rule.InlineProjectionsOnValues; import com.facebook.presto.sql.planner.iterative.rule.InlineSqlFunctions; @@ -129,6 +130,7 @@ import com.facebook.presto.sql.planner.iterative.rule.RewriteConstantArrayContainsToInExpression; import com.facebook.presto.sql.planner.iterative.rule.RewriteFilterWithExternalFunctionToProject; import com.facebook.presto.sql.planner.iterative.rule.RewriteSpatialPartitioningAggregation; +import com.facebook.presto.sql.planner.iterative.rule.RewriteTableFunctionToTableScan; import com.facebook.presto.sql.planner.iterative.rule.RuntimeReorderJoinSides; import com.facebook.presto.sql.planner.iterative.rule.ScaledWriterRule; import com.facebook.presto.sql.planner.iterative.rule.SimplifyCardinalityMap; @@ -401,6 +403,7 @@ public PlanOptimizers( .addAll(predicatePushDownRules) .addAll(columnPruningRules) .addAll(ImmutableSet.of( + new ImplementTableFunctionSource(metadata), new MergeDuplicateAggregation(metadata.getFunctionAndTypeManager()), new RemoveRedundantIdentityProjections(), new RemoveFullSample(), @@ -825,6 +828,14 @@ public PlanOptimizers( costCalculator, ImmutableSet.of(new ScaledWriterRule()))); + builder.add( + new IterativeOptimizer( + metadata, + ruleStats, + statsCalculator, + costCalculator, + ImmutableSet.of(new RewriteTableFunctionToTableScan(metadata)))); + if (!noExchange) { builder.add(new ReplicateSemiJoinInDelete()); // Must run before AddExchanges diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java index 0ef7e7d7d663f..4f259360a6b67 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java @@ -29,6 +29,7 @@ import com.facebook.presto.spi.plan.AggregationNode; import com.facebook.presto.spi.plan.AggregationNode.Aggregation; import com.facebook.presto.spi.plan.Assignments; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; import com.facebook.presto.spi.plan.DeleteNode; import com.facebook.presto.spi.plan.FilterNode; import com.facebook.presto.spi.plan.LimitNode; @@ -114,6 +115,7 @@ import static com.facebook.presto.sql.analyzer.ExpressionAnalyzer.isNumericType; import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.getSourceLocation; import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; +import static com.facebook.presto.sql.planner.OrderingTranslator.sortItemToSortOrder; import static com.facebook.presto.sql.planner.PlannerUtils.newVariable; import static com.facebook.presto.sql.planner.PlannerUtils.toOrderingScheme; import static com.facebook.presto.sql.planner.PlannerUtils.toSortOrder; @@ -142,7 +144,7 @@ import static java.lang.String.format; import static java.util.Objects.requireNonNull; -class QueryPlanner +public class QueryPlanner { private final Analysis analysis; private final VariableAllocator variableAllocator; @@ -512,7 +514,7 @@ private PlanBuilder project(PlanBuilder subPlan, Iterable expression * * @return the new subplan and a mapping of each expression to the symbol representing the coercion or an existing symbol if a coercion wasn't needed */ - private PlanAndMappings coerce(PlanBuilder subPlan, List expressions, Analysis analysis, PlanNodeIdAllocator idAllocator, VariableAllocator variableAllocator, Metadata metadata) + public PlanAndMappings coerce(PlanBuilder subPlan, List expressions, Analysis analysis, PlanNodeIdAllocator idAllocator, VariableAllocator variableAllocator, Metadata metadata) { Assignments.Builder assignments = Assignments.builder(); assignments.putAll(subPlan.getRoot().getOutputVariables().stream().collect(toImmutableMap(Function.identity(), Function.identity()))); @@ -543,6 +545,28 @@ private PlanAndMappings coerce(PlanBuilder subPlan, List expressions return new PlanAndMappings(subPlan, mappings.build()); } + public static OrderingScheme translateOrderingScheme(List items, Function coercions) + { + List coerced = items.stream() + .map(SortItem::getSortKey) + .map(coercions) + .collect(toImmutableList()); + + ImmutableList.Builder variables = ImmutableList.builder(); + Map orders = new HashMap<>(); + for (int i = 0; i < coerced.size(); i++) { + VariableReferenceExpression variable = coerced.get(i); + // for multiple sort items based on the same expression, retain the first one: + // ORDER BY x DESC, x ASC, y --> ORDER BY x DESC, y + if (!orders.containsKey(variable)) { + variables.add(variable); + orders.put(variable, new Ordering(variable, sortItemToSortOrder(items.get(i)))); + } + } + + return new OrderingScheme(new ArrayList(orders.values())); + } + private Map coerce(Iterable expressions, PlanBuilder subPlan, TranslationMap translations) { ImmutableMap.Builder projections = ImmutableMap.builder(); @@ -1069,7 +1093,7 @@ else if (window.getFrame().isPresent()) { subPlan.getRoot().getSourceLocation(), idAllocator.getNextId(), subPlan.getRoot(), - new WindowNode.Specification( + new DataOrganizationSpecification( partitionByVariables.build(), orderingScheme), ImmutableMap.of(newVariable, function), @@ -1343,6 +1367,11 @@ private static List toSymbolReferences(List new NodeLocation(location.getLine(), location.getColumn())), variable.getName()); + } + public static class PlanAndMappings { private final PlanBuilder subPlan; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java index e260ea6b98d82..ce9e04cf16ee0 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java @@ -21,19 +21,23 @@ import com.facebook.presto.common.type.RowType; import com.facebook.presto.common.type.Type; import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.metadata.TableFunctionHandle; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.TableHandle; import com.facebook.presto.spi.VariableAllocator; import com.facebook.presto.spi.constraints.TableConstraint; +import com.facebook.presto.spi.function.SchemaFunctionName; import com.facebook.presto.spi.plan.AggregationNode; import com.facebook.presto.spi.plan.Assignments; import com.facebook.presto.spi.plan.CteReferenceNode; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; import com.facebook.presto.spi.plan.EquiJoinClause; import com.facebook.presto.spi.plan.ExceptNode; import com.facebook.presto.spi.plan.FilterNode; import com.facebook.presto.spi.plan.IntersectNode; import com.facebook.presto.spi.plan.JoinNode; +import com.facebook.presto.spi.plan.OrderingScheme; import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; import com.facebook.presto.spi.plan.ProjectNode; @@ -54,6 +58,10 @@ import com.facebook.presto.sql.planner.optimizations.SampleNodeUtil; import com.facebook.presto.sql.planner.plan.LateralJoinNode; import com.facebook.presto.sql.planner.plan.SampleNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughColumn; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughSpecification; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.TableArgumentProperties; import com.facebook.presto.sql.planner.plan.UnnestNode; import com.facebook.presto.sql.tree.AliasedRelation; import com.facebook.presto.sql.tree.Cast; @@ -82,8 +90,10 @@ import com.facebook.presto.sql.tree.Row; import com.facebook.presto.sql.tree.SampledRelation; import com.facebook.presto.sql.tree.SetOperation; +import com.facebook.presto.sql.tree.SortItem; import com.facebook.presto.sql.tree.SymbolReference; import com.facebook.presto.sql.tree.Table; +import com.facebook.presto.sql.tree.TableFunctionInvocation; import com.facebook.presto.sql.tree.TableSubquery; import com.facebook.presto.sql.tree.Union; import com.facebook.presto.sql.tree.Unnest; @@ -91,6 +101,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableListMultimap; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; import com.google.common.collect.ListMultimap; import com.google.common.collect.UnmodifiableIterator; @@ -106,6 +117,7 @@ import java.util.Map; import java.util.Optional; import java.util.Set; +import java.util.stream.Collectors; import java.util.stream.IntStream; import static com.facebook.presto.SystemSessionProperties.getCteMaterializationStrategy; @@ -122,6 +134,7 @@ import static com.facebook.presto.sql.analyzer.FeaturesConfig.CteMaterializationStrategy.NONE; import static com.facebook.presto.sql.analyzer.SemanticExceptions.notSupportedException; import static com.facebook.presto.sql.planner.PlannerUtils.newVariable; +import static com.facebook.presto.sql.planner.QueryPlanner.translateOrderingScheme; import static com.facebook.presto.sql.planner.TranslateExpressionsUtil.toRowExpression; import static com.facebook.presto.sql.tree.Join.Type.INNER; import static com.facebook.presto.sql.tree.Join.Type.LEFT; @@ -225,6 +238,113 @@ protected RelationPlan visitTable(Table node, SqlPlannerContext context) return new RelationPlan(root, scope, outputVariables); } + @Override + protected RelationPlan visitTableFunctionInvocation(TableFunctionInvocation node, SqlPlannerContext context) + { + Analysis.TableFunctionInvocationAnalysis functionAnalysis = analysis.getTableFunctionAnalysis(node); + ImmutableList.Builder sources = ImmutableList.builder(); + ImmutableList.Builder sourceProperties = ImmutableList.builder(); + ImmutableList.Builder outputVariables = ImmutableList.builder(); + + // create new symbols for table function's proper columns + RelationType relationType = analysis.getScope(node).getRelationType(); + List properOutputs = IntStream.range(0, functionAnalysis.getProperColumnsCount()) + .mapToObj(relationType::getFieldByIndex) + .map(field -> variableAllocator.newVariable(getSourceLocation(node), field.getName().get(), field.getType())) + .collect(toImmutableList()); + + outputVariables.addAll(properOutputs); + + // process sources in order of argument declarations + for (Analysis.TableArgumentAnalysis tableArgument : functionAnalysis.getTableArgumentAnalyses()) { + RelationPlan sourcePlan = process(tableArgument.getRelation(), context); + PlanBuilder sourcePlanBuilder = initializePlanBuilder(sourcePlan); + + List requiredColumns = functionAnalysis.getRequiredColumns().get(tableArgument.getArgumentName()).stream() + .map(sourcePlan::getVariable) + .collect(toImmutableList()); + + Optional specification = Optional.empty(); + + // if the table argument has set semantics, create Specification + if (!tableArgument.isRowSemantics()) { + // partition by + List partitionBy = ImmutableList.of(); + // if there are partitioning columns, they might have to be coerced for copartitioning + if (tableArgument.getPartitionBy().isPresent() && !tableArgument.getPartitionBy().get().isEmpty()) { + List partitioningColumns = tableArgument.getPartitionBy().get(); + sourcePlanBuilder = sourcePlanBuilder.appendProjections(partitioningColumns, variableAllocator, idAllocator, session, metadata, sqlParser, analysis, context); + QueryPlanner partitionQueryPlanner = new QueryPlanner(analysis, variableAllocator, idAllocator, lambdaDeclarationToVariableMap, metadata, session, context, sqlParser); + QueryPlanner.PlanAndMappings copartitionCoercions = partitionQueryPlanner.coerce(sourcePlanBuilder, partitioningColumns, analysis, idAllocator, variableAllocator, metadata); + sourcePlanBuilder = copartitionCoercions.getSubPlan(); + partitionBy = partitioningColumns.stream() + .map(copartitionCoercions::get) + .collect(toImmutableList()); + } + + // order by + Optional orderBy = Optional.empty(); + if (tableArgument.getOrderBy().isPresent()) { + List orderByColumns = tableArgument.getOrderBy().get().getSortItems().stream().map(SortItem::getSortKey).collect(Collectors.toList()); + sourcePlanBuilder = sourcePlanBuilder.appendProjections(orderByColumns, variableAllocator, idAllocator, session, metadata, sqlParser, analysis, context); + // the ordering symbols are not coerced + orderBy = Optional.of(translateOrderingScheme(tableArgument.getOrderBy().get().getSortItems(), sourcePlanBuilder::translate)); + } + + specification = Optional.of(new DataOrganizationSpecification(partitionBy, orderBy)); + } + + // add output symbols passed from the table argument + ImmutableList.Builder passThroughColumns = ImmutableList.builder(); + if (tableArgument.isPassThroughColumns()) { + // the original output symbols from the source node, not coerced + // note: hidden columns are included. They are present in sourcePlan.fieldMappings + outputVariables.addAll(sourcePlan.getFieldMappings()); + Set partitionBy = specification + .map(DataOrganizationSpecification::getPartitionBy) + .map(ImmutableSet::copyOf) + .orElse(ImmutableSet.of()); + sourcePlan.getFieldMappings().stream() + .map(variable -> new PassThroughColumn(variable, partitionBy.contains(variable))) + .forEach(passThroughColumns::add); + } + else if (tableArgument.getPartitionBy().isPresent()) { + tableArgument.getPartitionBy().get().stream() + // the original symbols for partitioning columns, not coerced + .map(sourcePlanBuilder::translate) + .forEach(variable -> { + outputVariables.add(variable); + passThroughColumns.add(new PassThroughColumn(variable, true)); + }); + } + + sources.add(sourcePlanBuilder.getRoot()); + sourceProperties.add(new TableArgumentProperties( + tableArgument.getArgumentName(), + tableArgument.isRowSemantics(), + tableArgument.isPruneWhenEmpty(), + new PassThroughSpecification(tableArgument.isPassThroughColumns(), passThroughColumns.build()), + requiredColumns, + specification)); + } + + PlanNode root = new TableFunctionNode( + idAllocator.getNextId(), + functionAnalysis.getFunctionName(), + functionAnalysis.getArguments(), + properOutputs, + sources.build(), + sourceProperties.build(), + functionAnalysis.getCopartitioningLists(), + new TableFunctionHandle( + functionAnalysis.getConnectorId(), + new SchemaFunctionName(functionAnalysis.getSchemaName(), functionAnalysis.getFunctionName()), + functionAnalysis.getConnectorTableFunctionHandle(), + functionAnalysis.getTransactionHandle())); + + return new RelationPlan(root, analysis.getScope(node), outputVariables.build()); + } + @Override protected RelationPlan visitAliasedRelation(AliasedRelation node, SqlPlannerContext context) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/SchedulingOrderVisitor.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/SchedulingOrderVisitor.java index d34e01cda201a..830a7902f8cf0 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/SchedulingOrderVisitor.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/SchedulingOrderVisitor.java @@ -22,9 +22,11 @@ import com.facebook.presto.spi.plan.TableScanNode; import com.facebook.presto.sql.planner.plan.IndexJoinNode; import com.facebook.presto.sql.planner.plan.InternalPlanVisitor; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.google.common.collect.ImmutableList; import java.util.List; +import java.util.NoSuchElementException; import java.util.function.Consumer; public class SchedulingOrderVisitor @@ -88,5 +90,17 @@ public Void visitTableScan(TableScanNode node, Consumer schedulingOr schedulingOrder.accept(node.getId()); return null; } + + @Override + public Void visitTableFunctionProcessor(TableFunctionProcessorNode node, Consumer schedulingOrder) + { + if (!node.getSource().isPresent()) { + schedulingOrder.accept(node.getId()); + } + else { + node.getSource().orElseThrow(NoSuchElementException::new).accept(this, schedulingOrder); + } + return null; + } } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/SplitSourceFactory.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/SplitSourceFactory.java index c4707b487744e..3a05cef7de79f 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/SplitSourceFactory.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/SplitSourceFactory.java @@ -58,6 +58,7 @@ import com.facebook.presto.sql.planner.plan.RowNumberNode; import com.facebook.presto.sql.planner.plan.SampleNode; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.facebook.presto.sql.planner.plan.TableWriterMergeNode; import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; import com.facebook.presto.sql.planner.plan.UnnestNode; @@ -67,6 +68,7 @@ import java.util.List; import java.util.Map; +import java.util.NoSuchElementException; import java.util.Optional; import java.util.function.Supplier; @@ -283,6 +285,21 @@ public Map visitWindow(WindowNode node, Context context return node.getSource().accept(this, context); } + @Override + public Map visitTableFunctionProcessor(TableFunctionProcessorNode node, Context context) + { + if (!node.getSource().isPresent()) { + // this is a source node, so produce splits + SplitSource splitSource = splitSourceProvider.getSplits( + session, + node.getHandle()); + splitSources.add(splitSource); + return ImmutableMap.of(node.getId(), splitSource); + } + + return node.getSource().orElseThrow(NoSuchElementException::new).accept(this, context); + } + @Override public Map visitRowNumber(RowNumberNode node, Context context) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ImplementTableFunctionSource.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ImplementTableFunctionSource.java new file mode 100644 index 0000000000000..fa523e397ccad --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ImplementTableFunctionSource.java @@ -0,0 +1,1018 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.common.function.OperatorType; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.matching.Captures; +import com.facebook.presto.matching.Pattern; +import com.facebook.presto.metadata.FunctionAndTypeManager; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.function.FunctionHandle; +import com.facebook.presto.spi.function.FunctionMetadata; +import com.facebook.presto.spi.plan.Assignments; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; +import com.facebook.presto.spi.plan.JoinNode; +import com.facebook.presto.spi.plan.JoinType; +import com.facebook.presto.spi.plan.Ordering; +import com.facebook.presto.spi.plan.OrderingScheme; +import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.spi.plan.ProjectNode; +import com.facebook.presto.spi.plan.WindowNode; +import com.facebook.presto.spi.plan.WindowNode.Frame; +import com.facebook.presto.spi.relation.CallExpression; +import com.facebook.presto.spi.relation.ConstantExpression; +import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.SpecialFormExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughSpecification; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.TableArgumentProperties; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; +import com.facebook.presto.sql.relational.FunctionResolution; +import com.facebook.presto.sql.tree.QualifiedName; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Streams; + +import java.util.Collection; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static com.facebook.presto.common.block.SortOrder.ASC_NULLS_LAST; +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.BooleanType.BOOLEAN; +import static com.facebook.presto.common.type.IntegerType.INTEGER; +import static com.facebook.presto.spi.plan.JoinType.FULL; +import static com.facebook.presto.spi.plan.JoinType.INNER; +import static com.facebook.presto.spi.plan.JoinType.LEFT; +import static com.facebook.presto.spi.plan.JoinType.RIGHT; +import static com.facebook.presto.spi.plan.WindowNode.Frame.BoundType.UNBOUNDED_FOLLOWING; +import static com.facebook.presto.spi.plan.WindowNode.Frame.BoundType.UNBOUNDED_PRECEDING; +import static com.facebook.presto.spi.plan.WindowNode.Frame.WindowType.ROWS; +import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.COALESCE; +import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.IF; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; +import static com.facebook.presto.sql.planner.plan.Patterns.tableFunction; +import static com.facebook.presto.sql.relational.Expressions.coalesce; +import static com.facebook.presto.sql.tree.ComparisonExpression.Operator.EQUAL; +import static com.facebook.presto.sql.tree.ComparisonExpression.Operator.GREATER_THAN; +import static com.facebook.presto.sql.tree.ComparisonExpression.Operator.IS_DISTINCT_FROM; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static com.google.common.collect.Iterables.getOnlyElement; +import static java.util.Objects.requireNonNull; +import static java.util.function.Function.identity; + +/** + * This rule prepares cartesian product of partitions + * from all inputs of table function. + *

+ * It rewrites TableFunctionNode with potentially many sources + * into a TableFunctionProcessorNode. The new node has one + * source being a combination of the original sources. + *

+ * The original sources are combined with joins. The join + * conditions depend on the prune when empty property, and on + * the co-partitioning of sources. + *

+ * The resulting source should be partitioned and ordered + * according to combined schemas from the component sources. + *

+ * Example transformation for two sources, both with set semantics + * and KEEP WHEN EMPTY property: + *

+ * - TableFunction foo
+ *      - source T1(a1, b1) PARTITION BY a1 ORDER BY b1
+ *      - source T2(a2, b2) PARTITION BY a2
+ * 
+ * Is transformed into: + *
+ * - TableFunctionDataProcessor foo
+ *      PARTITION BY (a1, a2), ORDER BY combined_row_number
+ *      - Project
+ *          marker_1 <= IF(table1_row_number = combined_row_number, table1_row_number, CAST(null AS bigint))
+ *          marker_2 <= IF(table2_row_number = combined_row_number, table2_row_number, CAST(null AS bigint))
+ *          - Project
+ *              combined_row_number <= IF(COALESCE(table1_row_number, BIGINT '-1') > COALESCE(table2_row_number, BIGINT '-1'), table1_row_number, table2_row_number)
+ *              combined_partition_size <= IF(COALESCE(table1_partition_size, BIGINT '-1') > COALESCE(table2_partition_size, BIGINT '-1'), table1_partition_size, table2_partition_size)
+ *              - FULL Join
+ *                  [table1_row_number = table2_row_number OR
+ *                   table1_row_number > table2_partition_size AND table2_row_number = BIGINT '1' OR
+ *                   table2_row_number > table1_partition_size AND table1_row_number = BIGINT '1']
+ *                  - Window [PARTITION BY a1 ORDER BY b1]
+ *                      table1_row_number <= row_number()
+ *                      table1_partition_size <= count()
+ *                          - source T1(a1, b1)
+ *                  - Window [PARTITION BY a2]
+ *                      table2_row_number <= row_number()
+ *                      table2_partition_size <= count()
+ *                          - source T2(a2, b2)
+ * 
+ */ +public class ImplementTableFunctionSource + implements Rule +{ + private static final Pattern PATTERN = tableFunction(); + private static final Frame FULL_FRAME = new Frame( + ROWS, + UNBOUNDED_PRECEDING, + Optional.empty(), + Optional.empty(), + UNBOUNDED_FOLLOWING, + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty()); + private static final DataOrganizationSpecification UNORDERED_SINGLE_PARTITION = new DataOrganizationSpecification(ImmutableList.of(), Optional.empty()); + + private final Metadata metadata; + + public ImplementTableFunctionSource(Metadata metadata) + { + this.metadata = requireNonNull(metadata, "metadata is null"); + } + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Result apply(TableFunctionNode node, Captures captures, Context context) + { + if (node.getSources().isEmpty()) { + return Result.ofPlanNode(new TableFunctionProcessorNode( + node.getId(), + node.getName(), + node.getProperOutputs(), + Optional.empty(), + false, + ImmutableList.of(), + ImmutableList.of(), + Optional.empty(), + Optional.empty(), + ImmutableSet.of(), + 0, + Optional.empty(), + node.getHandle())); + } + + if (node.getSources().size() == 1) { + // Single source does not require pre-processing. + // If the source has row semantics, its specification is empty. + // If the source has set semantics, its specification is present, even if there is no partitioning or ordering specified. + // This property can be used later to choose optimal distribution. + TableArgumentProperties sourceProperties = getOnlyElement(node.getTableArgumentProperties()); + return Result.ofPlanNode(new TableFunctionProcessorNode( + node.getId(), + node.getName(), + node.getProperOutputs(), + Optional.of(getOnlyElement(node.getSources())), + sourceProperties.pruneWhenEmpty(), + ImmutableList.of(sourceProperties.getPassThroughSpecification()), + ImmutableList.of(sourceProperties.getRequiredColumns()), + Optional.empty(), + sourceProperties.specification(), + ImmutableSet.of(), + 0, + Optional.empty(), + node.getHandle())); + } + Map sources = mapSourcesByName(node.getSources(), node.getTableArgumentProperties()); + ImmutableList.Builder intermediateResultsBuilder = ImmutableList.builder(); + + FunctionAndTypeManager functionAndTypeManager = metadata.getFunctionAndTypeManager(); + + // Create call expression for row_number + FunctionHandle rowNumberFunctionHandle = functionAndTypeManager.resolveFunction(Optional.of(context.getSession().getSessionFunctions()), + context.getSession().getTransactionId(), + functionAndTypeManager.getFunctionAndTypeResolver().qualifyObjectName(QualifiedName.of("row_number")), + ImmutableList.of()); + + FunctionMetadata rowNumberFunctionMetadata = functionAndTypeManager.getFunctionMetadata(rowNumberFunctionHandle); + CallExpression rowNumberFunction = new CallExpression("row_number", rowNumberFunctionHandle, functionAndTypeManager.getType(rowNumberFunctionMetadata.getReturnType()), ImmutableList.of()); + + // Create call expression for count + FunctionHandle countFunctionHandle = functionAndTypeManager.resolveFunction(Optional.of(context.getSession().getSessionFunctions()), + context.getSession().getTransactionId(), + functionAndTypeManager.getFunctionAndTypeResolver().qualifyObjectName(QualifiedName.of("count")), + ImmutableList.of()); + + FunctionMetadata countFunctionMetadata = functionAndTypeManager.getFunctionMetadata(countFunctionHandle); + CallExpression countFunction = new CallExpression("count", countFunctionHandle, functionAndTypeManager.getType(countFunctionMetadata.getReturnType()), ImmutableList.of()); + + // handle co-partitioned sources + for (List copartitioningList : node.getCopartitioningLists()) { + List sourceList = copartitioningList.stream() + .map(sources::get) + .collect(toImmutableList()); + intermediateResultsBuilder.add(copartition(sourceList, rowNumberFunction, countFunction, context, metadata)); + } + + // prepare non-co-partitioned sources + Set copartitionedSources = node.getCopartitioningLists().stream() + .flatMap(Collection::stream) + .collect(toImmutableSet()); + sources.entrySet().stream() + .filter(entry -> !copartitionedSources.contains(entry.getKey())) + .map(entry -> planWindowFunctionsForSource(entry.getValue().source(), entry.getValue().properties(), rowNumberFunction, countFunction, context)) + .forEach(intermediateResultsBuilder::add); + + NodeWithVariables finalResultSource; + + List intermediateResultSources = intermediateResultsBuilder.build(); + if (intermediateResultSources.size() == 1) { + finalResultSource = getOnlyElement(intermediateResultSources); + } + else { + NodeWithVariables first = intermediateResultSources.get(0); + NodeWithVariables second = intermediateResultSources.get(1); + JoinedNodes joined = join(first, second, context, metadata); + + for (int i = 2; i < intermediateResultSources.size(); i++) { + NodeWithVariables joinedWithSymbols = appendHelperSymbolsForJoinedNodes(joined, context, metadata); + joined = join(joinedWithSymbols, intermediateResultSources.get(i), context, metadata); + } + + finalResultSource = appendHelperSymbolsForJoinedNodes(joined, context, metadata); + } + + // For each source, all source's output symbols are mapped to the source's row number symbol. + // The row number symbol will be later converted to a marker of "real" input rows vs "filler" input rows of the source. + // The "filler" input rows are the rows appended while joining partitions of different lengths, + // to fill the smaller partition up to the bigger partition's size. They are a side effect of the algorithm, + // and should not be processed by the table function. + Map rowNumberSymbols = finalResultSource.rowNumberSymbolsMapping(); + + // The max row number symbol from all joined partitions. + VariableReferenceExpression finalRowNumberSymbol = finalResultSource.rowNumber(); + // Combined partitioning lists from all sources. + List finalPartitionBy = finalResultSource.partitionBy(); + + NodeWithMarkers marked = appendMarkerSymbols(finalResultSource.node(), ImmutableSet.copyOf(rowNumberSymbols.values()), finalRowNumberSymbol, context, metadata); + + // Remap the symbol mapping: replace the row number symbol with the corresponding marker symbol. + // In the new map, every source symbol is associated with the corresponding marker symbol. + // Null value of the marker indicates that the source value should be ignored by the table function. + ImmutableMap markerSymbols = rowNumberSymbols.entrySet().stream() + .collect(toImmutableMap(Map.Entry::getKey, entry -> marked.variableToMarker().get(entry.getValue()))); + + // Use the final row number symbol for ordering the combined sources. + // It runs along each partition in the cartesian product, numbering the partition's rows according to the expected ordering / orderings. + // note: ordering is necessary even if all the source tables are not ordered. Thanks to the ordering, the original rows + // of each input table come before the "filler" rows. + ImmutableList.Builder newOrderings = ImmutableList.builder(); + newOrderings.add(new Ordering(finalRowNumberSymbol, ASC_NULLS_LAST)); + Optional finalOrderBy = Optional.of(new OrderingScheme(newOrderings.build())); + + // derive the prune when empty property + boolean pruneWhenEmpty = node.getTableArgumentProperties().stream().anyMatch(TableArgumentProperties::pruneWhenEmpty); + + // Combine the pass through specifications from all sources + List passThroughSpecifications = node.getTableArgumentProperties().stream() + .map(TableArgumentProperties::getPassThroughSpecification) + .collect(toImmutableList()); + + // Combine the required symbols from all sources + List> requiredVariables = node.getTableArgumentProperties().stream() + .map(TableArgumentProperties::getRequiredColumns) + .collect(toImmutableList()); + + return Result.ofPlanNode(new TableFunctionProcessorNode( + node.getId(), + node.getName(), + node.getProperOutputs(), + Optional.of(marked.node()), + pruneWhenEmpty, + passThroughSpecifications, + requiredVariables, + Optional.of(markerSymbols), + Optional.of(new DataOrganizationSpecification(finalPartitionBy, finalOrderBy)), + ImmutableSet.of(), + 0, + Optional.empty(), + node.getHandle())); + } + + private static Map mapSourcesByName(List sources, List properties) + { + return Streams.zip(sources.stream(), properties.stream(), SourceWithProperties::new) + .collect(toImmutableMap(entry -> entry.properties().getArgumentName(), identity())); + } + + private static NodeWithVariables planWindowFunctionsForSource( + PlanNode source, + TableArgumentProperties argumentProperties, + CallExpression rowNumberFunction, + CallExpression countFunction, + Context context) + { + String argumentName = argumentProperties.getArgumentName(); + + VariableReferenceExpression rowNumber = context.getVariableAllocator().newVariable(argumentName + "_row_number", BIGINT); + Map rowNumberSymbolMapping = source.getOutputVariables().stream() + .collect(toImmutableMap(identity(), symbol -> rowNumber)); + + VariableReferenceExpression partitionSize = context.getVariableAllocator().newVariable(argumentName + "_partition_size", BIGINT); + + // If the source has set semantics, its specification is present, even if there is no partitioning or ordering specified. + // If the source has row semantics, its specification is empty. Currently, such source is processed + // as if it was a single partition. Alternatively, it could be split into smaller partitions of arbitrary size. + DataOrganizationSpecification specification = argumentProperties.specification().orElse(UNORDERED_SINGLE_PARTITION); + + PlanNode window = new WindowNode( + source.getSourceLocation(), + context.getIdAllocator().getNextId(), + source, + specification, + ImmutableMap.of( + rowNumber, new WindowNode.Function(rowNumberFunction, FULL_FRAME, false), + partitionSize, new WindowNode.Function(countFunction, FULL_FRAME, false)), + Optional.empty(), + ImmutableSet.of(), + 0); + + return new NodeWithVariables(window, rowNumber, partitionSize, specification.getPartitionBy(), argumentProperties.pruneWhenEmpty(), rowNumberSymbolMapping); + } + + private static NodeWithVariables copartition( + List sourceList, + CallExpression rowNumberFunction, + CallExpression countFunction, + Context context, + Metadata metadata) + { + checkArgument(sourceList.size() >= 2, "co-partitioning list should contain at least two tables"); + + // Reorder the co-partitioned sources to process the sources with prune when empty property first. + // It allows to use inner or side joins instead of outer joins. + sourceList = sourceList.stream() + .sorted(Comparator.comparingInt(source -> source.properties().pruneWhenEmpty() ? -1 : 1)) + .collect(toImmutableList()); + + NodeWithVariables first = planWindowFunctionsForSource(sourceList.get(0).source(), sourceList.get(0).properties(), rowNumberFunction, countFunction, context); + NodeWithVariables second = planWindowFunctionsForSource(sourceList.get(1).source(), sourceList.get(1).properties(), rowNumberFunction, countFunction, context); + JoinedNodes copartitioned = copartition(first, second, context, metadata); + + for (int i = 2; i < sourceList.size(); i++) { + NodeWithVariables copartitionedWithSymbols = appendHelperSymbolsForCopartitionedNodes(copartitioned, context, metadata); + NodeWithVariables next = planWindowFunctionsForSource(sourceList.get(i).source(), sourceList.get(i).properties(), rowNumberFunction, countFunction, context); + copartitioned = copartition(copartitionedWithSymbols, next, context, metadata); + } + + return appendHelperSymbolsForCopartitionedNodes(copartitioned, context, metadata); + } + + private static JoinedNodes copartition(NodeWithVariables left, NodeWithVariables right, Context context, Metadata metadata) + { + checkArgument(left.partitionBy().size() == right.partitionBy().size(), "co-partitioning lists do not match"); + + // In StatementAnalyzer we require that co-partitioned tables have non-empty partitioning column lists. + // Co-partitioning tables with empty partition by would be ineffective. + checkState(!left.partitionBy().isEmpty(), "co-partitioned tables must have partitioning columns"); + + FunctionResolution functionResolution = new FunctionResolution(metadata.getFunctionAndTypeManager().getFunctionAndTypeResolver()); + + Optional copartitionConjuncts = Streams.zip( + left.partitionBy.stream(), + right.partitionBy.stream(), + (leftColumn, rightColumn) -> new CallExpression("NOT", + functionResolution.notFunction(), + BOOLEAN, + ImmutableList.of( + new CallExpression(IS_DISTINCT_FROM.name(), + functionResolution.comparisonFunction(IS_DISTINCT_FROM, INTEGER, INTEGER), + BOOLEAN, + ImmutableList.of(leftColumn, rightColumn))))) + .map(expr -> expr) + .reduce((expr, conjunct) -> new SpecialFormExpression(SpecialFormExpression.Form.AND, + BOOLEAN, + ImmutableList.of(expr, conjunct))); + + // Align matching partitions (co-partitions) from left and right source, according to row number. + // Matching partitions are identified by their corresponding partitioning columns being NOT DISTINCT from each other. + // If one or both sources are ordered, the row number reflects the ordering. + // The second and third disjunct in the join condition account for the situation when partitions have different sizes. + // It preserves the outstanding rows from the bigger partition, matching them to the first row from the smaller partition. + // + // (P1_1 IS NOT DISTINCT FROM P2_1) AND (P1_2 IS NOT DISTINCT FROM P2_2) AND ... + // AND ( + // R1 = R2 + // OR + // (R1 > S2 AND R2 = 1) + // OR + // (R2 > S1 AND R1 = 1)) + + SpecialFormExpression orExpression = new SpecialFormExpression(SpecialFormExpression.Form.OR, + BOOLEAN, + ImmutableList.of( + new CallExpression(EQUAL.name(), + functionResolution.comparisonFunction(EQUAL, BIGINT, BIGINT), + BOOLEAN, + ImmutableList.of(left.rowNumber, right.rowNumber)), + new SpecialFormExpression(SpecialFormExpression.Form.OR, + BOOLEAN, + ImmutableList.of( + new SpecialFormExpression(SpecialFormExpression.Form.AND, + BOOLEAN, + ImmutableList.of( + new CallExpression(GREATER_THAN.name(), + functionResolution.comparisonFunction(GREATER_THAN, BIGINT, BIGINT), + BOOLEAN, + ImmutableList.of(left.rowNumber, right.partitionSize)), + new CallExpression(EQUAL.name(), + functionResolution.comparisonFunction(EQUAL, BIGINT, BIGINT), + BOOLEAN, + ImmutableList.of(right.rowNumber, new ConstantExpression(1L, BIGINT))))), + new SpecialFormExpression(SpecialFormExpression.Form.AND, + BOOLEAN, + ImmutableList.of( + new CallExpression(GREATER_THAN.name(), + functionResolution.comparisonFunction(GREATER_THAN, BIGINT, BIGINT), + BOOLEAN, + ImmutableList.of(right.rowNumber, left.partitionSize)), + new CallExpression(EQUAL.name(), + functionResolution.comparisonFunction(EQUAL, BIGINT, BIGINT), + BOOLEAN, + ImmutableList.of(left.rowNumber, new ConstantExpression(1L, BIGINT))))))))); + RowExpression joinCondition = copartitionConjuncts.map( + conjunct -> new SpecialFormExpression(SpecialFormExpression.Form.AND, + BOOLEAN, + ImmutableList.of(conjunct, orExpression))) + .orElse(orExpression); + + // The join type depends on the prune when empty property of the sources. + // If a source is prune when empty, we should not process any co-partition which is not present in this source, + // so effectively the other source becomes inner side of the join. + // + // example: + // table T1 partition by P1 table T2 partition by P2 + // P1 C1 P2 C2 + // ---------- ---------- + // 1 'a' 2 'c' + // 2 'b' 3 'd' + // + // co-partitioning results: + // 1) T1 is prune when empty: do LEFT JOIN to drop co-partition '3' + // P1 C1 P2 C2 + // ------------------------ + // 1 'a' null null + // 2 'b' 2 'c' + // + // 2) T2 is prune when empty: do RIGHT JOIN to drop co-partition '1' + // P1 C1 P2 C2 + // ------------------------ + // 2 'b' 2 'c' + // null null 3 'd' + // + // 3) T1 and T2 are both prune when empty: do INNER JOIN to drop co-partitions '1' and '3' + // P1 C1 P2 C2 + // ------------------------ + // 2 'b' 2 'c' + // + // 4) neither table is prune when empty: do FULL JOIN to preserve all co-partitions + // P1 C1 P2 C2 + // ------------------------ + // 1 'a' null null + // 2 'b' 2 'c' + // null null 3 'd' + JoinType joinType; + if (left.pruneWhenEmpty() && right.pruneWhenEmpty()) { + joinType = INNER; + } + else if (left.pruneWhenEmpty()) { + joinType = LEFT; + } + else if (right.pruneWhenEmpty()) { + joinType = RIGHT; + } + else { + joinType = FULL; + } + + return new JoinedNodes( + new JoinNode( + Optional.empty(), + context.getIdAllocator().getNextId(), + joinType, + left.node(), + right.node(), + ImmutableList.of(), + Stream.concat(left.node().getOutputVariables().stream(), + right.node().getOutputVariables().stream()) + .collect(Collectors.toList()), + Optional.of(joinCondition), + Optional.empty(), + Optional.empty(), + Optional.empty(), + ImmutableMap.of()), + left.rowNumber(), + left.partitionSize(), + left.partitionBy(), + left.pruneWhenEmpty(), + left.rowNumberSymbolsMapping(), + right.rowNumber(), + right.partitionSize(), + right.partitionBy(), + right.pruneWhenEmpty(), + right.rowNumberSymbolsMapping()); + } + + private static NodeWithVariables appendHelperSymbolsForCopartitionedNodes( + JoinedNodes copartitionedNodes, + Context context, + Metadata metadata) + { + checkArgument(copartitionedNodes.leftPartitionBy().size() == copartitionedNodes.rightPartitionBy().size(), "co-partitioning lists do not match"); + + // Derive row number for joined partitions: this is the bigger partition's row number. One of the combined values might be null as a result of outer join. + VariableReferenceExpression joinedRowNumber = context.getVariableAllocator().newVariable("combined_row_number", BIGINT); + RowExpression rowNumberExpression = new SpecialFormExpression( + IF, + BIGINT, + ImmutableList.of( + new CallExpression( + GREATER_THAN.name(), + metadata.getFunctionAndTypeManager().resolveOperator( + OperatorType.GREATER_THAN, + fromTypes(BIGINT, BIGINT)), + BOOLEAN, + ImmutableList.of( + new SpecialFormExpression( + COALESCE, + BIGINT, + copartitionedNodes.leftRowNumber(), + new ConstantExpression(-1L, BIGINT)), + new SpecialFormExpression( + COALESCE, + BIGINT, + copartitionedNodes.rightRowNumber(), + new ConstantExpression(-1L, BIGINT)))), + copartitionedNodes.leftRowNumber(), + copartitionedNodes.rightRowNumber())); + + // Derive partition size for joined partitions: this is the bigger partition's size. One of the combined values might be null as a result of outer join. + VariableReferenceExpression joinedPartitionSize = context.getVariableAllocator().newVariable("combined_partition_size", BIGINT); + RowExpression partitionSizeExpression = new SpecialFormExpression( + IF, + BIGINT, + ImmutableList.of( + new CallExpression( + GREATER_THAN.name(), + metadata.getFunctionAndTypeManager().resolveOperator( + OperatorType.GREATER_THAN, + fromTypes(BIGINT, BIGINT)), + BOOLEAN, + ImmutableList.of( + new SpecialFormExpression( + COALESCE, + BIGINT, + copartitionedNodes.leftPartitionSize(), + new ConstantExpression(-1L, BIGINT)), + new SpecialFormExpression( + COALESCE, + BIGINT, + copartitionedNodes.rightPartitionSize(), + new ConstantExpression(-1L, BIGINT)))), + copartitionedNodes.leftPartitionSize(), + copartitionedNodes.rightPartitionSize())); + + // Derive partitioning columns for joined partitions. + // Either the combined partitioning columns are pairwise NOT DISTINCT (this is the co-partitioning rule), + // or one of them is null as a result of outer join. + ImmutableList.Builder joinedPartitionBy = ImmutableList.builder(); + Assignments.Builder joinedPartitionByAssignments = Assignments.builder(); + for (int i = 0; i < copartitionedNodes.leftPartitionBy().size(); i++) { + VariableReferenceExpression leftColumn = copartitionedNodes.leftPartitionBy().get(i); + VariableReferenceExpression rightColumn = copartitionedNodes.rightPartitionBy().get(i); + Type type = context.getVariableAllocator().getVariables().get(leftColumn.getName()); + + VariableReferenceExpression joinedColumn = context.getVariableAllocator().newVariable("combined_partition_column", type); + joinedPartitionByAssignments.put(joinedColumn, coalesce(leftColumn, rightColumn)); + joinedPartitionBy.add(joinedColumn); + } + + PlanNode project = new ProjectNode( + context.getIdAllocator().getNextId(), + copartitionedNodes.joinedNode(), + Assignments.builder() + .putIdentities(copartitionedNodes.joinedNode().getOutputVariables()) + .put(joinedRowNumber, rowNumberExpression) + .put(joinedPartitionSize, partitionSizeExpression) + .putAll(joinedPartitionByAssignments.build()) + .build()); + boolean joinedPruneWhenEmpty = copartitionedNodes.leftPruneWhenEmpty() || copartitionedNodes.rightPruneWhenEmpty(); + + Map joinedRowNumberSymbolsMapping = ImmutableMap.builder() + .putAll(copartitionedNodes.leftRowNumberSymbolsMapping()) + .putAll(copartitionedNodes.rightRowNumberSymbolsMapping()) + .buildOrThrow(); + + return new NodeWithVariables(project, joinedRowNumber, joinedPartitionSize, joinedPartitionBy.build(), joinedPruneWhenEmpty, joinedRowNumberSymbolsMapping); + } + + private static JoinedNodes join(NodeWithVariables left, NodeWithVariables right, Context context, Metadata metadata) + { + // Align rows from left and right source according to row number. Because every partition is row-numbered, this produces cartesian product of partitions. + // If one or both sources are ordered, the row number reflects the ordering. + // The second and third disjunct in the join condition account for the situation when partitions have different sizes. It preserves the outstanding rows + // from the bigger partition, matching them to the first row from the smaller partition. + // + // R1 = R2 + // OR + // (R1 > S2 AND R2 = 1) + // OR + // (R2 > S1 AND R1 = 1) + + FunctionResolution functionResolution = new FunctionResolution(metadata.getFunctionAndTypeManager().getFunctionAndTypeResolver()); + RowExpression joinCondition = new SpecialFormExpression(SpecialFormExpression.Form.OR, + BOOLEAN, + ImmutableList.of( + new CallExpression(EQUAL.name(), + functionResolution.comparisonFunction(EQUAL, BIGINT, BIGINT), + BOOLEAN, + ImmutableList.of(left.rowNumber, right.rowNumber)), + new SpecialFormExpression(SpecialFormExpression.Form.OR, + BOOLEAN, + ImmutableList.of( + new SpecialFormExpression(SpecialFormExpression.Form.AND, + BOOLEAN, + ImmutableList.of( + new CallExpression(GREATER_THAN.name(), + functionResolution.comparisonFunction(GREATER_THAN, BIGINT, BIGINT), + BOOLEAN, + ImmutableList.of(left.rowNumber, right.partitionSize)), + new CallExpression(EQUAL.name(), + functionResolution.comparisonFunction(EQUAL, BIGINT, BIGINT), + BOOLEAN, + ImmutableList.of(right.rowNumber, new ConstantExpression(1L, BIGINT))))), + new SpecialFormExpression(SpecialFormExpression.Form.AND, + BOOLEAN, + ImmutableList.of( + new CallExpression(GREATER_THAN.name(), + functionResolution.comparisonFunction(GREATER_THAN, BIGINT, BIGINT), + BOOLEAN, + ImmutableList.of(right.rowNumber, left.partitionSize)), + new CallExpression(EQUAL.name(), + functionResolution.comparisonFunction(EQUAL, BIGINT, BIGINT), + BOOLEAN, + ImmutableList.of(left.rowNumber, new ConstantExpression(1L, BIGINT))))))))); + JoinType joinType; + if (left.pruneWhenEmpty() && right.pruneWhenEmpty()) { + joinType = INNER; + } + else if (left.pruneWhenEmpty()) { + joinType = LEFT; + } + else if (right.pruneWhenEmpty()) { + joinType = RIGHT; + } + else { + joinType = FULL; + } + + return new JoinedNodes( + new JoinNode( + Optional.empty(), + context.getIdAllocator().getNextId(), + joinType, + left.node(), + right.node(), + ImmutableList.of(), + Stream.concat(left.node().getOutputVariables().stream(), + right.node().getOutputVariables().stream()) + .collect(Collectors.toList()), + Optional.of(joinCondition), + Optional.empty(), + Optional.empty(), + Optional.empty(), + ImmutableMap.of()), + left.rowNumber(), + left.partitionSize(), + left.partitionBy(), + left.pruneWhenEmpty(), + left.rowNumberSymbolsMapping(), + right.rowNumber(), + right.partitionSize(), + right.partitionBy(), + right.pruneWhenEmpty(), + right.rowNumberSymbolsMapping()); + } + + private static NodeWithVariables appendHelperSymbolsForJoinedNodes(JoinedNodes joinedNodes, Context context, Metadata metadata) + { + // Derive row number for joined partitions: this is the bigger partition's row number. One of the combined values might be null as a result of outer join. + VariableReferenceExpression joinedRowNumber = context.getVariableAllocator().newVariable("combined_row_number", BIGINT); + RowExpression rowNumberExpression = new SpecialFormExpression( + IF, + BIGINT, + ImmutableList.of( + new CallExpression( + GREATER_THAN.name(), + metadata.getFunctionAndTypeManager().resolveOperator( + OperatorType.GREATER_THAN, + fromTypes(BIGINT, BIGINT)), + BOOLEAN, + ImmutableList.of( + new SpecialFormExpression( + COALESCE, + BIGINT, + joinedNodes.leftRowNumber(), + new ConstantExpression(-1L, BIGINT)), + new SpecialFormExpression( + COALESCE, + BIGINT, + joinedNodes.rightRowNumber(), + new ConstantExpression(-1L, BIGINT)))), + joinedNodes.leftRowNumber(), + joinedNodes.rightRowNumber())); + + // Derive partition size for joined partitions: this is the bigger partition's size. One of the combined values might be null as a result of outer join. + VariableReferenceExpression joinedPartitionSize = context.getVariableAllocator().newVariable("combined_partition_size", BIGINT); + RowExpression partitionSizeExpression = new SpecialFormExpression( + IF, + BIGINT, + ImmutableList.of( + new CallExpression( + GREATER_THAN.name(), + metadata.getFunctionAndTypeManager().resolveOperator( + OperatorType.GREATER_THAN, + fromTypes(BIGINT, BIGINT)), + BOOLEAN, + ImmutableList.of( + new SpecialFormExpression( + COALESCE, + BIGINT, + joinedNodes.leftPartitionSize(), + new ConstantExpression(-1L, BIGINT)), + new SpecialFormExpression( + COALESCE, + BIGINT, + joinedNodes.rightPartitionSize(), + new ConstantExpression(-1L, BIGINT)))), + joinedNodes.leftPartitionSize(), + joinedNodes.rightPartitionSize())); + + PlanNode project = new ProjectNode( + context.getIdAllocator().getNextId(), + joinedNodes.joinedNode(), + Assignments.builder() + .putIdentities(joinedNodes.joinedNode().getOutputVariables()) + .put(joinedRowNumber, rowNumberExpression) + .put(joinedPartitionSize, partitionSizeExpression) + .build()); + + List joinedPartitionBy = ImmutableList.builder() + .addAll(joinedNodes.leftPartitionBy()) + .addAll(joinedNodes.rightPartitionBy()) + .build(); + + boolean joinedPruneWhenEmpty = joinedNodes.leftPruneWhenEmpty() || joinedNodes.rightPruneWhenEmpty(); + + Map joinedRowNumberSymbolsMapping = ImmutableMap.builder() + .putAll(joinedNodes.leftRowNumberSymbolsMapping()) + .putAll(joinedNodes.rightRowNumberSymbolsMapping()) + .buildOrThrow(); + + return new NodeWithVariables(project, joinedRowNumber, joinedPartitionSize, joinedPartitionBy, joinedPruneWhenEmpty, joinedRowNumberSymbolsMapping); + } + + private static NodeWithMarkers appendMarkerSymbols(PlanNode node, Set variables, VariableReferenceExpression referenceSymbol, Context context, Metadata metadata) + { + Assignments.Builder assignments = Assignments.builder(); + assignments.putIdentities(node.getOutputVariables()); + + ImmutableMap.Builder variablesToMarkers = ImmutableMap.builder(); + + for (VariableReferenceExpression variable : variables) { + VariableReferenceExpression marker = context.getVariableAllocator().newVariable("marker", BIGINT); + variablesToMarkers.put(variable, marker); + RowExpression ifExpression = new SpecialFormExpression( + IF, + BIGINT, + ImmutableList.of( + new CallExpression( + EQUAL.name(), + metadata.getFunctionAndTypeManager().resolveOperator( + OperatorType.EQUAL, + fromTypes(BIGINT, BIGINT)), + BOOLEAN, + ImmutableList.of(variable, referenceSymbol)), + variable, + new ConstantExpression(null, BIGINT))); + assignments.put(marker, ifExpression); + } + + PlanNode project = new ProjectNode( + context.getIdAllocator().getNextId(), + node, + assignments.build()); + + return new NodeWithMarkers(project, variablesToMarkers.buildOrThrow()); + } + + private static class SourceWithProperties + { + private final PlanNode source; + private final TableArgumentProperties properties; + + public SourceWithProperties(PlanNode source, TableArgumentProperties properties) + { + this.source = requireNonNull(source, "source is null"); + this.properties = requireNonNull(properties, "properties is null"); + } + + public PlanNode source() + { + return source; + } + + public TableArgumentProperties properties() + { + return properties; + } + } + + public static final class NodeWithVariables + { + private final PlanNode node; + private final VariableReferenceExpression rowNumber; + private final VariableReferenceExpression partitionSize; + private final List partitionBy; + private final boolean pruneWhenEmpty; + private final Map rowNumberSymbolsMapping; + + public NodeWithVariables(PlanNode node, VariableReferenceExpression rowNumber, VariableReferenceExpression partitionSize, + List partitionBy, boolean pruneWhenEmpty, + Map rowNumberSymbolsMapping) + { + this.node = requireNonNull(node, "node is null"); + this.rowNumber = requireNonNull(rowNumber, "rowNumber is null"); + this.partitionSize = requireNonNull(partitionSize, "partitionSize is null"); + this.partitionBy = ImmutableList.copyOf(partitionBy); + this.pruneWhenEmpty = pruneWhenEmpty; + this.rowNumberSymbolsMapping = ImmutableMap.copyOf(rowNumberSymbolsMapping); + } + + public PlanNode node() + { + return node; + } + + public VariableReferenceExpression rowNumber() + { + return rowNumber; + } + + public VariableReferenceExpression partitionSize() + { + return partitionSize; + } + + public List partitionBy() + { + return partitionBy; + } + + public boolean pruneWhenEmpty() + { + return pruneWhenEmpty; + } + + public Map rowNumberSymbolsMapping() + { + return rowNumberSymbolsMapping; + } + } + + private static class JoinedNodes + { + private final PlanNode joinedNode; + private final VariableReferenceExpression leftRowNumber; + private final VariableReferenceExpression leftPartitionSize; + private final List leftPartitionBy; + private final boolean leftPruneWhenEmpty; + private final Map leftRowNumberSymbolsMapping; + private final VariableReferenceExpression rightRowNumber; + private final VariableReferenceExpression rightPartitionSize; + private final List rightPartitionBy; + private final boolean rightPruneWhenEmpty; + private final Map rightRowNumberSymbolsMapping; + + public JoinedNodes( + PlanNode joinedNode, + VariableReferenceExpression leftRowNumber, + VariableReferenceExpression leftPartitionSize, + List leftPartitionBy, + boolean leftPruneWhenEmpty, + Map leftRowNumberSymbolsMapping, + VariableReferenceExpression rightRowNumber, + VariableReferenceExpression rightPartitionSize, + List rightPartitionBy, + boolean rightPruneWhenEmpty, + Map rightRowNumberSymbolsMapping) + { + this.joinedNode = requireNonNull(joinedNode, "joinedNode is null"); + this.leftRowNumber = requireNonNull(leftRowNumber, "leftRowNumber is null"); + this.leftPartitionSize = requireNonNull(leftPartitionSize, "leftPartitionSize is null"); + this.leftPartitionBy = ImmutableList.copyOf(requireNonNull(leftPartitionBy, "leftPartitionBy is null")); + this.leftPruneWhenEmpty = leftPruneWhenEmpty; + this.leftRowNumberSymbolsMapping = ImmutableMap.copyOf(requireNonNull(leftRowNumberSymbolsMapping, "leftRowNumberSymbolsMapping is null")); + this.rightRowNumber = requireNonNull(rightRowNumber, "rightRowNumber is null"); + this.rightPartitionSize = requireNonNull(rightPartitionSize, "rightPartitionSize is null"); + this.rightPartitionBy = ImmutableList.copyOf(requireNonNull(rightPartitionBy, "rightPartitionBy is null")); + this.rightPruneWhenEmpty = rightPruneWhenEmpty; + this.rightRowNumberSymbolsMapping = ImmutableMap.copyOf(requireNonNull(rightRowNumberSymbolsMapping, "rightRowNumberSymbolsMapping is null")); + } + + public PlanNode joinedNode() + { + return joinedNode; + } + public VariableReferenceExpression leftRowNumber() + { + return leftRowNumber; + } + public VariableReferenceExpression leftPartitionSize() + { + return leftPartitionSize; + } + public List leftPartitionBy() + { + return leftPartitionBy; + } + public boolean leftPruneWhenEmpty() + { + return leftPruneWhenEmpty; + } + public Map leftRowNumberSymbolsMapping() + { + return leftRowNumberSymbolsMapping; + } + public VariableReferenceExpression rightRowNumber() + { + return rightRowNumber; + } + public VariableReferenceExpression rightPartitionSize() + { + return rightPartitionSize; + } + public List rightPartitionBy() + { + return rightPartitionBy; + } + public boolean rightPruneWhenEmpty() + { + return rightPruneWhenEmpty; + } + public Map rightRowNumberSymbolsMapping() + { + return rightRowNumberSymbolsMapping; + } + } + + private static class NodeWithMarkers + { + private final PlanNode node; + private final Map variableToMarker; + + public NodeWithMarkers(PlanNode node, Map variableToMarker) + { + this.node = requireNonNull(node, "node is null"); + this.variableToMarker = ImmutableMap.copyOf(requireNonNull(variableToMarker, "symbolToMarker is null")); + } + + public PlanNode node() + { + return node; + } + + public Map variableToMarker() + { + return variableToMarker; + } + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteTableFunctionToTableScan.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteTableFunctionToTableScan.java new file mode 100644 index 0000000000000..d028eb3a1ed08 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteTableFunctionToTableScan.java @@ -0,0 +1,82 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.common.predicate.TupleDomain; +import com.facebook.presto.matching.Captures; +import com.facebook.presto.matching.Pattern; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.TableHandle; +import com.facebook.presto.spi.connector.TableFunctionApplicationResult; +import com.facebook.presto.spi.plan.TableScanNode; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; +import com.google.common.collect.ImmutableMap; + +import java.util.List; +import java.util.Optional; + +import static com.facebook.presto.matching.Pattern.empty; +import static com.facebook.presto.sql.planner.plan.Patterns.sources; +import static com.facebook.presto.sql.planner.plan.Patterns.tableFunctionProcessor; +import static com.google.common.base.Preconditions.checkState; +import static java.util.Objects.requireNonNull; + +public class RewriteTableFunctionToTableScan + implements Rule +{ + private static final Pattern PATTERN = tableFunctionProcessor() + .with(empty(sources())); + + private final Metadata metadata; + + public RewriteTableFunctionToTableScan(Metadata metadata) + { + this.metadata = requireNonNull(metadata, "metadata is null"); + } + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Result apply(TableFunctionProcessorNode node, Captures captures, Context context) + { + Optional> result = metadata.applyTableFunction(context.getSession(), node.getHandle()); + + if (!result.isPresent()) { + return Result.empty(); + } + + List columnHandles = result.get().getColumnHandles(); + checkState(node.getOutputVariables().size() == columnHandles.size(), "returned table does not match the node's output"); + ImmutableMap.Builder assignments = ImmutableMap.builder(); + for (int i = 0; i < columnHandles.size(); i++) { + assignments.put(node.getOutputVariables().get(i), columnHandles.get(i)); + } + + return Result.ofPlanNode(new TableScanNode( + node.getSourceLocation(), + node.getId(), + result.get().getTableHandle(), + node.getOutputVariables(), + assignments.buildOrThrow(), + TupleDomain.all(), + TupleDomain.all(), Optional.empty())); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/AddExchanges.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/AddExchanges.java index 7fe4f5d08937a..a8f9fbd2d425c 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/AddExchanges.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/AddExchanges.java @@ -75,6 +75,8 @@ import com.facebook.presto.sql.planner.plan.RowNumberNode; import com.facebook.presto.sql.planner.plan.SequenceNode; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; import com.facebook.presto.sql.planner.plan.UnnestNode; import com.google.common.annotations.VisibleForTesting; @@ -96,6 +98,7 @@ import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.NoSuchElementException; import java.util.Optional; import java.util.Set; import java.util.function.Function; @@ -138,6 +141,7 @@ import static com.facebook.presto.sql.planner.optimizations.ActualProperties.Global.singleStreamPartition; import static com.facebook.presto.sql.planner.optimizations.LocalProperties.grouped; import static com.facebook.presto.sql.planner.optimizations.PartitioningUtils.translateVariable; +import static com.facebook.presto.sql.planner.optimizations.PreferredProperties.partitionedWithLocal; import static com.facebook.presto.sql.planner.optimizations.SetOperationNodeUtils.fromListMultimap; import static com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.REMOTE_MATERIALIZED; import static com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.REMOTE_STREAMING; @@ -288,7 +292,7 @@ public PlanWithProperties visitAggregation(AggregationNode node, PreferredProper // from partial aggregations (enforced in `ValidateAggregationWithDefaultValues.java`). Therefore, we don't have preference on what the child will return. if (!node.getGroupingKeys().isEmpty() && !hasMixedGroupingSets) { AggregationPartitioningMergingStrategy aggregationPartitioningMergingStrategy = getAggregationPartitioningMergingStrategy(session); - preferredProperties = PreferredProperties.partitionedWithLocal(partitioningRequirement, grouped(node.getGroupingKeys())) + preferredProperties = partitionedWithLocal(partitioningRequirement, grouped(node.getGroupingKeys())) .mergeWithParent(parentPreferredProperties, shouldAggregationMergePartitionPreferences(aggregationPartitioningMergingStrategy)); if (aggregationPartitioningMergingStrategy.isAdoptingMergedPreference()) { @@ -349,7 +353,7 @@ private Function partitionBy = node.getSpecification().orElseThrow(NoSuchElementException::new).getPartitionBy(); + List> desiredProperties = new ArrayList<>(); + if (!partitionBy.isEmpty()) { + desiredProperties.add(new GroupingProperty<>(partitionBy)); + } + node.getSpecification().orElseThrow(NoSuchElementException::new).getOrderingScheme().ifPresent(orderingScheme -> desiredProperties.addAll(orderingScheme.toLocalProperties())); + + PlanWithProperties child = planChild(node, partitionedWithLocal(ImmutableSet.copyOf(partitionBy), desiredProperties)); + + // TODO do not gather if already gathered + if (!node.isPruneWhenEmpty()) { + child = withDerivedProperties( + gatheringExchange(idAllocator.getNextId(), REMOTE_STREAMING, child.getNode()), + child.getProperties()); + } + else if (!isStreamPartitionedOn(child.getProperties(), partitionBy) && + !isNodePartitionedOn(child.getProperties(), partitionBy)) { + if (partitionBy.isEmpty()) { + child = withDerivedProperties( + gatheringExchange(idAllocator.getNextId(), REMOTE_STREAMING, child.getNode()), + child.getProperties()); + } + else { + child = withDerivedProperties( + partitionedExchange(idAllocator.getNextId(), REMOTE_STREAMING, child.getNode(), Partitioning.create(FIXED_HASH_DISTRIBUTION, partitionBy), node.getHashSymbol()), + child.getProperties()); + } + } + + return rebaseAndDeriveProperties(node, child); + } + @Override public PlanWithProperties visitRowNumber(RowNumberNode node, PreferredProperties preferredProperties) { @@ -440,7 +497,7 @@ public PlanWithProperties visitRowNumber(RowNumberNode node, PreferredProperties PlanWithProperties child = planChild( node, - PreferredProperties.partitionedWithLocal(ImmutableSet.copyOf(node.getPartitionBy()), grouped(node.getPartitionBy())) + partitionedWithLocal(ImmutableSet.copyOf(node.getPartitionBy()), grouped(node.getPartitionBy())) .mergeWithParent(preferredProperties, !isExactPartitioningPreferred(session))); // TODO: add config option/session property to force parallel plan if child is unpartitioned and window has a PARTITION BY clause @@ -485,7 +542,7 @@ public PlanWithProperties visitTopNRowNumber(TopNRowNumberNode node, PreferredPr addExchange = partial -> gatheringExchange(idAllocator.getNextId(), REMOTE_STREAMING, partial); } else { - preferredChildProperties = PreferredProperties.partitionedWithLocal(ImmutableSet.copyOf(node.getPartitionBy()), grouped(node.getPartitionBy())) + preferredChildProperties = partitionedWithLocal(ImmutableSet.copyOf(node.getPartitionBy()), grouped(node.getPartitionBy())) .mergeWithParent(preferredProperties, !isExactPartitioningPreferred(session)); addExchange = partial -> partitionedExchange( idAllocator.getNextId(), @@ -558,7 +615,7 @@ public PlanWithProperties visitDelete(DeleteNode node, PreferredProperties prefe PlanWithProperties child = planChild( node, - PreferredProperties.partitionedWithLocal(ImmutableSet.copyOf(inputDistribution.getPartitionBy()), desiredProperties) + partitionedWithLocal(ImmutableSet.copyOf(inputDistribution.getPartitionBy()), desiredProperties) .mergeWithParent(preferredProperties, !isExactPartitioningPreferred(session))); if (!isStreamPartitionedOn(child.getProperties(), inputDistribution.getPartitionBy()) && @@ -598,7 +655,7 @@ private PlanWithProperties planSortWithPartition(SortNode node, PreferredPropert PlanWithProperties child = planChild( node, - PreferredProperties.partitionedWithLocal(ImmutableSet.copyOf(node.getPartitionBy()), desiredProperties) + partitionedWithLocal(ImmutableSet.copyOf(node.getPartitionBy()), desiredProperties) .mergeWithParent(preferredProperties, !isExactPartitioningPreferred(session))); if (!isStreamPartitionedOn(child.getProperties(), node.getPartitionBy()) && @@ -1208,7 +1265,7 @@ public PlanWithProperties visitIndexJoin(IndexJoinNode node, PreferredProperties // Only prefer grouping on join columns if no parent local property preferences List> desiredLocalProperties = preferredProperties.getLocalProperties().isEmpty() ? grouped(joinColumns) : ImmutableList.of(); - PlanWithProperties probeSource = accept(node.getProbeSource(), PreferredProperties.partitionedWithLocal(ImmutableSet.copyOf(joinColumns), desiredLocalProperties) + PlanWithProperties probeSource = accept(node.getProbeSource(), partitionedWithLocal(ImmutableSet.copyOf(joinColumns), desiredLocalProperties) .mergeWithParent(preferredProperties, true)); ActualProperties probeProperties = probeSource.getProperties(); diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/AddLocalExchanges.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/AddLocalExchanges.java index 80b111c3a39c0..be548ad819f32 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/AddLocalExchanges.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/AddLocalExchanges.java @@ -22,6 +22,7 @@ import com.facebook.presto.spi.VariableAllocator; import com.facebook.presto.spi.WarningCollector; import com.facebook.presto.spi.plan.AggregationNode; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; import com.facebook.presto.spi.plan.DeleteNode; import com.facebook.presto.spi.plan.DistinctLimitNode; import com.facebook.presto.spi.plan.EquiJoinClause; @@ -56,6 +57,8 @@ import com.facebook.presto.sql.planner.plan.LateralJoinNode; import com.facebook.presto.sql.planner.plan.RowNumberNode; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.facebook.presto.sql.planner.plan.TableWriterMergeNode; import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; import com.google.common.collect.ImmutableList; @@ -64,6 +67,7 @@ import java.util.ArrayList; import java.util.Iterator; import java.util.List; +import java.util.NoSuchElementException; import java.util.Optional; import java.util.Set; @@ -107,6 +111,7 @@ import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static java.lang.String.format; import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.toList; @@ -489,6 +494,82 @@ public PlanWithProperties visitDelete(DeleteNode node, StreamPreferredProperties return deriveProperties(result, child.getProperties()); } + @Override + public PlanWithProperties visitTableFunction(TableFunctionNode node, StreamPreferredProperties parentPreferences) + { + throw new IllegalStateException(format("Unexpected node: TableFunctionNode (%s)", node.getName())); + } + + @Override + public PlanWithProperties visitTableFunctionProcessor(TableFunctionProcessorNode node, StreamPreferredProperties parentPreferences) + { + if (!node.getSource().isPresent()) { + return deriveProperties(node, ImmutableList.of()); + } + + if (!node.getSpecification().isPresent()) { + // node.getSpecification.isEmpty() indicates that there were no sources or a single source with row semantics. + // The case of no sources was addressed above. + // The case of a single source with row semantics is addressed here. Source's properties do not hold after the TableFunctionProcessorNode + PlanWithProperties child = planAndEnforce(node.getSource().orElseThrow(NoSuchElementException::new), StreamPreferredProperties.any(), StreamPreferredProperties.any()); + return rebaseAndDeriveProperties(node, ImmutableList.of(child)); + } + + List partitionBy = node.getSpecification().orElseThrow(NoSuchElementException::new).getPartitionBy(); + StreamPreferredProperties childRequirements; + if (!node.isPruneWhenEmpty()) { + childRequirements = singleStream(); + } + else { + childRequirements = parentPreferences + .constrainTo(node.getSource().orElseThrow(NoSuchElementException::new).getOutputVariables()) + .withDefaultParallelism(session) + .withPartitioning(partitionBy); + } + + PlanWithProperties child = planAndEnforce(node.getSource().orElseThrow(NoSuchElementException::new), childRequirements, childRequirements); + + List> desiredProperties = new ArrayList<>(); + if (!partitionBy.isEmpty()) { + desiredProperties.add(new GroupingProperty<>(partitionBy)); + } + node.getSpecification().flatMap(DataOrganizationSpecification::getOrderingScheme).ifPresent(orderingScheme -> desiredProperties.addAll(orderingScheme.toLocalProperties())); + Iterator>> matchIterator = LocalProperties.match(child.getProperties().getLocalProperties(), desiredProperties).iterator(); + + Set prePartitionedInputs = ImmutableSet.of(); + if (!partitionBy.isEmpty()) { + Optional> groupingRequirement = matchIterator.next(); + Set unPartitionedInputs = groupingRequirement.map(LocalProperty::getColumns).orElse(ImmutableSet.of()); + prePartitionedInputs = partitionBy.stream() + .filter(symbol -> !unPartitionedInputs.contains(symbol)) + .collect(toImmutableSet()); + } + + int preSortedOrderPrefix = 0; + if (prePartitionedInputs.equals(ImmutableSet.copyOf(partitionBy))) { + while (matchIterator.hasNext() && !matchIterator.next().isPresent()) { + preSortedOrderPrefix++; + } + } + + TableFunctionProcessorNode result = new TableFunctionProcessorNode( + node.getId(), + node.getName(), + node.getProperOutputs(), + Optional.of(child.getNode()), + node.isPruneWhenEmpty(), + node.getPassThroughSpecifications(), + node.getRequiredVariables(), + node.getMarkerVariables(), + node.getSpecification(), + prePartitionedInputs, + preSortedOrderPrefix, + node.getHashSymbol(), + node.getHandle()); + + return deriveProperties(result, child.getProperties()); + } + @Override public PlanWithProperties visitMarkDistinct(MarkDistinctNode node, StreamPreferredProperties parentPreferences) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PlanNodeDecorrelator.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PlanNodeDecorrelator.java index 9110ec4ceb128..9915c495762be 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PlanNodeDecorrelator.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PlanNodeDecorrelator.java @@ -18,6 +18,7 @@ import com.facebook.presto.spi.VariableAllocator; import com.facebook.presto.spi.plan.AggregationNode; import com.facebook.presto.spi.plan.Assignments; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; import com.facebook.presto.spi.plan.FilterNode; import com.facebook.presto.spi.plan.LimitNode; import com.facebook.presto.spi.plan.Ordering; @@ -27,7 +28,6 @@ import com.facebook.presto.spi.plan.PlanNodeIdAllocator; import com.facebook.presto.spi.plan.ProjectNode; import com.facebook.presto.spi.plan.TopNNode; -import com.facebook.presto.spi.plan.WindowNode.Specification; import com.facebook.presto.spi.relation.CallExpression; import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; @@ -307,7 +307,7 @@ public Optional visitTopN(TopNNode node, Void context) decorrelatedChildNode.getSourceLocation(), node.getId(), decorrelatedChildNode, - new Specification( + new DataOrganizationSpecification( ImmutableList.copyOf(childDecorrelationResult.variablesToPropagate), Optional.of(orderingScheme)), variableAllocator.newVariable("row_number", BIGINT), diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PropertyDerivations.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PropertyDerivations.java index d62c9e17bf1fc..ced8e6b7b0cf5 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PropertyDerivations.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PropertyDerivations.java @@ -24,6 +24,7 @@ import com.facebook.presto.spi.LocalProperty; import com.facebook.presto.spi.SortingProperty; import com.facebook.presto.spi.plan.AggregationNode; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; import com.facebook.presto.spi.plan.DeleteNode; import com.facebook.presto.spi.plan.DistinctLimitNode; import com.facebook.presto.spi.plan.EquiJoinClause; @@ -67,6 +68,8 @@ import com.facebook.presto.sql.planner.plan.SampleNode; import com.facebook.presto.sql.planner.plan.SequenceNode; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.facebook.presto.sql.planner.plan.TableWriterMergeNode; import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; import com.facebook.presto.sql.planner.plan.UnnestNode; @@ -105,6 +108,7 @@ import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static java.lang.String.format; import static java.util.stream.Collectors.toMap; public class PropertyDerivations @@ -263,6 +267,50 @@ public ActualProperties visitWindow(WindowNode node, List inpu .build(); } + @Override + public ActualProperties visitTableFunction(TableFunctionNode node, List inputProperties) + { + throw new IllegalStateException(format("Unexpected node: TableFunctionNode (%s)", node.getName())); + } + + @Override + public ActualProperties visitTableFunctionProcessor(TableFunctionProcessorNode node, List inputProperties) + { + ImmutableList.Builder> localProperties = ImmutableList.builder(); + + if (node.getSource().isPresent()) { + ActualProperties properties = Iterables.getOnlyElement(inputProperties); + + // Only the partitioning properties of the source are passed-through, because the pass-through mechanism preserves the partitioning values. + // Sorting properties might be broken because input rows can be shuffled or nulls can be inserted as the result of pass-through. + // Constant properties might be broken because nulls can be inserted as the result of pass-through. + if (!node.getPrePartitioned().isEmpty()) { + GroupingProperty prePartitionedProperty = new GroupingProperty<>(node.getPrePartitioned()); + for (LocalProperty localProperty : properties.getLocalProperties()) { + if (!prePartitionedProperty.isSimplifiedBy(localProperty)) { + break; + } + localProperties.add(localProperty); + } + } + } + + List partitionBy = node.getSpecification() + .map(DataOrganizationSpecification::getPartitionBy) + .orElse(ImmutableList.of()); + if (!partitionBy.isEmpty()) { + localProperties.add(new GroupingProperty<>(partitionBy)); + } + + // TODO add global single stream property when there's Specification present with no partitioning columns + + return ActualProperties.builder() + .local(localProperties.build()) + .build() + // Crop properties to output columns. + .translateVariable(variable -> node.getOutputVariables().contains(variable) ? Optional.of(variable) : Optional.empty()); + } + @Override public ActualProperties visitGroupId(GroupIdNode node, List inputProperties) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java index a595f1768d6b6..3f6179ae9776b 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java @@ -64,6 +64,7 @@ import com.facebook.presto.sql.planner.plan.SequenceNode; import com.facebook.presto.sql.planner.plan.SimplePlanRewriter; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.facebook.presto.sql.planner.plan.TableWriterMergeNode; import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; import com.facebook.presto.sql.planner.plan.UnnestNode; @@ -985,5 +986,11 @@ public PlanNode visitLateralJoin(LateralJoinNode node, RewriteContext> context) + { + return node; + } } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/StreamPropertyDerivations.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/StreamPropertyDerivations.java index 9bb82623312ca..525a44c3b0597 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/StreamPropertyDerivations.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/StreamPropertyDerivations.java @@ -57,6 +57,8 @@ import com.facebook.presto.sql.planner.plan.SampleNode; import com.facebook.presto.sql.planner.plan.SequenceNode; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.facebook.presto.sql.planner.plan.TableWriterMergeNode; import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; import com.facebook.presto.sql.planner.plan.UnnestNode; @@ -66,11 +68,13 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; +import com.google.common.collect.Sets; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.NoSuchElementException; import java.util.Objects; import java.util.Optional; import java.util.Set; @@ -555,6 +559,32 @@ public StreamProperties visitWindow(WindowNode node, List inpu return Iterables.getOnlyElement(inputProperties); } + @Override + public StreamProperties visitTableFunction(TableFunctionNode node, List inputProperties) + { + throw new IllegalStateException(format("Unexpected node: TableFunctionNode (%s)", node.getName())); + } + + @Override + public StreamProperties visitTableFunctionProcessor(TableFunctionProcessorNode node, List inputProperties) + { + if (!node.getSource().isPresent()) { + return StreamProperties.singleStream(); // TODO allow multiple; return partitioning properties + } + + StreamProperties properties = Iterables.getOnlyElement(inputProperties); + + Set passThroughInputs = Sets.intersection(ImmutableSet.copyOf(node.getSource().orElseThrow(NoSuchElementException::new).getOutputVariables()), ImmutableSet.copyOf(node.getOutputVariables())); + StreamProperties translatedProperties = properties.translate(column -> { + if (passThroughInputs.contains(column)) { + return Optional.of(column); + } + return Optional.empty(); + }); + + return translatedProperties.unordered(true); + } + @Override public StreamProperties visitRowNumber(RowNumberNode node, List inputProperties) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/SymbolMapper.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/SymbolMapper.java index 9805efad17939..3e36e8fd21f7b 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/SymbolMapper.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/SymbolMapper.java @@ -20,6 +20,7 @@ import com.facebook.presto.spi.WarningCollector; import com.facebook.presto.spi.plan.AggregationNode; import com.facebook.presto.spi.plan.AggregationNode.Aggregation; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; import com.facebook.presto.spi.plan.Ordering; import com.facebook.presto.spi.plan.OrderingScheme; import com.facebook.presto.spi.plan.PartitioningScheme; @@ -37,6 +38,8 @@ import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.facebook.presto.sql.planner.plan.TableWriterMergeNode; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.ExpressionRewriter; @@ -51,12 +54,14 @@ import java.util.List; import java.util.Map; import java.util.Map.Entry; +import java.util.Optional; import java.util.Set; import static com.facebook.presto.spi.StandardWarningCode.MULTIPLE_ORDER_BY; import static com.facebook.presto.spi.plan.AggregationNode.groupingSets; import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.getNodeLocation; import static com.facebook.presto.sql.planner.optimizations.PartitioningUtils.translateVariable; +import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.ImmutableSet.toImmutableSet; @@ -110,6 +115,13 @@ public VariableReferenceExpression map(VariableReferenceExpression variable) return new VariableReferenceExpression(variable.getSourceLocation(), canonical, types.get(new SymbolReference(getNodeLocation(variable.getSourceLocation()), canonical))); } + public List map(List symbols) + { + return symbols.stream() + .map(this::map) + .collect(toImmutableList()); + } + public Expression map(Expression value) { return ExpressionTreeRewriter.rewriteWith(new ExpressionRewriter() @@ -135,6 +147,27 @@ public RowExpression rewriteVariableReference(VariableReferenceExpression variab }, value); } + public OrderingSchemeWithPreSortedPrefix map(OrderingScheme orderingScheme, int preSorted) + { + ImmutableList.Builder newOrderings = ImmutableList.builder(); + int newPreSorted = preSorted; + + Set added = new HashSet<>(orderingScheme.getOrderBy().size()); + + for (int i = 0; i < orderingScheme.getOrderBy().size(); i++) { + VariableReferenceExpression variable = orderingScheme.getOrderBy().get(i).getVariable(); + VariableReferenceExpression canonical = map(variable); + if (added.add(canonical)) { + newOrderings.add(new Ordering(canonical, orderingScheme.getOrdering(variable))); + } + else if (i < preSorted) { + newPreSorted--; + } + } + + return new OrderingSchemeWithPreSortedPrefix(new OrderingScheme(newOrderings.build()), newPreSorted); + } + public OrderingScheme map(OrderingScheme orderingScheme) { // SymbolMapper inlines symbol with multiple level reference (SymbolInliner only inline single level). @@ -299,6 +332,64 @@ public TableWriterMergeNode map(TableWriterMergeNode node, PlanNode source) node.getStatisticsAggregation().map(this::map)); } + public TableFunctionProcessorNode map(TableFunctionProcessorNode node, PlanNode source) + { + // rewrite and deduplicate pass-through specifications + // note: Potentially, pass-through symbols from different sources might be recognized as semantically identical, and rewritten + // to the same symbol. Currently, we retrieve the first occurrence of a symbol, and skip all the following occurrences. + // For better performance, we could pick the occurrence with "isPartitioningColumn" property, since the pass-through mechanism + // is more efficient for partitioning columns which are guaranteed to be constant within partition. + // TODO choose a partitioning column to be retrieved while deduplicating + ImmutableList.Builder newPassThroughSpecifications = ImmutableList.builder(); + Set newPassThroughVariables = new HashSet<>(); + for (TableFunctionNode.PassThroughSpecification specification : node.getPassThroughSpecifications()) { + ImmutableList.Builder newColumns = ImmutableList.builder(); + for (TableFunctionNode.PassThroughColumn column : specification.getColumns()) { + VariableReferenceExpression newVariable = map(column.getOutputVariables()); + if (newPassThroughVariables.add(newVariable)) { + newColumns.add(new TableFunctionNode.PassThroughColumn(newVariable, column.isPartitioningColumn())); + } + } + newPassThroughSpecifications.add(new TableFunctionNode.PassThroughSpecification(specification.isDeclaredAsPassThrough(), newColumns.build())); + } + + // rewrite required symbols without deduplication. the table function expects specific input layout + List> newRequiredVariables = node.getRequiredVariables().stream() + .map(this::map) + .collect(toImmutableList()); + + // rewrite and deduplicate marker mapping + Optional> newMarkerVariables = node.getMarkerVariables() + .map(mapping -> mapping.entrySet().stream() + .collect(toImmutableMap( + entry -> map(entry.getKey()), + entry -> map(entry.getValue()), + (first, second) -> { + checkState(first.equals(second), "Ambiguous marker symbols: %s and %s", first, second); + return first; + }))); + + // rewrite and deduplicate specification + Optional newSpecification = node.getSpecification().map(specification -> mapAndDistinct(specification, node.getPreSorted())); + + return new TableFunctionProcessorNode( + node.getId(), + node.getName(), + map(node.getProperOutputs()), + Optional.of(source), + node.isPruneWhenEmpty(), + newPassThroughSpecifications.build(), + newRequiredVariables, + newMarkerVariables, + newSpecification.map(SpecificationWithPreSortedPrefix::getSpecification), + node.getPrePartitioned().stream() + .map(this::map) + .collect(toImmutableSet()), + newSpecification.map(SpecificationWithPreSortedPrefix::getPreSorted).orElse(node.getPreSorted()), + node.getHashSymbol().map(this::map), + node.getHandle()); + } + private PartitioningScheme canonicalize(PartitioningScheme scheme, PlanNode source) { return new PartitioningScheme(translateVariable(scheme.getPartitioning(), this::map), @@ -348,6 +439,25 @@ private List mapAndDistinctVariable(List newOrderingScheme = specification.getOrderingScheme() + .map(orderingScheme -> map(orderingScheme, preSorted)); + + return new SpecificationWithPreSortedPrefix( + new DataOrganizationSpecification( + mapAndDistinctVariable(specification.getPartitionBy()), + newOrderingScheme.map(OrderingSchemeWithPreSortedPrefix::getOrderingScheme)), + newOrderingScheme.map(OrderingSchemeWithPreSortedPrefix::getPreSorted).orElse(preSorted)); + } + + DataOrganizationSpecification mapAndDistinct(DataOrganizationSpecification specification) + { + return new DataOrganizationSpecification( + mapAndDistinctVariable(specification.getPartitionBy()), + specification.getOrderingScheme().map(this::map)); + } + public static SymbolMapper.Builder builder(WarningCollector warningCollector) { return new Builder(warningCollector); @@ -379,4 +489,48 @@ public void put(VariableReferenceExpression from, VariableReferenceExpression to mappingsBuilder.put(from, to); } } + + private static class OrderingSchemeWithPreSortedPrefix + { + private final OrderingScheme orderingScheme; + private final int preSorted; + + public OrderingSchemeWithPreSortedPrefix(OrderingScheme orderingScheme, int preSorted) + { + this.orderingScheme = requireNonNull(orderingScheme, "orderingScheme is null"); + this.preSorted = preSorted; + } + + public OrderingScheme getOrderingScheme() + { + return orderingScheme; + } + + public int getPreSorted() + { + return preSorted; + } + } + + private static class SpecificationWithPreSortedPrefix + { + private final DataOrganizationSpecification specification; + private final int preSorted; + + public SpecificationWithPreSortedPrefix(DataOrganizationSpecification specification, int preSorted) + { + this.specification = requireNonNull(specification, "specification is null"); + this.preSorted = preSorted; + } + + public DataOrganizationSpecification getSpecification() + { + return specification; + } + + public int getPreSorted() + { + return preSorted; + } + } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java index 5f90e0015dde6..50a169960375b 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java @@ -23,6 +23,7 @@ import com.facebook.presto.spi.plan.CteConsumerNode; import com.facebook.presto.spi.plan.CteProducerNode; import com.facebook.presto.spi.plan.CteReferenceNode; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; import com.facebook.presto.spi.plan.DeleteNode; import com.facebook.presto.spi.plan.DistinctLimitNode; import com.facebook.presto.spi.plan.EquiJoinClause; @@ -71,6 +72,10 @@ import com.facebook.presto.sql.planner.plan.SequenceNode; import com.facebook.presto.sql.planner.plan.SimplePlanRewriter; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughColumn; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughSpecification; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.facebook.presto.sql.planner.plan.TableWriterMergeNode; import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; import com.facebook.presto.sql.planner.plan.UnnestNode; @@ -80,6 +85,7 @@ import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; import com.google.common.collect.Lists; @@ -138,7 +144,7 @@ public PlanOptimizerResult optimize(PlanNode plan, Session session, TypeProvider } private static class Rewriter - extends SimplePlanRewriter + extends SimplePlanRewriter { private final Map mapping = new HashMap<>(); private final TypeProvider types; @@ -155,8 +161,13 @@ private Rewriter(TypeProvider types, FunctionAndTypeManager functionAndTypeManag this.warningCollector = warningCollector; } + public Map getMapping() + { + return mapping; + } + @Override - public PlanNode visitAggregation(AggregationNode node, RewriteContext context) + public PlanNode visitAggregation(AggregationNode node, RewriteContext context) { PlanNode source = context.rewrite(node.getSource()); //TODO: use mapper in other methods @@ -165,26 +176,26 @@ public PlanNode visitAggregation(AggregationNode node, RewriteContext cont } @Override - public PlanNode visitCteReference(CteReferenceNode node, RewriteContext context) + public PlanNode visitCteReference(CteReferenceNode node, RewriteContext context) { PlanNode source = context.rewrite(node.getSource()); return new CteReferenceNode(node.getSourceLocation(), node.getId(), source, node.getCteId()); } - public PlanNode visitCteProducer(CteProducerNode node, RewriteContext context) + public PlanNode visitCteProducer(CteProducerNode node, RewriteContext context) { PlanNode source = context.rewrite(node.getSource()); List canonical = Lists.transform(node.getOutputVariables(), this::canonicalize); return new CteProducerNode(node.getSourceLocation(), node.getId(), source, node.getCteId(), node.getRowCountVariable(), canonical); } - public PlanNode visitCteConsumer(CteConsumerNode node, RewriteContext context) + public PlanNode visitCteConsumer(CteConsumerNode node, RewriteContext context) { // No rewrite on source by cte consumer return node; } - public PlanNode visitSequence(SequenceNode node, RewriteContext context) + public PlanNode visitSequence(SequenceNode node, RewriteContext context) { List cteProducers = node.getCteProducers().stream().map(c -> SimplePlanRewriter.rewriteWith(new Rewriter(types, functionAndTypeManager, warningCollector), c)) @@ -194,7 +205,7 @@ public PlanNode visitSequence(SequenceNode node, RewriteContext context) } @Override - public PlanNode visitGroupId(GroupIdNode node, RewriteContext context) + public PlanNode visitGroupId(GroupIdNode node, RewriteContext context) { PlanNode source = context.rewrite(node.getSource()); @@ -214,21 +225,21 @@ public PlanNode visitGroupId(GroupIdNode node, RewriteContext context) } @Override - public PlanNode visitExplainAnalyze(ExplainAnalyzeNode node, RewriteContext context) + public PlanNode visitExplainAnalyze(ExplainAnalyzeNode node, RewriteContext context) { PlanNode source = context.rewrite(node.getSource()); return new ExplainAnalyzeNode(node.getSourceLocation(), node.getId(), source, canonicalize(node.getOutputVariable()), node.isVerbose(), node.getFormat()); } @Override - public PlanNode visitMarkDistinct(MarkDistinctNode node, RewriteContext context) + public PlanNode visitMarkDistinct(MarkDistinctNode node, RewriteContext context) { PlanNode source = context.rewrite(node.getSource()); return new MarkDistinctNode(node.getSourceLocation(), node.getId(), source, canonicalize(node.getMarkerVariable()), canonicalizeAndDistinct(node.getDistinctVariables()), canonicalize(node.getHashVariable())); } @Override - public PlanNode visitUnnest(UnnestNode node, RewriteContext context) + public PlanNode visitUnnest(UnnestNode node, RewriteContext context) { PlanNode source = context.rewrite(node.getSource()); ImmutableMap.Builder> builder = ImmutableMap.builder(); @@ -239,7 +250,7 @@ public PlanNode visitUnnest(UnnestNode node, RewriteContext context) } @Override - public PlanNode visitWindow(WindowNode node, RewriteContext context) + public PlanNode visitWindow(WindowNode node, RewriteContext context) { PlanNode source = context.rewrite(node.getSource()); @@ -299,13 +310,13 @@ private WindowNode.Frame canonicalize(WindowNode.Frame frame) } @Override - public PlanNode visitTableScan(TableScanNode node, RewriteContext context) + public PlanNode visitTableScan(TableScanNode node, RewriteContext context) { return node; } @Override - public PlanNode visitExchange(ExchangeNode node, RewriteContext context) + public PlanNode visitExchange(ExchangeNode node, RewriteContext context) { List sources = node.getSources().stream() .map(context::rewrite) @@ -391,7 +402,7 @@ private List canonicalizeExchangeNodeInputs(Exchang } @Override - public PlanNode visitRemoteSource(RemoteSourceNode node, RewriteContext context) + public PlanNode visitRemoteSource(RemoteSourceNode node, RewriteContext context) { return new RemoteSourceNode( node.getSourceLocation(), @@ -406,31 +417,31 @@ public PlanNode visitRemoteSource(RemoteSourceNode node, RewriteContext co } @Override - public PlanNode visitOffset(OffsetNode node, RewriteContext context) + public PlanNode visitOffset(OffsetNode node, RewriteContext context) { return context.defaultRewrite(node); } @Override - public PlanNode visitLimit(LimitNode node, RewriteContext context) + public PlanNode visitLimit(LimitNode node, RewriteContext context) { return context.defaultRewrite(node); } @Override - public PlanNode visitDistinctLimit(DistinctLimitNode node, RewriteContext context) + public PlanNode visitDistinctLimit(DistinctLimitNode node, RewriteContext context) { return new DistinctLimitNode(node.getSourceLocation(), node.getId(), context.rewrite(node.getSource()), node.getLimit(), node.isPartial(), canonicalizeAndDistinct(node.getDistinctVariables()), canonicalize(node.getHashVariable()), node.getTimeoutMillis()); } @Override - public PlanNode visitSample(SampleNode node, RewriteContext context) + public PlanNode visitSample(SampleNode node, RewriteContext context) { return new SampleNode(node.getSourceLocation(), node.getId(), context.rewrite(node.getSource()), node.getSampleRatio(), node.getSampleType()); } @Override - public PlanNode visitValues(ValuesNode node, RewriteContext context) + public PlanNode visitValues(ValuesNode node, RewriteContext context) { List> canonicalizedRows = node.getRows().stream() .map(rowExpressions -> rowExpressions.stream() @@ -448,19 +459,19 @@ public PlanNode visitValues(ValuesNode node, RewriteContext context) } @Override - public PlanNode visitDelete(DeleteNode node, RewriteContext context) + public PlanNode visitDelete(DeleteNode node, RewriteContext context) { return new DeleteNode(node.getSourceLocation(), node.getId(), context.rewrite(node.getSource()), canonicalize(node.getRowId()), node.getOutputVariables(), node.getInputDistribution()); } @Override - public PlanNode visitUpdate(UpdateNode node, RewriteContext context) + public PlanNode visitUpdate(UpdateNode node, RewriteContext context) { return new UpdateNode(node.getSourceLocation(), node.getId(), node.getSource(), canonicalize(node.getRowId()), node.getColumnValueAndRowIdSymbols(), node.getOutputVariables()); } @Override - public PlanNode visitStatisticsWriterNode(StatisticsWriterNode node, RewriteContext context) + public PlanNode visitStatisticsWriterNode(StatisticsWriterNode node, RewriteContext context) { PlanNode source = context.rewrite(node.getSource()); SymbolMapper mapper = new SymbolMapper(mapping, types, warningCollector); @@ -468,7 +479,7 @@ public PlanNode visitStatisticsWriterNode(StatisticsWriterNode node, RewriteCont } @Override - public PlanNode visitTableFinish(TableFinishNode node, RewriteContext context) + public PlanNode visitTableFinish(TableFinishNode node, RewriteContext context) { PlanNode source = context.rewrite(node.getSource()); SymbolMapper mapper = new SymbolMapper(mapping, types, warningCollector); @@ -476,13 +487,134 @@ public PlanNode visitTableFinish(TableFinishNode node, RewriteContext cont } @Override - public PlanNode visitRowNumber(RowNumberNode node, RewriteContext context) + public PlanNode visitTableFunction(TableFunctionNode node, RewriteContext context) + { + Map mappings = + Optional.ofNullable(context.get()) + .map(c -> new HashMap<>(c.getCorrelationMapping())) + .orElseGet(HashMap::new); + + SymbolMapper mapper = new SymbolMapper(mappings, warningCollector); + + List newProperOutputs = node.getOutputVariables().stream() + .map(mapper::map) + .collect(toImmutableList()); + + ImmutableList.Builder newSources = ImmutableList.builder(); + ImmutableList.Builder newTableArgumentProperties = ImmutableList.builder(); + + for (int i = 0; i < node.getSources().size(); i++) { + PlanNode newSource = node.getSources().get(i).accept(this, context); + newSources.add(newSource); + + SymbolMapper inputMapper = new SymbolMapper( + newSource instanceof UnaliasContext.PlanAndMappings + ? ((UnaliasContext.PlanAndMappings) newSource).getMappings() + : new HashMap<>(), + warningCollector); + + TableFunctionNode.TableArgumentProperties properties = node.getTableArgumentProperties().get(i); + + Optional newSpecification = properties.specification().map(inputMapper::mapAndDistinct); + PassThroughSpecification newPassThroughSpecification = new PassThroughSpecification( + properties.getPassThroughSpecification().isDeclaredAsPassThrough(), + properties.getPassThroughSpecification().getColumns().stream() + .map(column -> new PassThroughColumn( + inputMapper.map(column.getOutputVariables()), + column.isPartitioningColumn())) + .collect(toImmutableList())); + newTableArgumentProperties.add(new TableFunctionNode.TableArgumentProperties( + properties.getArgumentName(), + properties.rowSemantics(), + properties.pruneWhenEmpty(), + newPassThroughSpecification, + inputMapper.map(properties.getRequiredColumns()), + newSpecification)); + } + + TableFunctionNode tableFunctionNode = new TableFunctionNode( + node.getId(), + node.getName(), + node.getArguments(), + newProperOutputs, + newSources.build(), + newTableArgumentProperties.build(), + node.getCopartitioningLists(), + node.getHandle()); + + return new UnaliasContext.PlanAndMappings( + tableFunctionNode, + mappings) + { + @Override + public List getSources() + { + return tableFunctionNode.getSources(); + } + + @Override + public List getOutputVariables() + { + return tableFunctionNode.getOutputVariables(); + } + + @Override + public PlanNode replaceChildren(List newChildren) + { + return tableFunctionNode.replaceChildren(newChildren); + } + + @Override + public PlanNode assignStatsEquivalentPlanNode(Optional statsEquivalentPlanNode) + { + return tableFunctionNode.assignStatsEquivalentPlanNode(statsEquivalentPlanNode); + } + }; + } + + @Override + public PlanNode visitTableFunctionProcessor(TableFunctionProcessorNode node, RewriteContext context) + { + if (!node.getSource().isPresent()) { + Map mappings = + Optional.ofNullable(context.get()) + .map(c -> new HashMap<>(c.getCorrelationMapping())) + .orElseGet(HashMap::new); + SymbolMapper mapper = new SymbolMapper(mappings, warningCollector); + + TableFunctionProcessorNode rewrittenTableFunctionProcessor = new TableFunctionProcessorNode( + node.getId(), + node.getName(), + mapper.map(node.getProperOutputs()), + Optional.empty(), + node.isPruneWhenEmpty(), + ImmutableList.of(), + ImmutableList.of(), + Optional.empty(), + Optional.empty(), + ImmutableSet.of(), + 0, + node.getHashSymbol().map(mapper::map), + node.getHandle()); + + return rewrittenTableFunctionProcessor; + } + + PlanNode rewrittenSource = node.getSource().get().accept(this, context); + Map mappings = ((Rewriter) context.getNodeRewriter()).getMapping(); + SymbolMapper mapper = new SymbolMapper(mappings, types, warningCollector); + + return mapper.map(node, rewrittenSource); + } + + @Override + public PlanNode visitRowNumber(RowNumberNode node, RewriteContext context) { return new RowNumberNode(node.getSourceLocation(), node.getId(), context.rewrite(node.getSource()), canonicalizeAndDistinct(node.getPartitionBy()), canonicalize(node.getRowNumberVariable()), node.getMaxRowCountPerPartition(), node.isPartial(), canonicalize(node.getHashVariable())); } @Override - public PlanNode visitTopNRowNumber(TopNRowNumberNode node, RewriteContext context) + public PlanNode visitTopNRowNumber(TopNRowNumberNode node, RewriteContext context) { return new TopNRowNumberNode( node.getSourceLocation(), @@ -496,7 +628,7 @@ public PlanNode visitTopNRowNumber(TopNRowNumberNode node, RewriteContext } @Override - public PlanNode visitFilter(FilterNode node, RewriteContext context) + public PlanNode visitFilter(FilterNode node, RewriteContext context) { PlanNode source = context.rewrite(node.getSource()); @@ -504,14 +636,14 @@ public PlanNode visitFilter(FilterNode node, RewriteContext context) } @Override - public PlanNode visitProject(ProjectNode node, RewriteContext context) + public PlanNode visitProject(ProjectNode node, RewriteContext context) { PlanNode source = context.rewrite(node.getSource()); return new ProjectNode(node.getSourceLocation(), node.getId(), source, canonicalize(node.getAssignments()), node.getLocality()); } @Override - public PlanNode visitOutput(OutputNode node, RewriteContext context) + public PlanNode visitOutput(OutputNode node, RewriteContext context) { PlanNode source = context.rewrite(node.getSource()); @@ -520,7 +652,7 @@ public PlanNode visitOutput(OutputNode node, RewriteContext context) } @Override - public PlanNode visitEnforceSingleRow(EnforceSingleRowNode node, RewriteContext context) + public PlanNode visitEnforceSingleRow(EnforceSingleRowNode node, RewriteContext context) { PlanNode source = context.rewrite(node.getSource()); @@ -528,7 +660,7 @@ public PlanNode visitEnforceSingleRow(EnforceSingleRowNode node, RewriteContext< } @Override - public PlanNode visitAssignUniqueId(AssignUniqueId node, RewriteContext context) + public PlanNode visitAssignUniqueId(AssignUniqueId node, RewriteContext context) { PlanNode source = context.rewrite(node.getSource()); @@ -536,7 +668,7 @@ public PlanNode visitAssignUniqueId(AssignUniqueId node, RewriteContext co } @Override - public PlanNode visitApply(ApplyNode node, RewriteContext context) + public PlanNode visitApply(ApplyNode node, RewriteContext context) { PlanNode source = context.rewrite(node.getInput()); PlanNode subquery = context.rewrite(node.getSubquery()); @@ -548,7 +680,7 @@ public PlanNode visitApply(ApplyNode node, RewriteContext context) } @Override - public PlanNode visitLateralJoin(LateralJoinNode node, RewriteContext context) + public PlanNode visitLateralJoin(LateralJoinNode node, RewriteContext context) { PlanNode source = context.rewrite(node.getInput()); PlanNode subquery = context.rewrite(node.getSubquery()); @@ -558,7 +690,7 @@ public PlanNode visitLateralJoin(LateralJoinNode node, RewriteContext cont } @Override - public PlanNode visitTopN(TopNNode node, RewriteContext context) + public PlanNode visitTopN(TopNNode node, RewriteContext context) { PlanNode source = context.rewrite(node.getSource()); @@ -567,7 +699,7 @@ public PlanNode visitTopN(TopNNode node, RewriteContext context) } @Override - public PlanNode visitSort(SortNode node, RewriteContext context) + public PlanNode visitSort(SortNode node, RewriteContext context) { PlanNode source = context.rewrite(node.getSource()); @@ -575,7 +707,7 @@ public PlanNode visitSort(SortNode node, RewriteContext context) } @Override - public PlanNode visitJoin(JoinNode node, RewriteContext context) + public PlanNode visitJoin(JoinNode node, RewriteContext context) { PlanNode left = context.rewrite(node.getLeft()); PlanNode right = context.rewrite(node.getRight()); @@ -610,7 +742,7 @@ public PlanNode visitJoin(JoinNode node, RewriteContext context) } @Override - public PlanNode visitSemiJoin(SemiJoinNode node, RewriteContext context) + public PlanNode visitSemiJoin(SemiJoinNode node, RewriteContext context) { PlanNode source = context.rewrite(node.getSource()); PlanNode filteringSource = context.rewrite(node.getFilteringSource()); @@ -630,7 +762,7 @@ public PlanNode visitSemiJoin(SemiJoinNode node, RewriteContext context) } @Override - public PlanNode visitSpatialJoin(SpatialJoinNode node, RewriteContext context) + public PlanNode visitSpatialJoin(SpatialJoinNode node, RewriteContext context) { PlanNode left = context.rewrite(node.getLeft()); PlanNode right = context.rewrite(node.getRight()); @@ -639,13 +771,13 @@ public PlanNode visitSpatialJoin(SpatialJoinNode node, RewriteContext cont } @Override - public PlanNode visitIndexSource(IndexSourceNode node, RewriteContext context) + public PlanNode visitIndexSource(IndexSourceNode node, RewriteContext context) { return new IndexSourceNode(node.getSourceLocation(), node.getId(), node.getIndexHandle(), node.getTableHandle(), canonicalize(node.getLookupVariables()), node.getOutputVariables(), node.getAssignments(), node.getCurrentConstraint()); } @Override - public PlanNode visitIndexJoin(IndexJoinNode node, RewriteContext context) + public PlanNode visitIndexJoin(IndexJoinNode node, RewriteContext context) { PlanNode probeSource = context.rewrite(node.getProbeSource()); PlanNode indexSource = context.rewrite(node.getIndexSource()); @@ -654,24 +786,24 @@ public PlanNode visitIndexJoin(IndexJoinNode node, RewriteContext context) } @Override - public PlanNode visitUnion(UnionNode node, RewriteContext context) + public PlanNode visitUnion(UnionNode node, RewriteContext context) { return new UnionNode(node.getSourceLocation(), node.getId(), rewriteSources(node, context).build(), canonicalizeSetOperationOutputVariables(node.getOutputVariables()), canonicalizeSetOperationVariableMap(node.getVariableMapping())); } @Override - public PlanNode visitIntersect(IntersectNode node, RewriteContext context) + public PlanNode visitIntersect(IntersectNode node, RewriteContext context) { return new IntersectNode(node.getSourceLocation(), node.getId(), rewriteSources(node, context).build(), canonicalizeSetOperationOutputVariables(node.getOutputVariables()), canonicalizeSetOperationVariableMap(node.getVariableMapping())); } @Override - public PlanNode visitExcept(ExceptNode node, RewriteContext context) + public PlanNode visitExcept(ExceptNode node, RewriteContext context) { return new ExceptNode(node.getSourceLocation(), node.getId(), rewriteSources(node, context).build(), canonicalizeSetOperationOutputVariables(node.getOutputVariables()), canonicalizeSetOperationVariableMap(node.getVariableMapping())); } - private static ImmutableList.Builder rewriteSources(SetOperationNode node, RewriteContext context) + private static ImmutableList.Builder rewriteSources(SetOperationNode node, RewriteContext context) { ImmutableList.Builder rewrittenSources = ImmutableList.builder(); for (PlanNode source : node.getSources()) { @@ -681,7 +813,7 @@ private static ImmutableList.Builder rewriteSources(SetOperationNode n } @Override - public PlanNode visitTableWriter(TableWriterNode node, RewriteContext context) + public PlanNode visitTableWriter(TableWriterNode node, RewriteContext context) { PlanNode source = context.rewrite(node.getSource()); SymbolMapper mapper = new SymbolMapper(mapping, types, warningCollector); @@ -689,7 +821,7 @@ public PlanNode visitTableWriter(TableWriterNode node, RewriteContext cont } @Override - public PlanNode visitTableWriteMerge(TableWriterMergeNode node, RewriteContext context) + public PlanNode visitTableWriteMerge(TableWriterMergeNode node, RewriteContext context) { PlanNode source = context.rewrite(node.getSource()); SymbolMapper mapper = new SymbolMapper(mapping, types, warningCollector); @@ -697,7 +829,7 @@ public PlanNode visitTableWriteMerge(TableWriterMergeNode node, RewriteContext context) + public PlanNode visitPlan(PlanNode node, RewriteContext context) { throw new UnsupportedOperationException("Unsupported plan node " + node.getClass().getSimpleName()); } @@ -790,9 +922,9 @@ private Map canonicalizeAndDistinct(Map canonicalizeSetOperationOutputVariable return builder.build(); } } + + private static class UnaliasContext + { + // Correlation mapping is a record of how correlation symbols have been mapped in the subplan which provides them. + // All occurrences of correlation symbols within the correlated subquery must be remapped accordingly. + // In case of nested correlation, correlationMappings has required mappings for correlation symbols from all levels of nesting. + private final Map correlationMapping; + + public UnaliasContext(Map correlationMapping) + { + this.correlationMapping = requireNonNull(correlationMapping, "correlationMapping is null"); + } + + public static UnaliasContext empty() + { + return new UnaliasContext(ImmutableMap.of()); + } + + public Map getCorrelationMapping() + { + return correlationMapping; + } + + private abstract static class PlanAndMappings + extends PlanNode + { + private final Map mappings; + + public PlanAndMappings(PlanNode root, Map mappings) + { + super(root.getSourceLocation(), root.getId(), root.getStatsEquivalentPlanNode()); + this.mappings = ImmutableMap.copyOf(requireNonNull(mappings, "mappings is null")); + } + + public Map getMappings() + { + return mappings; + } + } + } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/WindowFilterPushDown.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/WindowFilterPushDown.java index 337e9eb39df03..308d5223f64a9 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/WindowFilterPushDown.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/WindowFilterPushDown.java @@ -51,7 +51,6 @@ import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.expressions.LogicalRowExpressions.TRUE_CONSTANT; import static com.facebook.presto.sql.planner.plan.ChildReplacer.replaceChildren; -import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Verify.verify; import static com.google.common.collect.Iterables.getOnlyElement; import static java.lang.Math.toIntExact; @@ -116,7 +115,7 @@ public boolean isPlanChanged() @Override public PlanNode visitWindow(WindowNode node, RewriteContext context) { - checkState(node.getWindowFunctions().size() == 1, "WindowFilterPushdown requires that WindowNodes contain exactly one window function"); +// checkState(node.getWindowFunctions().size() == 1, "WindowFilterPushdown requires that WindowNodes contain exactly one window function"); PlanNode rewrittenSource = context.rewrite(node.getSource()); if (canReplaceWithRowNumber(node, metadata.getFunctionAndTypeManager())) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/InternalPlanVisitor.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/InternalPlanVisitor.java index 43608a98ec271..bd5549ce46ec7 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/InternalPlanVisitor.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/InternalPlanVisitor.java @@ -136,4 +136,14 @@ public R visitSequence(SequenceNode node, C context) { return visitPlan(node, context); } + + public R visitTableFunction(TableFunctionNode node, C context) + { + return visitPlan(node, context); + } + + public R visitTableFunctionProcessor(TableFunctionProcessorNode node, C context) + { + return visitPlan(node, context); + } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/Patterns.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/Patterns.java index 21a54b9b325a6..1e9e2b6ae5afd 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/Patterns.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/Patterns.java @@ -228,6 +228,16 @@ public static Pattern window() return typeOf(WindowNode.class); } + public static Pattern tableFunction() + { + return typeOf(TableFunctionNode.class); + } + + public static Pattern tableFunctionProcessor() + { + return typeOf(TableFunctionProcessorNode.class); + } + public static Pattern rowNumber() { return typeOf(RowNumberNode.class); diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/SimplePlanRewriter.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/SimplePlanRewriter.java index 22d4f18e42ff9..f87c1a1bba5c5 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/SimplePlanRewriter.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/SimplePlanRewriter.java @@ -61,6 +61,11 @@ public C get() return userContext; } + public SimplePlanRewriter getNodeRewriter() + { + return nodeRewriter; + } + /** * Invoke the rewrite logic recursively on children of the given node and swap it * out with an identical copy with the rewritten children diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/TableFunctionNode.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/TableFunctionNode.java new file mode 100644 index 0000000000000..a44404000a740 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/TableFunctionNode.java @@ -0,0 +1,286 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.plan; + +import com.facebook.presto.metadata.TableFunctionHandle; +import com.facebook.presto.spi.SourceLocation; +import com.facebook.presto.spi.function.table.Argument; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; +import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.spi.plan.PlanNodeId; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; + +import javax.annotation.concurrent.Immutable; + +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.Objects.requireNonNull; + +@Immutable +public class TableFunctionNode + extends InternalPlanNode +{ + private final String name; + private final Map arguments; + private final List outputVariables; + private final List sources; + private final List tableArgumentProperties; + private final List> copartitioningLists; + private final TableFunctionHandle handle; + + @JsonCreator + public TableFunctionNode( + @JsonProperty("id") PlanNodeId id, + @JsonProperty("name") String name, + @JsonProperty("arguments") Map arguments, + @JsonProperty("outputVariables") List outputVariables, + @JsonProperty("sources") List sources, + @JsonProperty("tableArgumentProperties") List tableArgumentProperties, + @JsonProperty("copartitioningLists") List> copartitioningLists, + @JsonProperty("handle") TableFunctionHandle handle) + { + this(Optional.empty(), id, Optional.empty(), name, arguments, outputVariables, sources, tableArgumentProperties, copartitioningLists, handle); + } + + public TableFunctionNode( + Optional sourceLocation, + PlanNodeId id, + Optional statsEquivalentPlanNode, + String name, + Map arguments, + List outputVariables, + List sources, + List tableArgumentProperties, + List> copartitioningLists, + TableFunctionHandle handle) + { + super(sourceLocation, id, statsEquivalentPlanNode); + this.name = requireNonNull(name, "name is null"); + this.arguments = ImmutableMap.copyOf(arguments); + this.outputVariables = ImmutableList.copyOf(outputVariables); + this.sources = ImmutableList.copyOf(sources); + this.tableArgumentProperties = ImmutableList.copyOf(tableArgumentProperties); + this.copartitioningLists = copartitioningLists.stream() + .map(ImmutableList::copyOf) + .collect(toImmutableList()); + this.handle = requireNonNull(handle, "handle is null"); + } + + @JsonProperty + public String getName() + { + return name; + } + + @JsonProperty + public Map getArguments() + { + return arguments; + } + + @Override + public List getOutputVariables() + { + ImmutableList.Builder variables = ImmutableList.builder(); + + variables.addAll(outputVariables); + + tableArgumentProperties.stream() + .map(TableArgumentProperties::getPassThroughSpecification) + .map(PassThroughSpecification::getColumns) + .flatMap(Collection::stream) + .map(PassThroughColumn::getOutputVariables) + .forEach(variables::add); + + return variables.build(); + } + + public List getProperOutputs() + { + return outputVariables; + } + + @JsonProperty + public List getTableArgumentProperties() + { + return tableArgumentProperties; + } + + @JsonProperty + public List> getCopartitioningLists() + { + return copartitioningLists; + } + + @JsonProperty + public TableFunctionHandle getHandle() + { + return handle; + } + + @JsonProperty + @Override + public List getSources() + { + return sources; + } + + @Override + public R accept(InternalPlanVisitor visitor, C context) + { + return visitor.visitTableFunction(this, context); + } + + @Override + public PlanNode replaceChildren(List newSources) + { + checkArgument(sources.size() == newSources.size(), "wrong number of new children"); + return new TableFunctionNode(getId(), name, arguments, outputVariables, newSources, tableArgumentProperties, copartitioningLists, handle); + } + + @Override + public PlanNode assignStatsEquivalentPlanNode(Optional statsEquivalentPlanNode) + { + return new TableFunctionNode(getSourceLocation(), getId(), statsEquivalentPlanNode, name, arguments, outputVariables, sources, tableArgumentProperties, copartitioningLists, handle); + } + + public static class TableArgumentProperties + { + private final String argumentName; + private final boolean rowSemantics; + private final boolean pruneWhenEmpty; + private final PassThroughSpecification passThroughSpecification; + private final List requiredColumns; + private final Optional specification; + + @JsonCreator + public TableArgumentProperties( + @JsonProperty("argumentName") String argumentName, + @JsonProperty("rowSemantics") boolean rowSemantics, + @JsonProperty("pruneWhenEmpty") boolean pruneWhenEmpty, + @JsonProperty("passThroughSpecification") PassThroughSpecification passThroughSpecification, + @JsonProperty("requiredColumns") List requiredColumns, + @JsonProperty("specification") Optional specification) + { + this.argumentName = requireNonNull(argumentName, "argumentName is null"); + this.rowSemantics = rowSemantics; + this.pruneWhenEmpty = pruneWhenEmpty; + this.passThroughSpecification = requireNonNull(passThroughSpecification, "passThroughSpecification is null"); + this.requiredColumns = ImmutableList.copyOf(requiredColumns); + this.specification = requireNonNull(specification, "specification is null"); + } + + @JsonProperty + public String getArgumentName() + { + return argumentName; + } + + @JsonProperty + public boolean rowSemantics() + { + return rowSemantics; + } + + @JsonProperty + public boolean pruneWhenEmpty() + { + return pruneWhenEmpty; + } + + @JsonProperty + public PassThroughSpecification getPassThroughSpecification() + { + return passThroughSpecification; + } + + @JsonProperty + public List getRequiredColumns() + { + return requiredColumns; + } + + @JsonProperty + public Optional specification() + { + return specification; + } + } + + public static class PassThroughSpecification + { + private final boolean declaredAsPassThrough; + private final List columns; + + @JsonCreator + public PassThroughSpecification( + @JsonProperty("declaredAsPassThrough") boolean declaredAsPassThrough, + @JsonProperty("columns") List columns) + { + this.declaredAsPassThrough = declaredAsPassThrough; + this.columns = ImmutableList.copyOf(requireNonNull(columns, "columns is null")); + checkArgument( + declaredAsPassThrough || this.columns.stream().allMatch(PassThroughColumn::isPartitioningColumn), + "non-partitioning pass-through column for non-pass-through source of a table function"); + } + + @JsonProperty + public boolean isDeclaredAsPassThrough() + { + return declaredAsPassThrough; + } + + @JsonProperty + public List getColumns() + { + return columns; + } + } + + public static class PassThroughColumn + { + private final VariableReferenceExpression outputVariables; + private final boolean partitioningColumn; + + @JsonCreator + public PassThroughColumn( + @JsonProperty("outputVariables") VariableReferenceExpression outputVariables, + @JsonProperty("partitioningColumn") boolean partitioningColumn) + { + this.outputVariables = requireNonNull(outputVariables, "symbol is null"); + this.partitioningColumn = partitioningColumn; + } + + @JsonProperty + public VariableReferenceExpression getOutputVariables() + { + return outputVariables; + } + + @JsonProperty + public boolean isPartitioningColumn() + { + return partitioningColumn; + } + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/TableFunctionProcessorNode.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/TableFunctionProcessorNode.java new file mode 100644 index 0000000000000..9a4a1d1c88123 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/TableFunctionProcessorNode.java @@ -0,0 +1,247 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.plan; + +import com.facebook.presto.metadata.TableFunctionHandle; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; +import com.facebook.presto.spi.plan.OrderingScheme; +import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.spi.plan.PlanNodeId; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughSpecification; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; + +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import static com.facebook.presto.spi.function.table.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.Iterables.getOnlyElement; +import static java.util.Objects.requireNonNull; + +public class TableFunctionProcessorNode + extends InternalPlanNode +{ + private final String name; + + // symbols produced by the function + private final List properOutputs; + + // pre-planned sources + private final Optional source; + // TODO do we need the info of which source has row semantics, or is it already included in the joins / join distribution? + + // specifies whether the function should be pruned or executed when the input is empty + // pruneWhenEmpty is false if and only if all original input tables are KEEP WHEN EMPTY + private final boolean pruneWhenEmpty; + + // all source symbols to be produced on output, ordered as table argument specifications + private final List passThroughSpecifications; + + // symbols required from each source, ordered as table argument specifications + private final List> requiredVariables; + + // mapping from source symbol to helper "marker" symbol which indicates whether the source value is valid + // for processing or for pass-through. null value in the marker column indicates that the value at the same + // position in the source column should not be processed or passed-through. + // the mapping is only present if there are two or more sources. + private final Optional> markerVariables; + + private final Optional specification; + private final Set prePartitioned; + private final int preSorted; + private final Optional hashSymbol; + + private final TableFunctionHandle handle; + + @JsonCreator + public TableFunctionProcessorNode( + @JsonProperty("id") PlanNodeId id, + @JsonProperty("name") String name, + @JsonProperty("properOutputs") List properOutputs, + @JsonProperty("source") Optional source, + @JsonProperty("pruneWhenEmpty") boolean pruneWhenEmpty, + @JsonProperty("passThroughSpecifications") List passThroughSpecifications, + @JsonProperty("requiredVariables") List> requiredVariables, + @JsonProperty("markerVariables") Optional> markerVariables, + @JsonProperty("specification") Optional specification, + @JsonProperty("prePartitioned") Set prePartitioned, + @JsonProperty("preSorted") int preSorted, + @JsonProperty("hashSymbol") Optional hashSymbol, + @JsonProperty("handle") TableFunctionHandle handle) + { + super(Optional.empty(), id, Optional.empty()); + this.name = requireNonNull(name, "name is null"); + this.properOutputs = ImmutableList.copyOf(properOutputs); + this.source = requireNonNull(source, "source is null"); + this.pruneWhenEmpty = pruneWhenEmpty; + this.passThroughSpecifications = ImmutableList.copyOf(passThroughSpecifications); + this.requiredVariables = requiredVariables.stream() + .map(ImmutableList::copyOf) + .collect(toImmutableList()); + this.markerVariables = markerVariables.map(ImmutableMap::copyOf); + this.specification = requireNonNull(specification, "specification is null"); + this.prePartitioned = ImmutableSet.copyOf(prePartitioned); + Set partitionBy = specification + .map(DataOrganizationSpecification::getPartitionBy) + .map(ImmutableSet::copyOf) + .orElse(ImmutableSet.of()); + checkArgument(partitionBy.containsAll(prePartitioned), "all pre-partitioned symbols must be contained in the partitioning list"); + this.preSorted = preSorted; + checkArgument( + specification + .flatMap(DataOrganizationSpecification::getOrderingScheme) + .map(OrderingScheme::getOrderBy) + .map(List::size) + .orElse(0) >= preSorted, + "the number of pre-sorted symbols cannot be greater than the number of all ordering symbols"); + checkArgument(preSorted == 0 || partitionBy.equals(prePartitioned), "to specify pre-sorted symbols, it is required that all partitioning symbols are pre-partitioned"); + this.hashSymbol = requireNonNull(hashSymbol, "hashSymbol is null"); + this.handle = requireNonNull(handle, "handle is null"); + } + + @JsonProperty + public String getName() + { + return name; + } + + @JsonProperty + public List getProperOutputs() + { + return properOutputs; + } + + @JsonProperty + public Optional getSource() + { + return source; + } + + @JsonProperty + public boolean isPruneWhenEmpty() + { + return pruneWhenEmpty; + } + + @JsonProperty + public List getPassThroughSpecifications() + { + return passThroughSpecifications; + } + + @JsonProperty + public List> getRequiredVariables() + { + return requiredVariables; + } + + @JsonProperty + public Optional> getMarkerVariables() + { + return markerVariables; + } + + @JsonProperty + public Optional getSpecification() + { + return specification; + } + + @JsonProperty + public Set getPrePartitioned() + { + return prePartitioned; + } + + @JsonProperty + public int getPreSorted() + { + return preSorted; + } + + @JsonProperty + public Optional getHashSymbol() + { + return hashSymbol; + } + + @JsonProperty + public TableFunctionHandle getHandle() + { + return handle; + } + + @JsonProperty + @Override + public List getSources() + { + return source.map(ImmutableList::of).orElse(ImmutableList.of()); + } + + @Override + public List getOutputVariables() + { + ImmutableList.Builder variables = ImmutableList.builder(); + + variables.addAll(properOutputs); + + passThroughSpecifications.stream() + .map(PassThroughSpecification::getColumns) + .flatMap(Collection::stream) + .map(TableFunctionNode.PassThroughColumn::getOutputVariables) + .forEach(variables::add); + + return variables.build(); + } + + @Override + public PlanNode assignStatsEquivalentPlanNode(Optional statsEquivalentPlanNode) + { + /*return new TableFunctionProcessorNode(getId(), + name, + properOutputs, + statsEquivalentPlanNode, + pruneWhenEmpty, + passThroughSpecifications, + requiredVariables, + markerVariables, + specification, + prePartitioned, + preSorted, + hashSymbol, + handle);*/ + return this; + } + + @Override + public PlanNode replaceChildren(List newSources) + { + Optional newSource = newSources.isEmpty() ? Optional.empty() : Optional.of(getOnlyElement(newSources)); + return new TableFunctionProcessorNode(getId(), name, properOutputs, newSource, pruneWhenEmpty, passThroughSpecifications, requiredVariables, markerVariables, specification, prePartitioned, preSorted, hashSymbol, handle); + } + + @Override + public R accept(InternalPlanVisitor visitor, C context) + { + return visitor.visitTableFunctionProcessor(this, context); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/TopNRowNumberNode.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/TopNRowNumberNode.java index bc6ee14a5e81f..d7e66afb664e4 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/TopNRowNumberNode.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/TopNRowNumberNode.java @@ -14,10 +14,10 @@ package com.facebook.presto.sql.planner.plan; import com.facebook.presto.spi.SourceLocation; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; import com.facebook.presto.spi.plan.OrderingScheme; import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.PlanNodeId; -import com.facebook.presto.spi.plan.WindowNode.Specification; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; @@ -37,7 +37,7 @@ public final class TopNRowNumberNode extends InternalPlanNode { private final PlanNode source; - private final Specification specification; + private final DataOrganizationSpecification specification; private final VariableReferenceExpression rowNumberVariable; private final int maxRowCountPerPartition; private final boolean partial; @@ -48,7 +48,7 @@ public TopNRowNumberNode( Optional sourceLocation, @JsonProperty("id") PlanNodeId id, @JsonProperty("source") PlanNode source, - @JsonProperty("specification") Specification specification, + @JsonProperty("specification") DataOrganizationSpecification specification, @JsonProperty("rowNumberVariable") VariableReferenceExpression rowNumberVariable, @JsonProperty("maxRowCountPerPartition") int maxRowCountPerPartition, @JsonProperty("partial") boolean partial, @@ -62,7 +62,7 @@ public TopNRowNumberNode( PlanNodeId id, Optional statsEquivalentPlanNode, PlanNode source, - Specification specification, + DataOrganizationSpecification specification, VariableReferenceExpression rowNumberVariable, int maxRowCountPerPartition, boolean partial, @@ -109,7 +109,7 @@ public PlanNode getSource() } @JsonProperty - public Specification getSpecification() + public DataOrganizationSpecification getSpecification() { return specification; } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java index 941d24f21f243..403192f71a38a 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java @@ -34,6 +34,9 @@ import com.facebook.presto.spi.SourceLocation; import com.facebook.presto.spi.TableHandle; import com.facebook.presto.spi.function.FunctionHandle; +import com.facebook.presto.spi.function.table.Argument; +import com.facebook.presto.spi.function.table.DescriptorArgument; +import com.facebook.presto.spi.function.table.ScalarArgument; import com.facebook.presto.spi.plan.AbstractJoinNode; import com.facebook.presto.spi.plan.AggregationNode; import com.facebook.presto.spi.plan.Assignments; @@ -95,6 +98,9 @@ import com.facebook.presto.sql.planner.plan.SampleNode; import com.facebook.presto.sql.planner.plan.SequenceNode; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.TableArgumentProperties; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.facebook.presto.sql.planner.plan.TableWriterMergeNode; import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; import com.facebook.presto.sql.planner.plan.UnnestNode; @@ -108,6 +114,7 @@ import com.google.common.base.Functions; import com.google.common.base.Joiner; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.ImmutableSortedMap; import com.google.common.collect.Iterables; @@ -117,11 +124,13 @@ import io.airlift.units.Duration; import java.util.ArrayList; +import java.util.Collection; import java.util.Collections; import java.util.Comparator; import java.util.LinkedList; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Optional; import java.util.Set; import java.util.function.Function; @@ -133,6 +142,7 @@ import static com.facebook.presto.execution.StageInfo.getAllStages; import static com.facebook.presto.expressions.DynamicFilters.extractDynamicFilters; import static com.facebook.presto.metadata.CastType.CAST; +import static com.facebook.presto.spi.function.table.DescriptorArgument.NULL_DESCRIPTOR; import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.createSymbolReference; import static com.facebook.presto.sql.planner.SortExpressionExtractor.getSortExpressionContext; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; @@ -146,11 +156,14 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableMap.toImmutableMap; import static io.airlift.units.DataSize.succinctBytes; import static java.lang.String.format; import static java.util.Arrays.stream; import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static java.util.function.Function.identity; +import static java.util.stream.Collectors.joining; import static java.util.stream.Collectors.toList; public class PlanPrinter @@ -203,7 +216,7 @@ private PlanPrinter( this.formatter = rowExpression -> rowExpressionFormatter.formatRowExpression(connectorSession, rowExpression); Visitor visitor = new Visitor(stageExecutionStrategy, types, estimatedStatsAndCosts, session, stats); - planRoot.accept(visitor, null); + planRoot.accept(visitor, new Context()); } public String toText(boolean verbose, int level, boolean verboseOptimizerInfo) @@ -489,7 +502,7 @@ public static String graphvizDistributedPlan(StageInfo stageInfo, FunctionAndTyp } private class Visitor - extends InternalPlanVisitor + extends InternalPlanVisitor { private final Optional stageExecutionStrategy; private final TypeProvider types; @@ -507,14 +520,14 @@ public Visitor(Optional stageExecutionStrategy, TypePr } @Override - public Void visitExplainAnalyze(ExplainAnalyzeNode node, Void context) + public Void visitExplainAnalyze(ExplainAnalyzeNode node, Context context) { - addNode(node, "ExplainAnalyze"); - return processChildren(node, context); + addNode(node, "ExplainAnalyze", context.getTag()); + return processChildren(node, new Context()); } @Override - public Void visitJoin(JoinNode node, Void context) + public Void visitJoin(JoinNode node, Context context) { List joinExpressions = new ArrayList<>(); for (EquiJoinClause clause : node.getCriteria()) { @@ -525,12 +538,12 @@ public Void visitJoin(JoinNode node, Void context) NodeRepresentation nodeOutput; if (node.isCrossJoin()) { checkState(joinExpressions.isEmpty()); - nodeOutput = addNode(node, "CrossJoin"); + nodeOutput = addNode(node, "CrossJoin", context.getTag()); } else { nodeOutput = addNode(node, node.getType().getJoinLabel(), - format("[%s]%s", Joiner.on(" AND ").join(joinExpressions), formatHash(node.getLeftHashVariable(), node.getRightHashVariable()))); + format("[%s]%s", Joiner.on(" AND ").join(joinExpressions), formatHash(node.getLeftHashVariable(), node.getRightHashVariable())), context.getTag()); } node.getDistributionType().ifPresent(distributionType -> nodeOutput.appendDetailsLine("Distribution: %s", distributionType)); @@ -540,51 +553,51 @@ public Void visitJoin(JoinNode node, Void context) getSortExpressionContext(node, functionAndTypeManager) .ifPresent(sortContext -> nodeOutput.appendDetails("SortExpression[%s]", formatter.apply(sortContext.getSortExpression()))); - node.getLeft().accept(this, context); - node.getRight().accept(this, context); + node.getLeft().accept(this, new Context()); + node.getRight().accept(this, new Context()); return null; } @Override - public Void visitSpatialJoin(SpatialJoinNode node, Void context) + public Void visitSpatialJoin(SpatialJoinNode node, Context context) { NodeRepresentation nodeOutput = addNode(node, node.getType().getJoinLabel(), - format("[%s]", formatter.apply(node.getFilter()))); + format("[%s]", formatter.apply(node.getFilter())), context.getTag()); nodeOutput.appendDetailsLine("Distribution: %s", node.getDistributionType()); - node.getLeft().accept(this, context); - node.getRight().accept(this, context); + node.getLeft().accept(this, new Context()); + node.getRight().accept(this, new Context()); return null; } @Override - public Void visitSemiJoin(SemiJoinNode node, Void context) + public Void visitSemiJoin(SemiJoinNode node, Context context) { NodeRepresentation nodeOutput = addNode(node, "SemiJoin", format("[%s = %s]%s", node.getSourceJoinVariable(), node.getFilteringSourceJoinVariable(), - formatHash(node.getSourceHashVariable(), node.getFilteringSourceHashVariable()))); + formatHash(node.getSourceHashVariable(), node.getFilteringSourceHashVariable())), context.getTag()); node.getDistributionType().ifPresent(distributionType -> nodeOutput.appendDetailsLine("Distribution: %s", distributionType)); if (!node.getDynamicFilters().isEmpty()) { nodeOutput.appendDetails(getDynamicFilterAssignments(node)); } - node.getSource().accept(this, context); - node.getFilteringSource().accept(this, context); + node.getSource().accept(this, new Context()); + node.getFilteringSource().accept(this, new Context()); return null; } @Override - public Void visitIndexSource(IndexSourceNode node, Void context) + public Void visitIndexSource(IndexSourceNode node, Context context) { NodeRepresentation nodeOutput = addNode(node, "IndexSource", - format("[%s, lookup = %s]", node.getIndexHandle(), node.getLookupVariables())); + format("[%s, lookup = %s]", node.getIndexHandle(), node.getLookupVariables()), context.getTag()); for (Map.Entry entry : node.getAssignments().entrySet()) { if (node.getOutputVariables().contains(entry.getKey())) { @@ -595,7 +608,7 @@ public Void visitIndexSource(IndexSourceNode node, Void context) } @Override - public Void visitIndexJoin(IndexJoinNode node, Void context) + public Void visitIndexJoin(IndexJoinNode node, Context context) { List joinExpressions = new ArrayList<>(); for (IndexJoinNode.EquiJoinClause clause : node.getCriteria()) { @@ -606,15 +619,15 @@ public Void visitIndexJoin(IndexJoinNode node, Void context) addNode(node, format("%sIndexJoin", node.getType().getJoinLabel()), - format("[%s]%s", Joiner.on(" AND ").join(joinExpressions), formatHash(node.getProbeHashVariable(), node.getIndexHashVariable()))); - node.getProbeSource().accept(this, context); - node.getIndexSource().accept(this, context); + format("[%s]%s", Joiner.on(" AND ").join(joinExpressions), formatHash(node.getProbeHashVariable(), node.getIndexHashVariable())), context.getTag()); + node.getProbeSource().accept(this, new Context()); + node.getIndexSource().accept(this, new Context()); return null; } @Override - public Void visitMergeJoin(MergeJoinNode node, Void context) + public Void visitMergeJoin(MergeJoinNode node, Context context) { List joinExpressions = new ArrayList<>(); for (EquiJoinClause clause : node.getCriteria()) { @@ -624,32 +637,32 @@ public Void visitMergeJoin(MergeJoinNode node, Void context) addNode(node, "MergeJoin", - format("[type: %s], [%s]%s", node.getType().getJoinLabel(), Joiner.on(" AND ").join(joinExpressions), formatHash(node.getLeftHashVariable(), node.getRightHashVariable()))); - node.getLeft().accept(this, context); - node.getRight().accept(this, context); + format("[type: %s], [%s]%s", node.getType().getJoinLabel(), Joiner.on(" AND ").join(joinExpressions), formatHash(node.getLeftHashVariable(), node.getRightHashVariable())), context.getTag()); + node.getLeft().accept(this, new Context()); + node.getRight().accept(this, new Context()); return null; } @Override - public Void visitLimit(LimitNode node, Void context) + public Void visitLimit(LimitNode node, Context context) { addNode(node, format("Limit%s", node.isPartial() ? "Partial" : ""), - format("[%s]", node.getCount())); - return processChildren(node, context); + format("[%s]", node.getCount()), context.getTag()); + return processChildren(node, new Context()); } @Override - public Void visitDistinctLimit(DistinctLimitNode node, Void context) + public Void visitDistinctLimit(DistinctLimitNode node, Context context) { addNode(node, format("DistinctLimit%s", node.isPartial() ? "Partial" : ""), - format("[%s]%s", node.getLimit(), formatHash(node.getHashVariable()))); - return processChildren(node, context); + format("[%s]%s", node.getLimit(), formatHash(node.getHashVariable())), context.getTag()); + return processChildren(node, new Context()); } @Override - public Void visitAggregation(AggregationNode node, Void context) + public Void visitAggregation(AggregationNode node, Context context) { String type = ""; if (node.getStep() != AggregationNode.Step.SINGLE) { @@ -667,13 +680,13 @@ public Void visitAggregation(AggregationNode node, Void context) } NodeRepresentation nodeOutput = addNode(node, - format("Aggregate%s%s%s", type, key, formatHash(node.getHashVariable()))); + format("Aggregate%s%s%s", type, key, formatHash(node.getHashVariable())), context.getTag()); for (Map.Entry entry : node.getAggregations().entrySet()) { nodeOutput.appendDetailsLine("%s := %s%s", entry.getKey(), formatAggregation(entry.getValue()), formatSourceLocation(entry.getValue().getCall().getSourceLocation(), entry.getKey().getSourceLocation())); } - return processChildren(node, context); + return processChildren(node, new Context()); } private String formatAggregation(AggregationNode.Aggregation aggregation) @@ -700,7 +713,7 @@ private String formatAggregation(AggregationNode.Aggregation aggregation) } @Override - public Void visitGroupId(GroupIdNode node, Void context) + public Void visitGroupId(GroupIdNode node, Context context) { // grouping sets are easier to understand in terms of inputs List> inputGroupingSetSymbols = node.getGroupingSets().stream() @@ -709,27 +722,27 @@ public Void visitGroupId(GroupIdNode node, Void context) .collect(Collectors.toList())) .collect(Collectors.toList()); - NodeRepresentation nodeOutput = addNode(node, "GroupId", format("%s", inputGroupingSetSymbols)); + NodeRepresentation nodeOutput = addNode(node, "GroupId", format("%s", inputGroupingSetSymbols), context.getTag()); for (Map.Entry mapping : node.getGroupingColumns().entrySet()) { nodeOutput.appendDetailsLine("%s := %s%s", mapping.getKey(), mapping.getValue(), formatSourceLocation(mapping.getValue().getSourceLocation(), mapping.getKey().getSourceLocation())); } - return processChildren(node, context); + return processChildren(node, new Context()); } @Override - public Void visitMarkDistinct(MarkDistinctNode node, Void context) + public Void visitMarkDistinct(MarkDistinctNode node, Context context) { addNode(node, "MarkDistinct", - format("[distinct=%s marker=%s]%s", formatOutputs(node.getDistinctVariables()), node.getMarkerVariable(), formatHash(node.getHashVariable()))); + format("[distinct=%s marker=%s]%s", formatOutputs(node.getDistinctVariables()), node.getMarkerVariable(), formatHash(node.getHashVariable())), context.getTag()); - return processChildren(node, context); + return processChildren(node, new Context()); } @Override - public Void visitWindow(WindowNode node, Void context) + public Void visitWindow(WindowNode node, Context context) { List partitionBy = Lists.transform(node.getPartitionBy(), Functions.toStringFunction()); @@ -769,7 +782,7 @@ public Void visitWindow(WindowNode node, Void context) .collect(Collectors.joining(", ")))); } - NodeRepresentation nodeOutput = addNode(node, "Window", format("[%s]%s", Joiner.on(", ").join(args), formatHash(node.getHashVariable()))); + NodeRepresentation nodeOutput = addNode(node, "Window", format("[%s]%s", Joiner.on(", ").join(args), formatHash(node.getHashVariable())), context.getTag()); for (Map.Entry entry : node.getWindowFunctions().entrySet()) { CallExpression call = entry.getValue().getFunctionCall(); @@ -783,11 +796,11 @@ public Void visitWindow(WindowNode node, Void context) frameInfo, formatSourceLocation(entry.getValue().getFunctionCall().getSourceLocation(), entry.getKey().getSourceLocation())); } - return processChildren(node, context); + return processChildren(node, new Context()); } @Override - public Void visitTopNRowNumber(TopNRowNumberNode node, Void context) + public Void visitTopNRowNumber(TopNRowNumberNode node, Context context) { List partitionBy = node.getPartitionBy().stream() .map(Functions.toStringFunction()) @@ -803,15 +816,15 @@ public Void visitTopNRowNumber(TopNRowNumberNode node, Void context) NodeRepresentation nodeOutput = addNode(node, format("TopNRowNumber%s", node.isPartial() ? "Partial" : ""), - format("[%s limit %s]%s", Joiner.on(", ").join(args), node.getMaxRowCountPerPartition(), formatHash(node.getHashVariable()))); + format("[%s limit %s]%s", Joiner.on(", ").join(args), node.getMaxRowCountPerPartition(), formatHash(node.getHashVariable())), context.getTag()); nodeOutput.appendDetailsLine("%s := %s%s", node.getRowNumberVariable(), "row_number()", formatSourceLocation(node.getRowNumberVariable().getSourceLocation())); - return processChildren(node, context); + return processChildren(node, new Context()); } @Override - public Void visitRowNumber(RowNumberNode node, Void context) + public Void visitRowNumber(RowNumberNode node, Context context) { List partitionBy = Lists.transform(node.getPartitionBy(), Functions.toStringFunction()); List args = new ArrayList<>(); @@ -825,24 +838,24 @@ public Void visitRowNumber(RowNumberNode node, Void context) NodeRepresentation nodeOutput = addNode(node, format("RowNumber%s", node.isPartial() ? "Partial" : ""), - format("[%s]%s", Joiner.on(", ").join(args), formatHash(node.getHashVariable()))); + format("[%s]%s", Joiner.on(", ").join(args), formatHash(node.getHashVariable())), context.getTag()); nodeOutput.appendDetailsLine("%s := %s%s", node.getRowNumberVariable(), "row_number()", formatSourceLocation(node.getRowNumberVariable().getSourceLocation())); - return processChildren(node, context); + return processChildren(node, new Context()); } @Override - public Void visitTableScan(TableScanNode node, Void context) + public Void visitTableScan(TableScanNode node, Context context) { TableHandle table = node.getTable(); NodeRepresentation nodeOutput; if (stageExecutionStrategy.isPresent()) { nodeOutput = addNode(node, "TableScan", - format("[%s, grouped = %s]", table, stageExecutionStrategy.get().isScanGroupedExecution(node.getId()))); + format("[%s, grouped = %s]", table, stageExecutionStrategy.get().isScanGroupedExecution(node.getId())), context.getTag()); } else { - nodeOutput = addNode(node, "TableScan", format("[%s]", table)); + nodeOutput = addNode(node, "TableScan", format("[%s]", table), context.getTag()); } PlanNodeStats nodeStats = stats.map(s -> s.get(node.getId())).orElse(null); printTableScanInfo(nodeOutput, node, nodeStats); @@ -850,49 +863,49 @@ public Void visitTableScan(TableScanNode node, Void context) } @Override - public Void visitSequence(SequenceNode node, Void context) + public Void visitSequence(SequenceNode node, Context context) { NodeRepresentation nodeOutput; - nodeOutput = addNode(node, "Sequence"); + nodeOutput = addNode(node, "Sequence", context.getTag()); nodeOutput.appendDetails(getCteExecutionOrder(node)); - return processChildren(node, context); + return processChildren(node, new Context()); } @Override - public Void visitCteConsumer(CteConsumerNode node, Void context) + public Void visitCteConsumer(CteConsumerNode node, Context context) { NodeRepresentation nodeOutput; - nodeOutput = addNode(node, "CteConsumer"); + nodeOutput = addNode(node, "CteConsumer", context.getTag()); nodeOutput.appendDetailsLine("CTE_NAME: %s", node.getCteId()); - return processChildren(node, context); + return processChildren(node, new Context()); } @Override - public Void visitCteProducer(CteProducerNode node, Void context) + public Void visitCteProducer(CteProducerNode node, Context context) { NodeRepresentation nodeOutput; - nodeOutput = addNode(node, "CteProducer"); + nodeOutput = addNode(node, "CteProducer", context.getTag()); nodeOutput.appendDetailsLine("CTE_NAME: %s", node.getCteId()); - return processChildren(node, context); + return processChildren(node, new Context()); } @Override - public Void visitCteReference(CteReferenceNode node, Void context) + public Void visitCteReference(CteReferenceNode node, Context context) { - addNode(node, "CteReference"); - return processChildren(node, context); + addNode(node, "CteReference", context.getTag()); + return processChildren(node, new Context()); } @Override - public Void visitValues(ValuesNode node, Void context) + public Void visitValues(ValuesNode node, Context context) { NodeRepresentation nodeOutput; if (node.getValuesNodeLabel().isPresent()) { - nodeOutput = addNode(node, format("Values converted from TableScan[%s]", node.getValuesNodeLabel().get())); + nodeOutput = addNode(node, format("Values converted from TableScan[%s]", node.getValuesNodeLabel().get()), context.getTag()); } else { - nodeOutput = addNode(node, "Values"); + nodeOutput = addNode(node, "Values", context.getTag()); } for (List row : node.getRows()) { nodeOutput.appendDetailsLine("(" + row.stream().map(formatter::apply).collect(Collectors.joining(", ")) + ")"); @@ -901,13 +914,13 @@ public Void visitValues(ValuesNode node, Void context) } @Override - public Void visitFilter(FilterNode node, Void context) + public Void visitFilter(FilterNode node, Context context) { return visitScanFilterAndProjectInfo(node, Optional.of(node), Optional.empty(), context); } @Override - public Void visitProject(ProjectNode node, Void context) + public Void visitProject(ProjectNode node, Context context) { if (node.getSource() instanceof FilterNode) { return visitScanFilterAndProjectInfo(node, Optional.of((FilterNode) node.getSource()), Optional.of(node), context); @@ -920,7 +933,7 @@ private Void visitScanFilterAndProjectInfo( PlanNode node, Optional filterNode, Optional projectNode, - Void context) + Context context) { checkState(projectNode.isPresent() || filterNode.isPresent()); @@ -994,7 +1007,8 @@ private Void visitScanFilterAndProjectInfo( format(formatString, arguments.toArray(new Object[0])), allNodes, ImmutableList.of(sourceNode), - ImmutableList.of()); + ImmutableList.of(), + context.getTag()); if (projectNode.isPresent()) { printAssignments(nodeOutput, projectNode.get().getAssignments()); @@ -1006,7 +1020,7 @@ private Void visitScanFilterAndProjectInfo( return null; } - sourceNode.accept(this, context); + sourceNode.accept(this, new Context()); return null; } @@ -1074,18 +1088,18 @@ else if (predicate.isNone()) { } @Override - public Void visitUnnest(UnnestNode node, Void context) + public Void visitUnnest(UnnestNode node, Context context) { addNode(node, "Unnest", - format("[replicate=%s, unnest=%s]", formatOutputs(node.getReplicateVariables()), formatOutputs(node.getUnnestVariables().keySet()))); - return processChildren(node, context); + format("[replicate=%s, unnest=%s]", formatOutputs(node.getReplicateVariables()), formatOutputs(node.getUnnestVariables().keySet())), context.getTag()); + return processChildren(node, new Context()); } @Override - public Void visitOutput(OutputNode node, Void context) + public Void visitOutput(OutputNode node, Context context) { - NodeRepresentation nodeOutput = addNode(node, "Output", format("[%s]", Joiner.on(", ").join(node.getColumnNames()))); + NodeRepresentation nodeOutput = addNode(node, "Output", format("[%s]", Joiner.on(", ").join(node.getColumnNames())), context.getTag()); for (int i = 0; i < node.getColumnNames().size(); i++) { String name = node.getColumnNames().get(i); VariableReferenceExpression variable = node.getOutputVariables().get(i); @@ -1093,22 +1107,22 @@ public Void visitOutput(OutputNode node, Void context) nodeOutput.appendDetailsLine("%s := %s%s", name, variable, formatSourceLocation(variable.getSourceLocation())); } } - return processChildren(node, context); + return processChildren(node, new Context()); } @Override - public Void visitTopN(TopNNode node, Void context) + public Void visitTopN(TopNNode node, Context context) { Iterable keys = Iterables.transform(node.getOrderingScheme().getOrderByVariables(), input -> input + " " + node.getOrderingScheme().getOrdering(input)); addNode(node, format("TopN%s", node.getStep() == TopNNode.Step.PARTIAL ? "Partial" : ""), - format("[%s by (%s)]", node.getCount(), Joiner.on(", ").join(keys))); - return processChildren(node, context); + format("[%s by (%s)]", node.getCount(), Joiner.on(", ").join(keys)), context.getTag()); + return processChildren(node, new Context()); } @Override - public Void visitSort(SortNode node, Void context) + public Void visitSort(SortNode node, Context context) { Iterable keys = Iterables.transform(node.getOrderingScheme().getOrderByVariables(), input -> input + " " + node.getOrderingScheme().getOrdering(input)); @@ -1116,52 +1130,53 @@ public Void visitSort(SortNode node, Void context) if (!node.getPartitionBy().isEmpty()) { detail = format("%s[Partition by %s]", detail, Joiner.on(", ").join(node.getPartitionBy())); } - addNode(node, format("%sSort", node.isPartial() ? "Partial" : ""), detail); + addNode(node, format("%sSort", node.isPartial() ? "Partial" : ""), detail, context.getTag()); - return processChildren(node, context); + return processChildren(node, new Context()); } @Override - public Void visitRemoteSource(RemoteSourceNode node, Void context) + public Void visitRemoteSource(RemoteSourceNode node, Context context) { addNode(node, format("Remote%s", node.getOrderingScheme().isPresent() ? "Merge" : "Source"), format("[%s]", Joiner.on(',').join(node.getSourceFragmentIds())), ImmutableList.of(), ImmutableList.of(), - node.getSourceFragmentIds()); + node.getSourceFragmentIds(), + context.getTag()); return null; } @Override - public Void visitUnion(UnionNode node, Void context) + public Void visitUnion(UnionNode node, Context context) { - addNode(node, "Union"); + addNode(node, "Union", context.getTag()); - return processChildren(node, context); + return processChildren(node, new Context()); } @Override - public Void visitIntersect(IntersectNode node, Void context) + public Void visitIntersect(IntersectNode node, Context context) { - addNode(node, "Intersect"); + addNode(node, "Intersect", context.getTag()); - return processChildren(node, context); + return processChildren(node, new Context()); } @Override - public Void visitExcept(ExceptNode node, Void context) + public Void visitExcept(ExceptNode node, Context context) { - addNode(node, "Except"); + addNode(node, "Except", context.getTag()); - return processChildren(node, context); + return processChildren(node, new Context()); } @Override - public Void visitTableWriter(TableWriterNode node, Void context) + public Void visitTableWriter(TableWriterNode node, Context context) { - NodeRepresentation nodeOutput = addNode(node, "TableWriter"); + NodeRepresentation nodeOutput = addNode(node, "TableWriter", context.getTag()); for (int i = 0; i < node.getColumnNames().size(); i++) { String name = node.getColumnNames().get(i); VariableReferenceExpression variable = node.getColumns().get(i); @@ -1174,40 +1189,40 @@ public Void visitTableWriter(TableWriterNode node, Void context) .orElse(0); nodeOutput.appendDetailsLine("Statistics collected: %s", statisticsCollected); - return processChildren(node, context); + return processChildren(node, new Context()); } @Override - public Void visitTableWriteMerge(TableWriterMergeNode node, Void context) + public Void visitTableWriteMerge(TableWriterMergeNode node, Context context) { - addNode(node, "TableWriterMerge"); - return processChildren(node, context); + addNode(node, "TableWriterMerge", context.getTag()); + return processChildren(node, new Context()); } @Override - public Void visitStatisticsWriterNode(StatisticsWriterNode node, Void context) + public Void visitStatisticsWriterNode(StatisticsWriterNode node, Context context) { - addNode(node, "StatisticsWriter", format("[%s]", node.getTableHandle())); - return processChildren(node, context); + addNode(node, "StatisticsWriter", format("[%s]", node.getTableHandle()), context.getTag()); + return processChildren(node, new Context()); } @Override - public Void visitTableFinish(TableFinishNode node, Void context) + public Void visitTableFinish(TableFinishNode node, Context context) { - addNode(node, "TableCommit", format("[%s]", node.getTarget())); - return processChildren(node, context); + addNode(node, "TableCommit", format("[%s]", node.getTarget()), context.getTag()); + return processChildren(node, new Context()); } @Override - public Void visitSample(SampleNode node, Void context) + public Void visitSample(SampleNode node, Context context) { - addNode(node, "Sample", format("[%s: %s]", node.getSampleType(), node.getSampleRatio())); + addNode(node, "Sample", format("[%s: %s]", node.getSampleType(), node.getSampleRatio()), context.getTag()); - return processChildren(node, context); + return processChildren(node, new Context()); } @Override - public Void visitExchange(ExchangeNode node, Void context) + public Void visitExchange(ExchangeNode node, Context context) { if (node.getOrderingScheme().isPresent()) { OrderingScheme orderingScheme = node.getOrderingScheme().get(); @@ -1218,7 +1233,7 @@ public Void visitExchange(ExchangeNode node, Void context) addNode(node, format("%sMerge", UPPER_UNDERSCORE.to(CaseFormat.UPPER_CAMEL, node.getScope().toString())), - format("[%s]", Joiner.on(", ").join(orderBy))); + format("[%s]", Joiner.on(", ").join(orderBy)), context.getTag()); } else if (node.getScope().isLocal()) { addNode(node, @@ -1227,7 +1242,7 @@ else if (node.getScope().isLocal()) { node.getPartitioningScheme().getPartitioning().getHandle(), node.getPartitioningScheme().isReplicateNullsAndAny() ? " - REPLICATE NULLS AND ANY" : "", formatHash(node.getPartitioningScheme().getHashColumn()), - Joiner.on(", ").join(node.getPartitioningScheme().getPartitioning().getArguments()))); + Joiner.on(", ").join(node.getPartitioningScheme().getPartitioning().getArguments())), context.getTag()); } else { addNode(node, @@ -1236,83 +1251,257 @@ else if (node.getScope().isLocal()) { node.getType(), node.getPartitioningScheme().getEncoding(), node.getPartitioningScheme().isReplicateNullsAndAny() ? " - REPLICATE NULLS AND ANY" : "", - formatHash(node.getPartitioningScheme().getHashColumn()))); + formatHash(node.getPartitioningScheme().getHashColumn())), context.getTag()); } - return processChildren(node, context); + return processChildren(node, new Context()); } @Override - public Void visitDelete(DeleteNode node, Void context) + public Void visitDelete(DeleteNode node, Context context) { - addNode(node, "Delete"); + addNode(node, "Delete", context.getTag()); - return processChildren(node, context); + return processChildren(node, new Context()); } @Override - public Void visitUpdate(UpdateNode node, Void context) + public Void visitUpdate(UpdateNode node, Context context) { - addNode(node, "Update"); + addNode(node, "Update", context.getTag()); - return processChildren(node, context); + return processChildren(node, new Context()); } @Override - public Void visitMetadataDelete(MetadataDeleteNode node, Void context) + public Void visitMetadataDelete(MetadataDeleteNode node, Context context) { - addNode(node, "MetadataDelete", format("[%s]", node.getTableHandle())); + addNode(node, "MetadataDelete", format("[%s]", node.getTableHandle()), context.getTag()); - return processChildren(node, context); + return processChildren(node, new Context()); } @Override - public Void visitEnforceSingleRow(EnforceSingleRowNode node, Void context) + public Void visitEnforceSingleRow(EnforceSingleRowNode node, Context context) { - addNode(node, "EnforceSingleRow"); + addNode(node, "EnforceSingleRow", context.getTag()); - return processChildren(node, context); + return processChildren(node, new Context()); } @Override - public Void visitAssignUniqueId(AssignUniqueId node, Void context) + public Void visitAssignUniqueId(AssignUniqueId node, Context context) { - addNode(node, "AssignUniqueId"); + addNode(node, "AssignUniqueId", context.getTag()); - return processChildren(node, context); + return processChildren(node, new Context()); } @Override - public Void visitGroupReference(GroupReference node, Void context) + public Void visitGroupReference(GroupReference node, Context context) { - addNode(node, "GroupReference", format("[%s]", node.getGroupId()), ImmutableList.of()); + addNode(node, "GroupReference", format("[%s]", node.getGroupId()), ImmutableList.of(), context.getTag()); return null; } @Override - public Void visitApply(ApplyNode node, Void context) + public Void visitApply(ApplyNode node, Context context) { - NodeRepresentation nodeOutput = addNode(node, "Apply", format("[%s]", node.getCorrelation())); + NodeRepresentation nodeOutput = addNode(node, "Apply", format("[%s]", node.getCorrelation()), context.getTag()); printAssignments(nodeOutput, node.getSubqueryAssignments()); - return processChildren(node, context); + return processChildren(node, new Context()); + } + + @Override + public Void visitLateralJoin(LateralJoinNode node, Context context) + { + addNode(node, "Lateral", format("[%s]", node.getCorrelation()), context.getTag()); + + return processChildren(node, new Context()); + } + + @Override + public Void visitTableFunction(TableFunctionNode node, Context context) + { + NodeRepresentation nodeOutput = addNode( + node, + "TableFunction", + "name", + context.getTag()); + + if (!node.getArguments().isEmpty()) { + nodeOutput.appendDetails("Arguments:"); + + Map tableArguments = node.getTableArgumentProperties().stream() + .collect(toImmutableMap(TableArgumentProperties::getArgumentName, identity())); + + node.getArguments().entrySet().stream() + .forEach(entry -> nodeOutput.appendDetails(formatArgument(entry.getKey(), entry.getValue(), tableArguments))); + + if (!node.getCopartitioningLists().isEmpty()) { + nodeOutput.appendDetails(node.getCopartitioningLists().stream() + .map(list -> list.stream().collect(Collectors.joining(", ", "(", ")"))) + .collect(joining(", ", "Co-partition: [", "]"))); + } + } + + for (int i = 0; i < node.getSources().size(); i++) { + node.getSources().get(i).accept(this, new Context(node.getTableArgumentProperties().get(i).getArgumentName())); + } + + return null; + } + + private String formatArgument(String argumentName, Argument argument, Map tableArguments) + { + if (argument instanceof ScalarArgument) { + ScalarArgument scalarArgument = (ScalarArgument) argument; + return formatScalarArgument(argumentName, scalarArgument); + } + if (argument instanceof DescriptorArgument) { + DescriptorArgument descriptorArgument = (DescriptorArgument) argument; + return formatDescriptorArgument(argumentName, descriptorArgument); + } + else { + TableArgumentProperties argumentProperties = tableArguments.get(argumentName); + return formatTableArgument(argumentName, argumentProperties); + } + } + + private String formatScalarArgument(String argumentName, ScalarArgument argument) + { + return format( + "%s => ScalarArgument{type=%s, value=%s}", + argumentName, + argument.getType().getDisplayName(), + argument.getValue()); + } + + private String formatDescriptorArgument(String argumentName, DescriptorArgument argument) + { + String descriptor; + if (argument.equals(NULL_DESCRIPTOR)) { + descriptor = "NULL"; + } + else { + descriptor = argument.getDescriptor().orElseThrow(() -> new IllegalStateException("Missing descriptor")).getFields().stream() + .map(field -> field.getName() + field.getType().map(type -> " " + type.getDisplayName()).orElse("")) + .collect(joining(", ", "(", ")")); + } + return format("%s => DescriptorArgument{%s}", argumentName, descriptor); + } + + private String formatTableArgument(String argumentName, TableArgumentProperties argumentProperties) + { + StringBuilder properties = new StringBuilder(); + if (argumentProperties.rowSemantics()) { + properties.append("row semantics"); + } + argumentProperties.specification().ifPresent(specification -> { + properties + .append("partition by: [") + .append(Joiner.on(", ").join(specification.getPartitionBy())) + .append("]"); + specification.getOrderingScheme().ifPresent(orderingScheme -> { + properties + .append(", order by: ") + .append(formatOrderingScheme(orderingScheme)); + }); + }); + + properties.append("required columns: [") + .append(Joiner.on(", ").join(argumentProperties.getRequiredColumns())) + .append("]"); + + properties.append("required columns: [") + .append(Joiner.on(", ").join(argumentProperties.getRequiredColumns())) + .append("]"); + + if (argumentProperties.pruneWhenEmpty()) { + properties.append(", prune when empty"); + } + + if (argumentProperties.getPassThroughSpecification().isDeclaredAsPassThrough()) { + properties.append(", pass through columns"); + } + + return format("%s => TableArgument{%s}", argumentName, properties); } @Override - public Void visitLateralJoin(LateralJoinNode node, Void context) + public Void visitTableFunctionProcessor(TableFunctionProcessorNode node, Context context) { - addNode(node, "Lateral", format("[%s]", node.getCorrelation())); + ImmutableMap.Builder descriptor = ImmutableMap.builder(); + + descriptor.put("name", node.getName()); - return processChildren(node, context); + descriptor.put("properOutputs", format("[%s]", Joiner.on(", ").join(node.getProperOutputs()))); + + node.getSpecification().ifPresent(specification -> { + if (!specification.getPartitionBy().isEmpty()) { + List prePartitioned = specification.getPartitionBy().stream() + .filter(node.getPrePartitioned()::contains) + .collect(toImmutableList()); + + List notPrePartitioned = specification.getPartitionBy().stream() + .filter(column -> !node.getPrePartitioned().contains(column)) + .collect(toImmutableList()); + + StringBuilder builder = new StringBuilder(); + if (!prePartitioned.isEmpty()) { + builder.append(prePartitioned.stream() + .map(VariableReferenceExpression::toString) + .collect(joining(", ", "<", ">"))); + if (!notPrePartitioned.isEmpty()) { + builder.append(", "); + } + } + if (!notPrePartitioned.isEmpty()) { + builder.append(Joiner.on(", ").join(notPrePartitioned)); + } + descriptor.put("partitionBy", format("[%s]", builder)); + } + specification.getOrderingScheme().ifPresent(orderingScheme -> descriptor.put("orderBy", formatOrderingScheme(orderingScheme, node.getPreSorted()))); + }); + + addNode(node, "TableFunctionDataProcessor" + descriptor.build(), context.getTag()); + + return processChildren(node, new Context()); + } + + private String formatOrderingScheme(OrderingScheme orderingScheme) + { + return formatCollection(orderingScheme.getOrderByVariables(), variable -> variable + " " + orderingScheme.getOrdering(variable)); + } + + private String formatOrderingScheme(OrderingScheme orderingScheme, int preSortedOrderPrefix) + { + List orderBy = Stream.concat( + orderingScheme.getOrderByVariables().stream() + .limit(preSortedOrderPrefix) + .map(variable -> "<" + variable + " " + orderingScheme.getOrdering(variable) + ">"), + orderingScheme.getOrderByVariables().stream() + .skip(preSortedOrderPrefix) + .map(variable -> variable + " " + orderingScheme.getOrdering(variable))) + .collect(toImmutableList()); + return formatCollection(orderBy, Objects::toString); + } + + public String formatCollection(Collection collection, Function formatter) + { + return collection.stream() + .map(formatter) + .collect(joining(", ", "[", "]")); } @Override - public Void visitPlan(PlanNode node, Void context) + public Void visitPlan(PlanNode node, Context context) { throw new UnsupportedOperationException("not yet implemented: " + node.getClass().getName()); } - private Void processChildren(PlanNode node, Void context) + private Void processChildren(PlanNode node, Context context) { for (PlanNode child : node.getSources()) { child.accept(this, context); @@ -1371,27 +1560,27 @@ private String formatDomain(Domain domain) return "[" + Joiner.on(", ").join(parts.build()) + "]"; } - public NodeRepresentation addNode(PlanNode node, String name) + public NodeRepresentation addNode(PlanNode node, String name, Optional tag) { - return addNode(node, name, ""); + return addNode(node, name, "", tag); } - public NodeRepresentation addNode(PlanNode node, String name, String identifier) + public NodeRepresentation addNode(PlanNode node, String name, String identifier, Optional tag) { - return addNode(node, name, identifier, node.getSources()); + return addNode(node, name, identifier, node.getSources(), tag); } - public NodeRepresentation addNode(PlanNode node, String name, String identifier, List children) + public NodeRepresentation addNode(PlanNode node, String name, String identifier, List children, Optional tag) { - return addNode(node, name, identifier, ImmutableList.of(node.getId()), children, ImmutableList.of()); + return addNode(node, name, identifier, ImmutableList.of(node.getId()), children, ImmutableList.of(), tag); } - public NodeRepresentation addNode(PlanNode node, String name, List children) + public NodeRepresentation addNode(PlanNode node, String name, List children, Optional tag) { - return addNode(node, name, "", ImmutableList.of(node.getId()), children, ImmutableList.of()); + return addNode(node, name, "", ImmutableList.of(node.getId()), children, ImmutableList.of(), tag); } - public NodeRepresentation addNode(PlanNode rootNode, String name, String identifier, List allNodes, List children, List remoteSources) + public NodeRepresentation addNode(PlanNode rootNode, String name, String identifier, List allNodes, List children, List remoteSources, Optional tag) { List childrenIds = children.stream().map(PlanNode::getId).collect(toImmutableList()); List estimatedStats = allNodes.stream() @@ -1400,6 +1589,9 @@ public NodeRepresentation addNode(PlanNode rootNode, String name, String identif List estimatedCosts = allNodes.stream() .map(nodeId -> estimatedStatsAndCosts.getCosts().getOrDefault(nodeId, PlanCostEstimate.unknown())) .collect(toList()); + name = tag + .map(tagName -> format("[%s] ", tagName)) + .orElse("") + name; NodeRepresentation nodeOutput = new NodeRepresentation( Optional.empty(), @@ -1496,4 +1688,29 @@ private static String formatOutputs(Iterable output .map(input -> input + ":" + input.getType().getDisplayName()) .collect(Collectors.joining(", ")); } + + public class Context + { + private final Optional tag; + + public Context() + { + this(Optional.empty()); + } + + public Context(String tag) + { + this(Optional.of(tag)); + } + + public Context(Optional tag) + { + this.tag = requireNonNull(tag, "tag is null"); + } + + public Optional getTag() + { + return tag; + } + } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java index ecc264a2518af..a7689e14cc0de 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java @@ -68,6 +68,9 @@ import com.facebook.presto.sql.planner.plan.SampleNode; import com.facebook.presto.sql.planner.plan.SequenceNode; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughColumn; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.facebook.presto.sql.planner.plan.TableWriterMergeNode; import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; import com.facebook.presto.sql.planner.plan.UnnestNode; @@ -83,6 +86,7 @@ import static com.facebook.presto.spi.plan.JoinNode.checkLeftOutputVariablesBeforeRight; import static com.facebook.presto.sql.planner.optimizations.AggregationNodeUtils.extractAggregationUniqueVariables; import static com.facebook.presto.sql.planner.optimizations.IndexJoinOptimizer.IndexKeyTracer; +import static com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughSpecification; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableSet.toImmutableSet; @@ -113,9 +117,135 @@ public Visitor() @Override public Void visitPlan(PlanNode node, Set boundVariables) { + // TODO: Michael: Is this okay? Trino's TypeValidator's Visitor extends off of + // SimplePlanVisitor. This is what is in SimplePlanVisitor's visitPlan. May need + // to change this later. + /* + for (PlanNode source : node.getSources()) { + source.accept(this, boundVariables); + } + return null; + */ throw new UnsupportedOperationException("not yet implemented: " + node.getClass().getName()); } + @Override + public Void visitTableFunction(TableFunctionNode node, Set boundSymbols) + { + for (int i = 0; i < node.getSources().size(); i++) { + PlanNode source = node.getSources().get(i); + source.accept(this, boundSymbols); + Set inputs = createInputs(source, boundSymbols); + TableFunctionNode.TableArgumentProperties argumentProperties = node.getTableArgumentProperties().get(i); + + checkDependencies( + inputs, + argumentProperties.getRequiredColumns(), + "Invalid node. Required input symbols from source %s (%s) not in source plan output (%s)", + argumentProperties.getArgumentName(), + argumentProperties.getRequiredColumns(), + source.getOutputVariables()); + argumentProperties.specification().ifPresent(specification -> { + checkDependencies( + inputs, + specification.getPartitionBy(), + "Invalid node. Partition by symbols for source %s (%s) not in source plan output (%s)", + argumentProperties.getArgumentName(), + specification.getPartitionBy(), + source.getOutputVariables()); + specification.getOrderingScheme().ifPresent(orderingScheme -> { + checkDependencies( + inputs, + orderingScheme.getOrderByVariables(), + "Invalid node. Order by symbols for source %s (%s) not in source plan output (%s)", + argumentProperties.getArgumentName(), + orderingScheme.getOrderBy(), + source.getOutputVariables()); + }); + }); + Set passThroughVariable = argumentProperties.getPassThroughSpecification().getColumns().stream() + .map(PassThroughColumn::getOutputVariables) + .collect(toImmutableSet()); + checkDependencies( + inputs, + passThroughVariable, + "Invalid node. Pass-through symbols for source %s (%s) not in source plan output (%s)", + argumentProperties.getArgumentName(), + passThroughVariable, + source.getOutputVariables()); + } + return null; + } + + @Override + public Void visitTableFunctionProcessor(TableFunctionProcessorNode node, Set boundVariables) + { + if (!node.getSource().isPresent()) { + return null; + } + + PlanNode source = node.getSource().get(); + source.accept(this, boundVariables); + + Set inputs = createInputs(source, boundVariables); + + Set passThroughSymbols = node.getPassThroughSpecifications().stream() + .map(PassThroughSpecification::getColumns) + .flatMap(Collection::stream) + .map(PassThroughColumn::getOutputVariables) + .collect(toImmutableSet()); + checkDependencies( + inputs, + passThroughSymbols, + "Invalid node. Pass-through symbols (%s) not in source plan output (%s)", + passThroughSymbols, + source.getOutputVariables()); + + Set requiredSymbols = node.getRequiredVariables().stream() + .flatMap(Collection::stream) + .collect(toImmutableSet()); + checkDependencies( + inputs, + requiredSymbols, + "Invalid node. Required symbols (%s) not in source plan output (%s)", + requiredSymbols, + source.getOutputVariables()); + + node.getMarkerVariables().ifPresent(mapping -> { + checkDependencies( + inputs, + mapping.keySet(), + "Invalid node. Source symbols (%s) not in source plan output (%s)", + mapping.keySet(), + source.getOutputVariables()); + checkDependencies( + inputs, + mapping.values(), + "Invalid node. Source marker symbols (%s) not in source plan output (%s)", + mapping.values(), + source.getOutputVariables()); + }); + + node.getSpecification().ifPresent(specification -> { + checkDependencies( + inputs, + specification.getPartitionBy(), + "Invalid node. Partition by symbols (%s) not in source plan output (%s)", + specification.getPartitionBy(), + source.getOutputVariables()); + specification.getOrderingScheme().ifPresent(orderingScheme -> { + checkDependencies( + inputs, + orderingScheme.getOrderByVariables(), + "Invalid node. Order by symbols (%s) not in source plan output (%s)", + orderingScheme.getOrderBy(), + source.getOutputVariables()); + }); + }); + + return null; + } + @Override public Void visitExplainAnalyze(ExplainAnalyzeNode node, Set boundVariables) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/relational/SqlToRowExpressionTranslator.java b/presto-main-base/src/main/java/com/facebook/presto/sql/relational/SqlToRowExpressionTranslator.java index c33430616a779..a97b68a741d5e 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/relational/SqlToRowExpressionTranslator.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/relational/SqlToRowExpressionTranslator.java @@ -420,6 +420,9 @@ else if (SMALLINT.equals(type)) { else if (BIGINT.equals(type)) { return constant(Long.parseLong(node.getValue()), BIGINT); } + else if (INTEGER.equals(type)) { + return constant(Long.parseLong(node.getValue()), INTEGER); + } } catch (NumberFormatException e) { throw new SemanticException(SemanticErrorCode.INVALID_LITERAL, node, format("Invalid formatted generic %s literal: %s", type, node)); diff --git a/presto-main-base/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java b/presto-main-base/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java index 04d4bb4ecff5d..884a1f217ea14 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java +++ b/presto-main-base/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java @@ -108,6 +108,7 @@ import com.facebook.presto.metadata.QualifiedTablePrefix; import com.facebook.presto.metadata.SchemaPropertyManager; import com.facebook.presto.metadata.Split; +import com.facebook.presto.metadata.TableFunctionRegistry; import com.facebook.presto.metadata.TablePropertyManager; import com.facebook.presto.nodeManager.PluginNodeManager; import com.facebook.presto.operator.Driver; @@ -438,7 +439,7 @@ private LocalQueryRunner(Session defaultSession, FeaturesConfig featuresConfig, featuresConfig.setIgnoreStatsCalculatorFailures(false); this.metadata = new MetadataManager( - new FunctionAndTypeManager(transactionManager, blockEncodingManager, featuresConfig, functionsConfig, new HandleResolver(), ImmutableSet.of()), + new FunctionAndTypeManager(transactionManager, new TableFunctionRegistry(), blockEncodingManager, featuresConfig, functionsConfig, new HandleResolver(), ImmutableSet.of()), blockEncodingManager, createTestingSessionPropertyManager( new SystemSessionProperties( @@ -777,7 +778,9 @@ public void installPlugin(Plugin plugin) @Override public void createCatalog(String catalogName, String connectorName, Map properties) { - throw new UnsupportedOperationException(); +// throw new UnsupportedOperationException(); + nodeManager.addCurrentNodeConnector(new ConnectorId(catalogName)); + connectorManager.createConnection(catalogName, connectorName, properties); } @Override diff --git a/presto-main-base/src/main/java/com/facebook/presto/util/GraphvizPrinter.java b/presto-main-base/src/main/java/com/facebook/presto/util/GraphvizPrinter.java index 2c725b65abbd4..b35ce1d2c989d 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/util/GraphvizPrinter.java +++ b/presto-main-base/src/main/java/com/facebook/presto/util/GraphvizPrinter.java @@ -65,6 +65,7 @@ import com.facebook.presto.sql.planner.plan.SampleNode; import com.facebook.presto.sql.planner.plan.SequenceNode; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.facebook.presto.sql.planner.plan.TableWriterMergeNode; import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; import com.facebook.presto.sql.planner.plan.UnnestNode; @@ -131,6 +132,7 @@ private enum NodeType ANALYZE_FINISH, EXPLAIN_ANALYZE, UPDATE, + TABLE_FUNCTION } private static final Map NODE_COLORS = immutableEnumMap(ImmutableMap.builder() @@ -162,6 +164,7 @@ private enum NodeType .put(NodeType.ANALYZE_FINISH, "plum") .put(NodeType.EXPLAIN_ANALYZE, "cadetblue1") .put(NodeType.UPDATE, "blue") + .put(NodeType.TABLE_FUNCTION, "mediumorchid") .build()); static { @@ -648,6 +651,13 @@ public Void visitApply(ApplyNode node, Void context) return null; } + @Override + public Void visitTableFunctionProcessor(TableFunctionProcessorNode node, Void context) + { + printNode(node, "Table Function Processor", node.getHandle().getSchemaFunctionName().toString(), NODE_COLORS.get(NodeType.TABLE_FUNCTION)); + return null; + } + @Override public Void visitAssignUniqueId(AssignUniqueId node, Void context) { diff --git a/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/MockConnectorColumnHandle.java b/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/MockConnectorColumnHandle.java new file mode 100644 index 0000000000000..84fbd45bfa9b9 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/MockConnectorColumnHandle.java @@ -0,0 +1,81 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.tvf; + +import com.facebook.presto.common.type.Type; +import com.facebook.presto.spi.ColumnHandle; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.Objects; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +public class MockConnectorColumnHandle + implements ColumnHandle +{ + private final String name; + private final Type type; + + @JsonCreator + public MockConnectorColumnHandle( + @JsonProperty("name") String name, + @JsonProperty("type") Type type) + { + this.name = requireNonNull(name, "name is null"); + this.type = requireNonNull(type, "type is null"); + } + + @JsonProperty + public String getName() + { + return name; + } + + @JsonProperty + public Type getType() + { + return type; + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("name", name) + .add("type", type) + .toString(); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if ((o == null) || (getClass() != o.getClass())) { + return false; + } + MockConnectorColumnHandle other = (MockConnectorColumnHandle) o; + return Objects.equals(name, other.name) && + Objects.equals(type, other.type); + } + + @Override + public int hashCode() + { + return Objects.hash(name, type); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/MockConnectorFactory.java b/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/MockConnectorFactory.java new file mode 100644 index 0000000000000..0139f3e22d6d1 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/MockConnectorFactory.java @@ -0,0 +1,589 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.tvf; + +import com.facebook.presto.common.RuntimeStats; +import com.facebook.presto.common.predicate.TupleDomain; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.ColumnMetadata; +import com.facebook.presto.spi.ConnectorHandleResolver; +import com.facebook.presto.spi.ConnectorPageSource; +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.ConnectorSplit; +import com.facebook.presto.spi.ConnectorSplitSource; +import com.facebook.presto.spi.ConnectorTableHandle; +import com.facebook.presto.spi.ConnectorTableLayout; +import com.facebook.presto.spi.ConnectorTableLayoutHandle; +import com.facebook.presto.spi.ConnectorTableLayoutResult; +import com.facebook.presto.spi.ConnectorTableMetadata; +import com.facebook.presto.spi.ConnectorViewDefinition; +import com.facebook.presto.spi.Constraint; +import com.facebook.presto.spi.FixedSplitSource; +import com.facebook.presto.spi.HostAddress; +import com.facebook.presto.spi.InMemoryRecordSet; +import com.facebook.presto.spi.NodeProvider; +import com.facebook.presto.spi.RecordPageSource; +import com.facebook.presto.spi.SchemaTableName; +import com.facebook.presto.spi.SchemaTablePrefix; +import com.facebook.presto.spi.SplitContext; +import com.facebook.presto.spi.connector.Connector; +import com.facebook.presto.spi.connector.ConnectorContext; +import com.facebook.presto.spi.connector.ConnectorFactory; +import com.facebook.presto.spi.connector.ConnectorMetadata; +import com.facebook.presto.spi.connector.ConnectorPageSourceProvider; +import com.facebook.presto.spi.connector.ConnectorRecordSetProvider; +import com.facebook.presto.spi.connector.ConnectorSplitManager; +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import com.facebook.presto.spi.connector.TableFunctionApplicationResult; +import com.facebook.presto.spi.function.SchemaFunctionName; +import com.facebook.presto.spi.function.TableFunctionHandleResolver; +import com.facebook.presto.spi.function.TableFunctionSplitResolver; +import com.facebook.presto.spi.function.table.ConnectorTableFunction; +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; +import com.facebook.presto.spi.function.table.TableFunctionProcessorProvider; +import com.facebook.presto.spi.schedule.NodeSelectionStrategy; +import com.facebook.presto.spi.statistics.TableStatistics; +import com.facebook.presto.spi.transaction.IsolationLevel; +import com.facebook.presto.tpch.TpchColumnHandle; +import com.facebook.presto.tpch.TpchRecordSetProvider; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import org.assertj.core.util.Sets; + +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.function.Supplier; +import java.util.stream.IntStream; + +import static com.facebook.presto.common.type.VarcharType.createUnboundedVarcharType; +import static com.facebook.presto.spi.schedule.NodeSelectionStrategy.NO_PREFERENCE; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static java.util.Objects.requireNonNull; + +public class MockConnectorFactory + implements ConnectorFactory +{ + private final Function> listSchemaNames; + private final BiFunction> listTables; + private final BiFunction> getViews; + private final BiFunction> getColumnHandles; + private final Optional> getTableFunctionProcessorProvider; + private final MockTableFunctionHandleResolver tableFunctionHandleResolver; + private final MockTableFunctionSplitResolver tableFunctionSplitResolver; + private final Supplier getTableStatistics; + private final ApplyTableFunction applyTableFunction; + private final Set tableFunctions; + private final Map> tableFunctionSplitsSources; + + public MockConnectorFactory( + Function> listSchemaNames, + BiFunction> listTables, + BiFunction> getViews, + BiFunction> getColumnHandles, + Supplier getTableStatistics, + ApplyTableFunction applyTableFunction, + Set tableFunctions, + MockTableFunctionHandleResolver tableFunctionHandleResolver, + MockTableFunctionSplitResolver tableFunctionSplitResolver, + Optional> getTableFunctionProcessorProvider, + Map> tableFunctionSplitsSources) + { + this.listSchemaNames = requireNonNull(listSchemaNames, "listSchemaNames is null"); + this.listTables = requireNonNull(listTables, "listTables is null"); + this.getViews = requireNonNull(getViews, "getViews is null"); + this.getColumnHandles = requireNonNull(getColumnHandles, "getColumnHandles is null"); + this.getTableStatistics = requireNonNull(getTableStatistics, "getTableStatistics is null"); + this.applyTableFunction = requireNonNull(applyTableFunction, "applyTableFunction is null"); + this.tableFunctions = requireNonNull(tableFunctions, "tableFunctions is null"); + this.tableFunctionHandleResolver = requireNonNull(tableFunctionHandleResolver, "tableFunctionHandleResolver is null"); + this.tableFunctionSplitResolver = requireNonNull(tableFunctionSplitResolver, "tableFunctionSplitResolver is null"); + this.getTableFunctionProcessorProvider = requireNonNull(getTableFunctionProcessorProvider, "tableFunctionProcessorProvider is null"); + this.tableFunctionSplitsSources = ImmutableMap.copyOf(tableFunctionSplitsSources); + } + + @Override + public String getName() + { + return "mock"; + } + + @Override + public ConnectorHandleResolver getHandleResolver() + { + return new MockHandleResolver(); + } + + @Override + public Connector create(String catalogName, Map config, ConnectorContext context) + { + return new MockConnector(context, listSchemaNames, listTables, getViews, getColumnHandles, getTableStatistics, applyTableFunction, tableFunctions, getTableFunctionProcessorProvider, tableFunctionSplitsSources); + } + + public static Builder builder() + { + return new Builder(); + } + + public static Function> defaultGetColumns() + { + return table -> IntStream.range(0, 100) + .boxed() + .map(i -> ColumnMetadata.builder().setName("column_" + i).setType(createUnboundedVarcharType()).build()) + .collect(toImmutableList()); + } + + @Override + public Optional> getTableFunctionProcessorProvider() + { + return getTableFunctionProcessorProvider; + } + + @Override + public Optional getTableFunctionHandleResolver() + { + return Optional.of(tableFunctionHandleResolver); + } + + @Override + public Optional getTableFunctionSplitResolver() + { + return Optional.of(tableFunctionSplitResolver); + } + + @FunctionalInterface + public interface ApplyTableFunction + { + Optional> apply(ConnectorSession session, ConnectorTableFunctionHandle handle); + } + + public static class MockConnector + implements Connector + { + private static final String DELETE_ROW_ID = "delete_row_id"; + private static final String UPDATE_ROW_ID = "update_row_id"; + private static final String MERGE_ROW_ID = "merge_row_id"; + + private final ConnectorContext context; + private final Function> listSchemaNames; + private final BiFunction> listTables; + private final BiFunction> getViews; + private final BiFunction> getColumnHandles; + private final Supplier getTableStatistics; + private final ApplyTableFunction applyTableFunction; + private final Set tableFunctions; + private final Optional> getTableFunctionProcessorProvider; + private final Map> tableFunctionSplitsSources; + + public MockConnector( + ConnectorContext context, + Function> listSchemaNames, + BiFunction> listTables, + BiFunction> getViews, + BiFunction> getColumnHandles, + Supplier getTableStatistics, + ApplyTableFunction applyTableFunction, + Set tableFunctions, + Optional> getTableFunctionProcessorProvider, + Map> tableFunctionSplitsSources) + { + this.context = requireNonNull(context, "context is null"); + this.listSchemaNames = requireNonNull(listSchemaNames, "listSchemaNames is null"); + this.listTables = requireNonNull(listTables, "listTables is null"); + this.getViews = requireNonNull(getViews, "getViews is null"); + this.getColumnHandles = requireNonNull(getColumnHandles, "getColumnHandles is null"); + this.getTableStatistics = requireNonNull(getTableStatistics, "getTableStatistics is null"); + this.applyTableFunction = requireNonNull(applyTableFunction, "applyTableFunction is null"); + this.tableFunctions = requireNonNull(tableFunctions, "tableFunctions is null"); + this.getTableFunctionProcessorProvider = requireNonNull(getTableFunctionProcessorProvider, "tableFunctionProcessorProvider is null"); + this.tableFunctionSplitsSources = ImmutableMap.copyOf(tableFunctionSplitsSources); + } + + @Override + public ConnectorTransactionHandle beginTransaction(IsolationLevel isolationLevel, boolean readOnly) + { + return MockConnectorTransactionHandle.INSTANCE; + } + + @Override + public ConnectorMetadata getMetadata(ConnectorTransactionHandle transaction) + { + return new MockConnectorMetadata(); + } + + public enum MockConnectorSplit + implements ConnectorSplit + { + MOCK_CONNECTOR_SPLIT; + + @Override + public NodeSelectionStrategy getNodeSelectionStrategy() + { + return NO_PREFERENCE; + } + + @Override + public List getPreferredNodes(NodeProvider nodeProvider) + { + return Collections.emptyList(); + } + + @Override + public Object getInfo() + { + return null; + } + } + + @Override + public ConnectorSplitManager getSplitManager() + { + return new ConnectorSplitManager() + { + @Override + public ConnectorSplitSource getSplits(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorTableLayoutHandle layout, SplitSchedulingContext splitSchedulingContext) + { + return new FixedSplitSource(Collections.singleton(MockConnectorSplit.MOCK_CONNECTOR_SPLIT)); + } + + @Override + public ConnectorSplitSource getSplits(ConnectorTransactionHandle transaction, ConnectorSession session, SchemaFunctionName name, ConnectorTableFunctionHandle functionHandle) + { + Function splitSourceProvider = tableFunctionSplitsSources.get(name); + requireNonNull(splitSourceProvider, "missing ConnectorSplitSource for table function " + name); + return splitSourceProvider.apply(functionHandle); + } + }; + } + + @Override + public ConnectorRecordSetProvider getRecordSetProvider() + { + return new TpchRecordSetProvider(); + } + + @Override + public ConnectorPageSourceProvider getPageSourceProvider() + { + return new MockConnectorPageSourceProvider(); + } + + @Override + public Set getTableFunctions() + { + return tableFunctions; + } + + public Optional> getGetTableFunctionProcessorProvider() + { + return getTableFunctionProcessorProvider; + } + + private class MockConnectorMetadata + implements ConnectorMetadata + { + @Override + public List listSchemaNames(ConnectorSession session) + { + return listSchemaNames.apply(session); + } + + @Override + public ConnectorTableHandle getTableHandle(ConnectorSession session, SchemaTableName tableName) + { + return new ConnectorTableHandle() {}; + } + + @Override + public ConnectorTableMetadata getTableMetadata(ConnectorSession session, ConnectorTableHandle tableHandle) + { + MockConnectorTableHandle table = (MockConnectorTableHandle) tableHandle; + return new ConnectorTableMetadata( + table.getTableName(), + defaultGetColumns().apply(table.getTableName()), + ImmutableMap.of()); + } + + @Override + public List listTables(ConnectorSession session, String schemaNameOrNull) + { + return listTables.apply(session, schemaNameOrNull); + } + + public void setTableProperties(ConnectorSession session, ConnectorTableHandle tableHandle, Map properties) + { + } + + @Override + public Map getColumnHandles(ConnectorSession session, ConnectorTableHandle tableHandle) + { + return (Map) (Map) getColumnHandles.apply(session, tableHandle); + } + + @Override + public ColumnMetadata getColumnMetadata(ConnectorSession session, ConnectorTableHandle tableHandle, ColumnHandle columnHandle) + { + if (columnHandle instanceof MockConnectorColumnHandle) { + MockConnectorColumnHandle mockColumnHandle = (MockConnectorColumnHandle) columnHandle; + return ColumnMetadata.builder().setName(mockColumnHandle.getName()).setType(mockColumnHandle.getType()).build(); + } + else { + TpchColumnHandle tpchColumnHandle = (TpchColumnHandle) columnHandle; + return ColumnMetadata.builder().setName(tpchColumnHandle.getColumnName()).setType(tpchColumnHandle.getType()).build(); + } + } + + @Override + public Map> listTableColumns(ConnectorSession session, SchemaTablePrefix prefix) + { + return listTables(session, prefix.getSchemaName()).stream() + .collect(toImmutableMap(table -> table, table -> IntStream.range(0, 100) + .boxed() + .map(i -> ColumnMetadata.builder().setName("column_" + i).setType(createUnboundedVarcharType()).build()) + .collect(toImmutableList()))); + } + + @Override + public ConnectorTableLayoutResult getTableLayoutForConstraint( + ConnectorSession session, + ConnectorTableHandle table, + Constraint constraint, + Optional> desiredColumns) + { + // TODO: Currently not supporting constraints + MockTableLayoutHandle mock = new MockTableLayoutHandle((MockConnectorTableHandle) table, TupleDomain.none()); + return new ConnectorTableLayoutResult(new ConnectorTableLayout(mock, + Optional.empty(), + mock.getPredicate(), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Collections.emptyList(), + Optional.empty()), TupleDomain.none()); + } + + @Override + public ConnectorTableLayout getTableLayout(ConnectorSession session, ConnectorTableLayoutHandle handle) + { + MockTableLayoutHandle mock = (MockTableLayoutHandle) handle; + return new ConnectorTableLayout(handle); + } + + @Override + public Map getViews(ConnectorSession session, SchemaTablePrefix prefix) + { + return getViews.apply(session, prefix); + } + + @Override + public TableStatistics getTableStatistics(ConnectorSession session, ConnectorTableHandle tableHandle, Optional tableLayoutHandle, List columnHandles, Constraint constraint) + { + return getTableStatistics.get(); + } + + @Override + public Optional> applyTableFunction(ConnectorSession session, ConnectorTableFunctionHandle handle) + { + return applyTableFunction.apply(session, handle); + } + } + + private class MockConnectorPageSourceProvider + implements ConnectorPageSourceProvider + { + @Override + //public ConnectorPageSource createPageSource(ConnectorTransactionHandle transaction, ConnectorSession session, ConnectorSplit split, ConnectorTableHandle table, List columns, SplitContext splitContext) + public ConnectorPageSource createPageSource(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorSplit split, ConnectorTableLayoutHandle layout, List columns, SplitContext splitContext, RuntimeStats runtimeStats) + { + MockConnectorTableHandle handle = ((MockTableLayoutHandle) layout).getTable(); + SchemaTableName tableName = handle.getTableName(); + List projection = columns.stream() + .map(MockConnectorColumnHandle.class::cast) + .collect(toImmutableList()); + List types = columns.stream() + .map(MockConnectorColumnHandle.class::cast) + .map(MockConnectorColumnHandle::getType) + .collect(toImmutableList()); + Map columnIndexes = getColumnIndexes(tableName); + /* + List> records = data.apply(tableName).stream() + .map(record -> { + ImmutableList.Builder projectedRow = ImmutableList.builder(); + for (MockConnectorColumnHandle column : projection) { + String columnName = column.getName(); + if (columnName.equals(DELETE_ROW_ID) || columnName.equals(UPDATE_ROW_ID) || columnName.equals(MERGE_ROW_ID)) { + projectedRow.add(0); + continue; + } + Integer index = columnIndexes.get(columnName); + requireNonNull(index, "index is null"); + projectedRow.add(record.get(index)); + } + return projectedRow.build(); + }) + .collect(toImmutableList());*/ + + return new MockConnectorPageSource(new RecordPageSource(new InMemoryRecordSet(types, ImmutableList.of()))); + } + + private Map getColumnIndexes(SchemaTableName tableName) + { + ImmutableMap.Builder columnIndexes = ImmutableMap.builder(); + List columnMetadata = defaultGetColumns().apply(tableName); + for (int index = 0; index < columnMetadata.size(); index++) { + columnIndexes.put(columnMetadata.get(index).getName(), index); + } + return columnIndexes.buildOrThrow(); + } + } + } + + public static class MockTableFunctionHandleResolver + implements TableFunctionHandleResolver + { + Set> handles = Sets.newHashSet(); + + @Override + public Set> getTableFunctionHandleClasses() + { + return handles; + } + + public void addTableFunctionHandle(Class tableFunctionHandleClass) + { + handles.add(tableFunctionHandleClass); + } + } + + public static class MockTableFunctionSplitResolver + implements TableFunctionSplitResolver + { + Set> handles = Sets.newHashSet(); + + @Override + public Set> getTableFunctionSplitClasses() + { + return handles; + } + + public void addSplitClass(Class splitClass) + { + handles.add(splitClass); + } + } + + public static final class Builder + { + private Function> listSchemaNames = (session) -> ImmutableList.of(); + private BiFunction> listTables = (session, schemaName) -> ImmutableList.of(); + private BiFunction> getViews = (session, schemaTablePrefix) -> ImmutableMap.of(); + private BiFunction> getColumnHandles = (session, tableHandle) -> { + MockConnectorTableHandle table = (MockConnectorTableHandle) tableHandle; + return defaultGetColumns().apply(table.getTableName()).stream() + .collect(toImmutableMap(ColumnMetadata::getName, column -> + new MockConnectorColumnHandle(column.getName(), column.getType()))); + }; + private Optional> getTableFunctionProcessorProvider = Optional.empty(); + private Supplier getTableStatistics = TableStatistics::empty; + private ApplyTableFunction applyTableFunction = (session, handle) -> Optional.empty(); + private Set tableFunctions = ImmutableSet.of(); + private MockTableFunctionHandleResolver tableFunctionHandleResolver = new MockTableFunctionHandleResolver(); + private MockTableFunctionSplitResolver tableFunctionSplitResolver = new MockTableFunctionSplitResolver(); + private final Map> tableFunctionSplitsSources = new HashMap<>(); + + public Builder withListSchemaNames(Function> listSchemaNames) + { + this.listSchemaNames = requireNonNull(listSchemaNames, "listSchemaNames is null"); + return this; + } + + public Builder withListTables(BiFunction> listTables) + { + this.listTables = requireNonNull(listTables, "listTables is null"); + return this; + } + + public Builder withGetViews(BiFunction> getViews) + { + this.getViews = requireNonNull(getViews, "getViews is null"); + return this; + } + + public Builder withGetColumnHandles(BiFunction> getColumnHandles) + { + this.getColumnHandles = requireNonNull(getColumnHandles, "getColumnHandles is null"); + return this; + } + + public Builder withGetTableStatistics(Supplier getTableStatistics) + { + this.getTableStatistics = requireNonNull(getTableStatistics, "getTableStatistics is null"); + return this; + } + + public Builder withApplyTableFunction(ApplyTableFunction applyTableFunction) + { + this.applyTableFunction = applyTableFunction; + return this; + } + + public Builder withTableFunctions(Iterable tableFunctions) + { + this.tableFunctions = ImmutableSet.copyOf(tableFunctions); + return this; + } + + public Builder withTableFunctionResolver(Class tableFunctionHandleclass) + { + this.tableFunctionHandleResolver.addTableFunctionHandle(tableFunctionHandleclass); + return this; + } + + public Builder withTableFunctionSplitResolver(Class splitClass) + { + this.tableFunctionSplitResolver.addSplitClass(splitClass); + return this; + } + + public Builder withGetTableFunctionProcessorProvider(Optional> getTableFunctionProcessorProvider) + { + this.getTableFunctionProcessorProvider = getTableFunctionProcessorProvider; + return this; + } + + public Builder withTableFunctionSplitSource(SchemaFunctionName name, Function sourceProvider) + { + tableFunctionSplitsSources.put(name, sourceProvider); + return this; + } + + public MockConnectorFactory build() + { + return new MockConnectorFactory(listSchemaNames, listTables, getViews, getColumnHandles, getTableStatistics, applyTableFunction, tableFunctions, tableFunctionHandleResolver, tableFunctionSplitResolver, getTableFunctionProcessorProvider, tableFunctionSplitsSources); + } + + private static T notSupported() + { + throw new UnsupportedOperationException(); + } + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/MockConnectorPageSource.java b/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/MockConnectorPageSource.java new file mode 100644 index 0000000000000..3d77b1b797519 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/MockConnectorPageSource.java @@ -0,0 +1,82 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.tvf; + +import com.facebook.presto.common.Page; +import com.facebook.presto.spi.ConnectorPageSource; + +import java.io.IOException; +import java.util.concurrent.CompletableFuture; + +import static java.util.Objects.requireNonNull; + +public class MockConnectorPageSource + implements ConnectorPageSource +{ + private final ConnectorPageSource delegate; + + public MockConnectorPageSource(ConnectorPageSource delegate) + { + this.delegate = requireNonNull(delegate, "delegate is null"); + } + + @Override + public long getCompletedBytes() + { + return delegate.getCompletedBytes(); + } + + @Override + public long getCompletedPositions() + { + return 0; + } + + @Override + public long getReadTimeNanos() + { + return delegate.getReadTimeNanos(); + } + + @Override + public boolean isFinished() + { + return delegate.isFinished(); + } + + @Override + public Page getNextPage() + { + return delegate.getNextPage(); + } + + @Override + public long getSystemMemoryUsage() + { + return 0; + } + + @Override + public void close() + throws IOException + { + delegate.close(); + } + + @Override + public CompletableFuture isBlocked() + { + return delegate.isBlocked(); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/MockConnectorPlugin.java b/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/MockConnectorPlugin.java new file mode 100644 index 0000000000000..a5abe3bdfec9a --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/MockConnectorPlugin.java @@ -0,0 +1,37 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.tvf; + +import com.facebook.presto.spi.Plugin; +import com.facebook.presto.spi.connector.ConnectorFactory; +import com.google.common.collect.ImmutableList; + +import static java.util.Objects.requireNonNull; + +public class MockConnectorPlugin + implements Plugin +{ + private final ConnectorFactory connectorFactory; + + public MockConnectorPlugin(ConnectorFactory connectorFactory) + { + this.connectorFactory = requireNonNull(connectorFactory, "connectorFactory is null"); + } + + @Override + public Iterable getConnectorFactories() + { + return ImmutableList.of(connectorFactory); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/MockConnectorTableHandle.java b/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/MockConnectorTableHandle.java new file mode 100644 index 0000000000000..b29826b684dd7 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/MockConnectorTableHandle.java @@ -0,0 +1,98 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.tvf; + +import com.facebook.presto.common.predicate.TupleDomain; +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.ConnectorTableHandle; +import com.facebook.presto.spi.SchemaTableName; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +import static java.util.Objects.requireNonNull; + +public class MockConnectorTableHandle + implements ConnectorTableHandle +{ + private final SchemaTableName tableName; + private final TupleDomain constraint; + private final Optional> columns; + + public MockConnectorTableHandle(SchemaTableName tableName) + { + this(tableName, TupleDomain.all(), Optional.empty()); + } + + @JsonCreator + public MockConnectorTableHandle( + @JsonProperty SchemaTableName tableName, + @JsonProperty("constraint") TupleDomain constraint, + @JsonProperty("columns") Optional> columns) + { + this.tableName = requireNonNull(tableName, "tableName is null"); + this.constraint = requireNonNull(constraint, "constraint is null"); + requireNonNull(columns, "columns is null"); + this.columns = columns.map(ImmutableList::copyOf); + } + + @JsonProperty + public SchemaTableName getTableName() + { + return tableName; + } + + @JsonProperty + public TupleDomain getConstraint() + { + return constraint; + } + + @JsonProperty + public Optional> getColumns() + { + return columns; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + MockConnectorTableHandle other = (MockConnectorTableHandle) o; + return Objects.equals(tableName, other.tableName) && + Objects.equals(constraint, other.constraint) && + Objects.equals(columns, other.columns); + } + + @Override + public int hashCode() + { + return Objects.hash(tableName, constraint, columns); + } + + @Override + public String toString() + { + return tableName.toString(); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/MockConnectorTransactionHandle.java b/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/MockConnectorTransactionHandle.java new file mode 100644 index 0000000000000..201e4f4dcb984 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/MockConnectorTransactionHandle.java @@ -0,0 +1,22 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.tvf; + +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; + +public enum MockConnectorTransactionHandle + implements ConnectorTransactionHandle +{ + INSTANCE +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/MockHandleResolver.java b/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/MockHandleResolver.java new file mode 100644 index 0000000000000..dd8ed2929bd0a --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/MockHandleResolver.java @@ -0,0 +1,63 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.tvf; + +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.ConnectorHandleResolver; +import com.facebook.presto.spi.ConnectorSplit; +import com.facebook.presto.spi.ConnectorTableHandle; +import com.facebook.presto.spi.ConnectorTableLayoutHandle; +import com.facebook.presto.spi.connector.ConnectorPartitioningHandle; +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import com.facebook.presto.tpch.TpchPartitioningHandle; + +public class MockHandleResolver + implements ConnectorHandleResolver +{ + @Override + public Class getTableHandleClass() + { + return MockConnectorTableHandle.class; + } + + @Override + public Class getColumnHandleClass() + { + return MockConnectorColumnHandle.class; + } + + @Override + public Class getSplitClass() + { + return MockConnectorFactory.MockConnector.MockConnectorSplit.class; + } + + @Override + public Class getTableLayoutHandleClass() + { + return MockTableLayoutHandle.class; + } + + @Override + public Class getTransactionHandleClass() + { + return MockConnectorTransactionHandle.class; + } + + @Override + public Class getPartitioningHandleClass() + { + return TpchPartitioningHandle.class; + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/MockTableLayoutHandle.java b/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/MockTableLayoutHandle.java new file mode 100644 index 0000000000000..d70b8e202d7d1 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/MockTableLayoutHandle.java @@ -0,0 +1,66 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.tvf; + +import com.facebook.presto.common.plan.PlanCanonicalizationStrategy; +import com.facebook.presto.common.predicate.TupleDomain; +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.ConnectorSplit; +import com.facebook.presto.spi.ConnectorTableLayoutHandle; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableMap; + +import java.util.Optional; + +public class MockTableLayoutHandle + implements ConnectorTableLayoutHandle +{ + private final MockConnectorTableHandle table; + private final TupleDomain predicate; + + @JsonCreator + public MockTableLayoutHandle(@JsonProperty("table") MockConnectorTableHandle table, @JsonProperty("predicate") TupleDomain predicate) + { + this.table = table; + this.predicate = predicate; + } + + @JsonProperty + public MockConnectorTableHandle getTable() + { + return table; + } + + @JsonProperty + public TupleDomain getPredicate() + { + return predicate; + } + + @Override + public String toString() + { + return table.toString(); + } + + @Override + public Object getIdentifier(Optional split, PlanCanonicalizationStrategy strategy) + { + return ImmutableMap.builder() + .put("table", table) + .put("predicate", predicate.canonicalize(ignored -> false)) + .build(); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/TestingTableFunctions.java b/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/TestingTableFunctions.java new file mode 100644 index 0000000000000..9b328aeed51b8 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/TestingTableFunctions.java @@ -0,0 +1,1353 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.tvf; + +import com.facebook.presto.common.Page; +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.block.BlockBuilder; +import com.facebook.presto.common.predicate.TupleDomain; +import com.facebook.presto.common.type.RowType; +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.ConnectorSplit; +import com.facebook.presto.spi.ConnectorSplitSource; +import com.facebook.presto.spi.FixedSplitSource; +import com.facebook.presto.spi.HostAddress; +import com.facebook.presto.spi.NodeProvider; +import com.facebook.presto.spi.SchemaTableName; +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import com.facebook.presto.spi.function.table.AbstractConnectorTableFunction; +import com.facebook.presto.spi.function.table.Argument; +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; +import com.facebook.presto.spi.function.table.Descriptor; +import com.facebook.presto.spi.function.table.DescriptorArgumentSpecification; +import com.facebook.presto.spi.function.table.ScalarArgument; +import com.facebook.presto.spi.function.table.ScalarArgumentSpecification; +import com.facebook.presto.spi.function.table.TableArgument; +import com.facebook.presto.spi.function.table.TableArgumentSpecification; +import com.facebook.presto.spi.function.table.TableFunctionAnalysis; +import com.facebook.presto.spi.function.table.TableFunctionDataProcessor; +import com.facebook.presto.spi.function.table.TableFunctionProcessorProvider; +import com.facebook.presto.spi.function.table.TableFunctionProcessorState; +import com.facebook.presto.spi.function.table.TableFunctionSplitProcessor; +import com.facebook.presto.spi.schedule.NodeSelectionStrategy; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; +import io.airlift.slice.Slice; +import org.openjdk.jol.info.ClassLayout; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Optional; +import java.util.stream.IntStream; + +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.BooleanType.BOOLEAN; +import static com.facebook.presto.common.type.IntegerType.INTEGER; +import static com.facebook.presto.common.type.VarcharType.VARCHAR; +import static com.facebook.presto.connector.tvf.TestingTableFunctions.ConstantFunction.ConstantFunctionSplit.DEFAULT_SPLIT_SIZE; +import static com.facebook.presto.spi.function.table.ReturnTypeSpecification.DescribedTable; +import static com.facebook.presto.spi.function.table.ReturnTypeSpecification.GenericTable.GENERIC_TABLE; +import static com.facebook.presto.spi.function.table.ReturnTypeSpecification.OnlyPassThrough.ONLY_PASS_THROUGH; +import static com.facebook.presto.spi.function.table.TableFunctionProcessorState.Finished.FINISHED; +import static com.facebook.presto.spi.function.table.TableFunctionProcessorState.Processed.produced; +import static com.facebook.presto.spi.function.table.TableFunctionProcessorState.Processed.usedInput; +import static com.facebook.presto.spi.function.table.TableFunctionProcessorState.Processed.usedInputAndProduced; +import static com.facebook.presto.spi.schedule.NodeSelectionStrategy.NO_PREFERENCE; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.Iterables.getOnlyElement; +import static io.airlift.slice.Slices.utf8Slice; +import static java.lang.Math.toIntExact; +import static java.util.Objects.requireNonNull; + +public class TestingTableFunctions +{ + private static final String SCHEMA_NAME = "system"; + private static final String TABLE_NAME = "table"; + private static final String COLUMN_NAME = "column"; + private static final ConnectorTableFunctionHandle HANDLE = new TestingTableFunctionHandle(); + private static final TableFunctionAnalysis ANALYSIS = TableFunctionAnalysis.builder() + .handle(HANDLE) + .returnedType(new Descriptor(ImmutableList.of(new Descriptor.Field(COLUMN_NAME, Optional.of(BOOLEAN))))) + .build(); + private static final TableFunctionAnalysis NO_DESCRIPTOR_ANALYSIS = TableFunctionAnalysis.builder() + .handle(HANDLE) + .requiredColumns("INPUT", ImmutableList.of(0)) + .build(); + + /** + * A table function returning a table with single empty column of type BOOLEAN. + * The argument `COLUMN` is the column name. + * The argument `IGNORED` is ignored. + * Both arguments are optional. + */ + public static class SimpleTableFunction + extends AbstractConnectorTableFunction + { + private static final String SCHEMA_NAME = "system"; + private static final String FUNCTION_NAME = "simple_table_function"; + private static final String TABLE_NAME = "simple_table"; + + public SimpleTableFunction() + { + super( + SCHEMA_NAME, + FUNCTION_NAME, + Arrays.asList( + ScalarArgumentSpecification.builder() + .name("COLUMN") + .type(VARCHAR) + .defaultValue(utf8Slice("col")) + .build(), + ScalarArgumentSpecification.builder() + .name("IGNORED") + .type(BIGINT) + .defaultValue(0L) + .build()), + GENERIC_TABLE); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + ScalarArgument argument = (ScalarArgument) arguments.get("COLUMN"); + String columnName = ((Slice) argument.getValue()).toStringUtf8(); + + return TableFunctionAnalysis.builder() + .handle(new SimpleTableFunctionHandle(getSchema(), TABLE_NAME, columnName)) + .returnedType(new Descriptor(ImmutableList.of(new Descriptor.Field(columnName, Optional.of(BOOLEAN))))) + .build(); + } + + public static class SimpleTableFunctionHandle + implements ConnectorTableFunctionHandle + { + private final MockConnectorTableHandle tableHandle; + + public SimpleTableFunctionHandle(String schema, String table, String column) + { + this.tableHandle = new MockConnectorTableHandle( + new SchemaTableName(schema, table), + TupleDomain.all(), + Optional.of(ImmutableList.of(new MockConnectorColumnHandle(column, BOOLEAN)))); + } + + public MockConnectorTableHandle getTableHandle() + { + return tableHandle; + } + } + } + + public static class TwoScalarArgumentsFunction + extends AbstractConnectorTableFunction + { + public TwoScalarArgumentsFunction() + { + super( + SCHEMA_NAME, + "two_arguments_function", + ImmutableList.of( + ScalarArgumentSpecification.builder() + .name("TEXT") + .type(VARCHAR) + .build(), + ScalarArgumentSpecification.builder() + .name("NUMBER") + .type(BIGINT) + .defaultValue(null) + .build()), + GENERIC_TABLE); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + return ANALYSIS; + } + } + + public static class TableArgumentFunction + extends AbstractConnectorTableFunction + { + public TableArgumentFunction() + { + super( + SCHEMA_NAME, + "table_argument_function", + ImmutableList.of( + TableArgumentSpecification.builder() + .name("INPUT") + .keepWhenEmpty() + .build()), + GENERIC_TABLE); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + return TableFunctionAnalysis.builder() + .handle(HANDLE) + .returnedType(new Descriptor(ImmutableList.of(new Descriptor.Field(COLUMN_NAME, Optional.of(BOOLEAN))))) + .requiredColumns("INPUT", ImmutableList.of(0)) + .build(); + } + } + + public static class TableArgumentRowSemanticsFunction + extends AbstractConnectorTableFunction + { + public TableArgumentRowSemanticsFunction() + { + super( + SCHEMA_NAME, + "table_argument_row_semantics_function", + ImmutableList.of( + TableArgumentSpecification.builder() + .name("INPUT") + .rowSemantics() + .build()), + GENERIC_TABLE); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + return TableFunctionAnalysis.builder() + .handle(HANDLE) + .returnedType(new Descriptor(ImmutableList.of(new Descriptor.Field(COLUMN_NAME, Optional.of(BOOLEAN))))) + .requiredColumns("INPUT", ImmutableList.of(0)) + .build(); + } + } + + public static class DescriptorArgumentFunction + extends AbstractConnectorTableFunction + { + public DescriptorArgumentFunction() + { + super( + SCHEMA_NAME, + "descriptor_argument_function", + ImmutableList.of( + DescriptorArgumentSpecification.builder() + .name("SCHEMA") + .defaultValue(null) + .build()), + GENERIC_TABLE); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + return ANALYSIS; + } + } + + public static class TwoTableArgumentsFunction + extends AbstractConnectorTableFunction + { + public TwoTableArgumentsFunction() + { + super( + SCHEMA_NAME, + "two_table_arguments_function", + ImmutableList.of( + TableArgumentSpecification.builder() + .name("INPUT1") + .keepWhenEmpty() + .build(), + TableArgumentSpecification.builder() + .name("INPUT2") + .keepWhenEmpty() + .build()), + GENERIC_TABLE); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + return TableFunctionAnalysis.builder() + .handle(HANDLE) + .returnedType(new Descriptor(ImmutableList.of(new Descriptor.Field(COLUMN_NAME, Optional.of(BOOLEAN))))) + .requiredColumns("INPUT1", ImmutableList.of(0)) + .requiredColumns("INPUT2", ImmutableList.of(0)) + .build(); + } + } + + public static class OnlyPassThroughFunction + extends AbstractConnectorTableFunction + { + public OnlyPassThroughFunction() + { + super( + SCHEMA_NAME, + "only_pass_through_function", + ImmutableList.of( + TableArgumentSpecification.builder() + .name("INPUT") + .keepWhenEmpty() + .build()), + ONLY_PASS_THROUGH); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + return NO_DESCRIPTOR_ANALYSIS; + } + } + + public static class MonomorphicStaticReturnTypeFunction + extends AbstractConnectorTableFunction + { + public MonomorphicStaticReturnTypeFunction() + { + super( + SCHEMA_NAME, + "monomorphic_static_return_type_function", + ImmutableList.of(), + new DescribedTable(Descriptor.descriptor( + ImmutableList.of("a", "b"), + ImmutableList.of(BOOLEAN, INTEGER)))); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + return TableFunctionAnalysis.builder() + .handle(HANDLE) + .build(); + } + } + + public static class PolymorphicStaticReturnTypeFunction + extends AbstractConnectorTableFunction + { + public PolymorphicStaticReturnTypeFunction() + { + super( + SCHEMA_NAME, + "polymorphic_static_return_type_function", + ImmutableList.of(TableArgumentSpecification.builder() + .name("INPUT") + .keepWhenEmpty() + .build()), + new DescribedTable(Descriptor.descriptor( + ImmutableList.of("a", "b"), + ImmutableList.of(BOOLEAN, INTEGER)))); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + return NO_DESCRIPTOR_ANALYSIS; + } + } + + public static class PassThroughFunction + extends AbstractConnectorTableFunction + { + public PassThroughFunction() + { + super( + SCHEMA_NAME, + "pass_through_function", + ImmutableList.of(TableArgumentSpecification.builder() + .name("INPUT") + .passThroughColumns() + .keepWhenEmpty() + .build()), + new DescribedTable(Descriptor.descriptor( + ImmutableList.of("x"), + ImmutableList.of(BOOLEAN)))); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + return NO_DESCRIPTOR_ANALYSIS; + } + } + + public static class DifferentArgumentTypesFunction + extends AbstractConnectorTableFunction + { + public DifferentArgumentTypesFunction() + { + super( + SCHEMA_NAME, + "different_arguments_function", + ImmutableList.of( + TableArgumentSpecification.builder() + .name("INPUT_1") + .passThroughColumns() + .keepWhenEmpty() + .build(), + DescriptorArgumentSpecification.builder() + .name("LAYOUT") + .build(), + TableArgumentSpecification.builder() + .name("INPUT_2") + .rowSemantics() + .passThroughColumns() + .build(), + ScalarArgumentSpecification.builder() + .name("ID") + .type(BIGINT) + .build(), + TableArgumentSpecification.builder() + .name("INPUT_3") + .pruneWhenEmpty() + .build()), + GENERIC_TABLE); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + return TableFunctionAnalysis.builder() + .handle(HANDLE) + .returnedType(new Descriptor(ImmutableList.of(new Descriptor.Field(COLUMN_NAME, Optional.of(BOOLEAN))))) + .requiredColumns("INPUT_1", ImmutableList.of(0)) + .requiredColumns("INPUT_2", ImmutableList.of(0)) + .requiredColumns("INPUT_3", ImmutableList.of(0)) + .build(); + } + } + + public static class RequiredColumnsFunction + extends AbstractConnectorTableFunction + { + public RequiredColumnsFunction() + { + super( + SCHEMA_NAME, + "required_columns_function", + ImmutableList.of( + TableArgumentSpecification.builder() + .name("INPUT") + .keepWhenEmpty() + .build()), + GENERIC_TABLE); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + return TableFunctionAnalysis.builder() + .handle(HANDLE) + .returnedType(new Descriptor(ImmutableList.of(new Descriptor.Field("column", Optional.of(BOOLEAN))))) + .requiredColumns("INPUT", ImmutableList.of(0, 1)) + .build(); + } + } + + public static class TestingTableFunctionHandle + implements ConnectorTableFunctionHandle + { + private final MockConnectorTableHandle tableHandle; + + public TestingTableFunctionHandle() + { + this.tableHandle = new MockConnectorTableHandle( + new SchemaTableName(SCHEMA_NAME, TABLE_NAME), + TupleDomain.all(), + Optional.of(ImmutableList.of(new MockConnectorColumnHandle(COLUMN_NAME, BOOLEAN)))); + } + + public MockConnectorTableHandle getTableHandle() + { + return tableHandle; + } + } + + // for testing execution by operator + + public static class IdentityFunction + extends AbstractConnectorTableFunction + { + public IdentityFunction() + { + super( + SCHEMA_NAME, + "identity_function", + ImmutableList.of( + TableArgumentSpecification.builder() + .name("INPUT") + .keepWhenEmpty() + .build()), + GENERIC_TABLE); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + List inputColumns = ((TableArgument) arguments.get("INPUT")).getRowType().getFields(); + Descriptor returnedType = new Descriptor(inputColumns.stream() + .map(field -> new Descriptor.Field(field.getName().orElse("anonymous_column"), Optional.of(field.getType()))) + .collect(toImmutableList())); + return TableFunctionAnalysis.builder() + .handle(new EmptyTableFunctionHandle("")) + .returnedType(returnedType) + .requiredColumns("INPUT", IntStream.range(0, inputColumns.size()).boxed().collect(toImmutableList())) + .build(); + } + + public static class IdentityFunctionProcessorProvider + implements TableFunctionProcessorProvider + { + @Override + public TableFunctionDataProcessor getDataProcessor(ConnectorTableFunctionHandle handle) + { + return input -> { + if (input == null) { + return FINISHED; + } + Optional inputPage = getOnlyElement(input); + return inputPage.map(TableFunctionProcessorState.Processed::usedInputAndProduced).orElseThrow(NoSuchElementException::new); + }; + } + } + } + + public static class IdentityPassThroughFunction + extends AbstractConnectorTableFunction + { + public IdentityPassThroughFunction() + { + super( + SCHEMA_NAME, + "identity_pass_through_function", + ImmutableList.of( + TableArgumentSpecification.builder() + .name("INPUT") + .passThroughColumns() + .keepWhenEmpty() + .build()), + ONLY_PASS_THROUGH); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + return TableFunctionAnalysis.builder() + .handle(new EmptyTableFunctionHandle("")) + .requiredColumns("INPUT", ImmutableList.of(0)) // per spec, function must require at least one column + .build(); + } + + public static class IdentityPassThroughFunctionProcessorProvider + implements TableFunctionProcessorProvider + { + @Override + public TableFunctionDataProcessor getDataProcessor(ConnectorTableFunctionHandle handle) + { + return new IdentityPassThroughFunctionDataProcessor(); + } + } + + public static class IdentityPassThroughFunctionDataProcessor + implements TableFunctionDataProcessor + { + private long processedPositions; // stateful + + @Override + public TableFunctionProcessorState process(List> input) + { + if (input == null) { + return FINISHED; + } + + Page page = getOnlyElement(input).orElseThrow(NoSuchElementException::new); + BlockBuilder builder = BIGINT.createBlockBuilder(null, page.getPositionCount()); + for (long index = processedPositions; index < processedPositions + page.getPositionCount(); index++) { + // TODO check for long overflow + builder.writeLong(index); + } + processedPositions = processedPositions + page.getPositionCount(); + return usedInputAndProduced(new Page(builder.build())); + } + } + } + + public static class RepeatFunction + extends AbstractConnectorTableFunction + { + public RepeatFunction() + { + super( + SCHEMA_NAME, + "repeat", + ImmutableList.of( + TableArgumentSpecification.builder() + .name("INPUT") + .passThroughColumns() + .keepWhenEmpty() + .build(), + ScalarArgumentSpecification.builder() + .name("N") + .type(INTEGER) + .defaultValue(2L) + .build()), + ONLY_PASS_THROUGH); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + ScalarArgument count = (ScalarArgument) arguments.get("N"); + requireNonNull(count.getValue(), "count value for function repeat() is null"); + checkArgument((long) count.getValue() > 0, "count value for function repeat() must be positive"); + + return TableFunctionAnalysis.builder() + .handle(new RepeatFunctionHandle((long) count.getValue())) + .requiredColumns("INPUT", ImmutableList.of(0)) // per spec, function must require at least one column + .build(); + } + + public static class RepeatFunctionHandle + implements ConnectorTableFunctionHandle + { + private final long count; + + @JsonCreator + public RepeatFunctionHandle(@JsonProperty("count") long count) + { + this.count = count; + } + + @JsonProperty + public long getCount() + { + return count; + } + } + + public static class RepeatFunctionProcessorProvider + implements TableFunctionProcessorProvider + { + @Override + public TableFunctionDataProcessor getDataProcessor(ConnectorTableFunctionHandle handle) + { + return new RepeatFunctionDataProcessor(((RepeatFunctionHandle) handle).getCount()); + } + } + + public static class RepeatFunctionDataProcessor + implements TableFunctionDataProcessor + { + private final long count; + + // stateful + private long processedPositions; + private long processedRounds; + private Block indexes; + boolean usedData; + + public RepeatFunctionDataProcessor(long count) + { + this.count = count; + } + + @Override + public TableFunctionProcessorState process(List> input) + { + if (input == null) { + if (processedRounds < count && indexes != null) { + processedRounds++; + return produced(new Page(indexes)); + } + return FINISHED; + } + + Page page = getOnlyElement(input).orElseThrow(NoSuchElementException::new); + if (processedRounds == 0) { + BlockBuilder builder = BIGINT.createBlockBuilder(null, page.getPositionCount()); + for (long index = processedPositions; index < processedPositions + page.getPositionCount(); index++) { + // TODO check for long overflow + builder.writeLong(index); + } + processedPositions = processedPositions + page.getPositionCount(); + indexes = builder.build(); + usedData = true; + } + else { + usedData = false; + } + processedRounds++; + + Page result = new Page(indexes); + + if (processedRounds == count) { + processedRounds = 0; + indexes = null; + } + + if (usedData) { + return usedInputAndProduced(result); + } + return produced(result); + } + } + } + + public static class EmptyOutputFunction + extends AbstractConnectorTableFunction + { + public EmptyOutputFunction() + { + super( + SCHEMA_NAME, + "empty_output", + ImmutableList.of(TableArgumentSpecification.builder() + .name("INPUT") + .keepWhenEmpty() + .build()), + new DescribedTable(new Descriptor(ImmutableList.of(new Descriptor.Field("column", Optional.of(BOOLEAN)))))); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + return TableFunctionAnalysis.builder() + .handle(new EmptyTableFunctionHandle("")) + .requiredColumns("INPUT", IntStream.range(0, ((TableArgument) arguments.get("INPUT")).getRowType().getFields().size()).boxed().collect(toImmutableList())) + .build(); + } + + public static class EmptyOutputProcessorProvider + implements TableFunctionProcessorProvider + { + @Override + public TableFunctionDataProcessor getDataProcessor(ConnectorTableFunctionHandle handle) + { + return new EmptyOutputDataProcessor(); + } + } + + // returns an empty Page (one column, zero rows) for each Page of input + private static class EmptyOutputDataProcessor + implements TableFunctionDataProcessor + { + private static final Page EMPTY_PAGE = new Page(BOOLEAN.createBlockBuilder(null, 0).build()); + + @Override + public TableFunctionProcessorState process(List> input) + { + if (input == null) { + return FINISHED; + } + return usedInputAndProduced(EMPTY_PAGE); + } + } + } + + public static class EmptyOutputWithPassThroughFunction + extends AbstractConnectorTableFunction + { + public EmptyOutputWithPassThroughFunction() + { + super( + SCHEMA_NAME, + "empty_output_with_pass_through", + ImmutableList.of(TableArgumentSpecification.builder() + .name("INPUT") + .keepWhenEmpty() + .passThroughColumns() + .build()), + new DescribedTable(new Descriptor(ImmutableList.of(new Descriptor.Field("column", Optional.of(BOOLEAN)))))); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + return TableFunctionAnalysis.builder() + .handle(new EmptyTableFunctionHandle("")) + .requiredColumns("INPUT", IntStream.range(0, ((TableArgument) arguments.get("INPUT")).getRowType().getFields().size()).boxed().collect(toImmutableList())) + .build(); + } + + public static class EmptyOutputWithPassThroughProcessorProvider + implements TableFunctionProcessorProvider + { + @Override + public TableFunctionDataProcessor getDataProcessor(ConnectorTableFunctionHandle handle) + { + return new EmptyOutputWithPassThroughDataProcessor(); + } + } + + // returns an empty Page (one proper column and pass-through, zero rows) for each Page of input + private static class EmptyOutputWithPassThroughDataProcessor + implements TableFunctionDataProcessor + { + // one proper channel, and one pass-through index channel + private static final Page EMPTY_PAGE = new Page( + BOOLEAN.createBlockBuilder(null, 0).build(), + BIGINT.createBlockBuilder(null, 0).build()); + + @Override + public TableFunctionProcessorState process(List> input) + { + if (input == null) { + return FINISHED; + } + return usedInputAndProduced(EMPTY_PAGE); + } + } + } + + public static class TestInputsFunction + extends AbstractConnectorTableFunction + { + public TestInputsFunction() + { + super( + SCHEMA_NAME, + "test_inputs_function", + ImmutableList.of( + TableArgumentSpecification.builder() + .rowSemantics() + .name("INPUT_1") + .build(), + TableArgumentSpecification.builder() + .name("INPUT_2") + .keepWhenEmpty() + .build(), + TableArgumentSpecification.builder() + .name("INPUT_3") + .keepWhenEmpty() + .build(), + TableArgumentSpecification.builder() + .name("INPUT_4") + .keepWhenEmpty() + .build()), + new DescribedTable(new Descriptor(ImmutableList.of(new Descriptor.Field("boolean_result", Optional.of(BOOLEAN)))))); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + return TableFunctionAnalysis.builder() + .handle(new EmptyTableFunctionHandle("")) + .requiredColumns("INPUT_1", IntStream.range(0, ((TableArgument) arguments.get("INPUT_1")).getRowType().getFields().size()).boxed().collect(toImmutableList())) + .requiredColumns("INPUT_2", IntStream.range(0, ((TableArgument) arguments.get("INPUT_2")).getRowType().getFields().size()).boxed().collect(toImmutableList())) + .requiredColumns("INPUT_3", IntStream.range(0, ((TableArgument) arguments.get("INPUT_3")).getRowType().getFields().size()).boxed().collect(toImmutableList())) + .requiredColumns("INPUT_4", IntStream.range(0, ((TableArgument) arguments.get("INPUT_4")).getRowType().getFields().size()).boxed().collect(toImmutableList())) + .build(); + } + + public static class TestInputsFunctionProcessorProvider + implements TableFunctionProcessorProvider + { + @Override + public TableFunctionDataProcessor getDataProcessor(ConnectorTableFunctionHandle handle) + { + BlockBuilder resultBuilder = BOOLEAN.createBlockBuilder(null, 1); + BOOLEAN.writeBoolean(resultBuilder, true); + + Page result = new Page(resultBuilder.build()); + + return input -> { + if (input == null) { + return FINISHED; + } + return usedInputAndProduced(result); + }; + } + } + } + + public static class PassThroughInputFunction + extends AbstractConnectorTableFunction + { + public PassThroughInputFunction() + { + super( + SCHEMA_NAME, + "pass_through", + ImmutableList.of( + TableArgumentSpecification.builder() + .name("INPUT_1") + .passThroughColumns() + .keepWhenEmpty() + .build(), + TableArgumentSpecification.builder() + .name("INPUT_2") + .passThroughColumns() + .keepWhenEmpty() + .build()), + new DescribedTable(new Descriptor(ImmutableList.of( + new Descriptor.Field("input_1_present", Optional.of(BOOLEAN)), + new Descriptor.Field("input_2_present", Optional.of(BOOLEAN)))))); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + return TableFunctionAnalysis.builder() + .handle(new EmptyTableFunctionHandle("")) + .requiredColumns("INPUT_1", ImmutableList.of(0)) + .requiredColumns("INPUT_2", ImmutableList.of(0)) + .build(); + } + + public static class PassThroughInputProcessorProvider + implements TableFunctionProcessorProvider + { + @Override + public TableFunctionDataProcessor getDataProcessor(ConnectorTableFunctionHandle handle) + { + return new PassThroughInputDataProcessor(); + } + } + + private static class PassThroughInputDataProcessor + implements TableFunctionDataProcessor + { + private boolean input1Present; + private boolean input2Present; + private int input1EndIndex; + private int input2EndIndex; + private boolean finished; + + @Override + public TableFunctionProcessorState process(List> input) + { + if (finished) { + return FINISHED; + } + if (input == null) { + finished = true; + + // proper column input_1_present + BlockBuilder input1Builder = BOOLEAN.createBlockBuilder(null, 1); + BOOLEAN.writeBoolean(input1Builder, input1Present); + + // proper column input_2_present + BlockBuilder input2Builder = BOOLEAN.createBlockBuilder(null, 1); + BOOLEAN.writeBoolean(input2Builder, input2Present); + + // pass-through index for input_1 + BlockBuilder input1PassThroughBuilder = BIGINT.createBlockBuilder(null, 1); + if (input1Present) { + input1PassThroughBuilder.writeLong(input1EndIndex - 1); + } + else { + input1PassThroughBuilder.appendNull(); + } + + // pass-through index for input_2 + BlockBuilder input2PassThroughBuilder = BIGINT.createBlockBuilder(null, 1); + if (input2Present) { + input2PassThroughBuilder.writeLong(input2EndIndex - 1); + } + else { + input2PassThroughBuilder.appendNull(); + } + + return produced(new Page(input1Builder.build(), input2Builder.build(), input1PassThroughBuilder.build(), input2PassThroughBuilder.build())); + } + input.get(0).ifPresent(page -> { + input1Present = true; + input1EndIndex += page.getPositionCount(); + }); + input.get(1).ifPresent(page -> { + input2Present = true; + input2EndIndex += page.getPositionCount(); + }); + return usedInput(); + } + } + } + + public static class TestInputFunction + extends AbstractConnectorTableFunction + { + public TestInputFunction() + { + super( + SCHEMA_NAME, + "test_input", + ImmutableList.of(TableArgumentSpecification.builder() + .name("INPUT") + .keepWhenEmpty() + .build()), + new DescribedTable(new Descriptor(ImmutableList.of(new Descriptor.Field("got_input", Optional.of(BOOLEAN)))))); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + return TableFunctionAnalysis.builder() + .handle(new EmptyTableFunctionHandle("")) + .requiredColumns("INPUT", IntStream.range(0, ((TableArgument) arguments.get("INPUT")).getRowType().getFields().size()).boxed().collect(toImmutableList())) + .build(); + } + + public static class TestInputProcessorProvider + implements TableFunctionProcessorProvider + { + @Override + public TableFunctionDataProcessor getDataProcessor(ConnectorTableFunctionHandle handle) + { + return new TestInputDataProcessor(); + } + } + + private static class TestInputDataProcessor + implements TableFunctionDataProcessor + { + private boolean processorGotInput; + private boolean finished; + + @Override + public TableFunctionProcessorState process(List> input) + { + if (finished) { + return FINISHED; + } + if (input == null) { + finished = true; + BlockBuilder builder = BOOLEAN.createBlockBuilder(null, 1); + BOOLEAN.writeBoolean(builder, processorGotInput); + return produced(new Page(builder.build())); + } + processorGotInput = true; + return usedInput(); + } + } + } + + public static class TestSingleInputRowSemanticsFunction + extends AbstractConnectorTableFunction + { + public TestSingleInputRowSemanticsFunction() + { + super( + SCHEMA_NAME, + "test_single_input_function", + ImmutableList.of(TableArgumentSpecification.builder() + .rowSemantics() + .name("INPUT") + .build()), + new DescribedTable(new Descriptor(ImmutableList.of(new Descriptor.Field("boolean_result", Optional.of(BOOLEAN)))))); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + return TableFunctionAnalysis.builder() + .handle(new EmptyTableFunctionHandle("")) + .requiredColumns("INPUT", IntStream.range(0, ((TableArgument) arguments.get("INPUT")).getRowType().getFields().size()).boxed().collect(toImmutableList())) + .build(); + } + + public static class TestSingleInputFunctionProcessorProvider + implements TableFunctionProcessorProvider + { + @Override + public TableFunctionDataProcessor getDataProcessor(ConnectorTableFunctionHandle handle) + { + BlockBuilder builder = BOOLEAN.createBlockBuilder(null, 1); + BOOLEAN.writeBoolean(builder, true); + Page result = new Page(builder.build()); + + return input -> { + if (input == null) { + return FINISHED; + } + return usedInputAndProduced(result); + }; + } + } + } + + public static class ConstantFunction + extends AbstractConnectorTableFunction + { + public ConstantFunction() + { + super( + SCHEMA_NAME, + "constant", + ImmutableList.of( + ScalarArgumentSpecification.builder() + .name("VALUE") + .type(INTEGER) + .build(), + ScalarArgumentSpecification.builder() + .name("N") + .type(INTEGER) + .defaultValue(1L) + .build()), + new DescribedTable(Descriptor.descriptor( + ImmutableList.of("constant_column"), + ImmutableList.of(INTEGER)))); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + ScalarArgument count = (ScalarArgument) arguments.get("N"); + requireNonNull(count.getValue(), "count value for function repeat() is null"); + checkArgument((long) count.getValue() > 0, "count value for function repeat() must be positive"); + + return TableFunctionAnalysis.builder() + .handle(new ConstantFunctionHandle((Long) ((ScalarArgument) arguments.get("VALUE")).getValue(), (long) count.getValue())) + .build(); + } + + public static class ConstantFunctionHandle + implements ConnectorTableFunctionHandle + { + private final Long value; + private final long count; + + @JsonCreator + public ConstantFunctionHandle(@JsonProperty("value") Long value, @JsonProperty("count") long count) + { + this.value = value; + this.count = count; + } + + @JsonProperty + public Long getValue() + { + return value; + } + + @JsonProperty + public long getCount() + { + return count; + } + } + + public static class ConstantFunctionProcessorProvider + implements TableFunctionProcessorProvider + { + @Override + public TableFunctionSplitProcessor getSplitProcessor(ConnectorTableFunctionHandle handle) + { + return new ConstantFunctionProcessor(((ConstantFunctionHandle) handle).getValue()); + } + } + + public static class ConstantFunctionProcessor + implements TableFunctionSplitProcessor + { + private static final int PAGE_SIZE = 1000; + + private final Long value; + + private long fullPagesCount; + private long processedPages; + private int reminder; + private Block block; + + public ConstantFunctionProcessor(Long value) + { + this.value = value; + } + + @Override + public TableFunctionProcessorState process(ConnectorSplit split) + { + boolean usedData = false; + + if (split != null) { + long count = ((ConstantFunctionSplit) split).getCount(); + this.fullPagesCount = count / PAGE_SIZE; + this.reminder = toIntExact(count % PAGE_SIZE); + if (fullPagesCount > 0) { + BlockBuilder builder = INTEGER.createBlockBuilder(null, PAGE_SIZE); + if (value == null) { + for (int i = 0; i < PAGE_SIZE; i++) { + builder.appendNull(); + } + } + else { + for (int i = 0; i < PAGE_SIZE; i++) { + builder.writeInt(toIntExact(value)); + } + } + this.block = builder.build(); + } + else { + BlockBuilder builder = INTEGER.createBlockBuilder(null, reminder); + if (value == null) { + for (int i = 0; i < reminder; i++) { + builder.appendNull(); + } + } + else { + for (int i = 0; i < reminder; i++) { + builder.writeInt(toIntExact(value)); + } + } + this.block = builder.build(); + } + usedData = true; + } + + if (processedPages < fullPagesCount) { + processedPages++; + Page result = new Page(block); + if (usedData) { + return usedInputAndProduced(result); + } + return produced(result); + } + + if (reminder > 0) { + Page result = new Page(block.getRegion(0, toIntExact(reminder))); + reminder = 0; + if (usedData) { + return usedInputAndProduced(result); + } + return produced(result); + } + + return FINISHED; + } + } + + public static ConnectorSplitSource getConstantFunctionSplitSource(ConstantFunctionHandle handle) + { + long splitSize = DEFAULT_SPLIT_SIZE; + ImmutableList.Builder splits = ImmutableList.builder(); + for (long i = 0; i < handle.getCount() / splitSize; i++) { + splits.add(new ConstantFunctionSplit(splitSize)); + } + long remainingSize = handle.getCount() % splitSize; + if (remainingSize > 0) { + splits.add(new ConstantFunctionSplit(remainingSize)); + } + return new FixedSplitSource(splits.build()); + } + + public static final class ConstantFunctionSplit + implements ConnectorSplit + { + private static final int INSTANCE_SIZE = toIntExact(ClassLayout.parseClass(ConstantFunctionSplit.class).instanceSize()); + public static final int DEFAULT_SPLIT_SIZE = 5500; + + private final long count; + + @JsonCreator + public ConstantFunctionSplit(@JsonProperty("count") long count) + { + this.count = count; + } + + @JsonProperty + public long getCount() + { + return count; + } + + @Override + public NodeSelectionStrategy getNodeSelectionStrategy() + { + return NO_PREFERENCE; + } + + @Override + public List getPreferredNodes(NodeProvider nodeProvider) + { + return Collections.emptyList(); + } + + @Override + public Object getInfo() + { + return count; + } + } + } + + public static class EmptySourceFunction + extends AbstractConnectorTableFunction + { + public EmptySourceFunction() + { + super( + SCHEMA_NAME, + "empty_source", + ImmutableList.of(), + new DescribedTable(new Descriptor(ImmutableList.of(new Descriptor.Field("column", Optional.of(BOOLEAN)))))); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + return TableFunctionAnalysis.builder() + .handle(new EmptyTableFunctionHandle("")) + .build(); + } + + public static class EmptySourceFunctionProcessorProvider + implements TableFunctionProcessorProvider + { + @Override + public TableFunctionSplitProcessor getSplitProcessor(ConnectorTableFunctionHandle handle) + { + return new EmptySourceFunctionProcessor(); + } + } + + public static class EmptySourceFunctionProcessor + implements TableFunctionSplitProcessor + { + private static final Page EMPTY_PAGE = new Page(BOOLEAN.createBlockBuilder(null, 0).build()); + + @Override + public TableFunctionProcessorState process(ConnectorSplit split) + { + if (split == null) { + return FINISHED; + } + + return usedInputAndProduced(EMPTY_PAGE); + } + } + } + + @JsonInclude(JsonInclude.Include.ALWAYS) + public static class EmptyTableFunctionHandle + implements ConnectorTableFunctionHandle + { + public final String dummy; + + @JsonCreator + public EmptyTableFunctionHandle(@JsonProperty("dummy") String dummy) + { + this.dummy = dummy; + } + + @JsonProperty + public String getDummy() + { + return dummy; + } + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/metadata/AbstractMockMetadata.java b/presto-main-base/src/test/java/com/facebook/presto/metadata/AbstractMockMetadata.java index 3c267a101dd86..c830bd2429cc6 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/metadata/AbstractMockMetadata.java +++ b/presto-main-base/src/test/java/com/facebook/presto/metadata/AbstractMockMetadata.java @@ -37,6 +37,7 @@ import com.facebook.presto.spi.connector.ConnectorCapabilities; import com.facebook.presto.spi.connector.ConnectorOutputMetadata; import com.facebook.presto.spi.connector.ConnectorTableVersion; +import com.facebook.presto.spi.connector.TableFunctionApplicationResult; import com.facebook.presto.spi.constraints.TableConstraint; import com.facebook.presto.spi.function.SqlFunction; import com.facebook.presto.spi.plan.PartitioningHandle; @@ -673,4 +674,10 @@ public void addConstraint(Session session, TableHandle tableHandle, TableConstra { throw new UnsupportedOperationException(); } + + @Override + public Optional> applyTableFunction(Session session, TableFunctionHandle handle) + { + return Optional.empty(); + } } diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/AbstractAnalyzerTest.java b/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/AbstractAnalyzerTest.java index 01e0ca48371ff..30311ee939a05 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/AbstractAnalyzerTest.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/AbstractAnalyzerTest.java @@ -22,6 +22,16 @@ import com.facebook.presto.common.type.StandardTypes; import com.facebook.presto.connector.informationSchema.InformationSchemaConnector; import com.facebook.presto.connector.system.SystemConnector; +import com.facebook.presto.connector.tvf.TestingTableFunctions.DescriptorArgumentFunction; +import com.facebook.presto.connector.tvf.TestingTableFunctions.MonomorphicStaticReturnTypeFunction; +import com.facebook.presto.connector.tvf.TestingTableFunctions.OnlyPassThroughFunction; +import com.facebook.presto.connector.tvf.TestingTableFunctions.PassThroughFunction; +import com.facebook.presto.connector.tvf.TestingTableFunctions.PolymorphicStaticReturnTypeFunction; +import com.facebook.presto.connector.tvf.TestingTableFunctions.RequiredColumnsFunction; +import com.facebook.presto.connector.tvf.TestingTableFunctions.TableArgumentFunction; +import com.facebook.presto.connector.tvf.TestingTableFunctions.TableArgumentRowSemanticsFunction; +import com.facebook.presto.connector.tvf.TestingTableFunctions.TwoScalarArgumentsFunction; +import com.facebook.presto.connector.tvf.TestingTableFunctions.TwoTableArgumentsFunction; import com.facebook.presto.execution.warnings.WarningCollectorConfig; import com.facebook.presto.functionNamespace.SqlInvokedFunctionNamespaceManagerConfig; import com.facebook.presto.functionNamespace.execution.NoopSqlFunctionExecutor; @@ -60,6 +70,7 @@ import com.facebook.presto.transaction.TransactionManager; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import org.intellij.lang.annotations.Language; import org.testng.annotations.BeforeClass; @@ -150,6 +161,20 @@ public void setup() metadata.getFunctionAndTypeManager().createFunction(SQL_FUNCTION_SQUARE, true); + metadata.getFunctionAndTypeManager().getTableFunctionRegistry().addTableFunctions( + TPCH_CONNECTOR_ID, + ImmutableSet.of( + new TwoScalarArgumentsFunction(), + new TableArgumentFunction(), + new TableArgumentRowSemanticsFunction(), + new DescriptorArgumentFunction(), + new TwoTableArgumentsFunction(), + new OnlyPassThroughFunction(), + new MonomorphicStaticReturnTypeFunction(), + new PolymorphicStaticReturnTypeFunction(), + new PassThroughFunction(), + new RequiredColumnsFunction())); + Catalog tpchTestCatalog = createTestingCatalog(TPCH_CATALOG, TPCH_CONNECTOR_ID); catalogManager.registerCatalog(tpchTestCatalog); metadata.getAnalyzePropertyManager().addProperties(TPCH_CONNECTOR_ID, tpchTestCatalog.getConnector(TPCH_CONNECTOR_ID).getAnalyzeProperties()); @@ -506,7 +531,12 @@ protected void assertFails(SemanticErrorCode error, int line, int column, @Langu protected void assertFails(SemanticErrorCode error, String message, @Language("SQL") String query) { - assertFails(CLIENT_SESSION, error, message, query); + assertFails(CLIENT_SESSION, error, message, query, false); + } + + protected void assertFailsExact(SemanticErrorCode error, String message, @Language("SQL") String query) + { + assertFails(CLIENT_SESSION, error, message, query, true); } protected void assertFails(Session session, SemanticErrorCode error, @Language("SQL") String query) @@ -514,6 +544,11 @@ protected void assertFails(Session session, SemanticErrorCode error, @Language(" assertFails(session, error, Optional.empty(), query); } + protected void assertFails(Session session, SemanticErrorCode error, String message, @Language("SQL") String query) + { + assertFails(session, error, message, query, false); + } + private void assertFails(Session session, SemanticErrorCode error, Optional location, @Language("SQL") String query) { try { @@ -542,7 +577,7 @@ private void assertFails(Session session, SemanticErrorCode error, Optional 'foo'))"); + analyze("SELECT * FROM TABLE(system.two_arguments_function('foo', 1))"); + analyze("SELECT * FROM TABLE(system.two_arguments_function(text => 'foo', number => 1))"); + + assertFails(INVALID_ARGUMENTS, + "line 1:51: All arguments must be passed by name or all must be passed positionally", + "SELECT * FROM TABLE(system.two_arguments_function('foo', number => 1))"); + + assertFails(INVALID_ARGUMENTS, + "line 1:51: All arguments must be passed by name or all must be passed positionally", + "SELECT * FROM TABLE(system.two_arguments_function(text => 'foo', 1))"); + + assertFails(INVALID_FUNCTION_ARGUMENT, + "line 1:66: Duplicate argument name: TEXT", + "SELECT * FROM TABLE(system.two_arguments_function(text => 'foo', text => 'bar'))"); + + // argument names are resolved in the canonical form + assertFails(INVALID_FUNCTION_ARGUMENT, + "line 1:66: Duplicate argument name: TEXT", + "SELECT * FROM TABLE(system.two_arguments_function(text => 'foo', TeXt => 'bar'))"); + + assertFails(INVALID_FUNCTION_ARGUMENT, + "line 1:66: Unexpected argument name: BAR", + "SELECT * FROM TABLE(system.two_arguments_function(text => 'foo', bar => 'bar'))"); + + assertFails(MISSING_ARGUMENT, + "line 1:51: Missing argument: TEXT", + "SELECT * FROM TABLE(system.two_arguments_function(number => 1))"); + } + + @Test + public void testTableArgument() + { + // cannot pass a table function as the argument + assertFails(NOT_SUPPORTED, + "line 1:52: Invalid table argument INPUT. Table functions are not allowed as table function arguments", + "SELECT * FROM TABLE(system.table_argument_function(input => my_schema.my_table_function(1)))"); + + assertThatThrownBy(() -> analyze("SELECT * FROM TABLE(system.table_argument_function(input => my_schema.my_table_function(arg => 1)))")) + .isInstanceOf(ParsingException.class) + .hasMessageContaining("line 1:93: mismatched input '=>'."); + + // cannot pass a table function as the argument, also preceding nested table function with TABLE is incorrect + assertThatThrownBy(() -> analyze("SELECT * FROM TABLE(system.table_argument_function(input => TABLE(my_schema.my_table_function(1))))")) + .isInstanceOf(ParsingException.class) + .hasMessageContaining("line 1:94: mismatched input '('."); + + // a table passed as the argument must be preceded with TABLE + analyze("SELECT * FROM TABLE(system.table_argument_function(input => TABLE(t1)))"); + + assertFails(INVALID_FUNCTION_ARGUMENT, + "line 1:52: Invalid argument INPUT. Expected table, got expression", + "SELECT * FROM TABLE(system.table_argument_function(input => t1))"); + + // a query passed as the argument must be preceded with TABLE + analyze("SELECT * FROM TABLE(system.table_argument_function(input => TABLE(SELECT * FROM t1)))"); + + assertThatThrownBy(() -> analyze("SELECT * FROM TABLE(system.table_argument_function(input => SELECT * FROM t1))")) + .isInstanceOf(ParsingException.class) + .hasMessageContaining("line 1:61: mismatched input 'SELECT'."); + + // query passed as the argument is correlated + analyze("SELECT * FROM t1 CROSS JOIN LATERAL (SELECT * FROM TABLE(system.table_argument_function(input => TABLE(SELECT 1 WHERE a > 0))))"); + + // wrong argument type + assertFails(INVALID_FUNCTION_ARGUMENT, + "line 1:52: Invalid argument INPUT. Expected table, got expression", + "SELECT * FROM TABLE(system.table_argument_function(input => 'foo'))"); + assertFails(INVALID_FUNCTION_ARGUMENT, + "line 1:52: Invalid argument INPUT. Expected table, got descriptor", + "SELECT * FROM TABLE(system.table_argument_function(input => DESCRIPTOR(x int, y int)))"); + } + + @Test + public void testTableArgumentProperties() + { + analyze("SELECT * FROM TABLE(system.table_argument_function(input => TABLE(t1) PARTITION BY a KEEP WHEN EMPTY ORDER BY b))"); + + assertFails(INVALID_FUNCTION_ARGUMENT, + "line 1:66: Invalid argument INPUT. Partitioning specified for table argument with row semantics", + "SELECT * FROM TABLE(system.table_argument_row_semantics_function(input => TABLE(t1) PARTITION BY a))"); + + assertFails(COLUMN_NOT_FOUND, + "line 1:92: Column b is not present in the input relation", + "SELECT * FROM TABLE(system.table_argument_function(input => TABLE(SELECT 1 a) PARTITION BY b))"); + + assertFails(INVALID_COLUMN_REFERENCE, + "line 1:88: Expected column reference. Actual: 1", + "SELECT * FROM TABLE(system.table_argument_function(input => TABLE(SELECT 1 a) ORDER BY 1))"); + + assertFails(TYPE_MISMATCH, + "line 1:104: HyperLogLog is not comparable, and therefore cannot be used in PARTITION BY", + "SELECT * FROM TABLE(system.table_argument_function(input => TABLE(SELECT approx_set(1) a) PARTITION BY a))"); + + assertFails(INVALID_FUNCTION_ARGUMENT, "line 1:66: Invalid argument INPUT. Ordering specified for table argument with row semantics", + "SELECT * FROM TABLE(system.table_argument_row_semantics_function(input => TABLE(t1) ORDER BY a))"); + + assertFails(COLUMN_NOT_FOUND, + "line 1:88: Column b is not present in the input relation", + "SELECT * FROM TABLE(system.table_argument_function(input => TABLE(SELECT 1 a) ORDER BY b))"); + + assertFails(INVALID_COLUMN_REFERENCE, + "line 1:88: Expected column reference. Actual: 1", + "SELECT * FROM TABLE(system.table_argument_function(input => TABLE(SELECT 1 a) ORDER BY 1))"); + + assertFails(TYPE_MISMATCH, + "line 1:100: HyperLogLog is not orderable, and therefore cannot be used in ORDER BY", + "SELECT * FROM TABLE(system.table_argument_function(input => TABLE(SELECT approx_set(1) a) ORDER BY a))"); + + assertFails(INVALID_FUNCTION_ARGUMENT, + "line 1:85: Invalid argument INPUT. Empty behavior specified for table argument with row semantics", + "SELECT * FROM TABLE(system.table_argument_row_semantics_function(input => TABLE(t1) PRUNE WHEN EMPTY))"); + + assertFails(INVALID_FUNCTION_ARGUMENT, + "line 1:85: Invalid argument INPUT. Empty behavior specified for table argument with row semantics", + "SELECT * FROM TABLE(system.table_argument_row_semantics_function(input => TABLE(t1) KEEP WHEN EMPTY))"); + } + + @Test + public void testDescriptorArgument() + { + analyze("SELECT * FROM TABLE(system.descriptor_argument_function(schema => DESCRIPTOR(x integer, y boolean)))"); + + assertFailsExact(INVALID_FUNCTION_ARGUMENT, + "line 1:57: Invalid descriptor argument SCHEMA. Descriptors should be formatted as 'DESCRIPTOR(name [type], ...)'", + "SELECT * FROM TABLE(system.descriptor_argument_function(schema => DESCRIPTOR(1 + 2)))"); + + assertFails(INVALID_FUNCTION_ARGUMENT, + "line 1:57: Invalid argument SCHEMA. Expected descriptor, got expression", + "SELECT * FROM TABLE(system.descriptor_argument_function(schema => 1))"); + + assertFails(INVALID_FUNCTION_ARGUMENT, + "line 1:57: Invalid argument SCHEMA. Expected descriptor, got table", + "SELECT * FROM TABLE(system.descriptor_argument_function(schema => TABLE(t1)))"); + + assertFails(TYPE_MISMATCH, + "line 1:78: Unknown type: verybigint", + "SELECT * FROM TABLE(system.descriptor_argument_function(schema => DESCRIPTOR(x verybigint)))"); + } + + @Test + public void testScalarArgument() + { + analyze("SELECT * FROM TABLE(system.two_arguments_function('foo', 1))"); + + assertFails(INVALID_FUNCTION_ARGUMENT, + "line 1:64: Invalid argument NUMBER. Expected expression, got descriptor", + "SELECT * FROM TABLE(system.two_arguments_function(text => 'a', number => DESCRIPTOR(x integer, y boolean)))"); + + assertFails(INVALID_FUNCTION_ARGUMENT, + "line 1:64: 'descriptor' function is not allowed as a table function argument", + "SELECT * FROM TABLE(system.two_arguments_function(text => 'a', number => DESCRIPTOR(1 + 2)))"); + + assertFails(INVALID_FUNCTION_ARGUMENT, + "line 1:64: Invalid argument NUMBER. Expected expression, got table", + "SELECT * FROM TABLE(system.two_arguments_function(text => 'a', number => TABLE(t1)))"); + + assertFails(EXPRESSION_NOT_CONSTANT, + "line 1:74: Constant expression cannot contain a subquery", + "SELECT * FROM TABLE(system.two_arguments_function(text => 'a', number => (SELECT 1)))"); + } + + @Test + public void testCopartitioning() + { + // TABLE(t1) is matched by fully qualified name: tpch.s1.t1. It matches the second copartition item s1.t1. + // Aliased relation TABLE(SELECT 1, 2) t1(x, y) is matched by unqualified name. It matches the first copartition item t1. + analyze("SELECT * FROM TABLE(system.two_table_arguments_function(" + + "input1 => TABLE(t1) PARTITION BY (a, b)," + + "input2 => TABLE(SELECT 1, 2) t1(x, y) PARTITION BY (x, y)" + + "COPARTITION (t1, s1.t1)))"); + + // Copartition items t1, t2 are first matched to arguments by unqualified names, and when no match is found, by fully qualified names. + // TABLE(tpch.s1.t1) is matched by fully qualified name. It matches the first copartition item t1. + // TABLE(s1.t2) is matched by unqualified name: tpch.s1.t2. It matches the second copartition item t2. + analyze("SELECT * FROM TABLE(system.two_table_arguments_function(" + + "input1 => TABLE(tpch.s1.t1) PARTITION BY (a, b)," + + "input2 => TABLE(s1.t2) PARTITION BY (a, b)" + + "COPARTITION (t1, t2)))"); + + assertFails(INVALID_COPARTITIONING, + "No table argument found for name: s1.foo", + "SELECT * FROM TABLE(system.two_table_arguments_function(" + + "input1 => TABLE(t1) PARTITION BY (a, b)," + + "input2 => TABLE(t2) PARTITION BY (a, b)" + + "COPARTITION (t1, s1.foo)))"); + + // Both table arguments are matched by fully qualified name: tpch.s1.t1 + assertFails(INVALID_COPARTITIONING, "Ambiguous reference: multiple table arguments found for name: t1", + "SELECT * FROM TABLE(system.two_table_arguments_function(" + + "input1 => TABLE(t1) PARTITION BY (a, b)," + + "input2 => TABLE(t1) PARTITION BY (a, b)" + + "COPARTITION (t1, t2)))"); + + // Both table arguments are matched by unqualified name: t1 + assertFails(INVALID_COPARTITIONING, + "Ambiguous reference: multiple table arguments found for name: t1", + "SELECT * FROM TABLE(system.two_table_arguments_function(" + + "input1 => TABLE(SELECT 1, 2) t1(a, b) PARTITION BY (a, b)," + + "input2 => TABLE(SELECT 3, 4) t1(c, d) PARTITION BY (c, d)" + + "COPARTITION (t1, t2)))"); + + assertFails(INVALID_COPARTITIONING, + "Multiple references to table argument: t1 in COPARTITION clause", + "SELECT * FROM TABLE(system.two_table_arguments_function(" + + "input1 => TABLE(t1) PARTITION BY (a, b)," + + "input2 => TABLE(t2) PARTITION BY (a, b)" + + "COPARTITION (t1, t1)))"); + } + + @Test + public void testCopartitionColumns() + { + assertFails(INVALID_COPARTITIONING, + "line 1:67: Table tpch.s1.t1 referenced in COPARTITION clause is not partitioned", + "SELECT * FROM TABLE(system.two_table_arguments_function(" + + "input1 => TABLE(t1)," + + "input2 => TABLE(t2) PARTITION BY (a, b)" + + "COPARTITION (t1, t2)))"); + + assertFails(INVALID_COPARTITIONING, + "line 1:67: No partitioning columns specified for table tpch.s1.t1 referenced in COPARTITION clause", + "SELECT * FROM TABLE(system.two_table_arguments_function(" + + "input1 => TABLE(t1) PARTITION BY ()," + + "input2 => TABLE(t2) PARTITION BY ()" + + "COPARTITION (t1, t2)))"); + + assertFails(INVALID_COPARTITIONING, + "Numbers of partitioning columns in copartitioned tables do not match", + "SELECT * FROM TABLE(system.two_table_arguments_function(" + + "input1 => TABLE(t1) PARTITION BY (a, b)," + + "input2 => TABLE(t2) PARTITION BY (a)" + + "COPARTITION (t1, t2)))"); + + assertFails(TYPE_MISMATCH, + "Partitioning columns in copartitioned tables have incompatible types", + "SELECT * FROM TABLE(system.two_table_arguments_function(" + + "input1 => TABLE(SELECT 1) t1(a) PARTITION BY (a)," + + "input2 => TABLE(SELECT 'x') t2(b) PARTITION BY (b)" + + "COPARTITION (t1, t2)))"); + } + + @Test + public void testNullArguments() + { + // cannot pass null for table argument + assertFails(INVALID_FUNCTION_ARGUMENT, + "line 1:52: Invalid argument INPUT. Expected table, got expression", + "SELECT * FROM TABLE(system.table_argument_function(input => null))"); + + // the wrong way to pass null for descriptor + assertFails(INVALID_FUNCTION_ARGUMENT, + "line 1:57: Invalid argument SCHEMA. Expected descriptor, got expression", + "SELECT * FROM TABLE(system.descriptor_argument_function(schema => null))"); + + // the right way to pass null for descriptor + analyze("SELECT * FROM TABLE(system.descriptor_argument_function(schema => CAST(null AS DESCRIPTOR)))"); + + // the default value for the argument schema is null + analyze("SELECT * FROM TABLE(system.descriptor_argument_function())"); + + analyze("SELECT * FROM TABLE(system.two_arguments_function(null, null))"); + + // the default value for the second argument is null + analyze("SELECT * FROM TABLE(system.two_arguments_function('a'))"); + } + + @Test + public void testTableFunctionInvocationContext() + { + // cannot specify relation alias for table function with ONLY PASS THROUGH return type + assertFails(INVALID_TABLE_FUNCTION_INVOCATION, + "line 1:21: Alias specified for table function with ONLY PASS THROUGH return type", + "SELECT * FROM TABLE(system.only_pass_through_function(TABLE(t1))) f(x)"); + + // per SQL standard, relation alias is required for table function with GENERIC TABLE return type. We don't require it. + analyze("SELECT * FROM TABLE(system.two_arguments_function('a', 1)) f(x)"); + analyze("SELECT * FROM TABLE(system.two_arguments_function('a', 1))"); + + // per SQL standard, relation alias is required for table function with statically declared return type, only if the function is polymorphic. + // We don't require aliasing polymorphic functions. + analyze("SELECT * FROM TABLE(system.monomorphic_static_return_type_function())"); + analyze("SELECT * FROM TABLE(system.monomorphic_static_return_type_function()) f(x, y)"); + analyze("SELECT * FROM TABLE(system.polymorphic_static_return_type_function(input => TABLE(t1)))"); + analyze("SELECT * FROM TABLE(system.polymorphic_static_return_type_function(input => TABLE(t1))) f(x, y)"); + + // sampled + assertFails(INVALID_TABLE_FUNCTION_INVOCATION, + "line 1:21: Cannot apply sample to polymorphic table function invocation", + "SELECT * FROM TABLE(system.only_pass_through_function(TABLE(t1))) TABLESAMPLE BERNOULLI (10)"); + +// // row pattern matching +// assertFails(INVALID_TABLE_FUNCTION_INVOCATION, +// "line 2:12: Cannot apply row pattern matching to polymorphic table function invocation", +// "SELECT * FROM TABLE(system.only_pass_through_function(TABLE(t1))) MATCH_RECOGNIZE( PATTERN (a*) DEFINE a AS true)"); + + // aliased + sampled + assertFails(INVALID_TABLE_FUNCTION_INVOCATION, + "line 1:15: Cannot apply sample to polymorphic table function invocation", + "SELECT * FROM TABLE(system.two_arguments_function('a', 1)) f(x) TABLESAMPLE BERNOULLI (10)"); + +// // aliased + row pattern matching +// assertFails(INVALID_TABLE_FUNCTION_INVOCATION, +// "line 2:6: Cannot apply row pattern matching to polymorphic table function invocation", +// "SELECT * FROM TABLE(system.two_arguments_function('a', 1)) f(x) MATCH_RECOGNIZE( PATTERN (a*) DEFINE a AS true ) t(y)"); +// +// // row pattern matching + sampled +// assertFails(INVALID_TABLE_FUNCTION_INVOCATION, +// "line 2:12: Cannot apply row pattern matching to polymorphic table function invocation", +// "SELECT * FROM TABLE(system.only_pass_through_function(TABLE(t1))) MATCH_RECOGNIZE( PATTERN (a*) DEFINE a AS true) TABLESAMPLE BERNOULLI (10)"); +// +// // aliased + row pattern matching + sampled +// assertFails(INVALID_TABLE_FUNCTION_INVOCATION, +// "line 2:6: Cannot apply row pattern matching to polymorphic table function invocation", +// "SELECT * FROM TABLE(system.two_arguments_function('a', 1)) f(x) MATCH_RECOGNIZE( PATTERN (a*) DEFINE a AS true ) t(y) TABLESAMPLE BERNOULLI (10)"); + } + + @Test + public void testTableFunctionAliasing() + { + // case-insensitive name matching + assertFails(DUPLICATE_RANGE_VARIABLE, + "line 1:64: Relation alias: T1 is a duplicate of input table name: tpch.s1.t1", + "SELECT * FROM TABLE(system.table_argument_function(TABLE(t1))) T1(x)"); + + assertFails(DUPLICATE_RANGE_VARIABLE, + "line 1:76: Relation alias: t1 is a duplicate of input table name: t1", + "SELECT * FROM TABLE(system.table_argument_function(TABLE(SELECT 1) T1(a))) t1(x)"); + + analyze("SELECT * FROM TABLE(system.table_argument_function(TABLE(t1) t2)) T1(x)"); + + // the original returned relation type is ("column" : BOOLEAN) + analyze("SELECT column FROM TABLE(system.two_arguments_function('a', 1)) table_alias"); + + analyze("SELECT column_alias FROM TABLE(system.two_arguments_function('a', 1)) table_alias(column_alias)"); + + analyze("SELECT table_alias.column_alias FROM TABLE(system.two_arguments_function('a', 1)) table_alias(column_alias)"); + + assertFails(MISSING_ATTRIBUTE, + "line 1:8: Column 'column' cannot be resolved", + "SELECT column FROM TABLE(system.two_arguments_function('a', 1)) table_alias(column_alias)"); + + assertFails(MISMATCHED_COLUMN_ALIASES, + "line 1:20: Column alias list has 3 entries but table function has 1 proper columns", + "SELECT column FROM TABLE(system.two_arguments_function('a', 1)) table_alias(col1, col2, col3)"); + + // the original returned relation type is ("a" : BOOLEAN, "b" : INTEGER) + analyze("SELECT column_alias_1, column_alias_2 FROM TABLE(system.monomorphic_static_return_type_function()) table_alias(column_alias_1, column_alias_2)"); + + assertFails(DUPLICATE_COLUMN_NAME, + "line 1:21: Duplicate name of table function proper column: col", + "SELECT * FROM TABLE(system.monomorphic_static_return_type_function()) table_alias(col, col)"); + + // case-insensitive name matching + assertFails(DUPLICATE_COLUMN_NAME, + "line 1:21: Duplicate name of table function proper column: col", + "SELECT * FROM TABLE(system.monomorphic_static_return_type_function()) table_alias(col, COL)"); + + // pass-through columns of an input table must not be aliased, and must be referenced by the original range variables of their corresponding table arguments + // the function pass_through_function has one proper column ("x" : BOOLEAN), and one table argument with pass-through property + // tha alias applies only to the proper column + analyze("SELECT table_alias.x, t1.a, t1.b, t1.c, t1.d FROM TABLE(system.pass_through_function(TABLE(t1))) table_alias"); + + analyze("SELECT table_alias.x, arg_alias.a, arg_alias.b, arg_alias.c, arg_alias.d FROM TABLE(system.pass_through_function(TABLE(t1) arg_alias)) table_alias"); + + assertFails(MISSING_ATTRIBUTE, + "line 1:23: 't1.a' cannot be resolved", + "SELECT table_alias.x, t1.a FROM TABLE(system.pass_through_function(TABLE(t1) arg_alias)) table_alias"); + + assertFails(MISSING_ATTRIBUTE, + "line 1:23: 'table_alias.a' cannot be resolved", + "SELECT table_alias.x, table_alias.a FROM TABLE(system.pass_through_function(TABLE(t1))) table_alias"); + } + + @Test + public void testTableFunctionRequiredColumns() + { + // the function required_column_function specifies columns 0 and 1 from table argument "INPUT" as required. + analyze("SELECT * FROM TABLE(system.required_columns_function(input => TABLE(t1)))"); + + analyze("SELECT * FROM TABLE(system.required_columns_function(input => TABLE(SELECT 1, 2, 3)))"); + + assertFails(FUNCTION_IMPLEMENTATION_ERROR, + "Invalid index: 1 of required column from table argument INPUT", + "SELECT * FROM TABLE(system.required_columns_function(input => TABLE(SELECT 1)))"); + + // table s1.t5 has two columns. The second column is hidden. Table function can require a hidden column. + analyze("SELECT * FROM TABLE(system.required_columns_function(input => TABLE(s1.t5)))"); + } } diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestCanonicalize.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestCanonicalize.java index 1fd0121232144..3880e9c72aaac 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestCanonicalize.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestCanonicalize.java @@ -15,7 +15,7 @@ import com.facebook.presto.Session; import com.facebook.presto.common.block.SortOrder; -import com.facebook.presto.spi.plan.WindowNode; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; import com.facebook.presto.sql.planner.assertions.BasePlanTest; import com.facebook.presto.sql.planner.assertions.ExpectedValueProvider; import com.facebook.presto.sql.planner.iterative.IterativeOptimizer; @@ -81,7 +81,7 @@ public void testJoin() @Test public void testDuplicatesInWindowOrderBy() { - ExpectedValueProvider specification = specification( + ExpectedValueProvider specification = specification( ImmutableList.of(), ImmutableList.of("A"), ImmutableMap.of("A", SortOrder.ASC_NULLS_LAST)); diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestEffectivePredicateExtractor.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestEffectivePredicateExtractor.java index 46676111a82aa..ae2a450ef2b05 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestEffectivePredicateExtractor.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestEffectivePredicateExtractor.java @@ -24,6 +24,7 @@ import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.TableHandle; import com.facebook.presto.spi.plan.AggregationNode; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; import com.facebook.presto.spi.plan.EquiJoinClause; import com.facebook.presto.spi.plan.FilterNode; import com.facebook.presto.spi.plan.JoinNode; @@ -342,7 +343,7 @@ public void testWindow() equals(AV, BV), equals(BV, CV), lessThan(CV, bigintLiteral(10)))), - new WindowNode.Specification( + new DataOrganizationSpecification( ImmutableList.of(AV), Optional.of(new OrderingScheme( ImmutableList.of(new Ordering(AV, SortOrder.ASC_NULLS_LAST))))), diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestTableFunctionInvocation.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestTableFunctionInvocation.java new file mode 100644 index 0000000000000..f822ee57e556f --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestTableFunctionInvocation.java @@ -0,0 +1,181 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner; + +import com.facebook.presto.connector.tvf.MockConnectorFactory; +import com.facebook.presto.connector.tvf.MockConnectorPlugin; +import com.facebook.presto.connector.tvf.TestingTableFunctions; +import com.facebook.presto.connector.tvf.TestingTableFunctions.DescriptorArgumentFunction; +import com.facebook.presto.connector.tvf.TestingTableFunctions.DifferentArgumentTypesFunction; +import com.facebook.presto.connector.tvf.TestingTableFunctions.TestingTableFunctionHandle; +import com.facebook.presto.connector.tvf.TestingTableFunctions.TwoScalarArgumentsFunction; +import com.facebook.presto.spi.connector.TableFunctionApplicationResult; +import com.facebook.presto.spi.function.table.Descriptor; +import com.facebook.presto.spi.function.table.Descriptor.Field; +import com.facebook.presto.sql.planner.assertions.BasePlanTest; +import com.facebook.presto.sql.tree.LongLiteral; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.util.Optional; + +import static com.facebook.presto.common.block.SortOrder.ASC_NULLS_LAST; +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.BooleanType.BOOLEAN; +import static com.facebook.presto.sql.Optimizer.PlanStage.CREATED; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyTree; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.specification; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableFunction; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static com.facebook.presto.sql.planner.assertions.TableFunctionMatcher.DescriptorArgumentValue.descriptorArgument; +import static com.facebook.presto.sql.planner.assertions.TableFunctionMatcher.DescriptorArgumentValue.nullDescriptor; +import static com.facebook.presto.sql.planner.assertions.TableFunctionMatcher.TableArgumentValue.Builder.tableArgument; + +public class TestTableFunctionInvocation + extends BasePlanTest +{ + private static final String TESTING_CATALOG = "mock"; + + @BeforeClass + public final void setup() + { + getQueryRunner().installPlugin(new MockConnectorPlugin(MockConnectorFactory.builder() + .withTableFunctions(ImmutableSet.of( + new DifferentArgumentTypesFunction(), + new TwoScalarArgumentsFunction(), + new DescriptorArgumentFunction(), + new TestingTableFunctions.TwoTableArgumentsFunction())) + .withApplyTableFunction((session, handle) -> { + if (handle instanceof TestingTableFunctionHandle) { + TestingTableFunctionHandle functionHandle = (TestingTableFunctionHandle) handle; + return Optional.of(new TableFunctionApplicationResult<>(functionHandle.getTableHandle(), functionHandle.getTableHandle().getColumns().orElseThrow(() -> new IllegalStateException("Missing columns")))); + } + throw new IllegalStateException("Unsupported table function handle: " + handle.getClass().getSimpleName()); + }) + .build())); + getQueryRunner().createCatalog(TESTING_CATALOG, "mock", ImmutableMap.of()); + } + + @Test + public void testTableFunctionInitialPlan() + { + assertPlan( + "SELECT * FROM TABLE(mock.system.different_arguments_function(" + + "INPUT_1 => TABLE(SELECT 'a') t1(c1) PARTITION BY c1 ORDER BY c1," + + "INPUT_3 => TABLE(SELECT 'b') t3(c3) PARTITION BY c3," + + "INPUT_2 => TABLE(VALUES 1) t2(c2)," + + "ID => BIGINT '2001'," + + "LAYOUT => DESCRIPTOR (x boolean, y bigint)" + + "COPARTITION (t1, t3))) t", + CREATED, + anyTree(tableFunction(builder -> builder + .name("different_arguments_function") + .addTableArgument( + "INPUT_1", + tableArgument(0) + .specification(specification(ImmutableList.of("c1"), ImmutableList.of("c1"), ImmutableMap.of("c1", ASC_NULLS_LAST))) + .passThroughVariables(ImmutableSet.of("c1")) + .passThroughColumns()) + .addTableArgument( + "INPUT_3", + tableArgument(2) + .specification(specification(ImmutableList.of("c3"), ImmutableList.of(), ImmutableMap.of())) + .pruneWhenEmpty() + .passThroughVariables(ImmutableSet.of("c3"))) + .addTableArgument( + "INPUT_2", + tableArgument(1) + .rowSemantics() + .passThroughVariables(ImmutableSet.of("c2")) + .passThroughColumns()) + .addScalarArgument("ID", 2001L) + .addDescriptorArgument( + "LAYOUT", + descriptorArgument(new Descriptor(ImmutableList.of( + new Field("X", Optional.of(BOOLEAN)), + new Field("Y", Optional.of(BIGINT)))))) + .addCopartitioning(ImmutableList.of("INPUT_1", "INPUT_3")) + .properOutputs(ImmutableList.of("OUTPUT")), + anyTree(project(ImmutableMap.of("c1", expression("'a'")), values(1))), + anyTree(values(ImmutableList.of("c2"), ImmutableList.of(ImmutableList.of(new LongLiteral("1"))))), + anyTree(project(ImmutableMap.of("c3", expression("'b'")), values(1)))))); + } + + @Test + public void testTableFunctionInitialPlanWithCoercionForCopartitioning() + { + assertPlan("SELECT * FROM TABLE(mock.system.two_table_arguments_function(" + + "INPUT1 => TABLE(VALUES SMALLINT '1') t1(c1) PARTITION BY c1," + + "INPUT2 => TABLE(VALUES INTEGER '2') t2(c2) PARTITION BY c2 " + + "COPARTITION (t1, t2))) t", + CREATED, + anyTree(tableFunction(builder -> builder + .name("two_table_arguments_function") + .addTableArgument( + "INPUT1", + tableArgument(0) + .specification(specification(ImmutableList.of("c1_coerced"), ImmutableList.of(), ImmutableMap.of())) + .passThroughVariables(ImmutableSet.of("c1"))) + .addTableArgument( + "INPUT2", + tableArgument(1) + .specification(specification(ImmutableList.of("c2"), ImmutableList.of(), ImmutableMap.of())) + .passThroughVariables(ImmutableSet.of("c2"))) + .addCopartitioning(ImmutableList.of("INPUT1", "INPUT2")) + .properOutputs(ImmutableList.of("COLUMN")), + project(ImmutableMap.of("c1_coerced", expression("CAST(c1 AS INTEGER)")), + anyTree(values(ImmutableList.of("c1"), ImmutableList.of(ImmutableList.of(new LongLiteral("1")))))), + anyTree(values(ImmutableList.of("c2"), ImmutableList.of(ImmutableList.of(new LongLiteral("2")))))))); + } + + @Test + public void testNullScalarArgument() + { + // the argument NUMBER has null default value + assertPlan( + " SELECT * FROM TABLE(mock.system.two_arguments_function(TEXT => null))", + CREATED, + anyTree(tableFunction(builder -> builder + .name("two_arguments_function") + .addScalarArgument("TEXT", null) + .addScalarArgument("NUMBER", null) + .properOutputs(ImmutableList.of("OUTPUT"))))); + } + + @Test + public void testNullDescriptorArgument() + { + assertPlan( + " SELECT * FROM TABLE(mock.system.descriptor_argument_function(SCHEMA => CAST(null AS DESCRIPTOR)))", + CREATED, + anyTree(tableFunction(builder -> builder + .name("descriptor_argument_function") + .addDescriptorArgument("SCHEMA", nullDescriptor()) + .properOutputs(ImmutableList.of("OUTPUT"))))); + + // the argument SCHEMA has null default value + assertPlan( + " SELECT * FROM TABLE(mock.system.descriptor_argument_function())", + CREATED, + anyTree(tableFunction(builder -> builder + .name("descriptor_argument_function") + .addDescriptorArgument("SCHEMA", nullDescriptor()) + .properOutputs(ImmutableList.of("OUTPUT"))))); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestTypeValidator.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestTypeValidator.java index 07a0529289495..8c7995568caec 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestTypeValidator.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestTypeValidator.java @@ -26,6 +26,7 @@ import com.facebook.presto.spi.plan.AggregationNode; import com.facebook.presto.spi.plan.AggregationNode.Aggregation; import com.facebook.presto.spi.plan.Assignments; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.PlanNodeId; import com.facebook.presto.spi.plan.ProjectNode; @@ -175,7 +176,7 @@ public void testValidWindow() WindowNode.Function function = new WindowNode.Function(call("sum", functionHandle, DOUBLE, variableC), frame, false); - WindowNode.Specification specification = new WindowNode.Specification(ImmutableList.of(), Optional.empty()); + DataOrganizationSpecification specification = new DataOrganizationSpecification(ImmutableList.of(), Optional.empty()); PlanNode node = new WindowNode( Optional.empty(), @@ -437,7 +438,7 @@ public void testInvalidWindowFunctionCall() WindowNode.Function function = new WindowNode.Function(call("sum", functionHandle, BIGINT, variableA), frame, false); - WindowNode.Specification specification = new WindowNode.Specification(ImmutableList.of(), Optional.empty()); + DataOrganizationSpecification specification = new DataOrganizationSpecification(ImmutableList.of(), Optional.empty()); PlanNode node = new WindowNode( Optional.empty(), @@ -471,7 +472,7 @@ public void testInvalidWindowFunctionSignature() WindowNode.Function function = new WindowNode.Function(call("sum", functionHandle, BIGINT, variableC), frame, false); - WindowNode.Specification specification = new WindowNode.Specification(ImmutableList.of(), Optional.empty()); + DataOrganizationSpecification specification = new DataOrganizationSpecification(ImmutableList.of(), Optional.empty()); PlanNode node = new WindowNode( Optional.empty(), diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java index 7d79307417df8..6e524258e432e 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java @@ -23,6 +23,7 @@ import com.facebook.presto.spi.plan.AggregationNode.Step; import com.facebook.presto.spi.plan.CteConsumerNode; import com.facebook.presto.spi.plan.CteProducerNode; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; import com.facebook.presto.spi.plan.EquiJoinClause; import com.facebook.presto.spi.plan.ExceptNode; import com.facebook.presto.spi.plan.FilterNode; @@ -627,6 +628,11 @@ public static PlanMatchPattern values(String... aliases) return values(ImmutableList.copyOf(aliases)); } + public static PlanMatchPattern values(int rowCount) + { + return values(ImmutableList.of(), nCopies(rowCount, ImmutableList.of())); + } + public static PlanMatchPattern values(List aliases, List> expectedRows) { return values(aliases, Optional.of(expectedRows)); @@ -667,6 +673,27 @@ public static PlanMatchPattern remoteSource(List sourceFragmentI return node(RemoteSourceNode.class).with(new RemoteSourceMatcher(sourceFragmentIds, outputSymbolAliases)); } + public static PlanMatchPattern tableFunction(Consumer handler, PlanMatchPattern... sources) + { + TableFunctionMatcher.Builder builder = new TableFunctionMatcher.Builder(sources); + handler.accept(builder); + return builder.build(); + } + + public static PlanMatchPattern tableFunctionProcessor(Consumer handler, PlanMatchPattern source) + { + TableFunctionProcessorMatcher.Builder builder = new TableFunctionProcessorMatcher.Builder(source); + handler.accept(builder); + return builder.build(); + } + + public static PlanMatchPattern tableFunctionProcessor(Consumer handler) + { + TableFunctionProcessorMatcher.Builder builder = new TableFunctionProcessorMatcher.Builder(); + handler.accept(builder); + return builder.build(); + } + public PlanMatchPattern(List sourcePatterns) { requireNonNull(sourcePatterns, "sourcePatterns are null"); @@ -927,7 +954,7 @@ private static List toSymbolAliases(List aliases) .collect(toImmutableList()); } - public static ExpectedValueProvider specification( + public static ExpectedValueProvider specification( List partitionBy, List orderBy, Map orderings) diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/SpecificationProvider.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/SpecificationProvider.java index 1caf681620280..b8cc9b3b4c55b 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/SpecificationProvider.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/SpecificationProvider.java @@ -14,9 +14,9 @@ package com.facebook.presto.sql.planner.assertions; import com.facebook.presto.common.block.SortOrder; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; import com.facebook.presto.spi.plan.Ordering; import com.facebook.presto.spi.plan.OrderingScheme; -import com.facebook.presto.spi.plan.WindowNode; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -33,7 +33,7 @@ import static java.util.Objects.requireNonNull; public class SpecificationProvider - implements ExpectedValueProvider + implements ExpectedValueProvider { private final List partitionBy; private final List orderBy; @@ -50,7 +50,7 @@ public class SpecificationProvider } @Override - public WindowNode.Specification getExpectedValue(SymbolAliases aliases) + public DataOrganizationSpecification getExpectedValue(SymbolAliases aliases) { Optional orderingScheme = Optional.empty(); if (!orderBy.isEmpty()) { @@ -64,7 +64,7 @@ public WindowNode.Specification getExpectedValue(SymbolAliases aliases) .collect(toImmutableList()))); } - return new WindowNode.Specification( + return new DataOrganizationSpecification( partitionBy .stream() .map(alias -> new VariableReferenceExpression(Optional.empty(), alias.toSymbol(aliases).getName(), UNKNOWN)) @@ -87,7 +87,7 @@ public String toString() * VariableReferenceExpression::equals to check whether two specification are equivalent once they include VariableReferenceExpression. * TODO Directly use equals once SymbolAlias is converted to something with type information. */ - public static boolean matchSpecification(WindowNode.Specification actual, WindowNode.Specification expected) + public static boolean matchSpecification(DataOrganizationSpecification actual, DataOrganizationSpecification expected) { return actual.getPartitionBy().stream().map(VariableReferenceExpression::getName).collect(toImmutableList()) .equals(expected.getPartitionBy().stream().map(VariableReferenceExpression::getName).collect(toImmutableList())) && @@ -104,7 +104,7 @@ public static boolean matchSpecification(WindowNode.Specification actual, Window .orElse(true); } - public static boolean matchSpecification(WindowNode.Specification actual, SpecificationProvider expected) + public static boolean matchSpecification(DataOrganizationSpecification actual, SpecificationProvider expected) { return actual.getPartitionBy().stream().map(VariableReferenceExpression::getName).collect(toImmutableList()) .equals(expected.partitionBy.stream().map(SymbolAlias::toString).collect(toImmutableList())) && diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/TableFunctionMatcher.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/TableFunctionMatcher.java new file mode 100644 index 0000000000000..8251e6c2410f2 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/TableFunctionMatcher.java @@ -0,0 +1,392 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.assertions; + +import com.facebook.presto.Session; +import com.facebook.presto.cost.StatsProvider; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.function.table.Argument; +import com.facebook.presto.spi.function.table.Descriptor; +import com.facebook.presto.spi.function.table.DescriptorArgument; +import com.facebook.presto.spi.function.table.ScalarArgument; +import com.facebook.presto.spi.function.table.TableArgument; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; +import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.TableArgumentProperties; +import com.facebook.presto.sql.tree.SymbolReference; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; + +import static com.facebook.presto.sql.planner.QueryPlanner.toSymbolReference; +import static com.facebook.presto.sql.planner.assertions.MatchResult.NO_MATCH; +import static com.facebook.presto.sql.planner.assertions.MatchResult.match; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.node; +import static com.facebook.presto.sql.planner.assertions.SpecificationProvider.matchSpecification; +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static java.util.Objects.requireNonNull; + +public class TableFunctionMatcher + implements Matcher +{ + private final String name; + private final Map arguments; + private final List properOutputs; + private final List> copartitioningLists; + + private TableFunctionMatcher( + String name, + Map arguments, + List properOutputs, + List> copartitioningLists) + { + this.name = requireNonNull(name, "name is null"); + this.arguments = ImmutableMap.copyOf(requireNonNull(arguments, "arguments is null")); + this.properOutputs = ImmutableList.copyOf(requireNonNull(properOutputs, "properOutputs is null")); + requireNonNull(copartitioningLists, "copartitioningLists is null"); + this.copartitioningLists = copartitioningLists.stream() + .map(ImmutableList::copyOf) + .collect(toImmutableList()); + } + + @Override + public boolean shapeMatches(PlanNode node) + { + return node instanceof TableFunctionNode; + } + + @Override + public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session session, Metadata metadata, SymbolAliases symbolAliases) + { + checkState(shapeMatches(node), "Plan testing framework error: shapeMatches returned false in detailMatches in %s", this.getClass().getName()); + + TableFunctionNode tableFunctionNode = (TableFunctionNode) node; + + if (!name.equals(tableFunctionNode.getName())) { + return NO_MATCH; + } + + if (arguments.size() != tableFunctionNode.getArguments().size()) { + return NO_MATCH; + } + for (Map.Entry entry : arguments.entrySet()) { + String name = entry.getKey(); + Argument actual = tableFunctionNode.getArguments().get(name); + if (actual == null) { + return NO_MATCH; + } + ArgumentValue expected = entry.getValue(); + if (expected instanceof DescriptorArgumentValue) { + DescriptorArgumentValue expectedDescriptor = (DescriptorArgumentValue) expected; + if (!(actual instanceof DescriptorArgument) || !expectedDescriptor.getDescriptor().equals(((DescriptorArgument) actual).getDescriptor())) { + return NO_MATCH; + } + } + else if (expected instanceof ScalarArgumentValue) { + ScalarArgumentValue expectedScalar = (ScalarArgumentValue) expected; + if (!(actual instanceof ScalarArgument) || !Objects.equals(expectedScalar.getValue(), ((ScalarArgument) actual).getValue())) { + return NO_MATCH; + } + } + else { + if (!(actual instanceof TableArgument)) { + return NO_MATCH; + } + TableArgumentValue expectedTableArgument = (TableArgumentValue) expected; + TableArgumentProperties argumentProperties = tableFunctionNode.getTableArgumentProperties().get(expectedTableArgument.sourceIndex()); + if (!name.equals(argumentProperties.getArgumentName())) { + return NO_MATCH; + } + if (expectedTableArgument.rowSemantics() != argumentProperties.rowSemantics() || + expectedTableArgument.pruneWhenEmpty() != argumentProperties.pruneWhenEmpty() || + expectedTableArgument.passThroughColumns() != argumentProperties.getPassThroughSpecification().isDeclaredAsPassThrough()) { + return NO_MATCH; + } + + if (expectedTableArgument.specification().isPresent() != argumentProperties.specification().isPresent()) { + return NO_MATCH; + } + if (!expectedTableArgument.specification() + .map(expectedSpecification -> matchSpecification(argumentProperties.specification().get(), expectedSpecification.getExpectedValue(symbolAliases))) + .orElse(true)) { + return NO_MATCH; + } + + Set expectedPassThrough = expectedTableArgument.passThroughVariables().stream() + .map(symbolAliases::get) + .collect(toImmutableSet()); + Set actualPassThrough = argumentProperties.getPassThroughSpecification().getColumns().stream() + .map(var -> toSymbolReference(var.getOutputVariables())) + .collect(toImmutableSet()); + if (!expectedPassThrough.equals(actualPassThrough)) { + return NO_MATCH; + } + } + } + + if (properOutputs.size() != tableFunctionNode.getProperOutputs().size()) { + return NO_MATCH; + } + + if (!ImmutableSet.copyOf(copartitioningLists).equals(ImmutableSet.copyOf(tableFunctionNode.getCopartitioningLists()))) { + return NO_MATCH; + } + + ImmutableMap.Builder properOutputsMapping = ImmutableMap.builder(); + + // TODO: Do we need these symbol references or should it be something like VariableReferenceExpression + /* + for (int i = 0; i < properOutputs.size(); i++) { + properOutputsMapping.put(properOutputs.get(i), tableFunctionNode.getOutputVariables().get(i).toSymbolReference()); + } + */ + + return match(SymbolAliases.builder() + .putAll(symbolAliases) + .putAll(properOutputsMapping.buildOrThrow()) + .build()); + } + + @Override + public String toString() + { + return toStringHelper(this) + .omitNullValues() + .add("name", name) + .add("arguments", arguments) + .add("properOutputs", properOutputs) + .add("copartitioningLists", copartitioningLists) + .toString(); + } + + public static class Builder + { + private final PlanMatchPattern[] sources; + private String name; + private final ImmutableMap.Builder arguments = ImmutableMap.builder(); + private List properOutputs = ImmutableList.of(); + private final ImmutableList.Builder> copartitioningLists = ImmutableList.builder(); + + Builder(PlanMatchPattern... sources) + { + this.sources = Arrays.copyOf(sources, sources.length); + } + + public Builder name(String name) + { + this.name = name; + return this; + } + + public Builder addDescriptorArgument(String name, DescriptorArgumentValue descriptor) + { + this.arguments.put(name, descriptor); + return this; + } + + public Builder addScalarArgument(String name, Object value) + { + this.arguments.put(name, new ScalarArgumentValue(value)); + return this; + } + + public Builder addTableArgument(String name, TableArgumentValue.Builder tableArgument) + { + this.arguments.put(name, tableArgument.build()); + return this; + } + + public Builder properOutputs(List properOutputs) + { + this.properOutputs = properOutputs; + return this; + } + + public Builder addCopartitioning(List copartitioning) + { + this.copartitioningLists.add(copartitioning); + return this; + } + + public PlanMatchPattern build() + { + return node(TableFunctionNode.class, sources) + .with(new TableFunctionMatcher(name, arguments.buildOrThrow(), properOutputs, copartitioningLists.build())); + } + } + + interface ArgumentValue + { + } + + public static class DescriptorArgumentValue + implements ArgumentValue + { + private final Optional descriptor; + + public DescriptorArgumentValue(Optional descriptor) + { + this.descriptor = requireNonNull(descriptor, "descriptor is null"); + } + + public static DescriptorArgumentValue descriptorArgument(Descriptor descriptor) + { + return new DescriptorArgumentValue(Optional.of(requireNonNull(descriptor, "descriptor is null"))); + } + + public static DescriptorArgumentValue nullDescriptor() + { + return new DescriptorArgumentValue(Optional.empty()); + } + + public Optional getDescriptor() + { + return descriptor; + } + } + + public static class ScalarArgumentValue + implements ArgumentValue + { + private final Object value; + + public ScalarArgumentValue(Object value) + { + this.value = value; + } + + public Object getValue() + { + return value; + } + } + + public static class TableArgumentValue + implements ArgumentValue + { + private final int sourceIndex; + private final boolean rowSemantics; + private final boolean pruneWhenEmpty; + private final boolean passThroughColumns; + private final Optional> specification; + private final Set passThroughVariables; + + public TableArgumentValue(int sourceIndex, boolean rowSemantics, boolean pruneWhenEmpty, boolean passThroughColumns, Optional> specification, Set passThroughVariables) + { + this.sourceIndex = sourceIndex; + this.rowSemantics = rowSemantics; + this.pruneWhenEmpty = pruneWhenEmpty; + this.passThroughColumns = passThroughColumns; + this.specification = requireNonNull(specification, "specification is null"); + this.passThroughVariables = ImmutableSet.copyOf(passThroughVariables); + } + + public int sourceIndex() + { + return sourceIndex; + } + + public boolean rowSemantics() + { + return rowSemantics; + } + + public boolean pruneWhenEmpty() + { + return pruneWhenEmpty; + } + + public boolean passThroughColumns() + { + return passThroughColumns; + } + + public Set passThroughVariables() + { + return passThroughVariables; + } + + public Optional> specification() + { + return specification; + } + + public static class Builder + { + private final int sourceIndex; + private boolean rowSemantics; + private boolean pruneWhenEmpty; + private boolean passThroughColumns; + private Optional> specification = Optional.empty(); + private Set passThroughVariables = ImmutableSet.of(); + + private Builder(int sourceIndex) + { + this.sourceIndex = sourceIndex; + } + + public static Builder tableArgument(int sourceIndex) + { + return new Builder(sourceIndex); + } + + public Builder rowSemantics() + { + this.rowSemantics = true; + this.pruneWhenEmpty = true; + return this; + } + + public Builder pruneWhenEmpty() + { + this.pruneWhenEmpty = true; + return this; + } + + public Builder passThroughColumns() + { + this.passThroughColumns = true; + return this; + } + + public Builder specification(ExpectedValueProvider specification) + { + this.specification = Optional.of(specification); + return this; + } + + public Builder passThroughVariables(Set variables) + { + this.passThroughVariables = variables; + return this; + } + + private TableArgumentValue build() + { + return new TableArgumentValue(sourceIndex, rowSemantics, pruneWhenEmpty, passThroughColumns, specification, passThroughVariables); + } + } + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/TableFunctionProcessorMatcher.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/TableFunctionProcessorMatcher.java new file mode 100644 index 0000000000000..015cffb412903 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/TableFunctionProcessorMatcher.java @@ -0,0 +1,208 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.assertions; + +import com.facebook.presto.Session; +import com.facebook.presto.cost.StatsProvider; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; +import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.sql.planner.QueryPlanner; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughColumn; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughSpecification; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; +import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.tree.SymbolReference; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; + +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Optional; +import java.util.Set; + +import static com.facebook.presto.sql.planner.QueryPlanner.toSymbolReference; +import static com.facebook.presto.sql.planner.assertions.MatchResult.NO_MATCH; +import static com.facebook.presto.sql.planner.assertions.MatchResult.match; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.node; +import static com.facebook.presto.sql.planner.assertions.SpecificationProvider.matchSpecification; +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static java.util.Objects.requireNonNull; + +public class TableFunctionProcessorMatcher + implements Matcher +{ + private final String name; + private final List properOutputs; + private final Set passThroughSymbols; + private final Optional> markerSymbols; + private final Optional> specification; + + private TableFunctionProcessorMatcher( + String name, + List properOutputs, + Set passThroughSymbols, + Optional> markerSymbols, + Optional> specification) + { + this.name = requireNonNull(name, "name is null"); + this.properOutputs = ImmutableList.copyOf(properOutputs); + this.passThroughSymbols = ImmutableSet.copyOf(passThroughSymbols); + this.markerSymbols = markerSymbols.map(ImmutableMap::copyOf); + this.specification = requireNonNull(specification, "specification is null"); + } + + @Override + public boolean shapeMatches(PlanNode node) + { + return node instanceof TableFunctionProcessorNode; + } + + @Override + public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session session, Metadata metadata, SymbolAliases symbolAliases) + { + checkState(shapeMatches(node), "Plan testing framework error: shapeMatches returned false in detailMatches in %s", this.getClass().getName()); + + TableFunctionProcessorNode tableFunctionProcessorNode = (TableFunctionProcessorNode) node; + + if (!name.equals(tableFunctionProcessorNode.getName())) { + return NO_MATCH; + } + + if (properOutputs.size() != tableFunctionProcessorNode.getProperOutputs().size()) { + return NO_MATCH; + } + + Set expectedPassThrough = passThroughSymbols.stream() + .map(symbolAliases::get) + .collect(toImmutableSet()); + Set actualPassThrough = tableFunctionProcessorNode.getPassThroughSpecifications().stream() + .map(PassThroughSpecification::getColumns) + .flatMap(Collection::stream) + .map(PassThroughColumn::getOutputVariables) + .map(QueryPlanner::toSymbolReference) + .collect(toImmutableSet()); + if (!expectedPassThrough.equals(actualPassThrough)) { + return NO_MATCH; + } + + if (markerSymbols.isPresent() != tableFunctionProcessorNode.getMarkerVariables().isPresent()) { + return NO_MATCH; + } + if (markerSymbols.isPresent()) { + Map expectedMapping = markerSymbols.get().entrySet().stream() + .collect(toImmutableMap(entry -> symbolAliases.get(entry.getKey()), entry -> symbolAliases.get(entry.getValue()))); + Map actualMapping = tableFunctionProcessorNode.getMarkerVariables().get().entrySet().stream() + .collect(toImmutableMap(entry -> toSymbolReference(entry.getKey()), entry -> toSymbolReference(entry.getValue()))); + if (!expectedMapping.equals(actualMapping)) { + return NO_MATCH; + } + } + + if (specification.isPresent() != tableFunctionProcessorNode.getSpecification().isPresent()) { + return NO_MATCH; + } + if (specification.isPresent()) { + if (!matchSpecification(specification.get().getExpectedValue(symbolAliases), tableFunctionProcessorNode.getSpecification().orElseThrow(NoSuchElementException::new))) { + return NO_MATCH; + } + } + + ImmutableMap.Builder properOutputsMapping = ImmutableMap.builder(); + for (int i = 0; i < properOutputs.size(); i++) { + properOutputsMapping.put(properOutputs.get(i), toSymbolReference(tableFunctionProcessorNode.getProperOutputs().get(i))); + } + + return match(SymbolAliases.builder() + .putAll(symbolAliases) + .putAll(properOutputsMapping.buildOrThrow()) + .build()); + } + + @Override + public String toString() + { + return toStringHelper(this) + .omitNullValues() + .add("name", name) + .add("properOutputs", properOutputs) + .add("passThroughSymbols", passThroughSymbols) + .add("markerSymbols", markerSymbols) + .add("specification", specification) + .toString(); + } + + public static class Builder + { + private final Optional source; + private String name; + private List properOutputs = ImmutableList.of(); + private Set passThroughSymbols = ImmutableSet.of(); + private Optional> markerSymbols = Optional.empty(); + private Optional> specification = Optional.empty(); + + public Builder() + { + this.source = Optional.empty(); + } + + public Builder(PlanMatchPattern source) + { + this.source = Optional.of(source); + } + + public Builder name(String name) + { + this.name = name; + return this; + } + + public Builder properOutputs(List properOutputs) + { + this.properOutputs = properOutputs; + return this; + } + + public Builder passThroughSymbols(Set passThroughSymbols) + { + this.passThroughSymbols = passThroughSymbols; + return this; + } + + public Builder markerSymbols(Map markerSymbols) + { + this.markerSymbols = Optional.of(markerSymbols); + return this; + } + + public Builder specification(ExpectedValueProvider specification) + { + this.specification = Optional.of(specification); + return this; + } + + public PlanMatchPattern build() + { + PlanMatchPattern[] sources = source.map(sourcePattern -> new PlanMatchPattern[] {sourcePattern}).orElse(new PlanMatchPattern[] {}); + return node(TableFunctionProcessorNode.class, sources) + .with(new TableFunctionProcessorMatcher(name, properOutputs, passThroughSymbols, markerSymbols, specification)); + } + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/TopNRowNumberMatcher.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/TopNRowNumberMatcher.java index c4d66eb263263..590625570cbd7 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/TopNRowNumberMatcher.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/TopNRowNumberMatcher.java @@ -17,8 +17,8 @@ import com.facebook.presto.common.block.SortOrder; import com.facebook.presto.cost.StatsProvider; import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; import com.facebook.presto.spi.plan.PlanNode; -import com.facebook.presto.spi.plan.WindowNode; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; @@ -38,14 +38,14 @@ public class TopNRowNumberMatcher implements Matcher { - private final Optional> specification; + private final Optional> specification; private final Optional rowNumberSymbol; private final Optional maxRowCountPerPartition; private final Optional partial; private final Optional> hashSymbol; private TopNRowNumberMatcher( - Optional> specification, + Optional> specification, Optional rowNumberSymbol, Optional maxRowCountPerPartition, Optional partial, @@ -125,7 +125,7 @@ public String toString() public static class Builder { private final PlanMatchPattern source; - private Optional> specification = Optional.empty(); + private Optional> specification = Optional.empty(); private Optional rowNumberSymbol = Optional.empty(); private Optional maxRowCountPerPartition = Optional.empty(); private Optional partial = Optional.empty(); diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/WindowMatcher.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/WindowMatcher.java index 02aa69afdd716..ffbe3c37d3e26 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/WindowMatcher.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/WindowMatcher.java @@ -18,6 +18,7 @@ import com.facebook.presto.cost.StatsProvider; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.function.FunctionHandle; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.WindowNode; import com.facebook.presto.spi.relation.VariableReferenceExpression; @@ -46,13 +47,13 @@ public final class WindowMatcher implements Matcher { private final Optional> prePartitionedInputs; - private final Optional> specification; + private final Optional> specification; private final Optional preSortedOrderPrefix; private final Optional> hashSymbol; private WindowMatcher( Optional> prePartitionedInputs, - Optional> specification, + Optional> specification, Optional preSortedOrderPrefix, Optional> hashSymbol) { @@ -136,7 +137,7 @@ public static class Builder { private final PlanMatchPattern source; private Optional> prePartitionedInputs = Optional.empty(); - private Optional> specification = Optional.empty(); + private Optional> specification = Optional.empty(); private Optional preSortedOrderPrefix = Optional.empty(); private List windowFunctionMatchers = new LinkedList<>(); private Optional> hashSymbol = Optional.empty(); @@ -164,7 +165,7 @@ public Builder specification( return specification(PlanMatchPattern.specification(partitionBy, orderBy, orderings)); } - public Builder specification(ExpectedValueProvider specification) + public Builder specification(ExpectedValueProvider specification) { requireNonNull(specification, "specification is null"); this.specification = Optional.of(specification); diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestImplementTableFunctionSource.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestImplementTableFunctionSource.java new file mode 100644 index 0000000000000..40a04f7874193 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestImplementTableFunctionSource.java @@ -0,0 +1,1323 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.spi.plan.Assignments; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; +import com.facebook.presto.spi.plan.JoinType; +import com.facebook.presto.spi.plan.Ordering; +import com.facebook.presto.spi.plan.OrderingScheme; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughSpecification; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.TableArgumentProperties; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import org.testng.annotations.Test; + +import java.util.Optional; + +import static com.facebook.presto.common.block.SortOrder.ASC_NULLS_LAST; +import static com.facebook.presto.common.block.SortOrder.DESC_NULLS_FIRST; +import static com.facebook.presto.common.type.IntegerType.INTEGER; +import static com.facebook.presto.common.type.TinyintType.TINYINT; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.functionCall; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.join; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.specification; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableFunctionProcessor; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.window; +import static com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughColumn; + +public class TestImplementTableFunctionSource + extends BaseRuleTest +{ + @Test + public void testNoSources() + { + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> p.tableFunction( + "test_function", + ImmutableList.of(p.variable("a")), + ImmutableList.of(), + ImmutableList.of(), + ImmutableList.of())) + .matches(tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a")))); + } + + @Test + public void testSingleSourceWithRowSemantics() + { + // no pass-through columns + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of(p.values(c)), + ImmutableList.of(new TableFunctionNode.TableArgumentProperties( + "table_argument", + true, + true, + new TableFunctionNode.PassThroughSpecification(false, ImmutableList.of()), + ImmutableList.of(c), + Optional.empty())), + ImmutableList.of()); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")), + values("c"))); + + // pass-through columns + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of(p.values(c)), + ImmutableList.of(new TableFunctionNode.TableArgumentProperties( + "table_argument", + true, + true, + new TableFunctionNode.PassThroughSpecification(true, ImmutableList.of(new TableFunctionNode.PassThroughColumn(c, false))), + ImmutableList.of(c), + Optional.empty())), + ImmutableList.of()); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableSet.of("c")), + values("c"))); + } + + @Test + public void testSingleSourceWithSetSemantics() + { + // no pass-through columns, no partition by + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of(p.values(c, d)), + ImmutableList.of(new TableArgumentProperties( + "table_argument", + false, + false, + new PassThroughSpecification(false, ImmutableList.of()), + ImmutableList.of(c, d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(), Optional.of(new OrderingScheme(ImmutableList.of(new Ordering(d, ASC_NULLS_LAST)))))))), + ImmutableList.of()); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .specification(specification(ImmutableList.of(), ImmutableList.of("d"), ImmutableMap.of("d", ASC_NULLS_LAST))), + values("c", "d"))); + + // no pass-through columns, partitioning column present + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of(p.values(c, d)), + ImmutableList.of(new TableArgumentProperties( + "table_argument", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new TableFunctionNode.PassThroughColumn(c, true))), + ImmutableList.of(c, d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.of(new OrderingScheme(ImmutableList.of(new Ordering(d, ASC_NULLS_LAST)))))))), + ImmutableList.of()); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableSet.of("c")) + .specification(specification(ImmutableList.of("c"), ImmutableList.of("d"), ImmutableMap.of("d", ASC_NULLS_LAST))), + values("c", "d"))); + + // pass-through columns + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of(p.values(c, d)), + ImmutableList.of(new TableArgumentProperties( + "table_argument", + false, + false, + new PassThroughSpecification(true, ImmutableList.of(new PassThroughColumn(c, true), new PassThroughColumn(d, false))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty())))), + ImmutableList.of()); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableSet.of("c", "d")) + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())), + values("c", "d"))); + } + + @Test + public void testTwoSourcesWithSetSemantics() + { + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + VariableReferenceExpression e = p.variable("e"); + VariableReferenceExpression f = p.variable("f"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c, d), + p.values(e, f)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + false, + new PassThroughSpecification(true, ImmutableList.of(new PassThroughColumn(e, false), new PassThroughColumn(f, false))), + ImmutableList.of(f), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(), Optional.empty())))), + ImmutableList.of()); + }) + .matches(tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableSet.of("c", "e", "f")) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_1", + "e", "marker_2", + "f", "marker_2")) + .specification(specification(ImmutableList.of("c"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number, input_1_row_number, null)"), + "marker_2", expression("IF(input_2_row_number = combined_row_number, input_2_row_number, null)")), + project(// append helper symbols for joined nodes + ImmutableMap.of( + "combined_row_number", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)")), + join(// join nodes using helper symbols + JoinType.FULL, + ImmutableList.of(), + Optional.of("input_1_row_number = input_2_row_number OR " + + "(input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR " + + "input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1')"), + window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + values("c", "d")), + window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of(), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + values("e", "f"))))))); + } + + @Test + public void testThreeSourcesWithSetSemantics() + { + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + VariableReferenceExpression e = p.variable("e"); + VariableReferenceExpression f = p.variable("f"); + VariableReferenceExpression g = p.variable("g"); + VariableReferenceExpression h = p.variable("h"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c, d), + p.values(e, f), + p.values(g, h)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + false, + new PassThroughSpecification(true, ImmutableList.of(new PassThroughColumn(e, false), new PassThroughColumn(f, false))), + ImmutableList.of(f), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(), Optional.empty()))), + new TableArgumentProperties( + "input_3", + false, + false, + new PassThroughSpecification(false, ImmutableList.of()), + ImmutableList.of(h), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(), Optional.of(new OrderingScheme(ImmutableList.of(new Ordering(h, DESC_NULLS_FIRST)))))))), + ImmutableList.of()); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableSet.of("c", "e", "f")) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_1", + "e", "marker_2", + "f", "marker_2", + "g", "marker_3", + "h", "marker_3")) + .specification(specification(ImmutableList.of("c"), ImmutableList.of("combined_row_number_1_2_3"), ImmutableMap.of("combined_row_number_1_2_3", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number_1_2_3, input_1_row_number, null)"), + "marker_2", expression("IF(input_2_row_number = combined_row_number_1_2_3, input_2_row_number, null)"), + "marker_3", expression("IF(input_3_row_number = combined_row_number_1_2_3, input_3_row_number, null)")), + project(// append helper symbols for joined nodes + ImmutableMap.of( + "combined_row_number_1_2_3", expression("IF(COALESCE(combined_row_number_1_2, BIGINT '-1') > COALESCE(input_3_row_number, BIGINT '-1'), combined_row_number_1_2, input_3_row_number)"), + "combined_partition_size_1_2_3", expression("IF(COALESCE(combined_partition_size_1_2, BIGINT '-1') > COALESCE(input_3_partition_size, BIGINT '-1'), combined_partition_size_1_2, input_3_partition_size)")), + join(// join nodes using helper symbols + JoinType.FULL, + ImmutableList.of(), + Optional.of("combined_row_number_1_2 = input_3_row_number OR " + + "(combined_row_number_1_2 > input_3_partition_size AND input_3_row_number = BIGINT '1' OR " + + "input_3_row_number > combined_partition_size_1_2 AND combined_row_number_1_2 = BIGINT '1')"), + project(// append helper symbols for joined nodes + ImmutableMap.of( + "combined_row_number_1_2", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size_1_2", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)")), + join(// join nodes using helper symbols + JoinType.FULL, + ImmutableList.of(), + Optional.of("input_1_row_number = input_2_row_number OR " + + "(input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR " + + "input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1')"), + window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + values("c", "d")), + window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of(), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + values("e", "f")))), + window(// append helper symbols for source input_3 + builder -> builder + .specification(specification(ImmutableList.of(), ImmutableList.of("h"), ImmutableMap.of("h", DESC_NULLS_FIRST))) + .addFunction("input_3_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_3_partition_size", functionCall("count", ImmutableList.of())), + // input_3 + values("g", "h"))))))); + } + + @Test + public void testTwoCoPartitionedSources() + { + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + VariableReferenceExpression e = p.variable("e"); + VariableReferenceExpression f = p.variable("f"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c, d), + p.values(e, f)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(c, d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + false, + new PassThroughSpecification(true, ImmutableList.of(new PassThroughColumn(e, true), new PassThroughColumn(f, false))), + ImmutableList.of(f), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(e), Optional.of(new OrderingScheme(ImmutableList.of(new Ordering(f, DESC_NULLS_FIRST)))))))), + ImmutableList.of(ImmutableList.of("input_1", "input_2"))); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableSet.of("c", "e", "f")) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_1", + "e", "marker_2", + "f", "marker_2")) + .specification(specification(ImmutableList.of("combined_partition_column"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number, input_1_row_number, null)"), + "marker_2", expression("IF(input_2_row_number = combined_row_number, input_2_row_number, null)")), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)"), + "combined_partition_column", expression("COALESCE(c, e)")), + join(// co-partition nodes + JoinType.LEFT, + ImmutableList.of(), + Optional.of("NOT (c IS DISTINCT FROM e) " + + "AND ( " + + "input_1_row_number = input_2_row_number OR " + + "(input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR " + + "input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1')) "), + window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + values("c", "d")), + window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of("e"), ImmutableList.of("f"), ImmutableMap.of("f", DESC_NULLS_FIRST))) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + values("e", "f"))))))); + } + + @Test + public void testCoPartitionJoinTypes() + { + // both sources are prune when empty, so they are combined using inner join + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c), + p.values(d)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(c), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(d, true))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(d), Optional.empty())))), + ImmutableList.of(ImmutableList.of("input_1", "input_2"))); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableSet.of("c", "d")) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_2")) + .specification(specification(ImmutableList.of("combined_partition_column"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number, input_1_row_number, null)"), + "marker_2", expression("IF(input_2_row_number = combined_row_number, input_2_row_number, null)")), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)"), + "combined_partition_column", expression("COALESCE(c, d)")), + join(// co-partition nodes + JoinType.INNER, + ImmutableList.of(), + Optional.of("NOT (c IS DISTINCT FROM d) " + + "AND ( " + + "input_1_row_number = input_2_row_number OR " + + "(input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR " + + "input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1')) "), + window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + values("c")), + window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + values("d"))))))); + + // only the left source is prune when empty, so sources are combined using left join + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c), + p.values(d)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(c), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(d, true))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(d), Optional.empty())))), + ImmutableList.of(ImmutableList.of("input_1", "input_2"))); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableSet.of("c", "d")) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_2")) + .specification(specification(ImmutableList.of("combined_partition_column"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number, input_1_row_number, null)"), + "marker_2", expression("IF(input_2_row_number = combined_row_number, input_2_row_number, null)")), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)"), + "combined_partition_column", expression("COALESCE(c, d)")), + join(// co-partition nodes + JoinType.LEFT, + ImmutableList.of(), + Optional.of("NOT (c IS DISTINCT FROM d) " + + "AND ( " + + " input_1_row_number = input_2_row_number OR " + + " (input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR " + + " input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1'))"), + window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + values("c")), + window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + values("d"))))))); + + // only the right source is prune when empty. the sources are reordered so that the prune when empty source is first. they are combined using left join + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c), + p.values(d)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(c), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(d, true))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(d), Optional.empty())))), + ImmutableList.of(ImmutableList.of("input_1", "input_2"))); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableSet.of("c", "d")) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_2")) + .specification(specification(ImmutableList.of("combined_partition_column"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number, input_1_row_number, null)"), + "marker_2", expression("IF(input_2_row_number = combined_row_number, input_2_row_number, null)")), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number", expression("IF(COALESCE(input_2_row_number, BIGINT '-1') > COALESCE(input_1_row_number, BIGINT '-1'), input_2_row_number, input_1_row_number)"), + "combined_partition_size", expression("IF(COALESCE(input_2_partition_size, BIGINT '-1') > COALESCE(input_1_partition_size, BIGINT '-1'), input_2_partition_size, input_1_partition_size)"), + "combined_partition_column", expression("COALESCE(d, c)")), + join(// co-partition nodes + JoinType.LEFT, + ImmutableList.of(), + Optional.of("NOT (d IS DISTINCT FROM c) " + + "AND (" + + " input_2_row_number = input_1_row_number OR" + + " (input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1' OR" + + " input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1'))"), + window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + values("d")), + window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + values("c"))))))); + + // neither source is prune when empty, so sources are combined using full join + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c), + p.values(d)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(c), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(d, true))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(d), Optional.empty())))), + ImmutableList.of(ImmutableList.of("input_1", "input_2"))); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableSet.of("c", "d")) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_2")) + .specification(specification(ImmutableList.of("combined_partition_column"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number, input_1_row_number, null)"), + "marker_2", expression("IF(input_2_row_number = combined_row_number, input_2_row_number, null)")), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)"), + "combined_partition_column", expression("COALESCE(c, d)")), + join(// co-partition nodes + JoinType.FULL, + ImmutableList.of(), + Optional.of("NOT (c IS DISTINCT FROM d)" + + " AND (" + + " input_1_row_number = input_2_row_number OR" + + " (input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR" + + " input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1'))"), + window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + values("c")), + window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + values("d"))))))); + } + + @Test + public void testThreeCoPartitionedSources() + { + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + VariableReferenceExpression e = p.variable("e"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c), + p.values(d), + p.values(e)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(c), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(d, true))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(d), Optional.empty()))), + new TableArgumentProperties( + "input_3", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(e, true))), + ImmutableList.of(e), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(e), Optional.empty())))), + ImmutableList.of(ImmutableList.of("input_1", "input_2", "input_3"))); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableSet.of("c", "d", "e")) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_2", + "e", "marker_3")) + .specification(specification(ImmutableList.of("combined_partition_column_1_2_3"), ImmutableList.of("combined_row_number_1_2_3"), ImmutableMap.of("combined_row_number_1_2_3", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number_1_2_3, input_1_row_number, null)"), + "marker_2", expression("IF(input_2_row_number = combined_row_number_1_2_3, input_2_row_number, null)"), + "marker_3", expression("IF(input_3_row_number = combined_row_number_1_2_3, input_3_row_number, null)")), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number_1_2_3", expression("IF(COALESCE(combined_row_number_1_2, BIGINT '-1') > COALESCE(input_3_row_number, BIGINT '-1'), combined_row_number_1_2, input_3_row_number)"), + "combined_partition_size_1_2_3", expression("IF(COALESCE(combined_partition_size_1_2, BIGINT '-1') > COALESCE(input_3_partition_size, BIGINT '-1'), combined_partition_size_1_2, input_3_partition_size)"), + "combined_partition_column_1_2_3", expression("COALESCE(combined_partition_column_1_2, e)")), + join(// co-partition nodes + JoinType.LEFT, + ImmutableList.of(), + Optional.of("NOT (combined_partition_column_1_2 IS DISTINCT FROM e) " + + "AND (" + + " combined_row_number_1_2 = input_3_row_number OR" + + " (combined_row_number_1_2 > input_3_partition_size AND input_3_row_number = BIGINT '1' OR" + + " input_3_row_number > combined_partition_size_1_2 AND combined_row_number_1_2 = BIGINT '1'))"), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number_1_2", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size_1_2", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)"), + "combined_partition_column_1_2", expression("COALESCE(c, d)")), + join(// co-partition nodes + JoinType.INNER, + ImmutableList.of(), + Optional.of("NOT (c IS DISTINCT FROM d) " + + "AND (" + + " input_1_row_number = input_2_row_number OR" + + " (input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR" + + " input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1'))"), + window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + values("c")), + window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + values("d")))), + window(// append helper symbols for source input_3 + builder -> builder + .specification(specification(ImmutableList.of("e"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_3_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_3_partition_size", functionCall("count", ImmutableList.of())), + // input_3 + values("e"))))))); + } + + @Test + public void testTwoCoPartitionLists() + { + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + VariableReferenceExpression e = p.variable("e"); + VariableReferenceExpression f = p.variable("f"); + VariableReferenceExpression g = p.variable("g"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c), + p.values(d), + p.values(e), + p.values(f, g)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(c), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(d, true))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(d), Optional.empty()))), + new TableArgumentProperties( + "input_3", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(e, true))), + ImmutableList.of(e), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(e), Optional.empty()))), + new TableArgumentProperties( + "input_4", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(f, true))), + ImmutableList.of(g), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(f), Optional.of(new OrderingScheme(ImmutableList.of(new Ordering(g, DESC_NULLS_FIRST)))))))), + ImmutableList.of( + ImmutableList.of("input_1", "input_2"), + ImmutableList.of("input_3", "input_4"))); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableSet.of("c", "d", "e", "f")) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_2", + "e", "marker_3", + "f", "marker_4", + "g", "marker_4")) + .specification(specification(ImmutableList.of("combined_partition_column_1_2", "combined_partition_column_3_4"), ImmutableList.of("combined_row_number_1_2_3_4"), ImmutableMap.of("combined_row_number_1_2_3_4", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number_1_2_3_4, input_1_row_number, null)"), + "marker_2", expression("IF(input_2_row_number = combined_row_number_1_2_3_4, input_2_row_number, null)"), + "marker_3", expression("IF(input_3_row_number = combined_row_number_1_2_3_4, input_3_row_number, null)"), + "marker_4", expression("IF(input_4_row_number = combined_row_number_1_2_3_4, input_4_row_number, null)")), + project(// append helper symbols for joined nodes + ImmutableMap.of( + "combined_row_number_1_2_3_4", expression("IF(COALESCE(combined_row_number_1_2, BIGINT '-1') > COALESCE(combined_row_number_3_4, BIGINT '-1'), combined_row_number_1_2, combined_row_number_3_4)"), + "combined_partition_size_1_2_3_4", expression("IF(COALESCE(combined_partition_size_1_2, BIGINT '-1') > COALESCE(combined_partition_size_3_4, BIGINT '-1'), combined_partition_size_1_2, combined_partition_size_3_4)")), + join(// join nodes using helper symbols + JoinType.LEFT, + ImmutableList.of(), + Optional.of("combined_row_number_1_2 = combined_row_number_3_4 OR " + + "(combined_row_number_1_2 > combined_partition_size_3_4 AND combined_row_number_3_4 = BIGINT '1' OR " + + "combined_row_number_3_4 > combined_partition_size_1_2 AND combined_row_number_1_2 = BIGINT '1')"), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number_1_2", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size_1_2", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)"), + "combined_partition_column_1_2", expression("COALESCE(c, d)")), + join(// co-partition nodes + JoinType.INNER, + ImmutableList.of(), + Optional.of("NOT (c IS DISTINCT FROM d) " + + "AND ( " + + "input_1_row_number = input_2_row_number OR " + + "(input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR " + + "input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1'))"), + window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + values("c")), + window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + values("d")))), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number_3_4", expression("IF(COALESCE(input_3_row_number, BIGINT '-1') > COALESCE(input_4_row_number, BIGINT '-1'), input_3_row_number, input_4_row_number)"), + "combined_partition_size_3_4", expression("IF(COALESCE(input_3_partition_size, BIGINT '-1') > COALESCE(input_4_partition_size, BIGINT '-1'), input_3_partition_size, input_4_partition_size)"), + "combined_partition_column_3_4", expression("COALESCE(e, f)")), + join(// co-partition nodes + JoinType.FULL, + ImmutableList.of(), + Optional.of("NOT (e IS DISTINCT FROM f) " + + "AND ( " + + "input_3_row_number = input_4_row_number OR " + + "(input_3_row_number > input_4_partition_size AND input_4_row_number = BIGINT '1' OR " + + "input_4_row_number > input_3_partition_size AND input_3_row_number = BIGINT '1')) "), + window(// append helper symbols for source input_3 + builder -> builder + .specification(specification(ImmutableList.of("e"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_3_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_3_partition_size", functionCall("count", ImmutableList.of())), + // input_3 + values("e")), + window(// append helper symbols for source input_4 + builder -> builder + .specification(specification(ImmutableList.of("f"), ImmutableList.of("g"), ImmutableMap.of("g", DESC_NULLS_FIRST))) + .addFunction("input_4_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_4_partition_size", functionCall("count", ImmutableList.of())), + // input_4 + values("f", "g"))))))))); + } + + @Test + public void testCoPartitionedAndNotCoPartitionedSources() + { + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + VariableReferenceExpression e = p.variable("e"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c), + p.values(d), + p.values(e)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(c), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(d, true))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(d), Optional.empty()))), + new TableArgumentProperties( + "input_3", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(e, true))), + ImmutableList.of(e), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(e), Optional.empty())))), + ImmutableList.of(ImmutableList.of("input_2", "input_3"))); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableSet.of("c", "d", "e")) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_2", + "e", "marker_3")) + .specification(specification(ImmutableList.of("combined_partition_column_2_3", "c"), ImmutableList.of("combined_row_number_2_3_1"), ImmutableMap.of("combined_row_number_2_3_1", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number_2_3_1, input_1_row_number, null)"), + "marker_2", expression("IF(input_2_row_number = combined_row_number_2_3_1, input_2_row_number, null)"), + "marker_3", expression("IF(input_3_row_number = combined_row_number_2_3_1, input_3_row_number, null)")), + project(// append helper symbols for joined nodes + ImmutableMap.of( + "combined_row_number_2_3_1", expression("IF(COALESCE(combined_row_number_2_3, BIGINT '-1') > COALESCE(input_1_row_number, BIGINT '-1'), combined_row_number_2_3, input_1_row_number)"), + "combined_partition_size_2_3_1", expression("IF(COALESCE(combined_partition_size_2_3, BIGINT '-1') > COALESCE(input_1_partition_size, BIGINT '-1'), combined_partition_size_2_3, input_1_partition_size)")), + join(// join nodes using helper symbols + JoinType.INNER, + ImmutableList.of(), + Optional.of("combined_row_number_2_3 = input_1_row_number OR " + + "(combined_row_number_2_3 > input_1_partition_size AND input_1_row_number = BIGINT '1' OR " + + "input_1_row_number > combined_partition_size_2_3 AND combined_row_number_2_3 = BIGINT '1')"), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number_2_3", expression("IF(COALESCE(input_2_row_number, BIGINT '-1') > COALESCE(input_3_row_number, BIGINT '-1'), input_2_row_number, input_3_row_number)"), + "combined_partition_size_2_3", expression("IF(COALESCE(input_2_partition_size, BIGINT '-1') > COALESCE(input_3_partition_size, BIGINT '-1'), input_2_partition_size, input_3_partition_size)"), + "combined_partition_column_2_3", expression("COALESCE(d, e)")), + join(// co-partition nodes + JoinType.LEFT, + ImmutableList.of(), + Optional.of("NOT (d IS DISTINCT FROM e) " + + "AND ( " + + "input_2_row_number = input_3_row_number OR " + + "(input_2_row_number > input_3_partition_size AND input_3_row_number = BIGINT '1' OR " + + "input_3_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1'))"), + window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + values("d")), + window(// append helper symbols for source input_3 + builder -> builder + .specification(specification(ImmutableList.of("e"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_3_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_3_partition_size", functionCall("count", ImmutableList.of())), + // input_3 + values("e")))), + window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + values("c"))))))); + } + + @Test + public void testCoerceForCopartitioning() + { + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c", TINYINT); + VariableReferenceExpression cCoerced = p.variable("c_coerced", INTEGER); + VariableReferenceExpression d = p.variable("d"); + VariableReferenceExpression e = p.variable("e", INTEGER); + VariableReferenceExpression f = p.variable("f"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + // coerce column c for co-partitioning + p.project( + Assignments.builder() + .put(c, p.rowExpression("c")) + .put(d, p.rowExpression("d")) + .put(cCoerced, p.rowExpression("CAST(c AS INTEGER)")) + .build(), + p.values(c, d)), + p.values(e, f)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(c, d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(cCoerced), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + false, + new PassThroughSpecification(true, ImmutableList.of(new PassThroughColumn(e, true), new PassThroughColumn(f, false))), + ImmutableList.of(f), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(e), Optional.of(new OrderingScheme(ImmutableList.of(new Ordering(f, DESC_NULLS_FIRST)))))))), + ImmutableList.of(ImmutableList.of("input_1", "input_2"))); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableSet.of("c", "e", "f")) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "c_coerced", "marker_1", + "d", "marker_1", + "e", "marker_2", + "f", "marker_2")) + .specification(specification(ImmutableList.of("combined_partition_column"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number, input_1_row_number, null)"), + "marker_2", expression("IF(input_2_row_number = combined_row_number, input_2_row_number, null)")), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)"), + "combined_partition_column", expression("COALESCE(c_coerced, e)")), + join(// co-partition nodes + JoinType.LEFT, + ImmutableList.of(), + Optional.of("NOT (c_coerced IS DISTINCT FROM e) " + + "AND ( " + + " input_1_row_number = input_2_row_number OR" + + " (input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR" + + " input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1'))"), + window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c_coerced"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + project( + ImmutableMap.of("c_coerced", expression("CAST(c AS INTEGER)")), + values("c", "d"))), + window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of("e"), ImmutableList.of("f"), ImmutableMap.of("f", DESC_NULLS_FIRST))) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + values("e", "f"))))))); + } + + @Test + public void testTwoCoPartitioningColumns() + { + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + VariableReferenceExpression e = p.variable("e"); + VariableReferenceExpression f = p.variable("f"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c, d), + p.values(e, f)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true), new PassThroughColumn(d, true))), + ImmutableList.of(c), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c, d), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + false, + new PassThroughSpecification(true, ImmutableList.of(new PassThroughColumn(e, true), new PassThroughColumn(f, true))), + ImmutableList.of(e), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(e, f), Optional.empty())))), + ImmutableList.of(ImmutableList.of("input_1", "input_2"))); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableSet.of("c", "d", "e", "f")) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_1", + "e", "marker_2", + "f", "marker_2")) + .specification(specification(ImmutableList.of("combined_partition_column_1", "combined_partition_column_2"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number, input_1_row_number, null)"), + "marker_2", expression("IF(input_2_row_number = combined_row_number, input_2_row_number, null)")), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)"), + "combined_partition_column_1", expression("COALESCE(c, e)"), + "combined_partition_column_2", expression("COALESCE(d, f)")), + join(// co-partition nodes + JoinType.LEFT, + ImmutableList.of(), + Optional.of("NOT (c IS DISTINCT FROM e) " + + "AND NOT (d IS DISTINCT FROM f) " + + "AND ( " + + " input_1_row_number = input_2_row_number OR" + + " (input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR" + + " input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1'))"), + window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c", "d"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + values("c", "d")), + window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of("e", "f"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + values("e", "f"))))))); + } + + @Test + public void testTwoSourcesWithRowAndSetSemantics() + { + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + VariableReferenceExpression e = p.variable("e"); + VariableReferenceExpression f = p.variable("f"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c, d), + p.values(e, f)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + true, + false, + new PassThroughSpecification(true, ImmutableList.of(new PassThroughColumn(e, false), new PassThroughColumn(f, false))), + ImmutableList.of(e), + Optional.empty())), + ImmutableList.of()); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableSet.of("c", "e", "f")) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_1", + "e", "marker_2", + "f", "marker_2")) + .specification(specification(ImmutableList.of("c"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number, input_1_row_number, null)"), + "marker_2", expression("IF(input_2_row_number = combined_row_number, input_2_row_number, null)")), + project(// append helper symbols for joined nodes + ImmutableMap.of( + "combined_row_number", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)")), + join(// join nodes using helper symbols + JoinType.FULL, + ImmutableList.of(), + Optional.of("input_1_row_number = input_2_row_number OR " + + "(input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR " + + "input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1')"), + window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + values("c", "d")), + window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of(), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + values("e", "f"))))))); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMergeAdjacentWindows.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMergeAdjacentWindows.java index f0e9d391ec889..55c64c8c946a3 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMergeAdjacentWindows.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMergeAdjacentWindows.java @@ -16,6 +16,7 @@ import com.facebook.presto.common.block.SortOrder; import com.facebook.presto.spi.function.FunctionHandle; import com.facebook.presto.spi.plan.Assignments; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; import com.facebook.presto.spi.plan.Ordering; import com.facebook.presto.spi.plan.OrderingScheme; import com.facebook.presto.spi.plan.WindowNode; @@ -93,7 +94,7 @@ public class TestMergeAdjacentWindows private static final FunctionHandle LAG_FUNCTION_HANDLE = createTestMetadataManager().getFunctionAndTypeManager().lookupFunction("lag", fromTypes(DOUBLE)); private static final FunctionHandle RANK_FUNCTION_HANDLE = createTestMetadataManager().getFunctionAndTypeManager().lookupFunction("rank", ImmutableList.of()); private static final String columnAAlias = "ALIAS_A"; - private static final ExpectedValueProvider specificationA = + private static final ExpectedValueProvider specificationA = specification(ImmutableList.of(columnAAlias), ImmutableList.of(), ImmutableMap.of()); @Test @@ -275,14 +276,14 @@ public void testIntermediateProjectNodes() values(columnAAlias, unusedAlias)))))); } - private static WindowNode.Specification newWindowNodeSpecification(PlanBuilder planBuilder, String symbolName) + private static DataOrganizationSpecification newWindowNodeSpecification(PlanBuilder planBuilder, String symbolName) { - return new WindowNode.Specification(ImmutableList.of(planBuilder.variable(symbolName, BIGINT)), Optional.empty()); + return new DataOrganizationSpecification(ImmutableList.of(planBuilder.variable(symbolName, BIGINT)), Optional.empty()); } - private static WindowNode.Specification newWindowNodeSpecification(PlanBuilder planBuilder, String symbolName, String sortkey) + private static DataOrganizationSpecification newWindowNodeSpecification(PlanBuilder planBuilder, String symbolName, String sortkey) { - return new WindowNode.Specification(ImmutableList.of(planBuilder.variable(symbolName, BIGINT)), + return new DataOrganizationSpecification(ImmutableList.of(planBuilder.variable(symbolName, BIGINT)), Optional.of(new OrderingScheme( ImmutableList.of(new Ordering(planBuilder.variable(sortkey, BIGINT), SortOrder.ASC_NULLS_FIRST))))); } diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneWindowColumns.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneWindowColumns.java index c4f68ef174e8f..ccbf278341a55 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneWindowColumns.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneWindowColumns.java @@ -15,6 +15,7 @@ import com.facebook.presto.common.block.SortOrder; import com.facebook.presto.spi.function.FunctionHandle; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; import com.facebook.presto.spi.plan.Ordering; import com.facebook.presto.spi.plan.OrderingScheme; import com.facebook.presto.spi.plan.PlanNode; @@ -280,7 +281,7 @@ private static PlanNode buildProjectedWindow( .filter(projectionFilter) .collect(toImmutableList())), p.window( - new WindowNode.Specification( + new DataOrganizationSpecification( ImmutableList.of(partitionKey), Optional.of(new OrderingScheme( ImmutableList.of(new Ordering(orderKey, SortOrder.ASC_NULLS_FIRST))))), diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSwapAdjacentWindowsBySpecifications.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSwapAdjacentWindowsBySpecifications.java index 9209e2df9cdc4..b3fa34f67a8d1 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSwapAdjacentWindowsBySpecifications.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSwapAdjacentWindowsBySpecifications.java @@ -15,6 +15,7 @@ import com.facebook.presto.common.block.SortOrder; import com.facebook.presto.spi.function.FunctionHandle; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; import com.facebook.presto.spi.plan.Ordering; import com.facebook.presto.spi.plan.OrderingScheme; import com.facebook.presto.spi.plan.WindowNode; @@ -79,7 +80,7 @@ public void doesNotFireOnPlanWithoutWindowFunctions() public void doesNotFireOnPlanWithSingleWindowNode() { tester().assertThat(new GatherAndMergeWindows.SwapAdjacentWindowsBySpecifications(0)) - .on(p -> p.window(new WindowNode.Specification( + .on(p -> p.window(new DataOrganizationSpecification( ImmutableList.of(p.variable("a")), Optional.empty()), ImmutableMap.of(p.variable("avg_1"), @@ -94,16 +95,16 @@ public void subsetComesFirst() String columnAAlias = "ALIAS_A"; String columnBAlias = "ALIAS_B"; - ExpectedValueProvider specificationA = specification(ImmutableList.of(columnAAlias), ImmutableList.of(), ImmutableMap.of()); - ExpectedValueProvider specificationAB = specification(ImmutableList.of(columnAAlias, columnBAlias), ImmutableList.of(), ImmutableMap.of()); + ExpectedValueProvider specificationA = specification(ImmutableList.of(columnAAlias), ImmutableList.of(), ImmutableMap.of()); + ExpectedValueProvider specificationAB = specification(ImmutableList.of(columnAAlias, columnBAlias), ImmutableList.of(), ImmutableMap.of()); tester().assertThat(new GatherAndMergeWindows.SwapAdjacentWindowsBySpecifications(0)) .on(p -> - p.window(new WindowNode.Specification( + p.window(new DataOrganizationSpecification( ImmutableList.of(p.variable("a")), Optional.empty()), ImmutableMap.of(p.variable("avg_1", DOUBLE), newWindowNodeFunction(ImmutableList.of(new Symbol("a")))), - p.window(new WindowNode.Specification( + p.window(new DataOrganizationSpecification( ImmutableList.of(p.variable("a"), p.variable("b")), Optional.empty()), ImmutableMap.of(p.variable("avg_2", DOUBLE), newWindowNodeFunction(ImmutableList.of(new Symbol("b")))), @@ -123,11 +124,11 @@ public void dependentWindowsAreNotReordered() { tester().assertThat(new GatherAndMergeWindows.SwapAdjacentWindowsBySpecifications(0)) .on(p -> - p.window(new WindowNode.Specification( + p.window(new DataOrganizationSpecification( ImmutableList.of(p.variable("a")), Optional.empty()), ImmutableMap.of(p.variable("avg_1"), newWindowNodeFunction(ImmutableList.of(new Symbol("avg_2")))), - p.window(new WindowNode.Specification( + p.window(new DataOrganizationSpecification( ImmutableList.of(p.variable("a"), p.variable("b")), Optional.empty()), ImmutableMap.of(p.variable("avg_2"), newWindowNodeFunction(ImmutableList.of(new Symbol("a")))), @@ -168,11 +169,11 @@ public void dependentWindowsAreNotReorderedWithOffset() tester().assertThat(new GatherAndMergeWindows.SwapAdjacentWindowsBySpecifications(0)) .on(p -> - p.window(new WindowNode.Specification( + p.window(new DataOrganizationSpecification( ImmutableList.of(p.variable("a")), Optional.of(new OrderingScheme(ImmutableList.of(new Ordering(p.variable("sortkey", BIGINT), SortOrder.ASC_NULLS_FIRST))))), ImmutableMap.of(p.variable("avg_1"), functionWithOffset), - p.window(new WindowNode.Specification( + p.window(new DataOrganizationSpecification( ImmutableList.of(p.variable("a"), p.variable("b")), Optional.of(new OrderingScheme(ImmutableList.of(new Ordering(p.variable("sortkey", BIGINT), SortOrder.ASC_NULLS_FIRST))))), ImmutableMap.of(p.variable("startValue"), windowFunction), @@ -213,11 +214,11 @@ public void dependentWindowsWithRangeAreNotReordered() tester().assertThat(new GatherAndMergeWindows.SwapAdjacentWindowsBySpecifications(0)) .on(p -> - p.window(new WindowNode.Specification( + p.window(new DataOrganizationSpecification( ImmutableList.of(p.variable("a")), Optional.of(new OrderingScheme(ImmutableList.of(new Ordering(p.variable("sortkey", BIGINT), SortOrder.ASC_NULLS_FIRST))))), ImmutableMap.of(p.variable("avg_1"), functionWithOffset), - p.window(new WindowNode.Specification( + p.window(new DataOrganizationSpecification( ImmutableList.of(p.variable("a"), p.variable("b")), Optional.of(new OrderingScheme(ImmutableList.of(new Ordering(p.variable("sortkey", BIGINT), SortOrder.ASC_NULLS_FIRST))))), ImmutableMap.of(p.variable("startValue"), windowFunction), diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java index 590601b3d7f9b..4cf0286f48154 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java @@ -19,6 +19,7 @@ import com.facebook.presto.common.predicate.TupleDomain; import com.facebook.presto.common.type.Type; import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.metadata.TableFunctionHandle; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.IndexHandle; @@ -27,11 +28,14 @@ import com.facebook.presto.spi.WarningCollector; import com.facebook.presto.spi.constraints.TableConstraint; import com.facebook.presto.spi.function.FunctionHandle; +import com.facebook.presto.spi.function.SchemaFunctionName; +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; import com.facebook.presto.spi.plan.AggregationNode; import com.facebook.presto.spi.plan.AggregationNode.Aggregation; import com.facebook.presto.spi.plan.AggregationNode.Step; import com.facebook.presto.spi.plan.Assignments; import com.facebook.presto.spi.plan.CteProducerNode; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; import com.facebook.presto.spi.plan.DeleteNode; import com.facebook.presto.spi.plan.DistinctLimitNode; import com.facebook.presto.spi.plan.EquiJoinClause; @@ -84,6 +88,7 @@ import com.facebook.presto.sql.planner.plan.RemoteSourceNode; import com.facebook.presto.sql.planner.plan.RowNumberNode; import com.facebook.presto.sql.planner.plan.SampleNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; import com.facebook.presto.sql.planner.plan.UnnestNode; import com.facebook.presto.sql.relational.FunctionResolution; import com.facebook.presto.sql.relational.SqlToRowExpressionTranslator; @@ -933,7 +938,7 @@ public PlanBuilder registerVariable(VariableReferenceExpression expression) return this; } - public WindowNode window(WindowNode.Specification specification, Map functions, PlanNode source) + public WindowNode window(DataOrganizationSpecification specification, Map functions, PlanNode source) { return new WindowNode( Optional.empty(), @@ -946,7 +951,7 @@ public WindowNode window(WindowNode.Specification specification, Map functions, VariableReferenceExpression hashVariable, PlanNode source) + public WindowNode window(DataOrganizationSpecification specification, Map functions, VariableReferenceExpression hashVariable, PlanNode source) { return new WindowNode( Optional.empty(), @@ -959,6 +964,25 @@ public WindowNode window(WindowNode.Specification specification, Map properOutputs, + List sources, + List tableArgumentProperties, + List> copartitioningLists) + + { + return new TableFunctionNode( + idAllocator.getNextId(), + name, + ImmutableMap.of(), + properOutputs, + sources, + tableArgumentProperties, + copartitioningLists, + new TableFunctionHandle(new ConnectorId("connector_id"), new SchemaFunctionName("system", name), new ConnectorTableFunctionHandle() {}, TestingTransactionHandle.create())); + } + public RowNumberNode rowNumber(List partitionBy, Optional maxRowCountPerPartition, VariableReferenceExpression rowNumberVariable, PlanNode source) { return new RowNumberNode( diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestEliminateSorts.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestEliminateSorts.java index aa0456e3bcfb3..57550cc78755b 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestEliminateSorts.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestEliminateSorts.java @@ -15,7 +15,7 @@ import com.facebook.presto.Session; import com.facebook.presto.common.block.SortOrder; -import com.facebook.presto.spi.plan.WindowNode; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; import com.facebook.presto.sql.planner.PartitioningProviderManager; import com.facebook.presto.sql.planner.RuleStatsRecorder; import com.facebook.presto.sql.planner.assertions.BasePlanTest; @@ -53,7 +53,7 @@ public class TestEliminateSorts private static final String QUANTITY_ALIAS = "QUANTITY"; private static final String TAX_ALIAS = "TAX"; - private static final ExpectedValueProvider windowSpec = specification( + private static final ExpectedValueProvider windowSpec = specification( ImmutableList.of(), ImmutableList.of(QUANTITY_ALIAS), ImmutableMap.of(QUANTITY_ALIAS, SortOrder.ASC_NULLS_LAST)); diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestMergeWindows.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestMergeWindows.java index 379bb4cfd2780..ac48f5531d741 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestMergeWindows.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestMergeWindows.java @@ -14,6 +14,7 @@ package com.facebook.presto.sql.planner.optimizations; import com.facebook.presto.common.block.SortOrder; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; import com.facebook.presto.spi.plan.JoinType; import com.facebook.presto.spi.plan.WindowNode; import com.facebook.presto.sql.planner.RuleStatsRecorder; @@ -88,8 +89,8 @@ public class TestMergeWindows private static final Optional UNSPECIFIED_FRAME = Optional.empty(); - private final ExpectedValueProvider specificationA; - private final ExpectedValueProvider specificationB; + private final ExpectedValueProvider specificationA; + private final ExpectedValueProvider specificationB; public TestMergeWindows() { @@ -292,12 +293,12 @@ public void testIdenticalWindowSpecificationsAAfilterA() @Test public void testIdenticalWindowSpecificationsDefaultFrame() { - ExpectedValueProvider specificationC = specification( + ExpectedValueProvider specificationC = specification( ImmutableList.of(SUPPKEY_ALIAS), ImmutableList.of(ORDERKEY_ALIAS), ImmutableMap.of(ORDERKEY_ALIAS, SortOrder.ASC_NULLS_LAST)); - ExpectedValueProvider specificationD = specification( + ExpectedValueProvider specificationD = specification( ImmutableList.of(ORDERKEY_ALIAS), ImmutableList.of(SHIPDATE_ALIAS), ImmutableMap.of(SHIPDATE_ALIAS, SortOrder.ASC_NULLS_LAST)); @@ -328,7 +329,7 @@ public void testMergeDifferentFrames() new FrameBound(FrameBound.Type.UNBOUNDED_PRECEDING), Optional.of(new FrameBound(FrameBound.Type.CURRENT_ROW)))); - ExpectedValueProvider specificationC = specification( + ExpectedValueProvider specificationC = specification( ImmutableList.of(SUPPKEY_ALIAS), ImmutableList.of(ORDERKEY_ALIAS), ImmutableMap.of(ORDERKEY_ALIAS, SortOrder.ASC_NULLS_LAST)); @@ -362,7 +363,7 @@ public void testMergeDifferentFramesWithDefault() new FrameBound(FrameBound.Type.CURRENT_ROW), Optional.of(new FrameBound(FrameBound.Type.UNBOUNDED_FOLLOWING)))); - ExpectedValueProvider specificationD = specification( + ExpectedValueProvider specificationD = specification( ImmutableList.of(SUPPKEY_ALIAS), ImmutableList.of(ORDERKEY_ALIAS), ImmutableMap.of(ORDERKEY_ALIAS, SortOrder.ASC_NULLS_LAST)); @@ -391,7 +392,7 @@ public void testMergeRangeFramesWithDefault() new FrameBound(FrameBound.Type.CURRENT_ROW), Optional.of(new FrameBound(FrameBound.Type.UNBOUNDED_FOLLOWING)))); - ExpectedValueProvider specificationD = specification( + ExpectedValueProvider specificationD = specification( ImmutableList.of(SUPPKEY_ALIAS), ImmutableList.of(ORDERKEY_ALIAS), ImmutableMap.of(ORDERKEY_ALIAS, SortOrder.ASC_NULLS_LAST)); @@ -433,12 +434,12 @@ public void testNotMergeAcrossJoinBranches() ")" + "SELECT * FROM foo, bar WHERE foo.a = bar.b"; - ExpectedValueProvider leftSpecification = specification( + ExpectedValueProvider leftSpecification = specification( ImmutableList.of(ORDERKEY_ALIAS), ImmutableList.of(SHIPDATE_ALIAS, QUANTITY_ALIAS), ImmutableMap.of(SHIPDATE_ALIAS, SortOrder.ASC_NULLS_LAST, QUANTITY_ALIAS, SortOrder.DESC_NULLS_LAST)); - ExpectedValueProvider rightSpecification = specification( + ExpectedValueProvider rightSpecification = specification( ImmutableList.of(rOrderkeyAlias), ImmutableList.of(rShipdateAlias, rQuantityAlias), ImmutableMap.of(rShipdateAlias, SortOrder.ASC_NULLS_LAST, rQuantityAlias, SortOrder.DESC_NULLS_LAST)); @@ -489,7 +490,7 @@ public void testNotMergeDifferentPartition() "SUM(quantity) over (PARTITION BY quantity ORDER BY orderkey ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) sum_quantity_C " + "FROM lineitem"; - ExpectedValueProvider specificationC = specification( + ExpectedValueProvider specificationC = specification( ImmutableList.of(QUANTITY_ALIAS), ImmutableList.of(ORDERKEY_ALIAS), ImmutableMap.of(ORDERKEY_ALIAS, SortOrder.ASC_NULLS_LAST)); @@ -513,7 +514,7 @@ public void testNotMergeDifferentOrderBy() "SUM(quantity) OVER (PARTITION BY suppkey ORDER BY quantity ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) sum_quantity_C " + "FROM lineitem"; - ExpectedValueProvider specificationC = specification( + ExpectedValueProvider specificationC = specification( ImmutableList.of(SUPPKEY_ALIAS), ImmutableList.of(QUANTITY_ALIAS), ImmutableMap.of(QUANTITY_ALIAS, SortOrder.ASC_NULLS_LAST)); @@ -538,7 +539,7 @@ public void testNotMergeDifferentOrdering() "SUM(discount) over (PARTITION BY suppkey ORDER BY orderkey ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) sum_discount_A " + "FROM lineitem"; - ExpectedValueProvider specificationC = specification( + ExpectedValueProvider specificationC = specification( ImmutableList.of(SUPPKEY_ALIAS), ImmutableList.of(ORDERKEY_ALIAS), ImmutableMap.of(ORDERKEY_ALIAS, SortOrder.DESC_NULLS_LAST)); @@ -564,7 +565,7 @@ public void testNotMergeDifferentNullOrdering() "SUM(discount) OVER (PARTITION BY suppkey ORDER BY orderkey ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) sum_discount_A " + "FROM lineitem"; - ExpectedValueProvider specificationC = specification( + ExpectedValueProvider specificationC = specification( ImmutableList.of(SUPPKEY_ALIAS), ImmutableList.of(ORDERKEY_ALIAS), ImmutableMap.of(ORDERKEY_ALIAS, SortOrder.ASC_NULLS_FIRST)); diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestPruneUnreferencedOutputs.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestPruneUnreferencedOutputs.java index af11a87d77e33..6a1eff9da903d 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestPruneUnreferencedOutputs.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestPruneUnreferencedOutputs.java @@ -16,6 +16,7 @@ import com.facebook.presto.common.block.SortOrder; import com.facebook.presto.spi.function.FunctionHandle; import com.facebook.presto.spi.plan.Assignments; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; import com.facebook.presto.spi.plan.Ordering; import com.facebook.presto.spi.plan.OrderingScheme; import com.facebook.presto.spi.plan.WindowNode; @@ -71,7 +72,7 @@ public void windowNodePruning() p.output(ImmutableList.of("user_uuid"), ImmutableList.of(p.variable("user_uuid", VARCHAR)), p.project(Assignments.of(p.variable("user_uuid", VARCHAR), p.variable("user_uuid", VARCHAR)), p.window( - new WindowNode.Specification( + new DataOrganizationSpecification( ImmutableList.of(p.variable("user_uuid", VARCHAR)), Optional.of(new OrderingScheme( ImmutableList.of( diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestReorderWindows.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestReorderWindows.java index 1ea58b5627d57..a62a2648040de 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestReorderWindows.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestReorderWindows.java @@ -14,7 +14,7 @@ package com.facebook.presto.sql.planner.optimizations; import com.facebook.presto.common.block.SortOrder; -import com.facebook.presto.spi.plan.WindowNode; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; import com.facebook.presto.sql.InMemoryExpressionOptimizerProvider; import com.facebook.presto.sql.planner.RuleStatsRecorder; import com.facebook.presto.sql.planner.assertions.BasePlanTest; @@ -59,13 +59,13 @@ public class TestReorderWindows private static final Optional commonFrame; - private static final ExpectedValueProvider windowA; - private static final ExpectedValueProvider windowAp; - private static final ExpectedValueProvider windowApp; - private static final ExpectedValueProvider windowB; - private static final ExpectedValueProvider windowC; - private static final ExpectedValueProvider windowD; - private static final ExpectedValueProvider windowE; + private static final ExpectedValueProvider windowA; + private static final ExpectedValueProvider windowAp; + private static final ExpectedValueProvider windowApp; + private static final ExpectedValueProvider windowB; + private static final ExpectedValueProvider windowC; + private static final ExpectedValueProvider windowD; + private static final ExpectedValueProvider windowE; static { ImmutableMap.Builder columns = ImmutableMap.builder(); diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/plan/TestWindowNode.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/plan/TestWindowNode.java index 767ce955bb423..fa5a8c6db51dc 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/plan/TestWindowNode.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/plan/TestWindowNode.java @@ -25,6 +25,7 @@ import com.facebook.presto.server.SliceSerializer; import com.facebook.presto.spi.VariableAllocator; import com.facebook.presto.spi.function.FunctionHandle; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; import com.facebook.presto.spi.plan.Ordering; import com.facebook.presto.spi.plan.OrderingScheme; import com.facebook.presto.spi.plan.PlanNodeId; @@ -116,7 +117,7 @@ public void testSerializationRoundtrip() Optional.empty()); PlanNodeId id = newId(); - WindowNode.Specification specification = new WindowNode.Specification( + DataOrganizationSpecification specification = new DataOrganizationSpecification( ImmutableList.of(columnA), Optional.of(new OrderingScheme(ImmutableList.of(new Ordering(columnB, SortOrder.ASC_NULLS_FIRST))))); CallExpression call = call("sum", functionHandle, BIGINT, new VariableReferenceExpression(Optional.empty(), columnC.getName(), BIGINT)); diff --git a/presto-main-base/src/test/java/com/facebook/presto/type/AbstractTestType.java b/presto-main-base/src/test/java/com/facebook/presto/type/AbstractTestType.java index dfc25091c6ae0..e0727d18c2b35 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/type/AbstractTestType.java +++ b/presto-main-base/src/test/java/com/facebook/presto/type/AbstractTestType.java @@ -24,6 +24,7 @@ import com.facebook.presto.common.type.UnknownType; import com.facebook.presto.metadata.FunctionAndTypeManager; import com.facebook.presto.metadata.HandleResolver; +import com.facebook.presto.metadata.TableFunctionRegistry; import com.facebook.presto.sql.analyzer.FeaturesConfig; import com.facebook.presto.sql.analyzer.FunctionsConfig; import com.google.common.collect.ImmutableMap; @@ -67,6 +68,7 @@ public abstract class AbstractTestType private static final BlockEncodingSerde blockEncodingSerde = new BlockEncodingManager(); protected static final FunctionAndTypeManager functionAndTypeManager = new FunctionAndTypeManager( createTestTransactionManager(), + new TableFunctionRegistry(), blockEncodingSerde, new FeaturesConfig(), new FunctionsConfig(), diff --git a/presto-main/pom.xml b/presto-main/pom.xml index 77a61c0480d97..cf8b178cf77a7 100644 --- a/presto-main/pom.xml +++ b/presto-main/pom.xml @@ -311,7 +311,7 @@ io.netty netty-common - + com.squareup.okhttp3 mockwebserver diff --git a/presto-main/src/main/java/com/facebook/presto/server/ServerMainModule.java b/presto-main/src/main/java/com/facebook/presto/server/ServerMainModule.java index ecce8a92ce850..715e333790d8a 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/ServerMainModule.java +++ b/presto-main/src/main/java/com/facebook/presto/server/ServerMainModule.java @@ -104,6 +104,7 @@ import com.facebook.presto.metadata.StaticCatalogStoreConfig; import com.facebook.presto.metadata.StaticFunctionNamespaceStore; import com.facebook.presto.metadata.StaticFunctionNamespaceStoreConfig; +import com.facebook.presto.metadata.TableFunctionRegistry; import com.facebook.presto.metadata.TablePropertyManager; import com.facebook.presto.nodeManager.PluginNodeManager; import com.facebook.presto.operator.ExchangeClientConfig; @@ -634,6 +635,7 @@ public ListeningExecutorService createResourceManagerExecutor(ResourceManagerCon binder.bind(StaticFunctionNamespaceStore.class).in(Scopes.SINGLETON); configBinder(binder).bindConfig(StaticFunctionNamespaceStoreConfig.class); binder.bind(FunctionAndTypeManager.class).in(Scopes.SINGLETON); + binder.bind(TableFunctionRegistry.class).in(Scopes.SINGLETON); binder.bind(MetadataManager.class).in(Scopes.SINGLETON); if (serverConfig.isCatalogServerEnabled() && serverConfig.isCoordinator()) { diff --git a/presto-parser/src/main/antlr4/com/facebook/presto/sql/parser/SqlBase.g4 b/presto-parser/src/main/antlr4/com/facebook/presto/sql/parser/SqlBase.g4 index e5ee6b44b9a40..0982c5b2d6b20 100644 --- a/presto-parser/src/main/antlr4/com/facebook/presto/sql/parser/SqlBase.g4 +++ b/presto-parser/src/main/antlr4/com/facebook/presto/sql/parser/SqlBase.g4 @@ -345,6 +345,7 @@ relationPrimary | UNNEST '(' expression (',' expression)* ')' (WITH ORDINALITY)? #unnest | LATERAL '(' query ')' #lateral | '(' relation ')' #parenthesizedRelation + | TABLE '(' tableFunctionCall ')' #tableFunctionInvocation ; expression @@ -473,6 +474,40 @@ type | INTERVAL from=intervalField TO to=intervalField ; +tableFunctionCall + : qualifiedName '(' (tableFunctionArgument (',' tableFunctionArgument)*)? + (COPARTITION copartitionTables (',' copartitionTables)*)? ')' + ; + +tableFunctionArgument + : (identifier '=>')? (tableArgument | descriptorArgument | expression) // descriptor before expression to avoid parsing descriptor as a function call + ; + +tableArgument + : tableArgumentRelation + (PARTITION BY ('(' (expression (',' expression)*)? ')' | expression))? + (PRUNE WHEN EMPTY | KEEP WHEN EMPTY)? + (ORDER BY ('(' sortItem (',' sortItem)* ')' | sortItem))? + ; + +tableArgumentRelation + : TABLE '(' qualifiedName ')' (AS? identifier columnAliases?)? #tableArgumentTable + | TABLE '(' query ')' (AS? identifier columnAliases?)? #tableArgumentQuery + ; + +descriptorArgument + : DESCRIPTOR '(' descriptorField (',' descriptorField)* ')' + | CAST '(' NULL AS DESCRIPTOR ')' + ; + +descriptorField + : identifier type? + ; + +copartitionTables + : '(' qualifiedName ',' qualifiedName (',' qualifiedName)* ')' + ; + typeParameter : INTEGER_VALUE | type ; @@ -632,20 +667,20 @@ nonReserved // IMPORTANT: this rule must only contain tokens. Nested rules are not supported. See SqlParser.exitNonReserved : ADD | ADMIN | ALL | ANALYZE | ANY | ARRAY | ASC | AT | BEFORE | BERNOULLI - | CALL | CALLED | CASCADE | CATALOGS | COLUMN | COLUMNS | COMMENT | COMMIT | COMMITTED | CURRENT | CURRENT_ROLE - | DATA | DATE | DAY | DEFINER | DESC | DETERMINISTIC | DISABLED | DISTRIBUTED - | ENABLED | ENFORCED | EXCLUDING | EXPLAIN | EXTERNAL + | CALL | CALLED | CASCADE | CATALOGS | COLUMN | COLUMNS | COMMENT | COMMIT | COMMITTED | COPARTITION | CURRENT | CURRENT_ROLE + | DATA | DATE | DAY | DEFINER | DESC | DESCRIPTOR | DETERMINISTIC | DISABLED | DISTRIBUTED + | EMPTY | ENABLED | ENFORCED | EXCLUDING | EXPLAIN | EXTERNAL | FETCH | FILTER | FIRST | FOLLOWING | FORMAT | FUNCTION | FUNCTIONS | GRANT | GRANTED | GRANTS | GRAPHVIZ | GROUPS | HOUR | IF | IGNORE | INCLUDING | INPUT | INTERVAL | INVOKER | IO | ISOLATION | JSON - | KEY + | KEEP | KEY | LANGUAGE | LAST | LATERAL | LEVEL | LIMIT | LOGICAL | MAP | MATERIALIZED | MINUTE | MONTH | NAME | NFC | NFD | NFKC | NFKD | NO | NONE | NULLIF | NULLS | OF | OFFSET | ONLY | OPTION | ORDINALITY | OUTPUT | OVER - | PARTITION | PARTITIONS | POSITION | PRECEDING | PRIMARY | PRIVILEGES | PROPERTIES + | PARTITION | PARTITIONS | POSITION | PRECEDING | PRIMARY | PRIVILEGES | PROPERTIES | PRUNE | RANGE | READ | REFRESH | RELY | RENAME | REPEATABLE | REPLACE | RESET | RESPECT | RESTRICT | RETURN | RETURNS | REVOKE | ROLE | ROLES | ROLLBACK | ROW | ROWS | SCHEMA | SCHEMAS | SECOND | SECURITY | SERIALIZABLE | SESSION | SET | SETS | SQL | SHOW | SOME | START | STATS | SUBSTRING | SYSTEM | SYSTEM_TIME | SYSTEM_VERSION @@ -685,6 +720,7 @@ COMMIT: 'COMMIT'; COMMITTED: 'COMMITTED'; CONSTRAINT: 'CONSTRAINT'; CREATE: 'CREATE'; +COPARTITION: 'COPARTITION'; CROSS: 'CROSS'; CUBE: 'CUBE'; CURRENT: 'CURRENT'; @@ -701,12 +737,14 @@ DEFINER: 'DEFINER'; DELETE: 'DELETE'; DESC: 'DESC'; DESCRIBE: 'DESCRIBE'; +DESCRIPTOR: 'DESCRIPTOR'; DETERMINISTIC: 'DETERMINISTIC'; DISABLED: 'DISABLED'; DISTINCT: 'DISTINCT'; DISTRIBUTED: 'DISTRIBUTED'; DROP: 'DROP'; ELSE: 'ELSE'; +EMPTY: 'EMPTY'; ENABLED: 'ENABLED'; END: 'END'; ENFORCED: 'ENFORCED'; @@ -754,6 +792,7 @@ IS: 'IS'; ISOLATION: 'ISOLATION'; JSON: 'JSON'; JOIN: 'JOIN'; +KEEP: 'KEEP'; KEY: 'KEY'; LANGUAGE: 'LANGUAGE'; LAST: 'LAST'; @@ -801,6 +840,7 @@ PREPARE: 'PREPARE'; PRIMARY: 'PRIMARY'; PRIVILEGES: 'PRIVILEGES'; PROPERTIES: 'PROPERTIES'; +PRUNE: 'PRUNE'; RANGE: 'RANGE'; READ: 'READ'; RECURSIVE: 'RECURSIVE'; diff --git a/presto-parser/src/main/java/com/facebook/presto/sql/ExpressionFormatter.java b/presto-parser/src/main/java/com/facebook/presto/sql/ExpressionFormatter.java index 29549470dae9e..5524c03cb3f78 100644 --- a/presto-parser/src/main/java/com/facebook/presto/sql/ExpressionFormatter.java +++ b/presto-parser/src/main/java/com/facebook/presto/sql/ExpressionFormatter.java @@ -738,7 +738,7 @@ static String formatOrderBy(OrderBy orderBy, Optional> paramete return "ORDER BY " + formatSortItems(orderBy.getSortItems(), parameters); } - static String formatSortItems(List sortItems, Optional> parameters) + public static String formatSortItems(List sortItems, Optional> parameters) { return Joiner.on(", ").join(sortItems.stream() .map(sortItemFormatterFunction(parameters)) diff --git a/presto-parser/src/main/java/com/facebook/presto/sql/SqlFormatter.java b/presto-parser/src/main/java/com/facebook/presto/sql/SqlFormatter.java index 29553e2bc4c95..0013c2a0aa2d8 100644 --- a/presto-parser/src/main/java/com/facebook/presto/sql/SqlFormatter.java +++ b/presto-parser/src/main/java/com/facebook/presto/sql/SqlFormatter.java @@ -112,6 +112,10 @@ import com.facebook.presto.sql.tree.SqlParameterDeclaration; import com.facebook.presto.sql.tree.StartTransaction; import com.facebook.presto.sql.tree.Table; +import com.facebook.presto.sql.tree.TableFunctionArgument; +import com.facebook.presto.sql.tree.TableFunctionDescriptorArgument; +import com.facebook.presto.sql.tree.TableFunctionInvocation; +import com.facebook.presto.sql.tree.TableFunctionTableArgument; import com.facebook.presto.sql.tree.TableSubquery; import com.facebook.presto.sql.tree.TransactionAccessMode; import com.facebook.presto.sql.tree.TransactionMode; @@ -208,6 +212,113 @@ protected Void visitLateral(Lateral node, Integer indent) return null; } + @Override + protected Void visitTableFunctionInvocation(TableFunctionInvocation node, Integer indent) + { + append(indent, "TABLE("); + appendTableFunctionInvocation(node, indent + 1); + builder.append(")"); + return null; + } + + private void appendTableFunctionInvocation(TableFunctionInvocation node, Integer indent) + { + builder.append(formatName(node.getName())) + .append("(\n"); + appendTableFunctionArguments(node.getArguments(), indent + 1); + if (!node.getCopartitioning().isEmpty()) { + builder.append("\n"); + append(indent + 1, "COPARTITION "); + builder.append(node.getCopartitioning().stream() + .map(tableList -> tableList.stream() + .map(Formatter::formatName) + .collect(Collectors.joining(", ", "(", ")"))) + .collect(Collectors.joining(", "))); + } + builder.append(")"); + } + + private void appendTableFunctionArguments(List arguments, int indent) + { + for (int i = 0; i < arguments.size(); i++) { + TableFunctionArgument argument = arguments.get(i); + if (argument.getName().isPresent()) { + append(indent, formatExpression(argument.getName().get(), parameters)); + builder.append(" => "); + } + else { + append(indent, ""); + } + Node value = argument.getValue(); + if (value instanceof Expression) { + builder.append(formatExpression((Expression) value, parameters)); + } + else { + process(value, indent + 1); + } + if (i < arguments.size() - 1) { + builder.append(",\n"); + } + } + } + + @Override + protected Void visitTableArgument(TableFunctionTableArgument node, Integer indent) + { + Relation relation = node.getTable(); + Node unaliased = relation instanceof AliasedRelation ? ((AliasedRelation) relation).getRelation() : relation; + if (unaliased instanceof TableSubquery) { + // unpack the relation from TableSubquery to avoid adding another pair of parentheses + unaliased = ((TableSubquery) unaliased).getQuery(); + } + builder.append("TABLE("); + process(unaliased, indent); + builder.append(")"); + if (relation instanceof AliasedRelation) { + AliasedRelation aliasedRelation = (AliasedRelation) relation; + builder.append(" AS ") + .append(formatExpression(aliasedRelation.getAlias(), parameters)); + appendAliasColumns(builder, aliasedRelation.getColumnNames()); + } + if (node.getPartitionBy().isPresent()) { + builder.append("\n"); + append(indent, "PARTITION BY ") + .append(node.getPartitionBy().get().stream() + .map(expr -> formatExpression(expr, parameters)) + .collect(joining(", "))); + } + node.getEmptyTableTreatment().ifPresent(treatment -> { + builder.append("\n"); + append(indent, treatment.getTreatment().name() + " WHEN EMPTY"); + }); + node.getOrderBy().ifPresent(orderBy -> { + builder.append("\n"); + append(indent, formatOrderBy(orderBy, Optional.empty())); + }); + return null; + } + + @Override + protected Void visitDescriptorArgument(TableFunctionDescriptorArgument node, Integer indent) + { + if (node.getDescriptor().isPresent()) { + builder.append(node.getDescriptor().get().getFields().stream() + .map(field -> { + String formattedField = formatExpression(field.getName(), parameters); + if (field.getType().isPresent()) { + formattedField = formattedField + " " + field.getType().get(); + } + return formattedField; + }) + .collect(Collectors.joining(", ", "DESCRIPTOR(", ")"))); + } + else { + builder.append("CAST (NULL AS DESCRIPTOR)"); + } + + return null; + } + @Override protected Void visitPrepare(Prepare node, Integer indent) { diff --git a/presto-parser/src/main/java/com/facebook/presto/sql/parser/AstBuilder.java b/presto-parser/src/main/java/com/facebook/presto/sql/parser/AstBuilder.java index 7b1dd90a57f5e..df6dbeed36e8a 100644 --- a/presto-parser/src/main/java/com/facebook/presto/sql/parser/AstBuilder.java +++ b/presto-parser/src/main/java/com/facebook/presto/sql/parser/AstBuilder.java @@ -55,6 +55,8 @@ import com.facebook.presto.sql.tree.DereferenceExpression; import com.facebook.presto.sql.tree.DescribeInput; import com.facebook.presto.sql.tree.DescribeOutput; +import com.facebook.presto.sql.tree.Descriptor; +import com.facebook.presto.sql.tree.DescriptorField; import com.facebook.presto.sql.tree.DoubleLiteral; import com.facebook.presto.sql.tree.DropColumn; import com.facebook.presto.sql.tree.DropConstraint; @@ -64,6 +66,8 @@ import com.facebook.presto.sql.tree.DropSchema; import com.facebook.presto.sql.tree.DropTable; import com.facebook.presto.sql.tree.DropView; +import com.facebook.presto.sql.tree.EmptyTableTreatment; +import com.facebook.presto.sql.tree.EmptyTableTreatment.Treatment; import com.facebook.presto.sql.tree.Except; import com.facebook.presto.sql.tree.Execute; import com.facebook.presto.sql.tree.ExistsPredicate; @@ -169,6 +173,9 @@ import com.facebook.presto.sql.tree.SubscriptExpression; import com.facebook.presto.sql.tree.Table; import com.facebook.presto.sql.tree.TableElement; +import com.facebook.presto.sql.tree.TableFunctionArgument; +import com.facebook.presto.sql.tree.TableFunctionInvocation; +import com.facebook.presto.sql.tree.TableFunctionTableArgument; import com.facebook.presto.sql.tree.TableSubquery; import com.facebook.presto.sql.tree.TableVersionExpression; import com.facebook.presto.sql.tree.TimeLiteral; @@ -212,6 +219,8 @@ import static com.facebook.presto.sql.tree.RoutineCharacteristics.NullCallClause.CALLED_ON_NULL_INPUT; import static com.facebook.presto.sql.tree.RoutineCharacteristics.NullCallClause.RETURNS_NULL_ON_NULL_INPUT; import static com.facebook.presto.sql.tree.SetProperties.Type.TABLE; +import static com.facebook.presto.sql.tree.TableFunctionDescriptorArgument.descriptorArgument; +import static com.facebook.presto.sql.tree.TableFunctionDescriptorArgument.nullDescriptorArgument; import static com.facebook.presto.sql.tree.TableVersionExpression.TableVersionOperator; import static com.facebook.presto.sql.tree.TableVersionExpression.TableVersionOperator.EQUAL; import static com.facebook.presto.sql.tree.TableVersionExpression.TableVersionOperator.LESS_THAN; @@ -1483,6 +1492,123 @@ public Node visitLateral(SqlBaseParser.LateralContext context) return new Lateral(getLocation(context), (Query) visit(context.query())); } + @Override + public Node visitTableFunctionInvocation(SqlBaseParser.TableFunctionInvocationContext context) + { + return visit(context.tableFunctionCall()); + } + + @Override + public Node visitTableFunctionCall(SqlBaseParser.TableFunctionCallContext context) + { + QualifiedName name = getQualifiedName(context.qualifiedName()); + List arguments = visit(context.tableFunctionArgument(), TableFunctionArgument.class); + List> copartitioning = ImmutableList.of(); + if (context.COPARTITION() != null) { + copartitioning = context.copartitionTables().stream() + .map(tablesList -> tablesList.qualifiedName().stream() + .map(this::getQualifiedName) + .collect(toImmutableList())) + .collect(toImmutableList()); + } + + return new TableFunctionInvocation(getLocation(context), name, arguments, copartitioning); + } + + @Override + public Node visitTableFunctionArgument(SqlBaseParser.TableFunctionArgumentContext context) + { + Optional name = visitIfPresent(context.identifier(), Identifier.class); + Node value; + if (context.tableArgument() != null) { + value = visit(context.tableArgument()); + } + else if (context.descriptorArgument() != null) { + value = visit(context.descriptorArgument()); + } + else { + value = visit(context.expression()); + } + + return new TableFunctionArgument(getLocation(context), name, value); + } + + @Override + public Node visitTableArgument(SqlBaseParser.TableArgumentContext context) + { + Relation table = (Relation) visit(context.tableArgumentRelation()); + + Optional> partitionBy = Optional.empty(); + if (context.PARTITION() != null) { + partitionBy = Optional.of(visit(context.expression(), Expression.class)); + } + + Optional orderBy = Optional.empty(); + if (context.ORDER() != null) { + orderBy = Optional.of(new OrderBy(visit(context.sortItem(), SortItem.class))); + } + + Optional emptyTableTreatment = Optional.empty(); + if (context.PRUNE() != null) { + emptyTableTreatment = Optional.of(new EmptyTableTreatment(getLocation(context.PRUNE()), Treatment.PRUNE)); + } + else if (context.KEEP() != null) { + emptyTableTreatment = Optional.of(new EmptyTableTreatment(getLocation(context.KEEP()), Treatment.KEEP)); + } + + return new TableFunctionTableArgument(getLocation(context), table, partitionBy, orderBy, emptyTableTreatment); + } + + @Override + public Node visitTableArgumentTable(SqlBaseParser.TableArgumentTableContext context) + { + Relation relation = new Table(getLocation(context.TABLE()), getQualifiedName(context.qualifiedName())); + + if (context.identifier() != null) { + Identifier alias = (Identifier) visit(context.identifier()); + List columnNames = null; + if (context.columnAliases() != null) { + columnNames = visit(context.columnAliases().identifier(), Identifier.class); + } + relation = new AliasedRelation(getLocation(context.TABLE()), relation, alias, columnNames); + } + + return relation; + } + + @Override + public Node visitTableArgumentQuery(SqlBaseParser.TableArgumentQueryContext context) + { + Relation relation = new TableSubquery(getLocation(context.TABLE()), (Query) visit(context.query())); + + if (context.identifier() != null) { + Identifier alias = (Identifier) visit(context.identifier()); + List columnNames = null; + if (context.columnAliases() != null) { + columnNames = visit(context.columnAliases().identifier(), Identifier.class); + } + relation = new AliasedRelation(getLocation(context.TABLE()), relation, alias, columnNames); + } + + return relation; + } + + @Override + public Node visitDescriptorArgument(SqlBaseParser.DescriptorArgumentContext context) + { + if (context.NULL() != null) { + return nullDescriptorArgument(getLocation(context)); + } + List fields = visit(context.descriptorField(), DescriptorField.class); + return descriptorArgument(getLocation(context), new Descriptor(getLocation(context.DESCRIPTOR()), fields)); + } + + @Override + public Node visitDescriptorField(SqlBaseParser.DescriptorFieldContext context) + { + return new DescriptorField(getLocation(context), (Identifier) visit(context.identifier()), Optional.of(getType(context.type()))); + } + @Override public Node visitParenthesizedRelation(SqlBaseParser.ParenthesizedRelationContext context) { diff --git a/presto-parser/src/main/java/com/facebook/presto/sql/tree/AstVisitor.java b/presto-parser/src/main/java/com/facebook/presto/sql/tree/AstVisitor.java index 8f42707192a4a..b292be528f1bc 100644 --- a/presto-parser/src/main/java/com/facebook/presto/sql/tree/AstVisitor.java +++ b/presto-parser/src/main/java/com/facebook/presto/sql/tree/AstVisitor.java @@ -37,6 +37,11 @@ protected R visitExpression(Expression node, C context) return visitNode(node, context); } + protected R visitEmptyTableTreatment(EmptyTableTreatment node, C context) + { + return visitNode(node, context); + } + protected R visitCurrentTime(CurrentTime node, C context) { return visitExpression(node, context); @@ -851,6 +856,7 @@ protected R visitRoutineBody(RoutineBody node, C context) { return visitNode(node, context); } + protected R visitReturn(Return node, C context) { return visitNode(node, context); @@ -860,4 +866,34 @@ protected R visitExternalBodyReference(ExternalBodyReference node, C context) { return visitNode(node, context); } + + protected R visitTableFunctionInvocation(TableFunctionInvocation node, C context) + { + return visitRelation(node, context); + } + + protected R visitTableFunctionArgument(TableFunctionArgument node, C context) + { + return visitNode(node, context); + } + + protected R visitTableArgument(TableFunctionTableArgument node, C context) + { + return visitNode(node, context); + } + + protected R visitDescriptorArgument(TableFunctionDescriptorArgument node, C context) + { + return visitNode(node, context); + } + + protected R visitDescriptor(Descriptor node, C context) + { + return visitNode(node, context); + } + + protected R visitDescriptorField(DescriptorField node, C context) + { + return visitNode(node, context); + } } diff --git a/presto-parser/src/main/java/com/facebook/presto/sql/tree/DefaultTraversalVisitor.java b/presto-parser/src/main/java/com/facebook/presto/sql/tree/DefaultTraversalVisitor.java index 7c90775c78d3a..e6bd352a69748 100644 --- a/presto-parser/src/main/java/com/facebook/presto/sql/tree/DefaultTraversalVisitor.java +++ b/presto-parser/src/main/java/com/facebook/presto/sql/tree/DefaultTraversalVisitor.java @@ -162,6 +162,16 @@ protected R visitFunctionCall(FunctionCall node, C context) return null; } + @Override + protected R visitTableFunctionInvocation(TableFunctionInvocation node, C context) + { + for (TableFunctionArgument argument : node.getArguments()) { + process(argument.getValue(), context); + } + + return null; + } + @Override protected R visitGroupingOperation(GroupingOperation node, C context) { diff --git a/presto-parser/src/main/java/com/facebook/presto/sql/tree/Descriptor.java b/presto-parser/src/main/java/com/facebook/presto/sql/tree/Descriptor.java new file mode 100644 index 0000000000000..b94dfc877d92e --- /dev/null +++ b/presto-parser/src/main/java/com/facebook/presto/sql/tree/Descriptor.java @@ -0,0 +1,85 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.tree; + +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.stream.Collectors; + +import static com.google.common.base.Preconditions.checkArgument; +import static java.util.Objects.requireNonNull; + +public class Descriptor + extends Node +{ + private final List fields; + + public Descriptor(NodeLocation location, List fields) + { + super(Optional.of(location)); + requireNonNull(fields, "fields is null"); + checkArgument(!fields.isEmpty(), "fields list is empty"); + this.fields = fields; + } + + public List getFields() + { + return fields; + } + + @Override + public R accept(AstVisitor visitor, C context) + { + return visitor.visitDescriptor(this, context); + } + + @Override + public List getChildren() + { + return fields; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + return Objects.equals(fields, ((Descriptor) o).fields); + } + + @Override + public int hashCode() + { + return Objects.hash(fields); + } + + @Override + public String toString() + { + return fields.stream() + .map(DescriptorField::toString) + .collect(Collectors.joining(", ", "DESCRIPTOR(", ")")); + } + + @Override + public boolean shallowEquals(Node o) + { + return sameClass(this, o); + } +} diff --git a/presto-parser/src/main/java/com/facebook/presto/sql/tree/DescriptorField.java b/presto-parser/src/main/java/com/facebook/presto/sql/tree/DescriptorField.java new file mode 100644 index 0000000000000..a5994d257c1d5 --- /dev/null +++ b/presto-parser/src/main/java/com/facebook/presto/sql/tree/DescriptorField.java @@ -0,0 +1,93 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.tree; + +import java.util.Collections; +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +import static java.util.Objects.requireNonNull; + +public class DescriptorField + extends Node +{ + private final Identifier name; + private final Optional type; + + public DescriptorField(NodeLocation location, Identifier name, Optional type) + { + super(Optional.of(location)); + this.name = requireNonNull(name, "name is null"); + this.type = requireNonNull(type, "type is null"); + } + + public Identifier getName() + { + return name; + } + + public Optional getType() + { + return type; + } + + @Override + public R accept(AstVisitor visitor, C context) + { + return visitor.visitDescriptorField(this, context); + } + + @Override + public List getChildren() + { + return Collections.emptyList(); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + DescriptorField field = (DescriptorField) o; + return Objects.equals(name, field.name) && + Objects.equals(type, (field.type)); + } + + @Override + public int hashCode() + { + return Objects.hash(name, type); + } + + @Override + public String toString() + { + return type.map(dataType -> name + " " + dataType).orElse(name.toString()); + } + + @Override + public boolean shallowEquals(Node o) + { + if (!sameClass(this, o)) { + return false; + } + + return Objects.equals(name, ((DescriptorField) o).name); + } +} diff --git a/presto-parser/src/main/java/com/facebook/presto/sql/tree/EmptyTableTreatment.java b/presto-parser/src/main/java/com/facebook/presto/sql/tree/EmptyTableTreatment.java new file mode 100644 index 0000000000000..8065f15ad437b --- /dev/null +++ b/presto-parser/src/main/java/com/facebook/presto/sql/tree/EmptyTableTreatment.java @@ -0,0 +1,93 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.tree; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +public final class EmptyTableTreatment + extends Node +{ + private final Treatment treatment; + + public EmptyTableTreatment(NodeLocation location, Treatment treatment) + { + super(Optional.of(location)); + this.treatment = requireNonNull(treatment, "treatment is null"); + } + + public Treatment getTreatment() + { + return treatment; + } + + @Override + public R accept(AstVisitor visitor, C context) + { + return visitor.visitEmptyTableTreatment(this, context); + } + + @Override + public List getChildren() + { + return ImmutableList.of(); + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if ((obj == null) || (getClass() != obj.getClass())) { + return false; + } + return treatment == ((EmptyTableTreatment) obj).treatment; + } + + @Override + public int hashCode() + { + return Objects.hash(treatment); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("treatment", treatment) + .toString(); + } + + @Override + public boolean shallowEquals(Node other) + { + if (!sameClass(this, other)) { + return false; + } + + return treatment == ((EmptyTableTreatment) other).treatment; + } + + public enum Treatment + { + KEEP, PRUNE + } +} diff --git a/presto-parser/src/main/java/com/facebook/presto/sql/tree/Identifier.java b/presto-parser/src/main/java/com/facebook/presto/sql/tree/Identifier.java index f85c74e0e3f37..56b3079eaa6f4 100644 --- a/presto-parser/src/main/java/com/facebook/presto/sql/tree/Identifier.java +++ b/presto-parser/src/main/java/com/facebook/presto/sql/tree/Identifier.java @@ -74,6 +74,15 @@ public boolean isDelimited() return delimited; } + public String getCanonicalValue() + { + if (isDelimited()) { + return value; + } + + return value.toUpperCase(ENGLISH); + } + @Override public R accept(AstVisitor visitor, C context) { diff --git a/presto-parser/src/main/java/com/facebook/presto/sql/tree/Node.java b/presto-parser/src/main/java/com/facebook/presto/sql/tree/Node.java index 2317c5258c1c9..0083696583624 100644 --- a/presto-parser/src/main/java/com/facebook/presto/sql/tree/Node.java +++ b/presto-parser/src/main/java/com/facebook/presto/sql/tree/Node.java @@ -51,4 +51,21 @@ public Optional getLocation() @Override public abstract String toString(); + + /** + * Compare with another node by considering internal state excluding any Node returned by getChildren() + */ + public boolean shallowEquals(Node other) + { + throw new UnsupportedOperationException("not yet implemented: " + getClass().getName()); + } + + static boolean sameClass(Node left, Node right) + { + if (left == right) { + return true; + } + + return left.getClass() == right.getClass(); + } } diff --git a/presto-parser/src/main/java/com/facebook/presto/sql/tree/TableFunctionArgument.java b/presto-parser/src/main/java/com/facebook/presto/sql/tree/TableFunctionArgument.java new file mode 100644 index 0000000000000..7db4fc9de0e78 --- /dev/null +++ b/presto-parser/src/main/java/com/facebook/presto/sql/tree/TableFunctionArgument.java @@ -0,0 +1,98 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.tree; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +import static com.google.common.base.Preconditions.checkArgument; +import static java.util.Objects.requireNonNull; + +public class TableFunctionArgument + extends Node +{ + private final Optional name; + private final Node value; + + public TableFunctionArgument(NodeLocation location, Optional name, Node value) + { + super(Optional.of(location)); + this.name = requireNonNull(name, "name is null"); + requireNonNull(value, "value is null"); + checkArgument(value instanceof TableFunctionTableArgument || value instanceof TableFunctionDescriptorArgument || value instanceof Expression); + this.value = value; + } + + public Optional getName() + { + return name; + } + + public Node getValue() + { + return value; + } + + @Override + public R accept(AstVisitor visitor, C context) + { + return visitor.visitTableFunctionArgument(this, context); + } + + @Override + public List getChildren() + { + return ImmutableList.of(value); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + TableFunctionArgument other = (TableFunctionArgument) o; + return Objects.equals(name, other.name) && + Objects.equals(value, other.value); + } + + @Override + public int hashCode() + { + return Objects.hash(name, value); + } + + @Override + public String toString() + { + return name.map(identifier -> identifier + " => ").orElse("") + value; + } + + @Override + public boolean shallowEquals(Node o) + { + if (!sameClass(this, o)) { + return false; + } + + return Objects.equals(name, ((TableFunctionArgument) o).name); + } +} diff --git a/presto-parser/src/main/java/com/facebook/presto/sql/tree/TableFunctionDescriptorArgument.java b/presto-parser/src/main/java/com/facebook/presto/sql/tree/TableFunctionDescriptorArgument.java new file mode 100644 index 0000000000000..b4ee67a976cd8 --- /dev/null +++ b/presto-parser/src/main/java/com/facebook/presto/sql/tree/TableFunctionDescriptorArgument.java @@ -0,0 +1,92 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.tree; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +import static java.util.Objects.requireNonNull; + +public class TableFunctionDescriptorArgument + extends Node +{ + private final Optional descriptor; + + public static TableFunctionDescriptorArgument descriptorArgument(NodeLocation location, Descriptor descriptor) + { + requireNonNull(descriptor, "descriptor is null"); + return new TableFunctionDescriptorArgument(location, Optional.of(descriptor)); + } + + public static TableFunctionDescriptorArgument nullDescriptorArgument(NodeLocation location) + { + return new TableFunctionDescriptorArgument(location, Optional.empty()); + } + + private TableFunctionDescriptorArgument(NodeLocation location, Optional descriptor) + { + super(Optional.of(location)); + this.descriptor = descriptor; + } + + public Optional getDescriptor() + { + return descriptor; + } + + @Override + public R accept(AstVisitor visitor, C context) + { + return visitor.visitDescriptorArgument(this, context); + } + + @Override + public List getChildren() + { + return descriptor.map(ImmutableList::of).orElse(ImmutableList.of()); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + return Objects.equals(descriptor, ((TableFunctionDescriptorArgument) o).descriptor); + } + + @Override + public int hashCode() + { + return Objects.hash(descriptor); + } + + @Override + public String toString() + { + return descriptor.map(Descriptor::toString).orElse("CAST (NULL AS DESCRIPTOR)"); + } + + @Override + public boolean shallowEquals(Node o) + { + return sameClass(this, o); + } +} diff --git a/presto-parser/src/main/java/com/facebook/presto/sql/tree/TableFunctionInvocation.java b/presto-parser/src/main/java/com/facebook/presto/sql/tree/TableFunctionInvocation.java new file mode 100644 index 0000000000000..2f75ac165fd50 --- /dev/null +++ b/presto-parser/src/main/java/com/facebook/presto/sql/tree/TableFunctionInvocation.java @@ -0,0 +1,120 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.tree; + +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.stream.Collectors; + +import static java.util.Objects.requireNonNull; + +public class TableFunctionInvocation + extends Relation +{ + private final QualifiedName name; + private final List arguments; + private final List> copartitioning; + + public TableFunctionInvocation(NodeLocation location, QualifiedName name, List arguments, List> copartitioning) + { + super(Optional.of(location)); + this.name = requireNonNull(name, "name is null"); + this.arguments = requireNonNull(arguments, "arguments is null"); + this.copartitioning = requireNonNull(copartitioning, "copartitioning is null"); + } + + public QualifiedName getName() + { + return name; + } + + public List getArguments() + { + return arguments; + } + + public List> getCopartitioning() + { + return copartitioning; + } + + @Override + public R accept(AstVisitor visitor, C context) + { + return visitor.visitTableFunctionInvocation(this, context); + } + + @Override + public List getChildren() + { + return arguments; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + TableFunctionInvocation that = (TableFunctionInvocation) o; + return Objects.equals(name, that.name) && + Objects.equals(arguments, that.arguments) && + Objects.equals(copartitioning, that.copartitioning); + } + + @Override + public int hashCode() + { + return Objects.hash(name, arguments, copartitioning); + } + + @Override + public String toString() + { + StringBuilder builder = new StringBuilder(); + builder.append(name) + .append("("); + builder.append(arguments.stream() + .map(TableFunctionArgument::toString) + .collect(Collectors.joining(", "))); + if (!copartitioning.isEmpty()) { + builder.append(" COPARTITION"); + builder.append(copartitioning.stream() + .map(list -> list.stream() + .map(QualifiedName::toString) + .collect(Collectors.joining(", ", "(", ")"))) + .collect(Collectors.joining(", "))); + } + builder.append(")"); + + return builder.toString(); + } + + @Override + public boolean shallowEquals(Node o) + { + if (!sameClass(this, o)) { + return false; + } + + TableFunctionInvocation other = (TableFunctionInvocation) o; + return Objects.equals(name, other.name) && + Objects.equals(copartitioning, other.copartitioning); + } +} diff --git a/presto-parser/src/main/java/com/facebook/presto/sql/tree/TableFunctionTableArgument.java b/presto-parser/src/main/java/com/facebook/presto/sql/tree/TableFunctionTableArgument.java new file mode 100644 index 0000000000000..b23da149cabae --- /dev/null +++ b/presto-parser/src/main/java/com/facebook/presto/sql/tree/TableFunctionTableArgument.java @@ -0,0 +1,123 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.tree; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.stream.Collectors; + +import static com.facebook.presto.sql.ExpressionFormatter.formatSortItems; +import static java.util.Objects.requireNonNull; + +public class TableFunctionTableArgument + extends Node +{ + private final Relation table; + private final Optional> partitionBy; // it is allowed to partition by empty list + private final Optional orderBy; + private final Optional emptyTableTreatment; + + public TableFunctionTableArgument( + NodeLocation location, + Relation table, + Optional> partitionBy, + Optional orderBy, + Optional emptyTableTreatment) + { + super(Optional.of(location)); + this.table = requireNonNull(table, "table is null"); + this.partitionBy = requireNonNull(partitionBy, "partitionBy is null"); + this.orderBy = requireNonNull(orderBy, "orderBy is null"); + this.emptyTableTreatment = requireNonNull(emptyTableTreatment, "emptyTableTreatment is null"); + } + + public Relation getTable() + { + return table; + } + + public Optional getEmptyTableTreatment() + { + return emptyTableTreatment; + } + + public Optional> getPartitionBy() + { + return partitionBy; + } + + public Optional getOrderBy() + { + return orderBy; + } + + @Override + public R accept(AstVisitor visitor, C context) + { + return visitor.visitTableArgument(this, context); + } + + @Override + public List getChildren() + { + ImmutableList.Builder builder = ImmutableList.builder(); + builder.add(table); + partitionBy.ifPresent(builder::addAll); + orderBy.ifPresent(builder::add); + emptyTableTreatment.ifPresent(builder::add); + + return builder.build(); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + TableFunctionTableArgument other = (TableFunctionTableArgument) o; + return Objects.equals(table, other.table) && + Objects.equals(partitionBy, other.partitionBy) && + Objects.equals(orderBy, other.orderBy) && + Objects.equals(emptyTableTreatment, other.emptyTableTreatment); + } + + @Override + public int hashCode() + { + return Objects.hash(table, partitionBy, orderBy, emptyTableTreatment); + } + + @Override + public String toString() + { + StringBuilder builder = new StringBuilder(); + builder.append(table); + partitionBy.ifPresent(partitioning -> builder.append(partitioning.stream() + .map(Expression::toString) + .collect(Collectors.joining(", ", " PARTITION BY (", ")")))); + orderBy.ifPresent(ordering -> builder.append(" ORDER BY (") + .append(formatSortItems(ordering.getSortItems(), Optional.empty())) + .append(")")); + + return builder.toString(); + } +} diff --git a/presto-parser/src/test/java/com/facebook/presto/sql/parser/TestSqlParser.java b/presto-parser/src/test/java/com/facebook/presto/sql/parser/TestSqlParser.java index 68b590834653d..e8e423af592db 100644 --- a/presto-parser/src/test/java/com/facebook/presto/sql/parser/TestSqlParser.java +++ b/presto-parser/src/test/java/com/facebook/presto/sql/parser/TestSqlParser.java @@ -51,6 +51,8 @@ import com.facebook.presto.sql.tree.DereferenceExpression; import com.facebook.presto.sql.tree.DescribeInput; import com.facebook.presto.sql.tree.DescribeOutput; +import com.facebook.presto.sql.tree.Descriptor; +import com.facebook.presto.sql.tree.DescriptorField; import com.facebook.presto.sql.tree.DoubleLiteral; import com.facebook.presto.sql.tree.DropColumn; import com.facebook.presto.sql.tree.DropConstraint; @@ -60,6 +62,7 @@ import com.facebook.presto.sql.tree.DropSchema; import com.facebook.presto.sql.tree.DropTable; import com.facebook.presto.sql.tree.DropView; +import com.facebook.presto.sql.tree.EmptyTableTreatment; import com.facebook.presto.sql.tree.Execute; import com.facebook.presto.sql.tree.ExistsPredicate; import com.facebook.presto.sql.tree.Explain; @@ -107,6 +110,7 @@ import com.facebook.presto.sql.tree.Query; import com.facebook.presto.sql.tree.QuerySpecification; import com.facebook.presto.sql.tree.RefreshMaterializedView; +import com.facebook.presto.sql.tree.Relation; import com.facebook.presto.sql.tree.RenameColumn; import com.facebook.presto.sql.tree.RenameSchema; import com.facebook.presto.sql.tree.RenameTable; @@ -145,6 +149,9 @@ import com.facebook.presto.sql.tree.SubqueryExpression; import com.facebook.presto.sql.tree.SubscriptExpression; import com.facebook.presto.sql.tree.Table; +import com.facebook.presto.sql.tree.TableFunctionArgument; +import com.facebook.presto.sql.tree.TableFunctionInvocation; +import com.facebook.presto.sql.tree.TableFunctionTableArgument; import com.facebook.presto.sql.tree.TableSubquery; import com.facebook.presto.sql.tree.TableVersionExpression; import com.facebook.presto.sql.tree.TimeLiteral; @@ -190,14 +197,18 @@ import static com.facebook.presto.sql.tree.ComparisonExpression.Operator.LESS_THAN; import static com.facebook.presto.sql.tree.ConstraintSpecification.ConstraintType.PRIMARY_KEY; import static com.facebook.presto.sql.tree.ConstraintSpecification.ConstraintType.UNIQUE; +import static com.facebook.presto.sql.tree.EmptyTableTreatment.Treatment.PRUNE; import static com.facebook.presto.sql.tree.RoutineCharacteristics.Determinism.DETERMINISTIC; import static com.facebook.presto.sql.tree.RoutineCharacteristics.Determinism.NOT_DETERMINISTIC; import static com.facebook.presto.sql.tree.RoutineCharacteristics.Language.SQL; import static com.facebook.presto.sql.tree.RoutineCharacteristics.NullCallClause.CALLED_ON_NULL_INPUT; import static com.facebook.presto.sql.tree.RoutineCharacteristics.NullCallClause.RETURNS_NULL_ON_NULL_INPUT; +import static com.facebook.presto.sql.tree.SortItem.NullOrdering.LAST; import static com.facebook.presto.sql.tree.SortItem.NullOrdering.UNDEFINED; import static com.facebook.presto.sql.tree.SortItem.Ordering.ASCENDING; import static com.facebook.presto.sql.tree.SortItem.Ordering.DESCENDING; +import static com.facebook.presto.sql.tree.TableFunctionDescriptorArgument.descriptorArgument; +import static com.facebook.presto.sql.tree.TableFunctionDescriptorArgument.nullDescriptorArgument; import static com.facebook.presto.sql.tree.TableVersionExpression.TableVersionOperator; import static com.facebook.presto.sql.tree.TableVersionExpression.TableVersionType.TIMESTAMP; import static com.facebook.presto.sql.tree.TableVersionExpression.TableVersionType.VERSION; @@ -3455,4 +3466,161 @@ public void testSelectWithBeforeTimestamp() assertStatement("CREATE VIEW view1 AS SELECT * FROM table1 FOR TIMESTAMP BEFORE TIMESTAMP '2023-08-17 13:29:46.822 America/Los_Angeles'", new CreateView(QualifiedName.of("view1"), query, false, Optional.empty())); } + + @Test + public void testTableFunctionInvocation() + { + assertStatement("SELECT * FROM TABLE(some_ptf(input => 1))", + selectAllFrom(new TableFunctionInvocation( + new NodeLocation(1, 21), + QualifiedName.of("some_ptf"), + ImmutableList.of(new TableFunctionArgument( + new NodeLocation(1, 30), + Optional.of(new Identifier(new NodeLocation(1, 30), "input", false)), + new LongLiteral(new NodeLocation(1, 39), "1"))), + ImmutableList.of()))); + + assertStatement("SELECT * FROM TABLE(some_ptf(" + + " arg1 => TABLE(orders) AS ord(a, b, c) " + + " PARTITION BY a " + + " PRUNE WHEN EMPTY " + + " ORDER BY b ASC NULLS LAST, " + + " arg2 => CAST(NULL AS DESCRIPTOR), " + + " arg3 => DESCRIPTOR(x integer, y varchar), " + + " arg4 => 5, " + + " 'not-named argument' " + + " COPARTITION (ord, nation)))", + selectAllFrom(new TableFunctionInvocation( + new NodeLocation(1, 21), + QualifiedName.of("some_ptf"), + ImmutableList.of( + new TableFunctionArgument( + new NodeLocation(1, 77), + Optional.of(new Identifier(new NodeLocation(1, 77), "arg1", false)), + new TableFunctionTableArgument( + new NodeLocation(1, 85), + new AliasedRelation( + new NodeLocation(1, 85), + new Table(new NodeLocation(1, 85), QualifiedName.of("orders")), + new Identifier(new NodeLocation(1, 102), "ord", false), + ImmutableList.of( + new Identifier(new NodeLocation(1, 106), "a", false), + new Identifier(new NodeLocation(1, 109), "b", false), + new Identifier(new NodeLocation(1, 112), "c", false))), + Optional.of(ImmutableList.of(new Identifier(new NodeLocation(1, 196), "a", false))), + Optional.of(new OrderBy(ImmutableList.of(new SortItem(new NodeLocation(1, 360), new Identifier(new NodeLocation(1, 360), "b", false), ASCENDING, LAST)))), + Optional.of(new EmptyTableTreatment(new NodeLocation(1, 266), PRUNE)))), + new TableFunctionArgument( + new NodeLocation(1, 425), + Optional.of(new Identifier(new NodeLocation(1, 425), "arg2", false)), + nullDescriptorArgument(new NodeLocation(1, 433))), + new TableFunctionArgument( + new NodeLocation(1, 506), + Optional.of(new Identifier(new NodeLocation(1, 506), "arg3", false)), + descriptorArgument( + new NodeLocation(1, 514), + new Descriptor(new NodeLocation(1, 514), ImmutableList.of( + new DescriptorField( + new NodeLocation(1, 525), + new Identifier(new NodeLocation(1, 525), "x", false), + Optional.of("integer")), + new DescriptorField( + new NodeLocation(1, 536), + new Identifier(new NodeLocation(1, 536), "y", false), + Optional.of("varchar")))))), + new TableFunctionArgument( + new NodeLocation(1, 595), + Optional.of(new Identifier(new NodeLocation(1, 595), "arg4", false)), + new LongLiteral(new NodeLocation(1, 603), "5")), + new TableFunctionArgument( + new NodeLocation(1, 653), + Optional.empty(), + new StringLiteral(new NodeLocation(1, 653), "not-named argument"))), + ImmutableList.of(ImmutableList.of( + QualifiedName.of("ord"), + QualifiedName.of("nation")))))); + } + + @Test + public void testTableFunctionTableArgumentAliasing() + { + // no alias + assertStatement("SELECT * FROM TABLE(some_ptf(input => TABLE(orders)))", + selectAllFrom(new TableFunctionInvocation( + new NodeLocation(1, 21), + QualifiedName.of("some_ptf"), + ImmutableList.of(new TableFunctionArgument( + new NodeLocation(1, 30), + Optional.of(new Identifier(new NodeLocation(1, 30), "input", false)), + new TableFunctionTableArgument( + new NodeLocation(1, 39), + new Table(new NodeLocation(1, 39), QualifiedName.of("orders")), + Optional.empty(), + Optional.empty(), + Optional.empty()))), + ImmutableList.of()))); + + // table alias; no column aliases + assertStatement("SELECT * FROM TABLE(some_ptf(input => TABLE(orders) AS ord))", + selectAllFrom(new TableFunctionInvocation( + new NodeLocation(1, 21), + QualifiedName.of("some_ptf"), + ImmutableList.of(new TableFunctionArgument( + new NodeLocation(1, 30), + Optional.of(new Identifier(new NodeLocation(1, 30), "input", false)), + new TableFunctionTableArgument( + new NodeLocation(1, 39), + new AliasedRelation( + new NodeLocation(1, 39), + new Table(new NodeLocation(1, 39), QualifiedName.of("orders")), + new Identifier(new NodeLocation(1, 56), "ord", false), + null), + Optional.empty(), + Optional.empty(), + Optional.empty()))), + ImmutableList.of()))); + + // table alias and column aliases + assertStatement("SELECT * FROM TABLE(some_ptf(input => TABLE(orders) AS ord(a, b, c)))", + selectAllFrom(new TableFunctionInvocation( + new NodeLocation(1, 21), + QualifiedName.of("some_ptf"), + ImmutableList.of(new TableFunctionArgument( + new NodeLocation(1, 30), + Optional.of(new Identifier(new NodeLocation(1, 30), "input", false)), + new TableFunctionTableArgument( + new NodeLocation(1, 39), + new AliasedRelation( + new NodeLocation(1, 39), + new Table(new NodeLocation(1, 39), QualifiedName.of("orders")), + new Identifier(new NodeLocation(1, 56), "ord", false), + ImmutableList.of( + new Identifier(new NodeLocation(1, 60), "a", false), + new Identifier(new NodeLocation(1, 63), "b", false), + new Identifier(new NodeLocation(1, 66), "c", false))), + Optional.empty(), + Optional.empty(), + Optional.empty()))), + ImmutableList.of()))); + } + + private static Query selectAllFrom(Relation relation) + { + return new Query( + new NodeLocation(1, 1), + Optional.empty(), + new QuerySpecification( + new NodeLocation(1, 1), + new Select(new NodeLocation(1, 1), false, ImmutableList.of(new AllColumns(new NodeLocation(1, 8)))), + Optional.of(relation), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty()), + Optional.empty(), + Optional.empty(), + Optional.empty()); + } } diff --git a/presto-parser/src/test/java/com/facebook/presto/sql/parser/TestSqlParserErrorHandling.java b/presto-parser/src/test/java/com/facebook/presto/sql/parser/TestSqlParserErrorHandling.java index dacb7065c991a..d31ec9f5db64e 100644 --- a/presto-parser/src/test/java/com/facebook/presto/sql/parser/TestSqlParserErrorHandling.java +++ b/presto-parser/src/test/java/com/facebook/presto/sql/parser/TestSqlParserErrorHandling.java @@ -47,15 +47,15 @@ public Object[][] getStatements() {"select * from foo where @what", "line 1:25: mismatched input '@'. Expecting: "}, {"select * from 'oops", - "line 1:15: mismatched input '''. Expecting: '(', 'LATERAL', 'UNNEST', "}, + "line 1:15: mismatched input '''. Expecting: '(', 'LATERAL', 'TABLE', 'UNNEST', "}, {"select *\nfrom x\nfrom", "line 3:1: mismatched input 'from'. Expecting: ',', '.', 'AS', 'CROSS', 'EXCEPT', 'FETCH', 'FOR', 'FULL', 'GROUP', 'HAVING', 'INNER', 'INTERSECT', 'JOIN', 'LEFT', 'LIMIT', 'NATURAL', 'OFFSET', 'ORDER', 'RIGHT', 'TABLESAMPLE', 'UNION', 'WHERE', , "}, {"select *\nfrom x\nwhere from", "line 3:7: mismatched input 'from'. Expecting: "}, {"select * from", - "line 1:14: mismatched input ''. Expecting: '(', 'LATERAL', 'UNNEST', "}, + "line 1:14: mismatched input ''. Expecting: '(', 'LATERAL', 'TABLE', 'UNNEST', "}, {"select * from ", - "line 1:16: mismatched input ''. Expecting: '(', 'LATERAL', 'UNNEST', "}, + "line 1:16: mismatched input ''. Expecting: '(', 'LATERAL', 'TABLE', 'UNNEST', "}, {"select * from `foo`", "line 1:15: backquoted identifiers are not supported; use double quotes to quote identifiers"}, {"select * from foo `bar`", @@ -103,7 +103,7 @@ public Object[][] getStatements() {"CREATE TABLE t (x bigint) COMMENT ", "line 1:35: mismatched input ''. Expecting: "}, {"SELECT * FROM ( ", - "line 1:17: mismatched input ''. Expecting: '(', 'LATERAL', 'UNNEST', , "}, + "line 1:17: mismatched input ''. Expecting: '(', 'LATERAL', 'TABLE', 'UNNEST', , "}, {"SELECT CAST(a AS )", "line 1:18: mismatched input ')'. Expecting: "}, {"SELECT CAST(a AS decimal()", diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkModule.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkModule.java index 6c666961fcd51..9d223a6711957 100644 --- a/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkModule.java +++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkModule.java @@ -80,6 +80,7 @@ import com.facebook.presto.metadata.StaticCatalogStoreConfig; import com.facebook.presto.metadata.StaticFunctionNamespaceStore; import com.facebook.presto.metadata.StaticFunctionNamespaceStoreConfig; +import com.facebook.presto.metadata.TableFunctionRegistry; import com.facebook.presto.metadata.TablePropertyManager; import com.facebook.presto.nodeManager.PluginNodeManager; import com.facebook.presto.operator.FileFragmentResultCacheConfig; @@ -367,6 +368,7 @@ protected void setup(Binder binder) // metadata binder.bind(FunctionAndTypeManager.class).in(Scopes.SINGLETON); + binder.bind(TableFunctionRegistry.class).in(Scopes.SINGLETON); binder.bind(MetadataManager.class).in(Scopes.SINGLETON); binder.bind(Metadata.class).to(MetadataManager.class).in(Scopes.SINGLETON); binder.bind(StaticFunctionNamespaceStore.class).in(Scopes.SINGLETON); diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/planner/PrestoSparkRddFactory.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/planner/PrestoSparkRddFactory.java index f06f859b9996a..2e7707d8e6547 100644 --- a/presto-spark-base/src/main/java/com/facebook/presto/spark/planner/PrestoSparkRddFactory.java +++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/planner/PrestoSparkRddFactory.java @@ -223,7 +223,7 @@ private JavaPairRDD cre Optional taskSourceRdd; List sources = findTableScanNodes(fragment.getRoot()); if (!sources.isEmpty()) { - try (CloseableSplitSourceProvider splitSourceProvider = new CloseableSplitSourceProvider(splitManager::getSplits)) { + try (CloseableSplitSourceProvider splitSourceProvider = new CloseableSplitSourceProvider(splitManager)) { SplitSourceFactory splitSourceFactory = new SplitSourceFactory(splitSourceProvider, WarningCollector.NOOP); Map splitSources = splitSourceFactory.createSplitSources(fragment, session, tableWriteInfo); taskSourceRdd = Optional.of(createTaskSourcesRdd( diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/ConnectorHandleResolver.java b/presto-spi/src/main/java/com/facebook/presto/spi/ConnectorHandleResolver.java index 79f1550e7c274..f93a8bb9d2c0a 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/ConnectorHandleResolver.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/ConnectorHandleResolver.java @@ -15,6 +15,7 @@ import com.facebook.presto.spi.connector.ConnectorPartitioningHandle; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; public interface ConnectorHandleResolver { @@ -60,4 +61,9 @@ default Class getMetadataUpdateHandleCl { throw new UnsupportedOperationException(); } + + default Class getTableFunctionHandleClass() + { + throw new UnsupportedOperationException(); + } } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/StandardErrorCode.java b/presto-spi/src/main/java/com/facebook/presto/spi/StandardErrorCode.java index 4de55e4aa38d6..f3848fae7c8e7 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/StandardErrorCode.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/StandardErrorCode.java @@ -75,6 +75,7 @@ public enum StandardErrorCode INVALID_LIMIT_CLAUSE(0x0000_0031, USER_ERROR), COLUMN_NOT_FOUND(0x0000_0032, USER_ERROR), UNKNOWN_TYPE(0x0000_0033, USER_ERROR), + MISSING_CATALOG_NAME(0x0000_0034, USER_ERROR), GENERIC_INTERNAL_ERROR(0x0001_0000, INTERNAL_ERROR), TOO_MANY_REQUESTS_FAILED(0x0001_0001, INTERNAL_ERROR, true), diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/connector/Connector.java b/presto-spi/src/main/java/com/facebook/presto/spi/connector/Connector.java index c23974963a37b..2b87d398c2b34 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/connector/Connector.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/connector/Connector.java @@ -14,6 +14,7 @@ package com.facebook.presto.spi.connector; import com.facebook.presto.spi.SystemTable; +import com.facebook.presto.spi.function.table.ConnectorTableFunction; import com.facebook.presto.spi.procedure.Procedure; import com.facebook.presto.spi.session.PropertyMetadata; import com.facebook.presto.spi.transaction.IsolationLevel; @@ -117,6 +118,14 @@ default Set getProcedures() return emptySet(); } + /** + * @return the set of table functions provided by this connector + */ + default Set getTableFunctions() + { + return emptySet(); + } + /** * @return the system properties for this connector */ diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorFactory.java b/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorFactory.java index 07b36b4dca528..37359e665d0d5 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorFactory.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorFactory.java @@ -14,8 +14,14 @@ package com.facebook.presto.spi.connector; import com.facebook.presto.spi.ConnectorHandleResolver; +import com.facebook.presto.spi.function.SchemaFunctionName; +import com.facebook.presto.spi.function.TableFunctionHandleResolver; +import com.facebook.presto.spi.function.TableFunctionSplitResolver; +import com.facebook.presto.spi.function.table.TableFunctionProcessorProvider; import java.util.Map; +import java.util.Optional; +import java.util.function.Function; public interface ConnectorFactory { @@ -24,4 +30,19 @@ public interface ConnectorFactory ConnectorHandleResolver getHandleResolver(); Connector create(String catalogName, Map config, ConnectorContext context); + + default Optional> getTableFunctionProcessorProvider() + { + return Optional.empty(); + } + + default Optional getTableFunctionHandleResolver() + { + return Optional.empty(); + } + + default Optional getTableFunctionSplitResolver() + { + return Optional.empty(); + } } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorMetadata.java b/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorMetadata.java index 4f2aef245af2e..3b7ada30f72fd 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorMetadata.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorMetadata.java @@ -41,6 +41,7 @@ import com.facebook.presto.spi.TableLayoutFilterCoverage; import com.facebook.presto.spi.api.Experimental; import com.facebook.presto.spi.constraints.TableConstraint; +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; import com.facebook.presto.spi.security.GrantInfo; import com.facebook.presto.spi.security.PrestoPrincipal; import com.facebook.presto.spi.security.Privilege; @@ -851,4 +852,17 @@ default void addConstraint(ConnectorSession session, ConnectorTableHandle tableH { throw new PrestoException(NOT_SUPPORTED, "This connector does not support adding table constraints"); } + + /** + * Attempt to push down the table function invocation into the connector. + *

+ * Connectors can indicate whether they don't support table function invocation pushdown or that the action had no + * effect by returning {@link Optional#empty()}. Connectors should expect this method may be called multiple times. + *

+ * If the method returns a result, the returned table handle will be used in place of the table function invocation. + */ + default Optional> applyTableFunction(ConnectorSession session, ConnectorTableFunctionHandle handle) + { + return Optional.empty(); + } } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorSplitManager.java b/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorSplitManager.java index 69ac79c9f7522..5690fce23ce58 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorSplitManager.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorSplitManager.java @@ -17,6 +17,8 @@ import com.facebook.presto.spi.ConnectorSplitSource; import com.facebook.presto.spi.ConnectorTableLayoutHandle; import com.facebook.presto.spi.WarningCollector; +import com.facebook.presto.spi.function.SchemaFunctionName; +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; import static java.util.Objects.requireNonNull; @@ -71,4 +73,13 @@ public WarningCollector getWarningCollector() return warningCollector; } } + + default ConnectorSplitSource getSplits( + ConnectorTransactionHandle transaction, + ConnectorSession session, + SchemaFunctionName name, + ConnectorTableFunctionHandle function) + { + throw new UnsupportedOperationException(); + } } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/connector/TableFunctionApplicationResult.java b/presto-spi/src/main/java/com/facebook/presto/spi/connector/TableFunctionApplicationResult.java new file mode 100644 index 0000000000000..916818fe63990 --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/connector/TableFunctionApplicationResult.java @@ -0,0 +1,42 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.connector; + +import com.facebook.presto.spi.ColumnHandle; + +import java.util.List; + +import static java.util.Objects.requireNonNull; + +public class TableFunctionApplicationResult +{ + private final T tableHandle; + private final List columnHandles; + + public TableFunctionApplicationResult(T tableHandle, List columnHandles) + { + this.tableHandle = requireNonNull(tableHandle, "tableHandle is null"); + this.columnHandles = requireNonNull(columnHandles, "columnHandles is null"); + } + + public T getTableHandle() + { + return tableHandle; + } + + public List getColumnHandles() + { + return columnHandles; + } +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/connector/classloader/ClassLoaderSafeConnectorMetadata.java b/presto-spi/src/main/java/com/facebook/presto/spi/connector/classloader/ClassLoaderSafeConnectorMetadata.java index e33b7d708a79c..798c3e943ebfc 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/connector/classloader/ClassLoaderSafeConnectorMetadata.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/connector/classloader/ClassLoaderSafeConnectorMetadata.java @@ -44,7 +44,9 @@ import com.facebook.presto.spi.connector.ConnectorPartitioningHandle; import com.facebook.presto.spi.connector.ConnectorPartitioningMetadata; import com.facebook.presto.spi.connector.ConnectorTableVersion; +import com.facebook.presto.spi.connector.TableFunctionApplicationResult; import com.facebook.presto.spi.constraints.TableConstraint; +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; import com.facebook.presto.spi.security.GrantInfo; import com.facebook.presto.spi.security.PrestoPrincipal; import com.facebook.presto.spi.security.Privilege; @@ -792,4 +794,12 @@ public void addConstraint(ConnectorSession session, ConnectorTableHandle tableHa delegate.addConstraint(session, tableHandle, tableConstraint); } } + + @Override + public Optional> applyTableFunction(ConnectorSession session, ConnectorTableFunctionHandle handle) + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + return delegate.applyTableFunction(session, handle); + } + } } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/CatalogSchemaFunctionName.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/CatalogSchemaFunctionName.java new file mode 100644 index 0000000000000..536cb4cce46be --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/CatalogSchemaFunctionName.java @@ -0,0 +1,95 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.function; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.Objects; + +import static java.util.Locale.ROOT; +import static java.util.Objects.requireNonNull; + +public final class CatalogSchemaFunctionName +{ + private final String catalogName; + private final SchemaFunctionName schemaFunctionName; + + public CatalogSchemaFunctionName(String catalogName, SchemaFunctionName schemaFunctionName) + { + this.catalogName = catalogName.toLowerCase(ROOT); + if (catalogName.isEmpty()) { + throw new IllegalArgumentException("catalogName is empty"); + } + this.schemaFunctionName = requireNonNull(schemaFunctionName, "schemaFunctionName is null"); + } + + @JsonCreator + public CatalogSchemaFunctionName( + @JsonProperty String catalogName, + @JsonProperty String schemaName, + @JsonProperty String functionName) + { + this(catalogName, new SchemaFunctionName(schemaName, functionName)); + } + + @JsonProperty + public String getCatalogName() + { + return catalogName; + } + + public SchemaFunctionName getSchemaFunctionName() + { + return schemaFunctionName; + } + + @JsonProperty + public String getSchemaName() + { + return schemaFunctionName.getSchemaName(); + } + + @JsonProperty + public String getFunctionName() + { + return schemaFunctionName.getFunctionName(); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + CatalogSchemaFunctionName that = (CatalogSchemaFunctionName) o; + return Objects.equals(catalogName, that.catalogName) && + Objects.equals(schemaFunctionName, that.schemaFunctionName); + } + + @Override + public int hashCode() + { + return Objects.hash(catalogName, schemaFunctionName); + } + + @Override + public String toString() + { + return catalogName + '.' + schemaFunctionName; + } +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/SchemaFunctionName.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/SchemaFunctionName.java new file mode 100644 index 0000000000000..262e399eeab3b --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/SchemaFunctionName.java @@ -0,0 +1,80 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.function; + +import com.facebook.presto.spi.api.Experimental; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.Objects; + +import static java.util.Locale.ROOT; + +@Experimental +public final class SchemaFunctionName +{ + private final String schemaName; + private final String functionName; + + @JsonCreator + public SchemaFunctionName(@JsonProperty("schemaName") String schemaName, @JsonProperty("functionName") String functionName) + { + this.schemaName = schemaName.toLowerCase(ROOT); + if (schemaName.isEmpty()) { + throw new IllegalArgumentException("schemaName is empty"); + } + this.functionName = functionName.toLowerCase(ROOT); + if (functionName.isEmpty()) { + throw new IllegalArgumentException("functionName is empty"); + } + } + + @JsonProperty + public String getSchemaName() + { + return schemaName; + } + + @JsonProperty + public String getFunctionName() + { + return functionName; + } + + @Override + public int hashCode() + { + return Objects.hash(schemaName, functionName); + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + SchemaFunctionName other = (SchemaFunctionName) obj; + return Objects.equals(this.schemaName, other.schemaName) && + Objects.equals(this.functionName, other.functionName); + } + + @Override + public String toString() + { + return schemaName + '.' + functionName; + } +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/TableFunctionHandleResolver.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/TableFunctionHandleResolver.java new file mode 100644 index 0000000000000..fd24b9c694c50 --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/TableFunctionHandleResolver.java @@ -0,0 +1,23 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.function; + +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; + +import java.util.Set; + +public interface TableFunctionHandleResolver +{ + Set> getTableFunctionHandleClasses(); +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/TableFunctionSplitResolver.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/TableFunctionSplitResolver.java new file mode 100644 index 0000000000000..2a31b1a9aa113 --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/TableFunctionSplitResolver.java @@ -0,0 +1,23 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.function; + +import com.facebook.presto.spi.ConnectorSplit; + +import java.util.Set; + +public interface TableFunctionSplitResolver +{ + Set> getTableFunctionSplitClasses(); +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/table/AbstractConnectorTableFunction.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/AbstractConnectorTableFunction.java new file mode 100644 index 0000000000000..f4190a6d93af5 --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/AbstractConnectorTableFunction.java @@ -0,0 +1,68 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.function.table; + +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import static java.util.Objects.requireNonNull; + +public abstract class AbstractConnectorTableFunction + implements ConnectorTableFunction +{ + private final String schema; + private final String name; + private final List arguments; + private final ReturnTypeSpecification returnTypeSpecification; + + public AbstractConnectorTableFunction(String schema, String name, List arguments, ReturnTypeSpecification returnTypeSpecification) + { + this.schema = requireNonNull(schema, "schema is null"); + this.name = requireNonNull(name, "name is null"); + this.arguments = Collections.unmodifiableList(new ArrayList<>(requireNonNull(arguments, "arguments is null"))); + this.returnTypeSpecification = requireNonNull(returnTypeSpecification, "returnTypeSpecification is null"); + } + + @Override + public String getSchema() + { + return schema; + } + + @Override + public String getName() + { + return name; + } + + @Override + public List getArguments() + { + return arguments; + } + + @Override + public ReturnTypeSpecification getReturnTypeSpecification() + { + return returnTypeSpecification; + } + + @Override + public abstract TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments); +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/table/Argument.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/Argument.java new file mode 100644 index 0000000000000..57ea0039dfa86 --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/Argument.java @@ -0,0 +1,36 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.function.table; + +import com.fasterxml.jackson.annotation.JsonSubTypes; +import com.fasterxml.jackson.annotation.JsonTypeInfo; + +/** + * This class represents the three types of arguments passed to a Table Function: + * scalar arguments, descriptor arguments, and table arguments. + *

+ * This representation should be considered experimental. Eventually, {@link ConnectorExpression} + * should be extended to include the Table Function arguments. + */ +@JsonTypeInfo( + use = JsonTypeInfo.Id.NAME, + property = "@type") +@JsonSubTypes({ + @JsonSubTypes.Type(value = DescriptorArgument.class, name = "descriptor"), + @JsonSubTypes.Type(value = ScalarArgument.class, name = "scalar"), + @JsonSubTypes.Type(value = TableArgument.class, name = "table"), +}) +public abstract class Argument +{ +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/table/ArgumentSpecification.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/ArgumentSpecification.java new file mode 100644 index 0000000000000..73e92d07d0323 --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/ArgumentSpecification.java @@ -0,0 +1,61 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.function.table; + +import javax.annotation.Nullable; + +import static com.facebook.presto.spi.function.table.Preconditions.checkArgument; +import static com.facebook.presto.spi.function.table.Preconditions.checkNotNullOrEmpty; + +/** + * Abstract class to capture the three supported argument types for a table function: + * - Table arguments + * - Descriptor arguments + * - SQL scalar arguments + *

+ * Each argument is named, and either passed positionally or in a `arg_name => value` convention. + *

+ * Default values are allowed for all arguments except Table arguments. + */ +public abstract class ArgumentSpecification +{ + private final String name; + private final boolean required; + + // native representation + private final Object defaultValue; + + ArgumentSpecification(String name, boolean required, @Nullable Object defaultValue) + { + this.name = checkNotNullOrEmpty(name, "name"); + checkArgument(!required || defaultValue == null, "non-null default value for a required argument"); + this.required = required; + this.defaultValue = defaultValue; + } + + public String getName() + { + return name; + } + + public boolean isRequired() + { + return required; + } + + public Object getDefaultValue() + { + return defaultValue; + } +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/table/ConnectorTableFunction.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/ConnectorTableFunction.java new file mode 100644 index 0000000000000..4b3b1a9a0b06d --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/ConnectorTableFunction.java @@ -0,0 +1,49 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.function.table; + +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; + +import java.util.List; +import java.util.Map; + +public interface ConnectorTableFunction +{ + String getSchema(); + + String getName(); + + List getArguments(); + + ReturnTypeSpecification getReturnTypeSpecification(); + + /** + * This method is called by the Analyzer. Its main purposes are to: + * 1. Determine the resulting relation type of the Table Function in case when the declared return type is GENERIC_TABLE. + * 2. Declare the required columns from the input tables. + * 3. Perform function-specific validation and pre-processing of the input arguments. + * As part of function-specific validation, the Table Function's author might want to: + * - check if the descriptors which reference input tables contain a correct number of column references + * - check if the referenced input columns have appropriate types to fit the function's logic // TODO return request for coercions to the Analyzer in the TableFunctionAnalysis object + * - if there is a descriptor which describes the function's output, check if it matches the shape of the actual function's output + * - for table arguments, check the number and types of ordering columns + *

+ * The actual argument values, and the pre-processing results can be stored in an ConnectorTableFunctionHandle + * object, which will be passed along with the Table Function invocation through subsequent phases of planning. + * + * @param arguments actual invocation arguments, mapped by argument names + */ + TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments); +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/table/ConnectorTableFunctionHandle.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/ConnectorTableFunctionHandle.java new file mode 100644 index 0000000000000..8ecf4b023b4ad --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/ConnectorTableFunctionHandle.java @@ -0,0 +1,24 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.function.table; + +import com.fasterxml.jackson.annotation.JsonInclude; + +/** + * An area to store all information necessary to execute the table function, gathered at analysis time + */ +@JsonInclude(JsonInclude.Include.ALWAYS) +public interface ConnectorTableFunctionHandle +{ +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/table/Descriptor.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/Descriptor.java new file mode 100644 index 0000000000000..b4a8f337c0b54 --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/Descriptor.java @@ -0,0 +1,137 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.function.table; + +import com.facebook.presto.common.type.Type; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.stream.Collectors; + +import static com.facebook.presto.spi.function.table.Preconditions.checkArgument; +import static com.facebook.presto.spi.function.table.Preconditions.checkNotNullOrEmpty; +import static java.util.Collections.unmodifiableList; +import static java.util.Objects.requireNonNull; + +public class Descriptor +{ + private final List fields; + + @JsonCreator + public Descriptor(@JsonProperty("fields") List fields) + { + requireNonNull(fields, "fields is null"); + checkArgument(!fields.isEmpty(), "descriptor has no fields"); + this.fields = unmodifiableList(fields); + } + + public static Descriptor descriptor(String... names) + { + List fields = Arrays.stream(names) + .map(name -> new Field(name, Optional.empty())) + .collect(Collectors.toList()); + return new Descriptor(fields); + } + + public static Descriptor descriptor(List names, List types) + { + requireNonNull(names, "names is null"); + requireNonNull(types, "types is null"); + checkArgument(names.size() == types.size(), "names and types lists do not match"); + List fields = new ArrayList<>(); + for (int i = 0; i < names.size(); i++) { + fields.add(new Field(names.get(i), Optional.of(types.get(i)))); + } + return new Descriptor(fields); + } + + @JsonProperty + public List getFields() + { + return fields; + } + + public boolean isTyped() + { + return fields.stream().allMatch(field -> field.type.isPresent()); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Descriptor that = (Descriptor) o; + return fields.equals(that.fields); + } + + @Override + public int hashCode() + { + return Objects.hash(fields); + } + + public static class Field + { + private final String name; + private final Optional type; + + @JsonCreator + public Field(@JsonProperty("name") String name, @JsonProperty("type") Optional type) + { + this.name = checkNotNullOrEmpty(name, "name"); + this.type = requireNonNull(type, "type is null"); + } + + @JsonProperty + public String getName() + { + return name; + } + + @JsonProperty + public Optional getType() + { + return type; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Field field = (Field) o; + return name.equals(field.name) && type.equals(field.type); + } + + @Override + public int hashCode() + { + return Objects.hash(name, type); + } + } +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/table/DescriptorArgument.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/DescriptorArgument.java new file mode 100644 index 0000000000000..545a742465ecc --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/DescriptorArgument.java @@ -0,0 +1,89 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.function.table; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.Objects; +import java.util.Optional; + +import static java.util.Objects.requireNonNull; + +/** + * This class represents the descriptor argument passed to a Table Function. + *

+ * This representation should be considered experimental. Eventually, {@link ConnectorExpression} + * should be extended to include this kind of argument. + */ +public class DescriptorArgument + extends Argument +{ + public static final DescriptorArgument NULL_DESCRIPTOR = builder().build(); + private final Optional descriptor; + + @JsonCreator + private DescriptorArgument(@JsonProperty("descriptor") Optional descriptor) + { + this.descriptor = requireNonNull(descriptor, "descriptor is null"); + } + + @JsonProperty + public Optional getDescriptor() + { + return descriptor; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + DescriptorArgument that = (DescriptorArgument) o; + return descriptor.equals(that.descriptor); + } + + @Override + public int hashCode() + { + return Objects.hash(descriptor); + } + + public static Builder builder() + { + return new Builder(); + } + + public static final class Builder + { + private Descriptor descriptor; + + private Builder() {} + + public Builder descriptor(Descriptor descriptor) + { + this.descriptor = descriptor; + return this; + } + + public DescriptorArgument build() + { + return new DescriptorArgument(Optional.ofNullable(descriptor)); + } + } +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/table/DescriptorArgumentSpecification.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/DescriptorArgumentSpecification.java new file mode 100644 index 0000000000000..637d25fac73be --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/DescriptorArgumentSpecification.java @@ -0,0 +1,55 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.function.table; + +public class DescriptorArgumentSpecification + extends ArgumentSpecification +{ + private DescriptorArgumentSpecification(String name, boolean required, Descriptor defaultValue) + { + super(name, required, defaultValue); + } + + public static Builder builder() + { + return new Builder(); + } + + public static final class Builder + { + private String name; + private boolean required = true; + private Descriptor defaultValue; + + private Builder() {} + + public Builder name(String name) + { + this.name = name; + return this; + } + + public Builder defaultValue(Descriptor defaultValue) + { + this.required = false; + this.defaultValue = defaultValue; + return this; + } + + public DescriptorArgumentSpecification build() + { + return new DescriptorArgumentSpecification(name, required, defaultValue); + } + } +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/table/NameAndPosition.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/NameAndPosition.java new file mode 100644 index 0000000000000..93216c3acc975 --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/NameAndPosition.java @@ -0,0 +1,70 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.function.table; + +import java.util.Objects; + +import static com.facebook.presto.spi.function.table.Preconditions.checkArgument; +import static com.facebook.presto.spi.function.table.Preconditions.checkNotNullOrEmpty; + +/** + * This class represents a descriptor field reference. + * `name` is the descriptor argument name, `position` is the zero-based field index. + *

+ * The specified field contains a column name, as passed by the Table Function caller. + * The column name is associated with an appropriate input table during the Analysis phase. + * The Table Function is supposed to refer to input data using `NameAndPosition`, + * and the engine should provide the requested column. + */ +public class NameAndPosition +{ + private final String name; + private final int position; + + public NameAndPosition(String name, int position) + { + this.name = checkNotNullOrEmpty(name, "name"); + checkArgument(position >= 0, "position in descriptor must not be negative"); + this.position = position; + } + + public String getName() + { + return name; + } + + public int getPosition() + { + return position; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + NameAndPosition that = (NameAndPosition) o; + return position == that.position && Objects.equals(name, that.name); + } + + @Override + public int hashCode() + { + return Objects.hash(name, position); + } +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/table/Preconditions.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/Preconditions.java new file mode 100644 index 0000000000000..83edc78526ade --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/Preconditions.java @@ -0,0 +1,35 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.function.table; + +import static java.util.Objects.requireNonNull; + +public final class Preconditions +{ + private Preconditions() {} + + public static String checkNotNullOrEmpty(String value, String name) + { + requireNonNull(value, name + " is null"); + checkArgument(!value.isEmpty(), name + " is empty"); + return value; + } + + public static void checkArgument(boolean assertion, String message) + { + if (!assertion) { + throw new IllegalArgumentException(message); + } + } +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/table/ReturnTypeSpecification.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/ReturnTypeSpecification.java new file mode 100644 index 0000000000000..45bf370a45714 --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/ReturnTypeSpecification.java @@ -0,0 +1,70 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.function.table; + +import static com.facebook.presto.spi.function.table.Preconditions.checkArgument; +import static java.util.Objects.requireNonNull; + +/** + * The return type declaration refers to the proper columns of the table function. + * These are the columns produced by the table function as opposed to the columns + * of input relations passed through by the table function. + */ +public abstract class ReturnTypeSpecification +{ + /** + * The proper columns of the table function are not known at function declaration time. + * They must be determined at query analysis time based on the actual call arguments. + */ + public static class GenericTable + extends ReturnTypeSpecification + { + public static final GenericTable GENERIC_TABLE = new GenericTable(); + + private GenericTable() {} + } + + /** + * The table function has no proper columns. + */ + public static class OnlyPassThrough + extends ReturnTypeSpecification + { + public static final OnlyPassThrough ONLY_PASS_THROUGH = new OnlyPassThrough(); + + private OnlyPassThrough() {} + } + + /** + * The proper columns of the table function are known at function declaration time. + * They do not depend on the actual call arguments. + */ + public static class DescribedTable + extends ReturnTypeSpecification + { + private final Descriptor descriptor; + + public DescribedTable(Descriptor descriptor) + { + requireNonNull(descriptor, "descriptor is null"); + checkArgument(descriptor.isTyped(), "field types not specified"); + this.descriptor = descriptor; + } + + public Descriptor getDescriptor() + { + return descriptor; + } + } +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/table/ScalarArgument.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/ScalarArgument.java new file mode 100644 index 0000000000000..3da8dae7199ae --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/ScalarArgument.java @@ -0,0 +1,102 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.function.table; + +import com.facebook.presto.common.predicate.NullableValue; +import com.facebook.presto.common.type.Type; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import javax.annotation.Nullable; + +import static java.util.Objects.requireNonNull; + +/** + * This class represents the scalar argument passed to a Table Function. + *

+ * This representation should be considered experimental. Eventually, {@link ConnectorExpression} + * should be extended to include this kind of argument. + *

+ * Additionally, only constant values are currently supported. In the future, + * we will add support for different kinds of expressions. + */ +public class ScalarArgument + extends Argument +{ + private final Type type; + + // native representation + @Nullable + private final Object value; + + public ScalarArgument(Type type, Object value) + { + this.type = requireNonNull(type, "type is null"); + this.value = value; + } + + public Type getType() + { + return type; + } + + public Object getValue() + { + return value; + } + + // deserialization + @JsonCreator + public static ScalarArgument fromNullableValue(@JsonProperty("nullableValue") NullableValue nullableValue) + { + return new ScalarArgument(nullableValue.getType(), nullableValue.getValue()); + } + + // serialization + @JsonProperty + public NullableValue getNullableValue() + { + return new NullableValue(type, value); + } + + public static Builder builder() + { + return new Builder(); + } + + public static final class Builder + { + private Type type; + private Object value; + + private Builder() {} + + public Builder type(Type type) + { + this.type = type; + return this; + } + + public Builder value(Object value) + { + this.value = value; + return this; + } + + public ScalarArgument build() + { + return new ScalarArgument(type, value); + } + } +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/table/ScalarArgumentSpecification.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/ScalarArgumentSpecification.java new file mode 100644 index 0000000000000..6016b4fc9b5ce --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/ScalarArgumentSpecification.java @@ -0,0 +1,80 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.function.table; + +import com.facebook.presto.common.predicate.Primitives; +import com.facebook.presto.common.type.Type; + +import static com.facebook.presto.spi.function.table.Preconditions.checkArgument; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +public class ScalarArgumentSpecification + extends ArgumentSpecification +{ + private final Type type; + + private ScalarArgumentSpecification(String name, Type type, boolean required, Object defaultValue) + { + super(name, required, defaultValue); + this.type = requireNonNull(type, "type is null"); + if (defaultValue != null) { + checkArgument(Primitives.wrap(type.getJavaType()).isInstance(defaultValue), format("default value %s does not match the declared type: %s", defaultValue, type)); + } + } + + public Type getType() + { + return type; + } + + public static Builder builder() + { + return new Builder(); + } + + public static final class Builder + { + private String name; + private Type type; + private boolean required = true; + private Object defaultValue; + + private Builder() {} + + public Builder name(String name) + { + this.name = name; + return this; + } + + public Builder type(Type type) + { + this.type = type; + return this; + } + + public Builder defaultValue(Object defaultValue) + { + this.required = false; + this.defaultValue = defaultValue; + return this; + } + + public ScalarArgumentSpecification build() + { + return new ScalarArgumentSpecification(name, type, required, defaultValue); + } + } +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/table/TableArgument.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/TableArgument.java new file mode 100644 index 0000000000000..eaf53c29b7d4c --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/TableArgument.java @@ -0,0 +1,103 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.function.table; + +import com.facebook.presto.common.type.RowType; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.Collections; +import java.util.List; + +import static java.util.Objects.requireNonNull; + +/** + * This class represents the table argument passed to a Table Function. + *

+ * This representation should be considered experimental. Eventually, {@link ConnectorExpression} + * should be extended to include this kind of argument. + */ +public class TableArgument + extends Argument +{ + private final RowType rowType; + private final List partitionBy; + private final List orderBy; + + @JsonCreator + public TableArgument( + @JsonProperty("rowType") RowType rowType, + @JsonProperty("partitionBy") List partitionBy, + @JsonProperty("orderBy") List orderBy) + { + this.rowType = requireNonNull(rowType, "rowType is null"); + this.partitionBy = requireNonNull(partitionBy, "partitionBy is null"); + this.orderBy = requireNonNull(orderBy, "orderBy is null"); + } + + @JsonProperty + public RowType getRowType() + { + return rowType; + } + + @JsonProperty + public List getPartitionBy() + { + return partitionBy; + } + + @JsonProperty + public List getOrderBy() + { + return orderBy; + } + + public static Builder builder() + { + return new Builder(); + } + + public static final class Builder + { + private RowType rowType; + private List partitionBy = Collections.emptyList(); + private List orderBy = Collections.emptyList(); + + private Builder() {} + + public Builder rowType(RowType rowType) + { + this.rowType = rowType; + return this; + } + + public Builder partitionBy(List partitionBy) + { + this.partitionBy = partitionBy; + return this; + } + + public Builder orderBy(List orderBy) + { + this.orderBy = orderBy; + return this; + } + + public TableArgument build() + { + return new TableArgument(rowType, partitionBy, orderBy); + } + } +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/table/TableArgumentSpecification.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/TableArgumentSpecification.java new file mode 100644 index 0000000000000..f45c42ca24f43 --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/TableArgumentSpecification.java @@ -0,0 +1,103 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.function.table; + +import static com.facebook.presto.spi.function.table.Preconditions.checkArgument; +import static java.util.Objects.requireNonNull; + +public class TableArgumentSpecification + extends ArgumentSpecification +{ + private final boolean rowSemantics; + private final boolean pruneWhenEmpty; + private final boolean passThroughColumns; + + private TableArgumentSpecification(String name, boolean rowSemantics, boolean pruneWhenEmpty, boolean passThroughColumns) + { + super(name, true, null); + + requireNonNull(pruneWhenEmpty, "The pruneWhenEmpty property is not set"); + checkArgument(!rowSemantics || pruneWhenEmpty, "Cannot set the KEEP WHEN EMPTY property for a table argument with row semantics"); + + this.rowSemantics = rowSemantics; + this.pruneWhenEmpty = pruneWhenEmpty; + this.passThroughColumns = passThroughColumns; + } + + public boolean isRowSemantics() + { + return rowSemantics; + } + + public boolean isPruneWhenEmpty() + { + return pruneWhenEmpty; + } + + public boolean isPassThroughColumns() + { + return passThroughColumns; + } + + public static Builder builder() + { + return new Builder(); + } + + public static final class Builder + { + private String name; + private boolean rowSemantics; + private boolean pruneWhenEmpty; + private boolean passThroughColumns; + + private Builder() {} + + public Builder name(String name) + { + this.name = name; + return this; + } + + public Builder rowSemantics() + { + this.rowSemantics = true; + this.pruneWhenEmpty = true; + return this; + } + + public Builder pruneWhenEmpty() + { + this.pruneWhenEmpty = true; + return this; + } + + public Builder keepWhenEmpty() + { + this.pruneWhenEmpty = false; + return this; + } + + public Builder passThroughColumns() + { + this.passThroughColumns = true; + return this; + } + + public TableArgumentSpecification build() + { + return new TableArgumentSpecification(name, rowSemantics, pruneWhenEmpty, passThroughColumns); + } + } +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/table/TableFunctionAnalysis.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/TableFunctionAnalysis.java new file mode 100644 index 0000000000000..ce78f74d2bcfe --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/TableFunctionAnalysis.java @@ -0,0 +1,109 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.function.table; + +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; + +import static com.facebook.presto.spi.function.table.Preconditions.checkArgument; +import static java.util.Objects.requireNonNull; + +/** + * An object of this class is produced by the `analyze()` method of a `ConnectorTableFunction` + * implementation. It contains all the analysis results: + *

+ * The `returnedType` field is used to inform the Analyzer of the proper columns returned by the Table + * Function, that is, the columns produced by the function, as opposed to the columns passed from the + * input tables. The `returnedType` should only be set if the declared returned type is GENERIC_TABLE. + *

+ * The `handle` field can be used to carry all information necessary to execute the table function, + * gathered at analysis time. Typically, these are the values of the constant arguments, and results + * of pre-processing arguments. + */ +public final class TableFunctionAnalysis +{ + // a map from table argument name to list of column indexes for all columns required from the table argument + private final Map> requiredColumns; + + private final Optional returnedType; + private final ConnectorTableFunctionHandle handle; + + private TableFunctionAnalysis(Optional returnedType, Map> requiredColumns, ConnectorTableFunctionHandle handle) + { + this.returnedType = requireNonNull(returnedType, "returnedType is null"); + returnedType.ifPresent(descriptor -> checkArgument(descriptor.isTyped(), "field types not specified")); + this.requiredColumns = Collections.unmodifiableMap( + requiredColumns.entrySet().stream() + .collect(Collectors.toMap( + Map.Entry::getKey, + entry -> Collections.unmodifiableList(entry.getValue())))); + this.handle = requireNonNull(handle, "handle is null"); + } + + public Optional getReturnedType() + { + return returnedType; + } + + public Map> getRequiredColumns() + { + return requiredColumns; + } + + public ConnectorTableFunctionHandle getHandle() + { + return handle; + } + + public static Builder builder() + { + return new Builder(); + } + + public static final class Builder + { + private Descriptor returnedType; + private final Map> requiredColumns = new HashMap<>(); + private ConnectorTableFunctionHandle handle = new ConnectorTableFunctionHandle() {}; + + private Builder() {} + + public Builder returnedType(Descriptor returnedType) + { + this.returnedType = returnedType; + return this; + } + + public Builder requiredColumns(String tableArgument, List columns) + { + this.requiredColumns.put(tableArgument, columns); + return this; + } + + public Builder handle(ConnectorTableFunctionHandle handle) + { + this.handle = handle; + return this; + } + + public TableFunctionAnalysis build() + { + return new TableFunctionAnalysis(Optional.ofNullable(returnedType), requiredColumns, handle); + } + } +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/table/TableFunctionDataProcessor.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/TableFunctionDataProcessor.java new file mode 100644 index 0000000000000..8a5b176e60ae7 --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/TableFunctionDataProcessor.java @@ -0,0 +1,35 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.function.table; + +import com.facebook.presto.common.Page; + +import java.util.List; +import java.util.Optional; + +public interface TableFunctionDataProcessor +{ + /** + * This method processes a portion of data. It is called multiple times until the partition is fully processed. + * + * @param input a tuple of {@link Page} including one page for each table function's input table. + * Pages list is ordered according to the corresponding argument specifications in {@link ConnectorTableFunction}. + * A page for an argument consists of columns requested during analysis (see {@link TableFunctionAnalysis#getRequiredColumns()}}. + * If any of the sources is fully processed, {@code Optional.empty)()} is returned for that source. + * If all sources are fully processed, the argument is {@code null}. + * @return {@link TableFunctionProcessorState} including the processor's state and optionally a portion of result. + * After the returned state is {@code FINISHED}, the method will not be called again. + */ + TableFunctionProcessorState process(List> input); +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/table/TableFunctionProcessorProvider.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/TableFunctionProcessorProvider.java new file mode 100644 index 0000000000000..556e3828eb79c --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/TableFunctionProcessorProvider.java @@ -0,0 +1,35 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.function.table; + +public interface TableFunctionProcessorProvider +{ + /** + * This method returns a {@code TableFunctionDataProcessor}. All the necessary information collected during analysis is available + * in the form of {@link ConnectorTableFunctionHandle}. It is called once per each partition processed by the table function. + */ + default TableFunctionDataProcessor getDataProcessor(ConnectorTableFunctionHandle handle) + { + throw new UnsupportedOperationException("this table function does not process input data"); + } + + /** + * This method returns a {@code TableFunctionSplitProcessor}. All the necessary information collected during analysis is available + * in the form of {@link ConnectorTableFunctionHandle}. It is called once per each split processed by the table function. + */ + default TableFunctionSplitProcessor getSplitProcessor(ConnectorTableFunctionHandle handle) + { + throw new UnsupportedOperationException("this table function does not process splits"); + } +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/table/TableFunctionProcessorState.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/TableFunctionProcessorState.java new file mode 100644 index 0000000000000..70ae743c34853 --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/TableFunctionProcessorState.java @@ -0,0 +1,106 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.function.table; + +import com.facebook.presto.common.Page; + +import javax.annotation.Nullable; + +import java.util.concurrent.CompletableFuture; + +import static java.util.Objects.requireNonNull; + +/** + * The result of processing input by {@link TableFunctionDataProcessor} or {@link TableFunctionSplitProcessor}. + * It can optionally include a portion of output data in the form of {@link Page} + * The returned {@link Page} should consist of: + * - proper columns produced by the table function + * - one column of type {@code BIGINT} for each table function's input table having the pass-through property (see {@link TableArgumentSpecification#isPassThroughColumns}), + * in order of the corresponding argument specifications. Entries in these columns are the indexes of input rows (from partition start) to be attached to output, + * or null to indicate that a row of nulls should be attached instead of an input row. The indexes are validated to be within the portion of the partition + * provided to the function so far. + * Note: when the input is empty, the only valid index value is null, because there are no input rows that could be attached to output. In such case, for performance + * reasons, the validation of indexes is skipped, and all pass-through columns are filled with nulls. + */ +public interface TableFunctionProcessorState +{ + final class Blocked + implements TableFunctionProcessorState + { + private final CompletableFuture future; + + private Blocked(CompletableFuture future) + { + this.future = requireNonNull(future, "future is null"); + } + + public static Blocked blocked(CompletableFuture future) + { + return new Blocked(future); + } + + public CompletableFuture getFuture() + { + return future; + } + } + + final class Finished + implements TableFunctionProcessorState + { + public static final Finished FINISHED = new Finished(); + + private Finished() {} + } + + final class Processed + implements TableFunctionProcessorState + { + private final boolean usedInput; + private final Page result; + + private Processed(boolean usedInput, @Nullable Page result) + { + this.usedInput = usedInput; + this.result = result; + } + + public static Processed usedInput() + { + return new Processed(true, null); + } + + public static Processed produced(Page result) + { + requireNonNull(result, "result is null"); + return new Processed(false, result); + } + + public static Processed usedInputAndProduced(Page result) + { + requireNonNull(result, "result is null"); + return new Processed(true, result); + } + + public boolean isUsedInput() + { + return usedInput; + } + + public Page getResult() + { + return result; + } + } +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/table/TableFunctionSplitProcessor.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/TableFunctionSplitProcessor.java new file mode 100644 index 0000000000000..504ea54fcb61f --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/TableFunctionSplitProcessor.java @@ -0,0 +1,28 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.function.table; + +import com.facebook.presto.spi.ConnectorSplit; + +public interface TableFunctionSplitProcessor +{ + /** + * This method processes a split. It is called multiple times until the whole output for the split is produced. + * + * @param split a {@link ConnectorSplit} representing a subtask. + * @return {@link TableFunctionProcessorState} including the processor's state and optionally a portion of result. + * After the returned state is {@code FINISHED}, the method will not be called again. + */ + TableFunctionProcessorState process(ConnectorSplit split); +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/plan/Assignments.java b/presto-spi/src/main/java/com/facebook/presto/spi/plan/Assignments.java index d19f363b6c53d..275621e55d05f 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/plan/Assignments.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/plan/Assignments.java @@ -204,6 +204,20 @@ public Builder put(Entry assignment) return this; } + public Builder putIdentities(Iterable variables) + { + for (VariableReferenceExpression variable : variables) { + putIdentity(variable); + } + return this; + } + + public Builder putIdentity(VariableReferenceExpression variable) + { + put(variable, variable); + return this; + } + public Assignments build() { return new Assignments(assignments); diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/plan/DataOrganizationSpecification.java b/presto-spi/src/main/java/com/facebook/presto/spi/plan/DataOrganizationSpecification.java new file mode 100644 index 0000000000000..30588227f76cf --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/plan/DataOrganizationSpecification.java @@ -0,0 +1,82 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.plan; + +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import javax.annotation.concurrent.Immutable; + +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +import static java.util.Collections.unmodifiableList; +import static java.util.Objects.requireNonNull; + +@Immutable +public class DataOrganizationSpecification +{ + private final List partitionBy; + private final Optional orderingScheme; + + @JsonCreator + public DataOrganizationSpecification( + @JsonProperty("partitionBy") List partitionBy, + @JsonProperty("orderingScheme") Optional orderingScheme) + { + requireNonNull(partitionBy, "partitionBy is null"); + requireNonNull(orderingScheme, "orderingScheme is null"); + + this.partitionBy = unmodifiableList(new ArrayList<>(partitionBy)); + this.orderingScheme = requireNonNull(orderingScheme, "orderingScheme is null"); + } + + @JsonProperty + public List getPartitionBy() + { + return partitionBy; + } + + @JsonProperty + public Optional getOrderingScheme() + { + return orderingScheme; + } + + @Override + public int hashCode() + { + return Objects.hash(partitionBy, orderingScheme); + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + + if (obj == null || getClass() != obj.getClass()) { + return false; + } + + DataOrganizationSpecification other = (DataOrganizationSpecification) obj; + + return Objects.equals(this.partitionBy, other.partitionBy) && + Objects.equals(this.orderingScheme, other.orderingScheme); + } +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/plan/OrderingScheme.java b/presto-spi/src/main/java/com/facebook/presto/spi/plan/OrderingScheme.java index 179e290d1808d..7dec3153cedb8 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/plan/OrderingScheme.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/plan/OrderingScheme.java @@ -14,6 +14,8 @@ package com.facebook.presto.spi.plan; import com.facebook.presto.common.block.SortOrder; +import com.facebook.presto.spi.LocalProperty; +import com.facebook.presto.spi.SortingProperty; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; @@ -105,6 +107,14 @@ public String toString() return stringBuilder.toString(); } + public List> toLocalProperties() + { + return unmodifiableList( + getOrderBy().stream() + .map(variable -> new SortingProperty<>(variable.getVariable(), getOrdering(variable.getVariable()))) + .collect(toList())); + } + private static void checkArgument(boolean condition, String messageFormat, Object... args) { if (!condition) { diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/plan/WindowNode.java b/presto-spi/src/main/java/com/facebook/presto/spi/plan/WindowNode.java index d24da2ad567c6..a9d958b34ca14 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/plan/WindowNode.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/plan/WindowNode.java @@ -47,7 +47,7 @@ public class WindowNode { private final PlanNode source; private final Set prePartitionedInputs; - private final Specification specification; + private final DataOrganizationSpecification specification; private final int preSortedOrderPrefix; private final Map windowFunctions; private final Optional hashVariable; @@ -57,7 +57,7 @@ public WindowNode( @JsonProperty("sourceLocation") Optional sourceLocation, @JsonProperty("id") PlanNodeId id, @JsonProperty("source") PlanNode source, - @JsonProperty("specification") Specification specification, + @JsonProperty("specification") DataOrganizationSpecification specification, @JsonProperty("windowFunctions") Map windowFunctions, @JsonProperty("hashVariable") Optional hashVariable, @JsonProperty("prePartitionedInputs") Set prePartitionedInputs, @@ -71,7 +71,7 @@ public WindowNode( PlanNodeId id, Optional statsEquivalentPlanNode, PlanNode source, - Specification specification, + DataOrganizationSpecification specification, Map windowFunctions, Optional hashVariable, Set prePartitionedInputs, @@ -122,7 +122,7 @@ public PlanNode getSource() } @JsonProperty - public Specification getSpecification() + public DataOrganizationSpecification getSpecification() { return specification; } @@ -134,7 +134,7 @@ public List getPartitionBy() public Optional getOrderingScheme() { - return specification.orderingScheme; + return specification.getOrderingScheme(); } @JsonProperty @@ -188,60 +188,6 @@ public PlanNode assignStatsEquivalentPlanNode(Optional statsEquivalent return new WindowNode(getSourceLocation(), getId(), statsEquivalentPlanNode, source, specification, windowFunctions, hashVariable, prePartitionedInputs, preSortedOrderPrefix); } - @Immutable - public static class Specification - { - private final List partitionBy; - private final Optional orderingScheme; - - @JsonCreator - public Specification( - @JsonProperty("partitionBy") List partitionBy, - @JsonProperty("orderingScheme") Optional orderingScheme) - { - requireNonNull(partitionBy, "partitionBy is null"); - requireNonNull(orderingScheme, "orderingScheme is null"); - - this.partitionBy = unmodifiableList(new ArrayList<>(partitionBy)); - this.orderingScheme = requireNonNull(orderingScheme, "orderingScheme is null"); - } - - @JsonProperty - public List getPartitionBy() - { - return partitionBy; - } - - @JsonProperty - public Optional getOrderingScheme() - { - return orderingScheme; - } - - @Override - public int hashCode() - { - return Objects.hash(partitionBy, orderingScheme); - } - - @Override - public boolean equals(Object obj) - { - if (this == obj) { - return true; - } - - if (obj == null || getClass() != obj.getClass()) { - return false; - } - - Specification other = (Specification) obj; - - return Objects.equals(this.partitionBy, other.partitionBy) && - Objects.equals(this.orderingScheme, other.orderingScheme); - } - } - @Immutable public static class Frame { diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/TestTableFunctionInvocation.java b/presto-tests/src/main/java/com/facebook/presto/tests/TestTableFunctionInvocation.java new file mode 100644 index 0000000000000..0ef24d4d7eb66 --- /dev/null +++ b/presto-tests/src/main/java/com/facebook/presto/tests/TestTableFunctionInvocation.java @@ -0,0 +1,549 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.tests; + +import com.facebook.presto.connector.tvf.MockConnectorColumnHandle; +import com.facebook.presto.connector.tvf.MockConnectorFactory; +import com.facebook.presto.connector.tvf.MockConnectorPlugin; +import com.facebook.presto.connector.tvf.TestingTableFunctions; +import com.facebook.presto.connector.tvf.TestingTableFunctions.SimpleTableFunction; +import com.facebook.presto.connector.tvf.TestingTableFunctions.SimpleTableFunction.SimpleTableFunctionHandle; +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.ConnectorTableHandle; +import com.facebook.presto.spi.FixedSplitSource; +import com.facebook.presto.spi.connector.TableFunctionApplicationResult; +import com.facebook.presto.spi.function.SchemaFunctionName; +import com.facebook.presto.testing.QueryRunner; +import com.facebook.presto.tpch.TpchPlugin; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.util.Map; +import java.util.Optional; +import java.util.function.BiFunction; +import java.util.stream.IntStream; + +import static com.facebook.presto.common.type.VarcharType.createUnboundedVarcharType; +import static com.facebook.presto.connector.tvf.MockConnectorFactory.MockConnector.MockConnectorSplit.MOCK_CONNECTOR_SPLIT; +import static com.facebook.presto.connector.tvf.TestingTableFunctions.ConstantFunction.getConstantFunctionSplitSource; +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; +import static com.facebook.presto.tpch.TpchMetadata.TINY_SCHEMA_NAME; +import static com.google.common.collect.ImmutableMap.toImmutableMap; + +public class TestTableFunctionInvocation + extends AbstractTestQueryFramework +{ + private static final String TESTING_CATALOG = "testing_catalog1"; + private static final String TABLE_FUNCTION_SCHEMA = "table_function_schema"; + + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + return DistributedQueryRunner.builder(testSessionBuilder() + .setCatalog(TESTING_CATALOG) + .setSchema(TABLE_FUNCTION_SCHEMA) + .build()).setSingleExtraProperty("query.max-memory-per-node", "2GB") + .build(); + } + @Override + protected QueryRunner createExpectedQueryRunner() + throws Exception + { + DistributedQueryRunner result = DistributedQueryRunner.builder(testSessionBuilder() + .setCatalog("tpch") + .setSchema(TINY_SCHEMA_NAME) + .build()) + .build(); + result.installPlugin(new TpchPlugin()); + result.createCatalog("tpch", "tpch"); + return result; + } + + @BeforeClass + public void setUp() + { + DistributedQueryRunner queryRunner = getDistributedQueryRunner(); + + BiFunction> getColumnHandles = (session, tableHandle) -> IntStream.range(0, 100) + .boxed() + .map(i -> "column_" + i) + .collect(toImmutableMap(column -> column, column -> new MockConnectorColumnHandle(column, createUnboundedVarcharType()) {})); + + queryRunner.installPlugin(new MockConnectorPlugin(MockConnectorFactory.builder() + .withTableFunctions(ImmutableSet.of( + new SimpleTableFunction(), + new TestingTableFunctions.IdentityFunction(), + new TestingTableFunctions.IdentityPassThroughFunction(), + new TestingTableFunctions.RepeatFunction(), + new TestingTableFunctions.EmptyOutputFunction(), + new TestingTableFunctions.EmptyOutputWithPassThroughFunction(), + new TestingTableFunctions.TestInputsFunction(), + new TestingTableFunctions.PassThroughInputFunction(), + new TestingTableFunctions.TestInputFunction(), + new TestingTableFunctions.TestSingleInputRowSemanticsFunction(), + new TestingTableFunctions.ConstantFunction(), + new TestingTableFunctions.EmptySourceFunction())) + .withApplyTableFunction((session, handle) -> { + if (handle instanceof SimpleTableFunctionHandle) { + SimpleTableFunctionHandle functionHandle = (SimpleTableFunctionHandle) handle; + return Optional.of(new TableFunctionApplicationResult<>(functionHandle.getTableHandle(), functionHandle.getTableHandle().getColumns().orElseThrow(() -> new IllegalStateException("Columns are missing")))); + } + return Optional.empty(); + }) + .withGetTableFunctionProcessorProvider(Optional.of(name -> { + if (name.equals(new SchemaFunctionName("system", "identity_function"))) { + return new TestingTableFunctions.IdentityFunction.IdentityFunctionProcessorProvider(); + } + else if (name.equals(new SchemaFunctionName("system", "identity_pass_through_function"))) { + return new TestingTableFunctions.IdentityPassThroughFunction.IdentityPassThroughFunctionProcessorProvider(); + } + else if (name.equals(new SchemaFunctionName("system", "repeat"))) { + return new TestingTableFunctions.RepeatFunction.RepeatFunctionProcessorProvider(); + } + else if (name.equals(new SchemaFunctionName("system", "empty_output"))) { + return new TestingTableFunctions.EmptyOutputFunction.EmptyOutputProcessorProvider(); + } + else if (name.equals(new SchemaFunctionName("system", "empty_output_with_pass_through"))) { + return new TestingTableFunctions.EmptyOutputWithPassThroughFunction.EmptyOutputWithPassThroughProcessorProvider(); + } + else if (name.equals(new SchemaFunctionName("system", "test_inputs_function"))) { + return new TestingTableFunctions.TestInputsFunction.TestInputsFunctionProcessorProvider(); + } + else if (name.equals(new SchemaFunctionName("system", "pass_through"))) { + return new TestingTableFunctions.PassThroughInputFunction.PassThroughInputProcessorProvider(); + } + else if (name.equals(new SchemaFunctionName("system", "test_input"))) { + return new TestingTableFunctions.TestInputFunction.TestInputProcessorProvider(); + } + else if (name.equals(new SchemaFunctionName("system", "test_single_input_function"))) { + return new TestingTableFunctions.TestSingleInputRowSemanticsFunction.TestSingleInputFunctionProcessorProvider(); + } + else if (name.equals(new SchemaFunctionName("system", "constant"))) { + return new TestingTableFunctions.ConstantFunction.ConstantFunctionProcessorProvider(); + } + else if (name.equals(new SchemaFunctionName("system", "empty_source"))) { + return new TestingTableFunctions.EmptySourceFunction.EmptySourceFunctionProcessorProvider(); + } + return null; + })) + .withTableFunctionResolver(TestingTableFunctions.RepeatFunction.RepeatFunctionHandle.class) + .withTableFunctionResolver(TestingTableFunctions.EmptyTableFunctionHandle.class) + .withTableFunctionResolver(TestingTableFunctions.ConstantFunction.ConstantFunctionHandle.class) + .withTableFunctionSplitResolver(MockConnectorFactory.MockConnector.MockConnectorSplit.class) + .withTableFunctionSplitResolver(TestingTableFunctions.ConstantFunction.ConstantFunctionSplit.class) + .withGetColumnHandles(getColumnHandles) + .withTableFunctionSplitSource( + new SchemaFunctionName("system", "constant"), + handle -> getConstantFunctionSplitSource((TestingTableFunctions.ConstantFunction.ConstantFunctionHandle) handle)) + .withTableFunctionSplitSource( + new SchemaFunctionName("system", "empty_source"), + handle -> new FixedSplitSource(ImmutableList.of(MOCK_CONNECTOR_SPLIT))) + .withTableFunctionSplitSource( + new SchemaFunctionName("system", "identity_function"), + handle -> new FixedSplitSource(ImmutableList.of(MOCK_CONNECTOR_SPLIT))) + .build())); + queryRunner.createCatalog(TESTING_CATALOG, "mock"); + + queryRunner.installPlugin(new TpchPlugin()); + queryRunner.createCatalog("tpch", "tpch"); + } + + @Test + public void testPrimitiveDefaultArgument() + { + assertQuery("SELECT boolean_column FROM TABLE(system.simple_table_function(column => 'boolean_column', ignored => 1))", "SELECT true WHERE false"); + + // skip the `ignored` argument. + assertQuery("SELECT boolean_column FROM TABLE(system.simple_table_function(column => 'boolean_column'))", + "SELECT true WHERE false"); + } + + @Test + public void testNoArgumentsPassed() + { + assertQuery("SELECT col FROM TABLE(system.simple_table_function())", + "SELECT true WHERE false"); + } + + @Test + public void testIdentityFunction() + { + assertQuery("SELECT b, a FROM TABLE(system.identity_function(input => TABLE(VALUES (1, 2), (3, 4), (5, 6)) T(a, b)))", + "VALUES (2, 1), (4, 3), (6, 5)"); + + assertQuery("SELECT b, a FROM TABLE(system.identity_pass_through_function(input => TABLE(VALUES (1, 2), (3, 4), (5, 6)) T(a, b)))", + "VALUES (2, 1), (4, 3), (6, 5)"); + + // null partitioning value + // TODO: Come back to this. It is supposed to be i.b. Table alias. + //assertQuery("SELECT b, a FROM TABLE(system.identity_function(input => TABLE(VALUES ('x', 1), ('y', 2), ('z', null)) T(a, b) PARTITION BY b)) i", + // "VALUES (1, 'x'), (2, 'y'), (null, 'z')"); + + assertQuery("SELECT b, a FROM TABLE(system.identity_pass_through_function(input => TABLE(VALUES ('x', 1), ('y', 2), ('z', null)) T(a, b) PARTITION BY b))", + "VALUES (1, 'x'), (2, 'y'), (null, 'z')"); + + // the identity_function copies all input columns and outputs them as proper columns. + // the table tpch.tiny.orders has a hidden column row_number, which is not exposed to the function. + assertQuery("SELECT * FROM TABLE(system.identity_function(input => TABLE(tpch.tiny.region)))", + "SELECT * FROM tpch.tiny.region"); + + // the identity_pass_through_function passes all input columns on output using the pass-through mechanism (as opposed to producing proper columns). + // the table tpch.tiny.orders has a hidden column row_number, which is exposed to the pass-through mechanism. + // the passed-through column row_number preserves its hidden property. + assertQuery("SELECT row_number, * FROM TABLE(system.identity_pass_through_function(input => TABLE(tpch.tiny.orders)))", + "SELECT row_number, * FROM tpch.tiny.orders"); + } + + @Test + public void testRepeatFunction() + { + assertQuery("SELECT * FROM TABLE(system.repeat(TABLE(VALUES (1, 2), (3, 4), (5, 6))))", + "VALUES (1, 2), (1, 2), (3, 4), (3, 4), (5, 6), (5, 6)"); + + assertQuery("SELECT * FROM TABLE(system.repeat(TABLE(VALUES ('a', true), ('b', false)), 4))", + "VALUES ('a', true), ('b', false), ('a', true), ('b', false), ('a', true), ('b', false), ('a', true), ('b', false)"); + + assertQuery("SELECT * FROM TABLE(system.repeat(TABLE(VALUES ('a', true), ('b', false)) t(x, y) PARTITION BY x,4))", + "VALUES ('a', true), ('b', false), ('a', true), ('b', false), ('a', true), ('b', false), ('a', true), ('b', false)"); + + assertQuery("SELECT * FROM TABLE(system.repeat(TABLE(VALUES ('a', true), ('b', false)) t(x, y) ORDER BY y, 4))", + "VALUES ('a', true), ('b', false), ('a', true), ('b', false), ('a', true), ('b', false), ('a', true), ('b', false)"); + + assertQuery("SELECT * FROM TABLE(system.repeat(TABLE(VALUES ('a', true), ('b', false)) t(x, y) PARTITION BY x ORDER BY y, 4))", + "VALUES ('a', true), ('b', false), ('a', true), ('b', false), ('a', true), ('b', false), ('a', true), ('b', false)"); + + assertQuery("SELECT * FROM TABLE(system.repeat(TABLE(tpch.tiny.part), 3))", + "SELECT * FROM tpch.tiny.part UNION ALL TABLE tpch.tiny.part UNION ALL TABLE tpch.tiny.part"); + + assertQuery("SELECT * FROM TABLE(system.repeat(TABLE(tpch.tiny.part) PARTITION BY type, 3))", + "SELECT * FROM tpch.tiny.part UNION ALL TABLE tpch.tiny.part UNION ALL TABLE tpch.tiny.part"); + + assertQuery("SELECT * FROM TABLE(system.repeat(TABLE(tpch.tiny.part) ORDER BY size, 3))", + "SELECT * FROM tpch.tiny.part UNION ALL TABLE tpch.tiny.part UNION ALL TABLE tpch.tiny.part"); + + assertQuery("SELECT * FROM TABLE(system.repeat(TABLE(tpch.tiny.part) PARTITION BY type ORDER BY size, 3))", + "SELECT * FROM tpch.tiny.part UNION ALL TABLE tpch.tiny.part UNION ALL TABLE tpch.tiny.part"); + } + + @Test + public void testFunctionsReturningEmptyPages() + { + // the functions empty_output and empty_output_with_pass_through return an empty Page for each processed input Page. the argument has KEEP WHEN EMPTY property + + // non-empty input, no pass-trough columns + + assertQuery("SELECT * FROM TABLE(system.empty_output(TABLE(tpch.tiny.orders)))", + "SELECT true WHERE false"); + + // non-empty input, pass-through partitioning column + assertQuery("SELECT * FROM TABLE(system.empty_output(TABLE(tpch.tiny.orders) PARTITION BY orderstatus))", + "SELECT true, 'X' WHERE false"); + + // non-empty input, argument has pass-trough columns + assertQuery("SELECT * FROM TABLE(system.empty_output_with_pass_through(TABLE(tpch.tiny.orders)))", + "SELECT true, * FROM tpch.tiny.orders WHERE false"); + + // non-empty input, argument has pass-trough columns, partitioning column present + assertQuery("SELECT * FROM TABLE(system.empty_output_with_pass_through(TABLE(tpch.tiny.orders) PARTITION BY orderstatus))", + "SELECT true, * FROM tpch.tiny.orders WHERE false"); + + // empty input, no pass-trough columns + assertQuery("SELECT * FROM TABLE(system.empty_output(TABLE(SELECT * FROM tpch.tiny.orders WHERE false)))", + "SELECT true WHERE false"); + + // empty input, pass-through partitioning column + assertQuery("SELECT * FROM TABLE(system.empty_output(TABLE(SELECT * FROM tpch.tiny.orders WHERE false) PARTITION BY orderstatus))", + "SELECT true, 'X' WHERE false"); + + // empty input, argument has pass-trough columns + assertQuery("SELECT * FROM TABLE(system.empty_output_with_pass_through(TABLE(SELECT * FROM tpch.tiny.orders WHERE false)))", + "SELECT true, * FROM tpch.tiny.orders WHERE false"); + + // empty input, argument has pass-trough columns, partitioning column present + assertQuery("SELECT * FROM TABLE(system.empty_output_with_pass_through(TABLE(SELECT * FROM tpch.tiny.orders WHERE false) PARTITION BY orderstatus)) ", + "SELECT true, * FROM tpch.tiny.orders WHERE false"); + + // function empty_source returns an empty Page for each Split it processes + assertQuery("SELECT * FROM TABLE(system.empty_source())", + "SELECT true WHERE false"); + } + + @Test + public void testInputPartitioning() + { + // table function test_inputs_function has four table arguments. input_1 has row semantics. input_2, input_3 and input_4 have set semantics. + // the function outputs one row per each tuple of partition it processes. The row includes a true value, and partitioning values. + assertQuery("SELECT *\n" + + "FROM TABLE(system.test_inputs_function(\n" + + " input_1 => TABLE(VALUES 1, 2, 3),\n" + + " input_2 => TABLE(VALUES 4, 5, 4, 5, 4) t2(x2) PARTITION BY x2,\n" + + " input_3 => TABLE(VALUES 6, 7, 6) t3(x3) PARTITION BY x3,\n" + + " input_4 => TABLE(VALUES 8, 9)))\n", + "VALUES (true, 4, 6), (true, 4, 7), (true, 5, 6), (true, 5, 7)"); + + assertQuery("SELECT * FROM TABLE(system.test_inputs_function(" + + "input_1 => TABLE(VALUES 1, 2, 3)," + + "input_2 => TABLE(VALUES 4, 5, 4, 5, 4) t2(x2) PARTITION BY x2," + + "input_3 => TABLE(VALUES 6, 7, 6) t3(x3) PARTITION BY x3," + + "input_4 => TABLE(VALUES 8, 9) t4(x4) PARTITION BY x4))", + "VALUES (true, 4, 6, 8), (true, 4, 6, 9), (true, 4, 7, 8), (true, 4, 7, 9), (true, 5, 6, 8), (true, 5, 6, 9), (true, 5, 7, 8), (true, 5, 7, 9)"); + + assertQuery("SELECT * FROM TABLE(system.test_inputs_function(" + + "input_1 => TABLE(VALUES 1, 2, 3)," + + "input_2 => TABLE(VALUES 4, 5, 4, 5, 4) t2(x2) PARTITION BY x2," + + "input_3 => TABLE(VALUES 6, 7, 6) t3(x3) PARTITION BY x3," + + "input_4 => TABLE(VALUES 8, 8) t4(x4) PARTITION BY x4))", + "VALUES (true, 4, 6, 8), (true, 4, 7, 8), (true, 5, 6, 8), (true, 5, 7, 8)"); + + // null partitioning values + assertQuery("SELECT * FROM TABLE(system.test_inputs_function(" + + "input_1 => TABLE(VALUES 1, null)," + + "input_2 => TABLE(VALUES 2, null, 2, null) t2(x2) PARTITION BY x2," + + "input_3 => TABLE(VALUES 3, null, 3, null) t3(x3) PARTITION BY x3," + + "input_4 => TABLE(VALUES null, null) t4(x4) PARTITION BY x4))", + "VALUES (true, 2, 3, null), (true, 2, null, null), (true, null, 3, null), (true, null, null, null)"); + + assertQuery("SELECT * FROM TABLE(system.test_inputs_function(" + + "input_1 => TABLE(VALUES 1, 2, 3)," + + "input_2 => TABLE(VALUES 4, 5, 4, 5, 4)," + + "input_3 => TABLE(VALUES 6, 7, 6)," + + "input_4 => TABLE(VALUES 8, 9)))", + "VALUES true"); + + assertQuery("SELECT DISTINCT regionkey, nationkey FROM TABLE(system.test_inputs_function(" + + "input_1 => TABLE(tpch.tiny.nation)," + + "input_2 => TABLE(tpch.tiny.nation) PARTITION BY regionkey ORDER BY name," + + "input_3 => TABLE(tpch.tiny.customer) PARTITION BY nationkey," + + "input_4 => TABLE(tpch.tiny.customer)))", + "SELECT DISTINCT n.regionkey, c.nationkey FROM tpch.tiny.nation n, tpch.tiny.customer c"); + } + + @Test + public void testEmptyPartitions() + { + // input_1 has row semantics, so it is prune when empty. input_2, input_3 and input_4 have set semantics, and are keep when empty by default + assertQuery("SELECT * FROM TABLE(system.test_inputs_function(" + + "input_1 => TABLE(VALUES 1, 2, 3)," + + "input_2 => TABLE(SELECT 2 WHERE false)," + + "input_3 => TABLE(SELECT 3 WHERE false)," + + "input_4 => TABLE(SELECT 4 WHERE false)))", + "VALUES true"); + + assertQueryReturnsEmptyResult("SELECT * FROM TABLE(system.test_inputs_function(" + + "input_1 => TABLE(SELECT 1 WHERE false)," + + "input_2 => TABLE(VALUES 2)," + + "input_3 => TABLE(VALUES 3)," + + "input_4 => TABLE(VALUES 4)))"); + + assertQuery("SELECT * FROM TABLE(system.test_inputs_function(" + + "input_1 => TABLE(VALUES 1, 2, 3)," + + "input_2 => TABLE(SELECT 2 WHERE false) t2(x2) PARTITION BY x2," + + "input_3 => TABLE(SELECT 3 WHERE false) t3(x3) PARTITION BY x3," + + "input_4 => TABLE(SELECT 4 WHERE false) t4(x4) PARTITION BY x4))", + "VALUES (true, CAST(null AS integer), CAST(null AS integer), CAST(null AS integer))"); + + assertQuery("SELECT * FROM TABLE(system.test_inputs_function(" + + "input_1 => TABLE(VALUES 1, 2, 3)," + + "input_2 => TABLE(SELECT 2 WHERE false) t2(x2) PARTITION BY x2," + + "input_3 => TABLE(VALUES 3, 4, 4) t3(x3) PARTITION BY x3," + + "input_4 => TABLE(VALUES 4, 4, 4, 5, 5, 5, 5) t4(x4) PARTITION BY x4))", + "VALUES (true, CAST(null AS integer), 3, 4), (true, null, 4, 4), (true, null, 4, 5), (true, null, 3, 5)"); + + assertQuery("SELECT * FROM TABLE(system.test_inputs_function(" + + "input_1 => TABLE(VALUES 1, 2, 3)," + + "input_2 => TABLE(SELECT 2 WHERE false) t2(x2) PARTITION BY x2," + + "input_3 => TABLE(SELECT 3 WHERE false) t3(x3) PARTITION BY x3," + + "input_4 => TABLE(VALUES 4, 5) t4(x4) PARTITION BY x4))", + "VALUES (true, CAST(null AS integer), CAST(null AS integer), 4), (true, null, null, 5)"); + + assertQueryReturnsEmptyResult("SELECT * FROM TABLE(system.test_inputs_function(" + + "input_1 => TABLE(VALUES 1, 2, 3)," + + "input_2 => TABLE(SELECT 2 WHERE false) t2(x2) PARTITION BY x2 PRUNE WHEN EMPTY," + + "input_3 => TABLE(SELECT 3 WHERE false) t3(x3) PARTITION BY x3," + + "input_4 => TABLE(VALUES 4, 5) t4(x4) PARTITION BY x4))"); + } + + @Test + public void testCopartitioning() + { + // all tanbles are by default KEEP WHEN EMPTY. If there is no matching partition, it is null-completed + assertQuery("SELECT *\n" + + "FROM TABLE(system.test_inputs_function(\n" + + " input_1 => TABLE(VALUES 1, 2, 3),\n" + + " input_2 => TABLE(VALUES 1, 1, 2, 2) t2(x2) PARTITION BY x2,\n" + + " input_3 => TABLE(VALUES 4, 5) t3(x3),\n" + + " input_4 => TABLE(VALUES 2, 2, 2, 3) t4(x4) PARTITION BY x4\n" + + " COPARTITION (t2, t4)))\n", + "VALUES (true, 1, null), (true, 2, 2), (true, null, 3)"); + + // partition `3` from input_4 is pruned because there is no matching partition in input_2 + assertQuery("SELECT *\n" + + "FROM TABLE(system.test_inputs_function(\n" + + " input_1 => TABLE(VALUES 1, 2, 3),\n" + + " input_2 => TABLE(VALUES 1, 1, 2, 2) t2(x2) PARTITION BY x2 PRUNE WHEN EMPTY,\n" + + " input_3 => TABLE(VALUES 4, 5) t3(x3),\n" + + " input_4 => TABLE(VALUES 2, 2, 2, 3) t4(x4) PARTITION BY x4\n" + + " COPARTITION (t2, t4)))\n", + "VALUES (true, 1, null), (true, 2, 2)"); + + // partition `1` from input_2 is pruned because there is no matching partition in input_4 + assertQuery("SELECT *\n" + + "FROM TABLE(system.test_inputs_function(\n" + + " input_1 => TABLE(VALUES 1, 2, 3),\n" + + " input_2 => TABLE(VALUES 1, 1, 2, 2) t2(x2) PARTITION BY x2,\n" + + " input_3 => TABLE(VALUES 4, 5) t3(x3),\n" + + " input_4 => TABLE(VALUES 2, 2, 2, 3) t4(x4) PARTITION BY x4 PRUNE WHEN EMPTY\n" + + " COPARTITION (t2, t4)))\n", + "VALUES (true, 2, 2), (true, null, 3)"); + + assertQuery("SELECT *\n" + + "FROM TABLE(system.test_inputs_function(\n" + + " input_1 => TABLE(VALUES 1, 2, 3),\n" + + " input_2 => TABLE(VALUES 1, 1, 2, 2) t2(x2) PARTITION BY x2 PRUNE WHEN EMPTY,\n" + + " input_3 => TABLE(VALUES 4, 5) t3(x3),\n" + + " input_4 => TABLE(VALUES 2, 2, 2, 3) t4(x4) PARTITION BY x4 PRUNE WHEN EMPTY\n" + + " COPARTITION (t2, t4)))\n", + "VALUES (true, 2, 2)"); + + // null partitioning values + assertQuery("SELECT *\n" + + "FROM TABLE(system.test_inputs_function(\n" + + " input_1 => TABLE(VALUES 1, 2, 3),\n" + + " input_2 => TABLE(VALUES 1, 1, null, null, 2, 2) t2(x2) PARTITION BY x2,\n" + + " input_3 => TABLE(VALUES 4, 5) t3(x3),\n" + + " input_4 => TABLE(VALUES null, 2, 2, 2, 3) t4(x4) PARTITION BY x4\n" + + " COPARTITION (t2, t4)))\n", + "VALUES (true, 1, null), (true, 2, 2), (true, null, null), (true, null, 3)"); + + assertQuery("SELECT *\n" + + "FROM TABLE(system.test_inputs_function(\n" + + " input_1 => TABLE(VALUES 1, 2, 3),\n" + + " input_2 => TABLE(VALUES 1, 1, null, null, 2, 2) t2(x2) PARTITION BY x2 PRUNE WHEN EMPTY,\n" + + " input_3 => TABLE(VALUES 4, 5) t3(x3),\n" + + " input_4 => TABLE(VALUES null, 2, 2, 2, 3) t4(x4) PARTITION BY x4 PRUNE WHEN EMPTY\n" + + " COPARTITION (t2, t4)))\n", + "VALUES (true, 2, 2), (true, null, null)"); + + assertQuery("SELECT *\n" + + "FROM TABLE(system.test_inputs_function(\n" + + " input_1 => TABLE(VALUES 1, 2, 3),\n" + + " input_2 => TABLE(VALUES 1, 1, null, null) t2(x2) PARTITION BY x2,\n" + + " input_3 => TABLE(VALUES 2, 2, null) t3(x3) PARTITION BY x3,\n" + + " input_4 => TABLE(VALUES 2, 3, 3) t4(x4) PARTITION BY x4\n" + + " COPARTITION (t2, t4, t3)))\n", + "VALUES (true, 1, null, null), (true, null, null, null), (true, null, 2, 2), (true, null, null, 3)"); + + assertQuery("SELECT *\n" + + "FROM TABLE(system.test_inputs_function(\n" + + " input_1 => TABLE(VALUES 1, 2, 3),\n" + + " input_2 => TABLE(VALUES 1, 1, null, null) t2(x2) PARTITION BY x2,\n" + + " input_3 => TABLE(VALUES 2, 2, null) t3(x3) PARTITION BY x3 PRUNE WHEN EMPTY,\n" + + " input_4 => TABLE(VALUES 2, 3, 3) t4(x4) PARTITION BY x4\n" + + " COPARTITION (t2, t4, t3)))\n", + "VALUES (true, CAST(null AS integer), null, null), (true, null, 2, 2)"); + + assertQuery("SELECT *\n" + + "FROM TABLE(system.test_inputs_function(\n" + + " input_1 => TABLE(VALUES 1, 2, 3),\n" + + " input_2 => TABLE(VALUES 1, 1, null, null) t2(x2) PARTITION BY x2 PRUNE WHEN EMPTY,\n" + + " input_3 => TABLE(VALUES 2, 2, null) t3(x3) PARTITION BY x3,\n" + + " input_4 => TABLE(VALUES 2, 3, 3) t4(x4) PARTITION BY x4\n" + + " COPARTITION (t2, t4, t3)))\n", + "VALUES (true, 1, CAST(null AS integer), CAST(null AS integer)), (true, null, null, null)"); + + assertQueryReturnsEmptyResult( + "SELECT *\n" + + "FROM TABLE(system.test_inputs_function(\n" + + " input_1 => TABLE(VALUES 1, 2, 3),\n" + + " input_2 => TABLE(VALUES 1, 1, null, null) t2(x2) PARTITION BY x2 PRUNE WHEN EMPTY,\n" + + " input_3 => TABLE(VALUES 2, 2, null) t3(x3) PARTITION BY x3,\n" + + " input_4 => TABLE(VALUES 2, 3, 3) t4(x4) PARTITION BY x4 PRUNE WHEN EMPTY\n" + + " COPARTITION (t2, t4, t3)))\n"); + } + + @Test + public void testPassThroughWithEmptyPartitions() + { + assertQuery("SELECT * FROM TABLE(system.pass_through(" + + "TABLE(VALUES (1, 'a'), (2, 'b')) t1(a1, b1) PARTITION BY a1," + + "TABLE(VALUES (2, 'x'), (3, 'y')) t2(a2, b2) PARTITION BY a2 " + + "COPARTITION (t1, t2)))", + "VALUES (true, false, 1, 'a', null, null), (true, true, 2, 'b', 2, 'x'), (false, true, null, null, 3, 'y')"); + + assertQuery("SELECT * FROM TABLE(system.pass_through(" + + "TABLE(VALUES (1, 'a'), (2, 'b')) t1(a1, b1) PARTITION BY a1," + + "TABLE(SELECT 2, 'x' WHERE false) t2(a2, b2) PARTITION BY a2 " + + "COPARTITION (t1, t2)))", + "VALUES (true, false, 1, 'a', CAST(null AS integer), CAST(null AS VARCHAR(1))), (true, false, 2, 'b', null, null)"); + + assertQuery("SELECT * FROM TABLE(system.pass_through(" + + "TABLE(VALUES (1, 'a'), (2, 'b')) t1(a1, b1) PARTITION BY a1," + + "TABLE(SELECT 2, 'x' WHERE false) t2(a2, b2) PARTITION BY a2))", + "VALUES (true, false, 1, 'a', CAST(null AS integer), CAST(null AS VARCHAR(1))), (true, false, 2, 'b', null, null)"); + } + + @Test + public void testPassThroughWithEmptyInput() + { + assertQuery("SELECT * FROM TABLE(system.pass_through(TABLE(SELECT 1, 'x' WHERE false) t1(a1, b1) PARTITION BY a1, TABLE(SELECT 2, 'y' WHERE false) t2(a2, b2) PARTITION BY a2 COPARTITION (t1, t2)))", + "VALUES (false, false, CAST(null AS integer), CAST(null AS VARCHAR(1)), CAST(null AS integer), CAST(null AS VARCHAR(1)))"); + + assertQuery("SELECT * FROM TABLE(system.pass_through(TABLE(SELECT 1, 'x' WHERE false) t1(a1, b1) PARTITION BY a1, TABLE(SELECT 2, 'y' WHERE false) t2(a2, b2) PARTITION BY a2))", + "VALUES (false, false, CAST(null AS integer), CAST(null AS VARCHAR(1)), CAST(null AS integer), CAST(null AS VARCHAR(1)))"); + } + + @Test + public void testInput() + { + assertQuery("SELECT got_input FROM TABLE(system.test_input(TABLE(VALUES 1)))", "VALUES true"); + + assertQuery("SELECT got_input FROM TABLE(system.test_input(TABLE(VALUES 1, 2, 3) t(a) PARTITION BY a))", + "VALUES true, true, true"); + + assertQuery("SELECT got_input FROM TABLE(system.test_input(TABLE(SELECT 1 WHERE false)))", "VALUES false"); + + assertQuery("SELECT got_input FROM TABLE(system.test_input(TABLE(SELECT 1 WHERE false) t(a) PARTITION BY a))", + "VALUES false"); + + assertQuery("SELECT got_input FROM TABLE(system.test_input(TABLE(SELECT * FROM tpch.tiny.orders WHERE false)))", "VALUES false"); + + assertQuery("SELECT got_input FROM TABLE(system.test_input(TABLE(SELECT * FROM tpch.tiny.orders WHERE false) PARTITION BY orderstatus ORDER BY orderkey))", "VALUES false"); + } + + @Test + public void testSingleSourceWithRowSemantics() + { + assertQuery("SELECT * FROM TABLE(system.test_single_input_function(TABLE(VALUES (true), (false), (true))))", "VALUES true"); + } + + @Test + public void testConstantFunction() + { + assertQuery("SELECT * FROM TABLE(system.constant(5))", "VALUES 5"); + + assertQuery("SELECT * FROM TABLE(system.constant(2, 10))", "VALUES (2), (2), (2), (2), (2), (2), (2), (2), (2), (2)"); + + assertQuery("SELECT * FROM TABLE(system.constant(null, 3))", "VALUES (CAST(null AS integer)), (null), (null)"); + + // value as constant expression + assertQuery("SELECT * FROM TABLE(system.constant(5 * 4, 3))", "VALUES (20), (20), (20)"); + + assertQueryFails("SELECT * FROM TABLE(system.constant(2147483648, 3))", "line 1:37: Cannot cast type bigint to integer"); + + assertQuery("SELECT count(*), count(DISTINCT constant_column), min(constant_column) FROM TABLE(system.constant(2, 1000000))", "VALUES (BIGINT '1000000', BIGINT '1', 2)"); + } +}