diff --git a/client-spark/common/src/main/scala/org/apache/spark/CelebornSparkContextHelper.scala b/client-spark/common/src/main/scala/org/apache/spark/CelebornSparkContextHelper.scala new file mode 100644 index 00000000000..f07d7088dd6 --- /dev/null +++ b/client-spark/common/src/main/scala/org/apache/spark/CelebornSparkContextHelper.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import scala.collection.JavaConverters._ + +import org.apache.spark.scheduler.{EventLoggingListener, SparkListenerInterface} + +object CelebornSparkContextHelper { + + def eventLogger: Option[EventLoggingListener] = SparkContext.getActive.get.eventLogger + + def env: SparkEnv = { + assert(SparkContext.getActive.isDefined) + SparkContext.getActive.get.env + } + + def activeSparkContext(): Option[SparkContext] = { + SparkContext.getActive + } + + def getListener(listenerClass: String): SparkListenerInterface = { + activeSparkContext().get.listenerBus.listeners.asScala.find(l => + l.getClass.getCanonicalName.contains(listenerClass)).getOrElse( + throw new RuntimeException( + s"cannot find any listener containing $listenerClass in class name")) + } +} diff --git a/client-spark/spark-2/src/main/scala/org/apache/spark/celeborn/StageDependencyManager.scala b/client-spark/spark-2/src/main/scala/org/apache/spark/celeborn/StageDependencyManager.scala new file mode 100644 index 00000000000..c1aca450f19 --- /dev/null +++ b/client-spark/spark-2/src/main/scala/org/apache/spark/celeborn/StageDependencyManager.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.celeborn + +import org.apache.spark.shuffle.celeborn.SparkShuffleManager + +import org.apache.celeborn.common.internal.Logging + +class StageDependencyManager(shuffleManager: SparkShuffleManager) extends Logging { + def removeCelebornShuffleInternal( + celebornShuffleId: Int, + stageId: Option[Int]): Unit = { + throw new NotImplementedError("the method is not implemented") + } +} diff --git a/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/TestCelebornShuffleManager.java b/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/TestCelebornShuffleManager.java index 5c995e75306..0fd4f8673ce 100644 --- a/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/TestCelebornShuffleManager.java +++ b/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/TestCelebornShuffleManager.java @@ -19,6 +19,7 @@ import org.apache.spark.SparkConf; import org.apache.spark.TaskContext; +import org.apache.spark.celeborn.StageDependencyManager; import org.apache.spark.shuffle.ShuffleHandle; import org.apache.spark.shuffle.ShuffleReader; @@ -42,4 +43,8 @@ public ShuffleReader getReader( } return super.getReader(handle, startPartition, endPartition, context); } + + public StageDependencyManager getStageDepManager() { + return new StageDependencyManager(this); + } } diff --git a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java index ed7865e19ff..b3d86cb4ee5 100644 --- a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java +++ b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java @@ -22,8 +22,10 @@ import java.util.concurrent.ConcurrentHashMap; import org.apache.spark.*; +import org.apache.spark.celeborn.StageDependencyManager; import org.apache.spark.internal.config.package$; import org.apache.spark.launcher.SparkLauncher; +import org.apache.spark.listener.ShuffleStatsTrackingListener; import org.apache.spark.rdd.DeterministicLevel; import org.apache.spark.shuffle.*; import org.apache.spark.shuffle.sort.SortShuffleManager; @@ -91,6 +93,13 @@ public class SparkShuffleManager implements ShuffleManager { private ExecutorShuffleIdTracker shuffleIdTracker = new ExecutorShuffleIdTracker(); + private StageDependencyManager stageDepManager = null; + + // for testing + public void initStageDepManager() { + this.stageDepManager = new StageDependencyManager(this); + } + public SparkShuffleManager(SparkConf conf, boolean isDriver) { if (conf.getBoolean(SQLConf.LOCAL_SHUFFLE_READER_ENABLED().key(), true)) { logger.warn( @@ -177,6 +186,45 @@ private void initializeLifecycleManager(String appId) { (celebornShuffleId) -> SparkUtils.removeCleanedShuffleId(this, celebornShuffleId)); } + if (lifecycleManager.conf().clientShuffleEarlyDeletion()) { + if (!lifecycleManager.conf().clientStageRerunEnabled()) { + throw new IllegalArgumentException( + CelebornConf.CLIENT_STAGE_RERUN_ENABLED().key() + + " has to be " + + "enabled, when " + + CelebornConf.CLIENT_SHUFFLE_EARLY_DELETION().key() + + " is set to true"); + } + SparkUtils.addSparkListener(new ShuffleStatsTrackingListener()); + lifecycleManager.registerStageToWriteCelebornShuffleCallback( + (celebornShuffleId, appShuffleIdentifier) -> + SparkUtils.addStageToWriteCelebornShuffleIdDep( + this, celebornShuffleId, appShuffleIdentifier)); + lifecycleManager.registerCelebornToAppShuffleIdMappingCallback( + (celebornShuffleId, appShuffleIdentifier) -> + SparkUtils.addCelebornToSparkShuffleIdRef( + this, celebornShuffleId, appShuffleIdentifier)); + lifecycleManager.registerGetCelebornShuffleIdForReaderCallback( + (celebornShuffleId, appShuffleIdentifier) -> + SparkUtils.addCelebornShuffleReadingStageDep( + this, celebornShuffleId, appShuffleIdentifier)); + lifecycleManager.registerUpstreamAppShuffleIdsCallback( + (stageId) -> SparkUtils.getAllUpstreamAppShuffleIds(this, stageId)); + lifecycleManager.registerGetAppShuffleIdByStageIdCallback( + (stageId) -> SparkUtils.getAppShuffleIdByStageId(this, stageId)); + lifecycleManager.registerReaderStageToAppShuffleIdsCallback( + (appShuffleId, appShuffleIdentifier) -> + SparkUtils.addAppShuffleReadingStageDep( + this, appShuffleId, appShuffleIdentifier)); + lifecycleManager.registerInvalidateAllUpstreamCheckCallback( + (appShuffleIdentifier) -> + SparkUtils.canInvalidateAllUpstream(this, appShuffleIdentifier)); + if (stageDepManager == null) { + stageDepManager = new StageDependencyManager(this); + } + stageDepManager.start(); + } + if (celebornConf.getReducerFileGroupBroadcastEnabled()) { lifecycleManager.registerBroadcastGetReducerFileGroupResponseCallback( (shuffleId, getReducerFileGroupResponse) -> @@ -497,4 +545,8 @@ public LifecycleManager getLifecycleManager() { public FailedShuffleCleaner getFailedShuffleCleaner() { return this.failedShuffleCleaner; } + + public StageDependencyManager getStageDepManager() { + return this.stageDepManager; + } } diff --git a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java index 2b30a20206f..658c76a0e02 100644 --- a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java +++ b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java @@ -20,10 +20,7 @@ import java.io.ByteArrayInputStream; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; +import java.util.*; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.LongAdder; @@ -680,4 +677,56 @@ public static boolean isLocalMaster(SparkConf conf) { String master = conf.get("spark.master", ""); return master.equals("local") || master.startsWith("local["); } + + public static Integer[] getAllUpstreamAppShuffleIds( + SparkShuffleManager sparkShuffleManager, int readerStageId) { + int[] upstreamShuffleIds = + sparkShuffleManager + .getStageDepManager() + .getAllUpstreamAppShuffleIdsByStageId(readerStageId); + return Arrays.stream(upstreamShuffleIds).boxed().toArray(Integer[]::new); + } + + public static Integer getAppShuffleIdByStageId( + SparkShuffleManager sparkShuffleManager, int readerStageId) { + int writtenAppShuffleId = + sparkShuffleManager.getStageDepManager().getAppShuffleIdByStageId(readerStageId); + return writtenAppShuffleId; + } + + public static void addCelebornShuffleReadingStageDep( + SparkShuffleManager sparkShuffleManager, int celebornShuffeId, String appShuffleIdentifier) { + sparkShuffleManager + .getStageDepManager() + .addCelebornShuffleIdReadingStageDep(celebornShuffeId, appShuffleIdentifier); + } + + public static void addAppShuffleReadingStageDep( + SparkShuffleManager sparkShuffleManager, int appShuffleId, String appShuffleIdentifier) { + sparkShuffleManager + .getStageDepManager() + .addAppShuffleIdReadingStageDep(appShuffleId, appShuffleIdentifier); + } + + public static boolean canInvalidateAllUpstream( + SparkShuffleManager sparkShuffleManager, String appShuffleIdentifier) { + String[] decodedAppShuffleIdentifier = appShuffleIdentifier.split("-"); + return sparkShuffleManager + .getStageDepManager() + .hasAllUpstreamShuffleIdsInfo(Integer.valueOf(decodedAppShuffleIdentifier[1])); + } + + public static void addStageToWriteCelebornShuffleIdDep( + SparkShuffleManager sparkShuffleManager, int celebornShuffeId, String appShuffleIdentifier) { + sparkShuffleManager + .getStageDepManager() + .addStageToCelebornShuffleIdRef(celebornShuffeId, appShuffleIdentifier); + } + + public static void addCelebornToSparkShuffleIdRef( + SparkShuffleManager sparkShuffleManager, int celebornShuffeId, String appShuffleIdentifier) { + sparkShuffleManager + .getStageDepManager() + .addCelebornToAppShuffleIdMapping(celebornShuffeId, appShuffleIdentifier); + } } diff --git a/client-spark/spark-3/src/main/scala/org/apache/spark/celeborn/StageDependencyManager.scala b/client-spark/spark-3/src/main/scala/org/apache/spark/celeborn/StageDependencyManager.scala new file mode 100644 index 00000000000..ea69c176e00 --- /dev/null +++ b/client-spark/spark-3/src/main/scala/org/apache/spark/celeborn/StageDependencyManager.scala @@ -0,0 +1,268 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.celeborn + +import java.time.Instant +import java.util +import java.util.concurrent.LinkedBlockingQueue + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import org.apache.spark.CelebornSparkContextHelper +import org.apache.spark.listener.{CelebornShuffleEarlyCleanup, ShuffleStatsTrackingListener} +import org.apache.spark.scheduler.StageInfo +import org.apache.spark.shuffle.celeborn.SparkShuffleManager + +import org.apache.celeborn.common.internal.Logging + +class StageDependencyManager(shuffleManager: SparkShuffleManager) extends Logging { + + // celeborn shuffle id to all stages reading it, this is needed when we determine when to + // clean the shuffle + private[celeborn] val readShuffleToStageDep = new mutable.HashMap[Int, mutable.HashSet[Int]]() + // stage id to all celeborn shuffle ids it reads from, this structure is needed for fast + // tracking when a stage is completed + private val stageToReadCelebornShuffleDep = new mutable.HashMap[Int, mutable.HashSet[Int]]() + // spark stage id to celeborn shuffle id which it writes, + // we need to save this mapping so that we can query which celeborn shuffles is depended on by a + // certain stage when it is submitted + private val stageToCelebornShuffleIdWritten = new mutable.HashMap[Int, Int]() + // app shuffle id to all app shuffle ids it reads from, this structure used as the intermediate data + private val appShuffleIdToUpstream = new mutable.HashMap[Int, mutable.HashSet[Int]]() + // stage id to app shuffle id it writes, this structure used as the intermediate data + // to build appShuffleIdToUpstream and is needed when we need to + // invalidate all app shuffle map output location when a stage is failed + private val stageToAppShuffleIdWritten = new mutable.HashMap[Int, Int]() + + private val celebornToAppShuffleIdentifier = new mutable.HashMap[Int, String]() + private val appShuffleIdentifierToSize = new mutable.HashMap[String, Long]() + + private val shuffleIdsToBeCleaned = new LinkedBlockingQueue[Int]() + + private lazy val cleanInterval = shuffleManager.getLifecycleManager + .conf.clientShuffleEarlyDeletionIntervalMs + + def addShuffleAndStageDep(celebornShuffleId: Int, stageId: Int): Unit = this.synchronized { + val newStageIdSet = + readShuffleToStageDep.getOrElseUpdate(celebornShuffleId, new mutable.HashSet[Int]()) + newStageIdSet += stageId + val newShuffleIdSet = + stageToReadCelebornShuffleDep.getOrElseUpdate(stageId, new mutable.HashSet[Int]()) + newShuffleIdSet += celebornShuffleId + val correctionResult = shuffleIdsToBeCleaned.remove(celebornShuffleId) + if (correctionResult) { + logInfo(s"shuffle $celebornShuffleId is later recognized as needed by stage $stageId, " + + s"removed it from to be cleaned list") + } + } + + private def stageOutputToShuffleOrS3(stageInfo: StageInfo): Boolean = { + stageInfo.taskMetrics.shuffleWriteMetrics.bytesWritten > 0 || + stageInfo.taskMetrics.outputMetrics.bytesWritten > 0 + } + + private def removeStageAndReadInfo(stageId: Int): Unit = { + stageToReadCelebornShuffleDep.remove(stageId) + } + + // it is called when the stage is completed + def addAppShuffleIdentifierToSize(appShuffleIdentifier: String, bytes: Long): Unit = + this.synchronized { + appShuffleIdentifierToSize += appShuffleIdentifier -> bytes + } + + // this is called when a shuffle is cleaned up + def queryShuffleSizeByAppShuffleIdentifier(appShuffleIdentifier: String): Long = + this.synchronized { + appShuffleIdentifierToSize.getOrElse( + appShuffleIdentifier, { + logError(s"unexpected case: cannot find size information for shuffle identifier" + + s" $appShuffleIdentifier") + -1L + }) + } + + def removeShuffleAndStageDep(stageInfo: StageInfo): Unit = this.synchronized { + val stageId = stageInfo.stageId + val allReadCelebornIds = stageToReadCelebornShuffleDep.get(stageId) + allReadCelebornIds.foreach { celebornShuffleIds => + celebornShuffleIds.foreach { celebornShuffleId => + val allStages = readShuffleToStageDep.get(celebornShuffleId) + allStages.foreach { stages => + stages.remove(stageId) + if (stages.nonEmpty) { + readShuffleToStageDep += celebornShuffleId -> stages + } else { + val readyToDelete = { + if (shuffleManager.getLifecycleManager.conf.clientShuffleEarlyDeletionCheckProp) { + val propertySet = System.getProperty("CELEBORN_EARLY_SHUFFLE_DELETION", "false") + propertySet.toBoolean && stageOutputToShuffleOrS3(stageInfo) + } else { + stageOutputToShuffleOrS3(stageInfo) + } + } + if (readyToDelete) { + removeCelebornShuffleInternal(celebornShuffleId, stageId = Some(stageInfo.stageId)) + } else { + logInfo( + s"not ready to delete shuffle $celebornShuffleId while stage $stageId finished") + } + } + } + } + } + } + + def removeCelebornShuffleInternal( + celebornShuffleId: Int, + stageId: Option[Int]): Unit = { + shuffleIdsToBeCleaned.put(celebornShuffleId) + readShuffleToStageDep.remove(celebornShuffleId) + val appShuffleIdentifierOpt = celebornToAppShuffleIdentifier.get(celebornShuffleId) + if (appShuffleIdentifierOpt.isEmpty) { + logWarning(s"cannot find appShuffleIdentifier for celeborn shuffle: $celebornShuffleId") + return + } + val appShuffleIdentifier = appShuffleIdentifierOpt.get + val Array(appShuffleId, stageOfShuffleBeingDeleted, _) = + appShuffleIdentifier.split('-') + val shuffleSize = queryShuffleSizeByAppShuffleIdentifier(appShuffleIdentifier) + celebornToAppShuffleIdentifier.remove(celebornShuffleId) + logInfo(s"clean up app shuffle id $appShuffleIdentifier," + + s" celeborn shuffle id : $celebornShuffleId") + stageId.foreach(sid => removeStageAndReadInfo(sid)) + // ClientMetricsSystem.updateShuffleWrittenBytes(shuffleSize * -1) + stageId.foreach(sid => + CelebornSparkContextHelper.eventLogger.foreach(e => { + // for shuffles being deleted when no one refers to it, we need to make a record of + // stage reading it to calculate the cost saving accurately + e.onOtherEvent(CelebornShuffleEarlyCleanup( + celebornShuffleId, + appShuffleId.toInt, + stageOfShuffleBeingDeleted.toInt, + shuffleSize, + readStageId = sid, + timeToEnqueue = Instant.now().toEpochMilli)) + })) + } + + def queryCelebornShuffleIdByWriterStageId(stageId: Int): Option[Int] = this.synchronized { + stageToCelebornShuffleIdWritten.get(stageId) + } + + def getAppShuffleIdByStageId(stageId: Int): Int = this.synchronized { + // return -1 means the stage is not writing any shuffle + stageToAppShuffleIdWritten.getOrElse(stageId, -1) + } + + def getAllUpstreamAppShuffleIdsByStageId(stageId: Int): Array[Int] = this.synchronized { + val writtenAppShuffleId = stageToAppShuffleIdWritten.getOrElse( + stageId, + throw new IllegalStateException(s"cannot find app shuffle id written by stage $stageId")) + val allUpstreamAppShuffleIds = appShuffleIdToUpstream.getOrElse( + writtenAppShuffleId, + throw new IllegalStateException(s"cannot find upstream shuffle ids written of shuffle " + + s"$writtenAppShuffleId")) + allUpstreamAppShuffleIds.toArray + } + + def addStageToCelebornShuffleIdRef(celebornShuffleId: Int, appShuffleIdentifier: String): Unit = + this.synchronized { + val Array(appShuffleId, stageId, _) = appShuffleIdentifier.split('-') + stageToCelebornShuffleIdWritten += stageId.toInt -> celebornShuffleId + stageToAppShuffleIdWritten += stageId.toInt -> appShuffleId.toInt + } + + def addCelebornToAppShuffleIdMapping( + celebornShuffleId: Int, + appShuffleIdentifier: String): Unit = { + this.synchronized { + celebornToAppShuffleIdentifier += celebornShuffleId -> appShuffleIdentifier + } + } + + def addCelebornShuffleIdReadingStageDep( + celebornShuffleId: Int, + appShuffleIdentifier: String): Unit = { + this.synchronized { + val Array(_, stageId, _) = appShuffleIdentifier.split('-') + val stageIds = + readShuffleToStageDep.getOrElseUpdate(celebornShuffleId, new mutable.HashSet[Int]()) + stageIds += stageId.toInt + val celebornShuffleIds = + stageToReadCelebornShuffleDep.getOrElseUpdate(stageId.toInt, new mutable.HashSet[Int]()) + celebornShuffleIds += celebornShuffleId + } + } + + def addAppShuffleIdReadingStageDep(appShuffleId: Int, appShuffleIdentifier: String): Unit = { + this.synchronized { + val Array(_, sid, _) = appShuffleIdentifier.split('-') + val stageId = sid.toInt + // update shuffle id to all upstream + if (stageToAppShuffleIdWritten.contains(stageId)) { + val upstreamAppShuffleIds = appShuffleIdToUpstream.getOrElseUpdate( + stageToAppShuffleIdWritten(stageId), + new mutable.HashSet[Int]()) + if (!upstreamAppShuffleIds.contains(appShuffleId)) { + logInfo(s"new upstream shuffleId detected for shuffle" + + s" ${stageToAppShuffleIdWritten(stageId)}, latest: $appShuffleIdToUpstream") + upstreamAppShuffleIds += appShuffleId + } + } + } + } + + def hasAllUpstreamShuffleIdsInfo(stageId: Int): Boolean = this.synchronized { + stageToAppShuffleIdWritten.contains(stageId) && + appShuffleIdToUpstream.contains(stageToAppShuffleIdWritten(stageId)) + } + + private var stopped: Boolean = false + + def start(): Unit = { + val cleanerThread = new Thread() { + override def run(): Unit = { + while (!stopped) { + try { + val allShuffleIds = new util.ArrayList[Int] + shuffleIdsToBeCleaned.drainTo(allShuffleIds) + allShuffleIds.asScala.foreach { shuffleId => + shuffleManager.getLifecycleManager.unregisterShuffle(shuffleId) + logInfo( + s"sent unregister shuffle request for shuffle $shuffleId (celeborn shuffle id)") + } + Thread.sleep(cleanInterval) + } catch { + case t: Throwable => + logError("unexpected error in shuffle early cleaner thread", t) + } + } + } + } + + cleanerThread.setName("shuffle early cleaner thread") + cleanerThread.setDaemon(true) + cleanerThread.start() + } + + def stop(): Unit = { + stopped = true + } +} diff --git a/client-spark/spark-3/src/main/scala/org/apache/spark/listener/CelebornShuffleEarlyCleanup.scala b/client-spark/spark-3/src/main/scala/org/apache/spark/listener/CelebornShuffleEarlyCleanup.scala new file mode 100644 index 00000000000..02e9c3ac14a --- /dev/null +++ b/client-spark/spark-3/src/main/scala/org/apache/spark/listener/CelebornShuffleEarlyCleanup.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.listener + +import org.apache.spark.scheduler.SparkListenerEvent + +case class CelebornShuffleEarlyCleanup( + celebornShuffleId: Int, + applicationShuffleId: Int, + stageId: Int, + shuffleSizeInBytes: Long, + readStageId: Int, + timeToEnqueue: Long) extends SparkListenerEvent diff --git a/client-spark/spark-3/src/main/scala/org/apache/spark/listener/ShuffleStatsTrackingListener.scala b/client-spark/spark-3/src/main/scala/org/apache/spark/listener/ShuffleStatsTrackingListener.scala new file mode 100644 index 00000000000..10e80a802cc --- /dev/null +++ b/client-spark/spark-3/src/main/scala/org/apache/spark/listener/ShuffleStatsTrackingListener.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.listener + +import org.apache.spark.SparkEnv +import org.apache.spark.internal.Logging +import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted, SparkListenerStageSubmitted, SparkListenerTaskEnd} +import org.apache.spark.shuffle.celeborn.SparkShuffleManager + +class ShuffleStatsTrackingListener extends SparkListener with Logging { + + override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = { + logInfo(s"stage ${stageSubmitted.stageInfo.stageId}.${stageSubmitted.stageInfo.attemptNumber()} started") + val stageId = stageSubmitted.stageInfo.stageId + val shuffleMgr = SparkEnv.get.shuffleManager.asInstanceOf[SparkShuffleManager] + val parentStages = stageSubmitted.stageInfo.parentIds + if (shuffleMgr.getLifecycleManager.conf.clientShuffleEarlyDeletion) { + parentStages.foreach { parentStageId => + val celebornShuffleId = shuffleMgr.getStageDepManager + .queryCelebornShuffleIdByWriterStageId(parentStageId) + celebornShuffleId.foreach { sid => + shuffleMgr.getStageDepManager.addShuffleAndStageDep(sid, stageId) + } + } + } + } + + override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = { + val stageIdentifier = s"${stageCompleted.stageInfo.stageId}-" + + s"${stageCompleted.stageInfo.attemptNumber()}" + logInfo(s"stage $stageIdentifier finished with" + + s" ${stageCompleted.stageInfo.taskMetrics.shuffleWriteMetrics.bytesWritten} shuffle bytes") + val shuffleMgr = SparkEnv.get.shuffleManager.asInstanceOf[SparkShuffleManager] + if (shuffleMgr.getLifecycleManager.conf.clientShuffleEarlyDeletion) { + val shuffleIdOpt = stageCompleted.stageInfo.shuffleDepId + shuffleIdOpt.foreach { appShuffleId => + val appShuffleIdentifier = s"$appShuffleId-${stageCompleted.stageInfo.stageId}-" + + s"${stageCompleted.stageInfo.attemptNumber()}" + shuffleMgr.getStageDepManager.addAppShuffleIdentifierToSize( + appShuffleIdentifier, + stageCompleted.stageInfo.taskMetrics.shuffleWriteMetrics.bytesWritten) + } + } + if (shuffleMgr.getLifecycleManager.conf.clientShuffleEarlyDeletion && + stageCompleted.stageInfo.failureReason.isEmpty) { + shuffleMgr.getStageDepManager.removeShuffleAndStageDep(stageCompleted.stageInfo) + } + } +} diff --git a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala index 55e036155b2..57057157e25 100644 --- a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala +++ b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala @@ -47,6 +47,7 @@ import org.apache.celeborn.common.protocol._ import org.apache.celeborn.common.protocol.message.{ControlMessages, StatusCode} import org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse import org.apache.celeborn.common.util.{JavaUtils, ThreadUtils, Utils} +import org.apache.celeborn.common.util.Utils.{KNOWN_MISSING_CELEBORN_SHUFFLE_ID, UNKNOWN_MISSING_CELEBORN_SHUFFLE_ID} class CelebornShuffleReader[K, C]( handle: CelebornShuffleHandle[K, _, C], @@ -99,32 +100,82 @@ class CelebornShuffleReader[K, C]( private val pushReplicateEnabled = conf.clientPushReplicateEnabled private val preferReplicaRead = context.attemptNumber % 2 == 1 - override def read(): Iterator[Product2[K, C]] = { + private def throwFetchFailureForMissingId(partitionId: Int, celebornShuffleId: Int): Unit = { + throw new FetchFailedException( + null, + handle.shuffleId, + -1, + -1, + partitionId, + SparkUtils.FETCH_FAILURE_ERROR_MSG + celebornShuffleId, + new CelebornIOException(s"cannot find shuffle id for ${handle.shuffleId}")) + } - val startTime = System.currentTimeMillis() - val serializerInstance = newSerializerInstance(dep) - val shuffleId = - try { - SparkUtils.celebornShuffleId(shuffleClient, handle, context, false) - } catch { - case e: CelebornRuntimeException => - logError(s"Failed to get shuffleId for appShuffleId ${handle.shuffleId}", e) - if (stageRerunEnabled) { - throw new FetchFailedException( - null, + private def handleMissingCelebornShuffleId(celebornShuffleId: Int, stageId: Int): Unit = { + if (conf.clientShuffleEarlyDeletion) { + if (celebornShuffleId == UNKNOWN_MISSING_CELEBORN_SHUFFLE_ID) { + logError(s"cannot find celeborn shuffle id for app shuffle ${handle.shuffleId} which " + + s"never appear before, throwing FetchFailureException") + (startPartition until endPartition).foreach(partitionId => { + if (stageRerunEnabled && + shuffleClient.reportMissingShuffleId( handle.shuffleId, - -1, - -1, - startPartition, - SparkUtils.FETCH_FAILURE_ERROR_MSG + handle.shuffleId + "/" + handle.shuffleId, - e) + context.stageId(), + context.stageAttemptNumber())) { + throwFetchFailureForMissingId(partitionId, celebornShuffleId) } else { + val e = new IllegalStateException(s"failed to handle missing celeborn id for app" + + s" shuffle ${handle.shuffleId}") + logError(s"failed to handle missing celeborn id for app shuffle ${handle.shuffleId}", e) throw e } + }) + } else if (celebornShuffleId == KNOWN_MISSING_CELEBORN_SHUFFLE_ID) { + logError(s"cannot find celeborn shuffle id for app shuffle ${handle.shuffleId} which " + + s"has appeared before, invalidating all upstream shuffle of this shuffle") + (startPartition until endPartition).foreach(partitionId => { + if (stageRerunEnabled) { + val invalidateAllUpstreamRet = shuffleClient.invalidateAllUpstreamShuffle( + context.stageId(), + context.stageAttemptNumber(), + handle.shuffleId) + if (invalidateAllUpstreamRet) { + throwFetchFailureForMissingId(partitionId, celebornShuffleId) + } else { + // if we cannot invalidate all upstream, we need to report regular fetch failure + // for this particular shuffle id + val fetchFailureResponse = shuffleClient.reportMissingShuffleId( + handle.shuffleId, + context.stageId(), + context.stageAttemptNumber()) + if (fetchFailureResponse) { + throwFetchFailureForMissingId(partitionId, UNKNOWN_MISSING_CELEBORN_SHUFFLE_ID) + } else { + val e = new IllegalStateException(s"failed to handle missing celeborn id for app" + + s" shuffle ${handle.shuffleId}") + logError( + s"failed to handle missing celeborn id for app shuffle" + + s" ${handle.shuffleId}", + e) + throw e + } + } + } + }) } - shuffleIdTracker.track(handle.shuffleId, shuffleId) + } + } + + override def read(): Iterator[Product2[K, C]] = { + + val startTime = System.currentTimeMillis() + val serializerInstance = newSerializerInstance(dep) + val celebornShuffleId = SparkUtils.celebornShuffleId(shuffleClient, handle, context, false) + + handleMissingCelebornShuffleId(celebornShuffleId, context.stageId()) logDebug( - s"get shuffleId $shuffleId for appShuffleId ${handle.shuffleId} attemptNum ${context.stageAttemptNumber()}") + s"get shuffleId $celebornShuffleId for appShuffleId ${handle.shuffleId}" + + s" attemptNum ${context.stageAttemptNumber()}") // Update the context task metrics for each record read. val metricsCallback = new MetricsCallback { @@ -151,22 +202,22 @@ class CelebornShuffleReader[K, C]( val fetchTimeoutMs = conf.clientFetchTimeoutMs val localFetchEnabled = conf.enableReadLocalShuffleFile val localHostAddress = Utils.localHostName(conf) - val shuffleKey = Utils.makeShuffleKey(handle.appUniqueId, shuffleId) + val shuffleKey = Utils.makeShuffleKey(handle.appUniqueId, celebornShuffleId) var fileGroups: ReduceFileGroups = null var isShuffleStageEnd: Boolean = false var updateFileGroupsRetryTimes = 0 do { isShuffleStageEnd = try { - shuffleClient.isShuffleStageEnd(shuffleId) + shuffleClient.isShuffleStageEnd(celebornShuffleId) } catch { case e: Exception => - logInfo(s"Failed to check shuffle stage end for $shuffleId, assume ended", e) + logInfo(s"Failed to check shuffle stage end for $celebornShuffleId, assume ended", e) true } try { // startPartition is irrelevant - fileGroups = shuffleClient.updateFileGroup(shuffleId, startPartition) + fileGroups = shuffleClient.updateFileGroup(celebornShuffleId, startPartition) } catch { case ce: CelebornIOException if ce.getCause != null && ce.getCause.isInstanceOf[ @@ -180,7 +231,7 @@ class CelebornShuffleReader[K, C]( // if a task is interrupted, should not report fetch failure // if a task update file group timeout, should not report fetch failure // if a task GetReducerFileGroupResponse failed via broadcast, should not report fetch failure - checkAndReportFetchFailureForUpdateFileGroupFailure(shuffleId, ce) + checkAndReportFetchFailureForUpdateFileGroupFailure(celebornShuffleId, ce) } } while (fileGroups == null) @@ -333,7 +384,7 @@ class CelebornShuffleReader[K, C]( if (exceptionRef.get() == null) { try { val inputStream = shuffleClient.readPartition( - shuffleId, + celebornShuffleId, handle.shuffleId, partitionId, encodedAttemptId, @@ -380,7 +431,7 @@ class CelebornShuffleReader[K, C]( if (exceptionRef.get() != null) { exceptionRef.get() match { case ce @ (_: CelebornIOException | _: PartitionUnRetryAbleException) => - handleFetchExceptions(handle.shuffleId, shuffleId, partitionId, ce) + handleFetchExceptions(handle.shuffleId, celebornShuffleId, partitionId, ce) case e => throw e } } @@ -424,7 +475,7 @@ class CelebornShuffleReader[K, C]( iter } catch { case e @ (_: CelebornIOException | _: PartitionUnRetryAbleException) => - handleFetchExceptions(handle.shuffleId, shuffleId, partitionId, e) + handleFetchExceptions(handle.shuffleId, celebornShuffleId, partitionId, e) } } diff --git a/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java b/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java index 69cc3cd6f9c..1a8d78b2d97 100644 --- a/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java +++ b/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java @@ -236,4 +236,14 @@ public void initReducePartitionMap(int shuffleId, int numPartitions, int workerN public Map> getReducePartitionMap() { return reducePartitionMap; } + + @Override + public boolean invalidateAllUpstreamShuffle(int stageId, int attemptId, int appShuffleId) { + return true; + } + + @Override + public boolean reportMissingShuffleId(int appShuffleId, int readerStageId, int stageAttemptId) { + return true; + } } diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java index d37ec644231..d2371b1059f 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java @@ -332,4 +332,15 @@ public static ControlMessages.GetReducerFileGroupResponse deserializeReducerFile } return deserializeReducerFileGroupResponseFunction.get().apply(shuffleId, bytes); } + + /** + * report fetch failure for all upstream shuffles for a given stage id, It must be a sync call and + * make sure the cleanup is done, otherwise, incorrect shuffle data can be fetched in re-run tasks + */ + public abstract boolean invalidateAllUpstreamShuffle( + int stageId, int attemptId, int triggerAppShuffleId); + + /** report the failure to find the corresponding celeborn id for a shuffle id */ + public abstract boolean reportMissingShuffleId( + int appShuffleId, int readerStageId, int stageAttemptId); } diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java index cfe40a2968f..85b844c3f47 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java @@ -2159,4 +2159,36 @@ public void excludeFailedFetchLocation(String hostAndFetchPort, Exception e) { fetchExcludedWorkers.put(hostAndFetchPort, System.currentTimeMillis()); } } + + @Override + public boolean invalidateAllUpstreamShuffle(int stageId, int attemptId, int triggerAppId) { + PbInvalidateAllUpstreamShuffle pbInvalidateAllUpstreamShuffle = + PbInvalidateAllUpstreamShuffle.newBuilder() + .setReaderStageId(stageId) + .setAttemptId(attemptId) + .setTriggerAppShuffleId(triggerAppId) + .build(); + PbInvalidateAllUpstreamShuffleResponse pbInvalidateAllUpstreamShuffleResponse = + lifecycleManagerRef.askSync( + pbInvalidateAllUpstreamShuffle, + conf.clientRpcRegisterShuffleAskTimeout(), + ClassTag$.MODULE$.apply(PbInvalidateAllUpstreamShuffleResponse.class)); + return pbInvalidateAllUpstreamShuffleResponse.getSuccess(); + } + + @Override + public boolean reportMissingShuffleId(int appShuffleId, int readerStageId, int stageAttemptId) { + PbReportMissingShuffleId pbReportMissingShuffleId = + PbReportMissingShuffleId.newBuilder() + .setReaderStageId(readerStageId) + .setAttemptId(stageAttemptId) + .setTriggerAppShuffleId(appShuffleId) + .build(); + PbReportMissingShuffleIdResponse response = + lifecycleManagerRef.askSync( + pbReportMissingShuffleId, + conf.clientRpcRegisterShuffleAskTimeout(), + ClassTag$.MODULE$.apply(PbReportMissingShuffleIdResponse.class)); + return response.getSuccess(); + } } diff --git a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala index 23189853544..7ea22d914f2 100644 --- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala +++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala @@ -60,6 +60,7 @@ import org.apache.celeborn.common.util.{JavaUtils, PbSerDeUtils, ThreadUtils, Ut // Can Remove this if celeborn don't support scala211 in future import org.apache.celeborn.common.util.FunctionConverter._ import org.apache.celeborn.common.util.ThreadUtils.awaitResult +import org.apache.celeborn.common.util.Utils.{KNOWN_MISSING_CELEBORN_SHUFFLE_ID, UNKNOWN_MISSING_CELEBORN_SHUFFLE_ID} import org.apache.celeborn.common.util.Utils.UNKNOWN_APP_SHUFFLE_ID import org.apache.celeborn.common.write.LocationPushFailedBatches @@ -113,6 +114,10 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends // app shuffle id -> whether shuffle is determinate, rerun of a indeterminate shuffle gets different result private val appShuffleDeterminateMap = JavaUtils.newConcurrentHashMap[Int, Boolean](); + // format ${stageid}.${attemptid} + private val stagesReceivedInvalidatingUpstream = + new mutable.HashMap[String, mutable.HashSet[Int]]() + private val rpcCacheSize = conf.clientRpcCacheSize private val rpcCacheConcurrencyLevel = conf.clientRpcCacheConcurrencyLevel private val rpcCacheExpireTime = conf.clientRpcCacheExpireTime @@ -537,6 +542,23 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends } else { context.reply(PbSerDeUtils.toPbApplicationMeta(applicationMeta)) } + + case pb: PbReportMissingShuffleId => + val appShuffleId = pb.getTriggerAppShuffleId + val readerStageId = pb.getReaderStageId + val stageAttemptId = pb.getAttemptId + logInfo( + s"Received ReportMissingShuffleId, appShuffleId $appShuffleId readerStageIdentifier:" + + s" $readerStageId.$stageAttemptId") + handleReportMissingShuffleId(context, appShuffleId, readerStageId, stageAttemptId) + + case pb: PbInvalidateAllUpstreamShuffle => + val readerStageId = pb.getReaderStageId + val attemptId = pb.getAttemptId + val triggerAppShuffleId = pb.getTriggerAppShuffleId + logInfo(s"received ReportFetchFailureForAllUpstream for stage $readerStageId," + + s" attemptId: $attemptId") + handleInvalidateAllUpstreamShuffle(context, readerStageId, attemptId, triggerAppShuffleId) } private def handleReducerPartitionEnd( @@ -969,6 +991,7 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends appShuffleIdentifier: String, isWriter: Boolean, isBarrierStage: Boolean): Unit = { + println(s"get shuffle id for $appShuffleIdentifier isWriter: $isWriter") val shuffleIds = if (isWriter) { shuffleIdMapping.computeIfAbsent( @@ -979,7 +1002,12 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends override def apply(id: Int) : scala.collection.mutable.LinkedHashMap[String, (Int, Boolean)] = { val newShuffleId = shuffleIdGenerator.getAndIncrement() - logInfo(s"generate new shuffleId $newShuffleId for appShuffleId $appShuffleId appShuffleIdentifier $appShuffleIdentifier") + logInfo(s"generate new shuffleId $newShuffleId for appShuffleId $appShuffleId" + + s" appShuffleIdentifier $appShuffleIdentifier") + stageToWriteCelebornShuffleCallback.foreach(callback => + callback.accept(newShuffleId, appShuffleIdentifier)) + celebornToAppShuffleIdMappingCallback.foreach(callback => + callback.accept(newShuffleId, appShuffleIdentifier)) scala.collection.mutable.LinkedHashMap(appShuffleIdentifier -> (newShuffleId, true)) } }) @@ -1014,15 +1042,20 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends // So if a barrier stage is getting reexecuted, previous stage/attempt needs to // be cleaned up as it is entirely unusuable if (determinate && !isBarrierStage && !isCelebornSkewShuffleOrChildShuffle( - appShuffleId)) - shuffleIds.values.toSeq.reverse.find(e => e._2 == true) - else + appShuffleId) && !conf.clientShuffleEarlyDeletion) { + val result = shuffleIds.values.toSeq.reverse.find(e => e._2 == true) + if (result.isEmpty) { + logWarning(s"cannot find candidate shuffleId for determinate" + + s" shuffle $appShuffleIdentifier") + } + result + } else None val shuffleId: Integer = if (determinate && candidateShuffle.isDefined) { val id = candidateShuffle.get._1 - logInfo(s"reuse existing shuffleId $id for appShuffleId $appShuffleId appShuffleIdentifier $appShuffleIdentifier") + println(s"reuse existing shuffleId $id for appShuffleId $appShuffleId appShuffleIdentifier $appShuffleIdentifier") id } else { // this branch means it is a redo of previous write stage @@ -1035,9 +1068,15 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends shuffleIds ++= mapUpdates } val newShuffleId = shuffleIdGenerator.getAndIncrement() - logInfo(s"generate new shuffleId $newShuffleId for appShuffleId $appShuffleId appShuffleIdentifier $appShuffleIdentifier") + logInfo(s"generate new shuffleId $newShuffleId for appShuffleId $appShuffleId" + + s" appShuffleIdentifier $appShuffleIdentifier") validateCelebornShuffleIdForClean.foreach(callback => callback.accept(appShuffleIdentifier)) + stageToWriteCelebornShuffleCallback.foreach { callback => + callback.accept(newShuffleId, appShuffleIdentifier) + } + celebornToAppShuffleIdMappingCallback.foreach(callback => + callback.accept(newShuffleId, appShuffleIdentifier)) shuffleIds.put(appShuffleIdentifier, (newShuffleId, true)) newShuffleId } @@ -1049,29 +1088,140 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends s"unexpected! unknown appShuffleId $appShuffleId when checking shuffle deterministic level")) } } else { - shuffleIds.values.filter(v => v._2).map(v => v._1).toSeq.reverse.find( - areAllMapTasksEnd) match { - case Some(celebornShuffleId) => - val pbGetShuffleIdResponse = { - logDebug( - s"get shuffleId $celebornShuffleId for appShuffleId $appShuffleId appShuffleIdentifier $appShuffleIdentifier isWriter $isWriter") - PbGetShuffleIdResponse.newBuilder().setShuffleId(celebornShuffleId).setSuccess( - true).build() - } - context.reply(pbGetShuffleIdResponse) - case None => - val pbGetShuffleIdResponse = { - logInfo( - s"there is no finished map stage associated with appShuffleId $appShuffleId") - PbGetShuffleIdResponse.newBuilder().setShuffleId(UNKNOWN_APP_SHUFFLE_ID).setSuccess( - false).build() + // this is not necessarily the most concise coding style, but it helps for debugging + // purpose + var found = false + val revertedShuffleIds = shuffleIds.values.map(v => v._1).toSeq.reverse + revertedShuffleIds.foreach { celebornShuffleId: Int => + if (!found) { + try { + if (areAllMapTasksEnd(celebornShuffleId)) { + getCelebornShuffleIdForReaderCallback.foreach(callback => + callback.accept(celebornShuffleId, appShuffleIdentifier)) + getAppShuffleIdForReaderCallback.foreach(callback => + callback.accept(appShuffleId, appShuffleIdentifier)) + val pbGetShuffleIdResponse = { + logDebug( + s"get shuffleId $celebornShuffleId for appShuffleId $appShuffleId appShuffleIdentifier $appShuffleIdentifier isWriter $isWriter") + PbGetShuffleIdResponse.newBuilder() + .setShuffleId(celebornShuffleId) + .setSuccess(true) + .build() + } + context.reply(pbGetShuffleIdResponse) + found = true + } else { + logInfo(s"not all map tasks finished for shuffle $celebornShuffleId") + } + } catch { + case ise: IllegalStateException => + if (conf.clientShuffleEarlyDeletion) { + logError( + s"hit error when getting celeborn shuffle id $celebornShuffleId for" + + s" appShuffleId $appShuffleId appShuffleIdentifier $appShuffleIdentifier", + ise) + val canInvalidateAllUpstream = + checkWhetherToInvalidateAllUpstreamCallback.exists(func => + func.apply(appShuffleIdentifier)) + val pbGetShuffleIdResponse = PbGetShuffleIdResponse + .newBuilder() + .setShuffleId({ + if (canInvalidateAllUpstream) { + KNOWN_MISSING_CELEBORN_SHUFFLE_ID + } else { + UNKNOWN_MISSING_CELEBORN_SHUFFLE_ID + } + }) + .setSuccess(true) + .build() + context.reply(pbGetShuffleIdResponse) + } else { + logError( + s"unexpected IllegalStateException without" + + s" ${CelebornConf.CLIENT_SHUFFLE_EARLY_DELETION.key} turning on", + ise) + throw ise; + } } - context.reply(pbGetShuffleIdResponse) + } } } } } + private def invalidateAllKnownUpstreamShuffleOutput(stageIdentifier: String): Unit = { + val Array(readerStageId, _) = stageIdentifier.split('.').map(_.toInt) + val invalidatedUpstreamIds = + stagesReceivedInvalidatingUpstream.getOrElseUpdate( + stageIdentifier, + new mutable.HashSet[Int]()) + logInfo(s"invalidating all upstream shuffles of stage $stageIdentifier") + val upstreamShuffleIds = getUpstreamAppShuffleIdsCallback.map(f => + f.apply(readerStageId)).getOrElse(Array()) + upstreamShuffleIds.foreach { upstreamAppShuffleId => + appShuffleTrackerCallback.foreach { callback => + logInfo(s"invalidated upstream app shuffle id $upstreamAppShuffleId for stage" + + s" $stageIdentifier") + callback.accept(upstreamAppShuffleId) + invalidatedUpstreamIds += upstreamAppShuffleId + val celebornShuffleIds = shuffleIdMapping.get(upstreamAppShuffleId) + val latestShuffle = celebornShuffleIds.maxBy(_._2._1) + celebornShuffleIds.put(latestShuffle._1, (KNOWN_MISSING_CELEBORN_SHUFFLE_ID, false)) + } + } + invalidateShuffleWrittenByStage(readerStageId) + } + + private def handleInvalidateAllUpstreamShuffle( + context: RpcCallContext, + readerStageId: Int, + readerStageAttemptId: Int, + triggerAppShuffleId: Int): Unit = stagesReceivedInvalidatingUpstream.synchronized { + require( + conf.clientShuffleEarlyDeletion, + "ReportFetchFailureForAllUpstream message is " + + s"supposed to be only received when turning on" + + s" ${CelebornConf.CLIENT_SHUFFLE_EARLY_DELETION.key}") + require( + getUpstreamAppShuffleIdsCallback.isDefined, + "no callback has been registered for" + + " invalidating all upstream shuffles for a reader stage") + var ret = true + try { + val stageIdentifier = s"$readerStageId.$readerStageAttemptId" + if (!stagesReceivedInvalidatingUpstream.contains(stageIdentifier)) { + invalidateAllKnownUpstreamShuffleOutput(stageIdentifier) + } else if (!stagesReceivedInvalidatingUpstream(stageIdentifier) + .contains(triggerAppShuffleId)) { + // in this case, it means that we haven't been able to capture a certain upstream app + // shuffle id for the current stage when we invalidate all upstream last time, + // and the new upstream shuffle id show up now, we need to add the new shuffle id + // dependency and then fallback to the fetchfailure error for this shuffle + // (since other captured upstream shuffles might have been regenerated) + logInfo(s"a new upstream shuffle id $triggerAppShuffleId show up for $stageIdentifier" + + s" after we have invalidated all known upstream shuffle outputs") + val appShuffleIdentifier = s"$triggerAppShuffleId-$readerStageId-$readerStageAttemptId" + getAppShuffleIdForReaderCallback.foreach(callback => + callback.accept(triggerAppShuffleId, appShuffleIdentifier)) + ret = false + } else { + logInfo(s"ignoring the message to invalidate all upstream shuffles for stage" + + s" $stageIdentifier (triggered appShuffleId $triggerAppShuffleId)," + + s" as it has been handled by another thread") + } + } catch { + case t: Throwable => + logError( + s"hit error when invalidating upstream shuffles for stage $readerStageId," + + s" attempt $readerStageAttemptId", + t) + ret = false + } + val pbInvalidateAllUpstreamShuffleResponse = + PbInvalidateAllUpstreamShuffleResponse.newBuilder().setSuccess(ret).build() + context.reply(pbInvalidateAllUpstreamShuffleResponse) + } + private def handleReportShuffleFetchFailure( context: RpcCallContext, appShuffleId: Int, @@ -1109,6 +1259,52 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends context.reply(pbReportShuffleFetchFailureResponse) } + private def handleReportMissingShuffleId( + context: RpcCallContext, + appShuffleId: Int, + stageId: Int, + stageAttemptId: Int): Unit = { + val shuffleIds = shuffleIdMapping.get(appShuffleId) + if (shuffleIds == null) { + throw new UnsupportedOperationException(s"unexpected! unknown appShuffleId $appShuffleId") + } + var ret = true + shuffleIds.synchronized { + val stageIdentifier = s"$stageId.$stageAttemptId" + val shuffleIdentifier = s"$appShuffleId.$stageId.$stageAttemptId" + if (stagesReceivedInvalidatingUpstream.getOrElse(stageIdentifier, new mutable.HashSet[Int]()) + .contains(appShuffleId)) { + logInfo(s"ignoring missing shuffle id report from stage $stageId.$stageAttemptId as" + + s" it is already reported by other reader and handled") + } else { + logInfo(s"handle missing shuffle id for appShuffleId $appShuffleId stage" + + s" $stageId.$stageAttemptId") + appShuffleTrackerCallback match { + case Some(callback) => + try { + callback.accept(appShuffleId) + } catch { + case t: Throwable => + logError(t.toString) + ret = false + } + shuffleIds.put(shuffleIdentifier, (UNKNOWN_MISSING_CELEBORN_SHUFFLE_ID, false)) + case None => + throw new UnsupportedOperationException( + "unexpected! appShuffleTrackerCallback is not registered") + } + invalidateShuffleWrittenByStage(stageId) + stagesReceivedInvalidatingUpstream += stageIdentifier -> + (stagesReceivedInvalidatingUpstream.getOrElse( + stageIdentifier, + new mutable.HashSet[Int]()) ++ Set(appShuffleId)) + val pbReportMissingShuffleIdResponse = + PbReportMissingShuffleIdResponse.newBuilder().setSuccess(ret).build() + context.reply(pbReportMissingShuffleIdResponse) + } + } + } + private def handleReportBarrierStageAttemptFailure( context: RpcCallContext, appShuffleId: Int, @@ -1179,6 +1375,24 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends } } + private def invalidateShuffleWrittenByStage(stageId: Int): Unit = { + val writtenShuffleId = getAppShuffleIdByStageIdCallback.map { callback => + callback.apply(stageId) + } + writtenShuffleId.foreach { shuffleId => + if (shuffleId >= 0) { + val celebornShuffleIds = shuffleIdMapping.get(writtenShuffleId) + if (celebornShuffleIds != null) { + logInfo(s"invalidating location of app shuffle id $writtenShuffleId written" + + s" by stage $stageId") + val latestShuffleId = celebornShuffleIds.maxBy(_._2._1) + celebornShuffleIds.put(latestShuffleId._1, (latestShuffleId._2._1, false)) + appShuffleTrackerCallback.foreach(callback => callback.accept(shuffleId)) + } + } + } + } + private def handleStageEnd(shuffleId: Int): Unit = { // check whether shuffle has registered if (!registeredShuffle.contains(shuffleId)) { @@ -2004,6 +2218,47 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends celebornSkewShuffleCheckCallback = Some(callback) } + @volatile private var getUpstreamAppShuffleIdsCallback + : Option[Function[Integer, Array[Integer]]] = None + def registerUpstreamAppShuffleIdsCallback(callback: Function[Integer, Array[Integer]]): Unit = { + getUpstreamAppShuffleIdsCallback = Some(callback) + } + + @volatile private var getAppShuffleIdByStageIdCallback: Option[Function[Integer, Integer]] = None + def registerGetAppShuffleIdByStageIdCallback(callback: Function[Integer, Integer]): Unit = { + getAppShuffleIdByStageIdCallback = Some(callback) + } + + // expecting celeborn shuffle id and application shuffle identifier + @volatile private var getCelebornShuffleIdForReaderCallback: Option[BiConsumer[Integer, String]] = + None + def registerGetCelebornShuffleIdForReaderCallback(callback: BiConsumer[Integer, String]): Unit = { + getCelebornShuffleIdForReaderCallback = Some(callback) + } + + @volatile private var getAppShuffleIdForReaderCallback: Option[BiConsumer[Integer, String]] = None + def registerReaderStageToAppShuffleIdsCallback(callback: BiConsumer[Integer, String]): Unit = { + getAppShuffleIdForReaderCallback = Some(callback) + } + + @volatile private var stageToWriteCelebornShuffleCallback: Option[BiConsumer[Integer, String]] = + None + def registerStageToWriteCelebornShuffleCallback(callback: BiConsumer[Integer, String]): Unit = { + stageToWriteCelebornShuffleCallback = Some(callback) + } + + @volatile private var celebornToAppShuffleIdMappingCallback: Option[BiConsumer[Integer, String]] = + None + def registerCelebornToAppShuffleIdMappingCallback(callback: BiConsumer[Integer, String]): Unit = { + celebornToAppShuffleIdMappingCallback = Some(callback) + } + + @volatile private var checkWhetherToInvalidateAllUpstreamCallback + : Option[Function[String, Boolean]] = None + def registerInvalidateAllUpstreamCheckCallback(callback: Function[String, Boolean]): Unit = { + checkWhetherToInvalidateAllUpstreamCallback = Some(callback) + } + // Initialize at the end of LifecycleManager construction. initialize() diff --git a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala index fb557ebbe00..ac844375339 100644 --- a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala +++ b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala @@ -293,7 +293,7 @@ class ReducePartitionCommitHandler( if (null != attempts) { attempts.length == shuffleToCompletedMappers.get(shuffleId) } else { - false + throw new IllegalStateException(s"cannot find mapper attempts record for shuffle $shuffleId") } } diff --git a/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java b/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java index 1e729332f3e..96c43c9b6b4 100644 --- a/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java +++ b/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java @@ -125,6 +125,14 @@ public T getParsedPayload() throws InvalidProtoco return (T) PbNotifyRequiredSegment.parseFrom(payload); case PUSH_MERGED_DATA_SPLIT_PARTITION_INFO_VALUE: return (T) PbPushMergedDataSplitPartitionInfo.parseFrom(payload); + case INVALIDATE_ALL_UPSTREAM_SHUFFLE_VALUE: + return (T) PbInvalidateAllUpstreamShuffle.parseFrom(payload); + case INVALIDATE_ALL_UPSTREAM_SHUFFLE_RESPONSE_VALUE: + return (T) PbInvalidateAllUpstreamShuffleResponse.parseFrom(payload); + case REPORT_MISSING_SHUFFLE_ID_VALUE: + return (T) PbReportMissingShuffleId.parseFrom(payload); + case REPORT_MISSING_SHUFFLE_ID_RESPONSE_VALUE: + return (T) PbReportMissingShuffleIdResponse.parseFrom(payload); default: logger.error("Unexpected type {}", type); } diff --git a/common/src/main/proto/TransportMessages.proto b/common/src/main/proto/TransportMessages.proto index 160c7e96201..707e08b9da6 100644 --- a/common/src/main/proto/TransportMessages.proto +++ b/common/src/main/proto/TransportMessages.proto @@ -117,6 +117,10 @@ enum MessageType { READ_REDUCER_PARTITION_END = 94; READ_REDUCER_PARTITION_END_RESPONSE = 95; REGISTER_APPLICATION_INFO = 96; + INVALIDATE_ALL_UPSTREAM_SHUFFLE = 97; + INVALIDATE_ALL_UPSTREAM_SHUFFLE_RESPONSE = 98; + REPORT_MISSING_SHUFFLE_ID = 99; + REPORT_MISSING_SHUFFLE_ID_RESPONSE = 100; } enum StreamType { @@ -1072,3 +1076,23 @@ message PbRegisterApplicationInfo { map extraInfo = 3; string requestId = 4; } + +message PbReportMissingShuffleId { + int32 readerStageId = 1; + int32 attemptId = 2; + int32 triggerAppShuffleId = 3; +} + +message PbReportMissingShuffleIdResponse { + bool success = 1; +} + +message PbInvalidateAllUpstreamShuffle { + int32 readerStageId = 1; + int32 attemptId = 2; + int32 triggerAppShuffleId = 3; +} + +message PbInvalidateAllUpstreamShuffleResponse { + bool success = 1; +} diff --git a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala index 7c1cc6875a4..5f6f89df4a2 100644 --- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala +++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala @@ -1036,6 +1036,10 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se def clientFetchCleanFailedShuffle: Boolean = get(CLIENT_FETCH_CLEAN_FAILED_SHUFFLE) def clientFetchCleanFailedShuffleIntervalMS: Long = get(CLIENT_FETCH_CLEAN_FAILED_SHUFFLE_INTERVAL) + def clientShuffleEarlyDeletion: Boolean = get(CLIENT_SHUFFLE_EARLY_DELETION) + def clientShuffleEarlyDeletionCheckProp: Boolean = + get(CLIENT_SHUFFLE_EARLY_DELETION_CHECK_PROPERTY) + def clientShuffleEarlyDeletionIntervalMs: Long = get(CLIENT_SHUFFLE_EARLY_DELETION_INTERVAL_MS) def clientFetchExcludeWorkerOnFailureEnabled: Boolean = get(CLIENT_FETCH_EXCLUDE_WORKER_ON_FAILURE_ENABLED) def clientFetchExcludedWorkerExpireTimeout: Long = @@ -5085,6 +5089,31 @@ object CelebornConf extends Logging { .booleanConf .createWithDefault(false) + val CLIENT_SHUFFLE_EARLY_DELETION: ConfigEntry[Boolean] = + buildConf("celeborn.client.spark.fetch.shuffleEarlyDeletion") + .categories("client") + .version("0.7.0") + .doc("whether to delete shuffle when we determine a shuffle is not needed by any stage") + .booleanConf + .createWithDefault(false) + + val CLIENT_SHUFFLE_EARLY_DELETION_CHECK_PROPERTY: ConfigEntry[Boolean] = + buildConf("celeborn.client.spark.fetch.shuffleEarlyDeletion.checkProperty") + .categories("client") + .version("0.7.0") + .doc("when this is enabled, we only early delete shuffle when" + + " \"CELEBORN_EARLY_SHUFFLE_DELETION\" property is set to true") + .booleanConf + .createWithDefault(false) + + val CLIENT_SHUFFLE_EARLY_DELETION_INTERVAL_MS: ConfigEntry[Long] = + buildConf("celeborn.client.spark.fetch.shuffleEarlyDeletion.intervalMs") + .categories("client") + .version("0.7.0") + .doc("interval length to delete unused shuffle (ms)") + .longConf + .createWithDefault(5 * 60 * 1000) + val CLIENT_FETCH_CLEAN_FAILED_SHUFFLE_INTERVAL: ConfigEntry[Long] = buildConf("celeborn.client.spark.fetch.cleanFailedShuffleInterval") .categories("client") diff --git a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala index eb9274632de..e097efa1ada 100644 --- a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala +++ b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala @@ -1073,6 +1073,22 @@ object ControlMessages extends Logging { case pb: PbApplicationMetaRequest => new TransportMessage(MessageType.APPLICATION_META_REQUEST, pb.toByteArray) + + case pb: PbInvalidateAllUpstreamShuffle => + new TransportMessage(MessageType.INVALIDATE_ALL_UPSTREAM_SHUFFLE, pb.toByteArray) + + case pb: PbInvalidateAllUpstreamShuffleResponse => + new TransportMessage( + MessageType.INVALIDATE_ALL_UPSTREAM_SHUFFLE_RESPONSE, + pb.toByteArray) + + case pb: PbReportMissingShuffleId => + new TransportMessage(MessageType.REPORT_MISSING_SHUFFLE_ID, pb.toByteArray) + + case pb: PbReportMissingShuffleIdResponse => + new TransportMessage( + MessageType.REPORT_MISSING_SHUFFLE_ID_RESPONSE, + pb.toByteArray) } // TODO change return type to GeneratedMessageV3 @@ -1516,6 +1532,18 @@ object ControlMessages extends Logging { PbSerDeUtils.fromPbUserIdentifier(pbRegisterApplicationInfo.getUserIdentifier), pbRegisterApplicationInfo.getExtraInfoMap, pbRegisterApplicationInfo.getRequestId) + + case REPORT_MISSING_SHUFFLE_ID_VALUE => + message.getParsedPayload() + + case REPORT_MISSING_SHUFFLE_ID_RESPONSE_VALUE => + message.getParsedPayload() + + case INVALIDATE_ALL_UPSTREAM_SHUFFLE_VALUE => + message.getParsedPayload() + + case INVALIDATE_ALL_UPSTREAM_SHUFFLE_RESPONSE_VALUE => + message.getParsedPayload() } } } diff --git a/common/src/main/scala/org/apache/celeborn/common/util/Utils.scala b/common/src/main/scala/org/apache/celeborn/common/util/Utils.scala index e71db2eadff..3a76b48816c 100644 --- a/common/src/main/scala/org/apache/celeborn/common/util/Utils.scala +++ b/common/src/main/scala/org/apache/celeborn/common/util/Utils.scala @@ -1119,6 +1119,9 @@ object Utils extends Logging { val UNKNOWN_APP_SHUFFLE_ID = -1 + val UNKNOWN_MISSING_CELEBORN_SHUFFLE_ID = -2 + val KNOWN_MISSING_CELEBORN_SHUFFLE_ID = -3 + def isHdfsPath(path: String): Boolean = { path.matches(COMPATIBLE_HDFS_REGEX) } diff --git a/docs/configuration/client.md b/docs/configuration/client.md index fb56d8d7252..f5db7bc6f64 100644 --- a/docs/configuration/client.md +++ b/docs/configuration/client.md @@ -125,6 +125,9 @@ license: | | celeborn.client.slot.assign.maxWorkers | 10000 | false | Max workers that slots of one shuffle can be allocated on. Will choose the smaller positive one from Master side and Client side, see `celeborn.master.slot.assign.maxWorkers`. | 0.3.1 | | | celeborn.client.spark.fetch.cleanFailedShuffle | false | false | whether to clean those disk space occupied by shuffles which cannot be fetched | 0.6.0 | | | celeborn.client.spark.fetch.cleanFailedShuffleInterval | 1s | false | the interval to clean the failed-to-fetch shuffle files, only valid when celeborn.client.spark.fetch.cleanFailedShuffle is enabled | 0.6.0 | | +| celeborn.client.spark.fetch.shuffleEarlyDeletion | false | false | whether to delete shuffle when we determine a shuffle is not needed by any stage | 0.7.0 | | +| celeborn.client.spark.fetch.shuffleEarlyDeletion.checkProperty | false | false | when this is enabled, we only early delete shuffle when "CELEBORN_EARLY_SHUFFLE_DELETION" property is set to true | 0.7.0 | | +| celeborn.client.spark.fetch.shuffleEarlyDeletion.intervalMs | 300000 | false | interval length to delete unused shuffle (ms) | 0.7.0 | | | celeborn.client.spark.push.dynamicWriteMode.enabled | false | false | Whether to dynamically switch push write mode based on conditions.If true, shuffle mode will be only determined by partition count | 0.5.0 | | | celeborn.client.spark.push.dynamicWriteMode.partitionNum.threshold | 2000 | false | Threshold of shuffle partition number for dynamically switching push writer mode. When the shuffle partition number is greater than this value, use the sort-based shuffle writer for memory efficiency; otherwise use the hash-based shuffle writer for speed. This configuration only takes effect when celeborn.client.spark.push.dynamicWriteMode.enabled is true. | 0.5.0 | | | celeborn.client.spark.push.sort.memory.maxMemoryFactor | 0.4 | false | the max portion of executor memory which can be used for SortBasedWriter buffer (only valid when celeborn.client.spark.push.sort.memory.useAdaptiveThreshold is enabled | 0.5.0 | | diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornFetchFailureDiskCleanSuite.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornFetchFailureDiskCleanSuite.scala index 936ea696113..a970db43bf8 100644 --- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornFetchFailureDiskCleanSuite.scala +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornFetchFailureDiskCleanSuite.scala @@ -16,8 +16,6 @@ */ package org.apache.celeborn.tests.spark -import java.io.File - import org.apache.spark.SparkConf import org.apache.spark.shuffle.celeborn.{SparkUtils, TestCelebornShuffleManager} import org.apache.spark.sql.SparkSession @@ -73,10 +71,10 @@ class CelebornFetchFailureDiskCleanSuite extends AnyFunSuite shuffleIdToBeDeleted = Seq(0)) TestCelebornShuffleManager.registerReaderGetHook(hook) val checkingThread = - triggerStorageCheckThread(Seq(0), Seq(1), sparkSession) + StorageCheckUtils.triggerStorageCheckThread(workerDirs, Seq(0), Seq(1), sparkSession) val tuples = sparkSession.sparkContext.parallelize(1 to 10000, 2) .map { i => (i, i) }.groupByKey(4).collect() - checkStorageValidation(checkingThread) + StorageCheckUtils.checkStorageValidation(checkingThread) // verify result assert(hook.executed.get()) assert(tuples.length == 10000) @@ -86,69 +84,4 @@ class CelebornFetchFailureDiskCleanSuite extends AnyFunSuite sparkSession.stop() } } - - class CheckingThread( - shuffleIdShouldNotExist: Seq[Int], - shuffleIdMustExist: Seq[Int], - sparkSession: SparkSession) - extends Thread { - var exception: Exception = _ - - protected def checkDirStatus(): Boolean = { - val deletedSuccessfully = shuffleIdShouldNotExist.forall(shuffleId => { - workerDirs.forall(dir => - !new File(s"$dir/celeborn-worker/shuffle_data/" + - s"${sparkSession.sparkContext.applicationId}/$shuffleId").exists()) - }) - val deletedSuccessfullyString = shuffleIdShouldNotExist.map(shuffleId => { - shuffleId.toString + ":" + - workerDirs.map(dir => - !new File(s"$dir/celeborn-worker/shuffle_data/" + - s"${sparkSession.sparkContext.applicationId}/$shuffleId").exists()).toList - }).mkString(",") - val createdSuccessfully = shuffleIdMustExist.forall(shuffleId => { - workerDirs.exists(dir => - new File(s"$dir/celeborn-worker/shuffle_data/" + - s"${sparkSession.sparkContext.applicationId}/$shuffleId").exists()) - }) - val createdSuccessfullyString = shuffleIdMustExist.map(shuffleId => { - shuffleId.toString + ":" + - workerDirs.map(dir => - new File(s"$dir/celeborn-worker/shuffle_data/" + - s"${sparkSession.sparkContext.applicationId}/$shuffleId").exists()).toList - }).mkString(",") - println(s"shuffle-to-be-deleted status: $deletedSuccessfullyString \n" + - s"shuffle-to-be-created status: $createdSuccessfullyString") - deletedSuccessfully && createdSuccessfully - } - - override def run(): Unit = { - var allDataInShape = checkDirStatus() - while (!allDataInShape) { - Thread.sleep(1000) - allDataInShape = checkDirStatus() - } - } - } - - protected def triggerStorageCheckThread( - shuffleIdShouldNotExist: Seq[Int], - shuffleIdMustExist: Seq[Int], - sparkSession: SparkSession): CheckingThread = { - val checkingThread = - new CheckingThread(shuffleIdShouldNotExist, shuffleIdMustExist, sparkSession) - checkingThread.setDaemon(true) - checkingThread.start() - checkingThread - } - - protected def checkStorageValidation(thread: Thread, timeout: Long = 1200 * 1000): Unit = { - val checkingThread = thread.asInstanceOf[CheckingThread] - checkingThread.join(timeout) - if (checkingThread.isAlive || checkingThread.exception != null) { - throw new IllegalStateException("the storage checking status failed," + - s"${checkingThread.isAlive} ${if (checkingThread.exception != null) checkingThread.exception.getMessage - else "NULL"}") - } - } } diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornShuffleEarlyDeleteSuite.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornShuffleEarlyDeleteSuite.scala new file mode 100644 index 00000000000..f301fc02339 --- /dev/null +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornShuffleEarlyDeleteSuite.scala @@ -0,0 +1,445 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.tests.spark + +import org.apache.spark.SparkConf +import org.apache.spark.shuffle.celeborn.{SparkUtils, TestCelebornShuffleManager} +import org.apache.spark.sql.SparkSession + +import org.apache.celeborn.client.ShuffleClient +import org.apache.celeborn.common.CelebornConf +import org.apache.celeborn.common.protocol.ShuffleMode +import org.apache.celeborn.service.deploy.worker.Worker +import org.apache.celeborn.tests.spark.fetch.failure.FailedCommitAndExpireDataReaderHook + +class CelebornShuffleEarlyDeleteSuite extends SparkTestBase { + + override def beforeAll(): Unit = { + logInfo("test initialized , setup Celeborn mini cluster") + setupMiniClusterWithRandomPorts(workerNum = 1) + } + + override def beforeEach(): Unit = { + ShuffleClient.reset() + } + + override def afterEach(): Unit = { + System.gc() + } + + override def createWorker(map: Map[String, String]): Worker = { + val storageDir = createTmpDir() + workerDirs = workerDirs :+ storageDir + super.createWorker(map ++ Map("celeborn.master.heartbeat.worker.timeout" -> "10s"), storageDir) + } + + private def createSparkSession(additionalConf: Map[String, String] = Map()): SparkSession = { + var builder = SparkSession + .builder() + .master("local[*, 4]") + .appName("celeborn early delete") + .config(updateSparkConf(new SparkConf(), ShuffleMode.SORT)) + .config("spark.sql.shuffle.partitions", 2) + .config("spark.celeborn.shuffle.forceFallback.partition.enabled", false) + .config("spark.celeborn.shuffle.enabled", "true") + .config("spark.celeborn.client.shuffle.expired.checkInterval", "1s") + .config("spark.kryoserializer.buffer.max", "2047m") + .config( + "spark.shuffle.manager", + "org.apache.spark.shuffle.celeborn.TestCelebornShuffleManager") + .config("spark.celeborn.client.spark.fetch.throwsFetchFailure", "true") + .config(s"spark.${CelebornConf.CLIENT_SHUFFLE_EARLY_DELETION.key}", "true") + .config(s"spark.${CelebornConf.CLIENT_SHUFFLE_EARLY_DELETION_INTERVAL_MS.key}", "1000") + additionalConf.foreach { case (k, v) => + builder = builder.config(k, v) + } + builder.getOrCreate() + } + + test("spark integration test - delete shuffle data from unneeded stages") { + if (Spark3OrNewer) { + val spark = createSparkSession() + try { + val rdd1 = spark.sparkContext.parallelize(0 until 20, 3).repartition(2) + .repartition(4) + val t = new Thread() { + override def run(): Unit = { + // shuffle 1 + rdd1.mapPartitions(iter => { + Thread.sleep(20000) + iter + }).count() + } + } + t.start() + val thread = StorageCheckUtils.triggerStorageCheckThread( + workerDirs, + shuffleIdShouldNotExist = Seq(0, 2), // guard on 2 to prevent any stage retry + shuffleIdMustExist = Seq(1), + sparkSession = spark) + StorageCheckUtils.checkStorageValidation(thread) + t.join() + } finally { + spark.stop() + } + } + } + + test("spark integration test - delete shuffle data only when all child stages finished") { + if (Spark3OrNewer) { + val spark = createSparkSession() + try { + val rdd1 = spark.sparkContext.parallelize(0 until 20, 3).repartition(2) + val rdd2 = rdd1.repartition(4) + val rdd3 = rdd1.repartition(4) + val t = new Thread() { + override def run(): Unit = { + rdd2.union(rdd3).mapPartitions(iter => { + Thread.sleep(20000) + iter + }).count() + } + } + t.start() + val thread = StorageCheckUtils.triggerStorageCheckThread( + workerDirs, + shuffleIdShouldNotExist = Seq(0, 3), // guard on 3 to prevent any stage retry + shuffleIdMustExist = Seq(1, 2), + sparkSession = spark) + StorageCheckUtils.checkStorageValidation(thread) + t.join() + } finally { + spark.stop() + } + } + } + + test("spark integration test - delete shuffle data only when all child stages finished" + + " (multi-level lineage)") { + if (Spark3OrNewer) { + val spark = createSparkSession() + try { + val rdd1 = spark.sparkContext.parallelize(0 until 20, 3).repartition(2) + val rdd2 = rdd1.repartition(4).repartition(2) + val rdd3 = rdd1.repartition(4).repartition(2) + val t = new Thread() { + override def run(): Unit = { + rdd2.union(rdd3).mapPartitions(iter => { + Thread.sleep(20000) + iter + }).count() + } + } + t.start() + val thread = StorageCheckUtils.triggerStorageCheckThread( + workerDirs, + shuffleIdShouldNotExist = Seq(0, 1, 2, 5), // guard on 5 to prevent any stage retry + shuffleIdMustExist = Seq(3, 4), + sparkSession = spark) + StorageCheckUtils.checkStorageValidation(thread) + t.join() + } finally { + spark.stop() + } + } + } + + test("spark integration test - when the stage has a skipped parent stage, we should still be" + + " able to delete data") { + if (Spark3OrNewer) { + val spark = createSparkSession() + try { + val rdd1 = spark.sparkContext.parallelize(0 until 20, 3).repartition(2) + rdd1.count() + val t = new Thread() { + override def run(): Unit = { + rdd1.mapPartitions(iter => { + Thread.sleep(20000) + iter + }).repartition(3).count() + } + } + t.start() + val thread = StorageCheckUtils.triggerStorageCheckThread( + workerDirs, + shuffleIdShouldNotExist = Seq(0, 2), + shuffleIdMustExist = Seq(1), + sparkSession = spark) + StorageCheckUtils.checkStorageValidation(thread) + t.join() + } finally { + spark.stop() + } + } + } + + private def deleteTooEarlyTest( + shuffleIdShouldNotExist: Seq[Int], + shuffleIdMustExist: Seq[Int], + spark: SparkSession): Unit = { + if (Spark3OrNewer) { + var r = 0L + try { + // shuffle 0 + val rdd1 = spark.sparkContext.parallelize((0 until 20), 3).repartition(2) + rdd1.count() + val t = new Thread() { + override def run(): Unit = { + // shuffle 1 + val rdd2 = rdd1.mapPartitions(iter => { + Thread.sleep(10000) + iter + }).repartition(3) + rdd2.count() + println("rdd2.count() finished") + // leaving enough time for shuffle 0 to be expired + Thread.sleep(10000) + // shuffle 2 + val rdd3 = rdd1.repartition(5).mapPartitions(iter => { + Thread.sleep(10000) + iter + }) + r = rdd3.count() + } + } + t.start() + val thread = StorageCheckUtils.triggerStorageCheckThread( + workerDirs, + shuffleIdShouldNotExist = shuffleIdShouldNotExist, + shuffleIdMustExist = shuffleIdMustExist, + sparkSession = spark) + StorageCheckUtils.checkStorageValidation(thread) + t.join() + assert(r === 20) + } finally { + spark.stop() + } + } + } + + test("spark integration test - do not fail job when shuffle is deleted \"too early\"") { + val spark = createSparkSession() + deleteTooEarlyTest(Seq(0, 3, 5), Seq(1, 2, 4), spark) + } + + test("spark integration test - do not fail job when shuffle is deleted \"too early\"" + + " (with failed shuffle deletion)") { + val spark = createSparkSession( + Map(s"spark.${CelebornConf.CLIENT_FETCH_CLEAN_FAILED_SHUFFLE.key}" -> "true")) + deleteTooEarlyTest(Seq(0, 2, 3, 5), Seq(1, 4), spark) + } + + test("spark integration test - do not fail job when shuffle files" + + " are deleted \"too early\" (ancestor dependency)") { + val spark = createSparkSession() + if (Spark3OrNewer) { + var r = 0L + try { + // shuffle 0 + val rdd1 = spark.sparkContext.parallelize((0 until 20), 3).repartition(2) + rdd1.count() + val t = new Thread() { + override def run(): Unit = { + // shuffle 1 + val rdd2 = rdd1.repartition(3) + rdd2.count() + println("rdd2.count finished()") + // leaving enough time for shuffle 0 to be expired + Thread.sleep(10000) + // shuffle 2 + rdd2.repartition(4).count() + // leaving enough time for shuffle 1 to be expired + Thread.sleep(10000) + val rdd4 = rdd1.union(rdd2).mapPartitions(iter => { + Thread.sleep(10000) + iter + }) + r = rdd4.count() + } + } + t.start() + val thread = StorageCheckUtils.triggerStorageCheckThread( + workerDirs, + shuffleIdShouldNotExist = Seq(0, 1, 5), + shuffleIdMustExist = Seq(3, 4), + sparkSession = spark) + StorageCheckUtils.checkStorageValidation(thread) + t.join() + assert(r === 40) + } finally { + spark.stop() + } + } + } + + test("spark integration test - do not fail job when multiple shuffles (be unioned)" + + " are deleted \"too early\"") { + if (Spark3OrNewer) { + val spark = createSparkSession() + var r = 0L + try { + // shuffle 0&1 + val rdd1 = spark.sparkContext.parallelize((0 until 20), 3).repartition(2) + val rdd2 = spark.sparkContext.parallelize((0 until 30), 3).repartition(2) + rdd1.count() + rdd2.count() + val t = new Thread() { + override def run(): Unit = { + // shuffle 2&3 + val rdd3 = rdd1.repartition(3) + val rdd4 = rdd2.repartition(3) + rdd3.count() + rdd4.count() + // leaving enough time for shuffle 0&1 to be expired + Thread.sleep(10000) + // shuffle 4&5 + rdd3.repartition(4).count() + rdd4.repartition(4).count() + // leaving enough time for shuffle 2&3 to be expired + Thread.sleep(10000) + val rdd5 = rdd3.union(rdd4).mapPartitions(iter => { + Thread.sleep(10000) + iter + }) + r = rdd5.count() + } + } + t.start() + val thread = StorageCheckUtils.triggerStorageCheckThread( + workerDirs, + // 4,5 are based on vanilla spark gc which are not necessarily stable in a test + // 6,7 is based on failed shuffle cleanup, which is not covered here + shuffleIdShouldNotExist = Seq(0, 1, 2, 3, 8, 9, 12), + shuffleIdMustExist = Seq(10, 11), + sparkSession = spark) + StorageCheckUtils.checkStorageValidation(thread) + t.join() + assert(r === 50) + } finally { + spark.stop() + } + } + } + + test("spark integration test - do not fail job when multiple shuffles (be zipped/joined)" + + " are deleted \"too early\"") { + if (Spark3OrNewer) { + val spark = createSparkSession() + var r = 0L + try { + // shuffle 0&1 + val rdd1 = spark.sparkContext.parallelize((0 until 20), 3).repartition(2) + val rdd2 = spark.sparkContext.parallelize((0 until 20), 3).repartition(2) + rdd1.count() + rdd2.count() + val t = new Thread() { + override def run(): Unit = { + // shuffle 2&3 + val rdd3 = rdd1.repartition(3) + val rdd4 = rdd2.repartition(3) + rdd3.count() + rdd4.count() + // leaving enough time for shuffle 0&1 to be expired + Thread.sleep(10000) + // shuffle 4&5 + rdd3.repartition(4).count() + rdd4.repartition(4).count() + // leaving enough time for shuffle 2&3 to be expired + Thread.sleep(10000) + println("starting job for rdd 5") + val rdd5 = rdd3.zip(rdd4).mapPartitions(iter => { + Thread.sleep(10000) + iter + }) + r = rdd5.count() + } + } + t.start() + val thread = StorageCheckUtils.triggerStorageCheckThread( + workerDirs, + // 4,5 are based on vanilla spark gc which are not necessarily stable in a test + // 6,9 are based on failed shuffle cleanup, which is not covered here + shuffleIdShouldNotExist = Seq(0, 1, 2, 3, 7, 10, 12), + shuffleIdMustExist = Seq(8, 11), + sparkSession = spark) + StorageCheckUtils.checkStorageValidation(thread) + t.join() + assert(r === 20) + } finally { + spark.stop() + } + } + } + + private def multiShuffleFailureTest( + shuffleIdShouldNotExist: Seq[Int], + shuffleIdMustExist: Seq[Int], + spark: SparkSession): Unit = { + if (Spark3OrNewer) { + val celebornConf = SparkUtils.fromSparkConf(spark.sparkContext.getConf) + val hook = new FailedCommitAndExpireDataReaderHook( + celebornConf, + triggerShuffleId = 6, + shuffleIdsToExpire = (0 to 5).toList) + TestCelebornShuffleManager.registerReaderGetHook(hook) + var r = 0L + try { + // shuffle 0&1&2 + val rdd1 = spark.sparkContext.parallelize((0 until 20), 3).repartition(2) + val rdd2 = spark.sparkContext.parallelize((0 until 20), 3).repartition(2) + val rdd3 = spark.sparkContext.parallelize((0 until 20), 3).repartition(2) + val t = new Thread() { + override def run(): Unit = { + // shuffle 3&4&5 + val rdd4 = rdd1.repartition(3) + val rdd5 = rdd2.repartition(3) + val rdd6 = rdd3.repartition(3) + println("starting job for rdd 7") + val rdd7 = rdd4.zip(rdd5).zip(rdd6).repartition(2) + r = rdd7.count() + } + } + t.start() + val thread = StorageCheckUtils.triggerStorageCheckThread( + workerDirs, + shuffleIdShouldNotExist = shuffleIdShouldNotExist, + shuffleIdMustExist = shuffleIdMustExist, + sparkSession = spark) + StorageCheckUtils.checkStorageValidation(thread) + t.join() + assert(r === 20) + } finally { + spark.stop() + } + } + } + + test("spark integration test - do not fail job when multiple shuffles (be zipped/joined)" + + " are to be retried for fetching") { + val spark = createSparkSession(Map("spark.stage.maxConsecutiveAttempts" -> "3")) + multiShuffleFailureTest(Seq(0, 1, 2, 3, 4, 5), Seq(17), spark) + } + + test("spark integration test - do not fail job when multiple shuffles (be zipped/joined)" + + " are to be retried for fetching (with failed shuffle deletion)") { + val spark = createSparkSession(Map( + "spark.stage.maxConsecutiveAttempts" -> "3", + s"spark.${CelebornConf.CLIENT_FETCH_CLEAN_FAILED_SHUFFLE.key}" -> "true")) + multiShuffleFailureTest(Seq(0, 1, 2, 3, 4, 5, 8, 9, 10), Seq(17), spark) + } +} diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/StorageCheckUtils.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/StorageCheckUtils.scala new file mode 100644 index 00000000000..e456e331916 --- /dev/null +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/StorageCheckUtils.scala @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.tests.spark + +import java.io.File + +import org.apache.spark.sql.SparkSession + +private[tests] object StorageCheckUtils { + + class CheckingThread( + workerDirs: Seq[String], + shuffleIdShouldNotExist: Seq[Int], + shuffleIdMustExist: Seq[Int], + sparkSession: SparkSession) + extends Thread { + var exception: Exception = _ + + protected def checkDirStatus(): Boolean = { + val deletedSuccessfully = shuffleIdShouldNotExist.forall(shuffleId => { + workerDirs.forall(dir => + !new File(s"$dir/celeborn-worker/shuffle_data/" + + s"${sparkSession.sparkContext.applicationId}/$shuffleId").exists()) + }) + val deletedSuccessfullyString = shuffleIdShouldNotExist.map(shuffleId => { + shuffleId.toString + ":" + + workerDirs.map(dir => + !new File(s"$dir/celeborn-worker/shuffle_data/" + + s"${sparkSession.sparkContext.applicationId}/$shuffleId").exists()).toList + }).mkString(",") + val createdSuccessfully = shuffleIdMustExist.forall(shuffleId => { + workerDirs.exists(dir => + new File(s"$dir/celeborn-worker/shuffle_data/" + + s"${sparkSession.sparkContext.applicationId}/$shuffleId").exists()) + }) + val createdSuccessfullyString = shuffleIdMustExist.map(shuffleId => { + shuffleId.toString + ":" + + workerDirs.map(dir => + new File(s"$dir/celeborn-worker/shuffle_data/" + + s"${sparkSession.sparkContext.applicationId}/$shuffleId").exists()).toList + }).mkString(",") + println(s"shuffle-to-be-deleted status: $deletedSuccessfullyString \n" + + s"shuffle-to-be-created status: $createdSuccessfullyString") + deletedSuccessfully && createdSuccessfully + } + + override def run(): Unit = { + var allDataInShape = checkDirStatus() + while (!allDataInShape) { + Thread.sleep(1000) + allDataInShape = checkDirStatus() + } + } + } + + def triggerStorageCheckThread( + workerDirs: Seq[String], + shuffleIdShouldNotExist: Seq[Int], + shuffleIdMustExist: Seq[Int], + sparkSession: SparkSession): CheckingThread = { + val checkingThread = + new CheckingThread(workerDirs, shuffleIdShouldNotExist, shuffleIdMustExist, sparkSession) + checkingThread.setDaemon(true) + checkingThread.start() + checkingThread + } + + def checkStorageValidation(thread: Thread, timeout: Long = 1200 * 1000): Unit = { + val checkingThread = thread.asInstanceOf[CheckingThread] + checkingThread.join(timeout) + if (checkingThread.isAlive || checkingThread.exception != null) { + throw new IllegalStateException("the storage checking status failed," + + s"${checkingThread.isAlive} ${if (checkingThread.exception != null) checkingThread.exception.getMessage + else "NULL"}") + } + } +} diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/fetch/failure/ShuffleReaderGetHooks.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/fetch/failure/ShuffleReaderGetHooks.scala index adac14242bd..176369fb0a2 100644 --- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/fetch/failure/ShuffleReaderGetHooks.scala +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/fetch/failure/ShuffleReaderGetHooks.scala @@ -20,13 +20,66 @@ package org.apache.celeborn.tests.spark.fetch.failure import java.io.File import java.util.concurrent.atomic.AtomicBoolean -import org.apache.spark.TaskContext +import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.shuffle.ShuffleHandle -import org.apache.spark.shuffle.celeborn.{CelebornShuffleHandle, ShuffleManagerHook, SparkCommonUtils, SparkShuffleManager, SparkUtils, TestCelebornShuffleManager} +import org.apache.spark.shuffle.celeborn.{CelebornShuffleHandle, ShuffleManagerHook, SparkCommonUtils, SparkUtils, TestCelebornShuffleManager} import org.apache.celeborn.client.ShuffleClient +import org.apache.celeborn.client.commit.ReducePartitionCommitHandler import org.apache.celeborn.common.CelebornConf +class FailedCommitAndExpireDataReaderHook( + conf: CelebornConf, + triggerShuffleId: Int, + shuffleIdsToExpire: List[Int]) + extends ShuffleManagerHook { + var executed: AtomicBoolean = new AtomicBoolean(false) + val lock = new Object + + override def exec( + handle: ShuffleHandle, + startPartition: Int, + endPartition: Int, + context: TaskContext): Unit = { + + if (executed.get()) return + + lock.synchronized { + // this has to be used in local mode since it leverages that the lifecycle manager + // is in the same process with reader + handle match { + case h: CelebornShuffleHandle[_, _, _] => + val shuffleClient = ShuffleClient.get( + h.appUniqueId, + h.lifecycleManagerHost, + h.lifecycleManagerPort, + conf, + h.userIdentifier, + h.extension) + val lifecycleManager = + SparkEnv.get.shuffleManager.asInstanceOf[TestCelebornShuffleManager] + .getLifecycleManager + val celebornShuffleId = SparkUtils.celebornShuffleId(shuffleClient, h, context, false) + if (celebornShuffleId == triggerShuffleId && !executed.get()) { + println(s"putting celeborn shuffle $celebornShuffleId as commit failure") + val commitHandler = lifecycleManager.commitManager.getCommitHandler(celebornShuffleId) + commitHandler.asInstanceOf[ReducePartitionCommitHandler].dataLostShuffleSet.add( + celebornShuffleId) + shuffleIdsToExpire.foreach(sid => + SparkEnv.get.shuffleManager.asInstanceOf[TestCelebornShuffleManager] + .getStageDepManager.removeCelebornShuffleInternal(sid, None)) + // leaving enough time for all shuffles to expire + Thread.sleep(10000) + executed.set(true) + } else { + println(s"ignore hook with $celebornShuffleId $triggerShuffleId and ${executed.get()}") + } + case _ => throw new RuntimeException("unexpected, only support RssShuffleHandle here") + } + } + } +} + class ShuffleReaderGetHooks( conf: CelebornConf, workerDirs: Seq[String],