Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
package org.apache.spark.dataflint.executor

import com.fasterxml.jackson.databind.ObjectMapper

import java.io.BufferedReader
import java.io.InputStreamReader
import java.nio.file.{Files, Paths}
import java.util.concurrent.TimeUnit
import scala.util.Try

object CloudMetadataDetector {

case class CloudMetadata(
cloudProvider: Option[String],
instanceType: Option[String],
lifecycleType: Option[String]
)

private val COMMAND_TIMEOUT_MS = 5000L

private val SYS_VENDOR_PATH = "/sys/class/dmi/id/sys_vendor"

private val AWS_COMMAND =
"""TOKEN=$(curl -sf -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 21600" --connect-timeout 1 --max-time 2 2>/dev/null)
|if [ -n "$TOKEN" ]; then
| HEADER="-H X-aws-ec2-metadata-token:$TOKEN"
|else
| HEADER=""
|fi
|IT=$(curl -sf $HEADER "http://169.254.169.254/latest/meta-data/instance-type" --connect-timeout 1 --max-time 2 2>/dev/null)
|if [ -z "$IT" ]; then exit 1; fi
|LC=$(curl -sf $HEADER "http://169.254.169.254/latest/meta-data/instance-life-cycle" --connect-timeout 1 --max-time 2 2>/dev/null)
|echo "{\"cloudProvider\":\"aws\",\"instanceType\":\"$IT\",\"lifecycleType\":\"$LC\"}"
|""".stripMargin

private val GCP_COMMAND =
"""MT=$(curl -sf -H "Metadata-Flavor: Google" "http://metadata.google.internal/computeMetadata/v1/instance/machine-type" --connect-timeout 1 --max-time 2 2>/dev/null)
|if [ -z "$MT" ]; then exit 1; fi
|IT=$(echo "$MT" | rev | cut -d/ -f1 | rev)
|PR=$(curl -sf -H "Metadata-Flavor: Google" "http://metadata.google.internal/computeMetadata/v1/instance/scheduling/preemptible" --connect-timeout 1 --max-time 2 2>/dev/null)
|if [ "$PR" = "TRUE" ]; then LC="preemptible"; else LC="on-demand"; fi
|echo "{\"cloudProvider\":\"gcp\",\"instanceType\":\"$IT\",\"lifecycleType\":\"$LC\"}"
|""".stripMargin

private val AZURE_COMMAND =
"""VS=$(curl -sf -H "Metadata: true" "http://169.254.169.254/metadata/instance/compute/vmSize?api-version=2021-02-01&format=text" --connect-timeout 1 --max-time 2 2>/dev/null)
|if [ -z "$VS" ]; then exit 1; fi
|PR=$(curl -sf -H "Metadata: true" "http://169.254.169.254/metadata/instance/compute/priority?api-version=2021-02-01&format=text" --connect-timeout 1 --max-time 2 2>/dev/null)
|if [ "$PR" = "Spot" ]; then LC="spot"; else LC="on-demand"; fi
|echo "{\"cloudProvider\":\"azure\",\"instanceType\":\"$VS\",\"lifecycleType\":\"$LC\"}"
|""".stripMargin

private val mapper = new ObjectMapper()

def detect(): CloudMetadata = {
detectCloudProvider() match {
case Some("aws") => runCloudCommand(AWS_COMMAND)
case Some("gcp") => runCloudCommand(GCP_COMMAND)
case Some("azure") => runCloudCommand(AZURE_COMMAND)
case _ => CloudMetadata(None, None, None)
}
}

private def detectCloudProvider(): Option[String] = {
Try {
val vendor = new String(Files.readAllBytes(Paths.get(SYS_VENDOR_PATH))).trim
if (vendor.contains("Amazon")) Some("aws")
else if (vendor.contains("Google")) Some("gcp")
else if (vendor.contains("Microsoft")) Some("azure")
else None
}.getOrElse(None)
}

private def runCloudCommand(command: String): CloudMetadata = {
Try {
runBashCommand(command).flatMap(parseJson)
}.getOrElse(None).getOrElse(CloudMetadata(None, None, None))
}

private def runBashCommand(command: String): Option[String] = {
val process = new ProcessBuilder("bash", "-c", command)
.redirectErrorStream(false)
.start()

val completed = process.waitFor(COMMAND_TIMEOUT_MS, TimeUnit.MILLISECONDS)
if (!completed) {
process.destroyForcibly()
return None
}

if (process.exitValue() != 0) {
return None
}

val reader = new BufferedReader(new InputStreamReader(process.getInputStream))
try {
val output = reader.readLine()
if (output != null && output.nonEmpty) Some(output.trim) else None
} finally {
reader.close()
}
}

private def parseJson(json: String): Option[CloudMetadata] = {
Try {
val node = mapper.readTree(json)
val cloudProvider = Option(node.get("cloudProvider")).map(_.asText()).filter(_.nonEmpty)
val instanceType = Option(node.get("instanceType")).map(_.asText()).filter(_.nonEmpty)
val lifecycleType = Option(node.get("lifecycleType")).map(_.asText()).filter(_.nonEmpty)
if (instanceType.isDefined || cloudProvider.isDefined) {
Some(CloudMetadata(cloudProvider, instanceType, lifecycleType))
} else {
None
}
}.getOrElse(None)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package org.apache.spark.dataflint.executor

import org.apache.spark.api.plugin.{ExecutorPlugin, PluginContext}
import org.apache.spark.internal.Logging

import java.util

class DataflintExecutorPlugin extends ExecutorPlugin with Logging {

override def init(ctx: PluginContext, extraConf: util.Map[String, String]): Unit = {
val enabled = Option(extraConf.get("executor.metadata.enabled")).exists(_.toBoolean)
if (!enabled) {
return
}

val executorId = ctx.executorID()
val hostname = try {
java.net.InetAddress.getLocalHost.getHostName
} catch {
case _: Throwable => "unknown"
}

try {
val osName = System.getProperty("os.name", "unknown")
val osArch = System.getProperty("os.arch", "unknown")
val jvmVersion = System.getProperty("java.version", "unknown")
val availableProcessors = Runtime.getRuntime.availableProcessors()
val totalMemoryBytes = Runtime.getRuntime.maxMemory()

val cloudMetadata = try {
CloudMetadataDetector.detect()
} catch {
case e: Throwable =>
logWarning("Failed to detect cloud metadata", e)
CloudMetadataDetector.CloudMetadata(None, None, None)
}

val message = ExecutorMetadataMessage(
executorId = executorId,
executorHost = hostname,
instanceType = cloudMetadata.instanceType,
lifecycleType = cloudMetadata.lifecycleType,
cloudProvider = cloudMetadata.cloudProvider,
osName = osName,
osArch = osArch,
jvmVersion = jvmVersion,
availableProcessors = availableProcessors,
totalMemoryBytes = totalMemoryBytes,
collectionError = None
)
ctx.send(message)
logInfo(s"Sent executor metadata: provider=${cloudMetadata.cloudProvider}, " +
s"instance=${cloudMetadata.instanceType}, lifecycle=${cloudMetadata.lifecycleType}")
} catch {
case e: Throwable =>
logWarning("Failed to collect/send executor metadata", e)
try {
ctx.send(ExecutorMetadataMessage(
executorId = executorId,
executorHost = hostname,
instanceType = None,
lifecycleType = None,
cloudProvider = None,
osName = System.getProperty("os.name", "unknown"),
osArch = System.getProperty("os.arch", "unknown"),
jvmVersion = System.getProperty("java.version", "unknown"),
availableProcessors = Runtime.getRuntime.availableProcessors(),
totalMemoryBytes = Runtime.getRuntime.maxMemory(),
collectionError = Some(e.getMessage)
))
} catch {
case inner: Throwable =>
logWarning("Failed to send error metadata to driver", inner)
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package org.apache.spark.dataflint.executor

import org.apache.spark.SparkContext
import org.apache.spark.dataflint.listener.{DataflintExecutorMetadataEvent, DataflintExecutorMetadataInfo}

object DriverMetadataHelper {

def isExecutorMetadataEnabled(sc: SparkContext): Boolean = {
sc.conf.getBoolean("spark.dataflint.experimental.executor.metadata.enabled", defaultValue = false)
}

def postExecutorMetadataEvent(sc: SparkContext, info: DataflintExecutorMetadataInfo): Unit = {
sc.listenerBus.post(DataflintExecutorMetadataEvent(info))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package org.apache.spark.dataflint.executor

case class ExecutorMetadataMessage(
executorId: String,
executorHost: String,
instanceType: Option[String],
lifecycleType: Option[String],
cloudProvider: Option[String],
osName: String,
osArch: String,
jvmVersion: String,
availableProcessors: Int,
totalMemoryBytes: Long,
collectionError: Option[String]
) extends java.io.Serializable
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ class DataflintListener(store: ElementTrackingStore) extends SparkListener with
val wrapper = new DataflintDeltaLakeScanInfoWrapper(e.scanInfo)
store.write(wrapper)
}
case e: DataflintExecutorMetadataEvent => {
val wrapper = new DataflintExecutorMetadataWrapper(e.metadata)
store.write(wrapper)
}
case _ => {}
}
} catch {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,8 @@ class DataflintStore(val store: KVStore) {
.sortBy(_.minExecutionId)
}

def executorMetadata(): Seq[DataflintExecutorMetadataInfo] = {
mapToSeq(store.view(classOf[DataflintExecutorMetadataWrapper]))(_.info)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,25 @@ class DataflintDeltaLakeScanInfoWrapper(val info: DataflintDeltaLakeScanInfo) {
@JsonIgnore
def id: String = s"${info.minExecutionId}_${info.tablePath.replaceAll(" ", "")}"
}

case class DataflintExecutorMetadataInfo(
executorId: String,
executorHost: String,
instanceType: Option[String],
lifecycleType: Option[String],
cloudProvider: Option[String],
osName: String,
osArch: String,
jvmVersion: String,
availableProcessors: Int,
totalMemoryBytes: Long,
collectionError: Option[String]
)

case class DataflintExecutorMetadataEvent(metadata: DataflintExecutorMetadataInfo) extends SparkListenerEvent

class DataflintExecutorMetadataWrapper(val info: DataflintExecutorMetadataInfo) {
@KVIndex
@JsonIgnore
def id: String = info.executorId
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package org.apache.spark.dataflint.saas

import org.apache.spark.dataflint.listener.{DatabricksAdditionalExecutionWrapper, DataflintDeltaLakeScanInfoWrapper, DataflintEnvironmentInfoEvent, DataflintEnvironmentInfoWrapper, DataflintRDDStorageInfoWrapper, IcebergCommitWrapper}
import org.apache.spark.dataflint.listener.{DatabricksAdditionalExecutionWrapper, DataflintDeltaLakeScanInfoWrapper, DataflintEnvironmentInfoEvent, DataflintEnvironmentInfoWrapper, DataflintExecutorMetadataWrapper, DataflintRDDStorageInfoWrapper, IcebergCommitWrapper}
import org.apache.spark.sql.execution.ui.{SQLExecutionUIData, SparkPlanGraphWrapper}
import org.apache.spark.status._

Expand All @@ -27,5 +27,6 @@ case class SparkRunStore(
icebergCommit: Seq[IcebergCommitWrapper],
dataflintEnvironmentInfo: Seq[DataflintEnvironmentInfoWrapper],
dataflintRDDStorageInfo: Seq[DataflintRDDStorageInfoWrapper],
dataflintDeltaLakeScanInfo: Seq[DataflintDeltaLakeScanInfoWrapper]
dataflintDeltaLakeScanInfo: Seq[DataflintDeltaLakeScanInfoWrapper],
dataflintExecutorMetadata: Seq[DataflintExecutorMetadataWrapper] = Seq.empty
)
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package org.apache.spark.dataflint.saas

import org.apache.spark.dataflint.listener.{DatabricksAdditionalExecutionWrapper, DataflintDeltaLakeScanInfoWrapper, DataflintEnvironmentInfoWrapper, DataflintRDDStorageInfoWrapper, IcebergCommitWrapper}
import org.apache.spark.dataflint.listener.{DatabricksAdditionalExecutionWrapper, DataflintDeltaLakeScanInfoWrapper, DataflintEnvironmentInfoWrapper, DataflintExecutorMetadataWrapper, DataflintRDDStorageInfoWrapper, IcebergCommitWrapper}
import org.apache.spark.sql.execution.ui.{SQLExecutionUIData, SparkPlanGraphWrapper}
import org.apache.spark.status._

Expand Down Expand Up @@ -35,7 +35,8 @@ class StoreDataExtractor(store: AppStatusStore) {
icebergCommit = readAll[IcebergCommitWrapper],
dataflintEnvironmentInfo = readAll[DataflintEnvironmentInfoWrapper],
dataflintRDDStorageInfo = readAll[DataflintRDDStorageInfoWrapper],
dataflintDeltaLakeScanInfo = readAll[DataflintDeltaLakeScanInfoWrapper]
dataflintDeltaLakeScanInfo = readAll[DataflintDeltaLakeScanInfoWrapper],
dataflintExecutorMetadata = readAll[DataflintExecutorMetadataWrapper]
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package io.dataflint.spark
import org.apache.spark.SparkContext
import org.apache.spark.api.plugin.{DriverPlugin, ExecutorPlugin, PluginContext, SparkPlugin}
import org.apache.spark.dataflint.{DataflintSparkUICommonLoader, DataflintSparkUILoader}
import org.apache.spark.dataflint.executor.{DataflintExecutorPlugin, DriverMetadataHelper, ExecutorMetadataMessage}
import org.apache.spark.dataflint.listener.DataflintExecutorMetadataInfo
import org.apache.spark.internal.Logging

import java.util
Expand All @@ -11,7 +13,7 @@ import scala.collection.JavaConverters.mapAsJavaMapConverter
class SparkDataflintPlugin extends SparkPlugin {
override def driverPlugin(): DriverPlugin = new SparkDataflintDriverPlugin()

override def executorPlugin(): ExecutorPlugin = null
override def executorPlugin(): ExecutorPlugin = new DataflintExecutorPlugin()
}

class SparkDataflintDriverPlugin extends DriverPlugin with Logging {
Expand All @@ -20,7 +22,40 @@ class SparkDataflintDriverPlugin extends DriverPlugin with Logging {
override def init(sc: SparkContext, pluginContext: PluginContext): util.Map[String, String] = {
this.sc = sc
DataflintSparkUICommonLoader.registerInstrumentationExtension(sc)
Map[String, String]().asJava
val executorMetadataEnabled = DriverMetadataHelper.isExecutorMetadataEnabled(sc)
Map("executor.metadata.enabled" -> executorMetadataEnabled.toString).asJava
}

override def receive(message: Any): String = {
message match {
case msg: ExecutorMetadataMessage =>
try {
val info = DataflintExecutorMetadataInfo(
executorId = msg.executorId,
executorHost = msg.executorHost,
instanceType = msg.instanceType,
lifecycleType = msg.lifecycleType,
cloudProvider = msg.cloudProvider,
osName = msg.osName,
osArch = msg.osArch,
jvmVersion = msg.jvmVersion,
availableProcessors = msg.availableProcessors,
totalMemoryBytes = msg.totalMemoryBytes,
collectionError = msg.collectionError
)
DriverMetadataHelper.postExecutorMetadataEvent(sc, info)
logInfo(s"Received executor metadata from executor ${msg.executorId}: " +
s"provider=${msg.cloudProvider}, instance=${msg.instanceType}, lifecycle=${msg.lifecycleType}")
null
} catch {
case e: Throwable =>
logWarning(s"Failed to process executor metadata from ${msg.executorId}", e)
null
}
case _ =>
logWarning(s"Received unknown message type: ${message.getClass.getName}")
null
}
}

override def registerMetrics(appId: String, pluginContext: PluginContext): Unit = {
Expand Down
Loading
Loading