diff --git a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieFileIndex.scala b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieFileIndex.scala index 660592cf791ef..2187e5ab6903f 100644 --- a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieFileIndex.scala +++ b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieFileIndex.scala @@ -37,7 +37,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.catalog.HoodieCatalogTable -import org.apache.spark.sql.catalyst.expressions.{Expression, Literal} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, GetStructField, Literal} import org.apache.spark.sql.execution.datasources.{FileIndex, FileStatusCache, NoopCache, PartitionDirectory} import org.apache.spark.sql.hudi.HoodieSqlCommonUtils import org.apache.spark.sql.internal.SQLConf @@ -105,6 +105,10 @@ case class HoodieFileIndex(spark: SparkSession, @transient protected var hasPushedDownPartitionPredicates: Boolean = false + /** True when any partition column is a nested field path (e.g. "nested_record.level"). */ + private val hasNestedPartitionColumns: Boolean = + getPartitionColumns.exists(_.contains(".")) + /** * NOTE: [[indicesSupport]] is a transient state, since it's only relevant while logical plan * is handled by the Spark's driver @@ -167,19 +171,44 @@ case class HoodieFileIndex(spark: SparkSession, /** * Invoked by Spark to fetch list of latest base files per partition. * - * @param partitionFilters partition column filters - * @param dataFilters data columns filters - * @return list of PartitionDirectory containing partition to base files mapping + * For regular partition columns, Spark passes correct `partitionFilters` directly. + * + * For nested partition columns (e.g. `nested_record.level`), Spark cannot match + * [[GetStructField]] expressions against the flat dot-path partition schema and passes + * `partitionFilters = []`. The nested predicates land in `dataFilters` instead. + * We re-extract them via [[extractNestedPartitionFilters]]. + * + * Example: `SELECT * FROM t WHERE nested_record.level = 'INFO' AND int_field > 0` + * - Spark passes: `partitionFilters = []`, `dataFilters = [nested_record.level = 'INFO', int_field > 0]` + * - We extract: `effectivePartitionFilters = [nested_record.level = 'INFO']` + * + * This is stateless — safe under AQE re-planning, subqueries, and FileIndex reuse. + * + * Known limitation: for mixed flat+nested partitions (e.g. `["country", "nested_record.level"]`), + * if Spark passes `partitionFilters = [country = 'US']`, we skip extraction and the nested + * filter is not used for partition pruning. A future fix could merge extracted nested filters + * with the provided `partitionFilters`. */ override def listFiles(partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): Seq[PartitionDirectory] = { - val slices = filterFileSlices(dataFilters, partitionFilters).flatMap( + val effectivePartitionFilters = if (partitionFilters.isEmpty && hasNestedPartitionColumns) { + extractNestedPartitionFilters(dataFilters) + } else { + partitionFilters + } + + val slices = filterFileSlices(dataFilters, effectivePartitionFilters).flatMap( { case (partitionOpt, fileSlices) => - fileSlices.filter(!_.isEmpty).map(fs => ( InternalRow.fromSeq(partitionOpt.get.getValues), fs)) + fileSlices.filter(!_.isEmpty).map(fs => (InternalRow.fromSeq(partitionOpt.get.getValues), fs)) } ) prepareFileSlices(slices) } + /** Delegates to companion object with this table's partition columns. */ + private def extractNestedPartitionFilters(dataFilters: Seq[Expression]): Seq[Expression] = { + HoodieFileIndex.extractNestedPartitionFilters(dataFilters, getPartitionColumns.toSet) + } + protected def prepareFileSlices(slices: Seq[(InternalRow, FileSlice)]): Seq[PartitionDirectory] = { hasPushedDownPartitionPredicates = true @@ -212,25 +241,25 @@ case class HoodieFileIndex(spark: SparkSession, } /** - * The functions prunes the partition paths based on the input partition filters. For every partition path, the file - * slices are further filtered after querying metadata table based on the data filters. + * Prunes partitions by `partitionFilters`, then optionally applies data skipping via metadata + * table indices (column stats, record-level index, etc.) to filter file slices. * - * @param dataFilters data columns filters - * @param partitionFilters partition column filters - * @param partitionPrune for HoodiePruneFileSourcePartitions rule only prune partitions - * @return A sequence of pruned partitions and corresponding filtered file slices + * @param dataFilters data column filters (used for data skipping) + * @param partitionFilters partition column filters (used for partition pruning) + * @param isPartitionPruneOnly when true, skip data skipping. Used by [[HoodiePruneFileSourcePartitions]] + * during planning (data skipping runs later in [[listFiles]]). */ - def filterFileSlices(dataFilters: Seq[Expression], partitionFilters: Seq[Expression], isPartitionPruned: Boolean = false) + def filterFileSlices(dataFilters: Seq[Expression], partitionFilters: Seq[Expression], + isPartitionPruneOnly: Boolean = false) : Seq[(Option[BaseHoodieTableFileIndex.PartitionPath], Seq[FileSlice])] = { val (isPruned, prunedPartitionsAndFileSlices) = prunePartitionsAndGetFileSlices(dataFilters, partitionFilters) hasPushedDownPartitionPredicates = true - // If there are no data filters, return all the file slices. - // If isPartitionPurge is true, this fun is trigger by HoodiePruneFileSourcePartitions, don't look up candidate files - // If there are no file slices, return empty list. - if (prunedPartitionsAndFileSlices.isEmpty || dataFilters.isEmpty || isPartitionPruned ) { + // Skip data skipping when: no file slices, no data filters, or partition-prune-only mode + // (planning phase — data skipping runs later during execution). + if (prunedPartitionsAndFileSlices.isEmpty || dataFilters.isEmpty || isPartitionPruneOnly) { prunedPartitionsAndFileSlices } else { // Look up candidate files names in the col-stats or record level index, if all of the following conditions are true @@ -502,6 +531,65 @@ object HoodieFileIndex extends Logging { val Strict: Val = Val("strict") } + /** + * Extracts filters from `dataFilters` that reference nested partition columns by walking + * [[GetStructField]] chains to reconstruct the full dot-path and matching against partition + * column names. We cannot match on the struct root alone because sibling fields share it + * (e.g. `nested_record.level` and `nested_record.nested_int` both reference `nested_record`). + * + * Given partition column `nested_record.level` and: + * {{{ + * dataFilters = [nested_record.level = 'INFO', nested_record.nested_int > 0, int_field = 5] + * }}} + * Returns: `[nested_record.level = 'INFO']` + * + * Known limitations vs regular partition columns: + * - `(nested_record.level = 'INFO' AND d = 2) OR (nested_record.level = 'ERROR')` is excluded + * entirely (references both partition and data columns). A weaker predicate like + * `nested_record.level IN ('INFO', 'ERROR')` could be extracted but is not implemented. + * Spark has the same OR limitation for regular partition columns. + * + * @param dataFilters filters to scan for nested partition predicates + * @param partitionColumnNames partition column dot-paths, e.g. `Set("nested_record.level")` + * @return only the filters whose every column reference is a partition column + */ + private[hudi] def extractNestedPartitionFilters(dataFilters: Seq[Expression], + partitionColumnNames: Set[String]): Seq[Expression] = { + val partitionColumnRoots = partitionColumnNames.map(_.split("\\.", 2)(0)) + dataFilters.filter { expr => + // Resolve all outermost GetStructField chains to their full dot-paths. + val structFieldPaths = collectOutermostStructFieldPaths(expr) + // The expression is a partition filter only when: + // 1. It contains at least one GetStructField that resolves to a partition column path, AND + // 2. ALL resolved paths are partition columns (no non-partition nested fields), AND + // 3. ALL attribute references are roots of partition columns + // (guards against mixed expressions like "nested_record.level = 'INFO' AND int_field > 0") + structFieldPaths.nonEmpty && + structFieldPaths.forall(partitionColumnNames.contains) && + expr.references.map(_.name).forall(partitionColumnRoots.contains) + } + } + + /** + * Collects full dot-paths of outermost [[GetStructField]] chains in an expression. + * `EqualTo(a.b.c, 1)` → `Seq("a.b.c")` (not intermediate `"a.b"`). + */ + private[hudi] def collectOutermostStructFieldPaths(expr: Expression): Seq[String] = { + expr match { + case g: GetStructField => resolveGetStructFieldPath(g).toSeq + case _ => expr.children.flatMap(collectOutermostStructFieldPaths) + } + } + + /** Resolves a [[GetStructField]] chain to its full dot-path: `attr("a").b.c` → `"a.b.c"`. */ + private[hudi] def resolveGetStructFieldPath(expr: Expression): Option[String] = expr match { + case GetStructField(child: AttributeReference, _, Some(fieldName)) => + Some(child.name + "." + fieldName) + case GetStructField(child: GetStructField, _, Some(fieldName)) => + resolveGetStructFieldPath(child).map(_ + "." + fieldName) + case _ => None + } + def collectReferencedColumns(spark: SparkSession, queryFilters: Seq[Expression], schema: StructType): Seq[String] = { val resolver = spark.sessionState.analyzer.resolver val refs = queryFilters.flatMap(_.references) diff --git a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/SparkHoodieTableFileIndex.scala b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/SparkHoodieTableFileIndex.scala index 1ba9628af3b75..b80eb204823af 100644 --- a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/SparkHoodieTableFileIndex.scala +++ b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/SparkHoodieTableFileIndex.scala @@ -46,11 +46,11 @@ import org.apache.spark.api.java.JavaSparkContext import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.{expressions, InternalRow} -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BasePredicate, BoundReference, EmptyRow, EqualTo, Expression, InterpretedPredicate, Literal} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BasePredicate, BoundReference, EmptyRow, EqualTo, Expression, GetStructField, Literal} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.{FileStatusCache, NoopCache} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{ByteType, DateType, IntegerType, LongType, ShortType, StringType, StructField, StructType} +import org.apache.spark.sql.types.{ByteType, DataType, DateType, IntegerType, LongType, ShortType, StringType, StructField, StructType} import org.slf4j.LoggerFactory import javax.annotation.concurrent.NotThreadSafe @@ -59,6 +59,7 @@ import java.lang.reflect.{Array => JArray} import java.util.Collections import scala.collection.JavaConverters._ +import scala.collection.mutable.LinkedHashMap import scala.language.implicitConversions import scala.util.{Success, Try} @@ -201,6 +202,25 @@ class SparkHoodieTableFileIndex(spark: SparkSession, } } + /** + * Spark-facing partition schema that preserves nested structure for nested partition columns. + * + * NOTE: Hudi's [[partitionSchema]] intentionally returns a *flat* schema where field names use full + * dot-paths (for example, "a.b.c") to avoid collisions with top-level data columns. Some Spark + * planner/analyzer paths, however, reason about nested columns as nested [[StructType]]s and + * require a nested schema shape to properly resolve [[GetStructField]] chains. + * + * This method reconstructs a nested [[StructType]] from the flat partition schema, using the same + * leaf data-types, and preserving deterministic field ordering based on the original flat schema. + */ + def partitionSchemaForSpark: StructType = { + if (!shouldReadAsPartitionedTable) { + new StructType() + } else { + SparkHoodieTableFileIndex.buildNestedPartitionSchema(_partitionSchemaFromProperties) + } + } + /** * Fetch list of latest base files w/ corresponding log files, after performing * partition pruning @@ -238,13 +258,32 @@ class SparkHoodieTableFileIndex(spark: SparkSession, def listMatchingPartitionPaths(predicates: Seq[Expression]): Seq[PartitionPath] = { val resolve = spark.sessionState.analyzer.resolver val partitionColumnNames = getPartitionColumns - val partitionPruningPredicates = predicates.filter { - _.references.map(_.name).forall { ref => - // NOTE: We're leveraging Spark's resolver here to appropriately handle case-sensitivity - partitionColumnNames.exists(partCol => resolve(ref, partCol)) - } + + // Resolves GetStructField chain to full dot-path: GetStructField(attr("a"), _, "b") → "a.b" + def getFieldPath(expr: Expression): Option[String] = expr match { + case a: AttributeReference => Some(a.name) + case GetStructField(child, _, Some(fieldName)) => + getFieldPath(child).map(_ + "." + fieldName) + case _ => None } + // True if every column reference in expr resolves to a partition column. + // For nested columns, walks GetStructField chains to match the full dot-path. + // Example: partition = "nested_record.level" + // nested_record.level = 'INFO' → GetStructField path "nested_record.level" → true + // nested_record.nested_int = 10 → GetStructField path "nested_record.nested_int" → false + // IsNotNull(nested_record) → AttributeReference "nested_record" not in partitionColumnNames → false + def referencesOnlyPartitionColumns(expr: Expression): Boolean = expr match { + case g: GetStructField => + getFieldPath(g).exists(path => partitionColumnNames.exists(pc => resolve(path, pc))) + case a: AttributeReference => + partitionColumnNames.exists(pc => resolve(a.name, pc)) + case _ => + expr.children.forall(referencesOnlyPartitionColumns) + } + + val partitionPruningPredicates = predicates.filter(referencesOnlyPartitionColumns) + if (partitionPruningPredicates.isEmpty) { val queryPartitionPaths = getAllQueryPartitionPaths.asScala.toSeq logInfo(s"No partition predicates provided, listing full table (${queryPartitionPaths.size} partitions)") @@ -269,10 +308,18 @@ class SparkHoodieTableFileIndex(spark: SparkSession, // the whole table if (haveProperPartitionValues(partitionPaths.toSeq) && partitionSchema.nonEmpty) { val predicate = partitionPruningPredicates.reduce(expressions.And) + val partitionFieldNames = partitionSchema.fieldNames val transformedPredicate = predicate.transform { + case g @ GetStructField(_, _, Some(_)) => + getFieldPath(g).flatMap { path => + val idx = partitionFieldNames.indexWhere(name => resolve(path, name)) + if (idx >= 0) Some(BoundReference(idx, partitionSchema(idx).dataType, nullable = true)) + else None + }.getOrElse(g) case a: AttributeReference => - val index = partitionSchema.indexWhere(a.name == _.name) - BoundReference(index, partitionSchema(index).dataType, nullable = true) + val index = partitionSchema.indexWhere(sf => resolve(a.name, sf.name)) + if (index >= 0) BoundReference(index, partitionSchema(index).dataType, nullable = true) + else a } val boundPredicate: BasePredicate = try { // Try using 1-arg constructor via reflection @@ -488,6 +535,76 @@ object SparkHoodieTableFileIndex extends SparkAdapterSupport { private val LOG = LoggerFactory.getLogger(classOf[SparkHoodieTableFileIndex]) private val PUT_LEAF_FILES_METHOD_NAME = "putLeafFiles" + private case class NestedFieldNode( + leafType: Option[DataType], + children: LinkedHashMap[String, NestedFieldNode] + ) + + /** + * Reconstruct nested partition schema from a flat partition schema containing dot-path field names. + * + * For example, flat fields ["a.b": int, "a.c": string, "d": long] becomes: + * + * StructType( + * StructField("a", StructType(StructField("b", int), StructField("c", string))), + * StructField("d", long) + * ) + */ + private[hudi] def buildNestedPartitionSchema(flatPartitionSchema: StructType): StructType = { + if (flatPartitionSchema.isEmpty) { + new StructType() + } else { + val root = NestedFieldNode(None, LinkedHashMap.empty) + + def getOrCreateChild(parent: NestedFieldNode, name: String): NestedFieldNode = { + parent.children.getOrElseUpdate(name, NestedFieldNode(None, LinkedHashMap.empty)) + } + + flatPartitionSchema.fields.foreach { field => + val parts = field.name.split("\\.", -1) + checkState(parts.forall(p => p.nonEmpty), + s"Invalid partition field path '${field.name}' in partition schema") + + var node = root + var i = 0 + while (i < parts.length) { + val part = parts(i) + val isLeaf = i == parts.length - 1 + + if (isLeaf) { + val child = getOrCreateChild(node, part) + checkState(child.children.isEmpty, + s"Conflicting partition schema: '${field.name}' collides with nested fields under '${parts.take(i + 1).mkString(".")}'") + checkState(child.leafType.isEmpty || child.leafType.contains(field.dataType), + s"Conflicting partition schema: '${field.name}' has inconsistent types (${child.leafType.orNull} vs ${field.dataType})") + node.children.update(part, child.copy(leafType = Some(field.dataType))) + } else { + val child = getOrCreateChild(node, part) + checkState(child.leafType.isEmpty, + s"Conflicting partition schema: '${field.name}' requires struct at '${parts.take(i + 1).mkString(".")}', but a leaf is defined") + node = child + } + + i += 1 + } + } + + def toStructType(node: NestedFieldNode): StructType = { + val fields = node.children.map { case (name, child) => + child.leafType match { + case Some(dt) if child.children.isEmpty => + StructField(name, dt, nullable = true) + case _ => + StructField(name, toStructType(child), nullable = true) + } + }.toArray + StructType(fields) + } + + toStructType(root) + } + } + private def haveProperPartitionValues(partitionPaths: Seq[PartitionPath]) = { partitionPaths.forall(_.getValues.length > 0) } @@ -520,27 +637,10 @@ object SparkHoodieTableFileIndex extends SparkAdapterSupport { } /** - * This method unravels [[StructType]] into a [[Map]] of pairs of dot-path notation with corresponding - * [[StructField]] object for every field of the provided [[StructType]], recursively. - * - * For example, following struct - *
-   *   StructType(
-   *     StructField("a",
-   *       StructType(
-   *          StructField("b", StringType),
-   *          StructField("c", IntType)
-   *       )
-   *     )
-   *   )
-   * 
- * - * will be converted into following mapping: - * - *
-   *   "a.b" -> StructField("b", StringType),
-   *   "a.c" -> StructField("c", IntType),
-   * 
+ * Maps every leaf field in `structType` to its dot-path name. + * Both the key and [[StructField.name]] use the full path. + * E.g. `StructType(StructField("a", StructType(StructField("b", IntegerType))))` + * → `Map("a.b" -> StructField("a.b", IntegerType))`. */ private def generateFieldMap(structType: StructType) : Map[String, StructField] = { def traverse(structField: Either[StructField, StructType]) : Map[String, StructField] = { @@ -548,7 +648,10 @@ object SparkHoodieTableFileIndex extends SparkAdapterSupport { case Right(struct) => struct.fields.flatMap(f => traverse(Left(f))).toMap case Left(field) => field.dataType match { case struct: StructType => traverse(Right(struct)).map { - case (key, structField) => (s"${field.name}.$key", structField) + case (key, structField) => { + val fullPath = s"${field.name}.$key" + (fullPath, structField.copy(name = fullPath)) + } } case _ => Map(field.name -> field) } diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestHoodieFileIndex.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestHoodieFileIndex.scala index d5e0c6a927ac3..8d06e257d1787 100644 --- a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestHoodieFileIndex.scala +++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestHoodieFileIndex.scala @@ -44,14 +44,14 @@ import org.apache.hudi.testutils.HoodieSparkClientTestBase import org.apache.hudi.util.JFunction import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, EqualTo, GreaterThanOrEqual, LessThan, Literal} +import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, EqualTo, Expression, GetStructField, GreaterThanOrEqual, LessThan, Literal, Or} import org.apache.spark.sql.execution.datasources.{NoopCache, PartitionDirectory} import org.apache.spark.sql.functions.{lit, struct} import org.apache.spark.sql.hudi.HoodieSparkSessionExtension import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.junit.jupiter.api.{BeforeEach, Test} -import org.junit.jupiter.api.Assertions.{assertEquals, assertTrue} +import org.junit.jupiter.api.Assertions.{assertEquals, assertThrows, assertTrue} import org.junit.jupiter.params.ParameterizedTest import org.junit.jupiter.params.provider.{Arguments, CsvSource, MethodSource, ValueSource} @@ -858,6 +858,35 @@ class TestHoodieFileIndex extends HoodieSparkClientTestBase with ScalaAssertionS partitionValues.mkString(StoragePath.SEPARATOR) } } + + // ---- buildNestedPartitionSchema tests ---- + + @ParameterizedTest + @MethodSource(Array("buildNestedPartitionSchemaCases")) + def testBuildNestedPartitionSchema(name: String, flat: StructType, expected: StructType): Unit = { + assertEquals(expected, SparkHoodieTableFileIndex.buildNestedPartitionSchema(flat)) + } + + @Test + def testBuildNestedPartitionSchemaConflictThrows(): Unit = { + // "a" as leaf and "a.b" as nested — conflict + val flat = StructType(Seq(StructField("a", StringType), StructField("a.b", IntegerType))) + assertThrows(classOf[IllegalStateException]) { + SparkHoodieTableFileIndex.buildNestedPartitionSchema(flat) + } + } + + // ---- extractNestedPartitionFilters tests ---- + + @ParameterizedTest + @MethodSource(Array("extractNestedPartitionFiltersCases")) + def testExtractNestedPartitionFilters(name: String, + filters: Seq[Expression], + partitionColumns: Set[String], + expected: Seq[Expression]): Unit = { + assertEquals(expected, HoodieFileIndex.extractNestedPartitionFilters(filters, partitionColumns)) + } + } object TestHoodieFileIndex { @@ -870,4 +899,65 @@ object TestHoodieFileIndex { Arguments.arguments("org.apache.hudi.keygen.TimestampBasedKeyGenerator") ) } + + def buildNestedPartitionSchemaCases(): java.util.stream.Stream[Arguments] = { + val nested = StructType(Seq( + StructField("nested_record", StructType(Seq(StructField("level", StringType, nullable = true))), nullable = true))) + val twoLevel = StructType(Seq( + StructField("a", StructType(Seq( + StructField("b", StructType(Seq(StructField("c", IntegerType, nullable = true))), nullable = true))), nullable = true))) + val siblings = StructType(Seq( + StructField("a", StructType(Seq( + StructField("b", StringType, nullable = true), + StructField("c", IntegerType, nullable = true))), nullable = true))) + val mixed = StructType(Seq( + StructField("country", StringType, nullable = true), + StructField("nested_record", StructType(Seq(StructField("level", StringType, nullable = true))), nullable = true))) + java.util.stream.Stream.of( + Arguments.of("empty", + new StructType(), + new StructType()), + Arguments.of("flat", + StructType(Seq(StructField("country", StringType))), + StructType(Seq(StructField("country", StringType, nullable = true)))), + Arguments.of("singleNested", + StructType(Seq(StructField("nested_record.level", StringType))), + nested), + Arguments.of("twoLevelNesting", + StructType(Seq(StructField("a.b.c", IntegerType))), + twoLevel), + Arguments.of("siblingFields", + StructType(Seq(StructField("a.b", StringType), StructField("a.c", IntegerType))), + siblings), + Arguments.of("mixedFlatAndNested", + StructType(Seq(StructField("country", StringType), StructField("nested_record.level", StringType))), + mixed) + ) + } + + def extractNestedPartitionFiltersCases(): java.util.stream.Stream[Arguments] = { + val levelStruct = StructType(Seq(StructField("level", StringType))) + val multiFieldStruct = StructType(Seq( + StructField("nested_int", IntegerType), StructField("level", StringType))) + + val partFilter = EqualTo( + GetStructField(AttributeReference("nested_record", levelStruct)(), 0, Some("level")), + Literal("INFO")) + val dataFilter = EqualTo(AttributeReference("int_field", IntegerType)(), Literal(5)) + val siblingFilter = EqualTo( + GetStructField(AttributeReference("nested_record", multiFieldStruct)(), 0, Some("nested_int")), + Literal(10)) + val orFilter = Or( + EqualTo(GetStructField(AttributeReference("nested_record", levelStruct)(), 0, Some("level")), Literal("INFO")), + EqualTo(GetStructField(AttributeReference("nested_record", levelStruct)(), 0, Some("level")), Literal("ERROR"))) + + java.util.stream.Stream.of( + Arguments.of("partitionFilterExtractedDataFilterDropped", + Seq(partFilter, dataFilter), Set("nested_record.level"), Seq(partFilter)), + Arguments.of("siblingFieldExcluded", + Seq(siblingFilter), Set("nested_record.level"), Seq.empty[Expression]), + Arguments.of("orWithOnlyPartitionColumnsExtracted", + Seq(orFilter), Set("nested_record.level"), Seq(orFilter)) + ) + } } diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestCOWDataSource.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestCOWDataSource.scala index db5ca1ee78002..106c1f09a7ae2 100644 --- a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestCOWDataSource.scala +++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestCOWDataSource.scala @@ -17,7 +17,7 @@ package org.apache.hudi.functional -import org.apache.hudi.{AvroConversionUtils, DataSourceReadOptions, DataSourceWriteOptions, HoodieDataSourceHelpers, HoodieSchemaConversionUtils, HoodieSparkUtils, QuickstartUtils, ScalaAssertionSupport} +import org.apache.hudi.{AvroConversionUtils, DataSourceReadOptions, DataSourceWriteOptions, HoodieBaseRelation, HoodieDataSourceHelpers, HoodieFileIndex, HoodieSchemaConversionUtils, HoodieSparkUtils, QuickstartUtils, ScalaAssertionSupport} import org.apache.hudi.DataSourceWriteOptions.{INLINE_CLUSTERING_ENABLE, KEYGENERATOR_CLASS_NAME} import org.apache.hudi.HoodieConversionUtils.toJavaOption import org.apache.hudi.QuickstartUtils.{convertToStringList, getQuickstartWriteConfigs} @@ -42,13 +42,14 @@ import org.apache.hudi.hive.HiveSyncConfigHolder import org.apache.hudi.keygen.{ComplexKeyGenerator, CustomKeyGenerator, GlobalDeleteKeyGenerator, NonpartitionedKeyGenerator, SimpleKeyGenerator, TimestampBasedKeyGenerator} import org.apache.hudi.keygen.constant.{KeyGeneratorOptions, KeyGeneratorType} import org.apache.hudi.metrics.{Metrics, MetricsReporterType} -import org.apache.hudi.storage.{StoragePath, StoragePathFilter} +import org.apache.hudi.storage.{HoodieStorage, StoragePath, StoragePathFilter} import org.apache.hudi.table.HoodieSparkTable import org.apache.hudi.testutils.{DataSourceTestUtils, HoodieSparkClientTestBase} import org.apache.hudi.util.JFunction import org.apache.hadoop.fs.FileSystem import org.apache.spark.sql.{DataFrame, DataFrameWriter, Dataset, Encoders, Row, SaveMode, SparkSession, SparkSessionExtensions} +import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} import org.apache.spark.sql.functions.{col, concat, lit, udf, when} import org.apache.spark.sql.hudi.HoodieSparkSessionExtension import org.apache.spark.sql.types.{ArrayType, DataTypes, DateType, IntegerType, LongType, MapType, StringType, StructField, StructType, TimestampType} @@ -2616,9 +2617,300 @@ class TestCOWDataSource extends HoodieSparkClientTestBase with ScalaAssertionSup assertEquals("row3", results(2).getAs[String]("_row_key")) assertEquals("value3", results(2).getAs[String]("data")) } + + @Test + def testNestedFieldPartition(): Unit = { + TestCOWDataSource.runNestedFieldPartitionTest(spark, basePath, storage, "COW") + } } object TestCOWDataSource { + + /** + * Shared test logic for nested field partition (COW and MOR). + * Used by TestCOWDataSource.testNestedFieldPartition and TestMORDataSource.testNestedFieldPartition. + */ + def runNestedFieldPartitionTest(spark: SparkSession, basePath: String, storage: HoodieStorage, tableType: String): Unit = { + // Define schema with nested_record containing level field + val nestedSchema = StructType(Seq( + StructField("nested_int", IntegerType, nullable = false), + StructField("level", StringType, nullable = false) + )) + + val schema = StructType(Seq( + StructField("key", StringType, nullable = false), + StructField("ts", LongType, nullable = false), + StructField("level", StringType, nullable = false), + StructField("int_field", IntegerType, nullable = false), + StructField("string_field", StringType, nullable = true), + StructField("nested_record", nestedSchema, nullable = true) + )) + + // Create test data where top-level 'level' and 'nested_record.level' have DIFFERENT values + // This helps verify we're correctly partitioning/filtering on the nested field + val recordsCommit1 = Seq( + Row("key1", 1L, "L1", 1, "str1", Row(10, "INFO")), + Row("key2", 2L, "L2", 2, "str2", Row(20, "ERROR")), + Row("key3", 3L, "L3", 3, "str3", Row(30, "INFO")), + Row("key4", 4L, "L4", 4, "str4", Row(40, "DEBUG")), + Row("key5", 5L, "L5", 5, "str5", Row(50, "INFO")) + ) + + val tableTypeOptVal = if (tableType == "MOR") { + DataSourceWriteOptions.MOR_TABLE_TYPE_OPT_VAL + } else { + DataSourceWriteOptions.COW_TABLE_TYPE_OPT_VAL + } + + val baseWriteOpts = Map( + "hoodie.insert.shuffle.parallelism" -> "4", + "hoodie.upsert.shuffle.parallelism" -> "4", + DataSourceWriteOptions.RECORDKEY_FIELD.key -> "key", + DataSourceWriteOptions.PARTITIONPATH_FIELD.key -> "nested_record.level", + HoodieTableConfig.ORDERING_FIELDS.key -> "ts", + HoodieWriteConfig.TBL_NAME.key -> "test_nested_partition", + DataSourceWriteOptions.TABLE_TYPE.key -> tableTypeOptVal + ) + val writeOpts = if (tableType == "MOR") { + baseWriteOpts + ("hoodie.compact.inline" -> "false") + } else { + baseWriteOpts + } + + // Commit 1 - Initial insert + val inputDF1 = spark.createDataFrame( + spark.sparkContext.parallelize(recordsCommit1), + schema + ) + inputDF1.write.format("hudi") + .options(writeOpts) + .option(DataSourceWriteOptions.OPERATION.key, DataSourceWriteOptions.INSERT_OPERATION_OPT_VAL) + .mode(SaveMode.Overwrite) + .save(basePath) + val commit1 = DataSourceTestUtils.latestCommitCompletionTime(storage, basePath) + + // Commit 2 - Upsert: update key1 (int_field 1->100), insert key6 (INFO) + val recordsCommit2 = Seq( + Row("key1", 10L, "L1", 100, "str1", Row(10, "INFO")), + Row("key6", 6L, "L6", 6, "str6", Row(60, "INFO")) + ) + val inputDF2 = spark.createDataFrame( + spark.sparkContext.parallelize(recordsCommit2), + schema + ) + inputDF2.write.format("hudi") + .options(writeOpts) + .option(DataSourceWriteOptions.OPERATION.key, DataSourceWriteOptions.UPSERT_OPERATION_OPT_VAL) + .mode(SaveMode.Append) + .save(basePath) + val commit2 = DataSourceTestUtils.latestCommitCompletionTime(storage, basePath) + + // Commit 3 - Upsert: update key3 (int_field 3->300), insert key7 (INFO) + val recordsCommit3 = Seq( + Row("key3", 30L, "L3", 300, "str3", Row(30, "INFO")), + Row("key7", 7L, "L7", 7, "str7", Row(70, "INFO")) + ) + val inputDF3 = spark.createDataFrame( + spark.sparkContext.parallelize(recordsCommit3), + schema + ) + inputDF3.write.format("hudi") + .options(writeOpts) + .option(DataSourceWriteOptions.OPERATION.key, DataSourceWriteOptions.UPSERT_OPERATION_OPT_VAL) + .mode(SaveMode.Append) + .save(basePath) + val commit3 = DataSourceTestUtils.latestCommitCompletionTime(storage, basePath) + + // Verify partition structure - we should have 3 partitions: INFO, ERROR, DEBUG + val allPartitions = storage.listDirectEntries(new StoragePath(basePath)) + .asScala.filter(_.isDirectory) + .map(_.getPath.getName) + .filterNot(_.startsWith(".")) // Filter out .hoodie and other hidden directories + .sorted + assertEquals(3, allPartitions.size, s"Expected 3 partitions for $tableType, but got: ${allPartitions.mkString(", ")}") + assertTrue(allPartitions.contains("INFO"), s"Missing INFO partition for $tableType") + assertTrue(allPartitions.contains("ERROR"), s"Missing ERROR partition for $tableType") + assertTrue(allPartitions.contains("DEBUG"), s"Missing DEBUG partition for $tableType") + + // Snapshot read - filter on nested_record.level = 'INFO' (latest state: 5 records) + val snapshotDF = spark.read.format("hudi") + .load(basePath) + .filter("nested_record.level = 'INFO'") + .select("key", "ts", "level", "int_field", "string_field", "nested_record") + .orderBy("key") + + // VERIFICATION 1: Check partition schema contains the nested field + val snapshotRelation = snapshotDF.queryExecution.optimizedPlan.collectFirst { + case lr: LogicalRelation => lr + } + assertTrue(snapshotRelation.isDefined, s"LogicalRelation should exist for $tableType") + val fileIndex = snapshotRelation.get.relation match { + case fsRelation: HadoopFsRelation => + fsRelation.location.asInstanceOf[HoodieFileIndex] + case baseRelation: HoodieBaseRelation => + baseRelation.fileIndex + case _ => null + } + assertTrue(fileIndex != null, s"FileIndex should be available for $tableType") + assertEquals(1, fileIndex.partitionSchema.fields.length, + s"Partition schema should have 1 field for $tableType") + assertEquals("nested_record.level", fileIndex.partitionSchema.fields(0).name, + s"Partition field should be 'nested_record.level' for $tableType") + + // VERIFICATION 2: Check that predicates were pushed down to FileIndex + assertTrue(fileIndex.hasPredicatesPushedDown, + s"Partition predicates should be pushed down to FileIndex for $tableType") + + // VERIFICATION 3: Verify partition pruning by checking the physical plan + // The physical plan should show that only specific files are being scanned + val physicalPlan = snapshotDF.queryExecution.executedPlan.toString() + assertTrue(physicalPlan.contains("Scan") || physicalPlan.contains("FileScan"), + s"Physical plan should contain scan operation for $tableType") + + // Collect results to execute the query + val snapshotResults = snapshotDF.collect() + val expectedSnapshot = Array( + Row("key1", 10L, "L1", 100, "str1", Row(10, "INFO")), + Row("key3", 30L, "L3", 300, "str3", Row(30, "INFO")), + Row("key5", 5L, "L5", 5, "str5", Row(50, "INFO")), + Row("key6", 6L, "L6", 6, "str6", Row(60, "INFO")), + Row("key7", 7L, "L7", 7, "str7", Row(70, "INFO")) + ) + assertEquals(expectedSnapshot.length, snapshotResults.length, + s"Snapshot (INFO) count mismatch for $tableType") + expectedSnapshot.zip(snapshotResults).foreach { case (expected, actual) => + assertEquals(expected, actual) + } + + // Time travel - as of commit1 (only initial 5 records; INFO = key1, key3, key5) + val timeTravelDF1 = spark.read.format("hudi") + .option(DataSourceReadOptions.TIME_TRAVEL_AS_OF_INSTANT.key, commit1) + .load(basePath) + .filter("nested_record.level = 'INFO'") + .select("key", "ts", "level", "int_field", "string_field", "nested_record") + .orderBy("key") + + // VERIFICATION 4: Verify partition pruning works for time travel queries + // Check that the time travel query with partition filter returns correct results + val timeTravelCommit1 = timeTravelDF1.collect() + val expectedAfterCommit1 = Array( + Row("key1", 1L, "L1", 1, "str1", Row(10, "INFO")), + Row("key3", 3L, "L3", 3, "str3", Row(30, "INFO")), + Row("key5", 5L, "L5", 5, "str5", Row(50, "INFO")) + ) + assertEquals(expectedAfterCommit1.length, timeTravelCommit1.length, + s"Time travel to commit1 (INFO) count mismatch for $tableType") + expectedAfterCommit1.zip(timeTravelCommit1).foreach { case (expected, actual) => + assertEquals(expected, actual) + } + + // Time travel - as of commit2 (after 2nd commit; INFO = key1 updated, key3, key5, key6) + val timeTravelCommit2 = spark.read.format("hudi") + .option(DataSourceReadOptions.TIME_TRAVEL_AS_OF_INSTANT.key, commit2) + .load(basePath) + .filter("nested_record.level = 'INFO'") + .select("key", "ts", "level", "int_field", "string_field", "nested_record") + .orderBy("key") + .collect() + + val expectedAfterCommit2 = Array( + Row("key1", 10L, "L1", 100, "str1", Row(10, "INFO")), + Row("key3", 3L, "L3", 3, "str3", Row(30, "INFO")), + Row("key5", 5L, "L5", 5, "str5", Row(50, "INFO")), + Row("key6", 6L, "L6", 6, "str6", Row(60, "INFO")) + ) + assertEquals(expectedAfterCommit2.length, timeTravelCommit2.length, + s"Time travel to commit2 (INFO) count mismatch for $tableType") + expectedAfterCommit2.zip(timeTravelCommit2).foreach { case (expected, actual) => + assertEquals(expected, actual) + } + + // Incremental query - from commit1 to commit2 (only key1 update and key6 insert; both INFO) + val incrementalDF1To2 = spark.read.format("hudi") + .option(DataSourceReadOptions.QUERY_TYPE.key, DataSourceReadOptions.QUERY_TYPE_INCREMENTAL_OPT_VAL) + .option(DataSourceReadOptions.START_COMMIT.key, commit1) + .option(DataSourceReadOptions.END_COMMIT.key, commit2) + .load(basePath) + .filter("nested_record.level = 'INFO'") + .select("key", "ts", "level", "int_field", "string_field", "nested_record") + .orderBy("key") + + // VERIFICATION 6: Verify partition filtering works for incremental queries + // For incremental queries, the filter on nested_record.level should still limit scanned data + val incrementalPlan1To2 = incrementalDF1To2.queryExecution.executedPlan.toString() + // The plan should show filtering is happening + assertTrue(incrementalPlan1To2.contains("Filter") || incrementalPlan1To2.contains("Scan"), + s"Incremental query plan should show filtering for $tableType") + + val incrementalCommit1To2 = incrementalDF1To2.collect() + val expectedInc1To2 = Array( + Row("key1", 10L, "L1", 100, "str1", Row(10, "INFO")), + Row("key6", 6L, "L6", 6, "str6", Row(60, "INFO")) + ) + assertEquals(expectedInc1To2.length, incrementalCommit1To2.length, + s"Incremental (commit1->commit2, INFO) count mismatch for $tableType") + expectedInc1To2.zip(incrementalCommit1To2).foreach { case (expected, actual) => + assertEquals(expected, actual) + } + + // Incremental query - from commit2 to commit3 (only key3 update and key7 insert; both INFO) + val incrementalCommit2To3 = spark.read.format("hudi") + .option(DataSourceReadOptions.QUERY_TYPE.key, DataSourceReadOptions.QUERY_TYPE_INCREMENTAL_OPT_VAL) + .option(DataSourceReadOptions.START_COMMIT.key, commit2) + .option(DataSourceReadOptions.END_COMMIT.key, commit3) + .load(basePath) + .filter("nested_record.level = 'INFO'") + .select("key", "ts", "level", "int_field", "string_field", "nested_record") + .orderBy("key") + .collect() + + val expectedInc2To3 = Array( + Row("key3", 30L, "L3", 300, "str3", Row(30, "INFO")), + Row("key7", 7L, "L7", 7, "str7", Row(70, "INFO")) + ) + assertEquals(expectedInc2To3.length, incrementalCommit2To3.length, + s"Incremental (commit2->commit3, INFO) count mismatch for $tableType") + expectedInc2To3.zip(incrementalCommit2To3).foreach { case (expected, actual) => + assertEquals(expected, actual) + } + + // VERIFICATION 4: Test with different partition values to ensure filtering is working correctly + // Query for ERROR partition (should only return key2) + val errorPartitionDF = spark.read.format("hudi") + .load(basePath) + .filter("nested_record.level = 'ERROR'") + .select("key", "nested_record") + + val errorResults = errorPartitionDF.collect() + assertEquals(1, errorResults.length, s"ERROR partition should have 1 record for $tableType") + assertEquals("key2", errorResults(0).getString(0), + s"ERROR partition should contain key2 for $tableType") + + // VERIFICATION 5: Test with DEBUG partition + val debugPartitionDF = spark.read.format("hudi") + .load(basePath) + .filter("nested_record.level = 'DEBUG'") + .select("key", "nested_record") + + val debugResults = debugPartitionDF.collect() + assertEquals(1, debugResults.length, s"DEBUG partition should have 1 record for $tableType") + assertEquals("key4", debugResults(0).getString(0), + s"DEBUG partition should contain key4 for $tableType") + + // VERIFICATION 6: Verify that filtering on top-level 'level' field returns correct results + // This ensures we're correctly distinguishing between nested_record.level (partition) and level (data column) + val topLevelFilterDF = spark.read.format("hudi") + .load(basePath) + .filter("level = 'L1'") // Filter on top-level 'level', not nested_record.level + .select("key", "level", "nested_record") + + val topLevelResults = topLevelFilterDF.collect() + // Should return key1 which has level='L1' and is in INFO partition + assertEquals(1, topLevelResults.length, s"Top-level level='L1' should return 1 record for $tableType") + assertEquals("key1", topLevelResults(0).getString(0), + s"Top-level level='L1' should return key1 for $tableType") + } + def convertColumnsToNullable(df: DataFrame, cols: String*): DataFrame = { cols.foldLeft(df) { (df, c) => // NOTE: This is the trick to make Spark convert a non-null column "c" into a nullable diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestMORDataSource.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestMORDataSource.scala index ab7bbcd097d28..81d049d432439 100644 --- a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestMORDataSource.scala +++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestMORDataSource.scala @@ -2347,6 +2347,11 @@ class TestMORDataSource extends HoodieSparkClientTestBase with SparkDatasetMixin assertEquals("row4", results(3).getAs[String]("_row_key")) assertEquals("value4", results(3).getAs[String]("data")) } + + @Test + def testNestedFieldPartition(): Unit = { + TestCOWDataSource.runNestedFieldPartitionTest(spark, basePath, storage, "MOR") + } } object TestMORDataSource { diff --git a/hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/hudi/analysis/Spark3HoodiePruneFileSourcePartitions.scala b/hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/hudi/analysis/Spark3HoodiePruneFileSourcePartitions.scala index 589ee9774d3b7..3cb9e8250da96 100644 --- a/hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/hudi/analysis/Spark3HoodiePruneFileSourcePartitions.scala +++ b/hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/hudi/analysis/Spark3HoodiePruneFileSourcePartitions.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.types.StructType * Prune the partitions of Hudi table based relations by the means of pushing down the * partition filters * - * NOTE: [[HoodiePruneFileSourcePartitions]] is a replica in kind to Spark's [[PruneFileSourcePartitions]] + * NOTE: [[HoodiePruneFileSourcePartitions]] is a replica in kind to Spark's [[org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions]] */ case class Spark3HoodiePruneFileSourcePartitions(spark: SparkSession) extends Rule[LogicalPlan] { @@ -48,11 +48,11 @@ case class Spark3HoodiePruneFileSourcePartitions(spark: SparkSession) extends Ru val normalizedFilters = exprUtils.normalizeExprs(deterministicFilters, lr.output) val (partitionPruningFilters, dataFilters) = - getPartitionFiltersAndDataFilters(fileIndex.partitionSchema, normalizedFilters) + getPartitionFiltersAndDataFilters(fileIndex.partitionSchemaForSpark, normalizedFilters) // [[HudiFileIndex]] is a caching one, therefore we don't need to reconstruct new relation, // instead we simply just refresh the index and update the stats - fileIndex.filterFileSlices(dataFilters, partitionPruningFilters, isPartitionPruned = true) + fileIndex.filterFileSlices(dataFilters, partitionPruningFilters, isPartitionPruneOnly = true) if (partitionPruningFilters.nonEmpty) { // Change table stats based on the sizeInBytes of pruned files @@ -105,11 +105,21 @@ private object Spark3HoodiePruneFileSourcePartitions extends PredicateHelper { Project(projects, withFilter) } + /** + * Returns true if the given attribute references a partition column. For nested partition columns + * (e.g. `nested_record.level`), `partitionSchema` is the nested [[StructType]] from + * `partitionSchemaForSpark`, so the top-level name is the struct root (e.g. `nested_record`), + * which matches `attr.name` directly via `contains`. + */ + private def isPartitionColumnReference(attr: AttributeReference, partitionSchema: StructType): Boolean = { + partitionSchema.names.contains(attr.name) + } + def getPartitionFiltersAndDataFilters(partitionSchema: StructType, normalizedFilters: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { val partitionColumns = normalizedFilters.flatMap { expr => expr.collect { - case attr: AttributeReference if partitionSchema.names.contains(attr.name) => + case attr: AttributeReference if isPartitionColumnReference(attr, partitionSchema) => attr } } diff --git a/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/hudi/analysis/Spark33HoodiePruneFileSourcePartitions.scala b/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/hudi/analysis/Spark33HoodiePruneFileSourcePartitions.scala index 7d7240231cd09..add1b7aaebf14 100644 --- a/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/hudi/analysis/Spark33HoodiePruneFileSourcePartitions.scala +++ b/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/hudi/analysis/Spark33HoodiePruneFileSourcePartitions.scala @@ -60,7 +60,7 @@ case class Spark33HoodiePruneFileSourcePartitions(spark: SparkSession) extends R // [[HudiFileIndex]] is a caching one, therefore we don't need to reconstruct new relation, // instead we simply just refresh the index and update the stats - fileIndex.filterFileSlices(dataFilters, partitionPruningFilters, isPartitionPruned = true) + fileIndex.filterFileSlices(dataFilters, partitionPruningFilters, isPartitionPruneOnly = true) if (partitionPruningFilters.nonEmpty) { // Change table stats based on the sizeInBytes of pruned files diff --git a/hudi-spark-datasource/hudi-spark4-common/src/main/scala/org/apache/spark/sql/hudi/analysis/Spark4HoodiePruneFileSourcePartitions.scala b/hudi-spark-datasource/hudi-spark4-common/src/main/scala/org/apache/spark/sql/hudi/analysis/Spark4HoodiePruneFileSourcePartitions.scala index 8412018c22db6..0f6cf87da86f6 100644 --- a/hudi-spark-datasource/hudi-spark4-common/src/main/scala/org/apache/spark/sql/hudi/analysis/Spark4HoodiePruneFileSourcePartitions.scala +++ b/hudi-spark-datasource/hudi-spark4-common/src/main/scala/org/apache/spark/sql/hudi/analysis/Spark4HoodiePruneFileSourcePartitions.scala @@ -48,11 +48,11 @@ case class Spark4HoodiePruneFileSourcePartitions(spark: SparkSession) extends Ru val normalizedFilters = exprUtils.normalizeExprs(deterministicFilters, lr.output) val (partitionPruningFilters, dataFilters) = - getPartitionFiltersAndDataFilters(fileIndex.partitionSchema, normalizedFilters) + getPartitionFiltersAndDataFilters(fileIndex.partitionSchemaForSpark, normalizedFilters) // [[HudiFileIndex]] is a caching one, therefore we don't need to reconstruct new relation, // instead we simply just refresh the index and update the stats - fileIndex.filterFileSlices(dataFilters, partitionPruningFilters, isPartitionPruned = true) + fileIndex.filterFileSlices(dataFilters, partitionPruningFilters, isPartitionPruneOnly = true) if (partitionPruningFilters.nonEmpty) { // Change table stats based on the sizeInBytes of pruned files @@ -105,11 +105,21 @@ private object Spark4HoodiePruneFileSourcePartitions extends PredicateHelper { Project(projects, withFilter) } + /** + * Returns true if the given attribute references a partition column. For nested partition columns + * (e.g. `nested_record.level`), `partitionSchema` is the nested [[StructType]] from + * `partitionSchemaForSpark`, so the top-level name is the struct root (e.g. `nested_record`), + * which matches `attr.name` directly via `contains`. + */ + private def isPartitionColumnReference(attr: AttributeReference, partitionSchema: StructType): Boolean = { + partitionSchema.names.contains(attr.name) + } + def getPartitionFiltersAndDataFilters(partitionSchema: StructType, normalizedFilters: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { val partitionColumns = normalizedFilters.flatMap { expr => expr.collect { - case attr: AttributeReference if partitionSchema.names.contains(attr.name) => + case attr: AttributeReference if isPartitionColumnReference(attr, partitionSchema) => attr } } diff --git a/hudi-utilities/src/main/java/org/apache/hudi/utilities/HiveIncrementalPuller.java b/hudi-utilities/src/main/java/org/apache/hudi/utilities/HiveIncrementalPuller.java index fede1b8fba030..2510edce72a8c 100644 --- a/hudi-utilities/src/main/java/org/apache/hudi/utilities/HiveIncrementalPuller.java +++ b/hudi-utilities/src/main/java/org/apache/hudi/utilities/HiveIncrementalPuller.java @@ -20,9 +20,9 @@ import org.apache.hudi.common.table.HoodieTableMetaClient; import org.apache.hudi.common.table.timeline.HoodieInstant; -import org.apache.hudi.common.util.FileIOUtils; import org.apache.hudi.common.util.Option; import org.apache.hudi.exception.HoodieException; +import org.apache.hudi.hadoop.fs.HadoopFSUtils; import org.apache.hudi.utilities.exception.HoodieIncrementalPullException; import org.apache.hudi.utilities.exception.HoodieIncrementalPullSQLException; @@ -50,6 +50,8 @@ import java.util.Scanner; import java.util.stream.Collectors; +import static org.apache.hudi.io.util.FileIOUtils.readAsUTFString; + /** * Utility to pull data after a given commit, based on the supplied HiveQL and save the delta as another hive temporary * table. This temporary table can be further read using {@link org.apache.hudi.utilities.sources.HiveIncrPullSource} and the changes can @@ -115,7 +117,7 @@ public HiveIncrementalPuller(Config config) throws IOException { this.config = config; validateConfig(config); String templateContent = - FileIOUtils.readAsUTFString(this.getClass().getResourceAsStream("/IncrementalPull.sqltemplate")); + readAsUTFString(this.getClass().getResourceAsStream("/IncrementalPull.sqltemplate")); incrementalPullSQLTemplate = new ST(templateContent); } @@ -298,12 +300,13 @@ private String scanForCommitTime(FileSystem fs, String targetDataPath) throws IO if (!fs.exists(new Path(targetDataPath)) || !fs.exists(new Path(targetDataPath + "/.hoodie"))) { return "0"; } - HoodieTableMetaClient metadata = HoodieTableMetaClient.builder().setConf(fs.getConf()).setBasePath(targetDataPath).build(); + HoodieTableMetaClient metadata = HoodieTableMetaClient.builder() + .setConf(HadoopFSUtils.getStorageConfWithCopy(fs.getConf())).setBasePath(targetDataPath).build(); Option lastCommit = metadata.getActiveTimeline().getCommitsTimeline().filterCompletedInstants().lastInstant(); if (lastCommit.isPresent()) { - return lastCommit.get().getTimestamp(); + return lastCommit.get().requestedTime(); } return "0"; } @@ -331,14 +334,15 @@ private boolean ensureTempPathExists(FileSystem fs, String lastCommitTime) throw } private String getLastCommitTimePulled(FileSystem fs, String sourceTableLocation) { - HoodieTableMetaClient metadata = HoodieTableMetaClient.builder().setConf(fs.getConf()).setBasePath(sourceTableLocation).build(); + HoodieTableMetaClient metadata = HoodieTableMetaClient.builder() + .setConf(HadoopFSUtils.getStorageConfWithCopy(fs.getConf())) + .setBasePath(sourceTableLocation).build(); List commitsToSync = metadata.getActiveTimeline().getCommitsTimeline().filterCompletedInstants() - .findInstantsAfter(config.fromCommitTime, config.maxCommits).getInstantsAsStream().map(HoodieInstant::getTimestamp) + .findInstantsAfter(config.fromCommitTime, config.maxCommits).getInstantsAsStream().map(HoodieInstant::requestedTime) .collect(Collectors.toList()); if (commitsToSync.isEmpty()) { - LOG.info("Nothing to sync. All commits in {} are {} and from commit time is {}", config.sourceTable, - metadata.getActiveTimeline().getCommitsTimeline().filterCompletedInstants().getInstants(), - config.fromCommitTime); + LOG.info("Nothing to sync. All commits in {} are {} and from commit time is {}", config.sourceTable, metadata.getActiveTimeline().getCommitsTimeline() + .filterCompletedInstants().getInstants(), config.fromCommitTime); return null; } LOG.info("Syncing commits {}", commitsToSync);