[SPARK-57688][SQL] Add spark.sql.execution.bypassPartialAggregation to skip partial agg#56777
[SPARK-57688][SQL] Add spark.sql.execution.bypassPartialAggregation to skip partial agg#56777xumingming wants to merge 3 commits into
Conversation
| "When false (default), uses a two-phase Partial+Final aggregation across a shuffle. " + | ||
| "This setting has no effect on queries containing DISTINCT aggregate functions, where " + | ||
| "the partial aggregation phases are required for correctness and are always applied.") | ||
| .version("3.3.1") |
There was a problem hiding this comment.
| .version("3.3.1") | |
| .version("4.3.0") |
| .booleanConf | ||
| .createWithDefault(true) | ||
|
|
||
| val BYPASS_PARTIAL_AGGREGATION = buildConf("spark.sql.execution.bypassPartialAggregation") |
There was a problem hiding this comment.
SparkConfigBindingPolicySuite requires every new config to declare a policy, please make sure to add withBindingPolicy.
| "the partial aggregation phases are required for correctness and are always applied.") | ||
| .version("3.3.1") | ||
| .booleanConf | ||
| .createWithDefault(false) |
There was a problem hiding this comment.
I actually want make it public, so users could utilize it to optimize performance, how do you think?
| s"Expected:\n${expected.mkString("\n")}\nActual:\n${actual.mkString("\n")}") | ||
| } | ||
| } | ||
| } |
There was a problem hiding this comment.
Test gap: no test with AQE enabled.
Also, no TypedImperativeAggregate bypass test.
There was a problem hiding this comment.
You mean no test with AQE disabled right?(AQE is enabled by default). I will make all tests run with and without AQE enabled.
There was a problem hiding this comment.
cc @cloud-fan @viirya @ueshin for AggUtils/AQE interaction
HyukjinKwon
left a comment
There was a problem hiding this comment.
1 blocking, 0 non-blocking, 0 nits.
Well-scoped feature (session_window and DISTINCT correctly excluded) and correct for plain-column grouping, but the bypass plan is invalid for expression grouping keys. (uros-b's existing config + test-gap comments are valid and not repeated here.)
Correctness (1)
- AggUtils.scala:146:
Some(groupingAttributes)over the raw child references synthetic grouping attributes the child doesn't output (e.g.GROUP BY v % 10); should beSome(groupingExpressions)— see inline
Verification
Traced the rewrite Partial+shuffle+Final → shuffle+Complete: equivalent for plain-attribute grouping (empty/single/many rows), and session_window (gated by !hasSessionWindow) and DISTINCT (different planner) are excluded. The one non-equivalent input is a non-attribute grouping key: groupingAttributes = groupingExpressions.map(_.toAttribute), and SparkStrategies wraps v % 10 as Alias(v % 10, "k"), so .toAttribute is not in the raw child's output → HashPartitioning over a missing attribute → bind failure at execution. Output schema is preserved (resultExpressions unchanged).
| val completeAggregateExpressions = aggregateExpressions.map(_.copy(mode = Complete)) | ||
| val completeAggregateAttributes = completeAggregateExpressions.map(_.resultAttribute) | ||
| val completeAggregate = createAggregate( | ||
| requiredChildDistributionExpressions = Some(groupingAttributes), |
There was a problem hiding this comment.
requiredChildDistributionExpressions = Some(groupingAttributes) is applied over the raw child here, but groupingAttributes = groupingExpressions.map(_.toAttribute). For any grouping key that isn't a plain child attribute — e.g. GROUP BY v % 10 — SparkStrategies wraps it as Alias(v % 10, "k") (SparkStrategies.scala ~L708), so .toAttribute is a synthetic AttributeReference that is not in the raw child's output. The resulting ClusteredDistribution → HashPartitioning(thatAttr) then fails to bind against the child at execution.
The normal Final path (L186) can use Some(groupingAttributes) only because its child is the Partial agg that produces those attributes; the bypass's child is the raw input, so it should distribute on the expressions themselves:
| requiredChildDistributionExpressions = Some(groupingAttributes), | |
| requiredChildDistributionExpressions = Some(groupingExpressions), |
This is currently untested: the first test groups by (v % 10).as("k") but only checks executedPlan structure (no execution), and the SUM/COUNT/AVG tests group by a materialized plain column k. An executing GROUP BY <expression> test (collect + checkAnswer under the config) would catch it. (Side note: with the fix, the bypass evaluates the grouping expression twice — in the shuffle partitioning and in the Complete agg — vs once in the two-phase path; fine for deterministic keys, worth a thought for nondeterministic ones.)
There was a problem hiding this comment.
- "requiredChildDistributionExpressions = Some(groupingAttributes)" is a good catch, I will make the change. Under the hood, even if we don't make the change, current code would produce the right result because the PullOutGroupingExpressions rule extracts the non-attribute grouping expression before the plan enters planAggregateWithoutDistinct. But the change you suggest is indeed great to make the code more readable, reasonable, will make the change.
- For the nondeterministic concerns, PullOutNondeterministic pulls the nondeterministic grouping expressions into a upstream Project, so it will not be evaluated multiple times.
…tion to skip pre-shuffle partial agg Adds a new SQL config spark.sql.execution.bypassPartialAggregation (default false). When set to true, planAggregateWithoutDistinct skips the pre-shuffle Partial-mode aggregation and runs a single Complete-mode aggregation after the shuffle instead. This can improve performance when group cardinality is high and the pre-shuffle reduction ratio is low. The bypass is suppressed when a session_window grouping key is present, since MergingSessionsExec must be inserted in the Partial+Merge+Final path to correctly merge overlapping sessions. The config has no effect on queries containing DISTINCT aggregate functions, where the partial aggregation phases are required for correctness and are always applied.
HyukjinKwon
left a comment
There was a problem hiding this comment.
0 addressed, 1 remaining, 3 new. (3 new = 0 newly introduced, 3 late catches — my own misses; same commit, no code changed between rounds.)
0 blocking, 2 non-blocking, 2 nits. Well-scoped feature; my prior blocking finding was overstated — corrected below.
Correctness (1)
- AggUtils.scala:146 (prior thread, remaining): correction — my earlier "
GROUP BY v % 10bind-fails" claim was wrong;PullOutGroupingExpressionspulls that key into a childProject.Some(groupingExpressions)(which you agreed to) still helps for a narrow case — see Verification.
Design / architecture (1)
- AggUtils.scala:142: bypass also fires for global aggregation (no grouping keys) — all rows shuffle to one partition with no pre-agg, zero benefit — see inline
Nits: 2 minor items (see inline comments).
Verification
Re-traced the Partial+shuffle+Final → shuffle+Complete rewrite: row-equivalent for empty/single/many rows, NULL keys, and duplicates; session_window (!hasSessionWindow) and DISTINCT (separate planner) are excluded. On the distribution key: PullOutGroupingExpressions (Optimizer.scala:341; comment L343-344 "the grouping keys can only be attribute and literal") pulls complex keys like v % 10 into a child Project, so Some(groupingAttributes) binds fine there — my prior round was wrong about that. It differs from Some(groupingExpressions) only for foldable / childless keys (a constant literal, spark_partition_id()), which aren't pulled out and aren't in the raw child's output → bind failure under bypass=true; hence the agreed change is still worth keeping.
| // when a session_window grouping key is present so that the normal Partial+Merge+Final path | ||
| // runs and MergingSessionsExec is correctly inserted. | ||
| val hasSessionWindow = groupingExpressions.exists(_.metadata.contains(SessionWindow.marker)) | ||
| if (child.conf.bypassPartialAggregation && !hasSessionWindow) { |
There was a problem hiding this comment.
Non-blocking (perf): this gate also fires for global aggregation (groupingExpressions.isEmpty). There requiredChildDistributionExpressions = Some(groupingAttributes) is Some(Nil) → AllTuples, so all raw rows shuffle to a single partition with no pre-aggregation. For a cardinality-1 global agg that's a pure regression with zero upside, and a user who enables this session-wide for high-cardinality grouped queries silently pessimizes any global aggs in the same session. Consider also requiring grouping keys:
| if (child.conf.bypassPartialAggregation && !hasSessionWindow) { | |
| if (child.conf.bypassPartialAggregation && groupingExpressions.nonEmpty && !hasSessionWindow) { |
| // One event for key "b" stands alone. | ||
| val df = Seq( | ||
| ("2016-03-27 19:39:34", 1, "a"), | ||
| ("2016-03-27 19:39:39", 2, "a"), // within 10s of the first "a" — same session |
There was a problem hiding this comment.
Nit: em-dash (non-ASCII) in a // comment — CLAUDE.md/scalastyle flag non-ASCII in comments.
| ("2016-03-27 19:39:39", 2, "a"), // within 10s of the first "a" — same session | |
| ("2016-03-27 19:39:39", 2, "a"), // within 10s of the first "a" - same session |
| val df = Seq( | ||
| ("2016-03-27 19:39:34", 1, "a"), | ||
| ("2016-03-27 19:39:39", 2, "a"), // within 10s of the first "a" — same session | ||
| ("2016-03-27 19:39:56", 3, "a"), // > 10s gap — separate session |
There was a problem hiding this comment.
Nit: em-dash (non-ASCII) in a // comment.
| ("2016-03-27 19:39:56", 3, "a"), // > 10s gap — separate session | |
| ("2016-03-27 19:39:56", 3, "a"), // > 10s gap - separate session |
…or better diagnostics Switch scalar-aggregate tests (SUM, COUNT, AVG, session_window) to use checkAnswer instead of raw actual.toSeq == expected.toSeq, providing better error messages when comparisons fail by pinpointing the mismatched row and column. Keep manual zip-and-sort for the collect_list test since checkAnswer does not sort nested arrays — collect_list output order within groups is non-deterministic between Partial+Final and Complete aggregation paths. Also replace non-ASCII em-dashes with ASCII equivalents (--, -, :) in test names and comments to satisfy scalastyle.
…tions Global aggregations (no GROUP BY) always produce a single output row, so the pre-shuffle partial aggregation achieves the maximum possible reduction ratio. Bypassing it would shuffle all raw rows to a single partition with no benefit — strictly worse than Partial+Final. Extract hasGroupingKeys = groupingExpressions.nonEmpty and add it to the bypass gate alongside hasSessionWindow, so the bypass only fires when there are grouping keys to hash-partition on. Add a test verifying that global aggregations continue to produce Partial+Final plans even with bypassPartialAggregation=true.
fbcf3c9 to
c8a214a
Compare
|
@uros-b @HyukjinKwon Thanks for the review, made the following changes:
|
What changes were proposed in this pull request?
Adds a new SQL config spark.sql.execution.bypassPartialAggregation (default false). When set to true, planAggregateWithoutDistinct skips the pre-shuffle Partial-mode aggregation and runs a single Complete-mode aggregation after the shuffle instead. This can improve performance when group cardinality is high and the pre-shuffle reduction ratio is low.
The bypass is suppressed when a session_window grouping key is present, since MergingSessionsExec must be inserted in the Partial+Merge+Final path to correctly merge overlapping sessions.
The config has no effect on queries containing DISTINCT aggregate functions, where the partial aggregation phases are required for correctness and are always applied.
Why are the changes needed?
The standard two-phase aggregation plan (Partial → shuffle → Final) assumes that pre-shuffle partial aggregation meaningfully reduces data volume. This assumption breaks down in two scenarios.
Scenario 1: High group cardinality. When group cardinality is high relative to partition size, every input row maps to a distinct key, so the partial aggregation produces one output row per input row and adds CPU and memory overhead with zero shuffle benefit.
On a table with 500M rows and 200M distinct user_id values, the pre-shuffle HashAggregateExec in Partial mode churns through the full dataset, spills when the hash map overflows, and still emits ~200M rows into the shuffle. The partial phase wastes wall-clock time and memory without reducing shuffle write volume.
Scenario 2: Skewed input data. Even when partial aggregation can reduce data volume on average, skewed input partitions can make it harmful. If one partition contains a disproportionate share of rows for a small number of keys, the partial HashAggregateExec on that partition must hold a large hash map in memory, triggering spills. The skewed partition becomes the bottleneck and dominates wall-clock time — worse than if the data had been shuffled first and aggregated on already-partitioned, evenly distributed data.
Does this PR introduce any user-facing change?
No.
How was this patch tested?
Added Unit Test.
Was this patch authored or co-authored using generative AI tooling?
No.