flatMapGroupsWithState Operator — Arbitrary Stateful Streaming Aggregation (with Explicit State Logic)

flatMapGroupsWithState[S: Encoder, U: Encoder](
  outputMode: OutputMode,
  timeoutConf: GroupStateTimeout)(
  func: (K, Iterator[V], GroupState[S]) => Iterator[U]): Dataset[U]
Note
flatMapGroupsWithState requires Append or Update output modes.
Note
Every time the state function func is executed for a key, the state (as GroupState[S]) is for this key only.
Caution
FIXME Why can’t flatMapGroupsWithState work with Complete output mode?
Note
  • K is the type of the keys in KeyValueGroupedDataset

  • V is the type of the values (per key) in KeyValueGroupedDataset

  • S is the user-defined type of the state as maintained for each group

  • U is the type of rows in the result Dataset

scala> spark.version
res0: String = 2.3.0-SNAPSHOT

import java.sql.Timestamp
type DeviceId = Int
case class Signal(timestamp: java.sql.Timestamp, value: Long, deviceId: DeviceId)

// input stream
import org.apache.spark.sql.functions._
val signals = spark.
  readStream.
  format("rate").
  option("rowsPerSecond", 1).
  load.
  withColumn("value", $"value" % 10).  // <-- randomize the values (just for fun)
  withColumn("deviceId", rint(rand() * 10) cast "int"). // <-- 10 devices randomly assigned to values
  as[Signal] // <-- convert to our type (from "unpleasant" Row)
