Skip to content

Commit c6da902

Browse files
committed
fix_review_comments
1 parent eaa1dc4 commit c6da902

File tree

1 file changed

+19
-7
lines changed

1 file changed

+19
-7
lines changed

spark/src/main/scala/org/apache/spark/Plugins.scala

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,8 @@ class CometDriverPlugin extends DriverPlugin with Logging with ShimCometDriverPl
5757
// register CometSparkSessionExtensions if it isn't already registered
5858
CometDriverPlugin.registerCometSessionExtension(sc.conf)
5959

60-
// Register Comet metrics source
61-
sc.env.metricsSystem.registerSource(CometSource)
62-
63-
// Register query execution listener via config
64-
CometDriverPlugin.registerQueryExecutionListener(sc.conf)
60+
// Register Comet metrics
61+
CometDriverPlugin.registerCometMetrics(sc)
6562

6663
if (CometSparkSessionExtensions.shouldOverrideMemoryConf(sc.getConf)) {
6764
val execMemOverhead = if (sc.getConf.contains(EXECUTOR_MEMORY_OVERHEAD.key)) {
@@ -107,8 +104,23 @@ class CometDriverPlugin extends DriverPlugin with Logging with ShimCometDriverPl
107104
}
108105

109106
object CometDriverPlugin extends Logging {
110-
def registerQueryExecutionListener(conf: SparkConf): Unit = {
111-
conf.set("spark.sql.queryExecutionListeners", "org.apache.comet.CometMetricsListener")
107+
def registerCometMetrics(sc: SparkContext): Unit = {
108+
sc.env.metricsSystem.registerSource(CometSource)
109+
110+
val listenerKey = "spark.sql.queryExecutionListeners"
111+
val listenerClass = "org.apache.comet.CometMetricsListener"
112+
val listeners = sc.conf.get(listenerKey, "")
113+
if (listeners.isEmpty) {
114+
logInfo(s"Setting $listenerKey=$listenerClass")
115+
sc.conf.set(listenerKey, listenerClass)
116+
} else {
117+
val currentListeners = listeners.split(",").map(_.trim)
118+
if (!currentListeners.contains(listenerClass)) {
119+
val newValue = s"$listeners,$listenerClass"
120+
logInfo(s"Setting $listenerKey=$newValue")
121+
sc.conf.set(listenerKey, newValue)
122+
}
123+
}
112124
}
113125

114126
def registerCometSessionExtension(conf: SparkConf): Unit = {

0 commit comments

Comments
 (0)