Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,13 @@ class TimedExec(val child: SparkPlan) extends SparkPlan with Logging {
override def nodeName: String = "DataFlint" + child.nodeName
override def output: Seq[Attribute] = child.output

// Expose child's children directly so TimedExec appears as a single node in the plan graph.
// The wrapped child is not visible in the tree; plan transformations see and update the
// grandchildren directly, and withNewChildrenInternal rebuilds child with the new ones.
override def children: Seq[SparkPlan] = child.children
// On Spark 3.2+: transparent — children = child.children, so TimedExec appears as one node
// in the plan graph. withNewChildrenInternal rebuilds child with the new grandchildren.
// On Spark 3.0/3.1: non-transparent — children = Seq(child), because 3.1's withNewChildren
// maps product elements via containsChild which can't see through the transparent wrapper.
// Shows two nodes in the plan graph, but plan transformations work correctly.
override def children: Seq[SparkPlan] =
if (TimedExec.isLegacySpark) Seq(child) else child.children

override def outputPartitioning: Partitioning = child.outputPartitioning
override def outputOrdering: Seq[SortOrder] = child.outputOrdering
Expand Down Expand Up @@ -125,10 +128,12 @@ class TimedExec(val child: SparkPlan) extends SparkPlan with Logging {
override def productElement(n: Int): Any =
if (n == 0) child else throw new IndexOutOfBoundsException(s"$n")

// When Spark updates our children (= child's children), rebuild child with new children.
// Uses TimedExec.apply to pick the right variant (with or without codegen).
// On 3.2+: children = child.children, so newChildren are the grandchildren → rebuild child.
// On 3.0/3.1: children = Seq(child), so newChildren has one element → the new child itself.
// (3.0/3.1 doesn't call withNewChildrenInternal, but makeCopy handles it via productElement.)
override protected def withNewChildrenInternal(newChildren: IndexedSeq[SparkPlan]): SparkPlan =
TimedExec(child.withNewChildren(newChildren))
if (TimedExec.isLegacySpark) TimedExec(newChildren.head)
else TimedExec(child.withNewChildren(newChildren))
}

/**
Expand All @@ -144,9 +149,12 @@ class TimedWithCodegenExec(override val child: SparkPlan) extends TimedExec(chil
rdds
}

// On 3.2+ (transparent): children = child.children, multi-child nodes (joins) expose
// multiple children which breaks codegen assumptions → restrict to single-child.
// On 3.0/3.1 (non-transparent): children = Seq(child), always length 1 → no restriction.
override def supportCodegen: Boolean = {
val c = child.asInstanceOf[CodegenSupport]
c.supportCodegen && child.children.length <= 1
c.supportCodegen && (TimedExec.isLegacySpark || child.children.length <= 1)
}

override def needCopyResult: Boolean = child.asInstanceOf[CodegenSupport].needCopyResult
Expand Down Expand Up @@ -184,10 +192,19 @@ class TimedWithCodegenExec(override val child: SparkPlan) extends TimedExec(chil
consume(ctx, input)

override protected def withNewChildrenInternal(newChildren: IndexedSeq[SparkPlan]): SparkPlan =
TimedExec(child.withNewChildren(newChildren))
if (TimedExec.isLegacySpark) TimedExec(newChildren.head)
else TimedExec(child.withNewChildren(newChildren))
}

object TimedExec {
// Spark 3.0/3.1's withNewChildren uses mapProductIterator + containsChild which is
// incompatible with the transparent wrapper (children = child.children). Detected by
// checking for withNewChildrenInternal which was added in Spark 3.2.
val isLegacySpark: Boolean = {
val parts = org.apache.spark.SPARK_VERSION.split("\\.")
parts.length >= 2 && parts(0).toInt == 3 && parts(1).toInt < 2
}

def apply(child: SparkPlan): TimedExec = child match {
case _: CodegenSupport => new TimedWithCodegenExec(child)
case _ => new TimedExec(child)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package org.apache.spark.dataflint

import org.apache.spark.SPARK_VERSION
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{SparkSession, SparkSessionExtensions}
import org.apache.spark.sql.catalyst.rules.Rule
Expand Down Expand Up @@ -37,31 +36,16 @@ case class DataFlintInstrumentationColumnarRule(session: SparkSession) extends C
// Eagerly compute the set of node simple-class-names to wrap, respecting per-type flags.
// When the global flag is on everything is enabled; otherwise only nodes whose specific
// flag is enabled are included.
// TimedExec uses a transparent wrapper pattern (children = child.children) that is incompatible
// with Spark 3.0/3.1's withNewChildren (which maps product elements via containsChild).
// On 3.0/3.1, CollapseCodegenStages cannot update children through the wrapper, causing
// ClassCastExceptions. SQL nodes participate in codegen pipelines and are affected;
// Python exec nodes do not participate in codegen and are safe to instrument on all versions.
private val isLegacySpark: Boolean = {
val parts = SPARK_VERSION.split("\\.")
parts.length >= 2 && parts(0) == "3" && (parts(1) == "0" || parts(1) == "1")
}

private val enabledNodeNames: Set[String] = {
val conf = session.sparkContext.conf
val globalEnabled = conf.getBoolean(DataflintSparkUICommonLoader.INSTRUMENT_SPARK_ENABLED, defaultValue = false)
val sqlNodes = if (isLegacySpark) {
logInfo("DataFlint: Spark 3.0/3.1 detected — skipping SQL node instrumentation (codegen incompatibility)")
Set.empty[String]
} else {
Set(
"FilterExec", "ProjectExec", "ExpandExec", "GenerateExec",
"SortMergeJoinExec", "BroadcastHashJoinExec", "BroadcastNestedLoopJoinExec",
"CartesianProductExec", "WindowGroupLimitExec", "SortAggregateExec", "SortExec", "HashAggregateExec",
"DataWritingCommandExec",
"FileSourceScanExec", "RowDataSourceScanExec", "BatchScanExec", "RDDScanExec",
)
}
val sqlNodes = Set(
"FilterExec", "ProjectExec", "ExpandExec", "GenerateExec",
"SortMergeJoinExec", "BroadcastHashJoinExec", "BroadcastNestedLoopJoinExec",
"CartesianProductExec", "WindowGroupLimitExec", "SortAggregateExec", "SortExec", "HashAggregateExec",
"DataWritingCommandExec",
"FileSourceScanExec", "RowDataSourceScanExec", "BatchScanExec", "RDDScanExec",
)
val all = Set(
"BatchEvalPythonExec",
"ArrowEvalPythonExec",
Expand Down
26 changes: 23 additions & 3 deletions spark-plugin/pyspark-testing/dataflint_pyspark_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DoubleType
import time

SLEEP_ENABLED = True
SLEEP_ENABLED = False

def sleep(seconds):
if SLEEP_ENABLED:
Expand Down Expand Up @@ -64,7 +64,7 @@ def sleep(seconds):
.config("spark.plugins", "io.dataflint.spark.SparkDataflintPlugin") \
.config("spark.ui.port", "10000") \
.config("spark.sql.maxMetadataStringLength", "10000") \
.config("spark.sql.adaptive.enabled", "false") \
.config("spark.sql.adaptive.enabled", "true") \
.config("spark.dataflint.telemetry.enabled", "false") \
.config("spark.dataflint.instrument.spark.mapInPandas.enabled", instrument) \
.config("spark.dataflint.instrument.spark.mapInArrow.enabled", instrument) \
Expand Down Expand Up @@ -229,7 +229,7 @@ def compute_discounted_totals_arrow(iterator):

# slow_sum is a Scala UDAF registered by DataFlintInstrumentationExtension.
# It sums Doubles but sleeps `spark.dataflint.test.slowSumSleepMs` per row on the JVM.
df_window = df.withColumn("rank_in_category", rank().over(window_by_category)) \
df_window = df.limit(1000).withColumn("rank_in_category", rank().over(window_by_category)) \
.withColumn("cumulative_slow_revenue", expr("slow_sum(price)").over(window_by_category)) \
.withColumn("avg_price_in_category", expr("slow_sum(price)").over(window_category_total))

Expand Down Expand Up @@ -426,6 +426,26 @@ def apply_category_discount(left_pdf, right_pdf):

from pyspark.sql.functions import explode, array, collect_list, row_number, lit

# ── Python UDF over Parquet (columnar scan + shuffle) ────────────────────────
# Reproduces ColumnarBatch→InternalRow ClassCastException on Spark 3.0/3.1 when
# the transparent wrapper prevents ColumnarToRowExec insertion.
print("="*80)
print("Running Python UDF over Parquet source (columnar scan + shuffle)")
print("="*80)
spark.sparkContext.setJobDescription("Python UDF over Parquet: columnar scan + ArrowEvalPython + shuffle")
df.write.mode("overwrite").parquet("/tmp/dataflint_columnar_test_input")
df_parquet = spark.read.parquet("/tmp/dataflint_columnar_test_input")

@pandas_udf(DoubleType())
def double_price(price: pd.Series) -> pd.Series:
return price * 2.0

df_parquet.withColumn("doubled_price", double_price("price")) \
.groupBy("category") \
.agg(spark_sum("doubled_price").alias("total")) \
.write.mode("overwrite").parquet("/tmp/dataflint_columnar_udf_shuffle_example")
print("\nResult written to /tmp/dataflint_columnar_udf_shuffle_example")

# ── FilterExec + ProjectExec ──────────────────────────────────────────────────
print("="*80)
print("Running FilterExec + ProjectExec example")
Expand Down
2 changes: 1 addition & 1 deletion spark-ui/src/reducers/PlanParsers/WindowParser.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import {

export function parseWindow(input: string): ParsedWindowPlan {
// Improved regex to correctly capture each part of the window specification
const regex = /Window \[(.*?)\](?:,\s*\[(.*?)\])?(?:,\s*\[(.*?)\])?/;
const regex = /\w+ \[(.*?)\](?:,\s*\[(.*?)\])?(?:,\s*\[(.*?)\])?/;

// Remove any unwanted hash numbers
const sanitizedInput = hashNumbersRemover(input);
Expand Down
26 changes: 25 additions & 1 deletion spark-ui/src/reducers/PlanParsers/batchEvalPythonParser.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -79,4 +79,28 @@ describe("parseBatchEvalPython", () => {
};
expect(parseBatchEvalPython(input)).toEqual(expected);
});
});

it("should parse MapInPandas with bare function name", () => {
const input = "MapInPandas compute_discounted_totals_pandas(customer#1, category#2, quantity#3, price#4)#9, [customer#10, category#11], false";
expect(parseBatchEvalPython(input)).toEqual({
functionNames: ["compute_discounted_totals_pandas"],
udfNames: [],
});
});

it("should parse FlatMapGroupsInPandas with grouped function name", () => {
const input = "FlatMapGroupsInPandas [category#2], enrich_group(customer#1, category#2, quantity#3, price#4)#217, [customer#218, category#219]";
expect(parseBatchEvalPython(input)).toEqual({
functionNames: ["enrich_group"],
udfNames: [],
});
});

it("should parse FlatMapCoGroupsInPandas with cogroup function name", () => {
const input = "FlatMapCoGroupsInPandas [category#2], [category#247], apply_category_discount(customer#1, category#2)#257, [customer#258, category#259]";
expect(parseBatchEvalPython(input)).toEqual({
functionNames: ["apply_category_discount"],
udfNames: [],
});
});
});
42 changes: 22 additions & 20 deletions spark-ui/src/reducers/PlanParsers/batchEvalPythonParser.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,32 @@ export function parseBatchEvalPython(input: string): ParsedBatchEvalPythonPlan {
// Remove hash numbers for cleaner parsing
const cleanedInput = hashNumbersRemover(input);

// Pattern: Any text followed by [first_list], [second_list] - with flexible whitespace and optional content at end
const regex = /^.*?\[\s*(.*?)\s*\]\s*,\s*\[\s*(.*?)\s*\].*$/;
// BatchEvalPython/ArrowEvalPython: [func1(col), func2(col)], [pythonUDF0, pythonUDF1]
// The first bracket must contain function calls (parentheses) or be empty to distinguish
// from FlatMapCoGroupsInPandas which has [group_keys], [group_keys] before the function.
const regex = /^.*?\[\s*((?:.*?\(.*?)?)\s*\]\s*,\s*\[\s*(.*?)\s*\].*$/;
const match = cleanedInput.match(regex);

if (!match) {
throw new Error("Invalid Python evaluation input format");
if (match) {
const [, functionNamesStr, udfNamesStr] = match;
const functionNames = functionNamesStr?.trim()
? bracedSplit(functionNamesStr).map(name => name.trim())
: [];
const udfNames = udfNamesStr?.trim()
? bracedSplit(udfNamesStr).map(name => name.trim())
: [];
return { functionNames, udfNames };
}

const [, functionNamesStr, udfNamesStr] = match;

// Parse the function names list
let functionNames: string[] = [];
if (functionNamesStr && functionNamesStr.trim()) {
functionNames = bracedSplit(functionNamesStr).map(name => name.trim());
}

// Parse the UDF names list
let udfNames: string[] = [];
if (udfNamesStr && udfNamesStr.trim()) {
udfNames = bracedSplit(udfNamesStr).map(name => name.trim());
// MapInPandas/FlatMapGroupsInPandas/FlatMapCoGroupsInPandas have the function name
// as a bare identifier outside brackets:
// "MapInPandas compute_func(col1, col2), [output_cols], false"
// "FlatMapGroupsInPandas [group_keys], enrich_group(col1, col2), [output_cols]"
// "FlatMapCoGroupsInPandas [left_keys], [right_keys], func(cols), [output_cols]"
const funcMatch = cleanedInput.match(/\b(\w+)\([^)]*\)[^,]*,\s*\[/);
if (funcMatch) {
return { functionNames: [funcMatch[1]], udfNames: [] };
}

return {
functionNames,
udfNames,
};
throw new Error("Invalid Python evaluation input format");
}
53 changes: 52 additions & 1 deletion spark-ui/src/reducers/SqlReducer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,44 @@ function calculateSql(
stages: SparkStagesStore,
): EnrichedSparkSQL {
const enrichedSql = sql as EnrichedSparkSQL;

// Capture node count before merge so the incremental update path (case 3) doesn't
// see a mismatch between pre-merge API count and post-merge stored count.
const originalNumOfNodes = enrichedSql.nodes.length;

// Merge duplicate nodes from non-transparent TimedExec wrapper (Spark 3.0/3.1).
// On legacy Spark, TimedExec uses children=Seq(child) which creates two nodes in the plan:
// "DataFlintFilter" (wrapper with duration metric) ← "Filter" (actual child with plan info)
// Strategy: keep the CHILD node (has rich plan description, e.g. filter condition) and
// add the wrapper's extra metrics (duration, rddId) to it. Mark as instrumented via nodeName.
// Edges: fromId (child/input) → toId (parent/output).
const mergedWrapperIds = new Set<number>();
for (const node of enrichedSql.nodes) {
if (!node.nodeName.startsWith("DataFlint")) continue;
const strippedName = node.nodeName.slice("DataFlint".length);
// Find the wrapped child: edge where toId === wrapper
const childEdges = enrichedSql.edges.filter(e => e.toId === node.nodeId);
if (childEdges.length !== 1) continue;
const childNode = enrichedSql.nodes.find(n => n.nodeId === childEdges[0].fromId);
if (!childNode || childNode.nodeName !== strippedName) continue;
// Keep child, add wrapper's extra metrics, mark as instrumented with DataFlint prefix
const childMetricNames = new Set(childNode.metrics.map(m => m.name));
const extraMetrics = node.metrics.filter(m => !childMetricNames.has(m.name));
childNode.metrics = [...childNode.metrics, ...extraMetrics];
childNode.nodeName = "DataFlint" + childNode.nodeName;
// Remove wrapper: redirect edges that pointed to wrapper to point to child instead
mergedWrapperIds.add(node.nodeId);
for (const edge of enrichedSql.edges) {
if (edge.fromId === node.nodeId) {
edge.fromId = childNode.nodeId;
}
}
}
if (mergedWrapperIds.size > 0) {
enrichedSql.nodes = enrichedSql.nodes.filter(n => !mergedWrapperIds.has(n.nodeId));
enrichedSql.edges = enrichedSql.edges.filter(e => !mergedWrapperIds.has(e.toId));
}

const typeEnrichedNodes = enrichedSql.nodes.map((node) => {
const isInstrumented = node.nodeName.startsWith("DataFlint");
const strippedNodeName = isInstrumented ? node.nodeName.slice("DataFlint".length) : node.nodeName;
Expand Down Expand Up @@ -494,7 +531,21 @@ export function calculateSqlStore(
newSql.status === SqlStatus.Completed.valueOf() ||
newSql.status === SqlStatus.Failed.valueOf()
) {
updatedSqls.push(calculateSql(newSql, plan, icebergCommit, deltaLakeScans, stages));
// If plan data is unavailable (offset advanced past this SQL) but we already
// have a fully calculated SQL with parsedPlan, preserve it — just update status/duration.
// This prevents losing plan descriptions on repeated polls with the non-paginated SQL API.
if (plan === undefined && currentSql.nodes.length > 0 && currentSql.nodes.some(n => n.parsedPlan !== undefined)) {
updatedSqls.push({
...currentSql,
status: newSql.status,
duration: newSql.duration,
failedJobIds: newSql.failedJobIds,
runningJobIds: newSql.runningJobIds,
successJobIds: newSql.successJobIds,
});
} else {
updatedSqls.push(calculateSql(newSql, plan, icebergCommit, deltaLakeScans, stages));
}
// From here newSql.status must be RUNNING
// case 3: running SQL structure, so we need to update the plan
} else if (currentSql.originalNumOfNodes !== newSql.nodes.length) {
Expand Down
Loading