­

Spark DAGScheduler源码解读1-stage划分

  • 2020 年 3 月 31 日
  • 筆記

首先看DAGScheduler的Job启动方法:

/**   * Run an action job on the given RDD and pass all the results to the resultHandler function as   * they arrive.   * @param rdd target RDD to run tasks on   * @param func a function to run on each partition of the RDD   * @param partitions set of partitions to run on; some jobs may not want to compute on all   *   partitions of the target RDD, e.g. for operations like first()   * @param callSite where in the user program this job was called   * @param resultHandler callback to pass each result to   * @param properties scheduler properties to attach to this job, e.g. fair scheduler pool name   *   * @throws Exception when the job fails   */  def runJob[T, U](      rdd: RDD[T],      func: (TaskContext, Iterator[T]) => U,      partitions: Seq[Int],      callSite: CallSite,      resultHandler: (Int, U) => Unit,      properties: Properties): Unit = {    val start = System.nanoTime    //关键代码:提交job    val waiter = submitJob(rdd, func, partitions, callSite, resultHandler, properties)    // Note: Do not call Await.ready(future) because that calls `scala.concurrent.blocking`,    // which causes concurrent SQL executions to fail if a fork-join pool is used. Note that    // due to idiosyncrasies in Scala, `awaitPermission` is not actually used anywhere so it's    // safe to pass in null here. For more detail, see SPARK-13747.    val awaitPermission = null.asInstanceOf[scala.concurrent.CanAwait]    waiter.completionFuture.ready(Duration.Inf)(awaitPermission)    waiter.completionFuture.value.get match {      case scala.util.Success(_) =>        logInfo("Job %d finished: %s, took %f s".format          (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9))      case scala.util.Failure(exception) =>        logInfo("Job %d failed: %s, took %f s".format          (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9))        // SPARK-8644: Include user stack trace in exceptions coming from DAGScheduler.        val callerStackTrace = Thread.currentThread().getStackTrace.tail        exception.setStackTrace(exception.getStackTrace ++ callerStackTrace)        throw exception    }  }

将任务提交到scheduler:

/**   * Submit an action job to the scheduler.   *   * @param rdd target RDD to run tasks on   * @param func a function to run on each partition of the RDD   * @param partitions set of partitions to run on; some jobs may not want to compute on all   *   partitions of the target RDD, e.g. for operations like first()   * @param callSite where in the user program this job was called   * @param resultHandler callback to pass each result to   * @param properties scheduler properties to attach to this job, e.g. fair scheduler pool name   *   * @return a JobWaiter object that can be used to block until the job finishes executing   *         or can be used to cancel the job.   *   * @throws IllegalArgumentException when partitions ids are illegal   */  def submitJob[T, U](      rdd: RDD[T],      func: (TaskContext, Iterator[T]) => U,      partitions: Seq[Int],      callSite: CallSite,      resultHandler: (Int, U) => Unit,      properties: Properties): JobWaiter[U] = {    // Check to make sure we are not launching a task on a partition that does not exist.    val maxPartitions = rdd.partitions.length    partitions.find(p => p >= maxPartitions || p < 0).foreach { p =>      throw new IllegalArgumentException(        "Attempting to access a non-existent partition: " + p + ". " +          "Total number of partitions: " + maxPartitions)    }      val jobId = nextJobId.getAndIncrement()    if (partitions.size == 0) {      // Return immediately if the job is running 0 tasks      return new JobWaiter[U](this, jobId, 0, resultHandler)    }      assert(partitions.size > 0)    val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]    //关键代码:post job信息    val waiter = new JobWaiter(this, jobId, partitions.size, resultHandler)    eventProcessLoop.post(JobSubmitted(      jobId, rdd, func2, partitions.toArray, callSite, waiter,      SerializationUtils.clone(properties)))    waiter  }

这里JobSubmitted方法是使用样例类实现的,

具体实现如下:

Scheduler在处理提交的Job的时候,会生成ResultStage,如下:

finalStage = newResultStage(finalRDD, func, partitions, jobId, callSite)

这里创建一个stage,并且将stage放入scheduler的HashMap中进行管理:

   stageIdToStage(id) = stage      updateJobIdStageIdMaps(jobId, stage)

第二步,用finalStage创建一个Job:

val job = new ActiveJob(jobId, finalStage, callSite, listener, properties)

第三步,将job加入缓存:

jobIdToActiveJob(jobId) = job  activeJobs += job  finalStage.setActiveJob(job)

第四步,这里很关键了,提交stage:

submitStage(finalStage)

来来来,接下来就是最核心的stage划分了:

/** 从最后一个stage开始递归计算父stage */  private def submitStage(stage: Stage) {    val jobId = activeJobForStage(stage)    if (jobId.isDefined) {      logDebug("submitStage(" + stage + ")")      if (!waitingStages(stage) && !runningStages(stage) && !failedStages(stage)) {        val missing = getMissingParentStages(stage).sortBy(_.id)        logDebug("missing: " + missing)        //这里返回递归调用,直至第一个stage,没有父stage为止,其余的stage都在waitingStages中        if (missing.isEmpty) {          logInfo("Submitting " + stage + " (" + stage.rdd + "), which has no missing parents")          //提交stage的时候会创建一批task,task数量与partition数量相同          submitMissingTasks(stage, jobId.get)        } else {          for (parent <- missing) {            //这里很巧妙,继续递归调用 parentStage,并同时加入到waitingStages中            submitStage(parent)          }          //这里和上面的line14配合,先提交,后加入waitingStages等待执行队列中          waitingStages += stage        }      }    } else {      abortStage(stage, "No active job for stage " + stage.id, None)    }  }

这里在获取父stage的时候是使用stack来进行实现的:

//stage的划分核心代码  private def getMissingParentStages(stage: Stage): List[Stage] = {    val missing = new HashSet[Stage]    val visited = new HashSet[RDD[_]]    // 使用Stack来进行存储父stage    val waitingForVisit = new Stack[RDD[_]]    def visit(rdd: RDD[_]) {      if (!visited(rdd)) {        visited += rdd        val rddHasUncachedPartitions = getCacheLocs(rdd).contains(Nil)        if (rddHasUncachedPartitions) {          //遍历rdd的依赖          for (dep <- rdd.dependencies) {            dep match {                //宽依赖处理              case shufDep: ShuffleDependency[_, _, _] =>                val mapStage = getShuffleMapStage(shufDep, stage.firstJobId)                if (!mapStage.isAvailable) {                  missing += mapStage                }                //窄依赖处理              case narrowDep: NarrowDependency[_] =>                //如果是窄依赖,直接将rdd重新入栈                waitingForVisit.push(narrowDep.rdd)            }          }        }      }    }    //stage rdd入栈    waitingForVisit.push(stage.rdd)    while (waitingForVisit.nonEmpty) {      //这里调用第7行自己内部的visit方法      visit(waitingForVisit.pop())    }    missing.toList  }

特别注意的是,这里再处理宽依赖的时候,getShuffleMapStage方法里会创建宽依依赖stage:

val stage = newOrUsedShuffleStage(shuffleDep, firstJobId)

在这里主要是创建带dependency的shuffleDep:

这样就会导致最后一个stage不是shuffleMapStage,但是之前的都是ShuffleMapStage。这样就实现了stage的划分:对一个stage,如果它的最后一个rdd的所有依赖都是窄依赖,那么就不会创建任何新的stage;如果该stage宽依赖了某个rdd,那么就用宽依赖的那个rdd,创建一个新的stage,然后立即将新的stage返回。

在这里还有个一个核心点,就是task的创建,限于篇幅,另外写篇文章吧。