scala> signals.explain
== Physical Plan ==
*Project [timestamp#0, (value#1L % 10) AS value#5L, cast(ROUND((rand(4440296395341152993) * 10.0)) as int) AS deviceId#9]
+- StreamingRelation rate, [timestamp#0, value#1L]

// stream processing using flatMapGroupsWithState operator
val device: Signal => DeviceId = { case Signal(_, _, deviceId) => deviceId }
val signalsByDevice = signals.groupByKey(device)

import org.apache.spark.sql.streaming.GroupState
type Key = Int
type Count = Long
type State = Map[Key, Count]
case class EventsCounted(deviceId: DeviceId, count: Long)
def countValuesPerKey(deviceId: Int, signalsPerDevice: Iterator[Signal], state: GroupState[State]): Iterator[EventsCounted] = {
  val values = signalsPerDevice.toList
  println(s"Device: $deviceId")
  println(s"Signals (${values.size}):")
  values.zipWithIndex.foreach { case (v, idx) => println(s"$idx. $v") }
  println(s"State: $state")

  // update the state with the count of elements for the key
  val initialState: State = Map(deviceId -> 0)
  val oldState = state.getOption.getOrElse(initialState)
  // the name to highlight that the state is for the key only
  val newValue = oldState(deviceId) + values.size
  val newState = Map(deviceId -> newValue)
  state.update(newState)

  // you must not return as it's already consumed
  // that leads to a very subtle error where no elements are in an iterator
  // iterators are one-pass data structures
  Iterator(EventsCounted(deviceId, newValue))
}
import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode}
val signalCounter = signalsByDevice.flatMapGroupsWithState(
  outputMode = OutputMode.Append,
  timeoutConf = GroupStateTimeout.NoTimeout)(func = countValuesPerKey)

import org.apache.spark.sql.streaming.{OutputMode, Trigger}
import scala.concurrent.duration._
val sq = signalCounter.
  writeStream.
  format("console").
  option("truncate", false).
  trigger(Trigger.ProcessingTime(10.seconds)).
  outputMode(OutputMode.Append).
  start
...
-------------------------------------------
Batch: 0
-------------------------------------------
+--------+-----+
|deviceId|count|
+--------+-----+
+--------+-----+
...
17/08/21 08:57:29 INFO StreamExecution: Streaming query made progress: {
  "id" : "a43822a6-500b-4f02-9133-53e9d39eedbf",
  "runId" : "79cb037e-0f28-4faf-a03e-2572b4301afe",
  "name" : null,
  "timestamp" : "2017-08-21T06:57:26.719Z",
  "batchId" : 0,
  "numInputRows" : 0,
  "processedRowsPerSecond" : 0.0,
  "durationMs" : {
    "addBatch" : 2404,
    "getBatch" : 22,
    "getOffset" : 0,
    "queryPlanning" : 141,
    "triggerExecution" : 2626,
    "walCommit" : 41
  },
  "stateOperators" : [ {
    "numRowsTotal" : 0,
    "numRowsUpdated" : 0,
    "memoryUsedBytes" : 12599
  } ],
  "sources" : [ {
    "description" : "RateSource[rowsPerSecond=1, rampUpTimeSeconds=0, numPartitions=8]",
    "startOffset" : null,
    "endOffset" : 0,
    "numInputRows" : 0,
    "processedRowsPerSecond" : 0.0
  } ],
  "sink" : {
    "description" : "ConsoleSink[numRows=20, truncate=false]"
  }
}
17/08/21 08:57:29 DEBUG StreamExecution: batch 0 committed
...
-------------------------------------------
Batch: 1
-------------------------------------------
Device: 3
Signals (1):
0. Signal(2017-08-21 08:57:27.682,1,3)
State: GroupState(<undefined>)
Device: 8
Signals (1):
0. Signal(2017-08-21 08:57:26.682,0,8)
State: GroupState(<undefined>)
Device: 7
Signals (1):
0. Signal(2017-08-21 08:57:28.682,2,7)
State: GroupState(<undefined>)
+--------+-----+
|deviceId|count|
+--------+-----+
|3       |1    |
|8       |1    |
|7       |1    |
+--------+-----+
...
17/08/21 08:57:31 INFO StreamExecution: Streaming query made progress: {
  "id" : "a43822a6-500b-4f02-9133-53e9d39eedbf",
  "runId" : "79cb037e-0f28-4faf-a03e-2572b4301afe",
  "name" : null,
  "timestamp" : "2017-08-21T06:57:30.004Z",
  "batchId" : 1,
  "numInputRows" : 3,
  "inputRowsPerSecond" : 0.91324200913242,
  "processedRowsPerSecond" : 2.2388059701492535,
  "durationMs" : {
    "addBatch" : 1245,
    "getBatch" : 22,
    "getOffset" : 0,
    "queryPlanning" : 23,
    "triggerExecution" : 1340,
    "walCommit" : 44
  },
  "stateOperators" : [ {
    "numRowsTotal" : 3,
    "numRowsUpdated" : 3,
    "memoryUsedBytes" : 18095
  } ],
  "sources" : [ {
    "description" : "RateSource[rowsPerSecond=1, rampUpTimeSeconds=0, numPartitions=8]",
    "startOffset" : 0,
    "endOffset" : 3,
    "numInputRows" : 3,
    "inputRowsPerSecond" : 0.91324200913242,
    "processedRowsPerSecond" : 2.2388059701492535
  } ],
  "sink" : {
    "description" : "ConsoleSink[numRows=20, truncate=false]"
  }
}
17/08/21 08:57:31 DEBUG StreamExecution: batch 1 committed
...
-------------------------------------------
Batch: 2
-------------------------------------------
Device: 1
Signals (1):
0. Signal(2017-08-21 08:57:36.682,0,1)
State: GroupState(<undefined>)
Device: 3
Signals (2):
0. Signal(2017-08-21 08:57:32.682,6,3)
1. Signal(2017-08-21 08:57:35.682,9,3)
State: GroupState(Map(3 -> 1))
Device: 5
Signals (1):
0. Signal(2017-08-21 08:57:34.682,8,5)
State: GroupState(<undefined>)
Device: 4
Signals (1):
0. Signal(2017-08-21 08:57:29.682,3,4)
State: GroupState(<undefined>)
Device: 8
Signals (2):
0. Signal(2017-08-21 08:57:31.682,5,8)
1. Signal(2017-08-21 08:57:33.682,7,8)
State: GroupState(Map(8 -> 1))
Device: 7
Signals (2):
0. Signal(2017-08-21 08:57:30.682,4,7)
1. Signal(2017-08-21 08:57:37.682,1,7)
State: GroupState(Map(7 -> 1))
Device: 0
Signals (1):
0. Signal(2017-08-21 08:57:38.682,2,0)
State: GroupState(<undefined>)
+--------+-----+
|deviceId|count|
+--------+-----+
|1       |1    |
|3       |3    |
|5       |1    |
|4       |1    |
|8       |3    |
|7       |3    |
|0       |1    |
+--------+-----+
...
17/08/21 08:57:41 INFO StreamExecution: Streaming query made progress: {
  "id" : "a43822a6-500b-4f02-9133-53e9d39eedbf",
  "runId" : "79cb037e-0f28-4faf-a03e-2572b4301afe",
  "name" : null,
  "timestamp" : "2017-08-21T06:57:40.005Z",
  "batchId" : 2,
  "numInputRows" : 10,
  "inputRowsPerSecond" : 0.9999000099990002,
  "processedRowsPerSecond" : 9.242144177449168,
  "durationMs" : {
    "addBatch" : 1032,
    "getBatch" : 8,
    "getOffset" : 0,
    "queryPlanning" : 19,
    "triggerExecution" : 1082,
    "walCommit" : 21
  },
  "stateOperators" : [ {
    "numRowsTotal" : 7,
    "numRowsUpdated" : 7,
    "memoryUsedBytes" : 19023
  } ],
  "sources" : [ {
    "description" : "RateSource[rowsPerSecond=1, rampUpTimeSeconds=0, numPartitions=8]",
    "startOffset" : 3,
    "endOffset" : 13,
    "numInputRows" : 10,
    "inputRowsPerSecond" : 0.9999000099990002,
    "processedRowsPerSecond" : 9.242144177449168
  } ],
  "sink" : {
    "description" : "ConsoleSink[numRows=20, truncate=false]"
  }
}
17/08/21 08:57:41 DEBUG StreamExecution: batch 2 committed

// In the end...
sq.stop

// Use stateOperators to access the stats
scala> println(sq.lastProgress.stateOperators(0).prettyJson)
{
  "numRowsTotal" : 7,
  "numRowsUpdated" : 7,
  "memoryUsedBytes" : 19023
}

Internally, flatMapGroupsWithState operator creates a Dataset with FlatMapGroupsWithState unary logical operator.

scala> :type signalCounter
org.apache.spark.sql.Dataset[EventsCounted]
scala> println(signalCounter.queryExecution.logical.numberedTreeString)
00 'SerializeFromObject [assertnotnull(assertnotnull(input[0, $line27.$read$$iw$$iw$EventsCounted, true])).deviceId AS deviceId#25, assertnotnull(assertnotnull(input[0, $line27.$read$$iw$$iw$EventsCounted, true])).count AS count#26L]
01 +- 'FlatMapGroupsWithState <function3>, unresolveddeserializer(upcast(getcolumnbyordinal(0, IntegerType), IntegerType, - root class: "scala.Int"), value#20), unresolveddeserializer(newInstance(class $line17.$read$$iw$$iw$Signal), timestamp#0, value#5L, deviceId#9), [value#20], [timestamp#0, value#5L, deviceId#9], obj#24: $line27.$read$$iw$$iw$EventsCounted, class[value[0]: map<int,bigint>], Append, false, NoTimeout
02    +- AppendColumns <function1>, class $line17.$read$$iw$$iw$Signal, [StructField(timestamp,TimestampType,true), StructField(value,LongType,false), StructField(deviceId,IntegerType,false)], newInstance(class $line17.$read$$iw$$iw$Signal), [input[0, int, false] AS value#20]
03       +- Project [timestamp#0, value#5L, cast(ROUND((rand(4440296395341152993) * cast(10 as double))) as int) AS deviceId#9]
04          +- Project [timestamp#0, (value#1L % cast(10 as bigint)) AS value#5L]
05             +- StreamingRelation DataSource(org.apache.spark.sql.SparkSession@385c6d6b,rate,List(),None,List(),None,Map(rowsPerSecond -> 1),None), rate, [timestamp#0, value#1L]

scala> signalCounter.explain
== Physical Plan ==
*SerializeFromObject [assertnotnull(input[0, $line27.$read$$iw$$iw$EventsCounted, true]).deviceId AS deviceId#25, assertnotnull(input[0, $line27.$read$$iw$$iw$EventsCounted, true]).count AS count#26L]
+- FlatMapGroupsWithState <function3>, value#20: int, newInstance(class $line17.$read$$iw$$iw$Signal), [value#20], [timestamp#0, value#5L, deviceId#9], obj#24: $line27.$read$$iw$$iw$EventsCounted, StatefulOperatorStateInfo(<unknown>,50c7ece5-0716-4e43-9b56-09842db8baf1,0,0), class[value[0]: map<int,bigint>], Append, NoTimeout, 0, 0
   +- *Sort [value#20 ASC NULLS FIRST], false, 0
      +- Exchange hashpartitioning(value#20, 200)
         +- AppendColumns <function1>, newInstance(class $line17.$read$$iw$$iw$Signal), [input[0, int, false] AS value#20]
            +- *Project [timestamp#0, (value#1L % 10) AS value#5L, cast(ROUND((rand(4440296395341152993) * 10.0)) as int) AS deviceId#9]
               +- StreamingRelation rate, [timestamp#0, value#1L]

flatMapGroupsWithState reports a IllegalArgumentException when the input outputMode is neither Append nor Update.

scala> val result = signalsByDevice.flatMapGroupsWithState(
     |   outputMode = OutputMode.Complete,
     |   timeoutConf = GroupStateTimeout.NoTimeout)(func = stateFn)
java.lang.IllegalArgumentException: The output mode of function should be append or update
  at org.apache.spark.sql.KeyValueGroupedDataset.flatMapGroupsWithState(KeyValueGroupedDataset.scala:381)
  ... 54 elided
Caution
FIXME Examples for append and update output modes (to demo the difference)
Caution
FIXME Examples for GroupStateTimeout.EventTimeTimeout with withWatermark operator

results matching ""

    No results matching ""