From d3aca960f88ccdc0631fb70fc721733bcce804bd Mon Sep 17 00:00:00 2001 From: CodingCat Date: Tue, 16 Dec 2025 08:52:11 -0800 Subject: [PATCH 01/20] lifecycle manager related changes --- .../celeborn/client/LifecycleManager.scala | 291 ++++++++++++++++-- common/src/main/proto/TransportMessages.proto | 24 ++ .../apache/celeborn/common/CelebornConf.scala | 29 ++ .../protocol/message/ControlMessages.scala | 28 ++ .../apache/celeborn/common/util/Utils.scala | 3 + 5 files changed, 351 insertions(+), 24 deletions(-) diff --git a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala index 23189853544..11261770b12 100644 --- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala +++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala @@ -25,7 +25,6 @@ import java.util.{function, List => JList} import java.util.concurrent._ import java.util.concurrent.atomic.{AtomicInteger, LongAdder} import java.util.function.{BiConsumer, BiFunction, Consumer} - import scala.collection.JavaConverters._ import scala.collection.generic.CanBuildFrom import scala.collection.mutable @@ -33,11 +32,9 @@ import scala.collection.mutable.ArrayBuffer import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration.Duration import scala.util.Random - import com.google.common.annotations.VisibleForTesting import com.google.common.cache.{Cache, CacheBuilder} import org.roaringbitmap.RoaringBitmap - import org.apache.celeborn.client.LifecycleManager.{ShuffleAllocatedWorkers, ShuffleFailedWorkers} import org.apache.celeborn.client.listener.WorkerStatusListener import org.apache.celeborn.common.{CelebornConf, CommitMetadata} @@ -56,6 +53,7 @@ import org.apache.celeborn.common.protocol.message.StatusCode import org.apache.celeborn.common.rpc._ import org.apache.celeborn.common.rpc.{ClientSaslContextBuilder, RpcSecurityContext, RpcSecurityContextBuilder} import org.apache.celeborn.common.rpc.netty.{LocalNettyRpcCallContext, RemoteNettyRpcCallContext} +import org.apache.celeborn.common.util.Utils.{KNOWN_MISSING_CELEBORN_SHUFFLE_ID, UNKNOWN_MISSING_CELEBORN_SHUFFLE_ID} import org.apache.celeborn.common.util.{JavaUtils, PbSerDeUtils, ThreadUtils, Utils} // Can Remove this if celeborn don't support scala211 in future import org.apache.celeborn.common.util.FunctionConverter._ @@ -113,6 +111,10 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends // app shuffle id -> whether shuffle is determinate, rerun of a indeterminate shuffle gets different result private val appShuffleDeterminateMap = JavaUtils.newConcurrentHashMap[Int, Boolean](); + // format ${stageid}.${attemptid} + private val stagesReceivedInvalidatingUpstream = + new mutable.HashMap[String, mutable.HashSet[Int]]() + private val rpcCacheSize = conf.clientRpcCacheSize private val rpcCacheConcurrencyLevel = conf.clientRpcCacheConcurrencyLevel private val rpcCacheExpireTime = conf.clientRpcCacheExpireTime @@ -537,6 +539,23 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends } else { context.reply(PbSerDeUtils.toPbApplicationMeta(applicationMeta)) } + + case pb: PbReportMissingShuffleId => + val appShuffleId = pb.getTriggerAppShuffleId + val readerStageId = pb.getReaderStageId + val stageAttemptId = pb.getAttemptId + logInfo( + s"Received ReportMissingShuffleId, appShuffleId $appShuffleId readerStageIdentifier:" + + s" $readerStageId.$stageAttemptId") + handleReportMissingShuffleId(context, appShuffleId, readerStageId, stageAttemptId) + + case pb: PbInvalidateAllUpstreamShuffle => + val readerStageId = pb.getReaderStageId + val attemptId = pb.getAttemptId + val triggerAppShuffleId = pb.getTriggerAppShuffleId + logInfo(s"received ReportFetchFailureForAllUpstream for stage $readerStageId," + + s" attemptId: $attemptId") + handleInvalidateAllUpstreamShuffle(context, readerStageId, attemptId, triggerAppShuffleId) } private def handleReducerPartitionEnd( @@ -979,7 +998,12 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends override def apply(id: Int) : scala.collection.mutable.LinkedHashMap[String, (Int, Boolean)] = { val newShuffleId = shuffleIdGenerator.getAndIncrement() - logInfo(s"generate new shuffleId $newShuffleId for appShuffleId $appShuffleId appShuffleIdentifier $appShuffleIdentifier") + logInfo(s"generate new shuffleId $newShuffleId for appShuffleId $appShuffleId" + + s" appShuffleIdentifier $appShuffleIdentifier") + stageToWriteCelebornShuffleCallback.foreach(callback => + callback.accept(newShuffleId, appShuffleIdentifier)) + celebornToAppShuffleIdMappingCallback.foreach(callback => + callback.accept(newShuffleId, appShuffleIdentifier)) scala.collection.mutable.LinkedHashMap(appShuffleIdentifier -> (newShuffleId, true)) } }) @@ -1014,8 +1038,14 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends // So if a barrier stage is getting reexecuted, previous stage/attempt needs to // be cleaned up as it is entirely unusuable if (determinate && !isBarrierStage && !isCelebornSkewShuffleOrChildShuffle( - appShuffleId)) - shuffleIds.values.toSeq.reverse.find(e => e._2 == true) + appShuffleId)) { + val result = shuffleIds.values.toSeq.reverse.find(e => e._2 == true) + if (result.isEmpty) { + logWarning(s"cannot find candidate shuffleId for determinate" + + s" shuffle $appShuffleIdentifier") + } + result + } else None @@ -1035,9 +1065,15 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends shuffleIds ++= mapUpdates } val newShuffleId = shuffleIdGenerator.getAndIncrement() - logInfo(s"generate new shuffleId $newShuffleId for appShuffleId $appShuffleId appShuffleIdentifier $appShuffleIdentifier") + logInfo(s"generate new shuffleId $newShuffleId for appShuffleId $appShuffleId" + + s" appShuffleIdentifier $appShuffleIdentifier") validateCelebornShuffleIdForClean.foreach(callback => callback.accept(appShuffleIdentifier)) + stageToWriteCelebornShuffleCallback.foreach { callback => + callback.accept(newShuffleId, appShuffleIdentifier) + } + celebornToAppShuffleIdMappingCallback.foreach(callback => + callback.accept(newShuffleId, appShuffleIdentifier)) shuffleIds.put(appShuffleIdentifier, (newShuffleId, true)) newShuffleId } @@ -1049,29 +1085,135 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends s"unexpected! unknown appShuffleId $appShuffleId when checking shuffle deterministic level")) } } else { - shuffleIds.values.filter(v => v._2).map(v => v._1).toSeq.reverse.find( - areAllMapTasksEnd) match { - case Some(celebornShuffleId) => - val pbGetShuffleIdResponse = { - logDebug( - s"get shuffleId $celebornShuffleId for appShuffleId $appShuffleId appShuffleIdentifier $appShuffleIdentifier isWriter $isWriter") - PbGetShuffleIdResponse.newBuilder().setShuffleId(celebornShuffleId).setSuccess( - true).build() - } - context.reply(pbGetShuffleIdResponse) - case None => - val pbGetShuffleIdResponse = { - logInfo( - s"there is no finished map stage associated with appShuffleId $appShuffleId") - PbGetShuffleIdResponse.newBuilder().setShuffleId(UNKNOWN_APP_SHUFFLE_ID).setSuccess( - false).build() + // this is not necessarily the most concise coding style, but it helps for debugging + // purpose + var found = false + shuffleIds.values.map(v => v._1).toSeq.reverse.foreach { celebornShuffleId: Int => + if (!found) { + try { + if (areAllMapTasksEnd(celebornShuffleId)) { + getCelebornShuffleIdForReaderCallback.foreach(callback => + callback.accept(celebornShuffleId, appShuffleIdentifier)) + getAppShuffleIdForReaderCallback.foreach(callback => + callback.accept(appShuffleId, appShuffleIdentifier)) + val pbGetShuffleIdResponse = { + logDebug( + s"get shuffleId $celebornShuffleId for appShuffleId $appShuffleId appShuffleIdentifier $appShuffleIdentifier isWriter $isWriter") + PbGetShuffleIdResponse.newBuilder().setShuffleId(celebornShuffleId).build() + } + context.reply(pbGetShuffleIdResponse) + found = true + } else { + logInfo(s"not all map tasks finished for shuffle $celebornShuffleId") + } + } catch { + case npe: NullPointerException => + if (conf.clientShuffleEarlyDeletion) { + logError( + s"hit error when getting celeborn shuffle id $celebornShuffleId for" + + s" appShuffleId $appShuffleId appShuffleIdentifier $appShuffleIdentifier", + npe) + val canInvalidateAllUpstream = + checkWhetherToInvalidateAllUpstreamCallback.exists(func => + func.apply(appShuffleIdentifier)) + val pbGetShuffleIdResponse = PbGetShuffleIdResponse + .newBuilder() + .setShuffleId({ + if (canInvalidateAllUpstream) { + KNOWN_MISSING_CELEBORN_SHUFFLE_ID + } else { + UNKNOWN_MISSING_CELEBORN_SHUFFLE_ID + } + }) + .build() + context.reply(pbGetShuffleIdResponse) + } else { + logError( + s"unexpected NullPointerException without" + + s" ${CelebornConf.CLIENT_SHUFFLE_EARLY_DELETION.key} turning on", + npe) + throw npe; + } } - context.reply(pbGetShuffleIdResponse) + } } } } } + private def invalidateAllKnownUpstreamShuffleOutput(stageIdentifier: String): Unit = { + val Array(readerStageId, _) = stageIdentifier.split('.').map(_.toInt) + val invalidatedUpstreamIds = + stagesReceivedInvalidatingUpstream.getOrElseUpdate( + stageIdentifier, + new mutable.HashSet[Int]()) + println(s"invalidating all upstream shuffles of stage $stageIdentifier") + val upstreamShuffleIds = getUpstreamAppShuffleIdsCallback.map(f => + f.apply(readerStageId)).getOrElse(Array()) + upstreamShuffleIds.foreach { upstreamAppShuffleId => + appShuffleTrackerCallback.foreach { callback => + logInfo(s"invalidated upstream app shuffle id $upstreamAppShuffleId for stage" + + s" $stageIdentifier") + callback.accept(upstreamAppShuffleId) + invalidatedUpstreamIds += upstreamAppShuffleId + val celebornShuffleIds = shuffleIdMapping.get(upstreamAppShuffleId) + val latestShuffle = celebornShuffleIds.maxBy(_._2._1) + celebornShuffleIds.put(latestShuffle._1, (KNOWN_MISSING_CELEBORN_SHUFFLE_ID, false)) + } + } + invalidateShuffleWrittenByStage(readerStageId) + } + + private def handleInvalidateAllUpstreamShuffle( + context: RpcCallContext, + readerStageId: Int, + readerStageAttemptId: Int, + triggerAppShuffleId: Int): Unit = stagesReceivedInvalidatingUpstream.synchronized { + require( + conf.clientShuffleEarlyDeletion, + "ReportFetchFailureForAllUpstream message is " + + s"supposed to be only received when turning on" + + s" ${CelebornConf.CLIENT_SHUFFLE_EARLY_DELETION.key}") + require( + getUpstreamAppShuffleIdsCallback.isDefined, + "no callback has been registered for" + + " invalidating all upstream shuffles for a reader stage") + var ret = true + try { + val stageIdentifier = s"$readerStageId.$readerStageAttemptId" + if (!stagesReceivedInvalidatingUpstream.contains(stageIdentifier)) { + invalidateAllKnownUpstreamShuffleOutput(stageIdentifier) + } else if (!stagesReceivedInvalidatingUpstream(stageIdentifier) + .contains(triggerAppShuffleId)) { + // in this case, it means that we haven't been able to capture a certain upstream app + // shuffle id for the current stage when we invalidate all upstream last time, + // and the new upstream shuffle id show up now, we need to add the new shuffle id + // dependency and then fallback to the fetchfailure error for this shuffle + // (since other captured upstream shuffles might have been regenerated) + logInfo(s"a new upstream shuffle id $triggerAppShuffleId show up for $stageIdentifier" + + s" after we have invalidated all known upstream shuffle outputs") + val appShuffleIdentifier = s"$triggerAppShuffleId-$readerStageId-$readerStageAttemptId" + getAppShuffleIdForReaderCallback.foreach(callback => + callback.accept(triggerAppShuffleId, appShuffleIdentifier)) + ret = false + } else { + logInfo(s"ignoring the message to invalidate all upstream shuffles for stage" + + s" $stageIdentifier (triggered appShuffleId $triggerAppShuffleId)," + + s" as it has been handled by another thread") + } + } catch { + case t: Throwable => + logError( + s"hit error when invalidating upstream shuffles for stage $readerStageId," + + s" attempt $readerStageAttemptId", + t) + ret = false + } + val pbInvalidateAllUpstreamShuffleResponse = + PbInvalidateAllUpstreamShuffleResponse.newBuilder().setSuccess(ret).build() + context.reply(pbInvalidateAllUpstreamShuffleResponse) + } + private def handleReportShuffleFetchFailure( context: RpcCallContext, appShuffleId: Int, @@ -1109,6 +1251,47 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends context.reply(pbReportShuffleFetchFailureResponse) } + private def handleReportMissingShuffleId( + context: RpcCallContext, + appShuffleId: Int, + stageId: Int, + stageAttemptId: Int): Unit = { + val shuffleIds = shuffleIdMapping.get(appShuffleId) + if (shuffleIds == null) { + throw new UnsupportedOperationException(s"unexpected! unknown appShuffleId $appShuffleId") + } + var ret = true + shuffleIds.synchronized { + val latestUpstreamShuffleId = shuffleIds.maxBy(_._2._1) + if (latestUpstreamShuffleId._2._1 == UNKNOWN_MISSING_CELEBORN_SHUFFLE_ID) { + logInfo(s"ignoring missing shuffle id report from stage $stageId.$stageAttemptId as" + + s" it is already reported by other reader and handled") + } else { + logInfo(s"handle missing shuffle id for appShuffleId $appShuffleId stage" + + s" $stageId.$stageAttemptId") + appShuffleTrackerCallback match { + case Some(callback) => + try { + callback.accept(appShuffleId) + } catch { + case t: Throwable => + logError(t.toString) + ret = false + } + shuffleIds.put(latestUpstreamShuffleId._1, (UNKNOWN_MISSING_CELEBORN_SHUFFLE_ID, false)) + case None => + throw new UnsupportedOperationException( + "unexpected! appShuffleTrackerCallback is not registered") + } + // invalidate the shuffle written by stage + invalidateShuffleWrittenByStage(stageId) + val pbReportMissingShuffleIdResponse = + PbReportMissingShuffleIdResponse.newBuilder().setSuccess(ret).build() + context.reply(pbReportMissingShuffleIdResponse) + } + } + } + private def handleReportBarrierStageAttemptFailure( context: RpcCallContext, appShuffleId: Int, @@ -1179,6 +1362,24 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends } } + private def invalidateShuffleWrittenByStage(stageId: Int): Unit = { + val writtenShuffleId = getAppShuffleIdByStageIdCallback.map { callback => + callback.apply(stageId) + } + writtenShuffleId.foreach { shuffleId => + if (shuffleId >= 0) { + val celebornShuffleIds = shuffleIdMapping.get(writtenShuffleId) + if (celebornShuffleIds != null) { + logInfo(s"invalidating location of app shuffle id $writtenShuffleId written" + + s" by stage $stageId") + val latestShuffleId = celebornShuffleIds.maxBy(_._2._1) + celebornShuffleIds.put(latestShuffleId._1, (latestShuffleId._2._1, false)) + appShuffleTrackerCallback.foreach(callback => callback.accept(shuffleId)) + } + } + } + } + private def handleStageEnd(shuffleId: Int): Unit = { // check whether shuffle has registered if (!registeredShuffle.contains(shuffleId)) { @@ -2004,6 +2205,48 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends celebornSkewShuffleCheckCallback = Some(callback) } + @volatile private var getUpstreamAppShuffleIdsCallback + : Option[Function[Integer, Array[Integer]]] = None + def registerUpstreamAppShuffleIdsCallback(callback: Function[Integer, Array[Integer]]): Unit = { + getUpstreamAppShuffleIdsCallback = Some(callback) + } + + @volatile private var getAppShuffleIdByStageIdCallback + : Option[Function[Integer, Integer]] = None + def registerGetAppShuffleIdByStageIdCallback(callback: Function[Integer, Integer]): Unit = { + getAppShuffleIdByStageIdCallback = Some(callback) + } + + // expecting celeborn shuffle id and application shuffle identifier + @volatile private var getCelebornShuffleIdForReaderCallback + : Option[BiConsumer[Integer, String]] = None + def registerGetCelebornShuffleIdForReaderCallback(callback: BiConsumer[Integer, String]): Unit = { + getCelebornShuffleIdForReaderCallback = Some(callback) + } + + @volatile private var getAppShuffleIdForReaderCallback: Option[BiConsumer[Integer, String]] = None + def registerReaderStageToAppShuffleIdsCallback(callback: BiConsumer[Integer, String]): Unit = { + getAppShuffleIdForReaderCallback = Some(callback) + } + + @volatile private var stageToWriteCelebornShuffleCallback: Option[BiConsumer[Integer, String]] = + None + def registerStageToWriteCelebornShuffleCallback(callback: BiConsumer[Integer, String]): Unit = { + stageToWriteCelebornShuffleCallback = Some(callback) + } + + @volatile private var celebornToAppShuffleIdMappingCallback + : Option[BiConsumer[Integer, String]] = None + def registerCelebornToAppShuffleIdMappingCallback(callback: BiConsumer[Integer, String]): Unit = { + celebornToAppShuffleIdMappingCallback = Some(callback) + } + + @volatile private var checkWhetherToInvalidateAllUpstreamCallback + : Option[Function[String, Boolean]] = None + def registerInvalidateAllUpstreamCheckCallback(callback: Function[String, Boolean]): Unit = { + checkWhetherToInvalidateAllUpstreamCallback = Some(callback) + } + // Initialize at the end of LifecycleManager construction. initialize() diff --git a/common/src/main/proto/TransportMessages.proto b/common/src/main/proto/TransportMessages.proto index 160c7e96201..707e08b9da6 100644 --- a/common/src/main/proto/TransportMessages.proto +++ b/common/src/main/proto/TransportMessages.proto @@ -117,6 +117,10 @@ enum MessageType { READ_REDUCER_PARTITION_END = 94; READ_REDUCER_PARTITION_END_RESPONSE = 95; REGISTER_APPLICATION_INFO = 96; + INVALIDATE_ALL_UPSTREAM_SHUFFLE = 97; + INVALIDATE_ALL_UPSTREAM_SHUFFLE_RESPONSE = 98; + REPORT_MISSING_SHUFFLE_ID = 99; + REPORT_MISSING_SHUFFLE_ID_RESPONSE = 100; } enum StreamType { @@ -1072,3 +1076,23 @@ message PbRegisterApplicationInfo { map extraInfo = 3; string requestId = 4; } + +message PbReportMissingShuffleId { + int32 readerStageId = 1; + int32 attemptId = 2; + int32 triggerAppShuffleId = 3; +} + +message PbReportMissingShuffleIdResponse { + bool success = 1; +} + +message PbInvalidateAllUpstreamShuffle { + int32 readerStageId = 1; + int32 attemptId = 2; + int32 triggerAppShuffleId = 3; +} + +message PbInvalidateAllUpstreamShuffleResponse { + bool success = 1; +} diff --git a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala index 7c1cc6875a4..5f6f89df4a2 100644 --- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala +++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala @@ -1036,6 +1036,10 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se def clientFetchCleanFailedShuffle: Boolean = get(CLIENT_FETCH_CLEAN_FAILED_SHUFFLE) def clientFetchCleanFailedShuffleIntervalMS: Long = get(CLIENT_FETCH_CLEAN_FAILED_SHUFFLE_INTERVAL) + def clientShuffleEarlyDeletion: Boolean = get(CLIENT_SHUFFLE_EARLY_DELETION) + def clientShuffleEarlyDeletionCheckProp: Boolean = + get(CLIENT_SHUFFLE_EARLY_DELETION_CHECK_PROPERTY) + def clientShuffleEarlyDeletionIntervalMs: Long = get(CLIENT_SHUFFLE_EARLY_DELETION_INTERVAL_MS) def clientFetchExcludeWorkerOnFailureEnabled: Boolean = get(CLIENT_FETCH_EXCLUDE_WORKER_ON_FAILURE_ENABLED) def clientFetchExcludedWorkerExpireTimeout: Long = @@ -5085,6 +5089,31 @@ object CelebornConf extends Logging { .booleanConf .createWithDefault(false) + val CLIENT_SHUFFLE_EARLY_DELETION: ConfigEntry[Boolean] = + buildConf("celeborn.client.spark.fetch.shuffleEarlyDeletion") + .categories("client") + .version("0.7.0") + .doc("whether to delete shuffle when we determine a shuffle is not needed by any stage") + .booleanConf + .createWithDefault(false) + + val CLIENT_SHUFFLE_EARLY_DELETION_CHECK_PROPERTY: ConfigEntry[Boolean] = + buildConf("celeborn.client.spark.fetch.shuffleEarlyDeletion.checkProperty") + .categories("client") + .version("0.7.0") + .doc("when this is enabled, we only early delete shuffle when" + + " \"CELEBORN_EARLY_SHUFFLE_DELETION\" property is set to true") + .booleanConf + .createWithDefault(false) + + val CLIENT_SHUFFLE_EARLY_DELETION_INTERVAL_MS: ConfigEntry[Long] = + buildConf("celeborn.client.spark.fetch.shuffleEarlyDeletion.intervalMs") + .categories("client") + .version("0.7.0") + .doc("interval length to delete unused shuffle (ms)") + .longConf + .createWithDefault(5 * 60 * 1000) + val CLIENT_FETCH_CLEAN_FAILED_SHUFFLE_INTERVAL: ConfigEntry[Long] = buildConf("celeborn.client.spark.fetch.cleanFailedShuffleInterval") .categories("client") diff --git a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala index eb9274632de..e097efa1ada 100644 --- a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala +++ b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala @@ -1073,6 +1073,22 @@ object ControlMessages extends Logging { case pb: PbApplicationMetaRequest => new TransportMessage(MessageType.APPLICATION_META_REQUEST, pb.toByteArray) + + case pb: PbInvalidateAllUpstreamShuffle => + new TransportMessage(MessageType.INVALIDATE_ALL_UPSTREAM_SHUFFLE, pb.toByteArray) + + case pb: PbInvalidateAllUpstreamShuffleResponse => + new TransportMessage( + MessageType.INVALIDATE_ALL_UPSTREAM_SHUFFLE_RESPONSE, + pb.toByteArray) + + case pb: PbReportMissingShuffleId => + new TransportMessage(MessageType.REPORT_MISSING_SHUFFLE_ID, pb.toByteArray) + + case pb: PbReportMissingShuffleIdResponse => + new TransportMessage( + MessageType.REPORT_MISSING_SHUFFLE_ID_RESPONSE, + pb.toByteArray) } // TODO change return type to GeneratedMessageV3 @@ -1516,6 +1532,18 @@ object ControlMessages extends Logging { PbSerDeUtils.fromPbUserIdentifier(pbRegisterApplicationInfo.getUserIdentifier), pbRegisterApplicationInfo.getExtraInfoMap, pbRegisterApplicationInfo.getRequestId) + + case REPORT_MISSING_SHUFFLE_ID_VALUE => + message.getParsedPayload() + + case REPORT_MISSING_SHUFFLE_ID_RESPONSE_VALUE => + message.getParsedPayload() + + case INVALIDATE_ALL_UPSTREAM_SHUFFLE_VALUE => + message.getParsedPayload() + + case INVALIDATE_ALL_UPSTREAM_SHUFFLE_RESPONSE_VALUE => + message.getParsedPayload() } } } diff --git a/common/src/main/scala/org/apache/celeborn/common/util/Utils.scala b/common/src/main/scala/org/apache/celeborn/common/util/Utils.scala index e71db2eadff..3a76b48816c 100644 --- a/common/src/main/scala/org/apache/celeborn/common/util/Utils.scala +++ b/common/src/main/scala/org/apache/celeborn/common/util/Utils.scala @@ -1119,6 +1119,9 @@ object Utils extends Logging { val UNKNOWN_APP_SHUFFLE_ID = -1 + val UNKNOWN_MISSING_CELEBORN_SHUFFLE_ID = -2 + val KNOWN_MISSING_CELEBORN_SHUFFLE_ID = -3 + def isHdfsPath(path: String): Boolean = { path.matches(COMPATIBLE_HDFS_REGEX) } From f1fba6bd5795407cdd601e614fe7f10d541f8398 Mon Sep 17 00:00:00 2001 From: CodingCat Date: Tue, 16 Dec 2025 08:58:14 -0800 Subject: [PATCH 02/20] stylistic fixes --- .../celeborn/client/LifecycleManager.scala | 29 ++++++++++--------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala index 11261770b12..64795b0c05c 100644 --- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala +++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala @@ -25,6 +25,7 @@ import java.util.{function, List => JList} import java.util.concurrent._ import java.util.concurrent.atomic.{AtomicInteger, LongAdder} import java.util.function.{BiConsumer, BiFunction, Consumer} + import scala.collection.JavaConverters._ import scala.collection.generic.CanBuildFrom import scala.collection.mutable @@ -32,9 +33,11 @@ import scala.collection.mutable.ArrayBuffer import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration.Duration import scala.util.Random + import com.google.common.annotations.VisibleForTesting import com.google.common.cache.{Cache, CacheBuilder} import org.roaringbitmap.RoaringBitmap + import org.apache.celeborn.client.LifecycleManager.{ShuffleAllocatedWorkers, ShuffleFailedWorkers} import org.apache.celeborn.client.listener.WorkerStatusListener import org.apache.celeborn.common.{CelebornConf, CommitMetadata} @@ -53,11 +56,11 @@ import org.apache.celeborn.common.protocol.message.StatusCode import org.apache.celeborn.common.rpc._ import org.apache.celeborn.common.rpc.{ClientSaslContextBuilder, RpcSecurityContext, RpcSecurityContextBuilder} import org.apache.celeborn.common.rpc.netty.{LocalNettyRpcCallContext, RemoteNettyRpcCallContext} -import org.apache.celeborn.common.util.Utils.{KNOWN_MISSING_CELEBORN_SHUFFLE_ID, UNKNOWN_MISSING_CELEBORN_SHUFFLE_ID} import org.apache.celeborn.common.util.{JavaUtils, PbSerDeUtils, ThreadUtils, Utils} // Can Remove this if celeborn don't support scala211 in future import org.apache.celeborn.common.util.FunctionConverter._ import org.apache.celeborn.common.util.ThreadUtils.awaitResult +import org.apache.celeborn.common.util.Utils.{KNOWN_MISSING_CELEBORN_SHUFFLE_ID, UNKNOWN_MISSING_CELEBORN_SHUFFLE_ID} import org.apache.celeborn.common.util.Utils.UNKNOWN_APP_SHUFFLE_ID import org.apache.celeborn.common.write.LocationPushFailedBatches @@ -1045,8 +1048,7 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends s" shuffle $appShuffleIdentifier") } result - } - else + } else None val shuffleId: Integer = @@ -1184,7 +1186,7 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends if (!stagesReceivedInvalidatingUpstream.contains(stageIdentifier)) { invalidateAllKnownUpstreamShuffleOutput(stageIdentifier) } else if (!stagesReceivedInvalidatingUpstream(stageIdentifier) - .contains(triggerAppShuffleId)) { + .contains(triggerAppShuffleId)) { // in this case, it means that we haven't been able to capture a certain upstream app // shuffle id for the current stage when we invalidate all upstream last time, // and the new upstream shuffle id show up now, we need to add the new shuffle id @@ -1252,10 +1254,10 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends } private def handleReportMissingShuffleId( - context: RpcCallContext, - appShuffleId: Int, - stageId: Int, - stageAttemptId: Int): Unit = { + context: RpcCallContext, + appShuffleId: Int, + stageId: Int, + stageAttemptId: Int): Unit = { val shuffleIds = shuffleIdMapping.get(appShuffleId) if (shuffleIds == null) { throw new UnsupportedOperationException(s"unexpected! unknown appShuffleId $appShuffleId") @@ -2211,15 +2213,14 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends getUpstreamAppShuffleIdsCallback = Some(callback) } - @volatile private var getAppShuffleIdByStageIdCallback - : Option[Function[Integer, Integer]] = None + @volatile private var getAppShuffleIdByStageIdCallback: Option[Function[Integer, Integer]] = None def registerGetAppShuffleIdByStageIdCallback(callback: Function[Integer, Integer]): Unit = { getAppShuffleIdByStageIdCallback = Some(callback) } // expecting celeborn shuffle id and application shuffle identifier - @volatile private var getCelebornShuffleIdForReaderCallback - : Option[BiConsumer[Integer, String]] = None + @volatile private var getCelebornShuffleIdForReaderCallback: Option[BiConsumer[Integer, String]] = + None def registerGetCelebornShuffleIdForReaderCallback(callback: BiConsumer[Integer, String]): Unit = { getCelebornShuffleIdForReaderCallback = Some(callback) } @@ -2235,8 +2236,8 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends stageToWriteCelebornShuffleCallback = Some(callback) } - @volatile private var celebornToAppShuffleIdMappingCallback - : Option[BiConsumer[Integer, String]] = None + @volatile private var celebornToAppShuffleIdMappingCallback: Option[BiConsumer[Integer, String]] = + None def registerCelebornToAppShuffleIdMappingCallback(callback: BiConsumer[Integer, String]): Unit = { celebornToAppShuffleIdMappingCallback = Some(callback) } From 90cea33791e73d617af2b46a202e92a564b37284 Mon Sep 17 00:00:00 2001 From: CodingCat Date: Tue, 16 Dec 2025 09:44:23 -0800 Subject: [PATCH 03/20] spark related changes --- .../spark/CelebornSparkContextHelper.scala | 43 +++ .../shuffle/celeborn/SparkShuffleManager.java | 48 +++- .../spark/shuffle/celeborn/SparkUtils.java | 57 +++- .../celeborn/StageDependencyManager.scala | 262 ++++++++++++++++++ .../CelebornShuffleEarlyCleanup.scala | 28 ++ .../spark/listener/ListenerHelper.scala | 46 +++ .../ShuffleStatsTrackingListener.scala | 64 +++++ docs/configuration/client.md | 3 + 8 files changed, 544 insertions(+), 7 deletions(-) create mode 100644 client-spark/common/src/main/scala/org/apache/spark/CelebornSparkContextHelper.scala create mode 100644 client-spark/spark-3/src/main/scala/org/apache/spark/celeborn/StageDependencyManager.scala create mode 100644 client-spark/spark-3/src/main/scala/org/apache/spark/listener/CelebornShuffleEarlyCleanup.scala create mode 100644 client-spark/spark-3/src/main/scala/org/apache/spark/listener/ListenerHelper.scala create mode 100644 client-spark/spark-3/src/main/scala/org/apache/spark/listener/ShuffleStatsTrackingListener.scala diff --git a/client-spark/common/src/main/scala/org/apache/spark/CelebornSparkContextHelper.scala b/client-spark/common/src/main/scala/org/apache/spark/CelebornSparkContextHelper.scala new file mode 100644 index 00000000000..f07d7088dd6 --- /dev/null +++ b/client-spark/common/src/main/scala/org/apache/spark/CelebornSparkContextHelper.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import scala.collection.JavaConverters._ + +import org.apache.spark.scheduler.{EventLoggingListener, SparkListenerInterface} + +object CelebornSparkContextHelper { + + def eventLogger: Option[EventLoggingListener] = SparkContext.getActive.get.eventLogger + + def env: SparkEnv = { + assert(SparkContext.getActive.isDefined) + SparkContext.getActive.get.env + } + + def activeSparkContext(): Option[SparkContext] = { + SparkContext.getActive + } + + def getListener(listenerClass: String): SparkListenerInterface = { + activeSparkContext().get.listenerBus.listeners.asScala.find(l => + l.getClass.getCanonicalName.contains(listenerClass)).getOrElse( + throw new RuntimeException( + s"cannot find any listener containing $listenerClass in class name")) + } +} diff --git a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java index ed7865e19ff..27266103dcf 100644 --- a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java +++ b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java @@ -22,8 +22,10 @@ import java.util.concurrent.ConcurrentHashMap; import org.apache.spark.*; +import org.apache.spark.celeborn.StageDependencyManager; import org.apache.spark.internal.config.package$; import org.apache.spark.launcher.SparkLauncher; +import org.apache.spark.listener.ListenerHelper; import org.apache.spark.rdd.DeterministicLevel; import org.apache.spark.shuffle.*; import org.apache.spark.shuffle.sort.SortShuffleManager; @@ -91,6 +93,13 @@ public class SparkShuffleManager implements ShuffleManager { private ExecutorShuffleIdTracker shuffleIdTracker = new ExecutorShuffleIdTracker(); + private StageDependencyManager stageDepManager = null; + + // for testing + public void initStageDepManager() { + this.stageDepManager = new StageDependencyManager(this); + } + public SparkShuffleManager(SparkConf conf, boolean isDriver) { if (conf.getBoolean(SQLConf.LOCAL_SHUFFLE_READER_ENABLED().key(), true)) { logger.warn( @@ -151,9 +160,6 @@ private void initializeLifecycleManager(String appId) { taskId -> SparkUtils.shouldReportShuffleFetchFailure(taskId)); SparkUtils.addSparkListener(new ShuffleFetchFailureReportTaskCleanListener()); - lifecycleManager.registerShuffleTrackerCallback( - shuffleId -> SparkUtils.unregisterAllMapOutput(mapOutputTracker, shuffleId)); - if (celebornConf.clientAdaptiveOptimizeSkewedPartitionReadEnabled()) { lifecycleManager.registerCelebornSkewShuffleCheckCallback( SparkUtils::isCelebornSkewShuffleOrChildShuffle); @@ -177,6 +183,38 @@ private void initializeLifecycleManager(String appId) { (celebornShuffleId) -> SparkUtils.removeCleanedShuffleId(this, celebornShuffleId)); } + if (lifecycleManager.conf().clientShuffleEarlyDeletion()) { + logger.info("register early deletion callbacks"); + ListenerHelper.addShuffleStatsTrackingListener(); + lifecycleManager.registerStageToWriteCelebornShuffleCallback( + (celebornShuffleId, appShuffleIdentifier) -> + SparkUtils.addStageToWriteCelebornShuffleIdDep( + this, celebornShuffleId, appShuffleIdentifier)); + lifecycleManager.registerCelebornToAppShuffleIdMappingCallback( + (celebornShuffleId, appShuffleIdentifier) -> + SparkUtils.addCelebornToSparkShuffleIdRef( + this, celebornShuffleId, appShuffleIdentifier)); + lifecycleManager.registerGetCelebornShuffleIdForReaderCallback( + (celebornShuffleId, appShuffleIdentifier) -> + SparkUtils.addCelebornShuffleReadingStageDep( + this, celebornShuffleId, appShuffleIdentifier)); + lifecycleManager.registerUpstreamAppShuffleIdsCallback( + (stageId) -> SparkUtils.getAllUpstreamAppShuffleIds(this, stageId)); + lifecycleManager.registerGetAppShuffleIdByStageIdCallback( + (stageId) -> SparkUtils.getAppShuffleIdByStageId(this, stageId)); + lifecycleManager.registerReaderStageToAppShuffleIdsCallback( + (appShuffleId, appShuffleIdentifier) -> + SparkUtils.addAppShuffleReadingStageDep( + this, appShuffleId, appShuffleIdentifier)); + lifecycleManager.registerInvalidateAllUpstreamCheckCallback( + (appShuffleIdentifier) -> + SparkUtils.canInvalidateAllUpstream(this, appShuffleIdentifier)); + if (stageDepManager == null) { + stageDepManager = new StageDependencyManager(this); + } + stageDepManager.start(); + } + if (celebornConf.getReducerFileGroupBroadcastEnabled()) { lifecycleManager.registerBroadcastGetReducerFileGroupResponseCallback( (shuffleId, getReducerFileGroupResponse) -> @@ -497,4 +535,8 @@ public LifecycleManager getLifecycleManager() { public FailedShuffleCleaner getFailedShuffleCleaner() { return this.failedShuffleCleaner; } + + public StageDependencyManager getStageDepManager() { + return this.stageDepManager; + } } diff --git a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java index 2b30a20206f..658c76a0e02 100644 --- a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java +++ b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java @@ -20,10 +20,7 @@ import java.io.ByteArrayInputStream; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; +import java.util.*; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.LongAdder; @@ -680,4 +677,56 @@ public static boolean isLocalMaster(SparkConf conf) { String master = conf.get("spark.master", ""); return master.equals("local") || master.startsWith("local["); } + + public static Integer[] getAllUpstreamAppShuffleIds( + SparkShuffleManager sparkShuffleManager, int readerStageId) { + int[] upstreamShuffleIds = + sparkShuffleManager + .getStageDepManager() + .getAllUpstreamAppShuffleIdsByStageId(readerStageId); + return Arrays.stream(upstreamShuffleIds).boxed().toArray(Integer[]::new); + } + + public static Integer getAppShuffleIdByStageId( + SparkShuffleManager sparkShuffleManager, int readerStageId) { + int writtenAppShuffleId = + sparkShuffleManager.getStageDepManager().getAppShuffleIdByStageId(readerStageId); + return writtenAppShuffleId; + } + + public static void addCelebornShuffleReadingStageDep( + SparkShuffleManager sparkShuffleManager, int celebornShuffeId, String appShuffleIdentifier) { + sparkShuffleManager + .getStageDepManager() + .addCelebornShuffleIdReadingStageDep(celebornShuffeId, appShuffleIdentifier); + } + + public static void addAppShuffleReadingStageDep( + SparkShuffleManager sparkShuffleManager, int appShuffleId, String appShuffleIdentifier) { + sparkShuffleManager + .getStageDepManager() + .addAppShuffleIdReadingStageDep(appShuffleId, appShuffleIdentifier); + } + + public static boolean canInvalidateAllUpstream( + SparkShuffleManager sparkShuffleManager, String appShuffleIdentifier) { + String[] decodedAppShuffleIdentifier = appShuffleIdentifier.split("-"); + return sparkShuffleManager + .getStageDepManager() + .hasAllUpstreamShuffleIdsInfo(Integer.valueOf(decodedAppShuffleIdentifier[1])); + } + + public static void addStageToWriteCelebornShuffleIdDep( + SparkShuffleManager sparkShuffleManager, int celebornShuffeId, String appShuffleIdentifier) { + sparkShuffleManager + .getStageDepManager() + .addStageToCelebornShuffleIdRef(celebornShuffeId, appShuffleIdentifier); + } + + public static void addCelebornToSparkShuffleIdRef( + SparkShuffleManager sparkShuffleManager, int celebornShuffeId, String appShuffleIdentifier) { + sparkShuffleManager + .getStageDepManager() + .addCelebornToAppShuffleIdMapping(celebornShuffeId, appShuffleIdentifier); + } } diff --git a/client-spark/spark-3/src/main/scala/org/apache/spark/celeborn/StageDependencyManager.scala b/client-spark/spark-3/src/main/scala/org/apache/spark/celeborn/StageDependencyManager.scala new file mode 100644 index 00000000000..3b061d37879 --- /dev/null +++ b/client-spark/spark-3/src/main/scala/org/apache/spark/celeborn/StageDependencyManager.scala @@ -0,0 +1,262 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.celeborn + +import java.time.Instant +import java.util +import java.util.concurrent.LinkedBlockingQueue + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import org.apache.spark.CelebornSparkContextHelper +import org.apache.spark.listener.{CelebornShuffleEarlyCleanup, ShuffleStatsTrackingListener} +import org.apache.spark.scheduler.StageInfo +import org.apache.spark.shuffle.celeborn.SparkShuffleManager + +import org.apache.celeborn.common.internal.Logging + +class StageDependencyManager(shuffleManager: SparkShuffleManager) extends Logging { + + // celeborn shuffle id to all stages reading it, this is needed when we determine when to + // clean the shuffle + private[celeborn] val readShuffleToStageDep = new mutable.HashMap[Int, mutable.HashSet[Int]]() + // stage id to all celeborn shuffle ids it reads from, this structure is needed for fast + // tracking when a stage is completed + private val stageToReadCelebornShuffleDep = new mutable.HashMap[Int, mutable.HashSet[Int]]() + // spark stage id to celeborn shuffle id which it writes, + // we need to save this mapping so that we can query which celeborn shuffles is depended on by a + // certain stage when it is submitted + private val stageToCelebornShuffleIdWritten = new mutable.HashMap[Int, Int]() + // app shuffle id to all app shuffle ids it reads from, this structure used as the intermediate data + private val appShuffleIdToUpstream = new mutable.HashMap[Int, mutable.HashSet[Int]]() + // stage id to app shuffle id it writes, this structure used as the intermediate data + // to build appShuffleIdToUpstream and is needed when we need to + // invalidate all app shuffle map output location when a stage is failed + private val stageToAppShuffleIdWritten = new mutable.HashMap[Int, Int]() + + private val celebornToAppShuffleIdentifier = new mutable.HashMap[Int, String]() + private val appShuffleIdentifierToSize = new mutable.HashMap[String, Long]() + + private val shuffleIdsToBeCleaned = new LinkedBlockingQueue[Int]() + + private lazy val cleanInterval = shuffleManager.getLifecycleManager + .conf.clientShuffleEarlyDeletionIntervalMs + + def addShuffleAndStageDep(celebornShuffleId: Int, stageId: Int): Unit = this.synchronized { + val newStageIdSet = + readShuffleToStageDep.getOrElseUpdate(celebornShuffleId, new mutable.HashSet[Int]()) + newStageIdSet += stageId + val newShuffleIdSet = + stageToReadCelebornShuffleDep.getOrElseUpdate(stageId, new mutable.HashSet[Int]()) + newShuffleIdSet += celebornShuffleId + val correctionResult = shuffleIdsToBeCleaned.remove(celebornShuffleId) + if (correctionResult) { + logInfo(s"shuffle $celebornShuffleId is later recognized as needed by stage $stageId, " + + s"removed it from to be cleaned list") + } + } + + private def stageOutputToShuffleOrS3(stageInfo: StageInfo): Boolean = { + stageInfo.taskMetrics.shuffleWriteMetrics.bytesWritten > 0 || + stageInfo.taskMetrics.outputMetrics.bytesWritten > 0 + } + + private def removeStageAndReadInfo(stageId: Int): Unit = { + stageToReadCelebornShuffleDep.remove(stageId) + } + + // it is called when the stage is completed + def addAppShuffleIdentifierToSize(appShuffleIdentifier: String, bytes: Long): Unit = + this.synchronized { + appShuffleIdentifierToSize += appShuffleIdentifier -> bytes + } + + // this is called when a shuffle is cleaned up + def queryShuffleSizeByAppShuffleIdentifier(appShuffleIdentifier: String): Long = + this.synchronized { + appShuffleIdentifierToSize.getOrElse( + appShuffleIdentifier, { + logError(s"unexpected case: cannot find size information for shuffle identifier" + + s" $appShuffleIdentifier") + -1L + }) + } + + def removeShuffleAndStageDep(stageInfo: StageInfo): Unit = this.synchronized { + val stageId = stageInfo.stageId + val allReadCelebornIds = stageToReadCelebornShuffleDep.get(stageId) + allReadCelebornIds.foreach { celebornShuffleIds => + celebornShuffleIds.foreach { celebornShuffleId => + val allStages = readShuffleToStageDep.get(celebornShuffleId) + allStages.foreach { stages => + stages.remove(stageId) + if (stages.nonEmpty) { + readShuffleToStageDep += celebornShuffleId -> stages + } else { + val readyToDelete = { + if (shuffleManager.getLifecycleManager.conf.clientShuffleEarlyDeletionCheckProp) { + val propertySet = System.getProperty("CELEBORN_EARLY_SHUFFLE_DELETION", "false") + propertySet.toBoolean && stageOutputToShuffleOrS3(stageInfo) + } else { + stageOutputToShuffleOrS3(stageInfo) + } + } + if (readyToDelete) { + removeCelebornShuffleInternal(celebornShuffleId, stageId = Some(stageInfo.stageId)) + } else { + logInfo( + s"not ready to delete shuffle $celebornShuffleId while stage $stageId finished") + } + } + } + } + } + } + + private[celeborn] def removeCelebornShuffleInternal( + celebornShuffleId: Int, + stageId: Option[Int]): Unit = { + shuffleIdsToBeCleaned.put(celebornShuffleId) + readShuffleToStageDep.remove(celebornShuffleId) + val appShuffleIdentifierOpt = celebornToAppShuffleIdentifier.get(celebornShuffleId) + if (appShuffleIdentifierOpt.isEmpty) { + logWarning(s"cannot find appShuffleIdentifier for celeborn shuffle: $celebornShuffleId") + return + } + val appShuffleIdentifier = appShuffleIdentifierOpt.get + val Array(appShuffleId, stageOfShuffleBeingDeleted, _) = + appShuffleIdentifier.split('-') + val shuffleSize = queryShuffleSizeByAppShuffleIdentifier(appShuffleIdentifier) + celebornToAppShuffleIdentifier.remove(celebornShuffleId) + logInfo(s"clean up app shuffle id $appShuffleIdentifier," + + s" celeborn shuffle id : $celebornShuffleId") + stageId.foreach(sid => removeStageAndReadInfo(sid)) + // ClientMetricsSystem.updateShuffleWrittenBytes(shuffleSize * -1) + stageId.foreach(sid => + CelebornSparkContextHelper.eventLogger.foreach(e => { + // for shuffles being deleted when no one refers to it, we need to make a record of + // stage reading it to calculate the cost saving accurately + e.onOtherEvent(CelebornShuffleEarlyCleanup( + celebornShuffleId, + appShuffleId.toInt, + stageOfShuffleBeingDeleted.toInt, + shuffleSize, + readStageId = sid, + timeToEnqueue = Instant.now().toEpochMilli)) + })) + } + + def queryCelebornShuffleIdByWriterStageId(stageId: Int): Option[Int] = this.synchronized { + stageToCelebornShuffleIdWritten.get(stageId) + } + + def getAppShuffleIdByStageId(stageId: Int): Int = this.synchronized { + // return -1 means the stage is not writing any shuffle + stageToAppShuffleIdWritten.getOrElse(stageId, -1) + } + + def getAllUpstreamAppShuffleIdsByStageId(stageId: Int): Array[Int] = this.synchronized { + val writtenAppShuffleId = stageToAppShuffleIdWritten.getOrElse( + stageId, + throw new IllegalStateException(s"cannot find app shuffle id written by stage $stageId")) + val allUpstreamAppShuffleIds = appShuffleIdToUpstream.getOrElse( + writtenAppShuffleId, + throw new IllegalStateException(s"cannot find upstream shuffle ids written of shuffle " + + s"$writtenAppShuffleId")) + allUpstreamAppShuffleIds.toArray + } + + def addStageToCelebornShuffleIdRef(celebornShuffleId: Int, appShuffleIdentifier: String): Unit = + this.synchronized { + val Array(appShuffleId, stageId, _) = appShuffleIdentifier.split('-') + stageToCelebornShuffleIdWritten += stageId.toInt -> celebornShuffleId + stageToAppShuffleIdWritten += stageId.toInt -> appShuffleId.toInt + } + + def addCelebornToAppShuffleIdMapping( + celebornShuffleId: Int, + appShuffleIdentifier: String): Unit = { + this.synchronized { + celebornToAppShuffleIdentifier += celebornShuffleId -> appShuffleIdentifier + } + } + + def addCelebornShuffleIdReadingStageDep( + celebornShuffleId: Int, + appShuffleIdentifier: String): Unit = { + this.synchronized { + val Array(_, stageId, _) = appShuffleIdentifier.split('-') + val stageIds = + readShuffleToStageDep.getOrElseUpdate(celebornShuffleId, new mutable.HashSet[Int]()) + stageIds += stageId.toInt + val celebornShuffleIds = + stageToReadCelebornShuffleDep.getOrElseUpdate(stageId.toInt, new mutable.HashSet[Int]()) + celebornShuffleIds += celebornShuffleId + } + } + + def addAppShuffleIdReadingStageDep(appShuffleId: Int, appShuffleIdentifier: String): Unit = { + this.synchronized { + val Array(_, sid, _) = appShuffleIdentifier.split('-') + val stageId = sid.toInt + // update shuffle id to all upstream + if (stageToAppShuffleIdWritten.contains(stageId)) { + val upstreamAppShuffleIds = appShuffleIdToUpstream.getOrElseUpdate( + stageToAppShuffleIdWritten(stageId), + new mutable.HashSet[Int]()) + if (!upstreamAppShuffleIds.contains(appShuffleId)) { + logInfo(s"new upstream shuffleId detected for shuffle" + + s" ${stageToAppShuffleIdWritten(stageId)}, latest: $appShuffleIdToUpstream") + upstreamAppShuffleIds += appShuffleId + } + } + } + } + + def hasAllUpstreamShuffleIdsInfo(stageId: Int): Boolean = this.synchronized { + stageToAppShuffleIdWritten.contains(stageId) && + appShuffleIdToUpstream.contains(stageToAppShuffleIdWritten(stageId)) + } + + private var stopped: Boolean = false + + def start(): Unit = { + val cleanerThread = new Thread() { + override def run(): Unit = { + while (!stopped) { + val allShuffleIds = new util.ArrayList[Int] + shuffleIdsToBeCleaned.drainTo(allShuffleIds) + allShuffleIds.asScala.foreach { shuffleId => + shuffleManager.getLifecycleManager.unregisterShuffle(shuffleId) + logInfo(s"sent unregister shuffle request for shuffle $shuffleId (celeborn shuffle id)") + } + Thread.sleep(cleanInterval) + } + } + } + + cleanerThread.setName("shuffle early cleaner thread") + cleanerThread.setDaemon(true) + cleanerThread.start() + } + + def stop(): Unit = { + stopped = true + } +} diff --git a/client-spark/spark-3/src/main/scala/org/apache/spark/listener/CelebornShuffleEarlyCleanup.scala b/client-spark/spark-3/src/main/scala/org/apache/spark/listener/CelebornShuffleEarlyCleanup.scala new file mode 100644 index 00000000000..02e9c3ac14a --- /dev/null +++ b/client-spark/spark-3/src/main/scala/org/apache/spark/listener/CelebornShuffleEarlyCleanup.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.listener + +import org.apache.spark.scheduler.SparkListenerEvent + +case class CelebornShuffleEarlyCleanup( + celebornShuffleId: Int, + applicationShuffleId: Int, + stageId: Int, + shuffleSizeInBytes: Long, + readStageId: Int, + timeToEnqueue: Long) extends SparkListenerEvent diff --git a/client-spark/spark-3/src/main/scala/org/apache/spark/listener/ListenerHelper.scala b/client-spark/spark-3/src/main/scala/org/apache/spark/listener/ListenerHelper.scala new file mode 100644 index 00000000000..ae7ffe2222b --- /dev/null +++ b/client-spark/spark-3/src/main/scala/org/apache/spark/listener/ListenerHelper.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.listener + +import org.apache.spark.SparkContext +import org.apache.spark.internal.Logging +import org.apache.spark.scheduler.SparkListenerInterface +import org.apache.spark.util.Utils + +object ListenerHelper extends Logging { + + private var listenerAdded: Boolean = false + + def addShuffleStatsTrackingListener(): Unit = this.synchronized { + if (!listenerAdded) { + val sc = SparkContext.getActive.get + val listeners = Utils.loadExtensions( + classOf[SparkListenerInterface], + Seq("org.apache.spark.listener.ShuffleStatsTrackingListener"), + sc.conf) + listeners.foreach { l => sc.listenerBus.addToSharedQueue(l) } + logInfo("registered ShuffleStatsTrackingListener") + listenerAdded = true + } + } + + def reset(): Unit = { + listenerAdded = false + } + +} diff --git a/client-spark/spark-3/src/main/scala/org/apache/spark/listener/ShuffleStatsTrackingListener.scala b/client-spark/spark-3/src/main/scala/org/apache/spark/listener/ShuffleStatsTrackingListener.scala new file mode 100644 index 00000000000..10e80a802cc --- /dev/null +++ b/client-spark/spark-3/src/main/scala/org/apache/spark/listener/ShuffleStatsTrackingListener.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.listener + +import org.apache.spark.SparkEnv +import org.apache.spark.internal.Logging +import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted, SparkListenerStageSubmitted, SparkListenerTaskEnd} +import org.apache.spark.shuffle.celeborn.SparkShuffleManager + +class ShuffleStatsTrackingListener extends SparkListener with Logging { + + override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = { + logInfo(s"stage ${stageSubmitted.stageInfo.stageId}.${stageSubmitted.stageInfo.attemptNumber()} started") + val stageId = stageSubmitted.stageInfo.stageId + val shuffleMgr = SparkEnv.get.shuffleManager.asInstanceOf[SparkShuffleManager] + val parentStages = stageSubmitted.stageInfo.parentIds + if (shuffleMgr.getLifecycleManager.conf.clientShuffleEarlyDeletion) { + parentStages.foreach { parentStageId => + val celebornShuffleId = shuffleMgr.getStageDepManager + .queryCelebornShuffleIdByWriterStageId(parentStageId) + celebornShuffleId.foreach { sid => + shuffleMgr.getStageDepManager.addShuffleAndStageDep(sid, stageId) + } + } + } + } + + override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = { + val stageIdentifier = s"${stageCompleted.stageInfo.stageId}-" + + s"${stageCompleted.stageInfo.attemptNumber()}" + logInfo(s"stage $stageIdentifier finished with" + + s" ${stageCompleted.stageInfo.taskMetrics.shuffleWriteMetrics.bytesWritten} shuffle bytes") + val shuffleMgr = SparkEnv.get.shuffleManager.asInstanceOf[SparkShuffleManager] + if (shuffleMgr.getLifecycleManager.conf.clientShuffleEarlyDeletion) { + val shuffleIdOpt = stageCompleted.stageInfo.shuffleDepId + shuffleIdOpt.foreach { appShuffleId => + val appShuffleIdentifier = s"$appShuffleId-${stageCompleted.stageInfo.stageId}-" + + s"${stageCompleted.stageInfo.attemptNumber()}" + shuffleMgr.getStageDepManager.addAppShuffleIdentifierToSize( + appShuffleIdentifier, + stageCompleted.stageInfo.taskMetrics.shuffleWriteMetrics.bytesWritten) + } + } + if (shuffleMgr.getLifecycleManager.conf.clientShuffleEarlyDeletion && + stageCompleted.stageInfo.failureReason.isEmpty) { + shuffleMgr.getStageDepManager.removeShuffleAndStageDep(stageCompleted.stageInfo) + } + } +} diff --git a/docs/configuration/client.md b/docs/configuration/client.md index fb56d8d7252..f5db7bc6f64 100644 --- a/docs/configuration/client.md +++ b/docs/configuration/client.md @@ -125,6 +125,9 @@ license: | | celeborn.client.slot.assign.maxWorkers | 10000 | false | Max workers that slots of one shuffle can be allocated on. Will choose the smaller positive one from Master side and Client side, see `celeborn.master.slot.assign.maxWorkers`. | 0.3.1 | | | celeborn.client.spark.fetch.cleanFailedShuffle | false | false | whether to clean those disk space occupied by shuffles which cannot be fetched | 0.6.0 | | | celeborn.client.spark.fetch.cleanFailedShuffleInterval | 1s | false | the interval to clean the failed-to-fetch shuffle files, only valid when celeborn.client.spark.fetch.cleanFailedShuffle is enabled | 0.6.0 | | +| celeborn.client.spark.fetch.shuffleEarlyDeletion | false | false | whether to delete shuffle when we determine a shuffle is not needed by any stage | 0.7.0 | | +| celeborn.client.spark.fetch.shuffleEarlyDeletion.checkProperty | false | false | when this is enabled, we only early delete shuffle when "CELEBORN_EARLY_SHUFFLE_DELETION" property is set to true | 0.7.0 | | +| celeborn.client.spark.fetch.shuffleEarlyDeletion.intervalMs | 300000 | false | interval length to delete unused shuffle (ms) | 0.7.0 | | | celeborn.client.spark.push.dynamicWriteMode.enabled | false | false | Whether to dynamically switch push write mode based on conditions.If true, shuffle mode will be only determined by partition count | 0.5.0 | | | celeborn.client.spark.push.dynamicWriteMode.partitionNum.threshold | 2000 | false | Threshold of shuffle partition number for dynamically switching push writer mode. When the shuffle partition number is greater than this value, use the sort-based shuffle writer for memory efficiency; otherwise use the hash-based shuffle writer for speed. This configuration only takes effect when celeborn.client.spark.push.dynamicWriteMode.enabled is true. | 0.5.0 | | | celeborn.client.spark.push.sort.memory.maxMemoryFactor | 0.4 | false | the max portion of executor memory which can be used for SortBasedWriter buffer (only valid when celeborn.client.spark.push.sort.memory.useAdaptiveThreshold is enabled | 0.5.0 | | From 81791b9bf6171c2782fcc910f331e484d27301e6 Mon Sep 17 00:00:00 2001 From: CodingCat Date: Tue, 16 Dec 2025 10:01:46 -0800 Subject: [PATCH 04/20] client related changes --- .../celeborn/CelebornShuffleReader.scala | 105 +++++++++++++----- .../celeborn/client/DummyShuffleClient.java | 10 ++ .../apache/celeborn/client/ShuffleClient.java | 11 ++ .../celeborn/client/ShuffleClientImpl.java | 32 ++++++ 4 files changed, 131 insertions(+), 27 deletions(-) diff --git a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala index 55e036155b2..57057157e25 100644 --- a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala +++ b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala @@ -47,6 +47,7 @@ import org.apache.celeborn.common.protocol._ import org.apache.celeborn.common.protocol.message.{ControlMessages, StatusCode} import org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse import org.apache.celeborn.common.util.{JavaUtils, ThreadUtils, Utils} +import org.apache.celeborn.common.util.Utils.{KNOWN_MISSING_CELEBORN_SHUFFLE_ID, UNKNOWN_MISSING_CELEBORN_SHUFFLE_ID} class CelebornShuffleReader[K, C]( handle: CelebornShuffleHandle[K, _, C], @@ -99,32 +100,82 @@ class CelebornShuffleReader[K, C]( private val pushReplicateEnabled = conf.clientPushReplicateEnabled private val preferReplicaRead = context.attemptNumber % 2 == 1 - override def read(): Iterator[Product2[K, C]] = { + private def throwFetchFailureForMissingId(partitionId: Int, celebornShuffleId: Int): Unit = { + throw new FetchFailedException( + null, + handle.shuffleId, + -1, + -1, + partitionId, + SparkUtils.FETCH_FAILURE_ERROR_MSG + celebornShuffleId, + new CelebornIOException(s"cannot find shuffle id for ${handle.shuffleId}")) + } - val startTime = System.currentTimeMillis() - val serializerInstance = newSerializerInstance(dep) - val shuffleId = - try { - SparkUtils.celebornShuffleId(shuffleClient, handle, context, false) - } catch { - case e: CelebornRuntimeException => - logError(s"Failed to get shuffleId for appShuffleId ${handle.shuffleId}", e) - if (stageRerunEnabled) { - throw new FetchFailedException( - null, + private def handleMissingCelebornShuffleId(celebornShuffleId: Int, stageId: Int): Unit = { + if (conf.clientShuffleEarlyDeletion) { + if (celebornShuffleId == UNKNOWN_MISSING_CELEBORN_SHUFFLE_ID) { + logError(s"cannot find celeborn shuffle id for app shuffle ${handle.shuffleId} which " + + s"never appear before, throwing FetchFailureException") + (startPartition until endPartition).foreach(partitionId => { + if (stageRerunEnabled && + shuffleClient.reportMissingShuffleId( handle.shuffleId, - -1, - -1, - startPartition, - SparkUtils.FETCH_FAILURE_ERROR_MSG + handle.shuffleId + "/" + handle.shuffleId, - e) + context.stageId(), + context.stageAttemptNumber())) { + throwFetchFailureForMissingId(partitionId, celebornShuffleId) } else { + val e = new IllegalStateException(s"failed to handle missing celeborn id for app" + + s" shuffle ${handle.shuffleId}") + logError(s"failed to handle missing celeborn id for app shuffle ${handle.shuffleId}", e) throw e } + }) + } else if (celebornShuffleId == KNOWN_MISSING_CELEBORN_SHUFFLE_ID) { + logError(s"cannot find celeborn shuffle id for app shuffle ${handle.shuffleId} which " + + s"has appeared before, invalidating all upstream shuffle of this shuffle") + (startPartition until endPartition).foreach(partitionId => { + if (stageRerunEnabled) { + val invalidateAllUpstreamRet = shuffleClient.invalidateAllUpstreamShuffle( + context.stageId(), + context.stageAttemptNumber(), + handle.shuffleId) + if (invalidateAllUpstreamRet) { + throwFetchFailureForMissingId(partitionId, celebornShuffleId) + } else { + // if we cannot invalidate all upstream, we need to report regular fetch failure + // for this particular shuffle id + val fetchFailureResponse = shuffleClient.reportMissingShuffleId( + handle.shuffleId, + context.stageId(), + context.stageAttemptNumber()) + if (fetchFailureResponse) { + throwFetchFailureForMissingId(partitionId, UNKNOWN_MISSING_CELEBORN_SHUFFLE_ID) + } else { + val e = new IllegalStateException(s"failed to handle missing celeborn id for app" + + s" shuffle ${handle.shuffleId}") + logError( + s"failed to handle missing celeborn id for app shuffle" + + s" ${handle.shuffleId}", + e) + throw e + } + } + } + }) } - shuffleIdTracker.track(handle.shuffleId, shuffleId) + } + } + + override def read(): Iterator[Product2[K, C]] = { + + val startTime = System.currentTimeMillis() + val serializerInstance = newSerializerInstance(dep) + val celebornShuffleId = SparkUtils.celebornShuffleId(shuffleClient, handle, context, false) + + handleMissingCelebornShuffleId(celebornShuffleId, context.stageId()) logDebug( - s"get shuffleId $shuffleId for appShuffleId ${handle.shuffleId} attemptNum ${context.stageAttemptNumber()}") + s"get shuffleId $celebornShuffleId for appShuffleId ${handle.shuffleId}" + + s" attemptNum ${context.stageAttemptNumber()}") // Update the context task metrics for each record read. val metricsCallback = new MetricsCallback { @@ -151,22 +202,22 @@ class CelebornShuffleReader[K, C]( val fetchTimeoutMs = conf.clientFetchTimeoutMs val localFetchEnabled = conf.enableReadLocalShuffleFile val localHostAddress = Utils.localHostName(conf) - val shuffleKey = Utils.makeShuffleKey(handle.appUniqueId, shuffleId) + val shuffleKey = Utils.makeShuffleKey(handle.appUniqueId, celebornShuffleId) var fileGroups: ReduceFileGroups = null var isShuffleStageEnd: Boolean = false var updateFileGroupsRetryTimes = 0 do { isShuffleStageEnd = try { - shuffleClient.isShuffleStageEnd(shuffleId) + shuffleClient.isShuffleStageEnd(celebornShuffleId) } catch { case e: Exception => - logInfo(s"Failed to check shuffle stage end for $shuffleId, assume ended", e) + logInfo(s"Failed to check shuffle stage end for $celebornShuffleId, assume ended", e) true } try { // startPartition is irrelevant - fileGroups = shuffleClient.updateFileGroup(shuffleId, startPartition) + fileGroups = shuffleClient.updateFileGroup(celebornShuffleId, startPartition) } catch { case ce: CelebornIOException if ce.getCause != null && ce.getCause.isInstanceOf[ @@ -180,7 +231,7 @@ class CelebornShuffleReader[K, C]( // if a task is interrupted, should not report fetch failure // if a task update file group timeout, should not report fetch failure // if a task GetReducerFileGroupResponse failed via broadcast, should not report fetch failure - checkAndReportFetchFailureForUpdateFileGroupFailure(shuffleId, ce) + checkAndReportFetchFailureForUpdateFileGroupFailure(celebornShuffleId, ce) } } while (fileGroups == null) @@ -333,7 +384,7 @@ class CelebornShuffleReader[K, C]( if (exceptionRef.get() == null) { try { val inputStream = shuffleClient.readPartition( - shuffleId, + celebornShuffleId, handle.shuffleId, partitionId, encodedAttemptId, @@ -380,7 +431,7 @@ class CelebornShuffleReader[K, C]( if (exceptionRef.get() != null) { exceptionRef.get() match { case ce @ (_: CelebornIOException | _: PartitionUnRetryAbleException) => - handleFetchExceptions(handle.shuffleId, shuffleId, partitionId, ce) + handleFetchExceptions(handle.shuffleId, celebornShuffleId, partitionId, ce) case e => throw e } } @@ -424,7 +475,7 @@ class CelebornShuffleReader[K, C]( iter } catch { case e @ (_: CelebornIOException | _: PartitionUnRetryAbleException) => - handleFetchExceptions(handle.shuffleId, shuffleId, partitionId, e) + handleFetchExceptions(handle.shuffleId, celebornShuffleId, partitionId, e) } } diff --git a/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java b/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java index 69cc3cd6f9c..1a8d78b2d97 100644 --- a/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java +++ b/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java @@ -236,4 +236,14 @@ public void initReducePartitionMap(int shuffleId, int numPartitions, int workerN public Map> getReducePartitionMap() { return reducePartitionMap; } + + @Override + public boolean invalidateAllUpstreamShuffle(int stageId, int attemptId, int appShuffleId) { + return true; + } + + @Override + public boolean reportMissingShuffleId(int appShuffleId, int readerStageId, int stageAttemptId) { + return true; + } } diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java index d37ec644231..d2371b1059f 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java @@ -332,4 +332,15 @@ public static ControlMessages.GetReducerFileGroupResponse deserializeReducerFile } return deserializeReducerFileGroupResponseFunction.get().apply(shuffleId, bytes); } + + /** + * report fetch failure for all upstream shuffles for a given stage id, It must be a sync call and + * make sure the cleanup is done, otherwise, incorrect shuffle data can be fetched in re-run tasks + */ + public abstract boolean invalidateAllUpstreamShuffle( + int stageId, int attemptId, int triggerAppShuffleId); + + /** report the failure to find the corresponding celeborn id for a shuffle id */ + public abstract boolean reportMissingShuffleId( + int appShuffleId, int readerStageId, int stageAttemptId); } diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java index cfe40a2968f..85b844c3f47 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java @@ -2159,4 +2159,36 @@ public void excludeFailedFetchLocation(String hostAndFetchPort, Exception e) { fetchExcludedWorkers.put(hostAndFetchPort, System.currentTimeMillis()); } } + + @Override + public boolean invalidateAllUpstreamShuffle(int stageId, int attemptId, int triggerAppId) { + PbInvalidateAllUpstreamShuffle pbInvalidateAllUpstreamShuffle = + PbInvalidateAllUpstreamShuffle.newBuilder() + .setReaderStageId(stageId) + .setAttemptId(attemptId) + .setTriggerAppShuffleId(triggerAppId) + .build(); + PbInvalidateAllUpstreamShuffleResponse pbInvalidateAllUpstreamShuffleResponse = + lifecycleManagerRef.askSync( + pbInvalidateAllUpstreamShuffle, + conf.clientRpcRegisterShuffleAskTimeout(), + ClassTag$.MODULE$.apply(PbInvalidateAllUpstreamShuffleResponse.class)); + return pbInvalidateAllUpstreamShuffleResponse.getSuccess(); + } + + @Override + public boolean reportMissingShuffleId(int appShuffleId, int readerStageId, int stageAttemptId) { + PbReportMissingShuffleId pbReportMissingShuffleId = + PbReportMissingShuffleId.newBuilder() + .setReaderStageId(readerStageId) + .setAttemptId(stageAttemptId) + .setTriggerAppShuffleId(appShuffleId) + .build(); + PbReportMissingShuffleIdResponse response = + lifecycleManagerRef.askSync( + pbReportMissingShuffleId, + conf.clientRpcRegisterShuffleAskTimeout(), + ClassTag$.MODULE$.apply(PbReportMissingShuffleIdResponse.class)); + return response.getSuccess(); + } } From ca177e1816908a0d6b37afc396dab705290a620d Mon Sep 17 00:00:00 2001 From: CodingCat Date: Tue, 16 Dec 2025 11:52:49 -0800 Subject: [PATCH 05/20] fix tests by setting success of getShuffleIdResponse --- .../scala/org/apache/celeborn/client/LifecycleManager.scala | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala index 64795b0c05c..670b5cb9fb5 100644 --- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala +++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala @@ -1101,7 +1101,10 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends val pbGetShuffleIdResponse = { logDebug( s"get shuffleId $celebornShuffleId for appShuffleId $appShuffleId appShuffleIdentifier $appShuffleIdentifier isWriter $isWriter") - PbGetShuffleIdResponse.newBuilder().setShuffleId(celebornShuffleId).build() + PbGetShuffleIdResponse.newBuilder() + .setShuffleId(celebornShuffleId) + .setSuccess(true) + .build() } context.reply(pbGetShuffleIdResponse) found = true @@ -1127,6 +1130,7 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends UNKNOWN_MISSING_CELEBORN_SHUFFLE_ID } }) + .setSuccess(true) .build() context.reply(pbGetShuffleIdResponse) } else { From a798a42a495dfd098ffe5686a93d837aab73f6fe Mon Sep 17 00:00:00 2001 From: CodingCat Date: Tue, 16 Dec 2025 13:11:58 -0800 Subject: [PATCH 06/20] remove ListenerHelper and add back mis-deleted registerShuffleTrackerCallback --- .../shuffle/celeborn/SparkShuffleManager.java | 7 ++- .../spark/listener/ListenerHelper.scala | 46 ------------------- 2 files changed, 5 insertions(+), 48 deletions(-) delete mode 100644 client-spark/spark-3/src/main/scala/org/apache/spark/listener/ListenerHelper.scala diff --git a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java index 27266103dcf..aadc8464394 100644 --- a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java +++ b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java @@ -25,7 +25,7 @@ import org.apache.spark.celeborn.StageDependencyManager; import org.apache.spark.internal.config.package$; import org.apache.spark.launcher.SparkLauncher; -import org.apache.spark.listener.ListenerHelper; +import org.apache.spark.listener.ShuffleStatsTrackingListener; import org.apache.spark.rdd.DeterministicLevel; import org.apache.spark.shuffle.*; import org.apache.spark.shuffle.sort.SortShuffleManager; @@ -160,6 +160,9 @@ private void initializeLifecycleManager(String appId) { taskId -> SparkUtils.shouldReportShuffleFetchFailure(taskId)); SparkUtils.addSparkListener(new ShuffleFetchFailureReportTaskCleanListener()); + lifecycleManager.registerShuffleTrackerCallback( + shuffleId -> SparkUtils.unregisterAllMapOutput(mapOutputTracker, shuffleId)); + if (celebornConf.clientAdaptiveOptimizeSkewedPartitionReadEnabled()) { lifecycleManager.registerCelebornSkewShuffleCheckCallback( SparkUtils::isCelebornSkewShuffleOrChildShuffle); @@ -185,7 +188,7 @@ private void initializeLifecycleManager(String appId) { if (lifecycleManager.conf().clientShuffleEarlyDeletion()) { logger.info("register early deletion callbacks"); - ListenerHelper.addShuffleStatsTrackingListener(); + SparkUtils.addSparkListener(new ShuffleStatsTrackingListener()); lifecycleManager.registerStageToWriteCelebornShuffleCallback( (celebornShuffleId, appShuffleIdentifier) -> SparkUtils.addStageToWriteCelebornShuffleIdDep( diff --git a/client-spark/spark-3/src/main/scala/org/apache/spark/listener/ListenerHelper.scala b/client-spark/spark-3/src/main/scala/org/apache/spark/listener/ListenerHelper.scala deleted file mode 100644 index ae7ffe2222b..00000000000 --- a/client-spark/spark-3/src/main/scala/org/apache/spark/listener/ListenerHelper.scala +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.listener - -import org.apache.spark.SparkContext -import org.apache.spark.internal.Logging -import org.apache.spark.scheduler.SparkListenerInterface -import org.apache.spark.util.Utils - -object ListenerHelper extends Logging { - - private var listenerAdded: Boolean = false - - def addShuffleStatsTrackingListener(): Unit = this.synchronized { - if (!listenerAdded) { - val sc = SparkContext.getActive.get - val listeners = Utils.loadExtensions( - classOf[SparkListenerInterface], - Seq("org.apache.spark.listener.ShuffleStatsTrackingListener"), - sc.conf) - listeners.foreach { l => sc.listenerBus.addToSharedQueue(l) } - logInfo("registered ShuffleStatsTrackingListener") - listenerAdded = true - } - } - - def reset(): Unit = { - listenerAdded = false - } - -} From 7a11ed35496e5a54874937cbc3fc7cc51e82752e Mon Sep 17 00:00:00 2001 From: CodingCat Date: Tue, 16 Dec 2025 13:14:28 -0800 Subject: [PATCH 07/20] stylistic fix --- .../org/apache/spark/shuffle/celeborn/SparkShuffleManager.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java index aadc8464394..4eef7a862d6 100644 --- a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java +++ b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java @@ -161,7 +161,7 @@ private void initializeLifecycleManager(String appId) { SparkUtils.addSparkListener(new ShuffleFetchFailureReportTaskCleanListener()); lifecycleManager.registerShuffleTrackerCallback( - shuffleId -> SparkUtils.unregisterAllMapOutput(mapOutputTracker, shuffleId)); + shuffleId -> SparkUtils.unregisterAllMapOutput(mapOutputTracker, shuffleId)); if (celebornConf.clientAdaptiveOptimizeSkewedPartitionReadEnabled()) { lifecycleManager.registerCelebornSkewShuffleCheckCallback( From 90abb89c4a12b5d3bbc117eef6038b3b4965605a Mon Sep 17 00:00:00 2001 From: CodingCat Date: Wed, 17 Dec 2025 09:46:12 -0800 Subject: [PATCH 08/20] attempt to add early delete test --- .../celeborn/StageDependencyManager.scala | 2 +- .../CelebornFetchFailureDiskCleanSuite.scala | 71 +-- .../CelebornShuffleEarlyDeleteSuite.scala | 447 ++++++++++++++++++ .../tests/spark/StorageCheckUtils.scala | 92 ++++ .../fetch/failure/ShuffleReaderGetHooks.scala | 57 ++- 5 files changed, 597 insertions(+), 72 deletions(-) create mode 100644 tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornShuffleEarlyDeleteSuite.scala create mode 100644 tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/StorageCheckUtils.scala diff --git a/client-spark/spark-3/src/main/scala/org/apache/spark/celeborn/StageDependencyManager.scala b/client-spark/spark-3/src/main/scala/org/apache/spark/celeborn/StageDependencyManager.scala index 3b061d37879..badb587924f 100644 --- a/client-spark/spark-3/src/main/scala/org/apache/spark/celeborn/StageDependencyManager.scala +++ b/client-spark/spark-3/src/main/scala/org/apache/spark/celeborn/StageDependencyManager.scala @@ -129,7 +129,7 @@ class StageDependencyManager(shuffleManager: SparkShuffleManager) extends Loggin } } - private[celeborn] def removeCelebornShuffleInternal( + def removeCelebornShuffleInternal( celebornShuffleId: Int, stageId: Option[Int]): Unit = { shuffleIdsToBeCleaned.put(celebornShuffleId) diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornFetchFailureDiskCleanSuite.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornFetchFailureDiskCleanSuite.scala index 936ea696113..a970db43bf8 100644 --- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornFetchFailureDiskCleanSuite.scala +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornFetchFailureDiskCleanSuite.scala @@ -16,8 +16,6 @@ */ package org.apache.celeborn.tests.spark -import java.io.File - import org.apache.spark.SparkConf import org.apache.spark.shuffle.celeborn.{SparkUtils, TestCelebornShuffleManager} import org.apache.spark.sql.SparkSession @@ -73,10 +71,10 @@ class CelebornFetchFailureDiskCleanSuite extends AnyFunSuite shuffleIdToBeDeleted = Seq(0)) TestCelebornShuffleManager.registerReaderGetHook(hook) val checkingThread = - triggerStorageCheckThread(Seq(0), Seq(1), sparkSession) + StorageCheckUtils.triggerStorageCheckThread(workerDirs, Seq(0), Seq(1), sparkSession) val tuples = sparkSession.sparkContext.parallelize(1 to 10000, 2) .map { i => (i, i) }.groupByKey(4).collect() - checkStorageValidation(checkingThread) + StorageCheckUtils.checkStorageValidation(checkingThread) // verify result assert(hook.executed.get()) assert(tuples.length == 10000) @@ -86,69 +84,4 @@ class CelebornFetchFailureDiskCleanSuite extends AnyFunSuite sparkSession.stop() } } - - class CheckingThread( - shuffleIdShouldNotExist: Seq[Int], - shuffleIdMustExist: Seq[Int], - sparkSession: SparkSession) - extends Thread { - var exception: Exception = _ - - protected def checkDirStatus(): Boolean = { - val deletedSuccessfully = shuffleIdShouldNotExist.forall(shuffleId => { - workerDirs.forall(dir => - !new File(s"$dir/celeborn-worker/shuffle_data/" + - s"${sparkSession.sparkContext.applicationId}/$shuffleId").exists()) - }) - val deletedSuccessfullyString = shuffleIdShouldNotExist.map(shuffleId => { - shuffleId.toString + ":" + - workerDirs.map(dir => - !new File(s"$dir/celeborn-worker/shuffle_data/" + - s"${sparkSession.sparkContext.applicationId}/$shuffleId").exists()).toList - }).mkString(",") - val createdSuccessfully = shuffleIdMustExist.forall(shuffleId => { - workerDirs.exists(dir => - new File(s"$dir/celeborn-worker/shuffle_data/" + - s"${sparkSession.sparkContext.applicationId}/$shuffleId").exists()) - }) - val createdSuccessfullyString = shuffleIdMustExist.map(shuffleId => { - shuffleId.toString + ":" + - workerDirs.map(dir => - new File(s"$dir/celeborn-worker/shuffle_data/" + - s"${sparkSession.sparkContext.applicationId}/$shuffleId").exists()).toList - }).mkString(",") - println(s"shuffle-to-be-deleted status: $deletedSuccessfullyString \n" + - s"shuffle-to-be-created status: $createdSuccessfullyString") - deletedSuccessfully && createdSuccessfully - } - - override def run(): Unit = { - var allDataInShape = checkDirStatus() - while (!allDataInShape) { - Thread.sleep(1000) - allDataInShape = checkDirStatus() - } - } - } - - protected def triggerStorageCheckThread( - shuffleIdShouldNotExist: Seq[Int], - shuffleIdMustExist: Seq[Int], - sparkSession: SparkSession): CheckingThread = { - val checkingThread = - new CheckingThread(shuffleIdShouldNotExist, shuffleIdMustExist, sparkSession) - checkingThread.setDaemon(true) - checkingThread.start() - checkingThread - } - - protected def checkStorageValidation(thread: Thread, timeout: Long = 1200 * 1000): Unit = { - val checkingThread = thread.asInstanceOf[CheckingThread] - checkingThread.join(timeout) - if (checkingThread.isAlive || checkingThread.exception != null) { - throw new IllegalStateException("the storage checking status failed," + - s"${checkingThread.isAlive} ${if (checkingThread.exception != null) checkingThread.exception.getMessage - else "NULL"}") - } - } } diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornShuffleEarlyDeleteSuite.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornShuffleEarlyDeleteSuite.scala new file mode 100644 index 00000000000..947702d3f90 --- /dev/null +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornShuffleEarlyDeleteSuite.scala @@ -0,0 +1,447 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.tests.spark + +import org.apache.spark.SparkConf +import org.apache.spark.shuffle.celeborn.{SparkUtils, TestCelebornShuffleManager} +import org.apache.spark.sql.SparkSession +import org.apache.celeborn.client.ShuffleClient +import org.apache.celeborn.common.CelebornConf +import org.apache.celeborn.common.protocol.ShuffleMode +import org.apache.celeborn.service.deploy.worker.Worker +import org.apache.celeborn.tests.spark.fetch.failure.FailedCommitAndExpireDataReaderHook + +class CelebornShuffleEarlyDeleteSuite extends SparkTestBase { + + override def beforeAll(): Unit = { + logInfo("test initialized , setup Celeborn mini cluster") + setupMiniClusterWithRandomPorts(workerNum = 1) + } + + override def beforeEach(): Unit = { + ShuffleClient.reset() + } + + override def afterEach(): Unit = { + System.gc() + } + + override def createWorker(map: Map[String, String]): Worker = { + val storageDir = createTmpDir() + workerDirs = workerDirs :+ storageDir + super.createWorker(map ++ Map("celeborn.master.heartbeat.worker.timeout" -> "10s"), storageDir) + } + + private def createSparkSession(additionalConf: Map[String, String] = Map()): SparkSession = { + var builder = SparkSession + .builder() + .master("local[*, 4]") + .appName("celeborn early delete") + .config(updateSparkConf(new SparkConf(), ShuffleMode.SORT)) + .config("spark.sql.shuffle.partitions", 2) + .config("spark.celeborn.shuffle.forceFallback.partition.enabled", false) + .config("spark.celeborn.shuffle.enabled", "true") + .config("spark.celeborn.client.shuffle.expired.checkInterval", "1s") + .config("spark.kryoserializer.buffer.max", "2047m") + .config( + "spark.shuffle.manager", + "org.apache.spark.shuffle.celeborn.TestCelebornShuffleManager") + .config("spark.celeborn.client.spark.fetch.throwsFetchFailure", "true") + .config(s"spark.${CelebornConf.CLIENT_SHUFFLE_EARLY_DELETION.key}", "true") + .config(s"spark.${CelebornConf.CLIENT_SHUFFLE_EARLY_DELETION_INTERVAL_MS.key}", "1000") + additionalConf.foreach { case (k, v) => + builder = builder.config(k, v) + } + builder.getOrCreate() + } + + /* + test("spark integration test - delete shuffle data from unneeded stages") { + if (Spark3OrNewer) { + val spark = createSparkSession() + try { + val rdd1 = spark.sparkContext.parallelize(0 until 20, 3).repartition(2) + .repartition(4) + val t = new Thread() { + override def run(): Unit = { + // shuffle 1 + rdd1.mapPartitions(iter => { + Thread.sleep(20000) + iter + }).count() + } + } + t.start() + val thread = StorageCheckUtils.triggerStorageCheckThread( + workerDirs, + shuffleIdShouldNotExist = Seq(0, 2), // guard on 2 to prevent any stage retry + shuffleIdMustExist = Seq(1), + sparkSession = spark) + StorageCheckUtils.checkStorageValidation(thread) + t.join() + } finally { + spark.stop() + } + } + } + + test("spark integration test - delete shuffle data only when all child stages finished") { + if (Spark3OrNewer) { + val spark = createSparkSession() + try { + val rdd1 = spark.sparkContext.parallelize(0 until 20, 3).repartition(2) + val rdd2 = rdd1.repartition(4) + val rdd3 = rdd1.repartition(4) + val t = new Thread() { + override def run(): Unit = { + rdd2.union(rdd3).mapPartitions(iter => { + Thread.sleep(20000) + iter + }).count() + } + } + t.start() + val thread = StorageCheckUtils.triggerStorageCheckThread( + workerDirs, + shuffleIdShouldNotExist = Seq(0, 3), // guard on 3 to prevent any stage retry + shuffleIdMustExist = Seq(1, 2), + sparkSession = spark) + StorageCheckUtils.checkStorageValidation(thread) + t.join() + } finally { + spark.stop() + } + } + } + + test("spark integration test - delete shuffle data only when all child stages finished" + + " (multi-level lineage)") { + if (Spark3OrNewer) { + val spark = createSparkSession() + try { + val rdd1 = spark.sparkContext.parallelize(0 until 20, 3).repartition(2) + val rdd2 = rdd1.repartition(4).repartition(2) + val rdd3 = rdd1.repartition(4).repartition(2) + val t = new Thread() { + override def run(): Unit = { + rdd2.union(rdd3).mapPartitions(iter => { + Thread.sleep(20000) + iter + }).count() + } + } + t.start() + val thread = StorageCheckUtils.triggerStorageCheckThread( + workerDirs, + shuffleIdShouldNotExist = Seq(0, 1, 2, 5), // guard on 5 to prevent any stage retry + shuffleIdMustExist = Seq(3, 4), + sparkSession = spark) + StorageCheckUtils.checkStorageValidation(thread) + t.join() + } finally { + spark.stop() + } + } + } + + test("spark integration test - when the stage has a skipped parent stage, we should still be" + + " able to delete data") { + if (Spark3OrNewer) { + val spark = createSparkSession() + try { + val rdd1 = spark.sparkContext.parallelize(0 until 20, 3).repartition(2) + rdd1.count() + val t = new Thread() { + override def run(): Unit = { + rdd1.mapPartitions(iter => { + Thread.sleep(20000) + iter + }).repartition(3).count() + } + } + t.start() + val thread = StorageCheckUtils.triggerStorageCheckThread( + workerDirs, + shuffleIdShouldNotExist = Seq(0, 2), + shuffleIdMustExist = Seq(1), + sparkSession = spark) + StorageCheckUtils.checkStorageValidation(thread) + t.join() + } finally { + spark.stop() + } + } + }*/ + + private def deleteTooEarlyTest( + shuffleIdShouldNotExist: Seq[Int], + shuffleIdMustExist: Seq[Int], + spark: SparkSession): Unit = { + if (Spark3OrNewer) { + var r = 0L + try { + // shuffle 0 + val rdd1 = spark.sparkContext.parallelize((0 until 20), 3).repartition(2) + rdd1.count() + val t = new Thread() { + override def run(): Unit = { + // shuffle 1 + val rdd2 = rdd1.mapPartitions(iter => { + Thread.sleep(10000) + iter + }).repartition(3) + rdd2.count() + println("rdd2.count() finished") + // leaving enough time for shuffle 0 to be expired + Thread.sleep(10000) + // shuffle 2 + val rdd3 = rdd1.repartition(5).mapPartitions(iter => { + Thread.sleep(10000) + iter + }) + r = rdd3.count() + } + } + t.start() + val thread = StorageCheckUtils.triggerStorageCheckThread( + workerDirs, + shuffleIdShouldNotExist = shuffleIdShouldNotExist, + shuffleIdMustExist = shuffleIdMustExist, + sparkSession = spark) + StorageCheckUtils.checkStorageValidation(thread) + t.join() + assert(r === 20) + } finally { + spark.stop() + } + } + } + + test("spark integration test - do not fail job when shuffle is deleted \"too early\"") { + val spark = createSparkSession() + deleteTooEarlyTest(Seq(0, 3, 5), Seq(1, 2, 4), spark) + } + + // test("spark integration test - do not fail job when shuffle is deleted \"too early\"" + + // " (with failed shuffle deletion)") { + // val spark = createSparkSession( + // Map(s"spark.${CelebornConf.CLIENT_FETCH_CLEAN_FAILED_SHUFFLE.key}" -> "true")) + // deleteTooEarlyTest(Seq(0, 2, 3, 5), Seq(1, 4), spark) + // } + + test("spark integration test - do not fail job when shuffle files" + + " are deleted \"too early\" (ancestor dependency)") { + val spark = createSparkSession() + if (Spark3OrNewer) { + var r = 0L + try { + // shuffle 0 + val rdd1 = spark.sparkContext.parallelize((0 until 20), 3).repartition(2) + rdd1.count() + val t = new Thread() { + override def run(): Unit = { + // shuffle 1 + val rdd2 = rdd1.repartition(3) + rdd2.count() + println("rdd2.count finished()") + // leaving enough time for shuffle 0 to be expired + Thread.sleep(10000) + // shuffle 2 + rdd2.repartition(4).count() + // leaving enough time for shuffle 1 to be expired + Thread.sleep(10000) + val rdd4 = rdd1.union(rdd2).mapPartitions(iter => { + Thread.sleep(10000) + iter + }) + r = rdd4.count() + } + } + t.start() + val thread = StorageCheckUtils.triggerStorageCheckThread( + workerDirs, + shuffleIdShouldNotExist = Seq(0, 1, 5), + shuffleIdMustExist = Seq(3, 4), + sparkSession = spark) + StorageCheckUtils.checkStorageValidation(thread) + t.join() + assert(r === 40) + } finally { + spark.stop() + } + } + } + + test("spark integration test - do not fail job when multiple shuffles (be unioned)" + + " are deleted \"too early\"") { + if (Spark3OrNewer) { + val spark = createSparkSession() + var r = 0L + try { + // shuffle 0&1 + val rdd1 = spark.sparkContext.parallelize((0 until 20), 3).repartition(2) + val rdd2 = spark.sparkContext.parallelize((0 until 30), 3).repartition(2) + rdd1.count() + rdd2.count() + val t = new Thread() { + override def run(): Unit = { + // shuffle 2&3 + val rdd3 = rdd1.repartition(3) + val rdd4 = rdd2.repartition(3) + rdd3.count() + rdd4.count() + // leaving enough time for shuffle 0&1 to be expired + Thread.sleep(10000) + // shuffle 4&5 + rdd3.repartition(4).count() + rdd4.repartition(4).count() + // leaving enough time for shuffle 2&3 to be expired + Thread.sleep(10000) + val rdd5 = rdd3.union(rdd4).mapPartitions(iter => { + Thread.sleep(10000) + iter + }) + r = rdd5.count() + } + } + t.start() + val thread = StorageCheckUtils.triggerStorageCheckThread( + workerDirs, + // 4,5 are based on vanilla spark gc which are not necessarily stable in a test + // 6,7 is based on failed shuffle cleanup, which is not covered here + shuffleIdShouldNotExist = Seq(0, 1, 2, 3, 8, 9, 12), + shuffleIdMustExist = Seq(10, 11), + sparkSession = spark) + StorageCheckUtils.checkStorageValidation(thread) + t.join() + assert(r === 50) + } finally { + spark.stop() + } + } + } + + // test("spark integration test - do not fail job when multiple shuffles (be zipped/joined)" + + // " are deleted \"too early\"") { + // if (runningWithSpark3OrNewer()) { + // val spark = createSparkSession( + // Map(s"spark.${CelebornConf.CLIENT_FETCH_CLEAN_FAILED_SHUFFLE.key}" -> "true")) + // var r = 0L + // try { + // // shuffle 0&1 + // val rdd1 = spark.sparkContext.parallelize((0 until 20), 3).repartition(2) + // val rdd2 = spark.sparkContext.parallelize((0 until 20), 3).repartition(2) + // rdd1.count() + // rdd2.count() + // val t = new Thread() { + // override def run(): Unit = { + // // shuffle 2&3 + // val rdd3 = rdd1.repartition(3) + // val rdd4 = rdd2.repartition(3) + // rdd3.count() + // rdd4.count() + // // leaving enough time for shuffle 0&1 to be expired + // Thread.sleep(10000) + // // shuffle 4&5 + // rdd3.repartition(4).count() + // rdd4.repartition(4).count() + // // leaving enough time for shuffle 2&3 to be expired + // Thread.sleep(10000) + // println("starting job for rdd 5") + // val rdd5 = rdd3.zip(rdd4).mapPartitions(iter => { + // Thread.sleep(10000) + // iter + // }) + // r = rdd5.count() + // } + // } + // t.start() + // val thread = StorageCheckUtils.triggerStorageCheckThread( + // workerDirs, + // // 4,5 are based on vanilla spark gc which are not necessarily stable in a test + // // 6,9 is based on failed shuffle cleanup, which is not covered here + // shuffleIdShouldNotExist = Seq(0, 1, 2, 3, 7, 10, 12), + // shuffleIdMustExist = Seq(8, 11), + // sparkSession = spark, + // forStableStatusChecking = false) + // StorageCheckUtils.checkStorageValidation(thread) + // t.join() + // assert(r === 20) + // } finally { + // spark.stop( + // } + // } + // } + + private def multiShuffleFailureTest( + shuffleIdShouldNotExist: Seq[Int], + shuffleIdMustExist: Seq[Int], + spark: SparkSession): Unit = { + if (Spark3OrNewer) { + val celebornConf = SparkUtils.fromSparkConf(spark.sparkContext.getConf) + val hook = new FailedCommitAndExpireDataReaderHook( + celebornConf, + triggerShuffleId = 6, + shuffleIdsToExpire = (0 to 5).toList) + TestCelebornShuffleManager.registerReaderGetHook(hook) + var r = 0L + try { + // shuffle 0&1&2 + val rdd1 = spark.sparkContext.parallelize((0 until 20), 3).repartition(2) + val rdd2 = spark.sparkContext.parallelize((0 until 20), 3).repartition(2) + val rdd3 = spark.sparkContext.parallelize((0 until 20), 3).repartition(2) + val t = new Thread() { + override def run(): Unit = { + // shuffle 3&4&5 + val rdd4 = rdd1.repartition(3) + val rdd5 = rdd2.repartition(3) + val rdd6 = rdd3.repartition(3) + println("starting job for rdd 7") + val rdd7 = rdd4.zip(rdd5).zip(rdd6).repartition(2) + r = rdd7.count() + } + } + t.start() + val thread = StorageCheckUtils.triggerStorageCheckThread( + workerDirs, + shuffleIdShouldNotExist = shuffleIdShouldNotExist, + shuffleIdMustExist = shuffleIdMustExist, + sparkSession = spark) + StorageCheckUtils.checkStorageValidation(thread) + t.join() + assert(r === 20) + } finally { + spark.stop() + } + } + } + + test("spark integration test - do not fail job when multiple shuffles (be zipped/joined)" + + " are to be retried for fetching") { + val spark = createSparkSession(Map("spark.stage.maxConsecutiveAttempts" -> "3")) + multiShuffleFailureTest(Seq(0, 1, 2, 3, 4, 5), Seq(17), spark) + } + + // test("spark integration test - do not fail job when multiple shuffles (be zipped/joined)" + + // " are to be retried for fetching (with failed shuffle deletion)") { + // val spark = createSparkSession(Map( + // "spark.stage.maxConsecutiveAttempts" -> "3", + // s"spark.${CelebornConf.CLIENT_FETCH_CLEAN_FAILED_SHUFFLE.key}" -> "true")) + // multiShuffleFailureTest(Seq(0, 1, 2, 3, 4, 5, 8, 9, 10), Seq(17), spark) + // } +} diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/StorageCheckUtils.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/StorageCheckUtils.scala new file mode 100644 index 00000000000..e456e331916 --- /dev/null +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/StorageCheckUtils.scala @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.tests.spark + +import java.io.File + +import org.apache.spark.sql.SparkSession + +private[tests] object StorageCheckUtils { + + class CheckingThread( + workerDirs: Seq[String], + shuffleIdShouldNotExist: Seq[Int], + shuffleIdMustExist: Seq[Int], + sparkSession: SparkSession) + extends Thread { + var exception: Exception = _ + + protected def checkDirStatus(): Boolean = { + val deletedSuccessfully = shuffleIdShouldNotExist.forall(shuffleId => { + workerDirs.forall(dir => + !new File(s"$dir/celeborn-worker/shuffle_data/" + + s"${sparkSession.sparkContext.applicationId}/$shuffleId").exists()) + }) + val deletedSuccessfullyString = shuffleIdShouldNotExist.map(shuffleId => { + shuffleId.toString + ":" + + workerDirs.map(dir => + !new File(s"$dir/celeborn-worker/shuffle_data/" + + s"${sparkSession.sparkContext.applicationId}/$shuffleId").exists()).toList + }).mkString(",") + val createdSuccessfully = shuffleIdMustExist.forall(shuffleId => { + workerDirs.exists(dir => + new File(s"$dir/celeborn-worker/shuffle_data/" + + s"${sparkSession.sparkContext.applicationId}/$shuffleId").exists()) + }) + val createdSuccessfullyString = shuffleIdMustExist.map(shuffleId => { + shuffleId.toString + ":" + + workerDirs.map(dir => + new File(s"$dir/celeborn-worker/shuffle_data/" + + s"${sparkSession.sparkContext.applicationId}/$shuffleId").exists()).toList + }).mkString(",") + println(s"shuffle-to-be-deleted status: $deletedSuccessfullyString \n" + + s"shuffle-to-be-created status: $createdSuccessfullyString") + deletedSuccessfully && createdSuccessfully + } + + override def run(): Unit = { + var allDataInShape = checkDirStatus() + while (!allDataInShape) { + Thread.sleep(1000) + allDataInShape = checkDirStatus() + } + } + } + + def triggerStorageCheckThread( + workerDirs: Seq[String], + shuffleIdShouldNotExist: Seq[Int], + shuffleIdMustExist: Seq[Int], + sparkSession: SparkSession): CheckingThread = { + val checkingThread = + new CheckingThread(workerDirs, shuffleIdShouldNotExist, shuffleIdMustExist, sparkSession) + checkingThread.setDaemon(true) + checkingThread.start() + checkingThread + } + + def checkStorageValidation(thread: Thread, timeout: Long = 1200 * 1000): Unit = { + val checkingThread = thread.asInstanceOf[CheckingThread] + checkingThread.join(timeout) + if (checkingThread.isAlive || checkingThread.exception != null) { + throw new IllegalStateException("the storage checking status failed," + + s"${checkingThread.isAlive} ${if (checkingThread.exception != null) checkingThread.exception.getMessage + else "NULL"}") + } + } +} diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/fetch/failure/ShuffleReaderGetHooks.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/fetch/failure/ShuffleReaderGetHooks.scala index adac14242bd..176369fb0a2 100644 --- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/fetch/failure/ShuffleReaderGetHooks.scala +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/fetch/failure/ShuffleReaderGetHooks.scala @@ -20,13 +20,66 @@ package org.apache.celeborn.tests.spark.fetch.failure import java.io.File import java.util.concurrent.atomic.AtomicBoolean -import org.apache.spark.TaskContext +import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.shuffle.ShuffleHandle -import org.apache.spark.shuffle.celeborn.{CelebornShuffleHandle, ShuffleManagerHook, SparkCommonUtils, SparkShuffleManager, SparkUtils, TestCelebornShuffleManager} +import org.apache.spark.shuffle.celeborn.{CelebornShuffleHandle, ShuffleManagerHook, SparkCommonUtils, SparkUtils, TestCelebornShuffleManager} import org.apache.celeborn.client.ShuffleClient +import org.apache.celeborn.client.commit.ReducePartitionCommitHandler import org.apache.celeborn.common.CelebornConf +class FailedCommitAndExpireDataReaderHook( + conf: CelebornConf, + triggerShuffleId: Int, + shuffleIdsToExpire: List[Int]) + extends ShuffleManagerHook { + var executed: AtomicBoolean = new AtomicBoolean(false) + val lock = new Object + + override def exec( + handle: ShuffleHandle, + startPartition: Int, + endPartition: Int, + context: TaskContext): Unit = { + + if (executed.get()) return + + lock.synchronized { + // this has to be used in local mode since it leverages that the lifecycle manager + // is in the same process with reader + handle match { + case h: CelebornShuffleHandle[_, _, _] => + val shuffleClient = ShuffleClient.get( + h.appUniqueId, + h.lifecycleManagerHost, + h.lifecycleManagerPort, + conf, + h.userIdentifier, + h.extension) + val lifecycleManager = + SparkEnv.get.shuffleManager.asInstanceOf[TestCelebornShuffleManager] + .getLifecycleManager + val celebornShuffleId = SparkUtils.celebornShuffleId(shuffleClient, h, context, false) + if (celebornShuffleId == triggerShuffleId && !executed.get()) { + println(s"putting celeborn shuffle $celebornShuffleId as commit failure") + val commitHandler = lifecycleManager.commitManager.getCommitHandler(celebornShuffleId) + commitHandler.asInstanceOf[ReducePartitionCommitHandler].dataLostShuffleSet.add( + celebornShuffleId) + shuffleIdsToExpire.foreach(sid => + SparkEnv.get.shuffleManager.asInstanceOf[TestCelebornShuffleManager] + .getStageDepManager.removeCelebornShuffleInternal(sid, None)) + // leaving enough time for all shuffles to expire + Thread.sleep(10000) + executed.set(true) + } else { + println(s"ignore hook with $celebornShuffleId $triggerShuffleId and ${executed.get()}") + } + case _ => throw new RuntimeException("unexpected, only support RssShuffleHandle here") + } + } + } +} + class ShuffleReaderGetHooks( conf: CelebornConf, workerDirs: Seq[String], From c85070013c877cc4418f4919cab71c82cc773828 Mon Sep 17 00:00:00 2001 From: CodingCat Date: Wed, 17 Dec 2025 10:53:56 -0800 Subject: [PATCH 09/20] fix existing tests --- .../apache/celeborn/client/LifecycleManager.scala | 13 +++++++------ .../commit/ReducePartitionCommitHandler.scala | 2 +- .../common/network/protocol/TransportMessage.java | 8 ++++++++ .../spark/CelebornShuffleEarlyDeleteSuite.scala | 5 ++--- 4 files changed, 18 insertions(+), 10 deletions(-) diff --git a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala index 670b5cb9fb5..51f6dff03e8 100644 --- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala +++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala @@ -1090,7 +1090,8 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends // this is not necessarily the most concise coding style, but it helps for debugging // purpose var found = false - shuffleIds.values.map(v => v._1).toSeq.reverse.foreach { celebornShuffleId: Int => + val revertedShuffleIds = shuffleIds.values.map(v => v._1).toSeq.reverse + revertedShuffleIds.foreach { celebornShuffleId: Int => if (!found) { try { if (areAllMapTasksEnd(celebornShuffleId)) { @@ -1112,12 +1113,12 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends logInfo(s"not all map tasks finished for shuffle $celebornShuffleId") } } catch { - case npe: NullPointerException => + case ise: IllegalStateException => if (conf.clientShuffleEarlyDeletion) { logError( s"hit error when getting celeborn shuffle id $celebornShuffleId for" + s" appShuffleId $appShuffleId appShuffleIdentifier $appShuffleIdentifier", - npe) + ise) val canInvalidateAllUpstream = checkWhetherToInvalidateAllUpstreamCallback.exists(func => func.apply(appShuffleIdentifier)) @@ -1135,10 +1136,10 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends context.reply(pbGetShuffleIdResponse) } else { logError( - s"unexpected NullPointerException without" + + s"unexpected IllegalStateException without" + s" ${CelebornConf.CLIENT_SHUFFLE_EARLY_DELETION.key} turning on", - npe) - throw npe; + ise) + throw ise; } } } diff --git a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala index fb557ebbe00..ac844375339 100644 --- a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala +++ b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala @@ -293,7 +293,7 @@ class ReducePartitionCommitHandler( if (null != attempts) { attempts.length == shuffleToCompletedMappers.get(shuffleId) } else { - false + throw new IllegalStateException(s"cannot find mapper attempts record for shuffle $shuffleId") } } diff --git a/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java b/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java index 1e729332f3e..96c43c9b6b4 100644 --- a/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java +++ b/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java @@ -125,6 +125,14 @@ public T getParsedPayload() throws InvalidProtoco return (T) PbNotifyRequiredSegment.parseFrom(payload); case PUSH_MERGED_DATA_SPLIT_PARTITION_INFO_VALUE: return (T) PbPushMergedDataSplitPartitionInfo.parseFrom(payload); + case INVALIDATE_ALL_UPSTREAM_SHUFFLE_VALUE: + return (T) PbInvalidateAllUpstreamShuffle.parseFrom(payload); + case INVALIDATE_ALL_UPSTREAM_SHUFFLE_RESPONSE_VALUE: + return (T) PbInvalidateAllUpstreamShuffleResponse.parseFrom(payload); + case REPORT_MISSING_SHUFFLE_ID_VALUE: + return (T) PbReportMissingShuffleId.parseFrom(payload); + case REPORT_MISSING_SHUFFLE_ID_RESPONSE_VALUE: + return (T) PbReportMissingShuffleIdResponse.parseFrom(payload); default: logger.error("Unexpected type {}", type); } diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornShuffleEarlyDeleteSuite.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornShuffleEarlyDeleteSuite.scala index 947702d3f90..20e1b012b8f 100644 --- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornShuffleEarlyDeleteSuite.scala +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornShuffleEarlyDeleteSuite.scala @@ -69,8 +69,7 @@ class CelebornShuffleEarlyDeleteSuite extends SparkTestBase { } builder.getOrCreate() } - - /* + test("spark integration test - delete shuffle data from unneeded stages") { if (Spark3OrNewer) { val spark = createSparkSession() @@ -186,7 +185,7 @@ class CelebornShuffleEarlyDeleteSuite extends SparkTestBase { spark.stop() } } - }*/ + } private def deleteTooEarlyTest( shuffleIdShouldNotExist: Seq[Int], From 7d5e1a6e0848de6d28e964ad936c5b371418319b Mon Sep 17 00:00:00 2001 From: CodingCat Date: Wed, 17 Dec 2025 10:55:03 -0800 Subject: [PATCH 10/20] stylistic --- .../celeborn/tests/spark/CelebornShuffleEarlyDeleteSuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornShuffleEarlyDeleteSuite.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornShuffleEarlyDeleteSuite.scala index 20e1b012b8f..8a40003ccd2 100644 --- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornShuffleEarlyDeleteSuite.scala +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornShuffleEarlyDeleteSuite.scala @@ -20,6 +20,7 @@ package org.apache.celeborn.tests.spark import org.apache.spark.SparkConf import org.apache.spark.shuffle.celeborn.{SparkUtils, TestCelebornShuffleManager} import org.apache.spark.sql.SparkSession + import org.apache.celeborn.client.ShuffleClient import org.apache.celeborn.common.CelebornConf import org.apache.celeborn.common.protocol.ShuffleMode @@ -69,7 +70,7 @@ class CelebornShuffleEarlyDeleteSuite extends SparkTestBase { } builder.getOrCreate() } - + test("spark integration test - delete shuffle data from unneeded stages") { if (Spark3OrNewer) { val spark = createSparkSession() From 74710efa1918fd043223a3170cbcbf81f46fbc83 Mon Sep 17 00:00:00 2001 From: CodingCat Date: Wed, 17 Dec 2025 19:40:28 -0800 Subject: [PATCH 11/20] debugging instrumentations --- .../scala/org/apache/celeborn/client/LifecycleManager.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala index 51f6dff03e8..42932c91cb9 100644 --- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala +++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala @@ -1271,10 +1271,10 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends shuffleIds.synchronized { val latestUpstreamShuffleId = shuffleIds.maxBy(_._2._1) if (latestUpstreamShuffleId._2._1 == UNKNOWN_MISSING_CELEBORN_SHUFFLE_ID) { - logInfo(s"ignoring missing shuffle id report from stage $stageId.$stageAttemptId as" + + println(s"ignoring missing shuffle id report from stage $stageId.$stageAttemptId as" + s" it is already reported by other reader and handled") } else { - logInfo(s"handle missing shuffle id for appShuffleId $appShuffleId stage" + + println(s"handle missing shuffle id for appShuffleId $appShuffleId stage" + s" $stageId.$stageAttemptId") appShuffleTrackerCallback match { case Some(callback) => From 470e018f4ebeda8e9c4cc211cf0c5da1a2a199b3 Mon Sep 17 00:00:00 2001 From: CodingCat Date: Thu, 18 Dec 2025 10:57:01 -0800 Subject: [PATCH 12/20] filter shit --- .../scala/org/apache/celeborn/client/LifecycleManager.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala index 42932c91cb9..983aa27f2d4 100644 --- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala +++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala @@ -1090,7 +1090,7 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends // this is not necessarily the most concise coding style, but it helps for debugging // purpose var found = false - val revertedShuffleIds = shuffleIds.values.map(v => v._1).toSeq.reverse + val revertedShuffleIds = shuffleIds.values.map(v => v._1).toSeq.reverse.filter(_ >= 0) revertedShuffleIds.foreach { celebornShuffleId: Int => if (!found) { try { From cdf6a868fbdfe5303c22d6bdf746b18419c92f97 Mon Sep 17 00:00:00 2001 From: CodingCat Date: Wed, 24 Dec 2025 21:45:35 -0800 Subject: [PATCH 13/20] using appshuffleidentifier to dedup missing shuffle id report and disable reuse shuffle id when early deletion feature turned on --- .../ShuffleStatsTrackingListener.scala | 2 +- .../celeborn/client/LifecycleManager.scala | 43 +++++++++++++++++-- 2 files changed, 40 insertions(+), 5 deletions(-) diff --git a/client-spark/spark-3/src/main/scala/org/apache/spark/listener/ShuffleStatsTrackingListener.scala b/client-spark/spark-3/src/main/scala/org/apache/spark/listener/ShuffleStatsTrackingListener.scala index 10e80a802cc..958bf6d7131 100644 --- a/client-spark/spark-3/src/main/scala/org/apache/spark/listener/ShuffleStatsTrackingListener.scala +++ b/client-spark/spark-3/src/main/scala/org/apache/spark/listener/ShuffleStatsTrackingListener.scala @@ -25,7 +25,7 @@ import org.apache.spark.shuffle.celeborn.SparkShuffleManager class ShuffleStatsTrackingListener extends SparkListener with Logging { override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = { - logInfo(s"stage ${stageSubmitted.stageInfo.stageId}.${stageSubmitted.stageInfo.attemptNumber()} started") + println(s"stage ${stageSubmitted.stageInfo.stageId}.${stageSubmitted.stageInfo.attemptNumber()} started") val stageId = stageSubmitted.stageInfo.stageId val shuffleMgr = SparkEnv.get.shuffleManager.asInstanceOf[SparkShuffleManager] val parentStages = stageSubmitted.stageInfo.parentIds diff --git a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala index 983aa27f2d4..3b0dd31290a 100644 --- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala +++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala @@ -991,6 +991,7 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends appShuffleIdentifier: String, isWriter: Boolean, isBarrierStage: Boolean): Unit = { + println(s"get shuffle id for $appShuffleIdentifier isWriter: $isWriter") val shuffleIds = if (isWriter) { shuffleIdMapping.computeIfAbsent( @@ -1033,6 +1034,7 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends case Some((shuffleId, _)) => val pbGetShuffleIdResponse = PbGetShuffleIdResponse.newBuilder().setShuffleId(shuffleId).setSuccess(true).build() + println(s"reply with shuffle id ${shuffleId} for $appShuffleIdentifier") context.reply(pbGetShuffleIdResponse) case None => Option(appShuffleDeterminateMap.get(appShuffleId)).map { determinate => @@ -1041,7 +1043,7 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends // So if a barrier stage is getting reexecuted, previous stage/attempt needs to // be cleaned up as it is entirely unusuable if (determinate && !isBarrierStage && !isCelebornSkewShuffleOrChildShuffle( - appShuffleId)) { + appShuffleId) && !conf.clientShuffleEarlyDeletion) { val result = shuffleIds.values.toSeq.reverse.find(e => e._2 == true) if (result.isEmpty) { logWarning(s"cannot find candidate shuffleId for determinate" + @@ -1054,7 +1056,7 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends val shuffleId: Integer = if (determinate && candidateShuffle.isDefined) { val id = candidateShuffle.get._1 - logInfo(s"reuse existing shuffleId $id for appShuffleId $appShuffleId appShuffleIdentifier $appShuffleIdentifier") + println(s"reuse existing shuffleId $id for appShuffleId $appShuffleId appShuffleIdentifier $appShuffleIdentifier") id } else { // this branch means it is a redo of previous write stage @@ -1269,6 +1271,39 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends } var ret = true shuffleIds.synchronized { + val stageIdentifier = s"$stageId.$stageAttemptId" + val shuffleIdentifier = s"$appShuffleId.$stageId.$stageAttemptId" + if (stagesReceivedInvalidatingUpstream.getOrElse(stageIdentifier, new mutable.HashSet[Int]()) + .contains(appShuffleId)) { + println(s"${Thread.currentThread().getName} " + + s"ignoring missing shuffle id report from stage $stageId.$stageAttemptId as" + + s" it is already reported by other reader and handled") + } else { + println(s"${Thread.currentThread().getName} handle missing shuffle id for appShuffleId" + + s" $appShuffleId stage" + + s" $stageId.$stageAttemptId") + appShuffleTrackerCallback match { + case Some(callback) => + try { + callback.accept(appShuffleId) + } catch { + case t: Throwable => + logError(t.toString) + ret = false + } + shuffleIds.put(shuffleIdentifier, (UNKNOWN_MISSING_CELEBORN_SHUFFLE_ID, false)) + case None => + throw new UnsupportedOperationException( + "unexpected! appShuffleTrackerCallback is not registered") + } + invalidateShuffleWrittenByStage(stageId) + stagesReceivedInvalidatingUpstream += stageIdentifier -> + (stagesReceivedInvalidatingUpstream.getOrElse( + stageIdentifier, new mutable.HashSet[Int]()) ++ Set(appShuffleId)) + val pbReportMissingShuffleIdResponse = + PbReportMissingShuffleIdResponse.newBuilder().setSuccess(ret).build() + context.reply(pbReportMissingShuffleIdResponse) + /* val latestUpstreamShuffleId = shuffleIds.maxBy(_._2._1) if (latestUpstreamShuffleId._2._1 == UNKNOWN_MISSING_CELEBORN_SHUFFLE_ID) { println(s"ignoring missing shuffle id report from stage $stageId.$stageAttemptId as" + @@ -1294,7 +1329,7 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends invalidateShuffleWrittenByStage(stageId) val pbReportMissingShuffleIdResponse = PbReportMissingShuffleIdResponse.newBuilder().setSuccess(ret).build() - context.reply(pbReportMissingShuffleIdResponse) + context.reply(pbReportMissingShuffleIdResponse)*/ } } } @@ -1377,7 +1412,7 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends if (shuffleId >= 0) { val celebornShuffleIds = shuffleIdMapping.get(writtenShuffleId) if (celebornShuffleIds != null) { - logInfo(s"invalidating location of app shuffle id $writtenShuffleId written" + + println(s"invalidating location of app shuffle id $writtenShuffleId written" + s" by stage $stageId") val latestShuffleId = celebornShuffleIds.maxBy(_._2._1) celebornShuffleIds.put(latestShuffleId._1, (latestShuffleId._2._1, false)) From 4a213abca6b457e23d29ffc88c48b989a4f0ab7e Mon Sep 17 00:00:00 2001 From: CodingCat Date: Wed, 24 Dec 2025 21:50:02 -0800 Subject: [PATCH 14/20] remove added filter --- .../scala/org/apache/celeborn/client/LifecycleManager.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala index 3b0dd31290a..f0e8bddac5f 100644 --- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala +++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala @@ -1092,7 +1092,7 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends // this is not necessarily the most concise coding style, but it helps for debugging // purpose var found = false - val revertedShuffleIds = shuffleIds.values.map(v => v._1).toSeq.reverse.filter(_ >= 0) + val revertedShuffleIds = shuffleIds.values.map(v => v._1).toSeq.reverse revertedShuffleIds.foreach { celebornShuffleId: Int => if (!found) { try { From 67a55b8a34fb9c0b9a1eb2315691e8d401c48501 Mon Sep 17 00:00:00 2001 From: CodingCat Date: Wed, 24 Dec 2025 21:51:59 -0800 Subject: [PATCH 15/20] stylistic fixes --- .../org/apache/celeborn/client/LifecycleManager.scala | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala index f0e8bddac5f..aff6debaf9f 100644 --- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala +++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala @@ -1274,7 +1274,7 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends val stageIdentifier = s"$stageId.$stageAttemptId" val shuffleIdentifier = s"$appShuffleId.$stageId.$stageAttemptId" if (stagesReceivedInvalidatingUpstream.getOrElse(stageIdentifier, new mutable.HashSet[Int]()) - .contains(appShuffleId)) { + .contains(appShuffleId)) { println(s"${Thread.currentThread().getName} " + s"ignoring missing shuffle id report from stage $stageId.$stageAttemptId as" + s" it is already reported by other reader and handled") @@ -1299,11 +1299,12 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends invalidateShuffleWrittenByStage(stageId) stagesReceivedInvalidatingUpstream += stageIdentifier -> (stagesReceivedInvalidatingUpstream.getOrElse( - stageIdentifier, new mutable.HashSet[Int]()) ++ Set(appShuffleId)) + stageIdentifier, + new mutable.HashSet[Int]()) ++ Set(appShuffleId)) val pbReportMissingShuffleIdResponse = PbReportMissingShuffleIdResponse.newBuilder().setSuccess(ret).build() context.reply(pbReportMissingShuffleIdResponse) - /* + /* val latestUpstreamShuffleId = shuffleIds.maxBy(_._2._1) if (latestUpstreamShuffleId._2._1 == UNKNOWN_MISSING_CELEBORN_SHUFFLE_ID) { println(s"ignoring missing shuffle id report from stage $stageId.$stageAttemptId as" + From 4064fb85c25db350eea457975cd01a926cd9caff Mon Sep 17 00:00:00 2001 From: CodingCat Date: Wed, 24 Dec 2025 23:13:38 -0800 Subject: [PATCH 16/20] code clean up --- .../ShuffleStatsTrackingListener.scala | 2 +- .../celeborn/client/LifecycleManager.scala | 38 ++----------------- 2 files changed, 5 insertions(+), 35 deletions(-) diff --git a/client-spark/spark-3/src/main/scala/org/apache/spark/listener/ShuffleStatsTrackingListener.scala b/client-spark/spark-3/src/main/scala/org/apache/spark/listener/ShuffleStatsTrackingListener.scala index 958bf6d7131..10e80a802cc 100644 --- a/client-spark/spark-3/src/main/scala/org/apache/spark/listener/ShuffleStatsTrackingListener.scala +++ b/client-spark/spark-3/src/main/scala/org/apache/spark/listener/ShuffleStatsTrackingListener.scala @@ -25,7 +25,7 @@ import org.apache.spark.shuffle.celeborn.SparkShuffleManager class ShuffleStatsTrackingListener extends SparkListener with Logging { override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = { - println(s"stage ${stageSubmitted.stageInfo.stageId}.${stageSubmitted.stageInfo.attemptNumber()} started") + logInfo(s"stage ${stageSubmitted.stageInfo.stageId}.${stageSubmitted.stageInfo.attemptNumber()} started") val stageId = stageSubmitted.stageInfo.stageId val shuffleMgr = SparkEnv.get.shuffleManager.asInstanceOf[SparkShuffleManager] val parentStages = stageSubmitted.stageInfo.parentIds diff --git a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala index aff6debaf9f..7ea22d914f2 100644 --- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala +++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala @@ -1034,7 +1034,6 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends case Some((shuffleId, _)) => val pbGetShuffleIdResponse = PbGetShuffleIdResponse.newBuilder().setShuffleId(shuffleId).setSuccess(true).build() - println(s"reply with shuffle id ${shuffleId} for $appShuffleIdentifier") context.reply(pbGetShuffleIdResponse) case None => Option(appShuffleDeterminateMap.get(appShuffleId)).map { determinate => @@ -1156,7 +1155,7 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends stagesReceivedInvalidatingUpstream.getOrElseUpdate( stageIdentifier, new mutable.HashSet[Int]()) - println(s"invalidating all upstream shuffles of stage $stageIdentifier") + logInfo(s"invalidating all upstream shuffles of stage $stageIdentifier") val upstreamShuffleIds = getUpstreamAppShuffleIdsCallback.map(f => f.apply(readerStageId)).getOrElse(Array()) upstreamShuffleIds.foreach { upstreamAppShuffleId => @@ -1275,12 +1274,10 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends val shuffleIdentifier = s"$appShuffleId.$stageId.$stageAttemptId" if (stagesReceivedInvalidatingUpstream.getOrElse(stageIdentifier, new mutable.HashSet[Int]()) .contains(appShuffleId)) { - println(s"${Thread.currentThread().getName} " + - s"ignoring missing shuffle id report from stage $stageId.$stageAttemptId as" + + logInfo(s"ignoring missing shuffle id report from stage $stageId.$stageAttemptId as" + s" it is already reported by other reader and handled") } else { - println(s"${Thread.currentThread().getName} handle missing shuffle id for appShuffleId" + - s" $appShuffleId stage" + + logInfo(s"handle missing shuffle id for appShuffleId $appShuffleId stage" + s" $stageId.$stageAttemptId") appShuffleTrackerCallback match { case Some(callback) => @@ -1304,33 +1301,6 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends val pbReportMissingShuffleIdResponse = PbReportMissingShuffleIdResponse.newBuilder().setSuccess(ret).build() context.reply(pbReportMissingShuffleIdResponse) - /* - val latestUpstreamShuffleId = shuffleIds.maxBy(_._2._1) - if (latestUpstreamShuffleId._2._1 == UNKNOWN_MISSING_CELEBORN_SHUFFLE_ID) { - println(s"ignoring missing shuffle id report from stage $stageId.$stageAttemptId as" + - s" it is already reported by other reader and handled") - } else { - println(s"handle missing shuffle id for appShuffleId $appShuffleId stage" + - s" $stageId.$stageAttemptId") - appShuffleTrackerCallback match { - case Some(callback) => - try { - callback.accept(appShuffleId) - } catch { - case t: Throwable => - logError(t.toString) - ret = false - } - shuffleIds.put(latestUpstreamShuffleId._1, (UNKNOWN_MISSING_CELEBORN_SHUFFLE_ID, false)) - case None => - throw new UnsupportedOperationException( - "unexpected! appShuffleTrackerCallback is not registered") - } - // invalidate the shuffle written by stage - invalidateShuffleWrittenByStage(stageId) - val pbReportMissingShuffleIdResponse = - PbReportMissingShuffleIdResponse.newBuilder().setSuccess(ret).build() - context.reply(pbReportMissingShuffleIdResponse)*/ } } } @@ -1413,7 +1383,7 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends if (shuffleId >= 0) { val celebornShuffleIds = shuffleIdMapping.get(writtenShuffleId) if (celebornShuffleIds != null) { - println(s"invalidating location of app shuffle id $writtenShuffleId written" + + logInfo(s"invalidating location of app shuffle id $writtenShuffleId written" + s" by stage $stageId") val latestShuffleId = celebornShuffleIds.maxBy(_._2._1) celebornShuffleIds.put(latestShuffleId._1, (latestShuffleId._2._1, false)) From 6c521eea0447b3586a17458b3a970d4631fb838a Mon Sep 17 00:00:00 2001 From: CodingCat Date: Thu, 25 Dec 2025 07:52:32 -0800 Subject: [PATCH 17/20] integrate with failed shuffle deletion --- .../shuffle/celeborn/SparkShuffleManager.java | 9 +- .../celeborn/StageDependencyManager.scala | 17 ++- .../CelebornShuffleEarlyDeleteSuite.scala | 126 +++++++++--------- 3 files changed, 81 insertions(+), 71 deletions(-) diff --git a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java index 4eef7a862d6..dc7fbedca83 100644 --- a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java +++ b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java @@ -187,7 +187,14 @@ private void initializeLifecycleManager(String appId) { } if (lifecycleManager.conf().clientShuffleEarlyDeletion()) { - logger.info("register early deletion callbacks"); + if (!lifecycleManager.conf().clientStageRerunEnabled()) { + throw new IllegalArgumentException( + CelebornConf.CLIENT_STAGE_RERUN_ENABLED().key() + + " has to be " + + "enabled, when " + + CelebornConf.CLIENT_SHUFFLE_EARLY_DELETION().key() + + " is set to true"); + } SparkUtils.addSparkListener(new ShuffleStatsTrackingListener()); lifecycleManager.registerStageToWriteCelebornShuffleCallback( (celebornShuffleId, appShuffleIdentifier) -> diff --git a/client-spark/spark-3/src/main/scala/org/apache/spark/celeborn/StageDependencyManager.scala b/client-spark/spark-3/src/main/scala/org/apache/spark/celeborn/StageDependencyManager.scala index badb587924f..51ad9d0237e 100644 --- a/client-spark/spark-3/src/main/scala/org/apache/spark/celeborn/StageDependencyManager.scala +++ b/client-spark/spark-3/src/main/scala/org/apache/spark/celeborn/StageDependencyManager.scala @@ -240,13 +240,18 @@ class StageDependencyManager(shuffleManager: SparkShuffleManager) extends Loggin val cleanerThread = new Thread() { override def run(): Unit = { while (!stopped) { - val allShuffleIds = new util.ArrayList[Int] - shuffleIdsToBeCleaned.drainTo(allShuffleIds) - allShuffleIds.asScala.foreach { shuffleId => - shuffleManager.getLifecycleManager.unregisterShuffle(shuffleId) - logInfo(s"sent unregister shuffle request for shuffle $shuffleId (celeborn shuffle id)") + try { + val allShuffleIds = new util.ArrayList[Int] + shuffleIdsToBeCleaned.drainTo(allShuffleIds) + allShuffleIds.asScala.foreach { shuffleId => + shuffleManager.getLifecycleManager.unregisterShuffle(shuffleId) + logInfo(s"sent unregister shuffle request for shuffle $shuffleId (celeborn shuffle id)") + } + Thread.sleep(cleanInterval) + } catch { + case t: Throwable => + logError("unexpected error in shuffle early cleaner thread", t) } - Thread.sleep(cleanInterval) } } } diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornShuffleEarlyDeleteSuite.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornShuffleEarlyDeleteSuite.scala index 8a40003ccd2..f301fc02339 100644 --- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornShuffleEarlyDeleteSuite.scala +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornShuffleEarlyDeleteSuite.scala @@ -237,12 +237,12 @@ class CelebornShuffleEarlyDeleteSuite extends SparkTestBase { deleteTooEarlyTest(Seq(0, 3, 5), Seq(1, 2, 4), spark) } - // test("spark integration test - do not fail job when shuffle is deleted \"too early\"" + - // " (with failed shuffle deletion)") { - // val spark = createSparkSession( - // Map(s"spark.${CelebornConf.CLIENT_FETCH_CLEAN_FAILED_SHUFFLE.key}" -> "true")) - // deleteTooEarlyTest(Seq(0, 2, 3, 5), Seq(1, 4), spark) - // } + test("spark integration test - do not fail job when shuffle is deleted \"too early\"" + + " (with failed shuffle deletion)") { + val spark = createSparkSession( + Map(s"spark.${CelebornConf.CLIENT_FETCH_CLEAN_FAILED_SHUFFLE.key}" -> "true")) + deleteTooEarlyTest(Seq(0, 2, 3, 5), Seq(1, 4), spark) + } test("spark integration test - do not fail job when shuffle files" + " are deleted \"too early\" (ancestor dependency)") { @@ -336,57 +336,55 @@ class CelebornShuffleEarlyDeleteSuite extends SparkTestBase { } } - // test("spark integration test - do not fail job when multiple shuffles (be zipped/joined)" + - // " are deleted \"too early\"") { - // if (runningWithSpark3OrNewer()) { - // val spark = createSparkSession( - // Map(s"spark.${CelebornConf.CLIENT_FETCH_CLEAN_FAILED_SHUFFLE.key}" -> "true")) - // var r = 0L - // try { - // // shuffle 0&1 - // val rdd1 = spark.sparkContext.parallelize((0 until 20), 3).repartition(2) - // val rdd2 = spark.sparkContext.parallelize((0 until 20), 3).repartition(2) - // rdd1.count() - // rdd2.count() - // val t = new Thread() { - // override def run(): Unit = { - // // shuffle 2&3 - // val rdd3 = rdd1.repartition(3) - // val rdd4 = rdd2.repartition(3) - // rdd3.count() - // rdd4.count() - // // leaving enough time for shuffle 0&1 to be expired - // Thread.sleep(10000) - // // shuffle 4&5 - // rdd3.repartition(4).count() - // rdd4.repartition(4).count() - // // leaving enough time for shuffle 2&3 to be expired - // Thread.sleep(10000) - // println("starting job for rdd 5") - // val rdd5 = rdd3.zip(rdd4).mapPartitions(iter => { - // Thread.sleep(10000) - // iter - // }) - // r = rdd5.count() - // } - // } - // t.start() - // val thread = StorageCheckUtils.triggerStorageCheckThread( - // workerDirs, - // // 4,5 are based on vanilla spark gc which are not necessarily stable in a test - // // 6,9 is based on failed shuffle cleanup, which is not covered here - // shuffleIdShouldNotExist = Seq(0, 1, 2, 3, 7, 10, 12), - // shuffleIdMustExist = Seq(8, 11), - // sparkSession = spark, - // forStableStatusChecking = false) - // StorageCheckUtils.checkStorageValidation(thread) - // t.join() - // assert(r === 20) - // } finally { - // spark.stop( - // } - // } - // } + test("spark integration test - do not fail job when multiple shuffles (be zipped/joined)" + + " are deleted \"too early\"") { + if (Spark3OrNewer) { + val spark = createSparkSession() + var r = 0L + try { + // shuffle 0&1 + val rdd1 = spark.sparkContext.parallelize((0 until 20), 3).repartition(2) + val rdd2 = spark.sparkContext.parallelize((0 until 20), 3).repartition(2) + rdd1.count() + rdd2.count() + val t = new Thread() { + override def run(): Unit = { + // shuffle 2&3 + val rdd3 = rdd1.repartition(3) + val rdd4 = rdd2.repartition(3) + rdd3.count() + rdd4.count() + // leaving enough time for shuffle 0&1 to be expired + Thread.sleep(10000) + // shuffle 4&5 + rdd3.repartition(4).count() + rdd4.repartition(4).count() + // leaving enough time for shuffle 2&3 to be expired + Thread.sleep(10000) + println("starting job for rdd 5") + val rdd5 = rdd3.zip(rdd4).mapPartitions(iter => { + Thread.sleep(10000) + iter + }) + r = rdd5.count() + } + } + t.start() + val thread = StorageCheckUtils.triggerStorageCheckThread( + workerDirs, + // 4,5 are based on vanilla spark gc which are not necessarily stable in a test + // 6,9 are based on failed shuffle cleanup, which is not covered here + shuffleIdShouldNotExist = Seq(0, 1, 2, 3, 7, 10, 12), + shuffleIdMustExist = Seq(8, 11), + sparkSession = spark) + StorageCheckUtils.checkStorageValidation(thread) + t.join() + assert(r === 20) + } finally { + spark.stop() + } + } + } private def multiShuffleFailureTest( shuffleIdShouldNotExist: Seq[Int], @@ -437,11 +435,11 @@ class CelebornShuffleEarlyDeleteSuite extends SparkTestBase { multiShuffleFailureTest(Seq(0, 1, 2, 3, 4, 5), Seq(17), spark) } - // test("spark integration test - do not fail job when multiple shuffles (be zipped/joined)" + - // " are to be retried for fetching (with failed shuffle deletion)") { - // val spark = createSparkSession(Map( - // "spark.stage.maxConsecutiveAttempts" -> "3", - // s"spark.${CelebornConf.CLIENT_FETCH_CLEAN_FAILED_SHUFFLE.key}" -> "true")) - // multiShuffleFailureTest(Seq(0, 1, 2, 3, 4, 5, 8, 9, 10), Seq(17), spark) - // } + test("spark integration test - do not fail job when multiple shuffles (be zipped/joined)" + + " are to be retried for fetching (with failed shuffle deletion)") { + val spark = createSparkSession(Map( + "spark.stage.maxConsecutiveAttempts" -> "3", + s"spark.${CelebornConf.CLIENT_FETCH_CLEAN_FAILED_SHUFFLE.key}" -> "true")) + multiShuffleFailureTest(Seq(0, 1, 2, 3, 4, 5, 8, 9, 10), Seq(17), spark) + } } From f94a0f458d4511d82fb73484e8d53fc6f32ed1e6 Mon Sep 17 00:00:00 2001 From: CodingCat Date: Thu, 25 Dec 2025 07:59:26 -0800 Subject: [PATCH 18/20] stylistic fixes --- .../spark/shuffle/celeborn/SparkShuffleManager.java | 10 +++++----- .../apache/spark/celeborn/StageDependencyManager.scala | 3 ++- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java index dc7fbedca83..b3d86cb4ee5 100644 --- a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java +++ b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java @@ -189,11 +189,11 @@ private void initializeLifecycleManager(String appId) { if (lifecycleManager.conf().clientShuffleEarlyDeletion()) { if (!lifecycleManager.conf().clientStageRerunEnabled()) { throw new IllegalArgumentException( - CelebornConf.CLIENT_STAGE_RERUN_ENABLED().key() - + " has to be " - + "enabled, when " - + CelebornConf.CLIENT_SHUFFLE_EARLY_DELETION().key() - + " is set to true"); + CelebornConf.CLIENT_STAGE_RERUN_ENABLED().key() + + " has to be " + + "enabled, when " + + CelebornConf.CLIENT_SHUFFLE_EARLY_DELETION().key() + + " is set to true"); } SparkUtils.addSparkListener(new ShuffleStatsTrackingListener()); lifecycleManager.registerStageToWriteCelebornShuffleCallback( diff --git a/client-spark/spark-3/src/main/scala/org/apache/spark/celeborn/StageDependencyManager.scala b/client-spark/spark-3/src/main/scala/org/apache/spark/celeborn/StageDependencyManager.scala index 51ad9d0237e..ea69c176e00 100644 --- a/client-spark/spark-3/src/main/scala/org/apache/spark/celeborn/StageDependencyManager.scala +++ b/client-spark/spark-3/src/main/scala/org/apache/spark/celeborn/StageDependencyManager.scala @@ -245,7 +245,8 @@ class StageDependencyManager(shuffleManager: SparkShuffleManager) extends Loggin shuffleIdsToBeCleaned.drainTo(allShuffleIds) allShuffleIds.asScala.foreach { shuffleId => shuffleManager.getLifecycleManager.unregisterShuffle(shuffleId) - logInfo(s"sent unregister shuffle request for shuffle $shuffleId (celeborn shuffle id)") + logInfo( + s"sent unregister shuffle request for shuffle $shuffleId (celeborn shuffle id)") } Thread.sleep(cleanInterval) } catch { From daa69c6b3665240226cc337303158434edae97df Mon Sep 17 00:00:00 2001 From: CodingCat Date: Thu, 25 Dec 2025 11:58:36 -0800 Subject: [PATCH 19/20] fix spark 2 compilation problem --- .../celeborn/StageDependencyManager.scala | 29 +++++++++++++++++++ .../celeborn/TestCelebornShuffleManager.java | 5 ++++ 2 files changed, 34 insertions(+) create mode 100644 client-spark/spark-2/src/main/scala/org/apache/spark/celeborn/StageDependencyManager.scala diff --git a/client-spark/spark-2/src/main/scala/org/apache/spark/celeborn/StageDependencyManager.scala b/client-spark/spark-2/src/main/scala/org/apache/spark/celeborn/StageDependencyManager.scala new file mode 100644 index 00000000000..7a89894dfcf --- /dev/null +++ b/client-spark/spark-2/src/main/scala/org/apache/spark/celeborn/StageDependencyManager.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.celeborn + +import org.apache.celeborn.common.internal.Logging +import org.apache.spark.shuffle.celeborn.SparkShuffleManager + +class StageDependencyManager(shuffleManager: SparkShuffleManager) extends Logging { + def removeCelebornShuffleInternal( + celebornShuffleId: Int, + stageId: Option[Int]): Unit = { + throw new NotImplementedError("the method is not implemented") + } +} \ No newline at end of file diff --git a/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/TestCelebornShuffleManager.java b/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/TestCelebornShuffleManager.java index 5c995e75306..de5e27d97d1 100644 --- a/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/TestCelebornShuffleManager.java +++ b/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/TestCelebornShuffleManager.java @@ -19,6 +19,7 @@ import org.apache.spark.SparkConf; import org.apache.spark.TaskContext; +import org.apache.spark.celeborn.StageDependencyManager; import org.apache.spark.shuffle.ShuffleHandle; import org.apache.spark.shuffle.ShuffleReader; @@ -42,4 +43,8 @@ public ShuffleReader getReader( } return super.getReader(handle, startPartition, endPartition, context); } + + public StageDependencyManager getStageDepManager() { + return new StageDependencyManager(this); + } } From 352a201c2d3e79a55ef917c2bc61c592bc799ed7 Mon Sep 17 00:00:00 2001 From: CodingCat Date: Thu, 25 Dec 2025 12:04:34 -0800 Subject: [PATCH 20/20] stylistic fixes --- .../apache/spark/celeborn/StageDependencyManager.scala | 9 +++++---- .../shuffle/celeborn/TestCelebornShuffleManager.java | 2 +- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/client-spark/spark-2/src/main/scala/org/apache/spark/celeborn/StageDependencyManager.scala b/client-spark/spark-2/src/main/scala/org/apache/spark/celeborn/StageDependencyManager.scala index 7a89894dfcf..c1aca450f19 100644 --- a/client-spark/spark-2/src/main/scala/org/apache/spark/celeborn/StageDependencyManager.scala +++ b/client-spark/spark-2/src/main/scala/org/apache/spark/celeborn/StageDependencyManager.scala @@ -17,13 +17,14 @@ package org.apache.spark.celeborn -import org.apache.celeborn.common.internal.Logging import org.apache.spark.shuffle.celeborn.SparkShuffleManager +import org.apache.celeborn.common.internal.Logging + class StageDependencyManager(shuffleManager: SparkShuffleManager) extends Logging { def removeCelebornShuffleInternal( - celebornShuffleId: Int, - stageId: Option[Int]): Unit = { + celebornShuffleId: Int, + stageId: Option[Int]): Unit = { throw new NotImplementedError("the method is not implemented") } -} \ No newline at end of file +} diff --git a/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/TestCelebornShuffleManager.java b/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/TestCelebornShuffleManager.java index de5e27d97d1..0fd4f8673ce 100644 --- a/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/TestCelebornShuffleManager.java +++ b/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/TestCelebornShuffleManager.java @@ -43,7 +43,7 @@ public ShuffleReader getReader( } return super.getReader(handle, startPartition, endPartition, context); } - + public StageDependencyManager getStageDepManager() { return new StageDependencyManager(this); }