Skip to content

Commit 70691d0

Browse files
authored
Use thread context classloader for Iceberg class loading (#3738)
1 parent 0177afc commit 70691d0

2 files changed

Lines changed: 40 additions & 26 deletions

File tree

spark/src/main/scala/org/apache/comet/iceberg/IcebergReflection.scala

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,29 @@ object IcebergReflection extends Logging {
7777
val UNKNOWN = "unknown"
7878
}
7979

80+
/**
81+
* Loads a class using the thread context classloader first, then falls back to the system
82+
* classloader.
83+
*
84+
* @param className
85+
* Fully qualified class name to load
86+
* @return
87+
* The loaded Class object
88+
*/
89+
def loadClass(className: String): Class[_] = {
90+
val classLoader = Thread.currentThread().getContextClassLoader
91+
if (classLoader != null) {
92+
// scalastyle:off classforname
93+
Class.forName(className, true, classLoader)
94+
// scalastyle:on classforname
95+
} else {
96+
// Fallback to default classloader if context classloader is null
97+
// scalastyle:off classforname
98+
Class.forName(className)
99+
// scalastyle:on classforname
100+
}
101+
}
102+
80103
/**
81104
* Searches through class hierarchy to find a method (including protected methods).
82105
*/
@@ -124,9 +147,7 @@ object IcebergReflection extends Logging {
124147
*/
125148
def extractFileLocation(file: Any): Option[String] = {
126149
try {
127-
// scalastyle:off classforname
128-
val contentFileClass = Class.forName(ClassNames.CONTENT_FILE)
129-
// scalastyle:on classforname
150+
val contentFileClass = loadClass(ClassNames.CONTENT_FILE)
130151
extractFileLocation(contentFileClass, file)
131152
} catch {
132153
case _: Exception => None
@@ -387,9 +408,7 @@ object IcebergReflection extends Logging {
387408
*/
388409
def getEqualityFieldIds(deleteFile: Any): java.util.List[_] = {
389410
try {
390-
// scalastyle:off classforname
391-
val deleteFileClass = Class.forName(ClassNames.DELETE_FILE)
392-
// scalastyle:on classforname
411+
val deleteFileClass = loadClass(ClassNames.DELETE_FILE)
393412
val equalityFieldIdsMethod = deleteFileClass.getMethod("equalityFieldIds")
394413
val ids = equalityFieldIdsMethod.invoke(deleteFile).asInstanceOf[java.util.List[_]]
395414
if (ids == null) new java.util.ArrayList[Any]() else ids
@@ -515,9 +534,7 @@ object IcebergReflection extends Logging {
515534
val fieldsMethod = partitionSpec.getClass.getMethod("fields")
516535
val fields = fieldsMethod.invoke(partitionSpec).asInstanceOf[java.util.List[_]]
517536

518-
// scalastyle:off classforname
519-
val partitionFieldClass = Class.forName(ClassNames.PARTITION_FIELD)
520-
// scalastyle:on classforname
537+
val partitionFieldClass = loadClass(ClassNames.PARTITION_FIELD)
521538
val sourceIdMethod = partitionFieldClass.getMethod("sourceId")
522539
val findFieldMethod = schema.getClass.getMethod("findField", classOf[Int])
523540

spark/src/main/scala/org/apache/comet/serde/operator/CometIcebergNativeScan.scala

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -227,9 +227,7 @@ object CometIcebergNativeScan extends CometOperatorSerde[CometBatchScanExec] wit
227227
fileScanTaskClass: Class[_],
228228
fileIO: Option[Any]): Seq[OperatorOuterClass.IcebergDeleteFile] = {
229229
try {
230-
// scalastyle:off classforname
231-
val deleteFileClass = Class.forName(IcebergReflection.ClassNames.DELETE_FILE)
232-
// scalastyle:on classforname
230+
val deleteFileClass = IcebergReflection.loadClass(IcebergReflection.ClassNames.DELETE_FILE)
233231

234232
val deletes = IcebergReflection.getDeleteFilesFromTask(task, fileScanTaskClass)
235233

@@ -336,13 +334,11 @@ object CometIcebergNativeScan extends CometOperatorSerde[CometBatchScanExec] wit
336334
if (spec != null) {
337335
// Deduplicate partition spec
338336
try {
339-
// scalastyle:off classforname
340337
val partitionSpecParserClass =
341-
Class.forName(IcebergReflection.ClassNames.PARTITION_SPEC_PARSER)
338+
IcebergReflection.loadClass(IcebergReflection.ClassNames.PARTITION_SPEC_PARSER)
342339
val toJsonMethod = partitionSpecParserClass.getMethod(
343340
"toJson",
344-
Class.forName(IcebergReflection.ClassNames.PARTITION_SPEC))
345-
// scalastyle:on classforname
341+
IcebergReflection.loadClass(IcebergReflection.ClassNames.PARTITION_SPEC))
346342
val partitionSpecJson = toJsonMethod
347343
.invoke(null, spec)
348344
.asInstanceOf[String]
@@ -685,9 +681,7 @@ object CometIcebergNativeScan extends CometOperatorSerde[CometBatchScanExec] wit
685681
*/
686682
private def convertIcebergLiteral(icebergLiteral: Any, sparkType: DataType): Literal = {
687683
// Load Literal interface to get value() method (use interface to avoid package-private issues)
688-
// scalastyle:off classforname
689-
val literalClass = Class.forName(IcebergReflection.ClassNames.LITERAL)
690-
// scalastyle:on classforname
684+
val literalClass = IcebergReflection.loadClass(IcebergReflection.ClassNames.LITERAL)
691685
val valueMethod = literalClass.getMethod("value")
692686
val value = valueMethod.invoke(icebergLiteral)
693687

@@ -790,13 +784,16 @@ object CometIcebergNativeScan extends CometOperatorSerde[CometBatchScanExec] wit
790784
}
791785

792786
// Load Iceberg classes once (avoid repeated class loading in loop)
793-
// scalastyle:off classforname
794-
val contentScanTaskClass = Class.forName(IcebergReflection.ClassNames.CONTENT_SCAN_TASK)
795-
val fileScanTaskClass = Class.forName(IcebergReflection.ClassNames.FILE_SCAN_TASK)
796-
val contentFileClass = Class.forName(IcebergReflection.ClassNames.CONTENT_FILE)
797-
val schemaParserClass = Class.forName(IcebergReflection.ClassNames.SCHEMA_PARSER)
798-
val schemaClass = Class.forName(IcebergReflection.ClassNames.SCHEMA)
799-
// scalastyle:on classforname
787+
val contentScanTaskClass =
788+
IcebergReflection.loadClass(IcebergReflection.ClassNames.CONTENT_SCAN_TASK)
789+
val fileScanTaskClass =
790+
IcebergReflection.loadClass(IcebergReflection.ClassNames.FILE_SCAN_TASK)
791+
val contentFileClass =
792+
IcebergReflection.loadClass(IcebergReflection.ClassNames.CONTENT_FILE)
793+
val schemaParserClass =
794+
IcebergReflection.loadClass(IcebergReflection.ClassNames.SCHEMA_PARSER)
795+
val schemaClass =
796+
IcebergReflection.loadClass(IcebergReflection.ClassNames.SCHEMA)
800797

801798
// Cache method lookups (avoid repeated getMethod in loop)
802799
val fileMethod = contentScanTaskClass.getMethod("file")

0 commit comments

Comments
 (0)