Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,24 @@ jobs:
strategy:
matrix:
os: [ubuntu-latest]
java: [temurin@8]
java: [temurin@17]
runs-on: ubuntu-latest
steps:
- name: Checkout current branch (full)
uses: actions/checkout@v4
with:
fetch-depth: 0

- name: Setup Java (temurin@8)
if: matrix.java == 'temurin@8'
# Java 17: required to run pluginspark4 tests against Spark 4.0.1 (Spark 4 needs ≥17).
# Spark 3.5 also supports Java 17, so pluginspark3 tests run on the same JVM.
# Scala 2.12/2.13 still emit Java 8 bytecode by default, so the produced jars stay
# compatible with Java 8 runtimes.
- name: Setup Java (temurin@17)
if: matrix.java == 'temurin@17'
uses: actions/setup-java@v3
with:
distribution: temurin
java-version: 8
java-version: 17
cache: sbt

- name: Set up sbt
Expand Down
36 changes: 35 additions & 1 deletion spark-plugin/build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,41 @@ lazy val pluginspark4 = (project in file("pluginspark4"))
Compile / unmanagedSourceDirectories += (plugin / Compile / sourceDirectory).value / "scala",

// Include resources from plugin directory for static UI files
Compile / unmanagedResourceDirectories += (plugin / Compile / resourceDirectory).value
Compile / unmanagedResourceDirectories += (plugin / Compile / resourceDirectory).value,

// Test dependencies — Spark 4.0.1 + scalatest. Mirrors pluginspark3 so we can run the
// same regression suites against the Spark 4 surface (cross-version validation).
// Requires the launching JVM to be Java 17+ since Spark 4 won't run on Java 8/11.
libraryDependencies += "org.scalatest" %% "scalatest-funsuite" % "3.2.17" % Test,
libraryDependencies += "org.scalatest" %% "scalatest-shouldmatchers" % "3.2.17" % Test,
libraryDependencies += "org.apache.spark" %% "spark-core" % "4.0.1" % Test,
libraryDependencies += "org.apache.spark" %% "spark-sql" % "4.0.1" % Test,

// Share version-portable test sources with pluginspark3. Most pluginspark3 specs
// depend on Spark-3-only internals (Dataset constructor, PythonMapInArrowExec, etc.)
// and don't compile against Spark 4, so we explicitly include only the suites that
// exercise version-stable surface area.
Test / unmanagedSourceDirectories += (plugin / Compile / sourceDirectory).value / "scala",
Test / unmanagedSources ++= {
val pluginspark3Tests = (pluginspark3 / Test / sourceDirectory).value / "scala"
Seq(
pluginspark3Tests / "org" / "apache" / "spark" / "dataflint" / "DataFlintCodegenFallbackSpec.scala"
)
},

// Fork JVM for tests; Spark on Java 9+ requires the same --add-opens as pluginspark3.
Test / fork := true,
Test / parallelExecution := false,
Test / javaOptions ++= {
if (sys.props("java.specification.version").startsWith("1.")) Seq.empty
else Seq(
"--add-opens=java.base/java.lang=ALL-UNNAMED",
"--add-opens=java.base/java.nio=ALL-UNNAMED",
"--add-opens=java.base/sun.nio.ch=ALL-UNNAMED",
"--add-opens=java.base/java.util=ALL-UNNAMED",
"--add-opens=java.base/java.io=ALL-UNNAMED",
)
}
)

