Skip to content

Commit 2b6dcc1

Browse files
committed
fix_review_comments
1 parent eaa1dc4 commit 2b6dcc1

File tree

2 files changed

+20
-8
lines changed

2 files changed

+20
-8
lines changed

spark/src/main/scala/org/apache/spark/Plugins.scala

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,8 @@ class CometDriverPlugin extends DriverPlugin with Logging with ShimCometDriverPl
5757
// register CometSparkSessionExtensions if it isn't already registered
5858
CometDriverPlugin.registerCometSessionExtension(sc.conf)
5959

60-
// Register Comet metrics source
61-
sc.env.metricsSystem.registerSource(CometSource)
62-
63-
// Register query execution listener via config
64-
CometDriverPlugin.registerQueryExecutionListener(sc.conf)
60+
// Register Comet metrics
61+
CometDriverPlugin.registerCometMetrics(sc)
6562

6663
if (CometSparkSessionExtensions.shouldOverrideMemoryConf(sc.getConf)) {
6764
val execMemOverhead = if (sc.getConf.contains(EXECUTOR_MEMORY_OVERHEAD.key)) {
@@ -107,8 +104,23 @@ class CometDriverPlugin extends DriverPlugin with Logging with ShimCometDriverPl
107104
}
108105

109106
object CometDriverPlugin extends Logging {
110-
def registerQueryExecutionListener(conf: SparkConf): Unit = {
111-
conf.set("spark.sql.queryExecutionListeners", "org.apache.comet.CometMetricsListener")
107+
def registerCometMetrics(sc: SparkContext): Unit = {
108+
sc.env.metricsSystem.registerSource(CometSource)
109+
110+
val listenerKey = "spark.sql.queryExecutionListeners"
111+
val listenerClass = "org.apache.comet.CometMetricsListener"
112+
val listeners = sc.conf.get(listenerKey, "")
113+
if (listeners.isEmpty) {
114+
logInfo(s"Setting $listenerKey=$listenerClass")
115+
sc.conf.set(listenerKey, listenerClass)
116+
} else {
117+
val currentListeners = listeners.split(",").map(_.trim)
118+
if (!currentListeners.contains(listenerClass)) {
119+
val newValue = s"$listeners,$listenerClass"
120+
logInfo(s"Setting $listenerKey=$newValue")
121+
sc.conf.set(listenerKey, newValue)
122+
}
123+
}
112124
}
113125

114126
def registerCometSessionExtension(conf: SparkConf): Unit = {

spark/src/test/scala/org/apache/spark/CometPluginsSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,8 @@ class CometPluginsSuite extends CometTestBase {
103103
val path = new File(dir, "test.parquet").toString
104104
spark.range(10000).toDF("id").write.mode(SaveMode.Overwrite).parquet(path)
105105

106-
val queriesBefore = CometSource.QUERIES_PLANNED.getCount
107106
spark.sparkContext.listenerBus.waitUntilEmpty()
107+
val queriesBefore = CometSource.QUERIES_PLANNED.getCount
108108
spark.read.parquet(path).filter("id > 100").collect()
109109
spark.read.parquet(path).filter("id > 200").collect()
110110
spark.sparkContext.listenerBus.waitUntilEmpty()

0 commit comments

Comments
 (0)