SimplifyCasts Logical Optimization

SimplifyCasts is a base logical optimization that eliminates redundant casts in the following cases:

  1. The input is already the type to cast to.

  2. The input is of ArrayType or MapType type and contains no null elements.

SimplifyCasts is part of the Operator Optimization before Inferring Filters fixed-point batch in the standard batches of the Catalyst Optimizer.

SimplifyCasts is simply a Catalyst rule for transforming logical plans, i.e. Rule[LogicalPlan].

// Case 1. The input is already the type to cast to
scala> val ds = spark.range(1)
ds: org.apache.spark.sql.Dataset[Long] = [id: bigint]

scala> ds.printSchema
 |-- id: long (nullable = false)

scala> ds.selectExpr("CAST (id AS long)").explain(true)
TRACE SparkOptimizer:
=== Applying Rule org.apache.spark.sql.catalyst.optimizer.SimplifyCasts ===
!Project [cast(id#0L as bigint) AS id#7L]   Project [id#0L AS id#7L]
 +- Range (0, 1, step=1, splits=Some(8))    +- Range (0, 1, step=1, splits=Some(8))

TRACE SparkOptimizer:
=== Applying Rule org.apache.spark.sql.catalyst.optimizer.RemoveAliasOnlyProject ===
!Project [id#0L AS id#7L]                  Range (0, 1, step=1, splits=Some(8))
!+- Range (0, 1, step=1, splits=Some(8))

TRACE SparkOptimizer: Fixed point reached for batch Operator Optimizations after 2 iterations.
DEBUG SparkOptimizer:
=== Result of Batch Operator Optimizations ===
!Project [cast(id#0L as bigint) AS id#7L]   Range (0, 1, step=1, splits=Some(8))
!+- Range (0, 1, step=1, splits=Some(8))
== Parsed Logical Plan ==
'Project [unresolvedalias(cast('id as bigint), None)]
+- Range (0, 1, step=1, splits=Some(8))

== Analyzed Logical Plan ==
id: bigint
Project [cast(id#0L as bigint) AS id#7L]
+- Range (0, 1, step=1, splits=Some(8))

== Optimized Logical Plan ==
Range (0, 1, step=1, splits=Some(8))

== Physical Plan ==
*Range (0, 1, step=1, splits=Some(8))

// Case 2A. The input is of `ArrayType` type and contains no `null` elements.
scala> val intArray = Seq(Array(1)).toDS
intArray: org.apache.spark.sql.Dataset[Array[Int]] = [value: array<int>]

scala> intArray.printSchema
 |-- value: array (nullable = true)
 |    |-- element: integer (containsNull = false)

scala> => arr.sum).explain(true)
TRACE SparkOptimizer:
=== Applying Rule org.apache.spark.sql.catalyst.optimizer.SimplifyCasts ===
 SerializeFromObject [input[0, int, true] AS value#36]                                                       SerializeFromObject [input[0, int, true] AS value#36]
 +- MapElements <function1>, class [I, [StructField(value,ArrayType(IntegerType,false),true)], obj#35: int   +- MapElements <function1>, class [I, [StructField(value,ArrayType(IntegerType,false),true)], obj#35: int
!   +- DeserializeToObject cast(value#15 as array<int>).toIntArray, obj#34: [I                                  +- DeserializeToObject value#15.toIntArray, obj#34: [I
       +- LocalRelation [value#15]                                                                                 +- LocalRelation [value#15]

TRACE SparkOptimizer: Fixed point reached for batch Operator Optimizations after 2 iterations.
DEBUG SparkOptimizer:
=== Result of Batch Operator Optimizations ===
 SerializeFromObject [input[0, int, true] AS value#36]                                                       SerializeFromObject [input[0, int, true] AS value#36]
 +- MapElements <function1>, class [I, [StructField(value,ArrayType(IntegerType,false),true)], obj#35: int   +- MapElements <function1>, class [I, [StructField(value,ArrayType(IntegerType,false),true)], obj#35: int
!   +- DeserializeToObject cast(value#15 as array<int>).toIntArray, obj#34: [I                                  +- DeserializeToObject value#15.toIntArray, obj#34: [I
       +- LocalRelation [value#15]                                                                                 +- LocalRelation [value#15]
== Parsed Logical Plan ==
'SerializeFromObject [input[0, int, true] AS value#36]
+- 'MapElements <function1>, class [I, [StructField(value,ArrayType(IntegerType,false),true)], obj#35: int
   +- 'DeserializeToObject unresolveddeserializer(upcast(getcolumnbyordinal(0, ArrayType(IntegerType,false)), ArrayType(IntegerType,false), - root class: "scala.Array").toIntArray), obj#34: [I
      +- LocalRelation [value#15]

== Analyzed Logical Plan ==
value: int
SerializeFromObject [input[0, int, true] AS value#36]
+- MapElements <function1>, class [I, [StructField(value,ArrayType(IntegerType,false),true)], obj#35: int
   +- DeserializeToObject cast(value#15 as array<int>).toIntArray, obj#34: [I
      +- LocalRelation [value#15]

== Optimized Logical Plan ==
SerializeFromObject [input[0, int, true] AS value#36]
+- MapElements <function1>, class [I, [StructField(value,ArrayType(IntegerType,false),true)], obj#35: int
   +- DeserializeToObject value#15.toIntArray, obj#34: [I
      +- LocalRelation [value#15]

== Physical Plan ==
*SerializeFromObject [input[0, int, true] AS value#36]
+- *MapElements <function1>, obj#35: int
   +- *DeserializeToObject value#15.toIntArray, obj#34: [I
      +- LocalTableScan [value#15]

// Case 2B. The input is of `MapType` type and contains no `null` elements.
scala> val mapDF = Seq(("one", 1), ("two", 2)).toDF("k", "v").withColumn("m", map(col("k"), col("v")))
mapDF: org.apache.spark.sql.DataFrame = [k: string, v: int ... 1 more field]

scala> mapDF.printSchema
 |-- k: string (nullable = true)
 |-- v: integer (nullable = false)
 |-- m: map (nullable = false)
 |    |-- key: string
 |    |-- value: integer (valueContainsNull = false)

scala> mapDF.selectExpr("""CAST (m AS map<string, int>)""").explain(true)
TRACE SparkOptimizer:
=== Applying Rule org.apache.spark.sql.catalyst.optimizer.SimplifyCasts ===
!Project [cast(map(_1#250, _2#251) as map<string,int>) AS m#272]   Project [map(_1#250, _2#251) AS m#272]
 +- LocalRelation [_1#250, _2#251]                                 +- LocalRelation [_1#250, _2#251]
== Parsed Logical Plan ==
'Project [unresolvedalias(cast('m as map<string,int>), None)]
+- Project [k#253, v#254, map(k#253, v#254) AS m#258]
   +- Project [_1#250 AS k#253, _2#251 AS v#254]
      +- LocalRelation [_1#250, _2#251]

== Analyzed Logical Plan ==
m: map<string,int>
Project [cast(m#258 as map<string,int>) AS m#272]
+- Project [k#253, v#254, map(k#253, v#254) AS m#258]
   +- Project [_1#250 AS k#253, _2#251 AS v#254]
      +- LocalRelation [_1#250, _2#251]

== Optimized Logical Plan ==
LocalRelation [m#272]

== Physical Plan ==
LocalTableScan [m#272]

Executing Rule — apply Method

apply(plan: LogicalPlan): LogicalPlan
apply is part of the Rule Contract to execute (apply) a rule on a TreeNode (e.g. LogicalPlan).


