From e41fae92d041f73109f275babda974b12e875142 Mon Sep 17 00:00:00 2001 From: Anupam Yadav Date: Wed, 3 Jun 2026 00:46:45 +0000 Subject: [PATCH 1/6] [SPARK-57220][SQL] Extend block-chunked segment-tree window frame to shrinking frames ### What changes were proposed in this pull request? Extends `SegmentTreeWindowFunctionFrame` (introduced in SPARK-56546 for sliding aggregates) to also handle shrinking frames of the form `... ROWS/RANGE BETWEEN AND UNBOUNDED FOLLOWING`. The class is parameterized with `ubound: Option[BoundOrdering]` (`None` = shrinking, `Some(ub)` = sliding) and a `fallbackFactory` for the small-partition path so the same machinery (build, spill, eligibility, metrics) serves both shapes. The dispatcher in `WindowEvaluatorFactoryBase` gains a shrinking-frame branch that consults the existing `eligibleForSegTree` gate and, on success, builds the unified frame with `ubound = None`. ### Why are the changes needed? The legacy `UnboundedFollowingWindowFunctionFrame` recomputes the suffix aggregate from scratch for every output row -- O(n * (n - 1) / 2). Its own scaladoc acknowledges this (`WindowFunctionFrame.scala:636`): > This is a very expensive operator to use, O(n * (n - 1) / 2), because > we need to maintain a buffer and must do full recalculation after each > row. The segment tree built by SPARK-56546 already supports arbitrary `[lower, upper)` queries; routing shrinking frames into it is purely a dispatch + parameter change. Workloads with shrinking frames -- common in retention / cohort / "remaining-lifetime" analytics -- become orders of magnitude faster. ### Does this PR introduce _any_ user-facing change? No. Same opt-in conf (`spark.sql.window.segmentTree.enabled`, default false), same eligibility allowlist (DeclarativeAggregate with mergeExpressions, no FILTER, no DISTINCT), same `minPartitionRows` fallback (now to `UnboundedFollowingWindowFunctionFrame` instead of `SlidingWindowFunctionFrame`), no analyzer / SQL grammar / plan-shape changes. ### How was this patch tested? New `UnboundedFollowingSegmentTreeSuite` (26 tests, all green): basic aggregates, ROWS lower-bound variations, multi-aggregate shared frame, single-row / empty / fallback partitions, NULL / NaN / Infinity, type coverage (Int/Long/Double/Decimal/String/Date/Timestamp), allowlist fallback, RANGE frames (uniform, non-uniform, ties, NULL keys, INTERVAL Timestamp), feature-flag off. Existing `SegmentTreeWindowFunctionSuite` (41 sliding tests), `WindowSegmentTreeSuite`, `WindowSegmentTreePropertySuite`, `WindowSegmentTreeMemorySuite`, `SegmentTreeWindowMetricsSuite`, `WindowSegmentTreeAllowlistSuite`, and `DataFrameWindowFunctionsSuite` all still pass (172 tests total, 0 failures), confirming the unified rewrite preserves sliding-frame semantics. Benchmark (`UnboundedFollowingWindowBenchmark`, JDK 17, EC2 c5.4xlarge): | N | naive | segtree | speedup | |------|-------------|---------|---------| | 5K | 620 ms | 73 ms | 8.5X | | 10K | 2 471 ms | 110 ms | 22.5X | | 25K | 14 259 ms | 119 ms | 119.3X | | 50K | 57 022 ms | 181 ms | 314.2X | | 100K | (~4 min) | 269 ms | -- | | 200K | (~16 min) | 480 ms | -- | Naive curve is clean O(N^2); segtree curve is sub-linear (logarithmic per-row). ### Was this patch authored or co-authored using generative AI tooling? Yes. Authored with assistance from Claude (Anthropic). --- ...oundedFollowingWindowBenchmark-results.txt | 118 +++++ .../SegmentTreeWindowFunctionFrame.scala | 125 ++++-- .../window/WindowEvaluatorFactoryBase.scala | 50 ++- .../UnboundedFollowingWindowBenchmark.scala | 187 ++++++++ .../window/SegmentTreeWindowTestHelper.scala | 7 +- .../UnboundedFollowingSegmentTreeSuite.scala | 412 ++++++++++++++++++ 6 files changed, 843 insertions(+), 56 deletions(-) create mode 100644 sql/core/benchmarks/UnboundedFollowingWindowBenchmark-results.txt create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/UnboundedFollowingWindowBenchmark.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/window/UnboundedFollowingSegmentTreeSuite.scala diff --git a/sql/core/benchmarks/UnboundedFollowingWindowBenchmark-results.txt b/sql/core/benchmarks/UnboundedFollowingWindowBenchmark-results.txt new file mode 100644 index 0000000000000..c8eafed7fb032 --- /dev/null +++ b/sql/core/benchmarks/UnboundedFollowingWindowBenchmark-results.txt @@ -0,0 +1,118 @@ +================================================================================================ +Section A - SUM (non-invertible suffix) +================================================================================================ + +OpenJDK 64-Bit Server VM 17.0.19+10-LTS on Linux 5.10.255-254.1008.amzn2int.x86_64 +Intel(R) Xeon(R) Platinum 8259CL CPU @ 2.50GHz +SUM shrinking frame, N=10K rows: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +SUM naive (master O(N^2)) 2471 2495 14 0.0 241298.5 1.0X +SUM segtree 110 115 4 0.1 10744.6 22.5X + + +================================================================================================ +Section A - MIN +================================================================================================ + +OpenJDK 64-Bit Server VM 17.0.19+10-LTS on Linux 5.10.255-254.1008.amzn2int.x86_64 +Intel(R) Xeon(R) Platinum 8259CL CPU @ 2.50GHz +MIN shrinking frame, N=10K rows: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +MIN naive (master O(N^2)) 2417 2438 23 0.0 236035.8 1.0X +MIN segtree 215 219 5 0.0 21015.3 11.2X + + +================================================================================================ +Section A - MAX +================================================================================================ + +OpenJDK 64-Bit Server VM 17.0.19+10-LTS on Linux 5.10.255-254.1008.amzn2int.x86_64 +Intel(R) Xeon(R) Platinum 8259CL CPU @ 2.50GHz +MAX shrinking frame, N=10K rows: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +MAX naive (master O(N^2)) 2396 2401 5 0.0 233937.5 1.0X +MAX segtree 228 229 1 0.0 22259.2 10.5X + + +================================================================================================ +Section A - COUNT +================================================================================================ + +OpenJDK 64-Bit Server VM 17.0.19+10-LTS on Linux 5.10.255-254.1008.amzn2int.x86_64 +Intel(R) Xeon(R) Platinum 8259CL CPU @ 2.50GHz +COUNT shrinking frame, N=10K rows: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +COUNT naive (master O(N^2)) 2203 2222 16 0.0 215139.0 1.0X +COUNT segtree 80 88 9 0.1 7846.1 27.4X + + +================================================================================================ +Section A - AVG (multi-buffer) +================================================================================================ + +OpenJDK 64-Bit Server VM 17.0.19+10-LTS on Linux 5.10.255-254.1008.amzn2int.x86_64 +Intel(R) Xeon(R) Platinum 8259CL CPU @ 2.50GHz +AVG shrinking frame, N=10K rows: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +AVG naive (master O(N^2)) 2886 2900 18 0.0 281837.8 1.0X +AVG segtree 84 86 4 0.1 8165.1 34.5X + + +================================================================================================ +Section B - N=5K +================================================================================================ + +OpenJDK 64-Bit Server VM 17.0.19+10-LTS on Linux 5.10.255-254.1008.amzn2int.x86_64 +Intel(R) Xeon(R) Platinum 8259CL CPU @ 2.50GHz +SUM shrinking frame, N=5K rows: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +SUM naive (master O(N^2)) N=5K 620 628 7 0.0 121170.2 1.0X +SUM segtree N=5K 73 74 1 0.1 14302.8 8.5X + + +================================================================================================ +Section B - N=25K (stress) +================================================================================================ + +OpenJDK 64-Bit Server VM 17.0.19+10-LTS on Linux 5.10.255-254.1008.amzn2int.x86_64 +Intel(R) Xeon(R) Platinum 8259CL CPU @ 2.50GHz +SUM shrinking frame, N=25K rows: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +SUM naive (master O(N^2)) N=25K 14259 14341 108 0.0 556977.9 1.0X +SUM segtree N=25K 119 120 0 0.2 4667.1 119.3X + + +================================================================================================ +Section B - N=50K (stress, last naive run) +================================================================================================ + +OpenJDK 64-Bit Server VM 17.0.19+10-LTS on Linux 5.10.255-254.1008.amzn2int.x86_64 +Intel(R) Xeon(R) Platinum 8259CL CPU @ 2.50GHz +SUM shrinking frame, N=50K rows: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +SUM naive (master O(N^2)) N=50K 57022 57659 987 0.0 1113704.1 1.0X +SUM segtree N=50K 181 182 1 0.3 3544.3 314.2X + + +================================================================================================ +Section B - N=100K (segtree-only, stress) +================================================================================================ + +OpenJDK 64-Bit Server VM 17.0.19+10-LTS on Linux 5.10.255-254.1008.amzn2int.x86_64 +Intel(R) Xeon(R) Platinum 8259CL CPU @ 2.50GHz +SUM shrinking frame, N=100K rows (segtree-only): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------------- +SUM segtree N=100K 269 270 2 0.4 2627.9 1.0X + + +================================================================================================ +Section B - N=200K (segtree-only, stress) +================================================================================================ + +OpenJDK 64-Bit Server VM 17.0.19+10-LTS on Linux 5.10.255-254.1008.amzn2int.x86_64 +Intel(R) Xeon(R) Platinum 8259CL CPU @ 2.50GHz +SUM shrinking frame, N=200K rows (segtree-only): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------------- +SUM segtree N=200K 480 481 1 0.4 2343.7 1.0X + + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/SegmentTreeWindowFunctionFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/SegmentTreeWindowFunctionFrame.scala index 51648e31e3498..a6de28c0bfda8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/SegmentTreeWindowFunctionFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/SegmentTreeWindowFunctionFrame.scala @@ -27,14 +27,21 @@ import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.internal.SQLConf /** - * Moving-frame window function frame backed by [[WindowSegmentTree]]. Produces - * the same outputs as [[SlidingWindowFunctionFrame]] for RowFrame or - * single-column RangeFrame moving frames whose aggregates are all - * [[DeclarativeAggregate]] with no FILTER/DISTINCT. For partitions below - * `spark.sql.window.segmentTree.minPartitionRows`, delegates to a wrapped - * [[SlidingWindowFunctionFrame]]. Under RANGE, two forward-only cursors - * (`lowerIter` / `upperIter`) advance the bounds in O(n) total; the segtree - * answers `[lowerBound, upperBound)` in O(log n). + * Window function frame backed by [[WindowSegmentTree]]. Handles two frame + * shapes: + * - **Sliding** (`ubound = Some(...)`): both edges move; mirrors + * [[SlidingWindowFunctionFrame]]. O(N log W) total. + * - **Shrinking** (`ubound = None`): upper edge pinned to partition end + * (`BETWEEN AND UNBOUNDED FOLLOWING`); replaces + * [[UnboundedFollowingWindowFunctionFrame]]'s O(N^2) full recompute with + * O(N log N). + * + * Eligibility, build, spill, and memory accounting are identical for both + * shapes; only the per-row cursor logic differs (admit+drop for sliding, + * drop-only for shrinking). + * + * For partitions below `spark.sql.window.segmentTree.minPartitionRows`, + * delegates to a frame produced by `fallbackFactory`. * * @note Not thread-safe. */ @@ -45,7 +52,8 @@ private[window] final class SegmentTreeWindowFunctionFrame( inputSchema: Seq[Attribute], frameType: FrameType, lbound: BoundOrdering, - ubound: BoundOrdering, + ubound: Option[BoundOrdering], + fallbackFactory: () => WindowFunctionFrame, newMutableProjection: (Seq[Expression], Seq[Attribute]) => MutableProjection, conf: SQLConf, maxCachedBlocks: Option[Int], @@ -57,16 +65,18 @@ private[window] final class SegmentTreeWindowFunctionFrame( require(frameType == RowFrame || frameType == RangeFrame, s"SegmentTreeWindowFunctionFrame supports RowFrame or RangeFrame, got $frameType") - private[this] var fallback: SlidingWindowFunctionFrame = _ + // True when this is a shrinking-frame (UnboundedFollowing) instance. + // Shorthand to avoid repeated `ubound.isEmpty` reads in hot loops. + private[this] val shrinking: Boolean = ubound.isEmpty + + private[this] var fallback: WindowFunctionFrame = _ private[this] var tree: WindowSegmentTree = _ /** - * Allocate a fresh fallback sliding-window frame. Called lazily from - * `prepare()` on the small-partition path. Factored out for testability - * (subclasses can inject a throwing fallback for prepare-failure tests). + * Allocate a fresh fallback frame via `fallbackFactory`. Called lazily + * from `prepare()` on the small-partition path. */ - private[window] def newFallback(): SlidingWindowFunctionFrame = - new SlidingWindowFunctionFrame(target, processor, lbound, ubound) + private[window] def newFallback(): WindowFunctionFrame = fallbackFactory() /** Test hook: whether the fallback frame has been lazily allocated. */ private[window] def fallbackAllocated: Boolean = fallback != null @@ -100,8 +110,11 @@ private[window] final class SegmentTreeWindowFunctionFrame( /** * Runtime dispatch flag: when `true`, `write()`, `currentLowerBound()`, and - * `currentUpperBound()` delegate to the wrapped [[SlidingWindowFunctionFrame]] - * (small-partition path). Set by `prepare()` based on partition size vs. + * `currentUpperBound()` delegate to the wrapped fallback frame produced by + * `fallbackFactory` (small-partition path). The fallback type is sliding- + * dependent: [[SlidingWindowFunctionFrame]] for moving frames and + * [[UnboundedFollowingWindowFunctionFrame]] for shrinking frames. Set by + * `prepare()` based on partition size vs. * `spark.sql.window.segmentTree.minPartitionRows`. */ private[window] var fallbackUsed: Boolean = false @@ -155,19 +168,31 @@ private[window] final class SegmentTreeWindowFunctionFrame( // Count only on the successful segtree path: if `tree.build` throws, // the counter is not bumped. numSegmentTreeFrames.foreach(_ += 1) - frameType match { - case RowFrame => - boundIter = rows.generateIterator() - nextRow = WindowFunctionFrame.getNextOrNull(boundIter) - case RangeFrame => - lowerIter = rows.generateIterator() - upperIter = rows.generateIterator() - // Pre-seed cursor heads so `RangeBoundOrdering.compare` never - // dereferences null on round 0. Either may be null if `rows` is - // empty; the advance loops' `!= null` / `< upperBound` guards - // handle that. - lowerRow = WindowFunctionFrame.getNextOrNull(lowerIter) - upperRow = WindowFunctionFrame.getNextOrNull(upperIter) + if (shrinking) { + // Upper bound pinned to partition end; never moves. + upperBound = tree.size + frameType match { + case RowFrame => + // RowFrame lower-bound advance is pure index arithmetic; no iterator. + case RangeFrame => + lowerIter = rows.generateIterator() + lowerRow = WindowFunctionFrame.getNextOrNull(lowerIter) + } + } else { + frameType match { + case RowFrame => + boundIter = rows.generateIterator() + nextRow = WindowFunctionFrame.getNextOrNull(boundIter) + case RangeFrame => + lowerIter = rows.generateIterator() + upperIter = rows.generateIterator() + // Pre-seed cursor heads so `RangeBoundOrdering.compare` never + // dereferences null on round 0. Either may be null if `rows` is + // empty; the advance loops' `!= null` / `< upperBound` guards + // handle that. + lowerRow = WindowFunctionFrame.getNextOrNull(lowerIter) + upperRow = WindowFunctionFrame.getNextOrNull(upperIter) + } } } @@ -206,17 +231,20 @@ private[window] final class SegmentTreeWindowFunctionFrame( private def writeRow(index: Int, current: InternalRow): Unit = { var boundsChanged = index == 0 - // admit loop: extend upperBound; if a candidate is already below the - // lower bound, advance lowerBound in lock-step to preserve invariant - // (0 <= lowerBound <= upperBound <= tree.size). - while (nextRow != null && - ubound.compare(nextRow, upperBound, current, index) <= 0) { - if (lbound.compare(nextRow, lowerBound, current, index) < 0) { - lowerBound += 1 + if (!shrinking) { + val ub = ubound.get + // admit loop: extend upperBound; if a candidate is already below the + // lower bound, advance lowerBound in lock-step to preserve invariant + // (0 <= lowerBound <= upperBound <= tree.size). + while (nextRow != null && + ub.compare(nextRow, upperBound, current, index) <= 0) { + if (lbound.compare(nextRow, lowerBound, current, index) < 0) { + lowerBound += 1 + } + nextRow = WindowFunctionFrame.getNextOrNull(boundIter) + upperBound += 1 + boundsChanged = true } - nextRow = WindowFunctionFrame.getNextOrNull(boundIter) - upperBound += 1 - boundsChanged = true } // drop loop: advance lowerBound to the frame's left edge. RowFrame's // `lbound.compare` is pure index arithmetic so the input row is unread; @@ -235,13 +263,16 @@ private[window] final class SegmentTreeWindowFunctionFrame( private def writeRange(index: Int, current: InternalRow): Unit = { var boundsChanged = index == 0 - // admit loop (upper edge). `RangeBoundOrdering.compare` ignores its index - // arguments; we pass `upperBound` for API symmetry with RowBoundOrdering. - while (upperRow != null && - ubound.compare(upperRow, upperBound, current, index) <= 0) { - upperBound += 1 - upperRow = WindowFunctionFrame.getNextOrNull(upperIter) - boundsChanged = true + if (!shrinking) { + val ub = ubound.get + // admit loop (upper edge). `RangeBoundOrdering.compare` ignores its index + // arguments; we pass `upperBound` for API symmetry with RowBoundOrdering. + while (upperRow != null && + ub.compare(upperRow, upperBound, current, index) <= 0) { + upperBound += 1 + upperRow = WindowFunctionFrame.getNextOrNull(upperIter) + boundsChanged = true + } } // drop loop (lower edge): strict `< 0`, guarded by diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactoryBase.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactoryBase.scala index 2ae10ce9d711c..b90dd535dd26d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactoryBase.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactoryBase.scala @@ -281,11 +281,43 @@ trait WindowEvaluatorFactoryBase { // Shrinking Frame. case ("AGGREGATE", frameType, lower, UnboundedFollowing, _) => - target: InternalRow => { - new UnboundedFollowingWindowFunctionFrame( - target, - processor, - createBoundOrdering(frameType, lower, timeZone)) + if (eligibleForSegTree(functions, aggFilters, frameType, conf)) { + val segFns = functions.map(_.asInstanceOf[DeclarativeAggregate]) + val cacheHint = estimateMaxCachedBlocks( + lower, UnboundedFollowing, frameType, blockSize) + target: InternalRow => { + val tc = TaskContext.get() + if (tc == null) { + throw SparkException.internalError( + "WindowEvaluatorFactoryBase.shrinkingSegTreeFrameFactory requires " + + "an active TaskContext") + } + val tmm = tc.taskMemoryManager() + val lb = createBoundOrdering(frameType, lower, timeZone) + new SegmentTreeWindowFunctionFrame( + target, + processor, + segFns, + childOutput, + frameType, + lb, + ubound = None, + fallbackFactory = () => + new UnboundedFollowingWindowFunctionFrame(target, processor, lb), + (e, s) => MutableProjection.create(e, s), + conf, + cacheHint, + tmm, + numSegmentTreeFrames, + numSegmentTreeFallbackFrames) + } + } else { + target: InternalRow => { + new UnboundedFollowingWindowFunctionFrame( + target, + processor, + createBoundOrdering(frameType, lower, timeZone)) + } } // Moving Frame. @@ -305,14 +337,18 @@ trait WindowEvaluatorFactoryBase { "an active TaskContext") } val tmm = tc.taskMemoryManager() + val lb = createBoundOrdering(frameType, lower, timeZone) + val ub = createBoundOrdering(frameType, upper, timeZone) new SegmentTreeWindowFunctionFrame( target, processor, segFns, childOutput, frameType, - createBoundOrdering(frameType, lower, timeZone), - createBoundOrdering(frameType, upper, timeZone), + lb, + ubound = Some(ub), + fallbackFactory = () => + new SlidingWindowFunctionFrame(target, processor, lb, ub), (e, s) => MutableProjection.create(e, s), conf, cacheHint, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/UnboundedFollowingWindowBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/UnboundedFollowingWindowBenchmark.scala new file mode 100644 index 0000000000000..35a169fecaddb --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/UnboundedFollowingWindowBenchmark.scala @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.benchmark + +import org.apache.spark.benchmark.Benchmark +import org.apache.spark.sql.internal.SQLConf + +/** + * Benchmark for shrinking ROWS frames (... BETWEEN AND UNBOUNDED FOLLOWING). + * + * Today's `UnboundedFollowingWindowFunctionFrame` runs the suffix aggregate + * O(n * (n - 1) / 2) per partition (acknowledged inline at + * `WindowFunctionFrame.scala:636`). The segtree path replaces this with + * O(n log n) when `spark.sql.window.segmentTree.enabled=true`. + * + * Layout: single partition on a non-partitioned ORDER BY (the worst case; + * with PARTITION BY the cost decomposes as sum-of-partition N^2 and is + * dominated by the largest partition). + * + * Sections: + * - A: per-aggregate equivalence at N=10K (naive ~3-5s/iter target). + * - B: N-sweep for SUM, naive vs segtree, demonstrating the algorithmic + * gap. N=50K is the largest naive run (~60s/iter); N=100K and 200K + * are segtree-only because naive would take ~4-16 min/iter. + */ +object UnboundedFollowingWindowBenchmark extends SqlBasedBenchmark { + + // Section A: calibrated so naive baseline lands ~3s/iter at MAIN_N. + private val A_N: Long = 10L * 1024L // ~2.4s naive @ N=10K (smoke: 2391ms) + + // Section B: N-sweep + private val B_N_SMALL: Long = 5L * 1024L // ~1.2s naive + private val B_N_MID: Long = 25L * 1024L // ~14s naive + private val B_N_LARGE: Long = 50L * 1024L // ~57s naive (last naive run) + private val B_N_HUGE: Long = 100L * 1024L // segtree-only, naive would be ~4 min + private val B_N_GIANT: Long = 200L * 1024L // segtree-only, naive would be ~16 min + + private val ITERS_NORMAL: Int = 5 + private val ITERS_STRESS: Int = 3 + + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { + val smokeMode = mainArgs.nonEmpty + val smokeRowCount = if (smokeMode) mainArgs(0).toLong else 0L + + def setupIntTable(n: Long): Unit = { + spark.range(n) + .selectExpr("id", "cast(rand(42) * 1000000 as int) as v") + .coalesce(1) + .createOrReplaceTempView("t") + } + + // Shrinking frame: [current, end-of-partition). + val frame = "OVER (ORDER BY id ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING)" + + // Digest comparison ensuring naive and segtree produce identical results. + // SUM/COUNT/MIN/MAX on integers are bit-exact across paths. + def digest(aggFn: String, sqlConfs: (String, String)*): Long = { + withSQLConf(sqlConfs: _*) { + spark.sql(s"SELECT SUM(HASH(m)) FROM (SELECT $aggFn(v) $frame AS m FROM t)") + .head().getLong(0) + } + } + + def rowsLabel(rows: Long): String = { + if (rows >= 1000000) s"${rows / 1000000}M" + else if (rows >= 1024) s"${rows / 1024}K" + else rows.toString + } + + /** + * Section A: Run the same SQL with conf off and on. The naive case is the + * baseline, so iterations must be cheap enough to fit ~3-5s/iter. + */ + def runSectionA(aggFn: String, iters: Int, rows: Long): Unit = { + val dNaive = digest(aggFn) + val dSeg = digest(aggFn, SQLConf.WINDOW_SEGMENT_TREE_ENABLED.key -> "true") + require(dNaive == dSeg, + s"$aggFn shrinking-frame digest mismatch: naive=$dNaive seg=$dSeg") + + val benchmark = new Benchmark( + s"$aggFn shrinking frame, N=${rowsLabel(rows)} rows", + rows, output = output) + benchmark.addCase(s"$aggFn naive (master O(N^2))", numIters = iters) { _ => + spark.sql(s"SELECT $aggFn(v) $frame FROM t").noop() + } + benchmark.addCase(s"$aggFn segtree", numIters = iters) { _ => + withSQLConf(SQLConf.WINDOW_SEGMENT_TREE_ENABLED.key -> "true") { + spark.sql(s"SELECT $aggFn(v) $frame FROM t").noop() + } + } + benchmark.run() + } + + /** + * Section B: SUM-only N-sweep. At N <= 50K we run both paths. At N >= 100K + * we run segtree-only because naive would dominate the benchmark wall-clock. + */ + def runSectionB(rows: Long, includeNaive: Boolean, iters: Int): Unit = { + if (includeNaive) { + val dNaive = digest("SUM") + val dSeg = digest("SUM", SQLConf.WINDOW_SEGMENT_TREE_ENABLED.key -> "true") + require(dNaive == dSeg, + s"Section B N=${rowsLabel(rows)} digest mismatch: naive=$dNaive seg=$dSeg") + } + val benchmark = new Benchmark( + s"SUM shrinking frame, N=${rowsLabel(rows)} rows" + + (if (!includeNaive) " (segtree-only)" else ""), + rows, output = output) + if (includeNaive) { + benchmark.addCase(s"SUM naive (master O(N^2)) N=${rowsLabel(rows)}", + numIters = iters) { _ => + spark.sql(s"SELECT SUM(v) $frame FROM t").noop() + } + } + benchmark.addCase(s"SUM segtree N=${rowsLabel(rows)}", numIters = iters) { _ => + withSQLConf(SQLConf.WINDOW_SEGMENT_TREE_ENABLED.key -> "true") { + spark.sql(s"SELECT SUM(v) $frame FROM t").noop() + } + } + benchmark.run() + } + + if (smokeMode) { + setupIntTable(smokeRowCount) + runBenchmark("SMOKE Section A SUM") { + runSectionA("SUM", ITERS_STRESS, smokeRowCount) + } + } else { + // Section A: per-aggregate (SUM, MIN, MAX, COUNT, AVG) at calibrated N=10K. + // STDDEV omitted: shrinking frame doesn't widen multi-buffer aggregates' + // win profile vs sliding (the gain is purely algorithmic, not buffer-pack). + setupIntTable(A_N) + runBenchmark("Section A - SUM (non-invertible suffix)") { + runSectionA("SUM", ITERS_NORMAL, A_N) + } + runBenchmark("Section A - MIN") { + runSectionA("MIN", ITERS_NORMAL, A_N) + } + runBenchmark("Section A - MAX") { + runSectionA("MAX", ITERS_NORMAL, A_N) + } + runBenchmark("Section A - COUNT") { + runSectionA("COUNT", ITERS_NORMAL, A_N) + } + runBenchmark("Section A - AVG (multi-buffer)") { + runSectionA("AVG", ITERS_NORMAL, A_N) + } + + // Section B: N-sweep showing the algorithmic gap widening with N. + setupIntTable(B_N_SMALL) + runBenchmark("Section B - N=5K") { + runSectionB(B_N_SMALL, includeNaive = true, ITERS_NORMAL) + } + setupIntTable(B_N_MID) + runBenchmark("Section B - N=25K (stress)") { + runSectionB(B_N_MID, includeNaive = true, ITERS_STRESS) + } + setupIntTable(B_N_LARGE) + runBenchmark("Section B - N=50K (stress, last naive run)") { + runSectionB(B_N_LARGE, includeNaive = true, ITERS_STRESS) + } + setupIntTable(B_N_HUGE) + runBenchmark("Section B - N=100K (segtree-only, stress)") { + runSectionB(B_N_HUGE, includeNaive = false, ITERS_STRESS) + } + setupIntTable(B_N_GIANT) + runBenchmark("Section B - N=200K (segtree-only, stress)") { + runSectionB(B_N_GIANT, includeNaive = false, ITERS_STRESS) + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/window/SegmentTreeWindowTestHelper.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/window/SegmentTreeWindowTestHelper.scala index cd5237c9b310f..6ac7b36f412ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/window/SegmentTreeWindowTestHelper.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/window/SegmentTreeWindowTestHelper.scala @@ -92,14 +92,17 @@ private[window] object SegmentTreeWindowTestHelper { /** Create a new frame. Caller owns lifecycle unless tracked via `track()`. */ def newFrame(): SegmentTreeWindowFunctionFrame = { val target = new SpecificInternalRow(Seq(bufAttrs.head.dataType)) + val lb = RowBoundOrdering(-1) + val ub = RowBoundOrdering(1) val frame = new SegmentTreeWindowFunctionFrame( target, processor, Array(fn), input, RowFrame, - RowBoundOrdering(-1), - RowBoundOrdering(1), + lb, + ubound = Some(ub), + fallbackFactory = () => new SlidingWindowFunctionFrame(target, processor, lb, ub), (es, s) => GenerateMutableProjection.generate(es, s), conf, None, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/window/UnboundedFollowingSegmentTreeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/window/UnboundedFollowingSegmentTreeSuite.scala new file mode 100644 index 0000000000000..e27267b6b5b5b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/window/UnboundedFollowingSegmentTreeSuite.scala @@ -0,0 +1,412 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.window + +import org.apache.spark.sql.{DataFrame, QueryTest, Row} +import org.apache.spark.sql.expressions.Window +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession + +/** + * End-to-end correctness tests for the segment-tree shrinking-frame path + * (`... ROWS/RANGE BETWEEN AND UNBOUNDED FOLLOWING`). + * + * Mirrors the structure of [[SegmentTreeWindowFunctionSuite]]: every test + * runs the same SQL with `spark.sql.window.segmentTree.enabled` off and on + * and asserts row-set equality. The "off" path runs through + * [[UnboundedFollowingWindowFunctionFrame]] (the O(N^2) baseline); the "on" + * path runs through the new shrinking branch in + * [[SegmentTreeWindowFunctionFrame]] (`ubound = None`). + */ +class UnboundedFollowingSegmentTreeSuite extends SharedSparkSession { + + import testImplicits._ + + private val enableSegTree: Map[String, String] = Map( + SQLConf.WINDOW_SEGMENT_TREE_ENABLED.key -> "true", + SQLConf.WINDOW_SEGMENT_TREE_MIN_PARTITION_ROWS.key -> "1") + + private val disableSegTree: Map[String, String] = Map( + SQLConf.WINDOW_SEGMENT_TREE_ENABLED.key -> "false") + + /** Baseline (flag off) vs segtree (flag on); compare row-sets. */ + private def checkEquivalence(build: () => DataFrame): Unit = { + val baseline: Seq[Row] = withSQLConf(disableSegTree.toSeq: _*) { + build().collect().toSeq + } + withSQLConf(enableSegTree.toSeq: _*) { + val actual = build().collect().toSeq + QueryTest.sameRows(baseline, actual, isSorted = false).foreach { err => + fail(s"shrinking-frame segtree output differs from baseline.\n$err") + } + } + } + + /** SQL-level variant that accepts a query string. */ + private def checkSqlEquivalence(df: DataFrame, query: String): Unit = { + df.createOrReplaceTempView("t") + try { + val baseline = withSQLConf(disableSegTree.toSeq: _*) { + spark.sql(query).collect().sortBy(_.toString) + } + withSQLConf(enableSegTree.toSeq: _*) { + val actual = spark.sql(query).collect().sortBy(_.toString) + assert(actual.toSeq === baseline.toSeq, + s"shrinking-frame segtree output differs from baseline.\n" + + s"Expected: ${baseline.toSeq}\nActual: ${actual.toSeq}") + } + } finally { + spark.catalog.dropTempView("t") + } + } + + /** 3 partitions, 40 rows each; values = row index. */ + private def baseDF: DataFrame = + spark.range(0, 120).selectExpr( + "id", + "(id % 3) AS pk", + "CAST(id AS INT) AS v") + + /** Shrinking ROWS frame: [lo, end-of-partition). */ + private def shrinkingRowsFrame(lo: Int) = + Window.partitionBy($"pk").orderBy($"id") + .rowsBetween(lo, Window.unboundedFollowing) + + // ============================================================ + // ROWS frame: basic aggregate equivalence (CURRENT ROW lower) + // ============================================================ + + test("MIN over ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING") { + checkEquivalence(() => + baseDF.select($"id", $"pk", min($"v").over(shrinkingRowsFrame(0)).as("agg"))) + } + + test("MAX over ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING") { + checkEquivalence(() => + baseDF.select($"id", $"pk", max($"v").over(shrinkingRowsFrame(0)).as("agg"))) + } + + test("SUM over ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING") { + checkEquivalence(() => + baseDF.select($"id", $"pk", sum($"v").over(shrinkingRowsFrame(0)).as("agg"))) + } + + test("COUNT over ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING") { + checkEquivalence(() => + baseDF.select($"id", $"pk", count($"v").over(shrinkingRowsFrame(0)).as("agg"))) + } + + test("AVG over ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING") { + checkEquivalence(() => + baseDF.select($"id", $"pk", avg($"v").over(shrinkingRowsFrame(0)).as("agg"))) + } + + // ============================================================ + // ROWS frame: lower-bound variations + // ============================================================ + + test("ROWS BETWEEN 5 PRECEDING AND UNBOUNDED FOLLOWING (suffix + lookback)") { + checkEquivalence(() => + baseDF.select($"id", $"pk", sum($"v").over(shrinkingRowsFrame(-5)).as("agg"))) + } + + test("ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING is NOT this path") { + // Both-unbounded routes to UnboundedWindowFunctionFrame (different case + // in the dispatcher) and is one-shot O(1). This test just verifies the + // segtree flag doesn't break it. + val frame = Window.partitionBy($"pk").orderBy($"id") + .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing) + checkEquivalence(() => + baseDF.select($"id", $"pk", sum($"v").over(frame).as("agg"))) + } + + test("ROWS BETWEEN 5 FOLLOWING AND UNBOUNDED FOLLOWING (lower bound is positive)") { + checkEquivalence(() => + baseDF.select($"id", $"pk", sum($"v").over(shrinkingRowsFrame(5)).as("agg"))) + } + + // ============================================================ + // Multi-aggregate: shared frame + // ============================================================ + + test("MIN + MAX + SUM share a single shrinking frame") { + checkEquivalence(() => + baseDF.select( + $"id", $"pk", + min($"v").over(shrinkingRowsFrame(0)).as("mn"), + max($"v").over(shrinkingRowsFrame(0)).as("mx"), + sum($"v").over(shrinkingRowsFrame(0)).as("s"))) + } + + // ============================================================ + // Partition / boundary edge cases + // ============================================================ + + test("single-row partition") { + val df = spark.range(0, 5).selectExpr("id", "id AS pk", "CAST(id AS INT) AS v") + checkEquivalence(() => + df.select($"id", $"pk", sum($"v").over(shrinkingRowsFrame(0)).as("agg"))) + } + + test("empty result table (no rows)") { + val df = spark.emptyDataFrame.selectExpr("CAST(NULL AS BIGINT) AS id", + "CAST(NULL AS BIGINT) AS pk", "CAST(NULL AS INT) AS v") + .where("id IS NOT NULL") + checkEquivalence(() => + df.select($"id", $"pk", sum($"v").over(shrinkingRowsFrame(0)).as("agg"))) + } + + test("partition below minPartitionRows falls back to UnboundedFollowingWindowFunctionFrame") { + // With minRows=1024 the segtree path forces fallback; baseline (off) and + // forced-fallback (on, but min=1024) must match. The point is that the + // small-partition path goes through the legacy frame, not segtree. + val df = baseDF + val baseline = withSQLConf(disableSegTree.toSeq: _*) { + df.select($"id", $"pk", sum($"v").over(shrinkingRowsFrame(0)).as("s")) + .collect().toSeq + } + withSQLConf( + SQLConf.WINDOW_SEGMENT_TREE_ENABLED.key -> "true", + SQLConf.WINDOW_SEGMENT_TREE_MIN_PARTITION_ROWS.key -> "1024") { + val actual = df.select($"id", $"pk", sum($"v").over(shrinkingRowsFrame(0)).as("s")) + .collect().toSeq + QueryTest.sameRows(baseline, actual, isSorted = false).foreach { err => + fail(s"forced-fallback path diverges from baseline.\n$err") + } + } + } + + // ============================================================ + // NULL / NaN / numeric edge cases + // ============================================================ + + test("all-NULL column: SUM/MIN/MAX/AVG/COUNT") { + val df = spark.range(0, 30).selectExpr("id", "(id % 3) AS pk", + "CAST(NULL AS INT) AS v") + checkEquivalence(() => + df.select($"id", $"pk", + sum($"v").over(shrinkingRowsFrame(0)).as("s"), + min($"v").over(shrinkingRowsFrame(0)).as("mn"), + max($"v").over(shrinkingRowsFrame(0)).as("mx"), + avg($"v").over(shrinkingRowsFrame(0)).as("a"), + count($"v").over(shrinkingRowsFrame(0)).as("c"))) + } + + test("mixed NULL and non-NULL: NULLs must not leak into MIN/MAX") { + val df = (0 until 60).map { i => + val v: Option[Int] = if (i % 4 == 0) None else Some(i) + (i.toLong, (i % 3).toLong, v) + }.toDF("id", "pk", "v") + checkEquivalence(() => + df.select($"id", $"pk", + min($"v").over(shrinkingRowsFrame(0)).as("mn"), + max($"v").over(shrinkingRowsFrame(0)).as("mx"), + sum($"v").over(shrinkingRowsFrame(0)).as("s"), + count($"v").over(shrinkingRowsFrame(0)).as("c"))) + } + + test("Double NaN and +/-Infinity propagate correctly through MIN/MAX/SUM") { + val df = Seq( + (0L, 0L, 1.0d), (1L, 0L, Double.NaN), (2L, 0L, 3.0d), + (3L, 0L, Double.PositiveInfinity), (4L, 0L, 5.0d), + (5L, 0L, Double.NegativeInfinity), (6L, 0L, 7.0d), (7L, 0L, 9.0d) + ).toDF("id", "pk", "v") + checkEquivalence(() => + df.select($"id", $"pk", + min($"v").over(shrinkingRowsFrame(0)).as("mn"), + max($"v").over(shrinkingRowsFrame(0)).as("mx"), + sum($"v").over(shrinkingRowsFrame(0)).as("s"))) + } + + // ============================================================ + // Type coverage + // ============================================================ + + test("numeric types: Int / Long / Double / Decimal") { + val df = spark.range(0, 60).selectExpr( + "id", + "(id % 3) AS pk", + "CAST(id AS INT) AS vi", + "id * 1000000000L AS vl", + "CAST(id AS DOUBLE) * 1.5 AS vd", + "CAST(id AS DECIMAL(20, 5)) AS vdec") + checkEquivalence(() => + df.select($"id", $"pk", + sum($"vi").over(shrinkingRowsFrame(0)).as("si"), + sum($"vl").over(shrinkingRowsFrame(0)).as("sl"), + sum($"vd").over(shrinkingRowsFrame(0)).as("sd"), + sum($"vdec").over(shrinkingRowsFrame(0)).as("sdec"))) + } + + test("String lexicographic MIN/MAX") { + val df = spark.range(0, 30).selectExpr( + "id", + "(id % 3) AS pk", + "CONCAT('s', LPAD(CAST((id * 7) % 31 AS STRING), 3, '0')) AS v") + checkEquivalence(() => + df.select($"id", $"pk", + min($"v").over(shrinkingRowsFrame(0)).as("mn"), + max($"v").over(shrinkingRowsFrame(0)).as("mx"))) + } + + test("Date / Timestamp MIN/MAX") { + val df = spark.range(0, 24).selectExpr( + "id", + "(id % 3) AS pk", + "DATE_ADD(DATE'2024-01-01', CAST(id AS INT)) AS d", + "TIMESTAMPADD(HOUR, CAST(id AS INT), TIMESTAMP'2024-01-01 00:00:00') AS ts") + checkEquivalence(() => + df.select($"id", $"pk", + min($"d").over(shrinkingRowsFrame(0)).as("dmn"), + max($"ts").over(shrinkingRowsFrame(0)).as("tsmx"))) + } + + // ============================================================ + // Allow-list: non-DeclarativeAggregate paths must fall back + // ============================================================ + + test("collect_list falls back cleanly (non-DeclarativeAggregate)") { + // collect_list is ImperativeAggregate; segtree path must not engage. + // The result should still be correct via the legacy frame. + checkEquivalence(() => + baseDF.select($"id", $"pk", + collect_list($"v").over(shrinkingRowsFrame(0)).as("lst"))) + } + + test("DISTINCT shrinking aggregate is rejected by analyzer regardless of seg-tree flag") { + def run(): Unit = { + baseDF.select($"id", $"pk", + count_distinct($"v").over(shrinkingRowsFrame(0)).as("cd")).collect() + } + withSQLConf(disableSegTree.toSeq: _*) { + val e = intercept[org.apache.spark.sql.AnalysisException](run()) + assert(e.getMessage.contains("DISTINCT_WINDOW_FUNCTION_UNSUPPORTED")) + } + withSQLConf(enableSegTree.toSeq: _*) { + val e = intercept[org.apache.spark.sql.AnalysisException](run()) + assert(e.getMessage.contains("DISTINCT_WINDOW_FUNCTION_UNSUPPORTED")) + } + } + + // ============================================================ + // RANGE shrinking frame (single-order-expr) + // ============================================================ + + test("RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING with non-uniform gaps") { + val df = spark.range(0, 40).selectExpr( + "CAST(id AS INT) AS id", + "(CAST(id AS INT) % 2) AS pk", + "CAST(CASE CAST(id AS INT) % 7 " + + "WHEN 0 THEN 1 WHEN 1 THEN 3 WHEN 2 THEN 4 WHEN 3 THEN 4 " + + "WHEN 4 THEN 7 WHEN 5 THEN 10 ELSE 15 END + (CAST(id AS INT) / 7) * 20 AS INT) AS k", + "CAST((id * 31) % 97 AS INT) AS v") + checkSqlEquivalence(df, + """SELECT id, pk, + | MIN(v) OVER (PARTITION BY pk ORDER BY k + | RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) AS mn, + | MAX(v) OVER (PARTITION BY pk ORDER BY k + | RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) AS mx + |FROM t""".stripMargin) + } + + test("RANGE BETWEEN 2 PRECEDING AND UNBOUNDED FOLLOWING") { + val df = spark.range(0, 40).selectExpr( + "CAST(id AS INT) AS id", + "(CAST(id AS INT) % 2) AS pk", + "CAST((id * 7) % 31 AS INT) AS k", + "CAST((id * 11) % 53 AS INT) AS v") + checkSqlEquivalence(df, + """SELECT id, pk, k, + | SUM(v) OVER (PARTITION BY pk ORDER BY k + | RANGE BETWEEN 2 PRECEDING AND UNBOUNDED FOLLOWING) AS s + |FROM t""".stripMargin) + } + + test("RANGE with tie (duplicate order keys): full tie group at lower edge") { + // Trap: at the lower edge, the FULL tie group at the lower offset must + // be retained, not just the first row. + val rows = (0 until 40).map { i => + val k = Seq(1, 2, 2, 2, 3, 4, 5)(i % 7) + (i, i % 2, k, (i * 13) % 41) + } + val df = rows.toDF("id", "pk", "k", "v") + checkSqlEquivalence(df, + """SELECT id, pk, k, + | MIN(v) OVER (PARTITION BY pk ORDER BY k + | RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) AS mn, + | MAX(v) OVER (PARTITION BY pk ORDER BY k + | RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) AS mx + |FROM t""".stripMargin) + } + + test("RANGE with NULL order key (NULLS FIRST / NULLS LAST)") { + val rows = (0 until 36).map { i => + val kOpt: Option[Int] = (i % 6) match { + case 0 | 1 | 5 => None + case 2 => Some(1) + case 3 => Some(2) + case _ => Some(3) + } + (i, i % 2, kOpt, (i * 11) % 37) + } + val df = rows.toDF("id", "pk", "k", "v") + checkSqlEquivalence(df, + """SELECT id, pk, + | MIN(v) OVER (PARTITION BY pk ORDER BY k ASC NULLS FIRST + | RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) AS mn_nf, + | MAX(v) OVER (PARTITION BY pk ORDER BY k ASC NULLS LAST + | RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) AS mx_nl + |FROM t""".stripMargin) + } + + test("RANGE Timestamp with INTERVAL offset (MAX) and shrinking upper") { + val df = spark.range(0, 30).selectExpr( + "CAST(id AS INT) AS id", + "(CAST(id AS INT) % 2) AS pk", + "CAST(TIMESTAMP'2024-01-01 10:00:00' + " + + "make_interval(0, 0, 0, 0, 0, 30 * CAST(id AS INT) * " + + "(CASE CAST(id AS INT) % 3 WHEN 0 THEN 1 WHEN 1 THEN 3 ELSE 4 END), 0) " + + "AS TIMESTAMP) AS ts", + "CAST((id * 17) % 53 AS INT) AS v") + checkSqlEquivalence(df, + """SELECT id, pk, + | MAX(v) OVER (PARTITION BY pk ORDER BY ts + | RANGE BETWEEN INTERVAL '1' HOUR PRECEDING AND UNBOUNDED FOLLOWING) AS mx + |FROM t""".stripMargin) + } + + // ============================================================ + // Feature-flag off: legacy frame is used + // ============================================================ + + test("feature flag off: segmentTree.enabled=false yields baseline semantics") { + val df = baseDF + val expected = withSQLConf(disableSegTree.toSeq: _*) { + df.select($"id", $"pk", min($"v").over(shrinkingRowsFrame(0)).as("mn")) + .collect().sortBy(_.toString).toSeq + } + withSQLConf( + SQLConf.WINDOW_SEGMENT_TREE_ENABLED.key -> "false", + SQLConf.WINDOW_SEGMENT_TREE_MIN_PARTITION_ROWS.key -> "1024") { + val actual = df.select($"id", $"pk", min($"v").over(shrinkingRowsFrame(0)).as("mn")) + .collect().sortBy(_.toString).toSeq + assert(actual === expected) + } + } +} From 68fd5769b0192ba7a87ec2dda4d01b537af58c1a Mon Sep 17 00:00:00 2001 From: Anupam Yadav Date: Thu, 4 Jun 2026 21:20:37 +0000 Subject: [PATCH 2/6] [SPARK-57220][SQL][FOLLOWUP] Use explicit cacheHint=Some(2) for shrinking frames The shrinking-frame branch in `WindowEvaluatorFactoryBase` previously called `estimateMaxCachedBlocks(lower, UnboundedFollowing, ...)`, which silently returns the default `Some(8)` because no `IntegerLiteral` upper-bound case matches. That value is numerically correct but misleading -- a reader inspecting the call site reasonably worries that the LRU will thrash on partitions large enough to span more than 8 blocks (>= 512K rows at the default 64K block size). In fact the shrinking-frame access pattern needs at most 2 cached block-levels regardless of partition size: - Middle blocks of `[lower, n)` are answered directly from the always-resident `blockAggregates`, never via the per-block LRU. - The lower-edge cursor advances monotonically with the output row, so each partial block is needed for at most `blockSize` consecutive queries and then never revisited. - One slot for the active block plus one for brief overlap at the boundary covers the entire pattern. Replace the indirect call with `Some(2)` and a comment documenting the shrinking-frame access pattern. Numerically equivalent to the prior behaviour for any partition size; the change is documentation about what the shrinking-frame path actually needs. Tests: existing `UnboundedFollowingSegmentTreeSuite` (26) and `SegmentTreeWindowFunctionSuite` (41) -- 67/67 pass. Scalastyle clean. Was this patch authored or co-authored using generative AI tooling? Yes. Authored with assistance from Claude (Anthropic). --- .../window/WindowEvaluatorFactoryBase.scala | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactoryBase.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactoryBase.scala index b90dd535dd26d..38ca4caa712e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactoryBase.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactoryBase.scala @@ -283,8 +283,17 @@ trait WindowEvaluatorFactoryBase { case ("AGGREGATE", frameType, lower, UnboundedFollowing, _) => if (eligibleForSegTree(functions, aggFilters, frameType, conf)) { val segFns = functions.map(_.asInstanceOf[DeclarativeAggregate]) - val cacheHint = estimateMaxCachedBlocks( - lower, UnboundedFollowing, frameType, blockSize) + // Shrinking frames touch the LRU only for one partial block at + // the lower edge -- middle blocks of `[lower, n)` are answered + // directly from `blockAggregates`, not the LRU. The cursor + // advances monotonically with the output row, so blocks behind + // the cursor are never revisited. Hint = 2 (active block + 1 + // slack for the brief overlap when the cursor crosses a + // boundary) suffices regardless of partition size; routing + // through `estimateMaxCachedBlocks` would produce 8 by default + // (no `IntegerLiteral` upper match) -- correct numerically but + // misleading about what the shrinking path actually needs. + val cacheHint = Some(2) target: InternalRow => { val tc = TaskContext.get() if (tc == null) { From 98d14b821a70348f3d58b3d53da53c85e8604a70 Mon Sep 17 00:00:00 2001 From: Anupam Yadav Date: Tue, 9 Jun 2026 19:01:30 -0700 Subject: [PATCH 3/6] Apply suggestions from code review Co-authored-by: Wenchen Fan --- .../sql/execution/window/SegmentTreeWindowFunctionFrame.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/SegmentTreeWindowFunctionFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/SegmentTreeWindowFunctionFrame.scala index a6de28c0bfda8..be7a838492b6e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/SegmentTreeWindowFunctionFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/SegmentTreeWindowFunctionFrame.scala @@ -111,7 +111,7 @@ private[window] final class SegmentTreeWindowFunctionFrame( /** * Runtime dispatch flag: when `true`, `write()`, `currentLowerBound()`, and * `currentUpperBound()` delegate to the wrapped fallback frame produced by - * `fallbackFactory` (small-partition path). The fallback type is sliding- + * `fallbackFactory` (small-partition path). The fallback type is shape- * dependent: [[SlidingWindowFunctionFrame]] for moving frames and * [[UnboundedFollowingWindowFunctionFrame]] for shrinking frames. Set by * `prepare()` based on partition size vs. From 4caeb0a538fadbb7aa9c4baa3dc48c65b9e2b34e Mon Sep 17 00:00:00 2001 From: Anupam Yadav Date: Wed, 10 Jun 2026 02:10:00 +0000 Subject: [PATCH 4/6] [SPARK-57220][SQL][FOLLOWUP] Address review nits: clarify cacheHint rationale and writeRow comment Address two of cloud-fan's review nits on PR #56291: 1. WindowEvaluatorFactoryBase.scala: the cacheHint=Some(2) rationale was incomplete -- it claimed shrinking frames only touch the LRU for the lower-edge partial block, but WindowSegmentTree.query also fetches the partition's last block via ensureBlockLevels(bhi) on every multi-block query. Rewrote the comment to reflect both LRU slots and warn against tuning the hint down to 1 (which would thrash by evicting the last block on every query). 2. SegmentTreeWindowFunctionFrame.scala: the writeRow/writeRange header comment described only the sliding admit-then-drop path. Restructured it to cover both shapes -- sliding (admit-then-drop, equivalence guarded by SegmentTreeWindowFunctionSuite) and shrinking (drop-only, equivalence guarded by UnboundedFollowingSegmentTreeSuite). Comment-only changes; no behavior change. --- .../SegmentTreeWindowFunctionFrame.scala | 26 ++++++++++++++----- .../window/WindowEvaluatorFactoryBase.scala | 23 +++++++++------- 2 files changed, 32 insertions(+), 17 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/SegmentTreeWindowFunctionFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/SegmentTreeWindowFunctionFrame.scala index be7a838492b6e..496126150f54b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/SegmentTreeWindowFunctionFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/SegmentTreeWindowFunctionFrame.scala @@ -221,13 +221,25 @@ private[window] final class SegmentTreeWindowFunctionFrame( } } - // `writeRow`/`writeRange` mirror the `(lowerBound, upperBound)` monotone - // cursor invariant of `SlidingWindowFunctionFrame.write`, but run - // admit-then-drop (no buffer to maintain) instead of drop-then-admit. - // Any future fix to Sliding's boundary semantics must be mirrored here; - // equivalence is guarded by `SegmentTreeWindowFunctionSuite` flag-on/off - // tests (`checkRangeEquivalence`, `feature flag off ...`, fallback tests) - // which compare against the Sliding baseline. + // `writeRow`/`writeRange` maintain the `(lowerBound, upperBound)` monotone + // cursor invariant for both sliding and shrinking frame shapes: + // + // - Sliding (`ubound.isDefined`, mirrors `SlidingWindowFunctionFrame.write`): + // run admit-then-drop (no buffer to maintain) instead of drop-then-admit. + // The admit loop below (`if (!shrinking)`) extends `upperBound`; the drop + // loop advances `lowerBound`. Any future fix to Sliding's boundary + // semantics must be mirrored here; equivalence is guarded by + // `SegmentTreeWindowFunctionSuite` flag-on/off tests + // (`checkRangeEquivalence`, `feature flag off ...`, fallback tests) + // against the Sliding baseline. + // + // - Shrinking (`ubound.isEmpty`, upper is `tree.size`): drop-only. The admit + // loop is skipped; only `lowerBound` advances each step. Equivalence is + // guarded by `UnboundedFollowingSegmentTreeSuite` against the + // `UnboundedFollowingWindowFunctionFrame` baseline. + // + // In both shapes, the segtree's `query(lowerBound, upperBound, ...)` is + // re-issued only when `boundsChanged` is true. private def writeRow(index: Int, current: InternalRow): Unit = { var boundsChanged = index == 0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactoryBase.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactoryBase.scala index 38ca4caa712e7..40cba3d5ceb4c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactoryBase.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactoryBase.scala @@ -283,16 +283,19 @@ trait WindowEvaluatorFactoryBase { case ("AGGREGATE", frameType, lower, UnboundedFollowing, _) => if (eligibleForSegTree(functions, aggFilters, frameType, conf)) { val segFns = functions.map(_.asInstanceOf[DeclarativeAggregate]) - // Shrinking frames touch the LRU only for one partial block at - // the lower edge -- middle blocks of `[lower, n)` are answered - // directly from `blockAggregates`, not the LRU. The cursor - // advances monotonically with the output row, so blocks behind - // the cursor are never revisited. Hint = 2 (active block + 1 - // slack for the brief overlap when the cursor crosses a - // boundary) suffices regardless of partition size; routing - // through `estimateMaxCachedBlocks` would produce 8 by default - // (no `IntegerLiteral` upper match) -- correct numerically but - // misleading about what the shrinking path actually needs. + // Shrinking-frame queries `[lower, n)` on `WindowSegmentTree` touch the LRU + // for exactly two blocks per query: (1) the lower-edge partial block, and + // (2) the partition's last block (the right-partial `mergeBlockRange(bhi, 0, + // ...)` calls `ensureBlockLevels(bhi)` on every multi-block query). Middle + // blocks of `[lower, n)` are answered directly from `blockAggregates` and + // never go through the LRU. The lower-edge block advances monotonically with + // the output row, so once the cursor crosses a boundary the previous block + // is never revisited; the last block stays hot because every query touches + // it. Hint = 2 keeps both resident; routing through `estimateMaxCachedBlocks` + // would produce 8 by default (no `IntegerLiteral` upper match) -- correct + // numerically but misleading about what the shrinking path actually needs. + // Note: tuning this down to 1 would thrash, evicting the last block on every + // query and forcing it to be rebuilt. val cacheHint = Some(2) target: InternalRow => { val tc = TaskContext.get() From 44677a37245a3630b0f96df2f7be60ca07084cfc Mon Sep 17 00:00:00 2001 From: Anupam Yadav Date: Fri, 12 Jun 2026 01:47:00 +0000 Subject: [PATCH 5/6] [SPARK-57220][SQL][TESTS] Fix benchmark comment: MAIN_N -> A_N --- .../execution/benchmark/UnboundedFollowingWindowBenchmark.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/UnboundedFollowingWindowBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/UnboundedFollowingWindowBenchmark.scala index 35a169fecaddb..7fbce2f35be7c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/UnboundedFollowingWindowBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/UnboundedFollowingWindowBenchmark.scala @@ -40,7 +40,7 @@ import org.apache.spark.sql.internal.SQLConf */ object UnboundedFollowingWindowBenchmark extends SqlBasedBenchmark { - // Section A: calibrated so naive baseline lands ~3s/iter at MAIN_N. + // Section A: calibrated so naive baseline lands ~3s/iter at A_N. private val A_N: Long = 10L * 1024L // ~2.4s naive @ N=10K (smoke: 2391ms) // Section B: N-sweep From d05984ec8ee57f617313621999b312d229cacacf Mon Sep 17 00:00:00 2001 From: Anupam Yadav Date: Sat, 13 Jun 2026 04:35:02 +0000 Subject: [PATCH 6/6] [SPARK-57424][SQL] Add First/Last to segment-tree window aggregate allowlist Adds `classOf[First]` and `classOf[Last]` to `WindowSegmentTree.EligibleAggregates`, routing First/Last window aggregates through the segment-tree path established by SPARK-56546 (sliding) and SPARK-57220 (shrinking) instead of the legacy O(N x W) sliding / O(N^2) shrinking frame implementations. No new frame class, no new SQLConf, no dispatcher changes -- the existing dispatcher branches (WindowEvaluatorFactoryBase: shrinking at line 283, moving at line 336) already gate on `eligibleForSegTree`, which calls `WindowSegmentTree.isEligible`. Why this is correct under the segment-tree combine: `First.mergeExpressions = if(valueSet.left, left, right)` and `Last.mergeExpressions = if(valueSet.right, right, left)` are order-dependent but correct under the left-to-right combine traversal produced by `WindowSegmentTree.query` (left partial -> full blocks ascending -> right partial; within a block, `queryDescend` walks children in ascending index order). Under that traversal both produce the row-order extreme across any contiguous range, matching the legacy result row-for-row. For IGNORE NULLS the same merge is mode-agnostic: per-row `updateExpressions` only set `valueSet=true` on non-null values, so a per-block partial of `(null, false)` for an all-NULL block is correctly skipped when merged with a later non-null block via mergeExpressions. The earlier docstring labeled First/Last as "Intentionally excluded ... order-dependent". This was over-conservative -- order-dependent in row-traversal order is exactly what the segment tree provides. Updated the docstring to enumerate First/Last alongside Min/Max/Sum/etc and document the audit explicitly. Tests: * WindowSegmentTreeAllowlistSuite: 4 routing tests for first / last / first_ignore_nulls / last_ignore_nulls; flipped the previous "first/last falls through" negative tests; updated the mixed-allowlist test to use collect_list (still on the denylist). * SegmentTreeWindowFunctionSuite: 6 oracle equivalence tests covering sliding First/Last respect-nulls and ignore-nulls, all-NULL columns in both modes, and a dedicated test for stretches of consecutive NULLs in IGNORE NULLS mode (the merge-path stress case). * UnboundedFollowingSegmentTreeSuite: 5 oracle equivalence tests covering shrinking First/Last respect-nulls and ignore-nulls plus all-NULL column boundary case. * All 97 tests in the three suites pass; 33 adjacent segtree tests pass unchanged; scalastyle clean. Benchmark (FirstLastSegmentTreeWindowBenchmark, Linux x86_64, Intel Xeon Platinum 8259CL @ 2.50GHz, OpenJDK 25.0.3+9-LTS): Sliding frame [-1000, +1000] at N=10K: | Aggregate | Naive | Segtree | Speedup | | FIRST respect-nulls | 414 ms | 94 ms | 4.4x | | LAST respect-nulls | 728 ms | 101 ms | 7.2x | | FIRST ignore-nulls | 528 ms | 86 ms | 6.1x | | LAST ignore-nulls | 913 ms | 91 ms | 10.0x | Shrinking frame [CURRENT ROW, UNBOUNDED FOLLOWING] at N=10K: | Aggregate | Naive | Segtree | Speedup | | FIRST respect-nulls | 2,158 ms | 79 ms | 27.5x | | LAST respect-nulls | 2,412 ms | 79 ms | 30.6x | | FIRST ignore-nulls | 2,363 ms | 76 ms | 30.9x | | LAST ignore-nulls | 3,399 ms | 79 ms | 43.0x | N-sweep on FIRST shrinking: | N | Naive | Segtree | Speedup | | 5K | 580 ms | 64 ms | 9.1x | | 25K | 13,407 ms | 107 ms | 125.5x | | 50K | 53,784 ms | 172 ms | 312.0x | | 100K | -- | 287 ms | -- | Same opt-in conf (`spark.sql.window.segmentTree.enabled`, default off); same eligibility allowlist mechanism; same fallback for partitions below `minPartitionRows`; same SQLMetrics. No public API changes. --- .../execution/window/WindowSegmentTree.scala | 29 ++- .../FirstLastSegmentTreeWindowBenchmark.scala | 212 ++++++++++++++++++ .../SegmentTreeWindowFunctionSuite.scala | 74 ++++++ .../UnboundedFollowingSegmentTreeSuite.scala | 53 +++++ .../WindowSegmentTreeAllowlistSuite.scala | 28 +-- 5 files changed, 368 insertions(+), 28 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FirstLastSegmentTreeWindowBenchmark.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowSegmentTree.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowSegmentTree.scala index 27a7736361341..cdc6556f18629 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowSegmentTree.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowSegmentTree.scala @@ -25,7 +25,7 @@ import org.apache.spark.SparkException import org.apache.spark.memory.{MemoryConsumer, MemoryMode, TaskMemoryManager} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.{Average, Count, DeclarativeAggregate, Max, Min, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp} +import org.apache.spark.sql.catalyst.expressions.aggregate.{Average, Count, DeclarativeAggregate, First, Last, Max, Min, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray import org.apache.spark.sql.types.DataType @@ -572,20 +572,29 @@ private[window] object WindowSegmentTree { /** * Explicit allowlist of [[DeclarativeAggregate]] subclasses safe for - * segment-tree execution. Safe iff combine semantics form a commutative - * monoid on the partial-buffer representation (associativity + - * compatibility with `mergeExpressions`): + * segment-tree execution. Safe iff combine semantics are correct under the + * left-to-right combine order produced by [[WindowSegmentTree.query]] + * (left partial -> full blocks ascending -> right partial; within a block, + * `queryDescend` walks children in ascending index order). * - * - [[Min]], [[Max]]: idempotent semilattice. - * - [[Sum]], [[Count]]: additive monoid. + * - [[Min]], [[Max]]: idempotent semilattice (associative + commutative). + * - [[Sum]], [[Count]]: additive monoid (associative + commutative). * - [[Average]]: sum + count, both additive monoids. * - [[StddevPop]], [[StddevSamp]], [[VariancePop]], [[VarianceSamp]]: * Welford (count, mean, M2) is associative -- see * CentralMomentAgg.mergeExpressions. + * - [[First]], [[Last]]: order-dependent but correct under left-to-right + * combine. `First.mergeExpressions` is `if(valueSet.left, left, right)` + * and `Last.mergeExpressions` is `if(valueSet.right, right, left)`; + * under the left-to-right traversal both pick the row-order extreme + * across any contiguous range. `IGNORE NULLS` is also handled: per-row + * `updateExpressions` only sets `valueSet=true` on non-null values, so + * a per-block partial of `(null, false)` for an all-NULL block is + * correctly skipped when merged with a later non-null block. * * Intentionally excluded (tracked as follow-up): HyperLogLogPlusPlus / - * ApproxCountDistinct (sketch-buffer interaction unaudited), First / Last - * (order-dependent), CollectList / CollectSet (unbounded buffer growth), + * ApproxCountDistinct (sketch-buffer interaction unaudited), + * CollectList / CollectSet (unbounded buffer growth), * Percentile / ApproxPercentile (sorted-sketch buffer), and any * ImperativeAggregate (excluded by the type check). * @@ -600,7 +609,9 @@ private[window] object WindowSegmentTree { classOf[StddevPop], classOf[StddevSamp], classOf[VariancePop], - classOf[VarianceSamp] + classOf[VarianceSamp], + classOf[First], + classOf[Last] ) /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FirstLastSegmentTreeWindowBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FirstLastSegmentTreeWindowBenchmark.scala new file mode 100644 index 0000000000000..8a9978065beff --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FirstLastSegmentTreeWindowBenchmark.scala @@ -0,0 +1,212 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.benchmark + +import org.apache.spark.benchmark.Benchmark +import org.apache.spark.sql.internal.SQLConf + +/** + * Benchmark for FIRST/LAST window aggregates over sliding and shrinking + * ROWS frames, comparing the legacy O(N x W) / O(N^2) frame paths against + * the segment-tree path enabled by adding `classOf[First]` / `classOf[Last]` + * to `WindowSegmentTree.EligibleAggregates`. + * + * Today's slow paths: + * - Sliding: `SlidingWindowFunctionFrame.write` rebuilds the per-row + * buffer aggregate by iterating `processor.update` over every row in + * the buffer (O(W) per output row, O(N*W) total). + * - Shrinking: `UnboundedFollowingWindowFunctionFrame.write` walks the + * remaining suffix on every output row (O(N^2) total; class scaladoc + * literally says O(n*(n-1)/2)). + * + * Sections: + * - A: FIRST/LAST per-mode at N=10K, sliding wide frame. + * - B: FIRST/LAST per-mode at N=10K, shrinking frame. + * - C: N-sweep for FIRST shrinking, naive vs segtree, demonstrating the + * algorithmic gap. Mirrors UnboundedFollowingWindowBenchmark layout. + */ +object FirstLastSegmentTreeWindowBenchmark extends SqlBasedBenchmark { + + // Section A/B: per-mode per-frame-shape at calibrated N + private val AB_N: Long = 10L * 1024L + + // Section C: N-sweep for shrinking FIRST + private val C_N_SMALL: Long = 5L * 1024L + private val C_N_MID: Long = 25L * 1024L + private val C_N_LARGE: Long = 50L * 1024L + private val C_N_HUGE: Long = 100L * 1024L + + private val ITERS_NORMAL: Int = 5 + private val ITERS_STRESS: Int = 3 + + // Sliding frame width tuned so the O(N*W) baseline is observable but not + // catastrophic. With N=10K and W=2001 the legacy SlidingWindowFunctionFrame + // performs ~20M update calls (10K rows x ~2K-row buffer rebuild on each + // boundary change) which is enough to expose the gap without dominating + // wall-clock at calibration time. + private val SLIDING_FRAME = + "OVER (ORDER BY id ROWS BETWEEN 1000 PRECEDING AND 1000 FOLLOWING)" + private val SHRINKING_FRAME = + "OVER (ORDER BY id ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING)" + + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { + val smokeMode = mainArgs.nonEmpty + val smokeRowCount = if (smokeMode) mainArgs(0).toLong else 0L + + def setupIntTable(n: Long): Unit = { + // Sprinkle ~10% NULLs so IGNORE NULLS exercises the merge path + // distinctly from respect-nulls; integer values otherwise. + spark.range(n) + .selectExpr("id", + "CASE WHEN rand(7) < 0.1 THEN NULL " + + "ELSE cast(rand(42) * 1000000 as int) END as v") + .coalesce(1) + .createOrReplaceTempView("t") + } + + def rowsLabel(rows: Long): String = { + if (rows >= 1000000) s"${rows / 1000000}M" + else if (rows >= 1024) s"${rows / 1024}K" + else rows.toString + } + + /** + * Equivalence digest. FIRST/LAST in respect-nulls mode are bit-exact + * across naive and segtree paths; in IGNORE NULLS the per-block merge + * also yields the same result row-for-row. We hash the result column + * directly (not COALESCE'd) so a NULL row hashes to NULL and a NULL + * sum is propagated, but the comparison still distinguishes shapes. + */ + def digest(sql: String, sqlConfs: (String, String)*): Long = { + withSQLConf(sqlConfs: _*) { + val r = spark.sql(s"SELECT SUM(HASH(m)) FROM (SELECT $sql AS m FROM t)") + .head().get(0) + if (r == null) 0L else r.asInstanceOf[Long] + } + } + + def runCase(label: String, sql: String, iters: Int, rows: Long): Unit = { + val dNaive = digest(sql) + val dSeg = digest(sql, SQLConf.WINDOW_SEGMENT_TREE_ENABLED.key -> "true") + require(dNaive == dSeg, + s"$label digest mismatch: naive=$dNaive seg=$dSeg") + + val benchmark = new Benchmark( + s"$label, N=${rowsLabel(rows)} rows", rows, output = output) + benchmark.addCase(s"$label naive", numIters = iters) { _ => + spark.sql(s"SELECT $sql FROM t").noop() + } + benchmark.addCase(s"$label segtree", numIters = iters) { _ => + withSQLConf(SQLConf.WINDOW_SEGMENT_TREE_ENABLED.key -> "true") { + spark.sql(s"SELECT $sql FROM t").noop() + } + } + benchmark.run() + } + + def runSweepCase(rows: Long, includeNaive: Boolean, iters: Int): Unit = { + val sql = s"FIRST(v) $SHRINKING_FRAME" + if (includeNaive) { + val dNaive = digest(sql) + val dSeg = digest(sql, SQLConf.WINDOW_SEGMENT_TREE_ENABLED.key -> "true") + require(dNaive == dSeg, + s"Section C N=${rowsLabel(rows)} digest mismatch: naive=$dNaive seg=$dSeg") + } + val benchmark = new Benchmark( + s"FIRST shrinking frame, N=${rowsLabel(rows)} rows" + + (if (!includeNaive) " (segtree-only)" else ""), + rows, output = output) + if (includeNaive) { + benchmark.addCase(s"FIRST naive N=${rowsLabel(rows)}", numIters = iters) { _ => + spark.sql(s"SELECT $sql FROM t").noop() + } + } + benchmark.addCase(s"FIRST segtree N=${rowsLabel(rows)}", numIters = iters) { _ => + withSQLConf(SQLConf.WINDOW_SEGMENT_TREE_ENABLED.key -> "true") { + spark.sql(s"SELECT $sql FROM t").noop() + } + } + benchmark.run() + } + + if (smokeMode) { + setupIntTable(smokeRowCount) + runBenchmark("SMOKE Section A FIRST sliding") { + runCase("FIRST sliding respect-nulls", + s"FIRST(v) $SLIDING_FRAME", ITERS_STRESS, smokeRowCount) + } + } else { + setupIntTable(AB_N) + + // Section A: sliding frame, all four mode/function combinations. + runBenchmark("Section A - FIRST sliding respect-nulls") { + runCase("FIRST sliding respect-nulls", + s"FIRST(v) $SLIDING_FRAME", ITERS_NORMAL, AB_N) + } + runBenchmark("Section A - LAST sliding respect-nulls") { + runCase("LAST sliding respect-nulls", + s"LAST(v) $SLIDING_FRAME", ITERS_NORMAL, AB_N) + } + runBenchmark("Section A - FIRST sliding IGNORE NULLS") { + runCase("FIRST sliding ignore-nulls", + s"FIRST(v) IGNORE NULLS $SLIDING_FRAME", ITERS_NORMAL, AB_N) + } + runBenchmark("Section A - LAST sliding IGNORE NULLS") { + runCase("LAST sliding ignore-nulls", + s"LAST(v) IGNORE NULLS $SLIDING_FRAME", ITERS_NORMAL, AB_N) + } + + // Section B: shrinking frame, all four mode/function combinations. + runBenchmark("Section B - FIRST shrinking respect-nulls") { + runCase("FIRST shrinking respect-nulls", + s"FIRST(v) $SHRINKING_FRAME", ITERS_NORMAL, AB_N) + } + runBenchmark("Section B - LAST shrinking respect-nulls") { + runCase("LAST shrinking respect-nulls", + s"LAST(v) $SHRINKING_FRAME", ITERS_NORMAL, AB_N) + } + runBenchmark("Section B - FIRST shrinking IGNORE NULLS") { + runCase("FIRST shrinking ignore-nulls", + s"FIRST(v) IGNORE NULLS $SHRINKING_FRAME", ITERS_NORMAL, AB_N) + } + runBenchmark("Section B - LAST shrinking IGNORE NULLS") { + runCase("LAST shrinking ignore-nulls", + s"LAST(v) IGNORE NULLS $SHRINKING_FRAME", ITERS_NORMAL, AB_N) + } + + // Section C: shrinking-frame N-sweep on FIRST, demonstrating O(N^2) + // legacy vs O(N log N) segtree gap widens with N. + setupIntTable(C_N_SMALL) + runBenchmark("Section C - N=5K") { + runSweepCase(C_N_SMALL, includeNaive = true, ITERS_NORMAL) + } + setupIntTable(C_N_MID) + runBenchmark("Section C - N=25K (stress)") { + runSweepCase(C_N_MID, includeNaive = true, ITERS_STRESS) + } + setupIntTable(C_N_LARGE) + runBenchmark("Section C - N=50K (stress, last naive run)") { + runSweepCase(C_N_LARGE, includeNaive = true, ITERS_STRESS) + } + setupIntTable(C_N_HUGE) + runBenchmark("Section C - N=100K (segtree-only, stress)") { + runSweepCase(C_N_HUGE, includeNaive = false, ITERS_STRESS) + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/window/SegmentTreeWindowFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/window/SegmentTreeWindowFunctionSuite.scala index 5e644cceb1426..ffe4e58449053 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/window/SegmentTreeWindowFunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/window/SegmentTreeWindowFunctionSuite.scala @@ -99,6 +99,19 @@ class SegmentTreeWindowFunctionSuite extends SharedSparkSession { baseDF.select($"id", $"pk", avg($"v").over(winSpec(-3, 3)).as("agg"))) } + // First / Last basic equivalence (respect-nulls; the default for + // first()/last()). Order-correctness depends on the segment-tree + // combine being left-to-right; see WindowSegmentTree.EligibleAggregates. + test("FIRST over ROWS BETWEEN 3 PRECEDING AND 3 FOLLOWING") { + checkEquivalence(() => + baseDF.select($"id", $"pk", first($"v").over(winSpec(-3, 3)).as("agg"))) + } + + test("LAST over ROWS BETWEEN 3 PRECEDING AND 3 FOLLOWING") { + checkEquivalence(() => + baseDF.select($"id", $"pk", last($"v").over(winSpec(-3, 3)).as("agg"))) + } + test("MIN + MAX + SUM share a single window frame") { checkEquivalence(() => baseDF.select( @@ -229,6 +242,67 @@ class SegmentTreeWindowFunctionSuite extends SharedSparkSession { count($"v").over(winSpec(-4, 4)).as("cn"))) } + // First / Last respect-nulls: NULL is a valid value. If the first row in the + // frame is NULL, FIRST returns NULL. The seg-tree merge must preserve this. + test("FIRST/LAST respect-nulls with mixed NULL frame contents") { + val df = spark.range(0, 60).selectExpr( + "id", + "(id % 3) AS pk", + "CASE WHEN id % 3 = 0 THEN NULL ELSE CAST(id AS INT) END AS v") + checkEquivalence(() => + df.select($"id", $"pk", + first($"v").over(winSpec(-4, 4)).as("fv"), + last($"v").over(winSpec(-4, 4)).as("lv"))) + } + + // First / Last IGNORE NULLS: per-row updates only set valueSet on non-null + // values. A per-block partial of (null, false) for an all-NULL block must + // be correctly skipped when merged with a later non-null block. + test("FIRST/LAST ignore-nulls with mixed NULL frame contents") { + val df = spark.range(0, 60).selectExpr( + "id", + "(id % 3) AS pk", + "CASE WHEN id % 3 = 0 THEN NULL ELSE CAST(id AS INT) END AS v") + checkEquivalence(() => + df.select($"id", $"pk", + first($"v", ignoreNulls = true).over(winSpec(-4, 4)).as("fv_ign"), + last($"v", ignoreNulls = true).over(winSpec(-4, 4)).as("lv_ign"))) + } + + // All-NULL column edge case for First/Last in both modes. + // Respect-nulls: returns NULL. Ignore-nulls: also returns NULL (no + // non-null candidate ever sets valueSet). + test("all-NULL column: FIRST/LAST in both modes") { + val df = spark.range(0, 30).selectExpr( + "id", "(id % 3) AS pk", "CAST(NULL AS INT) AS v") + checkEquivalence(() => + df.select($"id", $"pk", + first($"v").over(winSpec(-3, 3)).as("fv"), + last($"v").over(winSpec(-3, 3)).as("lv"), + first($"v", ignoreNulls = true).over(winSpec(-3, 3)).as("fv_ign"), + last($"v", ignoreNulls = true).over(winSpec(-3, 3)).as("lv_ign"))) + } + + // Adversarial NULL distribution for IGNORE NULLS: per-block aggregates need + // to compose correctly when an entire block is all-NULL. With block size + // 65536 and partition size 120 we cannot literally produce a fully-NULL + // block via the standard fixture, but a long stretch of consecutive NULLs + // exercises the same merge path (per-row updates produce intermediate + // valueSet=false buffers which then merge with a later valueSet=true buffer + // via mergeExpressions). Combined with a wide frame to force tree queries + // crossing the all-NULL stretch. + test("FIRST/LAST ignore-nulls: stretches of consecutive NULLs cross-merge correctly") { + val df = spark.range(0, 90).selectExpr( + "id", + "0 AS pk", + // First 30 rows non-null, next 30 all NULL, last 30 non-null again. + "CASE WHEN id BETWEEN 30 AND 59 THEN NULL ELSE CAST(id AS INT) END AS v") + checkEquivalence(() => + df.select($"id", $"pk", + first($"v", ignoreNulls = true).over(winSpec(-20, 20)).as("fv_ign"), + last($"v", ignoreNulls = true).over(winSpec(-20, 20)).as("lv_ign"))) + } + test("Double NaN and +/-Infinity propagate correctly through MIN/MAX/SUM") { // Trap: NaN > +Inf in Spark's MIN/MAX ordering; +Inf + -Inf = NaN in SUM. // Seg-tree uses DeclarativeAggregate.merge; behavior must match baseline. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/window/UnboundedFollowingSegmentTreeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/window/UnboundedFollowingSegmentTreeSuite.scala index e27267b6b5b5b..4ea836bff9566 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/window/UnboundedFollowingSegmentTreeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/window/UnboundedFollowingSegmentTreeSuite.scala @@ -117,6 +117,22 @@ class UnboundedFollowingSegmentTreeSuite extends SharedSparkSession { baseDF.select($"id", $"pk", avg($"v").over(shrinkingRowsFrame(0)).as("agg"))) } + // First / Last over a shrinking frame: in respect-nulls mode, FIRST is just + // `rows[lower]` (the first row of the suffix). LAST advances with the + // shrinking lower bound but always sees the partition's end. Both modes + // exercise the segment-tree merge path through a series of `[lower, n)` + // queries; correctness depends on the same left-to-right combine that + // makes First/Last safe in the sliding case. + test("FIRST over ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING") { + checkEquivalence(() => + baseDF.select($"id", $"pk", first($"v").over(shrinkingRowsFrame(0)).as("agg"))) + } + + test("LAST over ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING") { + checkEquivalence(() => + baseDF.select($"id", $"pk", last($"v").over(shrinkingRowsFrame(0)).as("agg"))) + } + // ============================================================ // ROWS frame: lower-bound variations // ============================================================ @@ -208,6 +224,43 @@ class UnboundedFollowingSegmentTreeSuite extends SharedSparkSession { count($"v").over(shrinkingRowsFrame(0)).as("c"))) } + // First / Last over a shrinking frame with NULL distribution. Mirrors the + // sliding-suite NULL tests; verifies the merge path is correct when the + // lower-edge partial block of the segtree query crosses a NULL/non-NULL + // boundary. + test("FIRST/LAST over shrinking frame: respect-nulls with mixed NULLs") { + val df = spark.range(0, 60).selectExpr( + "id", + "(id % 3) AS pk", + "CASE WHEN id % 3 = 0 THEN NULL ELSE CAST(id AS INT) END AS v") + checkEquivalence(() => + df.select($"id", $"pk", + first($"v").over(shrinkingRowsFrame(0)).as("fv"), + last($"v").over(shrinkingRowsFrame(0)).as("lv"))) + } + + test("FIRST/LAST over shrinking frame: ignore-nulls with mixed NULLs") { + val df = spark.range(0, 60).selectExpr( + "id", + "(id % 3) AS pk", + "CASE WHEN id % 3 = 0 THEN NULL ELSE CAST(id AS INT) END AS v") + checkEquivalence(() => + df.select($"id", $"pk", + first($"v", ignoreNulls = true).over(shrinkingRowsFrame(0)).as("fv_ign"), + last($"v", ignoreNulls = true).over(shrinkingRowsFrame(0)).as("lv_ign"))) + } + + test("all-NULL column: FIRST/LAST shrinking frame in both modes") { + val df = spark.range(0, 30).selectExpr("id", "(id % 3) AS pk", + "CAST(NULL AS INT) AS v") + checkEquivalence(() => + df.select($"id", $"pk", + first($"v").over(shrinkingRowsFrame(0)).as("fv"), + last($"v").over(shrinkingRowsFrame(0)).as("lv"), + first($"v", ignoreNulls = true).over(shrinkingRowsFrame(0)).as("fv_ign"), + last($"v", ignoreNulls = true).over(shrinkingRowsFrame(0)).as("lv_ign"))) + } + test("mixed NULL and non-NULL: NULLs must not leak into MIN/MAX") { val df = (0 until 60).map { i => val v: Option[Int] = if (i % 4 == 0) None else Some(i) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/window/WindowSegmentTreeAllowlistSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/window/WindowSegmentTreeAllowlistSuite.scala index 236d38cc6a910..8dbb4ceecf1bb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/window/WindowSegmentTreeAllowlistSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/window/WindowSegmentTreeAllowlistSuite.scala @@ -81,7 +81,13 @@ class WindowSegmentTreeAllowlistSuite ("stddev_pop", (c: org.apache.spark.sql.Column) => stddev_pop(c)), ("stddev_samp", (c: org.apache.spark.sql.Column) => stddev_samp(c)), ("var_pop", (c: org.apache.spark.sql.Column) => var_pop(c)), - ("var_samp", (c: org.apache.spark.sql.Column) => var_samp(c)) + ("var_samp", (c: org.apache.spark.sql.Column) => var_samp(c)), + ("first", (c: org.apache.spark.sql.Column) => first(c)), + ("last", (c: org.apache.spark.sql.Column) => last(c)), + ("first_ignore_nulls", + (c: org.apache.spark.sql.Column) => first(c, ignoreNulls = true)), + ("last_ignore_nulls", + (c: org.apache.spark.sql.Column) => last(c, ignoreNulls = true)) ).foreach { case (name, fn) => test(s"$name routes to the segment-tree path") { withSQLConf(enableSegTree.toSeq: _*) { @@ -96,22 +102,6 @@ class WindowSegmentTreeAllowlistSuite // Negative: non-allowlisted aggregates fall through - test("first_value falls through (order-dependent aggregate)") { - withSQLConf(enableSegTree.toSeq: _*) { - val df = baseDF.withColumn("agg", first($"v").over(winSpec)) - val (seg, _) = segTreeCounters(df) - assert(seg == 0, s"first_value should not use segment tree (got $seg frames)") - } - } - - test("last_value falls through (order-dependent aggregate)") { - withSQLConf(enableSegTree.toSeq: _*) { - val df = baseDF.withColumn("agg", last($"v").over(winSpec)) - val (seg, _) = segTreeCounters(df) - assert(seg == 0, s"last_value should not use segment tree (got $seg frames)") - } - } - test("collect_list falls through (unbounded buffer)") { withSQLConf(enableSegTree.toSeq: _*) { val df = baseDF.withColumn("agg", collect_list($"v").over(winSpec)) @@ -176,10 +166,10 @@ class WindowSegmentTreeAllowlistSuite withSQLConf(enableSegTree.toSeq: _*) { val df = baseDF .withColumn("s", sum($"v").over(winSpec)) - .withColumn("fv", first($"v").over(winSpec)) + .withColumn("cl", collect_list($"v").over(winSpec)) val (seg, _) = segTreeCounters(df) // Both aggregates share the same Window node; gating is forall(isEligible), - // so `first_value` drops the whole group. + // so `collect_list` (unbounded-buffer denylist) drops the whole group. assert(seg == 0, s"Window group containing a non-allowlisted agg must fall through (got $seg)") }