TaskContext

TaskContext is the base for task contexts that serve the following purpose:

A task can access the TaskContext instance using TaskContext.get object method (that simply returns null unless executed within the execution thread of a task).

import org.apache.spark.TaskContext
val ctx = TaskContext.get

TaskContext allows for registering task listeners and accessing local properties that were set on the driver.

Table 1. TaskContext Contract
Method Description

addTaskCompletionListener

addTaskCompletionListener(
  listener: TaskCompletionListener): TaskContext
addTaskCompletionListener[U](f: (TaskContext) => U): TaskContext

Registers a TaskCompletionListener

Used when…​FIXME

addTaskFailureListener

addTaskFailureListener(listener: TaskFailureListener): TaskContext
addTaskFailureListener(f: (TaskContext, Throwable) => Unit): TaskContext

Registers a TaskFailureListener

Used when…​FIXME

attemptNumber

attemptNumber(): Int

Specifies how many times the task has been attempted to execute (starting from 0)

Used when…​FIXME

fetchFailed

fetchFailed: Option[FetchFailedException]

Used when…​FIXME

getKillReason

getKillReason(): Option[String]

Used when…​FIXME

getLocalProperties

getLocalProperties: Properties

Used when…​FIXME

getLocalProperty

getLocalProperty(key: String): String

Used when…​FIXME

getMetricsSources

getMetricsSources(sourceName: String): Seq[Source]

Metrics sources by sourceName which are associated with the instance that runs the task.

Used when…​FIXME

isCompleted

isCompleted(): Boolean

Used when…​FIXME

isInterrupted

isInterrupted(): Boolean

Used when…​FIXME

isRunningLocally

isRunningLocally(): Boolean

Used when…​FIXME

killTaskIfInterrupted

killTaskIfInterrupted(): Unit

Used when…​FIXME

markInterrupted

markInterrupted(reason: String): Unit

Used when…​FIXME

markTaskCompleted

markTaskCompleted(error: Option[Throwable]): Unit

Used when…​FIXME

markTaskFailed

markTaskFailed(error: Throwable): Unit

Used when…​FIXME

partitionId

partitionId(): Int

ID of the Partition computed by the task

Used when…​FIXME

registerAccumulator

registerAccumulator(a: AccumulatorV2[_, _]): Unit

Used when…​FIXME

setFetchFailed

setFetchFailed(fetchFailed: FetchFailedException): Unit

Used when…​FIXME

stageAttemptNumber

stageAttemptNumber(): Int

Used when…​FIXME

stageId

stageId(): Int

ID of the Stage the task belongs to

Used when…​FIXME

taskAttemptId

taskAttemptId(): Long

Task (execution) attempt ID

Used when…​FIXME

taskMemoryManager

taskMemoryManager(): TaskMemoryManager

Used when…​FIXME

taskMetrics

taskMetrics(): TaskMetrics

Used when…​FIXME

Table 2. TaskContexts
TaskContext Description

BarrierTaskContext

TaskContextImpl

Setting Thread-Local TaskContext — setTaskContext Object Method

setTaskContext(tc: TaskContext): Unit

setTaskContext binds the given TaskContext as a thread-local variable.

Note

setTaskContext is used when:

Accessing Active TaskContext — get Object Method

get(): TaskContext

get returns the thread-local TaskContext instance (by requesting the taskContext thread-local variable to get the instance).

Note
get is a method of TaskContext object in Scala and so it is just one instance available (per classloader). With the ThreadLocal variable (ThreadLocal[TaskContext]), the TaskContext instance is thread-local and so allows for associating state with the thread of a task.
val rdd = sc.range(0, 3, numSlices = 3)

assert(rdd.partitions.size == 3)

rdd.foreach { n =>
  import org.apache.spark.TaskContext
  val tc = TaskContext.get
  val msg = s"""|-------------------
                |partitionId:   ${tc.partitionId}
                |stageId:       ${tc.stageId}
                |attemptNum:    ${tc.attemptNumber}
                |taskAttemptId: ${tc.taskAttemptId}
                |-------------------""".stripMargin
  println(msg)
}

Registering Task Listeners

Using TaskContext object you can register task listeners for task completion regardless of the final state and task failures only.

addTaskCompletionListener Method

addTaskCompletionListener(listener: TaskCompletionListener): TaskContext
addTaskCompletionListener(f: (TaskContext) => Unit): TaskContext

addTaskCompletionListener methods register a TaskCompletionListener listener to be executed on task completion.

Note
It will be executed regardless of the final state of a task - success, failure, or cancellation.
val rdd = sc.range(0, 5, numSlices = 1)

import org.apache.spark.TaskContext
val printTaskInfo = (tc: TaskContext) => {
  val msg = s"""|-------------------
                |partitionId:   ${tc.partitionId}
                |stageId:       ${tc.stageId}
                |attemptNum:    ${tc.attemptNumber}
                |taskAttemptId: ${tc.taskAttemptId}
                |-------------------""".stripMargin
  println(msg)
}

rdd.foreachPartition { _ =>
  val tc = TaskContext.get
  tc.addTaskCompletionListener(printTaskInfo)
}

addTaskFailureListener Method

addTaskFailureListener(listener: TaskFailureListener): TaskContext
addTaskFailureListener(f: (TaskContext, Throwable) => Unit): TaskContext

addTaskFailureListener methods register a TaskFailureListener listener to be executed on task failure only. It can be executed multiple times since a task can be re-attempted when it fails.

val rdd = sc.range(0, 2, numSlices = 2)

import org.apache.spark.TaskContext
val printTaskErrorInfo = (tc: TaskContext, error: Throwable) => {
  val msg = s"""|-------------------
                |partitionId:   ${tc.partitionId}
                |stageId:       ${tc.stageId}
                |attemptNum:    ${tc.attemptNumber}
                |taskAttemptId: ${tc.taskAttemptId}
                |error:         ${error.toString}
                |-------------------""".stripMargin
  println(msg)
}

val throwExceptionForOddNumber = (n: Long) => {
  if (n % 2 == 1) {
    throw new Exception(s"No way it will pass for odd number: $n")
  }
}

// FIXME It won't work.
rdd.map(throwExceptionForOddNumber).foreachPartition { _ =>
  val tc = TaskContext.get
  tc.addTaskFailureListener(printTaskErrorInfo)
}

// Listener registration matters.
rdd.mapPartitions { (it: Iterator[Long]) =>
  val tc = TaskContext.get
  tc.addTaskFailureListener(printTaskErrorInfo)
  it
}.map(throwExceptionForOddNumber).count

results matching ""

    No results matching ""