lazy val pluginspark4databricks = (project in file("pluginspark4databricks"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,14 @@ object MetricsUtils {
}
}

/**
* Create a plain sum metric (displayed as a raw number in Spark UI). Useful for
* non-byte numeric values like the RDD id, where "size" formatting (B/KB/MB) would
* mis-display the value.
*/
def getSumMetric(name: String)(implicit sparkContext: SparkContext): (String, SQLMetric) =
name -> SQLMetrics.createMetric(sparkContext, name)

/**
* Create a "timing" metric (displayed as milliseconds with total/min/med/max in Spark UI).
* Used by TimedExec for the "duration" metric.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package org.apache.spark.dataflint
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode}
import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder}
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.util.DateTimeConstants.NANOS_PER_MILLIS
Expand Down Expand Up @@ -57,11 +57,12 @@ class TimedExec(val child: SparkPlan) extends SparkPlan with Logging {
override def outputOrdering: Seq[SortOrder] = child.outputOrdering
override def supportsColumnar: Boolean = child.supportsColumnar

// Preserves ALL of child's existing metrics (spillSize, numOutputRows, etc.) + adds duration and rddId
// Preserves ALL of child's existing metrics (spillSize, numOutputRows, etc.) + adds duration and rddId.
// rddId uses a plain sum metric — a "size" metric would render it as bytes ("12 B") in the UI.
override lazy val metrics: Map[String, SQLMetric] =
child.metrics ++ Map(
MetricsUtils.getTimingMetric("duration")(sparkContext),
MetricsUtils.getSizeMetric("rddId")(sparkContext)
MetricsUtils.getSumMetric("rddId")(sparkContext)
)

// Delegate prepare() to child so that DataWritingCommandExec (and similar nodes that
Expand All @@ -72,7 +73,11 @@ class TimedExec(val child: SparkPlan) extends SparkPlan with Logging {

protected def postRddId(rddId: Int): Unit = {
val rddIdMetric = longMetric("rddId")
rddIdMetric += rddId
// `set` instead of `+=` so re-execution of the same TimedExec instance overwrites
// the metric instead of accumulating. doExecute is invoked once per execute() call;
// a plan instance reused across queries (or AQE materialization) would otherwise
// sum every RDD id it ever wrapped.
rddIdMetric.set(rddId.toLong)
MetricsUtils.postDriverMetrics(sparkContext, rddIdMetric)
}

Expand Down Expand Up @@ -104,25 +109,33 @@ class TimedExec(val child: SparkPlan) extends SparkPlan with Logging {
override def executeCollect(): Array[InternalRow] = {
if (child.getClass.getSimpleName == "DataWritingCommandExec") {
val durationMetric = longMetric("duration")
val innerChild = child.children.head
val wrappedChild = if (innerChild.getClass.getSimpleName == "WriteFilesExec") {
// Spark 3.4+: wrap the data plan inside WriteFilesExec
val dataPlan = innerChild.children.head
val timedDataPlan = new TimedExec.RDDTimingWrapper(dataPlan, durationMetric)
val wrappedWriteFiles = innerChild.withNewChildren(IndexedSeq(timedDataPlan))
child.withNewChildren(IndexedSeq(wrappedWriteFiles))
} else {
// Older Spark: wrap the data plan directly
val timedDataPlan = new TimedExec.RDDTimingWrapper(innerChild, durationMetric)
child.withNewChildren(IndexedSeq(timedDataPlan))
// child.children may be empty on unusual DataWritingCommandExec shapes (vendor
// forks, future Spark versions). Fall through to super.executeCollect() in that
// case — duration will be zero on the rebuild path, but the write still runs.
val maybeRebuilt: Option[SparkPlan] = child.children.headOption.flatMap { innerChild =>
if (innerChild.getClass.getSimpleName == "WriteFilesExec") {
// Spark 3.4+: wrap the data plan inside WriteFilesExec
innerChild.children.headOption.map { dataPlan =>
val timedDataPlan = new TimedExec.RDDTimingWrapper(dataPlan, durationMetric)
val wrappedWriteFiles = innerChild.withNewChildren(IndexedSeq(timedDataPlan))
child.withNewChildren(IndexedSeq(wrappedWriteFiles))
}
} else {
// Older Spark: wrap the data plan directly
val timedDataPlan = new TimedExec.RDDTimingWrapper(innerChild, durationMetric)
Some(child.withNewChildren(IndexedSeq(timedDataPlan)))
}
}
wrappedChild.executeCollect()
maybeRebuilt.fold(super.executeCollect())(_.executeCollect())
} else {
super.executeCollect()
}
}

override def canEqual(that: Any): Boolean = that.isInstanceOf[TimedExec]
// Match the runtime class so TimedExec(x) and TimedWithCodegenExec(x) don't compare
// equal — they have different execution semantics (codegen vs RDD path), and TreeNode
// equality / canonicalization use canEqual to decide plan reuse.
override def canEqual(that: Any): Boolean = that.getClass == this.getClass
// productArity/productElement support makeCopy on Spark 3.0/3.1 (constructor arg = child)
override def productArity: Int = 1
override def productElement(n: Int): Any =
Expand Down Expand Up @@ -153,9 +166,21 @@ class TimedWithCodegenExec(override val child: SparkPlan) extends TimedExec(chil
// On 3.2+ (transparent): children = child.children, multi-child nodes (joins) expose
// multiple children which breaks codegen assumptions → restrict to single-child.
// On 3.0/3.1 (non-transparent): children = Seq(child), always length 1 → no restriction.
//
// Mirror Spark's CollapseCodegenStages CodegenFallback check on `child.expressions`.
// The framework normally excludes plans whose expressions contain a CodegenFallback
// (e.g. JsonToStructs / from_json), but our transparent `children = child.children`
// hides `child` from that check, so we must do it ourselves — otherwise downstream
// CodegenFallback.doGenCode reads ctx.INPUT_ROW = null and NPEs in Block.code
// interpolation. (issue #74)
override def supportCodegen: Boolean = {
val c = child.asInstanceOf[CodegenSupport]
c.supportCodegen && (TimedExec.isLegacySpark || child.children.length <= 1)
// Use TreeNode.find (available since Spark 3.0) rather than TreeNode.exists, which
// was added in 3.2 — calling `.exists` on Expression NoSuchMethodErrors at runtime
// on Spark 3.0/3.1 even though it compiles fine against newer Spark headers.
c.supportCodegen &&
(TimedExec.isLegacySpark || child.children.length <= 1) &&
!child.expressions.exists(_.find(_.isInstanceOf[CodegenFallback]).isDefined)
}

override def needCopyResult: Boolean = child.asInstanceOf[CodegenSupport].needCopyResult
Expand Down Expand Up @@ -201,9 +226,15 @@ object TimedExec {
// Spark 3.0/3.1's withNewChildren uses mapProductIterator + containsChild which is
// incompatible with the transparent wrapper (children = child.children). Detected by
// checking for withNewChildrenInternal which was added in Spark 3.2.
//
// Wrap the parse in Try — any vendor distribution with a non-numeric major/minor (or
// a version string Spark doesn't expose at all) defaults to non-legacy, matching every
// released Spark line ≥ 3.2.
val isLegacySpark: Boolean = {
val parts = org.apache.spark.SPARK_VERSION.split("\\.")
parts.length >= 2 && parts(0).toInt == 3 && parts(1).toInt < 2
scala.util.Try {
val parts = org.apache.spark.SPARK_VERSION.split("\\.")
parts.length >= 2 && parts(0).toInt == 3 && parts(1).toInt < 2
}.getOrElse(false)
}

def apply(child: SparkPlan): TimedExec = child match {
Expand All @@ -215,6 +246,15 @@ object TimedExec {
* A minimal SparkPlan that wraps execute() with per-partition duration timing.
* Used by the write path: inserted inside WriteFilesExec (or as the direct child on older
* Spark) so that the write command consumes a timed RDD per partition.
*
* Reconstruction note: `durationMetric` is intentionally NOT exposed via productElement.
* The constructor takes (child, durationMetric) but `productArity = 1` reports only
* `child`, so any code that tries to clone us via `makeCopy` would fail to supply
* `durationMetric`. We sidestep this by overriding `withNewChildrenInternal` to plumb
* the metric through manually — it's the only reconstruction path Spark 3.2+ takes for
* us. Spark 3.0/3.1's makeCopy-based path is not reached because this wrapper is only
* constructed inside `executeCollect` on Spark 3.4+ (WriteFilesExec branch) or as a
* direct child wrap on older Spark, neither of which round-trips through `makeCopy`.
*/
private[dataflint] class RDDTimingWrapper(val child: SparkPlan, durationMetric: SQLMetric) extends SparkPlan {
override def output: Seq[Attribute] = child.output
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package org.apache.spark.dataflint

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.{col, explode, from_json}
import org.apache.spark.sql.types.{ArrayType, StringType, StructField, StructType}
import org.scalatest.BeforeAndAfterAll
import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.matchers.should.Matchers

/**
* Regression test for issue #74: TimedWithCodegenExec must exclude itself from
* whole-stage codegen when its child contains a CodegenFallback expression (e.g.
* from_json / JsonToStructs). The transparent `children = child.children` hides
* the wrapped child from CollapseCodegenStages' CodegenFallback check, so that
* exclusion has to happen in TimedWithCodegenExec.supportCodegen.
*
* Without the fix, the wrapped child gets wholestaged anyway and CodegenFallback's
* generated code reads ctx.INPUT_ROW = null, NPE'ing in Block.code interpolation:
* java.lang.NullPointerException: Cannot invoke "Object.getClass()" because "arg" is null
*/
class DataFlintCodegenFallbackSpec extends AnyFunSuite with Matchers with BeforeAndAfterAll {

private var spark: SparkSession = _

override def beforeAll(): Unit = {
spark = SparkSession.builder()
.master("local[1]")
.appName("DataFlintCodegenFallbackSpec")
.config(DataflintSparkUICommonLoader.INSTRUMENT_SQL_NODES_ENABLED, "true")
.config("spark.sql.codegen.wholeStage", "true")
.config("spark.sql.adaptive.enabled", "false")
.config("spark.ui.enabled", "false")
.withExtensions(new DataFlintInstrumentationExtension)
.getOrCreate()
}

override def afterAll(): Unit = {
if (spark != null) spark.stop()
}

test("from_json under whole-stage codegen does not NPE (issue #74)") {
val session = spark
import session.implicits._

val schema = ArrayType(StructType(Seq(
StructField("name", StringType, nullable = true),
StructField("kind", StringType, nullable = true)
)))

// Cache to materialize the rows so JsonToStructs is not constant-folded into the
// LocalTableScan — otherwise from_json never reaches whole-stage codegen and the bug
// doesn't fire.
val raw = Seq(
("k1", """[{"name":"a","kind":"x"}]"""),
("k2", null.asInstanceOf[String])
).toDF("key", "payload")
.repartition(1)
.cache()
raw.count()

val rowCount: Long = raw
.filter(col("payload").isNotNull)
.withColumn("parsed", from_json(col("payload"), schema))
.filter(col("parsed").isNotNull)
.select(col("key"), explode(col("parsed")).alias("d"))
.filter(col("d.name").isNotNull)
.count()

rowCount shouldBe 1L
}
}
Loading
Loading