Spark DAGScheduler源码解读1-stage划分
- 2020 年 3 月 31 日
- 筆記
首先看DAGScheduler的Job启动方法:
/** * Run an action job on the given RDD and pass all the results to the resultHandler function as * they arrive. * @param rdd target RDD to run tasks on * @param func a function to run on each partition of the RDD * @param partitions set of partitions to run on; some jobs may not want to compute on all * partitions of the target RDD, e.g. for operations like first() * @param callSite where in the user program this job was called * @param resultHandler callback to pass each result to * @param properties scheduler properties to attach to this job, e.g. fair scheduler pool name * * @throws Exception when the job fails */ def runJob[T, U]( rdd: RDD[T], func: (TaskContext, Iterator[T]) => U, partitions: Seq[Int], callSite: CallSite, resultHandler: (Int, U) => Unit, properties: Properties): Unit = { val start = System.nanoTime //关键代码:提交job val waiter = submitJob(rdd, func, partitions, callSite, resultHandler, properties) // Note: Do not call Await.ready(future) because that calls `scala.concurrent.blocking`, // which causes concurrent SQL executions to fail if a fork-join pool is used. Note that // due to idiosyncrasies in Scala, `awaitPermission` is not actually used anywhere so it's // safe to pass in null here. For more detail, see SPARK-13747. val awaitPermission = null.asInstanceOf[scala.concurrent.CanAwait] waiter.completionFuture.ready(Duration.Inf)(awaitPermission) waiter.completionFuture.value.get match { case scala.util.Success(_) => logInfo("Job %d finished: %s, took %f s".format (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9)) case scala.util.Failure(exception) => logInfo("Job %d failed: %s, took %f s".format (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9)) // SPARK-8644: Include user stack trace in exceptions coming from DAGScheduler. val callerStackTrace = Thread.currentThread().getStackTrace.tail exception.setStackTrace(exception.getStackTrace ++ callerStackTrace) throw exception } }
将任务提交到scheduler:
/** * Submit an action job to the scheduler. * * @param rdd target RDD to run tasks on * @param func a function to run on each partition of the RDD * @param partitions set of partitions to run on; some jobs may not want to compute on all * partitions of the target RDD, e.g. for operations like first() * @param callSite where in the user program this job was called * @param resultHandler callback to pass each result to * @param properties scheduler properties to attach to this job, e.g. fair scheduler pool name * * @return a JobWaiter object that can be used to block until the job finishes executing * or can be used to cancel the job. * * @throws IllegalArgumentException when partitions ids are illegal */ def submitJob[T, U]( rdd: RDD[T], func: (TaskContext, Iterator[T]) => U, partitions: Seq[Int], callSite: CallSite, resultHandler: (Int, U) => Unit, properties: Properties): JobWaiter[U] = { // Check to make sure we are not launching a task on a partition that does not exist. val maxPartitions = rdd.partitions.length partitions.find(p => p >= maxPartitions || p < 0).foreach { p => throw new IllegalArgumentException( "Attempting to access a non-existent partition: " + p + ". " + "Total number of partitions: " + maxPartitions) } val jobId = nextJobId.getAndIncrement() if (partitions.size == 0) { // Return immediately if the job is running 0 tasks return new JobWaiter[U](this, jobId, 0, resultHandler) } assert(partitions.size > 0) val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] //关键代码:post job信息 val waiter = new JobWaiter(this, jobId, partitions.size, resultHandler) eventProcessLoop.post(JobSubmitted( jobId, rdd, func2, partitions.toArray, callSite, waiter, SerializationUtils.clone(properties))) waiter }
这里JobSubmitted方法是使用样例类实现的,

具体实现如下:

Scheduler在处理提交的Job的时候,会生成ResultStage,如下:
finalStage = newResultStage(finalRDD, func, partitions, jobId, callSite)
这里创建一个stage,并且将stage放入scheduler的HashMap中进行管理:
stageIdToStage(id) = stage updateJobIdStageIdMaps(jobId, stage)
第二步,用finalStage创建一个Job:
val job = new ActiveJob(jobId, finalStage, callSite, listener, properties)
第三步,将job加入缓存:
jobIdToActiveJob(jobId) = job activeJobs += job finalStage.setActiveJob(job)
第四步,这里很关键了,提交stage:
submitStage(finalStage)
来来来,接下来就是最核心的stage划分了:
/** 从最后一个stage开始递归计算父stage */ private def submitStage(stage: Stage) { val jobId = activeJobForStage(stage) if (jobId.isDefined) { logDebug("submitStage(" + stage + ")") if (!waitingStages(stage) && !runningStages(stage) && !failedStages(stage)) { val missing = getMissingParentStages(stage).sortBy(_.id) logDebug("missing: " + missing) //这里返回递归调用,直至第一个stage,没有父stage为止,其余的stage都在waitingStages中 if (missing.isEmpty) { logInfo("Submitting " + stage + " (" + stage.rdd + "), which has no missing parents") //提交stage的时候会创建一批task,task数量与partition数量相同 submitMissingTasks(stage, jobId.get) } else { for (parent <- missing) { //这里很巧妙,继续递归调用 parentStage,并同时加入到waitingStages中 submitStage(parent) } //这里和上面的line14配合,先提交,后加入waitingStages等待执行队列中 waitingStages += stage } } } else { abortStage(stage, "No active job for stage " + stage.id, None) } }
这里在获取父stage的时候是使用stack来进行实现的:
//stage的划分核心代码 private def getMissingParentStages(stage: Stage): List[Stage] = { val missing = new HashSet[Stage] val visited = new HashSet[RDD[_]] // 使用Stack来进行存储父stage val waitingForVisit = new Stack[RDD[_]] def visit(rdd: RDD[_]) { if (!visited(rdd)) { visited += rdd val rddHasUncachedPartitions = getCacheLocs(rdd).contains(Nil) if (rddHasUncachedPartitions) { //遍历rdd的依赖 for (dep <- rdd.dependencies) { dep match { //宽依赖处理 case shufDep: ShuffleDependency[_, _, _] => val mapStage = getShuffleMapStage(shufDep, stage.firstJobId) if (!mapStage.isAvailable) { missing += mapStage } //窄依赖处理 case narrowDep: NarrowDependency[_] => //如果是窄依赖,直接将rdd重新入栈 waitingForVisit.push(narrowDep.rdd) } } } } } //stage rdd入栈 waitingForVisit.push(stage.rdd) while (waitingForVisit.nonEmpty) { //这里调用第7行自己内部的visit方法 visit(waitingForVisit.pop()) } missing.toList }
特别注意的是,这里再处理宽依赖的时候,getShuffleMapStage方法里会创建宽依依赖stage:
val stage = newOrUsedShuffleStage(shuffleDep, firstJobId)
在这里主要是创建带dependency的shuffleDep:

这样就会导致最后一个stage不是shuffleMapStage,但是之前的都是ShuffleMapStage。这样就实现了stage的划分:对一个stage,如果它的最后一个rdd的所有依赖都是窄依赖,那么就不会创建任何新的stage;如果该stage宽依赖了某个rdd,那么就用宽依赖的那个rdd,创建一个新的stage,然后立即将新的stage返回。
在这里还有个一个核心点,就是task的创建,限于篇幅,另外写篇文章吧。