diff --git a/sql/core/benchmarks/UnboundedFollowingWindowBenchmark-results.txt b/sql/core/benchmarks/UnboundedFollowingWindowBenchmark-results.txt new file mode 100644 index 0000000000000..c8eafed7fb032 --- /dev/null +++ b/sql/core/benchmarks/UnboundedFollowingWindowBenchmark-results.txt @@ -0,0 +1,118 @@ +================================================================================================ +Section A - SUM (non-invertible suffix) +================================================================================================ + +OpenJDK 64-Bit Server VM 17.0.19+10-LTS on Linux 5.10.255-254.1008.amzn2int.x86_64 +Intel(R) Xeon(R) Platinum 8259CL CPU @ 2.50GHz +SUM shrinking frame, N=10K rows: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +SUM naive (master O(N^2)) 2471 2495 14 0.0 241298.5 1.0X +SUM segtree 110 115 4 0.1 10744.6 22.5X + + +================================================================================================ +Section A - MIN +================================================================================================ + +OpenJDK 64-Bit Server VM 17.0.19+10-LTS on Linux 5.10.255-254.1008.amzn2int.x86_64 +Intel(R) Xeon(R) Platinum 8259CL CPU @ 2.50GHz +MIN shrinking frame, N=10K rows: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +MIN naive (master O(N^2)) 2417 2438 23 0.0 236035.8 1.0X +MIN segtree 215 219 5 0.0 21015.3 11.2X + + +================================================================================================ +Section A - MAX +================================================================================================ + +OpenJDK 64-Bit Server VM 17.0.19+10-LTS on Linux 5.10.255-254.1008.amzn2int.x86_64 +Intel(R) Xeon(R) Platinum 8259CL CPU @ 2.50GHz +MAX shrinking frame, N=10K rows: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +MAX naive (master O(N^2)) 2396 2401 5 0.0 233937.5 1.0X +MAX segtree 228 229 1 0.0 22259.2 10.5X + + +================================================================================================ +Section A - COUNT +================================================================================================ + +OpenJDK 64-Bit Server VM 17.0.19+10-LTS on Linux 5.10.255-254.1008.amzn2int.x86_64 +Intel(R) Xeon(R) Platinum 8259CL CPU @ 2.50GHz +COUNT shrinking frame, N=10K rows: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +COUNT naive (master O(N^2)) 2203 2222 16 0.0 215139.0 1.0X +COUNT segtree 80 88 9 0.1 7846.1 27.4X + + +================================================================================================ +Section A - AVG (multi-buffer) +================================================================================================ + +OpenJDK 64-Bit Server VM 17.0.19+10-LTS on Linux 5.10.255-254.1008.amzn2int.x86_64 +Intel(R) Xeon(R) Platinum 8259CL CPU @ 2.50GHz +AVG shrinking frame, N=10K rows: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +AVG naive (master O(N^2)) 2886 2900 18 0.0 281837.8 1.0X +AVG segtree 84 86 4 0.1 8165.1 34.5X + + +================================================================================================ +Section B - N=5K +================================================================================================ + +OpenJDK 64-Bit Server VM 17.0.19+10-LTS on Linux 5.10.255-254.1008.amzn2int.x86_64 +Intel(R) Xeon(R) Platinum 8259CL CPU @ 2.50GHz +SUM shrinking frame, N=5K rows: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +SUM naive (master O(N^2)) N=5K 620 628 7 0.0 121170.2 1.0X +SUM segtree N=5K 73 74 1 0.1 14302.8 8.5X + + +================================================================================================ +Section B - N=25K (stress) +================================================================================================ + +OpenJDK 64-Bit Server VM 17.0.19+10-LTS on Linux 5.10.255-254.1008.amzn2int.x86_64 +Intel(R) Xeon(R) Platinum 8259CL CPU @ 2.50GHz +SUM shrinking frame, N=25K rows: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +SUM naive (master O(N^2)) N=25K 14259 14341 108 0.0 556977.9 1.0X +SUM segtree N=25K 119 120 0 0.2 4667.1 119.3X + + +================================================================================================ +Section B - N=50K (stress, last naive run) +================================================================================================ + +OpenJDK 64-Bit Server VM 17.0.19+10-LTS on Linux 5.10.255-254.1008.amzn2int.x86_64 +Intel(R) Xeon(R) Platinum 8259CL CPU @ 2.50GHz +SUM shrinking frame, N=50K rows: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +SUM naive (master O(N^2)) N=50K 57022 57659 987 0.0 1113704.1 1.0X +SUM segtree N=50K 181 182 1 0.3 3544.3 314.2X + + +================================================================================================ +Section B - N=100K (segtree-only, stress) +================================================================================================ + +OpenJDK 64-Bit Server VM 17.0.19+10-LTS on Linux 5.10.255-254.1008.amzn2int.x86_64 +Intel(R) Xeon(R) Platinum 8259CL CPU @ 2.50GHz +SUM shrinking frame, N=100K rows (segtree-only): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------------- +SUM segtree N=100K 269 270 2 0.4 2627.9 1.0X + + +================================================================================================ +Section B - N=200K (segtree-only, stress) +================================================================================================ + +OpenJDK 64-Bit Server VM 17.0.19+10-LTS on Linux 5.10.255-254.1008.amzn2int.x86_64 +Intel(R) Xeon(R) Platinum 8259CL CPU @ 2.50GHz +SUM shrinking frame, N=200K rows (segtree-only): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------------- +SUM segtree N=200K 480 481 1 0.4 2343.7 1.0X + + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/SegmentTreeWindowFunctionFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/SegmentTreeWindowFunctionFrame.scala index 51648e31e3498..496126150f54b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/SegmentTreeWindowFunctionFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/SegmentTreeWindowFunctionFrame.scala @@ -27,14 +27,21 @@ import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.internal.SQLConf /** - * Moving-frame window function frame backed by [[WindowSegmentTree]]. Produces - * the same outputs as [[SlidingWindowFunctionFrame]] for RowFrame or - * single-column RangeFrame moving frames whose aggregates are all - * [[DeclarativeAggregate]] with no FILTER/DISTINCT. For partitions below - * `spark.sql.window.segmentTree.minPartitionRows`, delegates to a wrapped - * [[SlidingWindowFunctionFrame]]. Under RANGE, two forward-only cursors - * (`lowerIter` / `upperIter`) advance the bounds in O(n) total; the segtree - * answers `[lowerBound, upperBound)` in O(log n). + * Window function frame backed by [[WindowSegmentTree]]. Handles two frame + * shapes: + * - **Sliding** (`ubound = Some(...)`): both edges move; mirrors + * [[SlidingWindowFunctionFrame]]. O(N log W) total. + * - **Shrinking** (`ubound = None`): upper edge pinned to partition end + * (`BETWEEN AND UNBOUNDED FOLLOWING`); replaces + * [[UnboundedFollowingWindowFunctionFrame]]'s O(N^2) full recompute with + * O(N log N). + * + * Eligibility, build, spill, and memory accounting are identical for both + * shapes; only the per-row cursor logic differs (admit+drop for sliding, + * drop-only for shrinking). + * + * For partitions below `spark.sql.window.segmentTree.minPartitionRows`, + * delegates to a frame produced by `fallbackFactory`. * * @note Not thread-safe. */ @@ -45,7 +52,8 @@ private[window] final class SegmentTreeWindowFunctionFrame( inputSchema: Seq[Attribute], frameType: FrameType, lbound: BoundOrdering, - ubound: BoundOrdering, + ubound: Option[BoundOrdering], + fallbackFactory: () => WindowFunctionFrame, newMutableProjection: (Seq[Expression], Seq[Attribute]) => MutableProjection, conf: SQLConf, maxCachedBlocks: Option[Int], @@ -57,16 +65,18 @@ private[window] final class SegmentTreeWindowFunctionFrame( require(frameType == RowFrame || frameType == RangeFrame, s"SegmentTreeWindowFunctionFrame supports RowFrame or RangeFrame, got $frameType") - private[this] var fallback: SlidingWindowFunctionFrame = _ + // True when this is a shrinking-frame (UnboundedFollowing) instance. + // Shorthand to avoid repeated `ubound.isEmpty` reads in hot loops. + private[this] val shrinking: Boolean = ubound.isEmpty + + private[this] var fallback: WindowFunctionFrame = _ private[this] var tree: WindowSegmentTree = _ /** - * Allocate a fresh fallback sliding-window frame. Called lazily from - * `prepare()` on the small-partition path. Factored out for testability - * (subclasses can inject a throwing fallback for prepare-failure tests). + * Allocate a fresh fallback frame via `fallbackFactory`. Called lazily + * from `prepare()` on the small-partition path. */ - private[window] def newFallback(): SlidingWindowFunctionFrame = - new SlidingWindowFunctionFrame(target, processor, lbound, ubound) + private[window] def newFallback(): WindowFunctionFrame = fallbackFactory() /** Test hook: whether the fallback frame has been lazily allocated. */ private[window] def fallbackAllocated: Boolean = fallback != null @@ -100,8 +110,11 @@ private[window] final class SegmentTreeWindowFunctionFrame( /** * Runtime dispatch flag: when `true`, `write()`, `currentLowerBound()`, and - * `currentUpperBound()` delegate to the wrapped [[SlidingWindowFunctionFrame]] - * (small-partition path). Set by `prepare()` based on partition size vs. + * `currentUpperBound()` delegate to the wrapped fallback frame produced by + * `fallbackFactory` (small-partition path). The fallback type is shape- + * dependent: [[SlidingWindowFunctionFrame]] for moving frames and + * [[UnboundedFollowingWindowFunctionFrame]] for shrinking frames. Set by + * `prepare()` based on partition size vs. * `spark.sql.window.segmentTree.minPartitionRows`. */ private[window] var fallbackUsed: Boolean = false @@ -155,19 +168,31 @@ private[window] final class SegmentTreeWindowFunctionFrame( // Count only on the successful segtree path: if `tree.build` throws, // the counter is not bumped. numSegmentTreeFrames.foreach(_ += 1) - frameType match { - case RowFrame => - boundIter = rows.generateIterator() - nextRow = WindowFunctionFrame.getNextOrNull(boundIter) - case RangeFrame => - lowerIter = rows.generateIterator() - upperIter = rows.generateIterator() - // Pre-seed cursor heads so `RangeBoundOrdering.compare` never - // dereferences null on round 0. Either may be null if `rows` is - // empty; the advance loops' `!= null` / `< upperBound` guards - // handle that. - lowerRow = WindowFunctionFrame.getNextOrNull(lowerIter) - upperRow = WindowFunctionFrame.getNextOrNull(upperIter) + if (shrinking) { + // Upper bound pinned to partition end; never moves. + upperBound = tree.size + frameType match { + case RowFrame => + // RowFrame lower-bound advance is pure index arithmetic; no iterator. + case RangeFrame => + lowerIter = rows.generateIterator() + lowerRow = WindowFunctionFrame.getNextOrNull(lowerIter) + } + } else { + frameType match { + case RowFrame => + boundIter = rows.generateIterator() + nextRow = WindowFunctionFrame.getNextOrNull(boundIter) + case RangeFrame => + lowerIter = rows.generateIterator() + upperIter = rows.generateIterator() + // Pre-seed cursor heads so `RangeBoundOrdering.compare` never + // dereferences null on round 0. Either may be null if `rows` is + // empty; the advance loops' `!= null` / `< upperBound` guards + // handle that. + lowerRow = WindowFunctionFrame.getNextOrNull(lowerIter) + upperRow = WindowFunctionFrame.getNextOrNull(upperIter) + } } } @@ -196,27 +221,42 @@ private[window] final class SegmentTreeWindowFunctionFrame( } } - // `writeRow`/`writeRange` mirror the `(lowerBound, upperBound)` monotone - // cursor invariant of `SlidingWindowFunctionFrame.write`, but run - // admit-then-drop (no buffer to maintain) instead of drop-then-admit. - // Any future fix to Sliding's boundary semantics must be mirrored here; - // equivalence is guarded by `SegmentTreeWindowFunctionSuite` flag-on/off - // tests (`checkRangeEquivalence`, `feature flag off ...`, fallback tests) - // which compare against the Sliding baseline. + // `writeRow`/`writeRange` maintain the `(lowerBound, upperBound)` monotone + // cursor invariant for both sliding and shrinking frame shapes: + // + // - Sliding (`ubound.isDefined`, mirrors `SlidingWindowFunctionFrame.write`): + // run admit-then-drop (no buffer to maintain) instead of drop-then-admit. + // The admit loop below (`if (!shrinking)`) extends `upperBound`; the drop + // loop advances `lowerBound`. Any future fix to Sliding's boundary + // semantics must be mirrored here; equivalence is guarded by + // `SegmentTreeWindowFunctionSuite` flag-on/off tests + // (`checkRangeEquivalence`, `feature flag off ...`, fallback tests) + // against the Sliding baseline. + // + // - Shrinking (`ubound.isEmpty`, upper is `tree.size`): drop-only. The admit + // loop is skipped; only `lowerBound` advances each step. Equivalence is + // guarded by `UnboundedFollowingSegmentTreeSuite` against the + // `UnboundedFollowingWindowFunctionFrame` baseline. + // + // In both shapes, the segtree's `query(lowerBound, upperBound, ...)` is + // re-issued only when `boundsChanged` is true. private def writeRow(index: Int, current: InternalRow): Unit = { var boundsChanged = index == 0 - // admit loop: extend upperBound; if a candidate is already below the - // lower bound, advance lowerBound in lock-step to preserve invariant - // (0 <= lowerBound <= upperBound <= tree.size). - while (nextRow != null && - ubound.compare(nextRow, upperBound, current, index) <= 0) { - if (lbound.compare(nextRow, lowerBound, current, index) < 0) { - lowerBound += 1 + if (!shrinking) { + val ub = ubound.get + // admit loop: extend upperBound; if a candidate is already below the + // lower bound, advance lowerBound in lock-step to preserve invariant + // (0 <= lowerBound <= upperBound <= tree.size). + while (nextRow != null && + ub.compare(nextRow, upperBound, current, index) <= 0) { + if (lbound.compare(nextRow, lowerBound, current, index) < 0) { + lowerBound += 1 + } + nextRow = WindowFunctionFrame.getNextOrNull(boundIter) + upperBound += 1 + boundsChanged = true } - nextRow = WindowFunctionFrame.getNextOrNull(boundIter) - upperBound += 1 - boundsChanged = true } // drop loop: advance lowerBound to the frame's left edge. RowFrame's // `lbound.compare` is pure index arithmetic so the input row is unread; @@ -235,13 +275,16 @@ private[window] final class SegmentTreeWindowFunctionFrame( private def writeRange(index: Int, current: InternalRow): Unit = { var boundsChanged = index == 0 - // admit loop (upper edge). `RangeBoundOrdering.compare` ignores its index - // arguments; we pass `upperBound` for API symmetry with RowBoundOrdering. - while (upperRow != null && - ubound.compare(upperRow, upperBound, current, index) <= 0) { - upperBound += 1 - upperRow = WindowFunctionFrame.getNextOrNull(upperIter) - boundsChanged = true + if (!shrinking) { + val ub = ubound.get + // admit loop (upper edge). `RangeBoundOrdering.compare` ignores its index + // arguments; we pass `upperBound` for API symmetry with RowBoundOrdering. + while (upperRow != null && + ub.compare(upperRow, upperBound, current, index) <= 0) { + upperBound += 1 + upperRow = WindowFunctionFrame.getNextOrNull(upperIter) + boundsChanged = true + } } // drop loop (lower edge): strict `< 0`, guarded by diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactoryBase.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactoryBase.scala index 2ae10ce9d711c..40cba3d5ceb4c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactoryBase.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactoryBase.scala @@ -281,11 +281,55 @@ trait WindowEvaluatorFactoryBase { // Shrinking Frame. case ("AGGREGATE", frameType, lower, UnboundedFollowing, _) => - target: InternalRow => { - new UnboundedFollowingWindowFunctionFrame( - target, - processor, - createBoundOrdering(frameType, lower, timeZone)) + if (eligibleForSegTree(functions, aggFilters, frameType, conf)) { + val segFns = functions.map(_.asInstanceOf[DeclarativeAggregate]) + // Shrinking-frame queries `[lower, n)` on `WindowSegmentTree` touch the LRU + // for exactly two blocks per query: (1) the lower-edge partial block, and + // (2) the partition's last block (the right-partial `mergeBlockRange(bhi, 0, + // ...)` calls `ensureBlockLevels(bhi)` on every multi-block query). Middle + // blocks of `[lower, n)` are answered directly from `blockAggregates` and + // never go through the LRU. The lower-edge block advances monotonically with + // the output row, so once the cursor crosses a boundary the previous block + // is never revisited; the last block stays hot because every query touches + // it. Hint = 2 keeps both resident; routing through `estimateMaxCachedBlocks` + // would produce 8 by default (no `IntegerLiteral` upper match) -- correct + // numerically but misleading about what the shrinking path actually needs. + // Note: tuning this down to 1 would thrash, evicting the last block on every + // query and forcing it to be rebuilt. + val cacheHint = Some(2) + target: InternalRow => { + val tc = TaskContext.get() + if (tc == null) { + throw SparkException.internalError( + "WindowEvaluatorFactoryBase.shrinkingSegTreeFrameFactory requires " + + "an active TaskContext") + } + val tmm = tc.taskMemoryManager() + val lb = createBoundOrdering(frameType, lower, timeZone) + new SegmentTreeWindowFunctionFrame( + target, + processor, + segFns, + childOutput, + frameType, + lb, + ubound = None, + fallbackFactory = () => + new UnboundedFollowingWindowFunctionFrame(target, processor, lb), + (e, s) => MutableProjection.create(e, s), + conf, + cacheHint, + tmm, + numSegmentTreeFrames, + numSegmentTreeFallbackFrames) + } + } else { + target: InternalRow => { + new UnboundedFollowingWindowFunctionFrame( + target, + processor, + createBoundOrdering(frameType, lower, timeZone)) + } } // Moving Frame. @@ -305,14 +349,18 @@ trait WindowEvaluatorFactoryBase { "an active TaskContext") } val tmm = tc.taskMemoryManager() + val lb = createBoundOrdering(frameType, lower, timeZone) + val ub = createBoundOrdering(frameType, upper, timeZone) new SegmentTreeWindowFunctionFrame( target, processor, segFns, childOutput, frameType, - createBoundOrdering(frameType, lower, timeZone), - createBoundOrdering(frameType, upper, timeZone), + lb, + ubound = Some(ub), + fallbackFactory = () => + new SlidingWindowFunctionFrame(target, processor, lb, ub), (e, s) => MutableProjection.create(e, s), conf, cacheHint, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowSegmentTree.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowSegmentTree.scala index 27a7736361341..cdc6556f18629 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowSegmentTree.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowSegmentTree.scala @@ -25,7 +25,7 @@ import org.apache.spark.SparkException import org.apache.spark.memory.{MemoryConsumer, MemoryMode, TaskMemoryManager} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.{Average, Count, DeclarativeAggregate, Max, Min, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp} +import org.apache.spark.sql.catalyst.expressions.aggregate.{Average, Count, DeclarativeAggregate, First, Last, Max, Min, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray import org.apache.spark.sql.types.DataType @@ -572,20 +572,29 @@ private[window] object WindowSegmentTree { /** * Explicit allowlist of [[DeclarativeAggregate]] subclasses safe for - * segment-tree execution. Safe iff combine semantics form a commutative - * monoid on the partial-buffer representation (associativity + - * compatibility with `mergeExpressions`): + * segment-tree execution. Safe iff combine semantics are correct under the + * left-to-right combine order produced by [[WindowSegmentTree.query]] + * (left partial -> full blocks ascending -> right partial; within a block, + * `queryDescend` walks children in ascending index order). * - * - [[Min]], [[Max]]: idempotent semilattice. - * - [[Sum]], [[Count]]: additive monoid. + * - [[Min]], [[Max]]: idempotent semilattice (associative + commutative). + * - [[Sum]], [[Count]]: additive monoid (associative + commutative). * - [[Average]]: sum + count, both additive monoids. * - [[StddevPop]], [[StddevSamp]], [[VariancePop]], [[VarianceSamp]]: * Welford (count, mean, M2) is associative -- see * CentralMomentAgg.mergeExpressions. + * - [[First]], [[Last]]: order-dependent but correct under left-to-right + * combine. `First.mergeExpressions` is `if(valueSet.left, left, right)` + * and `Last.mergeExpressions` is `if(valueSet.right, right, left)`; + * under the left-to-right traversal both pick the row-order extreme + * across any contiguous range. `IGNORE NULLS` is also handled: per-row + * `updateExpressions` only sets `valueSet=true` on non-null values, so + * a per-block partial of `(null, false)` for an all-NULL block is + * correctly skipped when merged with a later non-null block. * * Intentionally excluded (tracked as follow-up): HyperLogLogPlusPlus / - * ApproxCountDistinct (sketch-buffer interaction unaudited), First / Last - * (order-dependent), CollectList / CollectSet (unbounded buffer growth), + * ApproxCountDistinct (sketch-buffer interaction unaudited), + * CollectList / CollectSet (unbounded buffer growth), * Percentile / ApproxPercentile (sorted-sketch buffer), and any * ImperativeAggregate (excluded by the type check). * @@ -600,7 +609,9 @@ private[window] object WindowSegmentTree { classOf[StddevPop], classOf[StddevSamp], classOf[VariancePop], - classOf[VarianceSamp] + classOf[VarianceSamp], + classOf[First], + classOf[Last] ) /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FirstLastSegmentTreeWindowBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FirstLastSegmentTreeWindowBenchmark.scala new file mode 100644 index 0000000000000..8a9978065beff --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FirstLastSegmentTreeWindowBenchmark.scala @@ -0,0 +1,212 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.benchmark + +import org.apache.spark.benchmark.Benchmark +import org.apache.spark.sql.internal.SQLConf + +/** + * Benchmark for FIRST/LAST window aggregates over sliding and shrinking + * ROWS frames, comparing the legacy O(N x W) / O(N^2) frame paths against + * the segment-tree path enabled by adding `classOf[First]` / `classOf[Last]` + * to `WindowSegmentTree.EligibleAggregates`. + * + * Today's slow paths: + * - Sliding: `SlidingWindowFunctionFrame.write` rebuilds the per-row + * buffer aggregate by iterating `processor.update` over every row in + * the buffer (O(W) per output row, O(N*W) total). + * - Shrinking: `UnboundedFollowingWindowFunctionFrame.write` walks the + * remaining suffix on every output row (O(N^2) total; class scaladoc + * literally says O(n*(n-1)/2)). + * + * Sections: + * - A: FIRST/LAST per-mode at N=10K, sliding wide frame. + * - B: FIRST/LAST per-mode at N=10K, shrinking frame. + * - C: N-sweep for FIRST shrinking, naive vs segtree, demonstrating the + * algorithmic gap. Mirrors UnboundedFollowingWindowBenchmark layout. + */ +object FirstLastSegmentTreeWindowBenchmark extends SqlBasedBenchmark { + + // Section A/B: per-mode per-frame-shape at calibrated N + private val AB_N: Long = 10L * 1024L + + // Section C: N-sweep for shrinking FIRST + private val C_N_SMALL: Long = 5L * 1024L + private val C_N_MID: Long = 25L * 1024L + private val C_N_LARGE: Long = 50L * 1024L + private val C_N_HUGE: Long = 100L * 1024L + + private val ITERS_NORMAL: Int = 5 + private val ITERS_STRESS: Int = 3 + + // Sliding frame width tuned so the O(N*W) baseline is observable but not + // catastrophic. With N=10K and W=2001 the legacy SlidingWindowFunctionFrame + // performs ~20M update calls (10K rows x ~2K-row buffer rebuild on each + // boundary change) which is enough to expose the gap without dominating + // wall-clock at calibration time. + private val SLIDING_FRAME = + "OVER (ORDER BY id ROWS BETWEEN 1000 PRECEDING AND 1000 FOLLOWING)" + private val SHRINKING_FRAME = + "OVER (ORDER BY id ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING)" + + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { + val smokeMode = mainArgs.nonEmpty + val smokeRowCount = if (smokeMode) mainArgs(0).toLong else 0L + + def setupIntTable(n: Long): Unit = { + // Sprinkle ~10% NULLs so IGNORE NULLS exercises the merge path + // distinctly from respect-nulls; integer values otherwise. + spark.range(n) + .selectExpr("id", + "CASE WHEN rand(7) < 0.1 THEN NULL " + + "ELSE cast(rand(42) * 1000000 as int) END as v") + .coalesce(1) + .createOrReplaceTempView("t") + } + + def rowsLabel(rows: Long): String = { + if (rows >= 1000000) s"${rows / 1000000}M" + else if (rows >= 1024) s"${rows / 1024}K" + else rows.toString + } + + /** + * Equivalence digest. FIRST/LAST in respect-nulls mode are bit-exact + * across naive and segtree paths; in IGNORE NULLS the per-block merge + * also yields the same result row-for-row. We hash the result column + * directly (not COALESCE'd) so a NULL row hashes to NULL and a NULL + * sum is propagated, but the comparison still distinguishes shapes. + */ + def digest(sql: String, sqlConfs: (String, String)*): Long = { + withSQLConf(sqlConfs: _*) { + val r = spark.sql(s"SELECT SUM(HASH(m)) FROM (SELECT $sql AS m FROM t)") + .head().get(0) + if (r == null) 0L else r.asInstanceOf[Long] + } + } + + def runCase(label: String, sql: String, iters: Int, rows: Long): Unit = { + val dNaive = digest(sql) + val dSeg = digest(sql, SQLConf.WINDOW_SEGMENT_TREE_ENABLED.key -> "true") + require(dNaive == dSeg, + s"$label digest mismatch: naive=$dNaive seg=$dSeg") + + val benchmark = new Benchmark( + s"$label, N=${rowsLabel(rows)} rows", rows, output = output) + benchmark.addCase(s"$label naive", numIters = iters) { _ => + spark.sql(s"SELECT $sql FROM t").noop() + } + benchmark.addCase(s"$label segtree", numIters = iters) { _ => + withSQLConf(SQLConf.WINDOW_SEGMENT_TREE_ENABLED.key -> "true") { + spark.sql(s"SELECT $sql FROM t").noop() + } + } + benchmark.run() + } + + def runSweepCase(rows: Long, includeNaive: Boolean, iters: Int): Unit = { + val sql = s"FIRST(v) $SHRINKING_FRAME" + if (includeNaive) { + val dNaive = digest(sql) + val dSeg = digest(sql, SQLConf.WINDOW_SEGMENT_TREE_ENABLED.key -> "true") + require(dNaive == dSeg, + s"Section C N=${rowsLabel(rows)} digest mismatch: naive=$dNaive seg=$dSeg") + } + val benchmark = new Benchmark( + s"FIRST shrinking frame, N=${rowsLabel(rows)} rows" + + (if (!includeNaive) " (segtree-only)" else ""), + rows, output = output) + if (includeNaive) { + benchmark.addCase(s"FIRST naive N=${rowsLabel(rows)}", numIters = iters) { _ => + spark.sql(s"SELECT $sql FROM t").noop() + } + } + benchmark.addCase(s"FIRST segtree N=${rowsLabel(rows)}", numIters = iters) { _ => + withSQLConf(SQLConf.WINDOW_SEGMENT_TREE_ENABLED.key -> "true") { + spark.sql(s"SELECT $sql FROM t").noop() + } + } + benchmark.run() + } + + if (smokeMode) { + setupIntTable(smokeRowCount) + runBenchmark("SMOKE Section A FIRST sliding") { + runCase("FIRST sliding respect-nulls", + s"FIRST(v) $SLIDING_FRAME", ITERS_STRESS, smokeRowCount) + } + } else { + setupIntTable(AB_N) + + // Section A: sliding frame, all four mode/function combinations. + runBenchmark("Section A - FIRST sliding respect-nulls") { + runCase("FIRST sliding respect-nulls", + s"FIRST(v) $SLIDING_FRAME", ITERS_NORMAL, AB_N) + } + runBenchmark("Section A - LAST sliding respect-nulls") { + runCase("LAST sliding respect-nulls", + s"LAST(v) $SLIDING_FRAME", ITERS_NORMAL, AB_N) + } + runBenchmark("Section A - FIRST sliding IGNORE NULLS") { + runCase("FIRST sliding ignore-nulls", + s"FIRST(v) IGNORE NULLS $SLIDING_FRAME", ITERS_NORMAL, AB_N) + } + runBenchmark("Section A - LAST sliding IGNORE NULLS") { + runCase("LAST sliding ignore-nulls", + s"LAST(v) IGNORE NULLS $SLIDING_FRAME", ITERS_NORMAL, AB_N) + } + + // Section B: shrinking frame, all four mode/function combinations. + runBenchmark("Section B - FIRST shrinking respect-nulls") { + runCase("FIRST shrinking respect-nulls", + s"FIRST(v) $SHRINKING_FRAME", ITERS_NORMAL, AB_N) + } + runBenchmark("Section B - LAST shrinking respect-nulls") { + runCase("LAST shrinking respect-nulls", + s"LAST(v) $SHRINKING_FRAME", ITERS_NORMAL, AB_N) + } + runBenchmark("Section B - FIRST shrinking IGNORE NULLS") { + runCase("FIRST shrinking ignore-nulls", + s"FIRST(v) IGNORE NULLS $SHRINKING_FRAME", ITERS_NORMAL, AB_N) + } + runBenchmark("Section B - LAST shrinking IGNORE NULLS") { + runCase("LAST shrinking ignore-nulls", + s"LAST(v) IGNORE NULLS $SHRINKING_FRAME", ITERS_NORMAL, AB_N) + } + + // Section C: shrinking-frame N-sweep on FIRST, demonstrating O(N^2) + // legacy vs O(N log N) segtree gap widens with N. + setupIntTable(C_N_SMALL) + runBenchmark("Section C - N=5K") { + runSweepCase(C_N_SMALL, includeNaive = true, ITERS_NORMAL) + } + setupIntTable(C_N_MID) + runBenchmark("Section C - N=25K (stress)") { + runSweepCase(C_N_MID, includeNaive = true, ITERS_STRESS) + } + setupIntTable(C_N_LARGE) + runBenchmark("Section C - N=50K (stress, last naive run)") { + runSweepCase(C_N_LARGE, includeNaive = true, ITERS_STRESS) + } + setupIntTable(C_N_HUGE) + runBenchmark("Section C - N=100K (segtree-only, stress)") { + runSweepCase(C_N_HUGE, includeNaive = false, ITERS_STRESS) + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/UnboundedFollowingWindowBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/UnboundedFollowingWindowBenchmark.scala new file mode 100644 index 0000000000000..7fbce2f35be7c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/UnboundedFollowingWindowBenchmark.scala @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.benchmark + +import org.apache.spark.benchmark.Benchmark +import org.apache.spark.sql.internal.SQLConf + +/** + * Benchmark for shrinking ROWS frames (... BETWEEN AND UNBOUNDED FOLLOWING). + * + * Today's `UnboundedFollowingWindowFunctionFrame` runs the suffix aggregate + * O(n * (n - 1) / 2) per partition (acknowledged inline at + * `WindowFunctionFrame.scala:636`). The segtree path replaces this with + * O(n log n) when `spark.sql.window.segmentTree.enabled=true`. + * + * Layout: single partition on a non-partitioned ORDER BY (the worst case; + * with PARTITION BY the cost decomposes as sum-of-partition N^2 and is + * dominated by the largest partition). + * + * Sections: + * - A: per-aggregate equivalence at N=10K (naive ~3-5s/iter target). + * - B: N-sweep for SUM, naive vs segtree, demonstrating the algorithmic + * gap. N=50K is the largest naive run (~60s/iter); N=100K and 200K + * are segtree-only because naive would take ~4-16 min/iter. + */ +object UnboundedFollowingWindowBenchmark extends SqlBasedBenchmark { + + // Section A: calibrated so naive baseline lands ~3s/iter at A_N. + private val A_N: Long = 10L * 1024L // ~2.4s naive @ N=10K (smoke: 2391ms) + + // Section B: N-sweep + private val B_N_SMALL: Long = 5L * 1024L // ~1.2s naive + private val B_N_MID: Long = 25L * 1024L // ~14s naive + private val B_N_LARGE: Long = 50L * 1024L // ~57s naive (last naive run) + private val B_N_HUGE: Long = 100L * 1024L // segtree-only, naive would be ~4 min + private val B_N_GIANT: Long = 200L * 1024L // segtree-only, naive would be ~16 min + + private val ITERS_NORMAL: Int = 5 + private val ITERS_STRESS: Int = 3 + + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { + val smokeMode = mainArgs.nonEmpty + val smokeRowCount = if (smokeMode) mainArgs(0).toLong else 0L + + def setupIntTable(n: Long): Unit = { + spark.range(n) + .selectExpr("id", "cast(rand(42) * 1000000 as int) as v") + .coalesce(1) + .createOrReplaceTempView("t") + } + + // Shrinking frame: [current, end-of-partition). + val frame = "OVER (ORDER BY id ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING)" + + // Digest comparison ensuring naive and segtree produce identical results. + // SUM/COUNT/MIN/MAX on integers are bit-exact across paths. + def digest(aggFn: String, sqlConfs: (String, String)*): Long = { + withSQLConf(sqlConfs: _*) { + spark.sql(s"SELECT SUM(HASH(m)) FROM (SELECT $aggFn(v) $frame AS m FROM t)") + .head().getLong(0) + } + } + + def rowsLabel(rows: Long): String = { + if (rows >= 1000000) s"${rows / 1000000}M" + else if (rows >= 1024) s"${rows / 1024}K" + else rows.toString + } + + /** + * Section A: Run the same SQL with conf off and on. The naive case is the + * baseline, so iterations must be cheap enough to fit ~3-5s/iter. + */ + def runSectionA(aggFn: String, iters: Int, rows: Long): Unit = { + val dNaive = digest(aggFn) + val dSeg = digest(aggFn, SQLConf.WINDOW_SEGMENT_TREE_ENABLED.key -> "true") + require(dNaive == dSeg, + s"$aggFn shrinking-frame digest mismatch: naive=$dNaive seg=$dSeg") + + val benchmark = new Benchmark( + s"$aggFn shrinking frame, N=${rowsLabel(rows)} rows", + rows, output = output) + benchmark.addCase(s"$aggFn naive (master O(N^2))", numIters = iters) { _ => + spark.sql(s"SELECT $aggFn(v) $frame FROM t").noop() + } + benchmark.addCase(s"$aggFn segtree", numIters = iters) { _ => + withSQLConf(SQLConf.WINDOW_SEGMENT_TREE_ENABLED.key -> "true") { + spark.sql(s"SELECT $aggFn(v) $frame FROM t").noop() + } + } + benchmark.run() + } + + /** + * Section B: SUM-only N-sweep. At N <= 50K we run both paths. At N >= 100K + * we run segtree-only because naive would dominate the benchmark wall-clock. + */ + def runSectionB(rows: Long, includeNaive: Boolean, iters: Int): Unit = { + if (includeNaive) { + val dNaive = digest("SUM") + val dSeg = digest("SUM", SQLConf.WINDOW_SEGMENT_TREE_ENABLED.key -> "true") + require(dNaive == dSeg, + s"Section B N=${rowsLabel(rows)} digest mismatch: naive=$dNaive seg=$dSeg") + } + val benchmark = new Benchmark( + s"SUM shrinking frame, N=${rowsLabel(rows)} rows" + + (if (!includeNaive) " (segtree-only)" else ""), + rows, output = output) + if (includeNaive) { + benchmark.addCase(s"SUM naive (master O(N^2)) N=${rowsLabel(rows)}", + numIters = iters) { _ => + spark.sql(s"SELECT SUM(v) $frame FROM t").noop() + } + } + benchmark.addCase(s"SUM segtree N=${rowsLabel(rows)}", numIters = iters) { _ => + withSQLConf(SQLConf.WINDOW_SEGMENT_TREE_ENABLED.key -> "true") { + spark.sql(s"SELECT SUM(v) $frame FROM t").noop() + } + } + benchmark.run() + } + + if (smokeMode) { + setupIntTable(smokeRowCount) + runBenchmark("SMOKE Section A SUM") { + runSectionA("SUM", ITERS_STRESS, smokeRowCount) + } + } else { + // Section A: per-aggregate (SUM, MIN, MAX, COUNT, AVG) at calibrated N=10K. + // STDDEV omitted: shrinking frame doesn't widen multi-buffer aggregates' + // win profile vs sliding (the gain is purely algorithmic, not buffer-pack). + setupIntTable(A_N) + runBenchmark("Section A - SUM (non-invertible suffix)") { + runSectionA("SUM", ITERS_NORMAL, A_N) + } + runBenchmark("Section A - MIN") { + runSectionA("MIN", ITERS_NORMAL, A_N) + } + runBenchmark("Section A - MAX") { + runSectionA("MAX", ITERS_NORMAL, A_N) + } + runBenchmark("Section A - COUNT") { + runSectionA("COUNT", ITERS_NORMAL, A_N) + } + runBenchmark("Section A - AVG (multi-buffer)") { + runSectionA("AVG", ITERS_NORMAL, A_N) + } + + // Section B: N-sweep showing the algorithmic gap widening with N. + setupIntTable(B_N_SMALL) + runBenchmark("Section B - N=5K") { + runSectionB(B_N_SMALL, includeNaive = true, ITERS_NORMAL) + } + setupIntTable(B_N_MID) + runBenchmark("Section B - N=25K (stress)") { + runSectionB(B_N_MID, includeNaive = true, ITERS_STRESS) + } + setupIntTable(B_N_LARGE) + runBenchmark("Section B - N=50K (stress, last naive run)") { + runSectionB(B_N_LARGE, includeNaive = true, ITERS_STRESS) + } + setupIntTable(B_N_HUGE) + runBenchmark("Section B - N=100K (segtree-only, stress)") { + runSectionB(B_N_HUGE, includeNaive = false, ITERS_STRESS) + } + setupIntTable(B_N_GIANT) + runBenchmark("Section B - N=200K (segtree-only, stress)") { + runSectionB(B_N_GIANT, includeNaive = false, ITERS_STRESS) + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/window/SegmentTreeWindowFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/window/SegmentTreeWindowFunctionSuite.scala index 5e644cceb1426..ffe4e58449053 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/window/SegmentTreeWindowFunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/window/SegmentTreeWindowFunctionSuite.scala @@ -99,6 +99,19 @@ class SegmentTreeWindowFunctionSuite extends SharedSparkSession { baseDF.select($"id", $"pk", avg($"v").over(winSpec(-3, 3)).as("agg"))) } + // First / Last basic equivalence (respect-nulls; the default for + // first()/last()). Order-correctness depends on the segment-tree + // combine being left-to-right; see WindowSegmentTree.EligibleAggregates. + test("FIRST over ROWS BETWEEN 3 PRECEDING AND 3 FOLLOWING") { + checkEquivalence(() => + baseDF.select($"id", $"pk", first($"v").over(winSpec(-3, 3)).as("agg"))) + } + + test("LAST over ROWS BETWEEN 3 PRECEDING AND 3 FOLLOWING") { + checkEquivalence(() => + baseDF.select($"id", $"pk", last($"v").over(winSpec(-3, 3)).as("agg"))) + } + test("MIN + MAX + SUM share a single window frame") { checkEquivalence(() => baseDF.select( @@ -229,6 +242,67 @@ class SegmentTreeWindowFunctionSuite extends SharedSparkSession { count($"v").over(winSpec(-4, 4)).as("cn"))) } + // First / Last respect-nulls: NULL is a valid value. If the first row in the + // frame is NULL, FIRST returns NULL. The seg-tree merge must preserve this. + test("FIRST/LAST respect-nulls with mixed NULL frame contents") { + val df = spark.range(0, 60).selectExpr( + "id", + "(id % 3) AS pk", + "CASE WHEN id % 3 = 0 THEN NULL ELSE CAST(id AS INT) END AS v") + checkEquivalence(() => + df.select($"id", $"pk", + first($"v").over(winSpec(-4, 4)).as("fv"), + last($"v").over(winSpec(-4, 4)).as("lv"))) + } + + // First / Last IGNORE NULLS: per-row updates only set valueSet on non-null + // values. A per-block partial of (null, false) for an all-NULL block must + // be correctly skipped when merged with a later non-null block. + test("FIRST/LAST ignore-nulls with mixed NULL frame contents") { + val df = spark.range(0, 60).selectExpr( + "id", + "(id % 3) AS pk", + "CASE WHEN id % 3 = 0 THEN NULL ELSE CAST(id AS INT) END AS v") + checkEquivalence(() => + df.select($"id", $"pk", + first($"v", ignoreNulls = true).over(winSpec(-4, 4)).as("fv_ign"), + last($"v", ignoreNulls = true).over(winSpec(-4, 4)).as("lv_ign"))) + } + + // All-NULL column edge case for First/Last in both modes. + // Respect-nulls: returns NULL. Ignore-nulls: also returns NULL (no + // non-null candidate ever sets valueSet). + test("all-NULL column: FIRST/LAST in both modes") { + val df = spark.range(0, 30).selectExpr( + "id", "(id % 3) AS pk", "CAST(NULL AS INT) AS v") + checkEquivalence(() => + df.select($"id", $"pk", + first($"v").over(winSpec(-3, 3)).as("fv"), + last($"v").over(winSpec(-3, 3)).as("lv"), + first($"v", ignoreNulls = true).over(winSpec(-3, 3)).as("fv_ign"), + last($"v", ignoreNulls = true).over(winSpec(-3, 3)).as("lv_ign"))) + } + + // Adversarial NULL distribution for IGNORE NULLS: per-block aggregates need + // to compose correctly when an entire block is all-NULL. With block size + // 65536 and partition size 120 we cannot literally produce a fully-NULL + // block via the standard fixture, but a long stretch of consecutive NULLs + // exercises the same merge path (per-row updates produce intermediate + // valueSet=false buffers which then merge with a later valueSet=true buffer + // via mergeExpressions). Combined with a wide frame to force tree queries + // crossing the all-NULL stretch. + test("FIRST/LAST ignore-nulls: stretches of consecutive NULLs cross-merge correctly") { + val df = spark.range(0, 90).selectExpr( + "id", + "0 AS pk", + // First 30 rows non-null, next 30 all NULL, last 30 non-null again. + "CASE WHEN id BETWEEN 30 AND 59 THEN NULL ELSE CAST(id AS INT) END AS v") + checkEquivalence(() => + df.select($"id", $"pk", + first($"v", ignoreNulls = true).over(winSpec(-20, 20)).as("fv_ign"), + last($"v", ignoreNulls = true).over(winSpec(-20, 20)).as("lv_ign"))) + } + test("Double NaN and +/-Infinity propagate correctly through MIN/MAX/SUM") { // Trap: NaN > +Inf in Spark's MIN/MAX ordering; +Inf + -Inf = NaN in SUM. // Seg-tree uses DeclarativeAggregate.merge; behavior must match baseline. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/window/SegmentTreeWindowTestHelper.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/window/SegmentTreeWindowTestHelper.scala index cd5237c9b310f..6ac7b36f412ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/window/SegmentTreeWindowTestHelper.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/window/SegmentTreeWindowTestHelper.scala @@ -92,14 +92,17 @@ private[window] object SegmentTreeWindowTestHelper { /** Create a new frame. Caller owns lifecycle unless tracked via `track()`. */ def newFrame(): SegmentTreeWindowFunctionFrame = { val target = new SpecificInternalRow(Seq(bufAttrs.head.dataType)) + val lb = RowBoundOrdering(-1) + val ub = RowBoundOrdering(1) val frame = new SegmentTreeWindowFunctionFrame( target, processor, Array(fn), input, RowFrame, - RowBoundOrdering(-1), - RowBoundOrdering(1), + lb, + ubound = Some(ub), + fallbackFactory = () => new SlidingWindowFunctionFrame(target, processor, lb, ub), (es, s) => GenerateMutableProjection.generate(es, s), conf, None, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/window/UnboundedFollowingSegmentTreeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/window/UnboundedFollowingSegmentTreeSuite.scala new file mode 100644 index 0000000000000..4ea836bff9566 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/window/UnboundedFollowingSegmentTreeSuite.scala @@ -0,0 +1,465 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.window + +import org.apache.spark.sql.{DataFrame, QueryTest, Row} +import org.apache.spark.sql.expressions.Window +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession + +/** + * End-to-end correctness tests for the segment-tree shrinking-frame path + * (`... ROWS/RANGE BETWEEN AND UNBOUNDED FOLLOWING`). + * + * Mirrors the structure of [[SegmentTreeWindowFunctionSuite]]: every test + * runs the same SQL with `spark.sql.window.segmentTree.enabled` off and on + * and asserts row-set equality. The "off" path runs through + * [[UnboundedFollowingWindowFunctionFrame]] (the O(N^2) baseline); the "on" + * path runs through the new shrinking branch in + * [[SegmentTreeWindowFunctionFrame]] (`ubound = None`). + */ +class UnboundedFollowingSegmentTreeSuite extends SharedSparkSession { + + import testImplicits._ + + private val enableSegTree: Map[String, String] = Map( + SQLConf.WINDOW_SEGMENT_TREE_ENABLED.key -> "true", + SQLConf.WINDOW_SEGMENT_TREE_MIN_PARTITION_ROWS.key -> "1") + + private val disableSegTree: Map[String, String] = Map( + SQLConf.WINDOW_SEGMENT_TREE_ENABLED.key -> "false") + + /** Baseline (flag off) vs segtree (flag on); compare row-sets. */ + private def checkEquivalence(build: () => DataFrame): Unit = { + val baseline: Seq[Row] = withSQLConf(disableSegTree.toSeq: _*) { + build().collect().toSeq + } + withSQLConf(enableSegTree.toSeq: _*) { + val actual = build().collect().toSeq + QueryTest.sameRows(baseline, actual, isSorted = false).foreach { err => + fail(s"shrinking-frame segtree output differs from baseline.\n$err") + } + } + } + + /** SQL-level variant that accepts a query string. */ + private def checkSqlEquivalence(df: DataFrame, query: String): Unit = { + df.createOrReplaceTempView("t") + try { + val baseline = withSQLConf(disableSegTree.toSeq: _*) { + spark.sql(query).collect().sortBy(_.toString) + } + withSQLConf(enableSegTree.toSeq: _*) { + val actual = spark.sql(query).collect().sortBy(_.toString) + assert(actual.toSeq === baseline.toSeq, + s"shrinking-frame segtree output differs from baseline.\n" + + s"Expected: ${baseline.toSeq}\nActual: ${actual.toSeq}") + } + } finally { + spark.catalog.dropTempView("t") + } + } + + /** 3 partitions, 40 rows each; values = row index. */ + private def baseDF: DataFrame = + spark.range(0, 120).selectExpr( + "id", + "(id % 3) AS pk", + "CAST(id AS INT) AS v") + + /** Shrinking ROWS frame: [lo, end-of-partition). */ + private def shrinkingRowsFrame(lo: Int) = + Window.partitionBy($"pk").orderBy($"id") + .rowsBetween(lo, Window.unboundedFollowing) + + // ============================================================ + // ROWS frame: basic aggregate equivalence (CURRENT ROW lower) + // ============================================================ + + test("MIN over ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING") { + checkEquivalence(() => + baseDF.select($"id", $"pk", min($"v").over(shrinkingRowsFrame(0)).as("agg"))) + } + + test("MAX over ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING") { + checkEquivalence(() => + baseDF.select($"id", $"pk", max($"v").over(shrinkingRowsFrame(0)).as("agg"))) + } + + test("SUM over ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING") { + checkEquivalence(() => + baseDF.select($"id", $"pk", sum($"v").over(shrinkingRowsFrame(0)).as("agg"))) + } + + test("COUNT over ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING") { + checkEquivalence(() => + baseDF.select($"id", $"pk", count($"v").over(shrinkingRowsFrame(0)).as("agg"))) + } + + test("AVG over ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING") { + checkEquivalence(() => + baseDF.select($"id", $"pk", avg($"v").over(shrinkingRowsFrame(0)).as("agg"))) + } + + // First / Last over a shrinking frame: in respect-nulls mode, FIRST is just + // `rows[lower]` (the first row of the suffix). LAST advances with the + // shrinking lower bound but always sees the partition's end. Both modes + // exercise the segment-tree merge path through a series of `[lower, n)` + // queries; correctness depends on the same left-to-right combine that + // makes First/Last safe in the sliding case. + test("FIRST over ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING") { + checkEquivalence(() => + baseDF.select($"id", $"pk", first($"v").over(shrinkingRowsFrame(0)).as("agg"))) + } + + test("LAST over ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING") { + checkEquivalence(() => + baseDF.select($"id", $"pk", last($"v").over(shrinkingRowsFrame(0)).as("agg"))) + } + + // ============================================================ + // ROWS frame: lower-bound variations + // ============================================================ + + test("ROWS BETWEEN 5 PRECEDING AND UNBOUNDED FOLLOWING (suffix + lookback)") { + checkEquivalence(() => + baseDF.select($"id", $"pk", sum($"v").over(shrinkingRowsFrame(-5)).as("agg"))) + } + + test("ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING is NOT this path") { + // Both-unbounded routes to UnboundedWindowFunctionFrame (different case + // in the dispatcher) and is one-shot O(1). This test just verifies the + // segtree flag doesn't break it. + val frame = Window.partitionBy($"pk").orderBy($"id") + .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing) + checkEquivalence(() => + baseDF.select($"id", $"pk", sum($"v").over(frame).as("agg"))) + } + + test("ROWS BETWEEN 5 FOLLOWING AND UNBOUNDED FOLLOWING (lower bound is positive)") { + checkEquivalence(() => + baseDF.select($"id", $"pk", sum($"v").over(shrinkingRowsFrame(5)).as("agg"))) + } + + // ============================================================ + // Multi-aggregate: shared frame + // ============================================================ + + test("MIN + MAX + SUM share a single shrinking frame") { + checkEquivalence(() => + baseDF.select( + $"id", $"pk", + min($"v").over(shrinkingRowsFrame(0)).as("mn"), + max($"v").over(shrinkingRowsFrame(0)).as("mx"), + sum($"v").over(shrinkingRowsFrame(0)).as("s"))) + } + + // ============================================================ + // Partition / boundary edge cases + // ============================================================ + + test("single-row partition") { + val df = spark.range(0, 5).selectExpr("id", "id AS pk", "CAST(id AS INT) AS v") + checkEquivalence(() => + df.select($"id", $"pk", sum($"v").over(shrinkingRowsFrame(0)).as("agg"))) + } + + test("empty result table (no rows)") { + val df = spark.emptyDataFrame.selectExpr("CAST(NULL AS BIGINT) AS id", + "CAST(NULL AS BIGINT) AS pk", "CAST(NULL AS INT) AS v") + .where("id IS NOT NULL") + checkEquivalence(() => + df.select($"id", $"pk", sum($"v").over(shrinkingRowsFrame(0)).as("agg"))) + } + + test("partition below minPartitionRows falls back to UnboundedFollowingWindowFunctionFrame") { + // With minRows=1024 the segtree path forces fallback; baseline (off) and + // forced-fallback (on, but min=1024) must match. The point is that the + // small-partition path goes through the legacy frame, not segtree. + val df = baseDF + val baseline = withSQLConf(disableSegTree.toSeq: _*) { + df.select($"id", $"pk", sum($"v").over(shrinkingRowsFrame(0)).as("s")) + .collect().toSeq + } + withSQLConf( + SQLConf.WINDOW_SEGMENT_TREE_ENABLED.key -> "true", + SQLConf.WINDOW_SEGMENT_TREE_MIN_PARTITION_ROWS.key -> "1024") { + val actual = df.select($"id", $"pk", sum($"v").over(shrinkingRowsFrame(0)).as("s")) + .collect().toSeq + QueryTest.sameRows(baseline, actual, isSorted = false).foreach { err => + fail(s"forced-fallback path diverges from baseline.\n$err") + } + } + } + + // ============================================================ + // NULL / NaN / numeric edge cases + // ============================================================ + + test("all-NULL column: SUM/MIN/MAX/AVG/COUNT") { + val df = spark.range(0, 30).selectExpr("id", "(id % 3) AS pk", + "CAST(NULL AS INT) AS v") + checkEquivalence(() => + df.select($"id", $"pk", + sum($"v").over(shrinkingRowsFrame(0)).as("s"), + min($"v").over(shrinkingRowsFrame(0)).as("mn"), + max($"v").over(shrinkingRowsFrame(0)).as("mx"), + avg($"v").over(shrinkingRowsFrame(0)).as("a"), + count($"v").over(shrinkingRowsFrame(0)).as("c"))) + } + + // First / Last over a shrinking frame with NULL distribution. Mirrors the + // sliding-suite NULL tests; verifies the merge path is correct when the + // lower-edge partial block of the segtree query crosses a NULL/non-NULL + // boundary. + test("FIRST/LAST over shrinking frame: respect-nulls with mixed NULLs") { + val df = spark.range(0, 60).selectExpr( + "id", + "(id % 3) AS pk", + "CASE WHEN id % 3 = 0 THEN NULL ELSE CAST(id AS INT) END AS v") + checkEquivalence(() => + df.select($"id", $"pk", + first($"v").over(shrinkingRowsFrame(0)).as("fv"), + last($"v").over(shrinkingRowsFrame(0)).as("lv"))) + } + + test("FIRST/LAST over shrinking frame: ignore-nulls with mixed NULLs") { + val df = spark.range(0, 60).selectExpr( + "id", + "(id % 3) AS pk", + "CASE WHEN id % 3 = 0 THEN NULL ELSE CAST(id AS INT) END AS v") + checkEquivalence(() => + df.select($"id", $"pk", + first($"v", ignoreNulls = true).over(shrinkingRowsFrame(0)).as("fv_ign"), + last($"v", ignoreNulls = true).over(shrinkingRowsFrame(0)).as("lv_ign"))) + } + + test("all-NULL column: FIRST/LAST shrinking frame in both modes") { + val df = spark.range(0, 30).selectExpr("id", "(id % 3) AS pk", + "CAST(NULL AS INT) AS v") + checkEquivalence(() => + df.select($"id", $"pk", + first($"v").over(shrinkingRowsFrame(0)).as("fv"), + last($"v").over(shrinkingRowsFrame(0)).as("lv"), + first($"v", ignoreNulls = true).over(shrinkingRowsFrame(0)).as("fv_ign"), + last($"v", ignoreNulls = true).over(shrinkingRowsFrame(0)).as("lv_ign"))) + } + + test("mixed NULL and non-NULL: NULLs must not leak into MIN/MAX") { + val df = (0 until 60).map { i => + val v: Option[Int] = if (i % 4 == 0) None else Some(i) + (i.toLong, (i % 3).toLong, v) + }.toDF("id", "pk", "v") + checkEquivalence(() => + df.select($"id", $"pk", + min($"v").over(shrinkingRowsFrame(0)).as("mn"), + max($"v").over(shrinkingRowsFrame(0)).as("mx"), + sum($"v").over(shrinkingRowsFrame(0)).as("s"), + count($"v").over(shrinkingRowsFrame(0)).as("c"))) + } + + test("Double NaN and +/-Infinity propagate correctly through MIN/MAX/SUM") { + val df = Seq( + (0L, 0L, 1.0d), (1L, 0L, Double.NaN), (2L, 0L, 3.0d), + (3L, 0L, Double.PositiveInfinity), (4L, 0L, 5.0d), + (5L, 0L, Double.NegativeInfinity), (6L, 0L, 7.0d), (7L, 0L, 9.0d) + ).toDF("id", "pk", "v") + checkEquivalence(() => + df.select($"id", $"pk", + min($"v").over(shrinkingRowsFrame(0)).as("mn"), + max($"v").over(shrinkingRowsFrame(0)).as("mx"), + sum($"v").over(shrinkingRowsFrame(0)).as("s"))) + } + + // ============================================================ + // Type coverage + // ============================================================ + + test("numeric types: Int / Long / Double / Decimal") { + val df = spark.range(0, 60).selectExpr( + "id", + "(id % 3) AS pk", + "CAST(id AS INT) AS vi", + "id * 1000000000L AS vl", + "CAST(id AS DOUBLE) * 1.5 AS vd", + "CAST(id AS DECIMAL(20, 5)) AS vdec") + checkEquivalence(() => + df.select($"id", $"pk", + sum($"vi").over(shrinkingRowsFrame(0)).as("si"), + sum($"vl").over(shrinkingRowsFrame(0)).as("sl"), + sum($"vd").over(shrinkingRowsFrame(0)).as("sd"), + sum($"vdec").over(shrinkingRowsFrame(0)).as("sdec"))) + } + + test("String lexicographic MIN/MAX") { + val df = spark.range(0, 30).selectExpr( + "id", + "(id % 3) AS pk", + "CONCAT('s', LPAD(CAST((id * 7) % 31 AS STRING), 3, '0')) AS v") + checkEquivalence(() => + df.select($"id", $"pk", + min($"v").over(shrinkingRowsFrame(0)).as("mn"), + max($"v").over(shrinkingRowsFrame(0)).as("mx"))) + } + + test("Date / Timestamp MIN/MAX") { + val df = spark.range(0, 24).selectExpr( + "id", + "(id % 3) AS pk", + "DATE_ADD(DATE'2024-01-01', CAST(id AS INT)) AS d", + "TIMESTAMPADD(HOUR, CAST(id AS INT), TIMESTAMP'2024-01-01 00:00:00') AS ts") + checkEquivalence(() => + df.select($"id", $"pk", + min($"d").over(shrinkingRowsFrame(0)).as("dmn"), + max($"ts").over(shrinkingRowsFrame(0)).as("tsmx"))) + } + + // ============================================================ + // Allow-list: non-DeclarativeAggregate paths must fall back + // ============================================================ + + test("collect_list falls back cleanly (non-DeclarativeAggregate)") { + // collect_list is ImperativeAggregate; segtree path must not engage. + // The result should still be correct via the legacy frame. + checkEquivalence(() => + baseDF.select($"id", $"pk", + collect_list($"v").over(shrinkingRowsFrame(0)).as("lst"))) + } + + test("DISTINCT shrinking aggregate is rejected by analyzer regardless of seg-tree flag") { + def run(): Unit = { + baseDF.select($"id", $"pk", + count_distinct($"v").over(shrinkingRowsFrame(0)).as("cd")).collect() + } + withSQLConf(disableSegTree.toSeq: _*) { + val e = intercept[org.apache.spark.sql.AnalysisException](run()) + assert(e.getMessage.contains("DISTINCT_WINDOW_FUNCTION_UNSUPPORTED")) + } + withSQLConf(enableSegTree.toSeq: _*) { + val e = intercept[org.apache.spark.sql.AnalysisException](run()) + assert(e.getMessage.contains("DISTINCT_WINDOW_FUNCTION_UNSUPPORTED")) + } + } + + // ============================================================ + // RANGE shrinking frame (single-order-expr) + // ============================================================ + + test("RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING with non-uniform gaps") { + val df = spark.range(0, 40).selectExpr( + "CAST(id AS INT) AS id", + "(CAST(id AS INT) % 2) AS pk", + "CAST(CASE CAST(id AS INT) % 7 " + + "WHEN 0 THEN 1 WHEN 1 THEN 3 WHEN 2 THEN 4 WHEN 3 THEN 4 " + + "WHEN 4 THEN 7 WHEN 5 THEN 10 ELSE 15 END + (CAST(id AS INT) / 7) * 20 AS INT) AS k", + "CAST((id * 31) % 97 AS INT) AS v") + checkSqlEquivalence(df, + """SELECT id, pk, + | MIN(v) OVER (PARTITION BY pk ORDER BY k + | RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) AS mn, + | MAX(v) OVER (PARTITION BY pk ORDER BY k + | RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) AS mx + |FROM t""".stripMargin) + } + + test("RANGE BETWEEN 2 PRECEDING AND UNBOUNDED FOLLOWING") { + val df = spark.range(0, 40).selectExpr( + "CAST(id AS INT) AS id", + "(CAST(id AS INT) % 2) AS pk", + "CAST((id * 7) % 31 AS INT) AS k", + "CAST((id * 11) % 53 AS INT) AS v") + checkSqlEquivalence(df, + """SELECT id, pk, k, + | SUM(v) OVER (PARTITION BY pk ORDER BY k + | RANGE BETWEEN 2 PRECEDING AND UNBOUNDED FOLLOWING) AS s + |FROM t""".stripMargin) + } + + test("RANGE with tie (duplicate order keys): full tie group at lower edge") { + // Trap: at the lower edge, the FULL tie group at the lower offset must + // be retained, not just the first row. + val rows = (0 until 40).map { i => + val k = Seq(1, 2, 2, 2, 3, 4, 5)(i % 7) + (i, i % 2, k, (i * 13) % 41) + } + val df = rows.toDF("id", "pk", "k", "v") + checkSqlEquivalence(df, + """SELECT id, pk, k, + | MIN(v) OVER (PARTITION BY pk ORDER BY k + | RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) AS mn, + | MAX(v) OVER (PARTITION BY pk ORDER BY k + | RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) AS mx + |FROM t""".stripMargin) + } + + test("RANGE with NULL order key (NULLS FIRST / NULLS LAST)") { + val rows = (0 until 36).map { i => + val kOpt: Option[Int] = (i % 6) match { + case 0 | 1 | 5 => None + case 2 => Some(1) + case 3 => Some(2) + case _ => Some(3) + } + (i, i % 2, kOpt, (i * 11) % 37) + } + val df = rows.toDF("id", "pk", "k", "v") + checkSqlEquivalence(df, + """SELECT id, pk, + | MIN(v) OVER (PARTITION BY pk ORDER BY k ASC NULLS FIRST + | RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) AS mn_nf, + | MAX(v) OVER (PARTITION BY pk ORDER BY k ASC NULLS LAST + | RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) AS mx_nl + |FROM t""".stripMargin) + } + + test("RANGE Timestamp with INTERVAL offset (MAX) and shrinking upper") { + val df = spark.range(0, 30).selectExpr( + "CAST(id AS INT) AS id", + "(CAST(id AS INT) % 2) AS pk", + "CAST(TIMESTAMP'2024-01-01 10:00:00' + " + + "make_interval(0, 0, 0, 0, 0, 30 * CAST(id AS INT) * " + + "(CASE CAST(id AS INT) % 3 WHEN 0 THEN 1 WHEN 1 THEN 3 ELSE 4 END), 0) " + + "AS TIMESTAMP) AS ts", + "CAST((id * 17) % 53 AS INT) AS v") + checkSqlEquivalence(df, + """SELECT id, pk, + | MAX(v) OVER (PARTITION BY pk ORDER BY ts + | RANGE BETWEEN INTERVAL '1' HOUR PRECEDING AND UNBOUNDED FOLLOWING) AS mx + |FROM t""".stripMargin) + } + + // ============================================================ + // Feature-flag off: legacy frame is used + // ============================================================ + + test("feature flag off: segmentTree.enabled=false yields baseline semantics") { + val df = baseDF + val expected = withSQLConf(disableSegTree.toSeq: _*) { + df.select($"id", $"pk", min($"v").over(shrinkingRowsFrame(0)).as("mn")) + .collect().sortBy(_.toString).toSeq + } + withSQLConf( + SQLConf.WINDOW_SEGMENT_TREE_ENABLED.key -> "false", + SQLConf.WINDOW_SEGMENT_TREE_MIN_PARTITION_ROWS.key -> "1024") { + val actual = df.select($"id", $"pk", min($"v").over(shrinkingRowsFrame(0)).as("mn")) + .collect().sortBy(_.toString).toSeq + assert(actual === expected) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/window/WindowSegmentTreeAllowlistSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/window/WindowSegmentTreeAllowlistSuite.scala index 236d38cc6a910..8dbb4ceecf1bb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/window/WindowSegmentTreeAllowlistSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/window/WindowSegmentTreeAllowlistSuite.scala @@ -81,7 +81,13 @@ class WindowSegmentTreeAllowlistSuite ("stddev_pop", (c: org.apache.spark.sql.Column) => stddev_pop(c)), ("stddev_samp", (c: org.apache.spark.sql.Column) => stddev_samp(c)), ("var_pop", (c: org.apache.spark.sql.Column) => var_pop(c)), - ("var_samp", (c: org.apache.spark.sql.Column) => var_samp(c)) + ("var_samp", (c: org.apache.spark.sql.Column) => var_samp(c)), + ("first", (c: org.apache.spark.sql.Column) => first(c)), + ("last", (c: org.apache.spark.sql.Column) => last(c)), + ("first_ignore_nulls", + (c: org.apache.spark.sql.Column) => first(c, ignoreNulls = true)), + ("last_ignore_nulls", + (c: org.apache.spark.sql.Column) => last(c, ignoreNulls = true)) ).foreach { case (name, fn) => test(s"$name routes to the segment-tree path") { withSQLConf(enableSegTree.toSeq: _*) { @@ -96,22 +102,6 @@ class WindowSegmentTreeAllowlistSuite // Negative: non-allowlisted aggregates fall through - test("first_value falls through (order-dependent aggregate)") { - withSQLConf(enableSegTree.toSeq: _*) { - val df = baseDF.withColumn("agg", first($"v").over(winSpec)) - val (seg, _) = segTreeCounters(df) - assert(seg == 0, s"first_value should not use segment tree (got $seg frames)") - } - } - - test("last_value falls through (order-dependent aggregate)") { - withSQLConf(enableSegTree.toSeq: _*) { - val df = baseDF.withColumn("agg", last($"v").over(winSpec)) - val (seg, _) = segTreeCounters(df) - assert(seg == 0, s"last_value should not use segment tree (got $seg frames)") - } - } - test("collect_list falls through (unbounded buffer)") { withSQLConf(enableSegTree.toSeq: _*) { val df = baseDF.withColumn("agg", collect_list($"v").over(winSpec)) @@ -176,10 +166,10 @@ class WindowSegmentTreeAllowlistSuite withSQLConf(enableSegTree.toSeq: _*) { val df = baseDF .withColumn("s", sum($"v").over(winSpec)) - .withColumn("fv", first($"v").over(winSpec)) + .withColumn("cl", collect_list($"v").over(winSpec)) val (seg, _) = segTreeCounters(df) // Both aggregates share the same Window node; gating is forall(isEligible), - // so `first_value` drops the whole group. + // so `collect_list` (unbounded-buffer denylist) drops the whole group. assert(seg == 0, s"Window group containing a non-allowlisted agg must fall through (got $seg)") }