diff --git a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieFileIndex.scala b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieFileIndex.scala index 660592cf791ef..2187e5ab6903f 100644 --- a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieFileIndex.scala +++ b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieFileIndex.scala @@ -37,7 +37,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.catalog.HoodieCatalogTable -import org.apache.spark.sql.catalyst.expressions.{Expression, Literal} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, GetStructField, Literal} import org.apache.spark.sql.execution.datasources.{FileIndex, FileStatusCache, NoopCache, PartitionDirectory} import org.apache.spark.sql.hudi.HoodieSqlCommonUtils import org.apache.spark.sql.internal.SQLConf @@ -105,6 +105,10 @@ case class HoodieFileIndex(spark: SparkSession, @transient protected var hasPushedDownPartitionPredicates: Boolean = false + /** True when any partition column is a nested field path (e.g. "nested_record.level"). */ + private val hasNestedPartitionColumns: Boolean = + getPartitionColumns.exists(_.contains(".")) + /** * NOTE: [[indicesSupport]] is a transient state, since it's only relevant while logical plan * is handled by the Spark's driver @@ -167,19 +171,44 @@ case class HoodieFileIndex(spark: SparkSession, /** * Invoked by Spark to fetch list of latest base files per partition. * - * @param partitionFilters partition column filters - * @param dataFilters data columns filters - * @return list of PartitionDirectory containing partition to base files mapping + * For regular partition columns, Spark passes correct `partitionFilters` directly. + * + * For nested partition columns (e.g. `nested_record.level`), Spark cannot match + * [[GetStructField]] expressions against the flat dot-path partition schema and passes + * `partitionFilters = []`. The nested predicates land in `dataFilters` instead. + * We re-extract them via [[extractNestedPartitionFilters]]. + * + * Example: `SELECT * FROM t WHERE nested_record.level = 'INFO' AND int_field > 0` + * - Spark passes: `partitionFilters = []`, `dataFilters = [nested_record.level = 'INFO', int_field > 0]` + * - We extract: `effectivePartitionFilters = [nested_record.level = 'INFO']` + * + * This is stateless — safe under AQE re-planning, subqueries, and FileIndex reuse. + * + * Known limitation: for mixed flat+nested partitions (e.g. `["country", "nested_record.level"]`), + * if Spark passes `partitionFilters = [country = 'US']`, we skip extraction and the nested + * filter is not used for partition pruning. A future fix could merge extracted nested filters + * with the provided `partitionFilters`. */ override def listFiles(partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): Seq[PartitionDirectory] = { - val slices = filterFileSlices(dataFilters, partitionFilters).flatMap( + val effectivePartitionFilters = if (partitionFilters.isEmpty && hasNestedPartitionColumns) { + extractNestedPartitionFilters(dataFilters) + } else { + partitionFilters + } + + val slices = filterFileSlices(dataFilters, effectivePartitionFilters).flatMap( { case (partitionOpt, fileSlices) => - fileSlices.filter(!_.isEmpty).map(fs => ( InternalRow.fromSeq(partitionOpt.get.getValues), fs)) + fileSlices.filter(!_.isEmpty).map(fs => (InternalRow.fromSeq(partitionOpt.get.getValues), fs)) } ) prepareFileSlices(slices) } + /** Delegates to companion object with this table's partition columns. */ + private def extractNestedPartitionFilters(dataFilters: Seq[Expression]): Seq[Expression] = { + HoodieFileIndex.extractNestedPartitionFilters(dataFilters, getPartitionColumns.toSet) + } + protected def prepareFileSlices(slices: Seq[(InternalRow, FileSlice)]): Seq[PartitionDirectory] = { hasPushedDownPartitionPredicates = true @@ -212,25 +241,25 @@ case class HoodieFileIndex(spark: SparkSession, } /** - * The functions prunes the partition paths based on the input partition filters. For every partition path, the file - * slices are further filtered after querying metadata table based on the data filters. + * Prunes partitions by `partitionFilters`, then optionally applies data skipping via metadata + * table indices (column stats, record-level index, etc.) to filter file slices. * - * @param dataFilters data columns filters - * @param partitionFilters partition column filters - * @param partitionPrune for HoodiePruneFileSourcePartitions rule only prune partitions - * @return A sequence of pruned partitions and corresponding filtered file slices + * @param dataFilters data column filters (used for data skipping) + * @param partitionFilters partition column filters (used for partition pruning) + * @param isPartitionPruneOnly when true, skip data skipping. Used by [[HoodiePruneFileSourcePartitions]] + * during planning (data skipping runs later in [[listFiles]]). */ - def filterFileSlices(dataFilters: Seq[Expression], partitionFilters: Seq[Expression], isPartitionPruned: Boolean = false) + def filterFileSlices(dataFilters: Seq[Expression], partitionFilters: Seq[Expression], + isPartitionPruneOnly: Boolean = false) : Seq[(Option[BaseHoodieTableFileIndex.PartitionPath], Seq[FileSlice])] = { val (isPruned, prunedPartitionsAndFileSlices) = prunePartitionsAndGetFileSlices(dataFilters, partitionFilters) hasPushedDownPartitionPredicates = true - // If there are no data filters, return all the file slices. - // If isPartitionPurge is true, this fun is trigger by HoodiePruneFileSourcePartitions, don't look up candidate files - // If there are no file slices, return empty list. - if (prunedPartitionsAndFileSlices.isEmpty || dataFilters.isEmpty || isPartitionPruned ) { + // Skip data skipping when: no file slices, no data filters, or partition-prune-only mode + // (planning phase — data skipping runs later during execution). + if (prunedPartitionsAndFileSlices.isEmpty || dataFilters.isEmpty || isPartitionPruneOnly) { prunedPartitionsAndFileSlices } else { // Look up candidate files names in the col-stats or record level index, if all of the following conditions are true @@ -502,6 +531,65 @@ object HoodieFileIndex extends Logging { val Strict: Val = Val("strict") } + /** + * Extracts filters from `dataFilters` that reference nested partition columns by walking + * [[GetStructField]] chains to reconstruct the full dot-path and matching against partition + * column names. We cannot match on the struct root alone because sibling fields share it + * (e.g. `nested_record.level` and `nested_record.nested_int` both reference `nested_record`). + * + * Given partition column `nested_record.level` and: + * {{{ + * dataFilters = [nested_record.level = 'INFO', nested_record.nested_int > 0, int_field = 5] + * }}} + * Returns: `[nested_record.level = 'INFO']` + * + * Known limitations vs regular partition columns: + * - `(nested_record.level = 'INFO' AND d = 2) OR (nested_record.level = 'ERROR')` is excluded + * entirely (references both partition and data columns). A weaker predicate like + * `nested_record.level IN ('INFO', 'ERROR')` could be extracted but is not implemented. + * Spark has the same OR limitation for regular partition columns. + * + * @param dataFilters filters to scan for nested partition predicates + * @param partitionColumnNames partition column dot-paths, e.g. `Set("nested_record.level")` + * @return only the filters whose every column reference is a partition column + */ + private[hudi] def extractNestedPartitionFilters(dataFilters: Seq[Expression], + partitionColumnNames: Set[String]): Seq[Expression] = { + val partitionColumnRoots = partitionColumnNames.map(_.split("\\.", 2)(0)) + dataFilters.filter { expr => + // Resolve all outermost GetStructField chains to their full dot-paths. + val structFieldPaths = collectOutermostStructFieldPaths(expr) + // The expression is a partition filter only when: + // 1. It contains at least one GetStructField that resolves to a partition column path, AND + // 2. ALL resolved paths are partition columns (no non-partition nested fields), AND + // 3. ALL attribute references are roots of partition columns + // (guards against mixed expressions like "nested_record.level = 'INFO' AND int_field > 0") + structFieldPaths.nonEmpty && + structFieldPaths.forall(partitionColumnNames.contains) && + expr.references.map(_.name).forall(partitionColumnRoots.contains) + } + } + + /** + * Collects full dot-paths of outermost [[GetStructField]] chains in an expression. + * `EqualTo(a.b.c, 1)` → `Seq("a.b.c")` (not intermediate `"a.b"`). + */ + private[hudi] def collectOutermostStructFieldPaths(expr: Expression): Seq[String] = { + expr match { + case g: GetStructField => resolveGetStructFieldPath(g).toSeq + case _ => expr.children.flatMap(collectOutermostStructFieldPaths) + } + } + + /** Resolves a [[GetStructField]] chain to its full dot-path: `attr("a").b.c` → `"a.b.c"`. */ + private[hudi] def resolveGetStructFieldPath(expr: Expression): Option[String] = expr match { + case GetStructField(child: AttributeReference, _, Some(fieldName)) => + Some(child.name + "." + fieldName) + case GetStructField(child: GetStructField, _, Some(fieldName)) => + resolveGetStructFieldPath(child).map(_ + "." + fieldName) + case _ => None + } + def collectReferencedColumns(spark: SparkSession, queryFilters: Seq[Expression], schema: StructType): Seq[String] = { val resolver = spark.sessionState.analyzer.resolver val refs = queryFilters.flatMap(_.references) diff --git a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/SparkHoodieTableFileIndex.scala b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/SparkHoodieTableFileIndex.scala index 1ba9628af3b75..b80eb204823af 100644 --- a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/SparkHoodieTableFileIndex.scala +++ b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/SparkHoodieTableFileIndex.scala @@ -46,11 +46,11 @@ import org.apache.spark.api.java.JavaSparkContext import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.{expressions, InternalRow} -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BasePredicate, BoundReference, EmptyRow, EqualTo, Expression, InterpretedPredicate, Literal} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BasePredicate, BoundReference, EmptyRow, EqualTo, Expression, GetStructField, Literal} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.{FileStatusCache, NoopCache} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{ByteType, DateType, IntegerType, LongType, ShortType, StringType, StructField, StructType} +import org.apache.spark.sql.types.{ByteType, DataType, DateType, IntegerType, LongType, ShortType, StringType, StructField, StructType} import org.slf4j.LoggerFactory import javax.annotation.concurrent.NotThreadSafe @@ -59,6 +59,7 @@ import java.lang.reflect.{Array => JArray} import java.util.Collections import scala.collection.JavaConverters._ +import scala.collection.mutable.LinkedHashMap import scala.language.implicitConversions import scala.util.{Success, Try} @@ -201,6 +202,25 @@ class SparkHoodieTableFileIndex(spark: SparkSession, } } + /** + * Spark-facing partition schema that preserves nested structure for nested partition columns. + * + * NOTE: Hudi's [[partitionSchema]] intentionally returns a *flat* schema where field names use full + * dot-paths (for example, "a.b.c") to avoid collisions with top-level data columns. Some Spark + * planner/analyzer paths, however, reason about nested columns as nested [[StructType]]s and + * require a nested schema shape to properly resolve [[GetStructField]] chains. + * + * This method reconstructs a nested [[StructType]] from the flat partition schema, using the same + * leaf data-types, and preserving deterministic field ordering based on the original flat schema. + */ + def partitionSchemaForSpark: StructType = { + if (!shouldReadAsPartitionedTable) { + new StructType() + } else { + SparkHoodieTableFileIndex.buildNestedPartitionSchema(_partitionSchemaFromProperties) + } + } + /** * Fetch list of latest base files w/ corresponding log files, after performing * partition pruning @@ -238,13 +258,32 @@ class SparkHoodieTableFileIndex(spark: SparkSession, def listMatchingPartitionPaths(predicates: Seq[Expression]): Seq[PartitionPath] = { val resolve = spark.sessionState.analyzer.resolver val partitionColumnNames = getPartitionColumns - val partitionPruningPredicates = predicates.filter { - _.references.map(_.name).forall { ref => - // NOTE: We're leveraging Spark's resolver here to appropriately handle case-sensitivity - partitionColumnNames.exists(partCol => resolve(ref, partCol)) - } + + // Resolves GetStructField chain to full dot-path: GetStructField(attr("a"), _, "b") → "a.b" + def getFieldPath(expr: Expression): Option[String] = expr match { + case a: AttributeReference => Some(a.name) + case GetStructField(child, _, Some(fieldName)) => + getFieldPath(child).map(_ + "." + fieldName) + case _ => None } + // True if every column reference in expr resolves to a partition column. + // For nested columns, walks GetStructField chains to match the full dot-path. + // Example: partition = "nested_record.level" + // nested_record.level = 'INFO' → GetStructField path "nested_record.level" → true + // nested_record.nested_int = 10 → GetStructField path "nested_record.nested_int" → false + // IsNotNull(nested_record) → AttributeReference "nested_record" not in partitionColumnNames → false + def referencesOnlyPartitionColumns(expr: Expression): Boolean = expr match { + case g: GetStructField => + getFieldPath(g).exists(path => partitionColumnNames.exists(pc => resolve(path, pc))) + case a: AttributeReference => + partitionColumnNames.exists(pc => resolve(a.name, pc)) + case _ => + expr.children.forall(referencesOnlyPartitionColumns) + } + + val partitionPruningPredicates = predicates.filter(referencesOnlyPartitionColumns) + if (partitionPruningPredicates.isEmpty) { val queryPartitionPaths = getAllQueryPartitionPaths.asScala.toSeq logInfo(s"No partition predicates provided, listing full table (${queryPartitionPaths.size} partitions)") @@ -269,10 +308,18 @@ class SparkHoodieTableFileIndex(spark: SparkSession, // the whole table if (haveProperPartitionValues(partitionPaths.toSeq) && partitionSchema.nonEmpty) { val predicate = partitionPruningPredicates.reduce(expressions.And) + val partitionFieldNames = partitionSchema.fieldNames val transformedPredicate = predicate.transform { + case g @ GetStructField(_, _, Some(_)) => + getFieldPath(g).flatMap { path => + val idx = partitionFieldNames.indexWhere(name => resolve(path, name)) + if (idx >= 0) Some(BoundReference(idx, partitionSchema(idx).dataType, nullable = true)) + else None + }.getOrElse(g) case a: AttributeReference => - val index = partitionSchema.indexWhere(a.name == _.name) - BoundReference(index, partitionSchema(index).dataType, nullable = true) + val index = partitionSchema.indexWhere(sf => resolve(a.name, sf.name)) + if (index >= 0) BoundReference(index, partitionSchema(index).dataType, nullable = true) + else a } val boundPredicate: BasePredicate = try { // Try using 1-arg constructor via reflection @@ -488,6 +535,76 @@ object SparkHoodieTableFileIndex extends SparkAdapterSupport { private val LOG = LoggerFactory.getLogger(classOf[SparkHoodieTableFileIndex]) private val PUT_LEAF_FILES_METHOD_NAME = "putLeafFiles" + private case class NestedFieldNode( + leafType: Option[DataType], + children: LinkedHashMap[String, NestedFieldNode] + ) + + /** + * Reconstruct nested partition schema from a flat partition schema containing dot-path field names. + * + * For example, flat fields ["a.b": int, "a.c": string, "d": long] becomes: + * + * StructType( + * StructField("a", StructType(StructField("b", int), StructField("c", string))), + * StructField("d", long) + * ) + */ + private[hudi] def buildNestedPartitionSchema(flatPartitionSchema: StructType): StructType = { + if (flatPartitionSchema.isEmpty) { + new StructType() + } else { + val root = NestedFieldNode(None, LinkedHashMap.empty) + + def getOrCreateChild(parent: NestedFieldNode, name: String): NestedFieldNode = { + parent.children.getOrElseUpdate(name, NestedFieldNode(None, LinkedHashMap.empty)) + } + + flatPartitionSchema.fields.foreach { field => + val parts = field.name.split("\\.", -1) + checkState(parts.forall(p => p.nonEmpty), + s"Invalid partition field path '${field.name}' in partition schema") + + var node = root + var i = 0 + while (i < parts.length) { + val part = parts(i) + val isLeaf = i == parts.length - 1 + + if (isLeaf) { + val child = getOrCreateChild(node, part) + checkState(child.children.isEmpty, + s"Conflicting partition schema: '${field.name}' collides with nested fields under '${parts.take(i + 1).mkString(".")}'") + checkState(child.leafType.isEmpty || child.leafType.contains(field.dataType), + s"Conflicting partition schema: '${field.name}' has inconsistent types (${child.leafType.orNull} vs ${field.dataType})") + node.children.update(part, child.copy(leafType = Some(field.dataType))) + } else { + val child = getOrCreateChild(node, part) + checkState(child.leafType.isEmpty, + s"Conflicting partition schema: '${field.name}' requires struct at '${parts.take(i + 1).mkString(".")}', but a leaf is defined") + node = child + } + + i += 1 + } + } + + def toStructType(node: NestedFieldNode): StructType = { + val fields = node.children.map { case (name, child) => + child.leafType match { + case Some(dt) if child.children.isEmpty => + StructField(name, dt, nullable = true) + case _ => + StructField(name, toStructType(child), nullable = true) + } + }.toArray + StructType(fields) + } + + toStructType(root) + } + } + private def haveProperPartitionValues(partitionPaths: Seq[PartitionPath]) = { partitionPaths.forall(_.getValues.length > 0) } @@ -520,27 +637,10 @@ object SparkHoodieTableFileIndex extends SparkAdapterSupport { } /** - * This method unravels [[StructType]] into a [[Map]] of pairs of dot-path notation with corresponding - * [[StructField]] object for every field of the provided [[StructType]], recursively. - * - * For example, following struct - *
- * StructType(
- * StructField("a",
- * StructType(
- * StructField("b", StringType),
- * StructField("c", IntType)
- * )
- * )
- * )
- *
- *
- * will be converted into following mapping:
- *
- *
- * "a.b" -> StructField("b", StringType),
- * "a.c" -> StructField("c", IntType),
- *
+ * Maps every leaf field in `structType` to its dot-path name.
+ * Both the key and [[StructField.name]] use the full path.
+ * E.g. `StructType(StructField("a", StructType(StructField("b", IntegerType))))`
+ * → `Map("a.b" -> StructField("a.b", IntegerType))`.
*/
private def generateFieldMap(structType: StructType) : Map[String, StructField] = {
def traverse(structField: Either[StructField, StructType]) : Map[String, StructField] = {
@@ -548,7 +648,10 @@ object SparkHoodieTableFileIndex extends SparkAdapterSupport {
case Right(struct) => struct.fields.flatMap(f => traverse(Left(f))).toMap
case Left(field) => field.dataType match {
case struct: StructType => traverse(Right(struct)).map {
- case (key, structField) => (s"${field.name}.$key", structField)
+ case (key, structField) => {
+ val fullPath = s"${field.name}.$key"
+ (fullPath, structField.copy(name = fullPath))
+ }
}
case _ => Map(field.name -> field)
}
diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestHoodieFileIndex.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestHoodieFileIndex.scala
index d5e0c6a927ac3..8d06e257d1787 100644
--- a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestHoodieFileIndex.scala
+++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestHoodieFileIndex.scala
@@ -44,14 +44,14 @@ import org.apache.hudi.testutils.HoodieSparkClientTestBase
import org.apache.hudi.util.JFunction
import org.apache.spark.sql._
-import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, EqualTo, GreaterThanOrEqual, LessThan, Literal}
+import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, EqualTo, Expression, GetStructField, GreaterThanOrEqual, LessThan, Literal, Or}
import org.apache.spark.sql.execution.datasources.{NoopCache, PartitionDirectory}
import org.apache.spark.sql.functions.{lit, struct}
import org.apache.spark.sql.hudi.HoodieSparkSessionExtension
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.junit.jupiter.api.{BeforeEach, Test}
-import org.junit.jupiter.api.Assertions.{assertEquals, assertTrue}
+import org.junit.jupiter.api.Assertions.{assertEquals, assertThrows, assertTrue}
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.{Arguments, CsvSource, MethodSource, ValueSource}
@@ -858,6 +858,35 @@ class TestHoodieFileIndex extends HoodieSparkClientTestBase with ScalaAssertionS
partitionValues.mkString(StoragePath.SEPARATOR)
}
}
+
+ // ---- buildNestedPartitionSchema tests ----
+
+ @ParameterizedTest
+ @MethodSource(Array("buildNestedPartitionSchemaCases"))
+ def testBuildNestedPartitionSchema(name: String, flat: StructType, expected: StructType): Unit = {
+ assertEquals(expected, SparkHoodieTableFileIndex.buildNestedPartitionSchema(flat))
+ }
+
+ @Test
+ def testBuildNestedPartitionSchemaConflictThrows(): Unit = {
+ // "a" as leaf and "a.b" as nested — conflict
+ val flat = StructType(Seq(StructField("a", StringType), StructField("a.b", IntegerType)))
+ assertThrows(classOf[IllegalStateException]) {
+ SparkHoodieTableFileIndex.buildNestedPartitionSchema(flat)
+ }
+ }
+
+ // ---- extractNestedPartitionFilters tests ----
+
+ @ParameterizedTest
+ @MethodSource(Array("extractNestedPartitionFiltersCases"))
+ def testExtractNestedPartitionFilters(name: String,
+ filters: Seq[Expression],
+ partitionColumns: Set[String],
+ expected: Seq[Expression]): Unit = {
+ assertEquals(expected, HoodieFileIndex.extractNestedPartitionFilters(filters, partitionColumns))
+ }
+
}
object TestHoodieFileIndex {
@@ -870,4 +899,65 @@ object TestHoodieFileIndex {
Arguments.arguments("org.apache.hudi.keygen.TimestampBasedKeyGenerator")
)
}
+
+ def buildNestedPartitionSchemaCases(): java.util.stream.Stream[Arguments] = {
+ val nested = StructType(Seq(
+ StructField("nested_record", StructType(Seq(StructField("level", StringType, nullable = true))), nullable = true)))
+ val twoLevel = StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("b", StructType(Seq(StructField("c", IntegerType, nullable = true))), nullable = true))), nullable = true)))
+ val siblings = StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("b", StringType, nullable = true),
+ StructField("c", IntegerType, nullable = true))), nullable = true)))
+ val mixed = StructType(Seq(
+ StructField("country", StringType, nullable = true),
+ StructField("nested_record", StructType(Seq(StructField("level", StringType, nullable = true))), nullable = true)))
+ java.util.stream.Stream.of(
+ Arguments.of("empty",
+ new StructType(),
+ new StructType()),
+ Arguments.of("flat",
+ StructType(Seq(StructField("country", StringType))),
+ StructType(Seq(StructField("country", StringType, nullable = true)))),
+ Arguments.of("singleNested",
+ StructType(Seq(StructField("nested_record.level", StringType))),
+ nested),
+ Arguments.of("twoLevelNesting",
+ StructType(Seq(StructField("a.b.c", IntegerType))),
+ twoLevel),
+ Arguments.of("siblingFields",
+ StructType(Seq(StructField("a.b", StringType), StructField("a.c", IntegerType))),
+ siblings),
+ Arguments.of("mixedFlatAndNested",
+ StructType(Seq(StructField("country", StringType), StructField("nested_record.level", StringType))),
+ mixed)
+ )
+ }
+
+ def extractNestedPartitionFiltersCases(): java.util.stream.Stream[Arguments] = {
+ val levelStruct = StructType(Seq(StructField("level", StringType)))
+ val multiFieldStruct = StructType(Seq(
+ StructField("nested_int", IntegerType), StructField("level", StringType)))
+
+ val partFilter = EqualTo(
+ GetStructField(AttributeReference("nested_record", levelStruct)(), 0, Some("level")),
+ Literal("INFO"))
+ val dataFilter = EqualTo(AttributeReference("int_field", IntegerType)(), Literal(5))
+ val siblingFilter = EqualTo(
+ GetStructField(AttributeReference("nested_record", multiFieldStruct)(), 0, Some("nested_int")),
+ Literal(10))
+ val orFilter = Or(
+ EqualTo(GetStructField(AttributeReference("nested_record", levelStruct)(), 0, Some("level")), Literal("INFO")),
+ EqualTo(GetStructField(AttributeReference("nested_record", levelStruct)(), 0, Some("level")), Literal("ERROR")))
+
+ java.util.stream.Stream.of(
+ Arguments.of("partitionFilterExtractedDataFilterDropped",
+ Seq(partFilter, dataFilter), Set("nested_record.level"), Seq(partFilter)),
+ Arguments.of("siblingFieldExcluded",
+ Seq(siblingFilter), Set("nested_record.level"), Seq.empty[Expression]),
+ Arguments.of("orWithOnlyPartitionColumnsExtracted",
+ Seq(orFilter), Set("nested_record.level"), Seq(orFilter))
+ )
+ }
}
diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestCOWDataSource.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestCOWDataSource.scala
index db5ca1ee78002..106c1f09a7ae2 100644
--- a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestCOWDataSource.scala
+++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestCOWDataSource.scala
@@ -17,7 +17,7 @@
package org.apache.hudi.functional
-import org.apache.hudi.{AvroConversionUtils, DataSourceReadOptions, DataSourceWriteOptions, HoodieDataSourceHelpers, HoodieSchemaConversionUtils, HoodieSparkUtils, QuickstartUtils, ScalaAssertionSupport}
+import org.apache.hudi.{AvroConversionUtils, DataSourceReadOptions, DataSourceWriteOptions, HoodieBaseRelation, HoodieDataSourceHelpers, HoodieFileIndex, HoodieSchemaConversionUtils, HoodieSparkUtils, QuickstartUtils, ScalaAssertionSupport}
import org.apache.hudi.DataSourceWriteOptions.{INLINE_CLUSTERING_ENABLE, KEYGENERATOR_CLASS_NAME}
import org.apache.hudi.HoodieConversionUtils.toJavaOption
import org.apache.hudi.QuickstartUtils.{convertToStringList, getQuickstartWriteConfigs}
@@ -42,13 +42,14 @@ import org.apache.hudi.hive.HiveSyncConfigHolder
import org.apache.hudi.keygen.{ComplexKeyGenerator, CustomKeyGenerator, GlobalDeleteKeyGenerator, NonpartitionedKeyGenerator, SimpleKeyGenerator, TimestampBasedKeyGenerator}
import org.apache.hudi.keygen.constant.{KeyGeneratorOptions, KeyGeneratorType}
import org.apache.hudi.metrics.{Metrics, MetricsReporterType}
-import org.apache.hudi.storage.{StoragePath, StoragePathFilter}
+import org.apache.hudi.storage.{HoodieStorage, StoragePath, StoragePathFilter}
import org.apache.hudi.table.HoodieSparkTable
import org.apache.hudi.testutils.{DataSourceTestUtils, HoodieSparkClientTestBase}
import org.apache.hudi.util.JFunction
import org.apache.hadoop.fs.FileSystem
import org.apache.spark.sql.{DataFrame, DataFrameWriter, Dataset, Encoders, Row, SaveMode, SparkSession, SparkSessionExtensions}
+import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation}
import org.apache.spark.sql.functions.{col, concat, lit, udf, when}
import org.apache.spark.sql.hudi.HoodieSparkSessionExtension
import org.apache.spark.sql.types.{ArrayType, DataTypes, DateType, IntegerType, LongType, MapType, StringType, StructField, StructType, TimestampType}
@@ -2616,9 +2617,300 @@ class TestCOWDataSource extends HoodieSparkClientTestBase with ScalaAssertionSup
assertEquals("row3", results(2).getAs[String]("_row_key"))
assertEquals("value3", results(2).getAs[String]("data"))
}
+
+ @Test
+ def testNestedFieldPartition(): Unit = {
+ TestCOWDataSource.runNestedFieldPartitionTest(spark, basePath, storage, "COW")
+ }
}
object TestCOWDataSource {
+
+ /**
+ * Shared test logic for nested field partition (COW and MOR).
+ * Used by TestCOWDataSource.testNestedFieldPartition and TestMORDataSource.testNestedFieldPartition.
+ */
+ def runNestedFieldPartitionTest(spark: SparkSession, basePath: String, storage: HoodieStorage, tableType: String): Unit = {
+ // Define schema with nested_record containing level field
+ val nestedSchema = StructType(Seq(
+ StructField("nested_int", IntegerType, nullable = false),
+ StructField("level", StringType, nullable = false)
+ ))
+
+ val schema = StructType(Seq(
+ StructField("key", StringType, nullable = false),
+ StructField("ts", LongType, nullable = false),
+ StructField("level", StringType, nullable = false),
+ StructField("int_field", IntegerType, nullable = false),
+ StructField("string_field", StringType, nullable = true),
+ StructField("nested_record", nestedSchema, nullable = true)
+ ))
+
+ // Create test data where top-level 'level' and 'nested_record.level' have DIFFERENT values
+ // This helps verify we're correctly partitioning/filtering on the nested field
+ val recordsCommit1 = Seq(
+ Row("key1", 1L, "L1", 1, "str1", Row(10, "INFO")),
+ Row("key2", 2L, "L2", 2, "str2", Row(20, "ERROR")),
+ Row("key3", 3L, "L3", 3, "str3", Row(30, "INFO")),
+ Row("key4", 4L, "L4", 4, "str4", Row(40, "DEBUG")),
+ Row("key5", 5L, "L5", 5, "str5", Row(50, "INFO"))
+ )
+
+ val tableTypeOptVal = if (tableType == "MOR") {
+ DataSourceWriteOptions.MOR_TABLE_TYPE_OPT_VAL
+ } else {
+ DataSourceWriteOptions.COW_TABLE_TYPE_OPT_VAL
+ }
+
+ val baseWriteOpts = Map(
+ "hoodie.insert.shuffle.parallelism" -> "4",
+ "hoodie.upsert.shuffle.parallelism" -> "4",
+ DataSourceWriteOptions.RECORDKEY_FIELD.key -> "key",
+ DataSourceWriteOptions.PARTITIONPATH_FIELD.key -> "nested_record.level",
+ HoodieTableConfig.ORDERING_FIELDS.key -> "ts",
+ HoodieWriteConfig.TBL_NAME.key -> "test_nested_partition",
+ DataSourceWriteOptions.TABLE_TYPE.key -> tableTypeOptVal
+ )
+ val writeOpts = if (tableType == "MOR") {
+ baseWriteOpts + ("hoodie.compact.inline" -> "false")
+ } else {
+ baseWriteOpts
+ }
+
+ // Commit 1 - Initial insert
+ val inputDF1 = spark.createDataFrame(
+ spark.sparkContext.parallelize(recordsCommit1),
+ schema
+ )
+ inputDF1.write.format("hudi")
+ .options(writeOpts)
+ .option(DataSourceWriteOptions.OPERATION.key, DataSourceWriteOptions.INSERT_OPERATION_OPT_VAL)
+ .mode(SaveMode.Overwrite)
+ .save(basePath)
+ val commit1 = DataSourceTestUtils.latestCommitCompletionTime(storage, basePath)
+
+ // Commit 2 - Upsert: update key1 (int_field 1->100), insert key6 (INFO)
+ val recordsCommit2 = Seq(
+ Row("key1", 10L, "L1", 100, "str1", Row(10, "INFO")),
+ Row("key6", 6L, "L6", 6, "str6", Row(60, "INFO"))
+ )
+ val inputDF2 = spark.createDataFrame(
+ spark.sparkContext.parallelize(recordsCommit2),
+ schema
+ )
+ inputDF2.write.format("hudi")
+ .options(writeOpts)
+ .option(DataSourceWriteOptions.OPERATION.key, DataSourceWriteOptions.UPSERT_OPERATION_OPT_VAL)
+ .mode(SaveMode.Append)
+ .save(basePath)
+ val commit2 = DataSourceTestUtils.latestCommitCompletionTime(storage, basePath)
+
+ // Commit 3 - Upsert: update key3 (int_field 3->300), insert key7 (INFO)
+ val recordsCommit3 = Seq(
+ Row("key3", 30L, "L3", 300, "str3", Row(30, "INFO")),
+ Row("key7", 7L, "L7", 7, "str7", Row(70, "INFO"))
+ )
+ val inputDF3 = spark.createDataFrame(
+ spark.sparkContext.parallelize(recordsCommit3),
+ schema
+ )
+ inputDF3.write.format("hudi")
+ .options(writeOpts)
+ .option(DataSourceWriteOptions.OPERATION.key, DataSourceWriteOptions.UPSERT_OPERATION_OPT_VAL)
+ .mode(SaveMode.Append)
+ .save(basePath)
+ val commit3 = DataSourceTestUtils.latestCommitCompletionTime(storage, basePath)
+
+ // Verify partition structure - we should have 3 partitions: INFO, ERROR, DEBUG
+ val allPartitions = storage.listDirectEntries(new StoragePath(basePath))
+ .asScala.filter(_.isDirectory)
+ .map(_.getPath.getName)
+ .filterNot(_.startsWith(".")) // Filter out .hoodie and other hidden directories
+ .sorted
+ assertEquals(3, allPartitions.size, s"Expected 3 partitions for $tableType, but got: ${allPartitions.mkString(", ")}")
+ assertTrue(allPartitions.contains("INFO"), s"Missing INFO partition for $tableType")
+ assertTrue(allPartitions.contains("ERROR"), s"Missing ERROR partition for $tableType")
+ assertTrue(allPartitions.contains("DEBUG"), s"Missing DEBUG partition for $tableType")
+
+ // Snapshot read - filter on nested_record.level = 'INFO' (latest state: 5 records)
+ val snapshotDF = spark.read.format("hudi")
+ .load(basePath)
+ .filter("nested_record.level = 'INFO'")
+ .select("key", "ts", "level", "int_field", "string_field", "nested_record")
+ .orderBy("key")
+
+ // VERIFICATION 1: Check partition schema contains the nested field
+ val snapshotRelation = snapshotDF.queryExecution.optimizedPlan.collectFirst {
+ case lr: LogicalRelation => lr
+ }
+ assertTrue(snapshotRelation.isDefined, s"LogicalRelation should exist for $tableType")
+ val fileIndex = snapshotRelation.get.relation match {
+ case fsRelation: HadoopFsRelation =>
+ fsRelation.location.asInstanceOf[HoodieFileIndex]
+ case baseRelation: HoodieBaseRelation =>
+ baseRelation.fileIndex
+ case _ => null
+ }
+ assertTrue(fileIndex != null, s"FileIndex should be available for $tableType")
+ assertEquals(1, fileIndex.partitionSchema.fields.length,
+ s"Partition schema should have 1 field for $tableType")
+ assertEquals("nested_record.level", fileIndex.partitionSchema.fields(0).name,
+ s"Partition field should be 'nested_record.level' for $tableType")
+
+ // VERIFICATION 2: Check that predicates were pushed down to FileIndex
+ assertTrue(fileIndex.hasPredicatesPushedDown,
+ s"Partition predicates should be pushed down to FileIndex for $tableType")
+
+ // VERIFICATION 3: Verify partition pruning by checking the physical plan
+ // The physical plan should show that only specific files are being scanned
+ val physicalPlan = snapshotDF.queryExecution.executedPlan.toString()
+ assertTrue(physicalPlan.contains("Scan") || physicalPlan.contains("FileScan"),
+ s"Physical plan should contain scan operation for $tableType")
+
+ // Collect results to execute the query
+ val snapshotResults = snapshotDF.collect()
+ val expectedSnapshot = Array(
+ Row("key1", 10L, "L1", 100, "str1", Row(10, "INFO")),
+ Row("key3", 30L, "L3", 300, "str3", Row(30, "INFO")),
+ Row("key5", 5L, "L5", 5, "str5", Row(50, "INFO")),
+ Row("key6", 6L, "L6", 6, "str6", Row(60, "INFO")),
+ Row("key7", 7L, "L7", 7, "str7", Row(70, "INFO"))
+ )
+ assertEquals(expectedSnapshot.length, snapshotResults.length,
+ s"Snapshot (INFO) count mismatch for $tableType")
+ expectedSnapshot.zip(snapshotResults).foreach { case (expected, actual) =>
+ assertEquals(expected, actual)
+ }
+
+ // Time travel - as of commit1 (only initial 5 records; INFO = key1, key3, key5)
+ val timeTravelDF1 = spark.read.format("hudi")
+ .option(DataSourceReadOptions.TIME_TRAVEL_AS_OF_INSTANT.key, commit1)
+ .load(basePath)
+ .filter("nested_record.level = 'INFO'")
+ .select("key", "ts", "level", "int_field", "string_field", "nested_record")
+ .orderBy("key")
+
+ // VERIFICATION 4: Verify partition pruning works for time travel queries
+ // Check that the time travel query with partition filter returns correct results
+ val timeTravelCommit1 = timeTravelDF1.collect()
+ val expectedAfterCommit1 = Array(
+ Row("key1", 1L, "L1", 1, "str1", Row(10, "INFO")),
+ Row("key3", 3L, "L3", 3, "str3", Row(30, "INFO")),
+ Row("key5", 5L, "L5", 5, "str5", Row(50, "INFO"))
+ )
+ assertEquals(expectedAfterCommit1.length, timeTravelCommit1.length,
+ s"Time travel to commit1 (INFO) count mismatch for $tableType")
+ expectedAfterCommit1.zip(timeTravelCommit1).foreach { case (expected, actual) =>
+ assertEquals(expected, actual)
+ }
+
+ // Time travel - as of commit2 (after 2nd commit; INFO = key1 updated, key3, key5, key6)
+ val timeTravelCommit2 = spark.read.format("hudi")
+ .option(DataSourceReadOptions.TIME_TRAVEL_AS_OF_INSTANT.key, commit2)
+ .load(basePath)
+ .filter("nested_record.level = 'INFO'")
+ .select("key", "ts", "level", "int_field", "string_field", "nested_record")
+ .orderBy("key")
+ .collect()
+
+ val expectedAfterCommit2 = Array(
+ Row("key1", 10L, "L1", 100, "str1", Row(10, "INFO")),
+ Row("key3", 3L, "L3", 3, "str3", Row(30, "INFO")),
+ Row("key5", 5L, "L5", 5, "str5", Row(50, "INFO")),
+ Row("key6", 6L, "L6", 6, "str6", Row(60, "INFO"))
+ )
+ assertEquals(expectedAfterCommit2.length, timeTravelCommit2.length,
+ s"Time travel to commit2 (INFO) count mismatch for $tableType")
+ expectedAfterCommit2.zip(timeTravelCommit2).foreach { case (expected, actual) =>
+ assertEquals(expected, actual)
+ }
+
+ // Incremental query - from commit1 to commit2 (only key1 update and key6 insert; both INFO)
+ val incrementalDF1To2 = spark.read.format("hudi")
+ .option(DataSourceReadOptions.QUERY_TYPE.key, DataSourceReadOptions.QUERY_TYPE_INCREMENTAL_OPT_VAL)
+ .option(DataSourceReadOptions.START_COMMIT.key, commit1)
+ .option(DataSourceReadOptions.END_COMMIT.key, commit2)
+ .load(basePath)
+ .filter("nested_record.level = 'INFO'")
+ .select("key", "ts", "level", "int_field", "string_field", "nested_record")
+ .orderBy("key")
+
+ // VERIFICATION 6: Verify partition filtering works for incremental queries
+ // For incremental queries, the filter on nested_record.level should still limit scanned data
+ val incrementalPlan1To2 = incrementalDF1To2.queryExecution.executedPlan.toString()
+ // The plan should show filtering is happening
+ assertTrue(incrementalPlan1To2.contains("Filter") || incrementalPlan1To2.contains("Scan"),
+ s"Incremental query plan should show filtering for $tableType")
+
+ val incrementalCommit1To2 = incrementalDF1To2.collect()
+ val expectedInc1To2 = Array(
+ Row("key1", 10L, "L1", 100, "str1", Row(10, "INFO")),
+ Row("key6", 6L, "L6", 6, "str6", Row(60, "INFO"))
+ )
+ assertEquals(expectedInc1To2.length, incrementalCommit1To2.length,
+ s"Incremental (commit1->commit2, INFO) count mismatch for $tableType")
+ expectedInc1To2.zip(incrementalCommit1To2).foreach { case (expected, actual) =>
+ assertEquals(expected, actual)
+ }
+
+ // Incremental query - from commit2 to commit3 (only key3 update and key7 insert; both INFO)
+ val incrementalCommit2To3 = spark.read.format("hudi")
+ .option(DataSourceReadOptions.QUERY_TYPE.key, DataSourceReadOptions.QUERY_TYPE_INCREMENTAL_OPT_VAL)
+ .option(DataSourceReadOptions.START_COMMIT.key, commit2)
+ .option(DataSourceReadOptions.END_COMMIT.key, commit3)
+ .load(basePath)
+ .filter("nested_record.level = 'INFO'")
+ .select("key", "ts", "level", "int_field", "string_field", "nested_record")
+ .orderBy("key")
+ .collect()
+
+ val expectedInc2To3 = Array(
+ Row("key3", 30L, "L3", 300, "str3", Row(30, "INFO")),
+ Row("key7", 7L, "L7", 7, "str7", Row(70, "INFO"))
+ )
+ assertEquals(expectedInc2To3.length, incrementalCommit2To3.length,
+ s"Incremental (commit2->commit3, INFO) count mismatch for $tableType")
+ expectedInc2To3.zip(incrementalCommit2To3).foreach { case (expected, actual) =>
+ assertEquals(expected, actual)
+ }
+
+ // VERIFICATION 4: Test with different partition values to ensure filtering is working correctly
+ // Query for ERROR partition (should only return key2)
+ val errorPartitionDF = spark.read.format("hudi")
+ .load(basePath)
+ .filter("nested_record.level = 'ERROR'")
+ .select("key", "nested_record")
+
+ val errorResults = errorPartitionDF.collect()
+ assertEquals(1, errorResults.length, s"ERROR partition should have 1 record for $tableType")
+ assertEquals("key2", errorResults(0).getString(0),
+ s"ERROR partition should contain key2 for $tableType")
+
+ // VERIFICATION 5: Test with DEBUG partition
+ val debugPartitionDF = spark.read.format("hudi")
+ .load(basePath)
+ .filter("nested_record.level = 'DEBUG'")
+ .select("key", "nested_record")
+
+ val debugResults = debugPartitionDF.collect()
+ assertEquals(1, debugResults.length, s"DEBUG partition should have 1 record for $tableType")
+ assertEquals("key4", debugResults(0).getString(0),
+ s"DEBUG partition should contain key4 for $tableType")
+
+ // VERIFICATION 6: Verify that filtering on top-level 'level' field returns correct results
+ // This ensures we're correctly distinguishing between nested_record.level (partition) and level (data column)
+ val topLevelFilterDF = spark.read.format("hudi")
+ .load(basePath)
+ .filter("level = 'L1'") // Filter on top-level 'level', not nested_record.level
+ .select("key", "level", "nested_record")
+
+ val topLevelResults = topLevelFilterDF.collect()
+ // Should return key1 which has level='L1' and is in INFO partition
+ assertEquals(1, topLevelResults.length, s"Top-level level='L1' should return 1 record for $tableType")
+ assertEquals("key1", topLevelResults(0).getString(0),
+ s"Top-level level='L1' should return key1 for $tableType")
+ }
+
def convertColumnsToNullable(df: DataFrame, cols: String*): DataFrame = {
cols.foldLeft(df) { (df, c) =>
// NOTE: This is the trick to make Spark convert a non-null column "c" into a nullable
diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestMORDataSource.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestMORDataSource.scala
index ab7bbcd097d28..81d049d432439 100644
--- a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestMORDataSource.scala
+++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestMORDataSource.scala
@@ -2347,6 +2347,11 @@ class TestMORDataSource extends HoodieSparkClientTestBase with SparkDatasetMixin
assertEquals("row4", results(3).getAs[String]("_row_key"))
assertEquals("value4", results(3).getAs[String]("data"))
}
+
+ @Test
+ def testNestedFieldPartition(): Unit = {
+ TestCOWDataSource.runNestedFieldPartitionTest(spark, basePath, storage, "MOR")
+ }
}
object TestMORDataSource {
diff --git a/hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/hudi/analysis/Spark3HoodiePruneFileSourcePartitions.scala b/hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/hudi/analysis/Spark3HoodiePruneFileSourcePartitions.scala
index 589ee9774d3b7..3cb9e8250da96 100644
--- a/hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/hudi/analysis/Spark3HoodiePruneFileSourcePartitions.scala
+++ b/hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/hudi/analysis/Spark3HoodiePruneFileSourcePartitions.scala
@@ -36,7 +36,7 @@ import org.apache.spark.sql.types.StructType
* Prune the partitions of Hudi table based relations by the means of pushing down the
* partition filters
*
- * NOTE: [[HoodiePruneFileSourcePartitions]] is a replica in kind to Spark's [[PruneFileSourcePartitions]]
+ * NOTE: [[HoodiePruneFileSourcePartitions]] is a replica in kind to Spark's [[org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions]]
*/
case class Spark3HoodiePruneFileSourcePartitions(spark: SparkSession) extends Rule[LogicalPlan] {
@@ -48,11 +48,11 @@ case class Spark3HoodiePruneFileSourcePartitions(spark: SparkSession) extends Ru
val normalizedFilters = exprUtils.normalizeExprs(deterministicFilters, lr.output)
val (partitionPruningFilters, dataFilters) =
- getPartitionFiltersAndDataFilters(fileIndex.partitionSchema, normalizedFilters)
+ getPartitionFiltersAndDataFilters(fileIndex.partitionSchemaForSpark, normalizedFilters)
// [[HudiFileIndex]] is a caching one, therefore we don't need to reconstruct new relation,
// instead we simply just refresh the index and update the stats
- fileIndex.filterFileSlices(dataFilters, partitionPruningFilters, isPartitionPruned = true)
+ fileIndex.filterFileSlices(dataFilters, partitionPruningFilters, isPartitionPruneOnly = true)
if (partitionPruningFilters.nonEmpty) {
// Change table stats based on the sizeInBytes of pruned files
@@ -105,11 +105,21 @@ private object Spark3HoodiePruneFileSourcePartitions extends PredicateHelper {
Project(projects, withFilter)
}
+ /**
+ * Returns true if the given attribute references a partition column. For nested partition columns
+ * (e.g. `nested_record.level`), `partitionSchema` is the nested [[StructType]] from
+ * `partitionSchemaForSpark`, so the top-level name is the struct root (e.g. `nested_record`),
+ * which matches `attr.name` directly via `contains`.
+ */
+ private def isPartitionColumnReference(attr: AttributeReference, partitionSchema: StructType): Boolean = {
+ partitionSchema.names.contains(attr.name)
+ }
+
def getPartitionFiltersAndDataFilters(partitionSchema: StructType,
normalizedFilters: Seq[Expression]): (Seq[Expression], Seq[Expression]) = {
val partitionColumns = normalizedFilters.flatMap { expr =>
expr.collect {
- case attr: AttributeReference if partitionSchema.names.contains(attr.name) =>
+ case attr: AttributeReference if isPartitionColumnReference(attr, partitionSchema) =>
attr
}
}
diff --git a/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/hudi/analysis/Spark33HoodiePruneFileSourcePartitions.scala b/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/hudi/analysis/Spark33HoodiePruneFileSourcePartitions.scala
index 7d7240231cd09..add1b7aaebf14 100644
--- a/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/hudi/analysis/Spark33HoodiePruneFileSourcePartitions.scala
+++ b/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/hudi/analysis/Spark33HoodiePruneFileSourcePartitions.scala
@@ -60,7 +60,7 @@ case class Spark33HoodiePruneFileSourcePartitions(spark: SparkSession) extends R
// [[HudiFileIndex]] is a caching one, therefore we don't need to reconstruct new relation,
// instead we simply just refresh the index and update the stats
- fileIndex.filterFileSlices(dataFilters, partitionPruningFilters, isPartitionPruned = true)
+ fileIndex.filterFileSlices(dataFilters, partitionPruningFilters, isPartitionPruneOnly = true)
if (partitionPruningFilters.nonEmpty) {
// Change table stats based on the sizeInBytes of pruned files
diff --git a/hudi-spark-datasource/hudi-spark4-common/src/main/scala/org/apache/spark/sql/hudi/analysis/Spark4HoodiePruneFileSourcePartitions.scala b/hudi-spark-datasource/hudi-spark4-common/src/main/scala/org/apache/spark/sql/hudi/analysis/Spark4HoodiePruneFileSourcePartitions.scala
index 8412018c22db6..0f6cf87da86f6 100644
--- a/hudi-spark-datasource/hudi-spark4-common/src/main/scala/org/apache/spark/sql/hudi/analysis/Spark4HoodiePruneFileSourcePartitions.scala
+++ b/hudi-spark-datasource/hudi-spark4-common/src/main/scala/org/apache/spark/sql/hudi/analysis/Spark4HoodiePruneFileSourcePartitions.scala
@@ -48,11 +48,11 @@ case class Spark4HoodiePruneFileSourcePartitions(spark: SparkSession) extends Ru
val normalizedFilters = exprUtils.normalizeExprs(deterministicFilters, lr.output)
val (partitionPruningFilters, dataFilters) =
- getPartitionFiltersAndDataFilters(fileIndex.partitionSchema, normalizedFilters)
+ getPartitionFiltersAndDataFilters(fileIndex.partitionSchemaForSpark, normalizedFilters)
// [[HudiFileIndex]] is a caching one, therefore we don't need to reconstruct new relation,
// instead we simply just refresh the index and update the stats
- fileIndex.filterFileSlices(dataFilters, partitionPruningFilters, isPartitionPruned = true)
+ fileIndex.filterFileSlices(dataFilters, partitionPruningFilters, isPartitionPruneOnly = true)
if (partitionPruningFilters.nonEmpty) {
// Change table stats based on the sizeInBytes of pruned files
@@ -105,11 +105,21 @@ private object Spark4HoodiePruneFileSourcePartitions extends PredicateHelper {
Project(projects, withFilter)
}
+ /**
+ * Returns true if the given attribute references a partition column. For nested partition columns
+ * (e.g. `nested_record.level`), `partitionSchema` is the nested [[StructType]] from
+ * `partitionSchemaForSpark`, so the top-level name is the struct root (e.g. `nested_record`),
+ * which matches `attr.name` directly via `contains`.
+ */
+ private def isPartitionColumnReference(attr: AttributeReference, partitionSchema: StructType): Boolean = {
+ partitionSchema.names.contains(attr.name)
+ }
+
def getPartitionFiltersAndDataFilters(partitionSchema: StructType,
normalizedFilters: Seq[Expression]): (Seq[Expression], Seq[Expression]) = {
val partitionColumns = normalizedFilters.flatMap { expr =>
expr.collect {
- case attr: AttributeReference if partitionSchema.names.contains(attr.name) =>
+ case attr: AttributeReference if isPartitionColumnReference(attr, partitionSchema) =>
attr
}
}
diff --git a/hudi-utilities/src/main/java/org/apache/hudi/utilities/HiveIncrementalPuller.java b/hudi-utilities/src/main/java/org/apache/hudi/utilities/HiveIncrementalPuller.java
index fede1b8fba030..2510edce72a8c 100644
--- a/hudi-utilities/src/main/java/org/apache/hudi/utilities/HiveIncrementalPuller.java
+++ b/hudi-utilities/src/main/java/org/apache/hudi/utilities/HiveIncrementalPuller.java
@@ -20,9 +20,9 @@
import org.apache.hudi.common.table.HoodieTableMetaClient;
import org.apache.hudi.common.table.timeline.HoodieInstant;
-import org.apache.hudi.common.util.FileIOUtils;
import org.apache.hudi.common.util.Option;
import org.apache.hudi.exception.HoodieException;
+import org.apache.hudi.hadoop.fs.HadoopFSUtils;
import org.apache.hudi.utilities.exception.HoodieIncrementalPullException;
import org.apache.hudi.utilities.exception.HoodieIncrementalPullSQLException;
@@ -50,6 +50,8 @@
import java.util.Scanner;
import java.util.stream.Collectors;
+import static org.apache.hudi.io.util.FileIOUtils.readAsUTFString;
+
/**
* Utility to pull data after a given commit, based on the supplied HiveQL and save the delta as another hive temporary
* table. This temporary table can be further read using {@link org.apache.hudi.utilities.sources.HiveIncrPullSource} and the changes can
@@ -115,7 +117,7 @@ public HiveIncrementalPuller(Config config) throws IOException {
this.config = config;
validateConfig(config);
String templateContent =
- FileIOUtils.readAsUTFString(this.getClass().getResourceAsStream("/IncrementalPull.sqltemplate"));
+ readAsUTFString(this.getClass().getResourceAsStream("/IncrementalPull.sqltemplate"));
incrementalPullSQLTemplate = new ST(templateContent);
}
@@ -298,12 +300,13 @@ private String scanForCommitTime(FileSystem fs, String targetDataPath) throws IO
if (!fs.exists(new Path(targetDataPath)) || !fs.exists(new Path(targetDataPath + "/.hoodie"))) {
return "0";
}
- HoodieTableMetaClient metadata = HoodieTableMetaClient.builder().setConf(fs.getConf()).setBasePath(targetDataPath).build();
+ HoodieTableMetaClient metadata = HoodieTableMetaClient.builder()
+ .setConf(HadoopFSUtils.getStorageConfWithCopy(fs.getConf())).setBasePath(targetDataPath).build();
Option