TaskContext

TaskContext is the base for contextual information about a task.

You can access the active TaskContext instance using TaskContext.get method.

import org.apache.spark.TaskContext
val ctx = TaskContext.get

TaskContext allows for registering task listeners and accessing local properties that were set on the driver.

Note
TaskContext is serializable.
package org.apache.spark

abstract class TaskContext extends Serializable {
  // only required methods that have no implementation
  // the others follow
  def isCompleted(): Boolean
  def isInterrupted(): Boolean
  def isRunningLocally(): Boolean
  def addTaskCompletionListener(listener: TaskCompletionListener): TaskContext
  def addTaskFailureListener(listener: TaskFailureListener): TaskContext
  def stageId(): Int
  def stageAttemptNumber(): Int
  def partitionId(): Int
  def attemptNumber(): Int
  def taskAttemptId(): Long
  def getLocalProperty(key: String): String
  def taskMetrics(): TaskMetrics
  def getMetricsSources(sourceName: String): Seq[Source]
  private[spark] def killTaskIfInterrupted(): Unit
  private[spark] def getKillReason(): Option[String]
  private[spark] def taskMemoryManager(): TaskMemoryManager
  private[spark] def registerAccumulator(a: AccumulatorV2[_, _]): Unit
  private[spark] def setFetchFailed(fetchFailed: FetchFailedException): Unit
}
Table 1. (Subset of) TaskContext Contract
Method Description

addTaskCompletionListener

Registers a TaskCompletionListener

Used when…​

addTaskFailureListener

Registers a TaskFailureListener

Used when…​

attemptNumber

Specifies how many times the task has been attempted (starting from 0).

Used when…​

getLocalProperty

Used when…​

Accesses local properties set by the driver using SparkContext.setLocalProperty.

getMetricsSources

Gives all the metrics sources by sourceName which are associated with the instance that runs the task.

isCompleted

Used when…​

isInterrupted

A flag that is enabled when a task was killed.

Used when…​

killTaskIfInterrupted

If the task was marked for interruption, i.e. cancellation, killTaskIfInterrupted (is supposed to) throws a TaskKilledException with the reason for the interrupt (that in turn kills the task).

Used (to break a task execution) when:

  • InterruptibleIterator is requested to hasNext

  • TaskRunner is requested to run

  • SortedIterator and UnsafeSorterSpillReader are requested to loadNext

  • Spark SQL’s FileScanRDD is requested to compute

partitionId

Id of the Partition computed by the task.

Used when…​

registerAccumulator

Used when…​

stageId

Id of the Stage the task belongs to.

Used when…​

taskAttemptId

Id of the attempt of the task.

Used when…​

taskMemoryManager

Used when…​

taskMetrics

TaskMetrics of the active Task.

Used when…​

Note
TaskContextImpl is the one and only known implementation of TaskContext Contract in Apache Spark.

unset Method

Caution
FIXME

setTaskContext Method

Caution
FIXME

Accessing Active TaskContext — get Method

get(): TaskContext

get method returns the TaskContext instance for an active task (as a TaskContextImpl). There can only be one instance and tasks can use the object to access contextual information about themselves.

val rdd = sc.range(0, 3, numSlices = 3)

scala> rdd.partitions.size
res0: Int = 3

rdd.foreach { n =>
  import org.apache.spark.TaskContext
  val tc = TaskContext.get
  val msg = s"""|-------------------
                |partitionId:   ${tc.partitionId}
                |stageId:       ${tc.stageId}
                |attemptNum:    ${tc.attemptNumber}
                |taskAttemptId: ${tc.taskAttemptId}
                |-------------------""".stripMargin
  println(msg)
}
Note
TaskContext object uses ThreadLocal to keep it thread-local, i.e. to associate state with the thread of a task.

Registering Task Listeners

Using TaskContext object you can register task listeners for task completion regardless of the final state and task failures only.

addTaskCompletionListener Method

addTaskCompletionListener(listener: TaskCompletionListener): TaskContext
addTaskCompletionListener(f: (TaskContext) => Unit): TaskContext

addTaskCompletionListener methods register a TaskCompletionListener listener to be executed on task completion.

Note
It will be executed regardless of the final state of a task - success, failure, or cancellation.
val rdd = sc.range(0, 5, numSlices = 1)

import org.apache.spark.TaskContext
val printTaskInfo = (tc: TaskContext) => {
  val msg = s"""|-------------------
                |partitionId:   ${tc.partitionId}
                |stageId:       ${tc.stageId}
                |attemptNum:    ${tc.attemptNumber}
                |taskAttemptId: ${tc.taskAttemptId}
                |-------------------""".stripMargin
  println(msg)
}

rdd.foreachPartition { _ =>
  val tc = TaskContext.get
  tc.addTaskCompletionListener(printTaskInfo)
}

addTaskFailureListener Method

addTaskFailureListener(listener: TaskFailureListener): TaskContext
addTaskFailureListener(f: (TaskContext, Throwable) => Unit): TaskContext

addTaskFailureListener methods register a TaskFailureListener listener to be executed on task failure only. It can be executed multiple times since a task can be re-attempted when it fails.

val rdd = sc.range(0, 2, numSlices = 2)

import org.apache.spark.TaskContext
val printTaskErrorInfo = (tc: TaskContext, error: Throwable) => {
  val msg = s"""|-------------------
                |partitionId:   ${tc.partitionId}
                |stageId:       ${tc.stageId}
                |attemptNum:    ${tc.attemptNumber}
                |taskAttemptId: ${tc.taskAttemptId}
                |error:         ${error.toString}
                |-------------------""".stripMargin
  println(msg)
}

val throwExceptionForOddNumber = (n: Long) => {
  if (n % 2 == 1) {
    throw new Exception(s"No way it will pass for odd number: $n")
  }
}

// FIXME It won't work.
rdd.map(throwExceptionForOddNumber).foreachPartition { _ =>
  val tc = TaskContext.get
  tc.addTaskFailureListener(printTaskErrorInfo)
}

// Listener registration matters.
rdd.mapPartitions { (it: Iterator[Long]) =>
  val tc = TaskContext.get
  tc.addTaskFailureListener(printTaskErrorInfo)
  it
}.map(throwExceptionForOddNumber).count

(Unused) Accessing Partition Id — getPartitionId Method

getPartitionId(): Int

getPartitionId gets the active TaskContext and returns partitionId or 0 (if TaskContext not available).

Note
getPartitionId is not used.

results matching ""

    No results matching ""