Skip to content
40 changes: 38 additions & 2 deletions src/query/expression/src/aggregate/aggregate_hashtable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ use crate::AggregateFunctionRef;
use crate::BlockEntry;
use crate::ColumnBuilder;
use crate::ProjectedBlock;
use crate::block::DataBlock;
use crate::types::DataType;

const SMALL_CAPACITY_RESIZE_COUNT: usize = 4;
Expand Down Expand Up @@ -74,10 +75,11 @@ impl AggregateHashTable {
Self {
direct_append: false,
current_radix_bits: config.initial_radix_bits,
payload: PartitionedPayload::new(
payload: PartitionedPayload::new_with_start_bit(
group_types,
aggrs,
1 << config.initial_radix_bits,
config.partition_start_bit,
vec![arena],
),
hash_index: HashIndex::new(&config, capacity),
Expand Down Expand Up @@ -105,10 +107,11 @@ impl AggregateHashTable {
Self {
direct_append: !need_init_entry,
current_radix_bits: config.initial_radix_bits,
payload: PartitionedPayload::new(
payload: PartitionedPayload::new_with_start_bit(
group_types,
aggrs,
1 << config.initial_radix_bits,
config.partition_start_bit,
vec![arena],
),
hash_index,
Expand Down Expand Up @@ -335,6 +338,39 @@ impl AggregateHashTable {
Ok(())
}

/// Directly merge a serialized DataBlock into this hash table without
/// creating an intermediate PartitionedPayload. This avoids the 2×
/// memory peak that occurs when `convert_to_partitioned_payload` +
/// `combine_payloads` are used, because aggregate states are only ever
/// allocated in the main arena.
pub fn combine_serialized_block(
&mut self,
data_block: &DataBlock,
num_states: usize,
group_len: usize,
) -> Result<()> {
let row_count = data_block.num_rows();
if row_count == 0 {
return Ok(());
}

let states_index: Vec<usize> = (0..num_states).collect();
let agg_states = ProjectedBlock::project(&states_index, data_block);

let group_index: Vec<usize> = (num_states..(num_states + group_len)).collect();
let group_columns = ProjectedBlock::project(&group_index, data_block);

let mut state = ProbeState::default();
self.add_groups(
&mut state,
group_columns,
&[(&[]).into()],
agg_states,
row_count,
)?;
Ok(())
}

pub fn merge_result(&mut self, flush_state: &mut PayloadFlushState) -> Result<bool> {
if !self.payload.flush(flush_state) {
return Ok(false);
Expand Down
7 changes: 7 additions & 0 deletions src/query/expression/src/aggregate/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ pub struct HashTableConfig {
// Max radix bits across all threads, this is a hint to repartition
pub current_max_radix_bits: Arc<AtomicU64>,
pub initial_radix_bits: u64,
pub partition_start_bit: u64,
pub max_radix_bits: u64,
pub repartition_radix_bits_incr: u64,
pub block_fill_factor: f64,
Expand All @@ -167,6 +168,7 @@ impl Default for HashTableConfig {
Self {
current_max_radix_bits: Arc::new(AtomicU64::new(3)),
initial_radix_bits: 3,
partition_start_bit: 0,
max_radix_bits: MAX_RADIX_BITS,
repartition_radix_bits_incr: 2,
block_fill_factor: 1.8,
Expand Down Expand Up @@ -211,6 +213,11 @@ impl HashTableConfig {
self
}

pub fn with_partition_start_bit(mut self, partition_start_bit: u64) -> Self {
self.partition_start_bit = partition_start_bit;
self
}

pub fn with_experiment_hash_index(mut self, enable: bool) -> Self {
self.enable_experiment_hash_index = enable;
self
Expand Down
51 changes: 46 additions & 5 deletions src/query/expression/src/aggregate/partitioned_payload.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,15 @@ struct PartitionMask {

impl PartitionMask {
fn new(partition_count: u64) -> Self {
Self::with_start_bit(partition_count, 0)
}

fn with_start_bit(partition_count: u64, start_bit: u64) -> Self {
let radix_bits = partition_count.trailing_zeros() as u64;
debug_assert_eq!(1 << radix_bits, partition_count);
debug_assert!(start_bit + radix_bits <= 48);

let shift = 48 - radix_bits;
let shift = 48 - start_bit - radix_bits;
let mask = ((1 << radix_bits) - 1) << shift;

Self { mask, shift }
Expand All @@ -59,6 +64,7 @@ pub struct PartitionedPayload {

pub arenas: Vec<Arc<Bump>>,

partition_start_bit: u64,
partition_mask: PartitionMask,
}

Expand All @@ -71,6 +77,16 @@ impl PartitionedPayload {
aggrs: Vec<AggregateFunctionRef>,
partition_count: u64,
arenas: Vec<Arc<Bump>>,
) -> Self {
Self::new_with_start_bit(group_types, aggrs, partition_count, 0, arenas)
}

pub fn new_with_start_bit(
group_types: Vec<DataType>,
aggrs: Vec<AggregateFunctionRef>,
partition_count: u64,
partition_start_bit: u64,
arenas: Vec<Arc<Bump>>,
) -> Self {
let states_layout = if !aggrs.is_empty() {
Some(get_states_layout(&aggrs).unwrap())
Expand Down Expand Up @@ -101,7 +117,8 @@ impl PartitionedPayload {
row_layout,

arenas,
partition_mask: PartitionMask::new(partition_count),
partition_start_bit,
partition_mask: PartitionMask::with_start_bit(partition_count, partition_start_bit),
}
}

Expand Down Expand Up @@ -169,11 +186,17 @@ impl PartitionedPayload {
group_types,
aggrs,
arenas,
partition_start_bit,
..
} = self;

let mut new_partition_payload =
PartitionedPayload::new(group_types, aggrs, new_partition_count as u64, arenas);
let mut new_partition_payload = PartitionedPayload::new_with_start_bit(
group_types,
aggrs,
new_partition_count as u64,
partition_start_bit,
arenas,
);

state.clear();
for payload in payloads.into_iter() {
Expand All @@ -184,7 +207,9 @@ impl PartitionedPayload {
}

pub fn combine(&mut self, other: PartitionedPayload, state: &mut PayloadFlushState) {
if other.partition_count() == self.partition_count() {
if other.partition_count() == self.partition_count()
&& other.partition_start_bit == self.partition_start_bit
{
for (l, r) in self.payloads.iter_mut().zip(other.payloads.into_iter()) {
l.combine(r);
}
Expand Down Expand Up @@ -293,3 +318,19 @@ impl PartitionedPayload {
self.payloads.iter().map(|x| x.memory_size()).sum()
}
}

#[cfg(test)]
mod tests {
use super::PartitionMask;

#[test]
fn test_partition_mask_with_start_bit() {
let top_bit_mask = PartitionMask::new(2);
assert_eq!(top_bit_mask.index(1_u64 << 47), 1);
assert_eq!(top_bit_mask.index(1_u64 << 44), 0);

let shifted_mask = PartitionMask::with_start_bit(2, 3);
assert_eq!(shifted_mask.index(1_u64 << 47), 0);
assert_eq!(shifted_mask.index(1_u64 << 44), 1);
}
}
Loading
Loading