Skip to content

Commit 75757e1

Browse files
ilicmarkodbcloud-fan
authored andcommitted
[SPARK-56169][SQL] Fix ClassCastException in error reporting when GetStructField child type is changed by plan transformation
### What changes were proposed in this pull request? SPARK-53470 added `ExpectsInputTypes` to `GetStructField` so that `checkInputDataTypes()` catches the case where a plan transformation changes the child's type from `StructType` to something else. This can happen when an analyzer rule inserts a projection that changes a column's output type after `GetStructField` was already created referencing that column. However, when `CheckAnalysis` detects this mismatch, the error formatting path (`toPrettySQL` -> `usePrettyExpression`) accesses `GetStructField.dataType` which calls `childSchema` -> `child.dataType.asInstanceOf[StructType]`, throwing a raw `ClassCastException` before the proper `DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE` error can be reported. This PR fixes two things: 1. `usePrettyExpression` checks `child.dataType` before accessing `childSchema`, falling back to a safe representation when the child is not a `StructType` 2. `childSchema` uses pattern matching instead of an unsafe cast, throwing a clear `SparkException.internalError` instead of `ClassCastException` ### Why are the changes needed? Without this fix, users see a raw `ClassCastException: StringType$ cannot be cast to StructType` instead of the proper `DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE` error that `checkInputDataTypes()` was trying to report. ### Does this PR introduce _any_ user-facing change? Yes - users now get a proper `AnalysisException` with error class `DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE` instead of a raw `ClassCastException`. ### How was this patch tested? New tests. ### Was this patch authored or co-authored using generative AI tooling? Yes Co-authored-by: Claude Closes #54970 from ilicmarkodb/fix-get-struct-field-classcast. Authored-by: ilicmarkodb <marko.ilic@databricks.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 34ac927 commit 75757e1

3 files changed

Lines changed: 44 additions & 7 deletions

File tree

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
1919

2020
import scala.util.{Either, Left, Right}
2121

22-
import org.apache.spark.QueryContext
22+
import org.apache.spark.{QueryContext, SparkException}
2323
import org.apache.spark.sql.catalyst.InternalRow
2424
import org.apache.spark.sql.catalyst.analysis._
2525
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
@@ -184,7 +184,13 @@ case class GetStructField(child: Expression, ordinal: Int, name: Option[String]
184184

185185
override def inputTypes: Seq[AbstractDataType] = Seq(StructType, IntegralType)
186186

187-
lazy val childSchema = child.dataType.asInstanceOf[StructType]
187+
lazy val childSchema = child.dataType match {
188+
case st: StructType => st
189+
case other =>
190+
throw SparkException.internalError(
191+
s"GetStructField requires a StructType child, but got ${other.catalogString}. " +
192+
s"The child type may have been changed by a plan transformation.")
193+
}
188194

189195
override lazy val canonicalized: Expression = {
190196
copy(child = child.canonicalized, name = None)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -94,11 +94,21 @@ package object util extends Logging {
9494
case Literal(v, t: NumericType) if v != null => PrettyAttribute(v.toString, t)
9595
case Literal(null, dataType) => PrettyAttribute("NULL", dataType)
9696
case e: GetStructField =>
97-
val name = e.name.getOrElse(e.childSchema(e.ordinal).name)
98-
PrettyAttribute(
99-
usePrettyExpression(e.child, shouldTrimTempResolvedColumn).sql + "." + name,
100-
e.dataType
101-
)
97+
e.child.dataType match {
98+
case st: StructType =>
99+
val name = e.name.getOrElse(st(e.ordinal).name)
100+
PrettyAttribute(
101+
usePrettyExpression(e.child, shouldTrimTempResolvedColumn).sql + "." + name,
102+
st(e.ordinal).dataType
103+
)
104+
case _ =>
105+
// Child type was changed by a plan transformation.
106+
val name = e.name.getOrElse(s"_${e.ordinal}")
107+
PrettyAttribute(
108+
usePrettyExpression(e.child, shouldTrimTempResolvedColumn).sql + "." + name,
109+
e.child.dataType
110+
)
111+
}
102112
case e: GetArrayStructFields =>
103113
PrettyAttribute(
104114
s"${usePrettyExpression(e.child, shouldTrimTempResolvedColumn)}.${e.field.name}",

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,27 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
177177
assert(getStructField(nullStruct, "a").nullable)
178178
}
179179

180+
test("GetStructField checkInputDataTypes should fail when child is not StructType") {
181+
// Simulate a plan transformation that changes the child's type from StructType to
182+
// StringType after the GetStructField was created.
183+
val stringAttr = AttributeReference("c2", StringType)()
184+
val getField = GetStructField(stringAttr, 0, Some("f1"))
185+
186+
assert(getField.checkInputDataTypes().isFailure)
187+
val result = getField.checkInputDataTypes().asInstanceOf[DataTypeMismatch]
188+
assert(result.errorSubClass == "UNEXPECTED_INPUT_TYPE")
189+
}
190+
191+
test("GetStructField toPrettySQL should not crash when child is not StructType") {
192+
val stringAttr = AttributeReference("c2", StringType)()
193+
val getField = GetStructField(stringAttr, 0, Some("f1"))
194+
195+
// This should not throw ClassCastException
196+
val prettySQL = toPrettySQL(getField)
197+
assert(prettySQL.contains("c2"))
198+
assert(prettySQL.contains("f1"))
199+
}
200+
180201
test("GetArrayStructFields") {
181202
// test 4 types: struct field nullability X array element nullability
182203
val type1 = ArrayType(StructType(StructField("a", IntegerType) :: Nil))

0 commit comments

Comments
 (0)