diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index df1049e..4c2d2a8 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -19,7 +19,7 @@ jobs: strategy: matrix: os: [ubuntu-latest] - java: [temurin@8] + java: [temurin@17] runs-on: ubuntu-latest steps: - name: Checkout current branch (full) @@ -27,12 +27,16 @@ jobs: with: fetch-depth: 0 - - name: Setup Java (temurin@8) - if: matrix.java == 'temurin@8' + # Java 17: required to run pluginspark4 tests against Spark 4.0.1 (Spark 4 needs ≥17). + # Spark 3.5 also supports Java 17, so pluginspark3 tests run on the same JVM. + # Scala 2.12/2.13 still emit Java 8 bytecode by default, so the produced jars stay + # compatible with Java 8 runtimes. + - name: Setup Java (temurin@17) + if: matrix.java == 'temurin@17' uses: actions/setup-java@v3 with: distribution: temurin - java-version: 8 + java-version: 17 cache: sbt - name: Set up sbt diff --git a/spark-plugin/build.sbt b/spark-plugin/build.sbt index 2f327eb..5973a44 100644 --- a/spark-plugin/build.sbt +++ b/spark-plugin/build.sbt @@ -165,7 +165,41 @@ lazy val pluginspark4 = (project in file("pluginspark4")) Compile / unmanagedSourceDirectories += (plugin / Compile / sourceDirectory).value / "scala", // Include resources from plugin directory for static UI files - Compile / unmanagedResourceDirectories += (plugin / Compile / resourceDirectory).value + Compile / unmanagedResourceDirectories += (plugin / Compile / resourceDirectory).value, + + // Test dependencies — Spark 4.0.1 + scalatest. Mirrors pluginspark3 so we can run the + // same regression suites against the Spark 4 surface (cross-version validation). + // Requires the launching JVM to be Java 17+ since Spark 4 won't run on Java 8/11. + libraryDependencies += "org.scalatest" %% "scalatest-funsuite" % "3.2.17" % Test, + libraryDependencies += "org.scalatest" %% "scalatest-shouldmatchers" % "3.2.17" % Test, + libraryDependencies += "org.apache.spark" %% "spark-core" % "4.0.1" % Test, + libraryDependencies += "org.apache.spark" %% "spark-sql" % "4.0.1" % Test, + + // Share version-portable test sources with pluginspark3. Most pluginspark3 specs + // depend on Spark-3-only internals (Dataset constructor, PythonMapInArrowExec, etc.) + // and don't compile against Spark 4, so we explicitly include only the suites that + // exercise version-stable surface area. + Test / unmanagedSourceDirectories += (plugin / Compile / sourceDirectory).value / "scala", + Test / unmanagedSources ++= { + val pluginspark3Tests = (pluginspark3 / Test / sourceDirectory).value / "scala" + Seq( + pluginspark3Tests / "org" / "apache" / "spark" / "dataflint" / "DataFlintCodegenFallbackSpec.scala" + ) + }, + + // Fork JVM for tests; Spark on Java 9+ requires the same --add-opens as pluginspark3. + Test / fork := true, + Test / parallelExecution := false, + Test / javaOptions ++= { + if (sys.props("java.specification.version").startsWith("1.")) Seq.empty + else Seq( + "--add-opens=java.base/java.lang=ALL-UNNAMED", + "--add-opens=java.base/java.nio=ALL-UNNAMED", + "--add-opens=java.base/sun.nio.ch=ALL-UNNAMED", + "--add-opens=java.base/java.util=ALL-UNNAMED", + "--add-opens=java.base/java.io=ALL-UNNAMED", + ) + } ) lazy val pluginspark4databricks = (project in file("pluginspark4databricks")) diff --git a/spark-plugin/plugin/src/main/scala/org/apache/spark/dataflint/MetricsUtils.scala b/spark-plugin/plugin/src/main/scala/org/apache/spark/dataflint/MetricsUtils.scala index a36d47f..cac5e77 100644 --- a/spark-plugin/plugin/src/main/scala/org/apache/spark/dataflint/MetricsUtils.scala +++ b/spark-plugin/plugin/src/main/scala/org/apache/spark/dataflint/MetricsUtils.scala @@ -80,6 +80,14 @@ object MetricsUtils { } } + /** + * Create a plain sum metric (displayed as a raw number in Spark UI). Useful for + * non-byte numeric values like the RDD id, where "size" formatting (B/KB/MB) would + * mis-display the value. + */ + def getSumMetric(name: String)(implicit sparkContext: SparkContext): (String, SQLMetric) = + name -> SQLMetrics.createMetric(sparkContext, name) + /** * Create a "timing" metric (displayed as milliseconds with total/min/med/max in Spark UI). * Used by TimedExec for the "duration" metric. diff --git a/spark-plugin/plugin/src/main/scala/org/apache/spark/dataflint/TimedExec.scala b/spark-plugin/plugin/src/main/scala/org/apache/spark/dataflint/TimedExec.scala index 04478c9..fb52276 100644 --- a/spark-plugin/plugin/src/main/scala/org/apache/spark/dataflint/TimedExec.scala +++ b/spark-plugin/plugin/src/main/scala/org/apache/spark/dataflint/TimedExec.scala @@ -3,7 +3,7 @@ package org.apache.spark.dataflint import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder} import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.util.DateTimeConstants.NANOS_PER_MILLIS @@ -57,11 +57,12 @@ class TimedExec(val child: SparkPlan) extends SparkPlan with Logging { override def outputOrdering: Seq[SortOrder] = child.outputOrdering override def supportsColumnar: Boolean = child.supportsColumnar - // Preserves ALL of child's existing metrics (spillSize, numOutputRows, etc.) + adds duration and rddId + // Preserves ALL of child's existing metrics (spillSize, numOutputRows, etc.) + adds duration and rddId. + // rddId uses a plain sum metric — a "size" metric would render it as bytes ("12 B") in the UI. override lazy val metrics: Map[String, SQLMetric] = child.metrics ++ Map( MetricsUtils.getTimingMetric("duration")(sparkContext), - MetricsUtils.getSizeMetric("rddId")(sparkContext) + MetricsUtils.getSumMetric("rddId")(sparkContext) ) // Delegate prepare() to child so that DataWritingCommandExec (and similar nodes that @@ -72,7 +73,11 @@ class TimedExec(val child: SparkPlan) extends SparkPlan with Logging { protected def postRddId(rddId: Int): Unit = { val rddIdMetric = longMetric("rddId") - rddIdMetric += rddId + // `set` instead of `+=` so re-execution of the same TimedExec instance overwrites + // the metric instead of accumulating. doExecute is invoked once per execute() call; + // a plan instance reused across queries (or AQE materialization) would otherwise + // sum every RDD id it ever wrapped. + rddIdMetric.set(rddId.toLong) MetricsUtils.postDriverMetrics(sparkContext, rddIdMetric) } @@ -104,25 +109,33 @@ class TimedExec(val child: SparkPlan) extends SparkPlan with Logging { override def executeCollect(): Array[InternalRow] = { if (child.getClass.getSimpleName == "DataWritingCommandExec") { val durationMetric = longMetric("duration") - val innerChild = child.children.head - val wrappedChild = if (innerChild.getClass.getSimpleName == "WriteFilesExec") { - // Spark 3.4+: wrap the data plan inside WriteFilesExec - val dataPlan = innerChild.children.head - val timedDataPlan = new TimedExec.RDDTimingWrapper(dataPlan, durationMetric) - val wrappedWriteFiles = innerChild.withNewChildren(IndexedSeq(timedDataPlan)) - child.withNewChildren(IndexedSeq(wrappedWriteFiles)) - } else { - // Older Spark: wrap the data plan directly - val timedDataPlan = new TimedExec.RDDTimingWrapper(innerChild, durationMetric) - child.withNewChildren(IndexedSeq(timedDataPlan)) + // child.children may be empty on unusual DataWritingCommandExec shapes (vendor + // forks, future Spark versions). Fall through to super.executeCollect() in that + // case — duration will be zero on the rebuild path, but the write still runs. + val maybeRebuilt: Option[SparkPlan] = child.children.headOption.flatMap { innerChild => + if (innerChild.getClass.getSimpleName == "WriteFilesExec") { + // Spark 3.4+: wrap the data plan inside WriteFilesExec + innerChild.children.headOption.map { dataPlan => + val timedDataPlan = new TimedExec.RDDTimingWrapper(dataPlan, durationMetric) + val wrappedWriteFiles = innerChild.withNewChildren(IndexedSeq(timedDataPlan)) + child.withNewChildren(IndexedSeq(wrappedWriteFiles)) + } + } else { + // Older Spark: wrap the data plan directly + val timedDataPlan = new TimedExec.RDDTimingWrapper(innerChild, durationMetric) + Some(child.withNewChildren(IndexedSeq(timedDataPlan))) + } } - wrappedChild.executeCollect() + maybeRebuilt.fold(super.executeCollect())(_.executeCollect()) } else { super.executeCollect() } } - override def canEqual(that: Any): Boolean = that.isInstanceOf[TimedExec] + // Match the runtime class so TimedExec(x) and TimedWithCodegenExec(x) don't compare + // equal — they have different execution semantics (codegen vs RDD path), and TreeNode + // equality / canonicalization use canEqual to decide plan reuse. + override def canEqual(that: Any): Boolean = that.getClass == this.getClass // productArity/productElement support makeCopy on Spark 3.0/3.1 (constructor arg = child) override def productArity: Int = 1 override def productElement(n: Int): Any = @@ -153,9 +166,21 @@ class TimedWithCodegenExec(override val child: SparkPlan) extends TimedExec(chil // On 3.2+ (transparent): children = child.children, multi-child nodes (joins) expose // multiple children which breaks codegen assumptions → restrict to single-child. // On 3.0/3.1 (non-transparent): children = Seq(child), always length 1 → no restriction. + // + // Mirror Spark's CollapseCodegenStages CodegenFallback check on `child.expressions`. + // The framework normally excludes plans whose expressions contain a CodegenFallback + // (e.g. JsonToStructs / from_json), but our transparent `children = child.children` + // hides `child` from that check, so we must do it ourselves — otherwise downstream + // CodegenFallback.doGenCode reads ctx.INPUT_ROW = null and NPEs in Block.code + // interpolation. (issue #74) override def supportCodegen: Boolean = { val c = child.asInstanceOf[CodegenSupport] - c.supportCodegen && (TimedExec.isLegacySpark || child.children.length <= 1) + // Use TreeNode.find (available since Spark 3.0) rather than TreeNode.exists, which + // was added in 3.2 — calling `.exists` on Expression NoSuchMethodErrors at runtime + // on Spark 3.0/3.1 even though it compiles fine against newer Spark headers. + c.supportCodegen && + (TimedExec.isLegacySpark || child.children.length <= 1) && + !child.expressions.exists(_.find(_.isInstanceOf[CodegenFallback]).isDefined) } override def needCopyResult: Boolean = child.asInstanceOf[CodegenSupport].needCopyResult @@ -201,9 +226,15 @@ object TimedExec { // Spark 3.0/3.1's withNewChildren uses mapProductIterator + containsChild which is // incompatible with the transparent wrapper (children = child.children). Detected by // checking for withNewChildrenInternal which was added in Spark 3.2. + // + // Wrap the parse in Try — any vendor distribution with a non-numeric major/minor (or + // a version string Spark doesn't expose at all) defaults to non-legacy, matching every + // released Spark line ≥ 3.2. val isLegacySpark: Boolean = { - val parts = org.apache.spark.SPARK_VERSION.split("\\.") - parts.length >= 2 && parts(0).toInt == 3 && parts(1).toInt < 2 + scala.util.Try { + val parts = org.apache.spark.SPARK_VERSION.split("\\.") + parts.length >= 2 && parts(0).toInt == 3 && parts(1).toInt < 2 + }.getOrElse(false) } def apply(child: SparkPlan): TimedExec = child match { @@ -215,6 +246,15 @@ object TimedExec { * A minimal SparkPlan that wraps execute() with per-partition duration timing. * Used by the write path: inserted inside WriteFilesExec (or as the direct child on older * Spark) so that the write command consumes a timed RDD per partition. + * + * Reconstruction note: `durationMetric` is intentionally NOT exposed via productElement. + * The constructor takes (child, durationMetric) but `productArity = 1` reports only + * `child`, so any code that tries to clone us via `makeCopy` would fail to supply + * `durationMetric`. We sidestep this by overriding `withNewChildrenInternal` to plumb + * the metric through manually — it's the only reconstruction path Spark 3.2+ takes for + * us. Spark 3.0/3.1's makeCopy-based path is not reached because this wrapper is only + * constructed inside `executeCollect` on Spark 3.4+ (WriteFilesExec branch) or as a + * direct child wrap on older Spark, neither of which round-trips through `makeCopy`. */ private[dataflint] class RDDTimingWrapper(val child: SparkPlan, durationMetric: SQLMetric) extends SparkPlan { override def output: Seq[Attribute] = child.output diff --git a/spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/DataFlintCodegenFallbackSpec.scala b/spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/DataFlintCodegenFallbackSpec.scala new file mode 100644 index 0000000..8e1c1b1 --- /dev/null +++ b/spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/DataFlintCodegenFallbackSpec.scala @@ -0,0 +1,71 @@ +package org.apache.spark.dataflint + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.functions.{col, explode, from_json} +import org.apache.spark.sql.types.{ArrayType, StringType, StructField, StructType} +import org.scalatest.BeforeAndAfterAll +import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.matchers.should.Matchers + +/** + * Regression test for issue #74: TimedWithCodegenExec must exclude itself from + * whole-stage codegen when its child contains a CodegenFallback expression (e.g. + * from_json / JsonToStructs). The transparent `children = child.children` hides + * the wrapped child from CollapseCodegenStages' CodegenFallback check, so that + * exclusion has to happen in TimedWithCodegenExec.supportCodegen. + * + * Without the fix, the wrapped child gets wholestaged anyway and CodegenFallback's + * generated code reads ctx.INPUT_ROW = null, NPE'ing in Block.code interpolation: + * java.lang.NullPointerException: Cannot invoke "Object.getClass()" because "arg" is null + */ +class DataFlintCodegenFallbackSpec extends AnyFunSuite with Matchers with BeforeAndAfterAll { + + private var spark: SparkSession = _ + + override def beforeAll(): Unit = { + spark = SparkSession.builder() + .master("local[1]") + .appName("DataFlintCodegenFallbackSpec") + .config(DataflintSparkUICommonLoader.INSTRUMENT_SQL_NODES_ENABLED, "true") + .config("spark.sql.codegen.wholeStage", "true") + .config("spark.sql.adaptive.enabled", "false") + .config("spark.ui.enabled", "false") + .withExtensions(new DataFlintInstrumentationExtension) + .getOrCreate() + } + + override def afterAll(): Unit = { + if (spark != null) spark.stop() + } + + test("from_json under whole-stage codegen does not NPE (issue #74)") { + val session = spark + import session.implicits._ + + val schema = ArrayType(StructType(Seq( + StructField("name", StringType, nullable = true), + StructField("kind", StringType, nullable = true) + ))) + + // Cache to materialize the rows so JsonToStructs is not constant-folded into the + // LocalTableScan — otherwise from_json never reaches whole-stage codegen and the bug + // doesn't fire. + val raw = Seq( + ("k1", """[{"name":"a","kind":"x"}]"""), + ("k2", null.asInstanceOf[String]) + ).toDF("key", "payload") + .repartition(1) + .cache() + raw.count() + + val rowCount: Long = raw + .filter(col("payload").isNotNull) + .withColumn("parsed", from_json(col("payload"), schema)) + .filter(col("parsed").isNotNull) + .select(col("key"), explode(col("parsed")).alias("d")) + .filter(col("d.name").isNotNull) + .count() + + rowCount shouldBe 1L + } +} diff --git a/spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/DataFlintWriteMetricsSpec.scala b/spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/DataFlintWriteMetricsSpec.scala new file mode 100644 index 0000000..ee08da4 --- /dev/null +++ b/spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/DataFlintWriteMetricsSpec.scala @@ -0,0 +1,100 @@ +package org.apache.spark.dataflint + +import java.util.concurrent.atomic.AtomicReference + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.util.QueryExecutionListener +import org.scalatest.BeforeAndAfterAll +import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.matchers.should.Matchers + +/** + * Regression test for the executeCollect write-path rebuild in TimedExec: + * TimedExec(DataWritingCommandExec(WriteFilesExec(dataPlan))) + * gets rebuilt at execute-time as + * NewDataWritingCommandExec(NewWriteFilesExec(RDDTimingWrapper(dataPlan))) + * and `executeCollect` runs on the rebuilt root. + * + * The rebuild is safe for metrics because: + * - DataWritingCommandExec.metrics delegates to cmd.metrics, and `withNewChildren` reuses + * the same cmd instance — so numOutputRows / numFiles / numOutputBytes / etc. are all + * the same SQLMetric objects on both the original and rebuilt nodes. + * - WriteFilesExec (Spark 3.4+) has no metrics of its own. + * - The data plan is shared by reference via RDDTimingWrapper, so its metrics are the + * original instances. + * + * If a future change breaks any of these assumptions, the metrics displayed on the original + * (UI-visible) plan tree would go stale; this spec catches that. + */ +class DataFlintWriteMetricsSpec extends AnyFunSuite with Matchers with BeforeAndAfterAll { + + private var spark: SparkSession = _ + private val capturedQE = new AtomicReference[QueryExecution](null) + + private val listener = new QueryExecutionListener { + override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = + capturedQE.set(qe) + override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = () + } + + override def beforeAll(): Unit = { + spark = SparkSession.builder() + .master("local[1]") + .appName("DataFlintWriteMetricsSpec") + .config(DataflintSparkUICommonLoader.INSTRUMENT_SQL_NODES_ENABLED, "true") + .config("spark.sql.adaptive.enabled", "false") + .config("spark.ui.enabled", "false") + .withExtensions(new DataFlintInstrumentationExtension) + .getOrCreate() + spark.listenerManager.register(listener) + } + + override def afterAll(): Unit = { + if (spark != null) { + spark.listenerManager.unregister(listener) + spark.stop() + } + } + + private def deleteRecursively(f: java.io.File): Unit = { + if (f.isDirectory) Option(f.listFiles).foreach(_.foreach(deleteRecursively)) + f.delete() + } + + test("DataWritingCommandExec metrics survive the executeCollect rebuild") { + val tempDir = java.nio.file.Files.createTempDirectory("dataflint-write-test").toFile + try { + capturedQE.set(null) + val session = spark + import session.implicits._ + Seq(("a", 1), ("b", 2), ("c", 3), ("d", 4)).toDF("key", "value") + .write.mode("overwrite").parquet(tempDir.getAbsolutePath) + + val qe = capturedQE.get() + qe should not be null + + // The DataFlint-wrapped DataWritingCommandExec — metrics on this node feed the SparkUI. + val timedDwce = qe.executedPlan.collect { + case t: TimedExec if t.child.getClass.getSimpleName == "DataWritingCommandExec" => t + }.headOption.getOrElse(fail(s"No TimedExec(DataWritingCommandExec) in plan:\n${qe.executedPlan.treeString}")) + + // cmd.metrics is shared via withNewChildren, so these must reach the original node. + val m = timedDwce.metrics + withClue(s"metrics: ${m.map { case (k, v) => s"$k=${v.value}" }.mkString(", ")}\n") { + m("numOutputRows").value shouldBe 4L + m("numFiles").value shouldBe 1L + m("numOutputBytes").value should be > 0L + m("duration").value should be >= 0L + } + + // The wrapped data plan shares its metrics with the rebuilt RDDTimingWrapper subtree. + val scan = qe.executedPlan.collect { + case n if n.getClass.getSimpleName == "LocalTableScanExec" => n + }.headOption.getOrElse(fail("No LocalTableScanExec in plan")) + scan.metrics("numOutputRows").value shouldBe 4L + } finally { + deleteRecursively(tempDir) + } + } +} diff --git a/spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/TimedExecMetricsSpec.scala b/spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/TimedExecMetricsSpec.scala new file mode 100644 index 0000000..bfb49e7 --- /dev/null +++ b/spark-plugin/pluginspark3/src/test/scala/org/apache/spark/dataflint/TimedExecMetricsSpec.scala @@ -0,0 +1,67 @@ +package org.apache.spark.dataflint + +import org.apache.spark.sql.SparkSession +import org.scalatest.BeforeAndAfterAll +import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.matchers.should.Matchers + +/** + * Unit tests for TimedExec metric semantics. + */ +class TimedExecMetricsSpec extends AnyFunSuite with Matchers with BeforeAndAfterAll { + + private var spark: SparkSession = _ + + override def beforeAll(): Unit = { + spark = SparkSession.builder() + .master("local[1]") + .appName("TimedExecMetricsSpec") + .config(DataflintSparkUICommonLoader.INSTRUMENT_SQL_NODES_ENABLED, "true") + .config("spark.sql.adaptive.enabled", "false") + .config("spark.ui.enabled", "false") + .withExtensions(new DataFlintInstrumentationExtension) + .getOrCreate() + } + + override def afterAll(): Unit = { + if (spark != null) spark.stop() + } + + test("TimedExec and TimedWithCodegenExec do not compare equal even when wrapping the same child") { + // FilterExec is CodegenSupport, so TimedExec.apply picks TimedWithCodegenExec for it. + // Pair it with a hand-built plain TimedExec wrapping the same child to exercise the + // canEqual / equals contract. + val df = spark.range(0, 5, 1, 1).filter("id > 1") + df.collect() + val timedCodegen = df.queryExecution.executedPlan.collect { + case t: TimedWithCodegenExec if t.child.getClass.getSimpleName == "FilterExec" => t + }.headOption.getOrElse(fail("no TimedWithCodegenExec(FilterExec) in plan")) + + val plain = new TimedExec(timedCodegen.child) + + plain.canEqual(timedCodegen) shouldBe false + timedCodegen.canEqual(plain) shouldBe false + (plain == timedCodegen) shouldBe false + } + + test("postRddId overwrites the rddId metric instead of accumulating") { + // Build a real TimedExec from the rule so we exercise the actual path. + val df = spark.range(0, 10, 1, 1).filter("id > 5") + df.collect() + val timed = df.queryExecution.executedPlan.collect { + case t: TimedExec if t.child.getClass.getSimpleName == "FilterExec" => t + }.headOption.getOrElse(fail("no TimedExec(FilterExec) in plan")) + + // postRddId is `protected`; reach it via reflection so we can exercise it with + // deterministic inputs and not couple the test to RDD-id allocation order. + val postRddId = classOf[TimedExec].getDeclaredMethod("postRddId", classOf[Int]) + postRddId.setAccessible(true) + + postRddId.invoke(timed, Int.box(100)) + timed.metrics("rddId").value shouldBe 100L + + postRddId.invoke(timed, Int.box(200)) + // With `+=` semantics this would be 300; with `set` semantics it is 200. + timed.metrics("rddId").value shouldBe 200L + } +}