Skip to content

Commit eaa1dc4

Browse files
committed
add_listner_comet_plugin_ini
1 parent 2eb506f commit eaa1dc4

File tree

4 files changed

+12
-18
lines changed

4 files changed

+12
-18
lines changed

spark/src/main/scala/org/apache/comet/CometMetricsListener.scala

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,22 +19,11 @@
1919

2020
package org.apache.comet
2121

22-
import java.util.concurrent.atomic.AtomicBoolean
23-
2422
import org.apache.spark.CometSource
25-
import org.apache.spark.sql.SparkSession
2623
import org.apache.spark.sql.execution.QueryExecution
2724
import org.apache.spark.sql.util.QueryExecutionListener
2825

29-
object CometMetricsListener extends QueryExecutionListener {
30-
31-
private val registered = new AtomicBoolean(false)
32-
33-
def ensureRegistered(session: SparkSession): Unit = {
34-
if (registered.compareAndSet(false, true)) {
35-
session.listenerManager.register(this)
36-
}
37-
}
26+
class CometMetricsListener extends QueryExecutionListener {
3827

3928
override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = {
4029
val stats = CometCoverageStats.forPlan(qe.executedPlan)

spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ import org.apache.spark.sql.execution.window.WindowExec
4747
import org.apache.spark.sql.internal.SQLConf
4848
import org.apache.spark.sql.types._
4949

50-
import org.apache.comet.{CometConf, CometExplainInfo, CometMetricsListener, ExtendedExplainInfo}
50+
import org.apache.comet.{CometConf, CometExplainInfo, ExtendedExplainInfo}
5151
import org.apache.comet.CometConf.{COMET_SPARK_TO_ARROW_ENABLED, COMET_SPARK_TO_ARROW_SUPPORTED_OPERATOR_LIST}
5252
import org.apache.comet.CometSparkSessionExtensions._
5353
import org.apache.comet.rules.CometExecRule.allExecs
@@ -387,8 +387,6 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
387387
normalizedPlan
388388
}
389389

390-
CometMetricsListener.ensureRegistered(session)
391-
392390
var newPlan = transform(planWithJoinRewritten)
393391

394392
// if the plan cannot be run fully natively then explain why (when appropriate

spark/src/main/scala/org/apache/spark/Plugins.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ class CometDriverPlugin extends DriverPlugin with Logging with ShimCometDriverPl
6060
// Register Comet metrics source
6161
sc.env.metricsSystem.registerSource(CometSource)
6262

63+
// Register query execution listener via config
64+
CometDriverPlugin.registerQueryExecutionListener(sc.conf)
65+
6366
if (CometSparkSessionExtensions.shouldOverrideMemoryConf(sc.getConf)) {
6467
val execMemOverhead = if (sc.getConf.contains(EXECUTOR_MEMORY_OVERHEAD.key)) {
6568
sc.getConf.getSizeAsMb(EXECUTOR_MEMORY_OVERHEAD.key)
@@ -104,6 +107,10 @@ class CometDriverPlugin extends DriverPlugin with Logging with ShimCometDriverPl
104107
}
105108

106109
object CometDriverPlugin extends Logging {
110+
def registerQueryExecutionListener(conf: SparkConf): Unit = {
111+
conf.set("spark.sql.queryExecutionListeners", "org.apache.comet.CometMetricsListener")
112+
}
113+
107114
def registerCometSessionExtension(conf: SparkConf): Unit = {
108115
val extensionKey = StaticSQLConf.SPARK_SESSION_EXTENSIONS.key
109116
val extensionClass = classOf[CometSparkSessionExtensions].getName

spark/src/test/scala/org/apache/spark/CometPluginsSuite.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ class CometPluginsSuite extends CometTestBase {
8888
spark.range(1000).toDF("id").write.mode(SaveMode.Overwrite).parquet(path)
8989
spark.read.parquet(path).filter("id > 500").collect()
9090
}
91-
91+
spark.sparkContext.listenerBus.waitUntilEmpty()
9292
assert(
9393
CometSource.QUERIES_PLANNED.getCount > queriesBefore,
9494
"queries.planned should increment after query")
@@ -104,10 +104,10 @@ class CometPluginsSuite extends CometTestBase {
104104
spark.range(10000).toDF("id").write.mode(SaveMode.Overwrite).parquet(path)
105105

106106
val queriesBefore = CometSource.QUERIES_PLANNED.getCount
107-
107+
spark.sparkContext.listenerBus.waitUntilEmpty()
108108
spark.read.parquet(path).filter("id > 100").collect()
109109
spark.read.parquet(path).filter("id > 200").collect()
110-
110+
spark.sparkContext.listenerBus.waitUntilEmpty()
111111
val queriesAfter = CometSource.QUERIES_PLANNED.getCount
112112
assert(
113113
queriesAfter == queriesBefore + 2,

0 commit comments

Comments
 (0)