risingwave_batch_executors/executor/
hash_agg.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::hash::BuildHasher;
16use std::marker::PhantomData;
17use std::sync::Arc;
18
19use bytes::Bytes;
20use futures_async_stream::try_stream;
21use hashbrown::hash_map::Entry;
22use itertools::Itertools;
23use risingwave_common::array::{DataChunk, StreamChunk};
24use risingwave_common::bitmap::{Bitmap, FilterByBitmap};
25use risingwave_common::catalog::{Field, Schema};
26use risingwave_common::hash::{HashKey, HashKeyDispatcher, PrecomputedBuildHasher};
27use risingwave_common::memory::MemoryContext;
28use risingwave_common::row::{OwnedRow, Row, RowExt};
29use risingwave_common::types::{DataType, ToOwnedDatum};
30use risingwave_common::util::chunk_coalesce::DataChunkBuilder;
31use risingwave_common::util::iter_util::ZipEqFast;
32use risingwave_common_estimate_size::EstimateSize;
33use risingwave_expr::aggregate::{AggCall, AggregateState, BoxedAggregateFunction};
34use risingwave_pb::Message;
35use risingwave_pb::batch_plan::HashAggNode;
36use risingwave_pb::batch_plan::plan_node::NodeBody;
37use risingwave_pb::data::DataChunk as PbDataChunk;
38
39use crate::error::{BatchError, Result};
40use crate::executor::aggregation::build as build_agg;
41use crate::executor::{
42    BoxedDataChunkStream, BoxedExecutor, BoxedExecutorBuilder, Executor, ExecutorBuilder,
43    WrapStreamExecutor,
44};
45use crate::monitor::BatchSpillMetrics;
46use crate::spill::spill_op::SpillBackend::Disk;
47use crate::spill::spill_op::{
48    DEFAULT_SPILL_PARTITION_NUM, SPILL_AT_LEAST_MEMORY, SpillBackend, SpillBuildHasher, SpillOp,
49};
50use crate::task::{ShutdownToken, TaskId};
51
52type AggHashMap<K, A> = hashbrown::HashMap<K, Vec<AggregateState>, PrecomputedBuildHasher, A>;
53
54/// A dispatcher to help create specialized hash agg executor.
55impl HashKeyDispatcher for HashAggExecutorBuilder {
56    type Output = BoxedExecutor;
57
58    fn dispatch_impl<K: HashKey>(self) -> Self::Output {
59        Box::new(HashAggExecutor::<K>::new(
60            Arc::new(self.aggs),
61            self.group_key_columns,
62            self.group_key_types,
63            self.schema,
64            self.child,
65            self.identity,
66            self.chunk_size,
67            self.mem_context,
68            self.spill_backend,
69            self.spill_metrics,
70            self.shutdown_rx,
71        ))
72    }
73
74    fn data_types(&self) -> &[DataType] {
75        &self.group_key_types
76    }
77}
78
79pub struct HashAggExecutorBuilder {
80    aggs: Vec<BoxedAggregateFunction>,
81    group_key_columns: Vec<usize>,
82    group_key_types: Vec<DataType>,
83    child: BoxedExecutor,
84    schema: Schema,
85    #[expect(dead_code)]
86    task_id: TaskId,
87    identity: String,
88    chunk_size: usize,
89    mem_context: MemoryContext,
90    spill_backend: Option<SpillBackend>,
91    spill_metrics: Arc<BatchSpillMetrics>,
92    shutdown_rx: ShutdownToken,
93}
94
95impl HashAggExecutorBuilder {
96    fn deserialize(
97        hash_agg_node: &HashAggNode,
98        child: BoxedExecutor,
99        task_id: TaskId,
100        identity: String,
101        chunk_size: usize,
102        mem_context: MemoryContext,
103        spill_backend: Option<SpillBackend>,
104        spill_metrics: Arc<BatchSpillMetrics>,
105        shutdown_rx: ShutdownToken,
106    ) -> Result<BoxedExecutor> {
107        let aggs: Vec<_> = hash_agg_node
108            .get_agg_calls()
109            .iter()
110            .map(|agg| AggCall::from_protobuf(agg).and_then(|agg| build_agg(&agg)))
111            .try_collect()?;
112
113        let group_key_columns = hash_agg_node
114            .get_group_key()
115            .iter()
116            .map(|x| *x as usize)
117            .collect_vec();
118
119        let child_schema = child.schema();
120
121        let group_key_types = group_key_columns
122            .iter()
123            .map(|i| child_schema.fields[*i].data_type.clone())
124            .collect_vec();
125
126        let fields = group_key_types
127            .iter()
128            .cloned()
129            .chain(aggs.iter().map(|e| e.return_type()))
130            .map(Field::unnamed)
131            .collect::<Vec<Field>>();
132
133        let builder = HashAggExecutorBuilder {
134            aggs,
135            group_key_columns,
136            group_key_types,
137            child,
138            schema: Schema { fields },
139            task_id,
140            identity,
141            chunk_size,
142            mem_context,
143            spill_backend,
144            spill_metrics,
145            shutdown_rx,
146        };
147
148        Ok(builder.dispatch())
149    }
150}
151
152impl BoxedExecutorBuilder for HashAggExecutorBuilder {
153    async fn new_boxed_executor(
154        source: &ExecutorBuilder<'_>,
155        inputs: Vec<BoxedExecutor>,
156    ) -> Result<BoxedExecutor> {
157        let [child]: [_; 1] = inputs.try_into().unwrap();
158
159        let hash_agg_node = try_match_expand!(
160            source.plan_node().get_node_body().unwrap(),
161            NodeBody::HashAgg
162        )?;
163
164        let identity = source.plan_node().get_identity();
165
166        let spill_metrics = source.context().spill_metrics();
167
168        Self::deserialize(
169            hash_agg_node,
170            child,
171            source.task_id.clone(),
172            identity.clone(),
173            source.context().get_config().developer.chunk_size,
174            source.context().create_executor_mem_context(identity),
175            if source.context().get_config().enable_spill {
176                Some(Disk)
177            } else {
178                None
179            },
180            spill_metrics,
181            source.shutdown_rx().clone(),
182        )
183    }
184}
185
186/// `HashAggExecutor` implements the hash aggregate algorithm.
187pub struct HashAggExecutor<K> {
188    /// Aggregate functions.
189    aggs: Arc<Vec<BoxedAggregateFunction>>,
190    /// Column indexes that specify a group
191    group_key_columns: Vec<usize>,
192    /// Data types of group key columns
193    group_key_types: Vec<DataType>,
194    /// Output schema
195    schema: Schema,
196    child: BoxedExecutor,
197    /// Used to initialize the state of the aggregation from the spilled files.
198    init_agg_state_executor: Option<BoxedExecutor>,
199    identity: String,
200    chunk_size: usize,
201    mem_context: MemoryContext,
202    spill_backend: Option<SpillBackend>,
203    spill_metrics: Arc<BatchSpillMetrics>,
204    /// The upper bound of memory usage for this executor.
205    memory_upper_bound: Option<u64>,
206    shutdown_rx: ShutdownToken,
207    _phantom: PhantomData<K>,
208}
209
210impl<K> HashAggExecutor<K> {
211    #[allow(clippy::too_many_arguments)]
212    pub fn new(
213        aggs: Arc<Vec<BoxedAggregateFunction>>,
214        group_key_columns: Vec<usize>,
215        group_key_types: Vec<DataType>,
216        schema: Schema,
217        child: BoxedExecutor,
218        identity: String,
219        chunk_size: usize,
220        mem_context: MemoryContext,
221        spill_backend: Option<SpillBackend>,
222        spill_metrics: Arc<BatchSpillMetrics>,
223        shutdown_rx: ShutdownToken,
224    ) -> Self {
225        Self::new_inner(
226            aggs,
227            group_key_columns,
228            group_key_types,
229            schema,
230            child,
231            None,
232            identity,
233            chunk_size,
234            mem_context,
235            spill_backend,
236            spill_metrics,
237            None,
238            shutdown_rx,
239        )
240    }
241
242    #[allow(clippy::too_many_arguments)]
243    fn new_inner(
244        aggs: Arc<Vec<BoxedAggregateFunction>>,
245        group_key_columns: Vec<usize>,
246        group_key_types: Vec<DataType>,
247        schema: Schema,
248        child: BoxedExecutor,
249        init_agg_state_executor: Option<BoxedExecutor>,
250        identity: String,
251        chunk_size: usize,
252        mem_context: MemoryContext,
253        spill_backend: Option<SpillBackend>,
254        spill_metrics: Arc<BatchSpillMetrics>,
255        memory_upper_bound: Option<u64>,
256        shutdown_rx: ShutdownToken,
257    ) -> Self {
258        HashAggExecutor {
259            aggs,
260            group_key_columns,
261            group_key_types,
262            schema,
263            child,
264            init_agg_state_executor,
265            identity,
266            chunk_size,
267            mem_context,
268            spill_backend,
269            spill_metrics,
270            memory_upper_bound,
271            shutdown_rx,
272            _phantom: PhantomData,
273        }
274    }
275}
276
277impl<K: HashKey + Send + Sync> Executor for HashAggExecutor<K> {
278    fn schema(&self) -> &Schema {
279        &self.schema
280    }
281
282    fn identity(&self) -> &str {
283        &self.identity
284    }
285
286    fn execute(self: Box<Self>) -> BoxedDataChunkStream {
287        self.do_execute()
288    }
289}
290
291/// `AggSpillManager` is used to manage how to write spill data file and read them back.
292/// The spill data first need to be partitioned. Each partition contains 2 files: `agg_state_file` and `input_chunks_file`.
293/// The spill file consume a data chunk and serialize the chunk into a protobuf bytes.
294/// Finally, spill file content will look like the below.
295/// The file write pattern is append-only and the read pattern is sequential scan.
296/// This can maximize the disk IO performance.
297///
298/// ```text
299/// [proto_len]
300/// [proto_bytes]
301/// ...
302/// [proto_len]
303/// [proto_bytes]
304/// ```
305pub struct AggSpillManager {
306    op: SpillOp,
307    partition_num: usize,
308    agg_state_writers: Vec<opendal::Writer>,
309    agg_state_chunk_builder: Vec<DataChunkBuilder>,
310    input_writers: Vec<opendal::Writer>,
311    input_chunk_builders: Vec<DataChunkBuilder>,
312    spill_build_hasher: SpillBuildHasher,
313    group_key_types: Vec<DataType>,
314    child_data_types: Vec<DataType>,
315    agg_data_types: Vec<DataType>,
316    spill_chunk_size: usize,
317    spill_metrics: Arc<BatchSpillMetrics>,
318}
319
320impl AggSpillManager {
321    fn new(
322        spill_backend: SpillBackend,
323        agg_identity: &String,
324        partition_num: usize,
325        group_key_types: Vec<DataType>,
326        agg_data_types: Vec<DataType>,
327        child_data_types: Vec<DataType>,
328        spill_chunk_size: usize,
329        spill_metrics: Arc<BatchSpillMetrics>,
330    ) -> Result<Self> {
331        let suffix_uuid = uuid::Uuid::new_v4();
332        let dir = format!("/{}-{}/", agg_identity, suffix_uuid);
333        let op = SpillOp::create(dir, spill_backend)?;
334        let agg_state_writers = Vec::with_capacity(partition_num);
335        let agg_state_chunk_builder = Vec::with_capacity(partition_num);
336        let input_writers = Vec::with_capacity(partition_num);
337        let input_chunk_builders = Vec::with_capacity(partition_num);
338        // Use uuid to generate an unique hasher so that when recursive spilling happens they would use a different hasher to avoid data skew.
339        let spill_build_hasher = SpillBuildHasher(suffix_uuid.as_u64_pair().1);
340        Ok(Self {
341            op,
342            partition_num,
343            agg_state_writers,
344            agg_state_chunk_builder,
345            input_writers,
346            input_chunk_builders,
347            spill_build_hasher,
348            group_key_types,
349            child_data_types,
350            agg_data_types,
351            spill_chunk_size,
352            spill_metrics,
353        })
354    }
355
356    async fn init_writers(&mut self) -> Result<()> {
357        for i in 0..self.partition_num {
358            let agg_state_partition_file_name = format!("agg-state-p{}", i);
359            let w = self.op.writer_with(&agg_state_partition_file_name).await?;
360            self.agg_state_writers.push(w);
361
362            let partition_file_name = format!("input-chunks-p{}", i);
363            let w = self.op.writer_with(&partition_file_name).await?;
364            self.input_writers.push(w);
365            self.input_chunk_builders.push(DataChunkBuilder::new(
366                self.child_data_types.clone(),
367                self.spill_chunk_size,
368            ));
369            self.agg_state_chunk_builder.push(DataChunkBuilder::new(
370                self.group_key_types
371                    .iter()
372                    .cloned()
373                    .chain(self.agg_data_types.iter().cloned())
374                    .collect(),
375                self.spill_chunk_size,
376            ));
377        }
378        Ok(())
379    }
380
381    async fn write_agg_state_row(&mut self, row: impl Row, hash_code: u64) -> Result<()> {
382        let partition = hash_code as usize % self.partition_num;
383        if let Some(output_chunk) = self.agg_state_chunk_builder[partition].append_one_row(row) {
384            let chunk_pb: PbDataChunk = output_chunk.to_protobuf();
385            let buf = Message::encode_to_vec(&chunk_pb);
386            let len_bytes = Bytes::copy_from_slice(&(buf.len() as u32).to_le_bytes());
387            self.spill_metrics
388                .batch_spill_write_bytes
389                .inc_by((buf.len() + len_bytes.len()) as u64);
390            self.agg_state_writers[partition].write(len_bytes).await?;
391            self.agg_state_writers[partition].write(buf).await?;
392        }
393        Ok(())
394    }
395
396    async fn write_input_chunk(&mut self, chunk: DataChunk, hash_codes: Vec<u64>) -> Result<()> {
397        let (columns, vis) = chunk.into_parts_v2();
398        for partition in 0..self.partition_num {
399            let new_vis = vis.clone()
400                & Bitmap::from_iter(
401                    hash_codes
402                        .iter()
403                        .map(|hash_code| (*hash_code as usize % self.partition_num) == partition),
404                );
405            let new_chunk = DataChunk::from_parts(columns.clone(), new_vis);
406            for output_chunk in self.input_chunk_builders[partition].append_chunk(new_chunk) {
407                let chunk_pb: PbDataChunk = output_chunk.to_protobuf();
408                let buf = Message::encode_to_vec(&chunk_pb);
409                let len_bytes = Bytes::copy_from_slice(&(buf.len() as u32).to_le_bytes());
410                self.spill_metrics
411                    .batch_spill_write_bytes
412                    .inc_by((buf.len() + len_bytes.len()) as u64);
413                self.input_writers[partition].write(len_bytes).await?;
414                self.input_writers[partition].write(buf).await?;
415            }
416        }
417        Ok(())
418    }
419
420    async fn close_writers(&mut self) -> Result<()> {
421        for partition in 0..self.partition_num {
422            if let Some(output_chunk) = self.agg_state_chunk_builder[partition].consume_all() {
423                let chunk_pb: PbDataChunk = output_chunk.to_protobuf();
424                let buf = Message::encode_to_vec(&chunk_pb);
425                let len_bytes = Bytes::copy_from_slice(&(buf.len() as u32).to_le_bytes());
426                self.spill_metrics
427                    .batch_spill_write_bytes
428                    .inc_by((buf.len() + len_bytes.len()) as u64);
429                self.agg_state_writers[partition].write(len_bytes).await?;
430                self.agg_state_writers[partition].write(buf).await?;
431            }
432
433            if let Some(output_chunk) = self.input_chunk_builders[partition].consume_all() {
434                let chunk_pb: PbDataChunk = output_chunk.to_protobuf();
435                let buf = Message::encode_to_vec(&chunk_pb);
436                let len_bytes = Bytes::copy_from_slice(&(buf.len() as u32).to_le_bytes());
437                self.spill_metrics
438                    .batch_spill_write_bytes
439                    .inc_by((buf.len() + len_bytes.len()) as u64);
440                self.input_writers[partition].write(len_bytes).await?;
441                self.input_writers[partition].write(buf).await?;
442            }
443        }
444
445        for mut w in self.agg_state_writers.drain(..) {
446            w.close().await?;
447        }
448        for mut w in self.input_writers.drain(..) {
449            w.close().await?;
450        }
451        Ok(())
452    }
453
454    async fn read_agg_state_partition(&mut self, partition: usize) -> Result<BoxedDataChunkStream> {
455        let agg_state_partition_file_name = format!("agg-state-p{}", partition);
456        let r = self.op.reader_with(&agg_state_partition_file_name).await?;
457        Ok(SpillOp::read_stream(r, self.spill_metrics.clone()))
458    }
459
460    async fn read_input_partition(&mut self, partition: usize) -> Result<BoxedDataChunkStream> {
461        let input_partition_file_name = format!("input-chunks-p{}", partition);
462        let r = self.op.reader_with(&input_partition_file_name).await?;
463        Ok(SpillOp::read_stream(r, self.spill_metrics.clone()))
464    }
465
466    async fn estimate_partition_size(&self, partition: usize) -> Result<u64> {
467        let agg_state_partition_file_name = format!("agg-state-p{}", partition);
468        let agg_state_size = self
469            .op
470            .stat(&agg_state_partition_file_name)
471            .await?
472            .content_length();
473        let input_partition_file_name = format!("input-chunks-p{}", partition);
474        let input_size = self
475            .op
476            .stat(&input_partition_file_name)
477            .await?
478            .content_length();
479        Ok(agg_state_size + input_size)
480    }
481
482    async fn clear_partition(&mut self, partition: usize) -> Result<()> {
483        let agg_state_partition_file_name = format!("agg-state-p{}", partition);
484        self.op.delete(&agg_state_partition_file_name).await?;
485        let input_partition_file_name = format!("input-chunks-p{}", partition);
486        self.op.delete(&input_partition_file_name).await?;
487        Ok(())
488    }
489}
490
491impl<K: HashKey + Send + Sync> HashAggExecutor<K> {
492    #[try_stream(boxed, ok = DataChunk, error = BatchError)]
493    async fn do_execute(self: Box<Self>) {
494        let child_schema = self.child.schema().clone();
495        let mut need_to_spill = false;
496        // If the memory upper bound is less than 1MB, we don't need to check memory usage.
497        let check_memory = match self.memory_upper_bound {
498            Some(upper_bound) => upper_bound > SPILL_AT_LEAST_MEMORY,
499            None => true,
500        };
501
502        // hash map for each agg groups
503        let mut groups = AggHashMap::<K, _>::with_hasher_in(
504            PrecomputedBuildHasher,
505            self.mem_context.global_allocator(),
506        );
507
508        if let Some(init_agg_state_executor) = self.init_agg_state_executor {
509            // `init_agg_state_executor` exists which means this is a sub `HashAggExecutor` used to consume spilling data.
510            // The spilled agg states by its parent executor need to be recovered first.
511            let mut init_agg_state_stream = init_agg_state_executor.execute();
512            #[for_await]
513            for chunk in &mut init_agg_state_stream {
514                let chunk = chunk?;
515                let group_key_indices = (0..self.group_key_columns.len()).collect_vec();
516                let keys = K::build_many(&group_key_indices, &chunk);
517                let mut memory_usage_diff = 0;
518                for (row_id, key) in keys.into_iter().enumerate() {
519                    let mut agg_states = vec![];
520                    for i in 0..self.aggs.len() {
521                        let agg = &self.aggs[i];
522                        let datum = chunk
523                            .row_at(row_id)
524                            .0
525                            .datum_at(self.group_key_columns.len() + i)
526                            .to_owned_datum();
527                        let agg_state = agg.decode_state(datum)?;
528                        memory_usage_diff += agg_state.estimated_size() as i64;
529                        agg_states.push(agg_state);
530                    }
531                    groups.try_insert(key, agg_states).unwrap();
532                }
533
534                if !self.mem_context.add(memory_usage_diff) && check_memory {
535                    warn!(
536                        "not enough memory to load one partition agg state after spill which is not a normal case, so keep going"
537                    );
538                }
539            }
540        }
541
542        let mut input_stream = self.child.execute();
543        // consume all chunks to compute the agg result
544        #[for_await]
545        for chunk in &mut input_stream {
546            let chunk = StreamChunk::from(chunk?);
547            let keys = K::build_many(self.group_key_columns.as_slice(), &chunk);
548            let mut memory_usage_diff = 0;
549            for (row_id, key) in keys
550                .into_iter()
551                .enumerate()
552                .filter_by_bitmap(chunk.visibility())
553            {
554                let mut new_group = false;
555                let states = match groups.entry(key) {
556                    Entry::Occupied(entry) => entry.into_mut(),
557                    Entry::Vacant(entry) => {
558                        new_group = true;
559                        let states = self
560                            .aggs
561                            .iter()
562                            .map(|agg| agg.create_state())
563                            .try_collect()?;
564                        entry.insert(states)
565                    }
566                };
567
568                // TODO: currently not a vectorized implementation
569                for (agg, state) in self.aggs.iter().zip_eq_fast(states) {
570                    if !new_group {
571                        memory_usage_diff -= state.estimated_size() as i64;
572                    }
573                    agg.update_range(state, &chunk, row_id..row_id + 1).await?;
574                    memory_usage_diff += state.estimated_size() as i64;
575                }
576            }
577            // update memory usage
578            if !self.mem_context.add(memory_usage_diff) && check_memory {
579                if self.spill_backend.is_some() {
580                    need_to_spill = true;
581                    break;
582                } else {
583                    Err(BatchError::OutOfMemory(self.mem_context.mem_limit()))?;
584                }
585            }
586        }
587
588        if need_to_spill {
589            // A spilling version of aggregation based on the RFC: Spill Hash Aggregation https://github.com/risingwavelabs/rfcs/pull/89
590            // When HashAggExecutor told memory is insufficient, AggSpillManager will start to partition the hash table and spill to disk.
591            // After spilling the hash table, AggSpillManager will consume all chunks from the input executor,
592            // partition and spill to disk with the same hash function as the hash table spilling.
593            // Finally, we would get e.g. 20 partitions. Each partition should contain a portion of the original hash table and input data.
594            // A sub HashAggExecutor would be used to consume each partition one by one.
595            // If memory is still not enough in the sub HashAggExecutor, it will spill its hash table and input recursively.
596            info!(
597                "batch hash agg executor {} starts to spill out",
598                &self.identity
599            );
600            let mut agg_spill_manager = AggSpillManager::new(
601                self.spill_backend.clone().unwrap(),
602                &self.identity,
603                DEFAULT_SPILL_PARTITION_NUM,
604                self.group_key_types.clone(),
605                self.aggs.iter().map(|agg| agg.return_type()).collect(),
606                child_schema.data_types(),
607                self.chunk_size,
608                self.spill_metrics.clone(),
609            )?;
610            agg_spill_manager.init_writers().await?;
611
612            let mut memory_usage_diff = 0;
613            // Spill agg states.
614            for (key, states) in groups {
615                let key_row = key.deserialize(&self.group_key_types)?;
616                let mut agg_datums = vec![];
617                for (agg, state) in self.aggs.iter().zip_eq_fast(states) {
618                    let encode_state = agg.encode_state(&state)?;
619                    memory_usage_diff -= state.estimated_size() as i64;
620                    agg_datums.push(encode_state);
621                }
622                let agg_state_row = OwnedRow::from_iter(agg_datums.into_iter());
623                let hash_code = agg_spill_manager.spill_build_hasher.hash_one(key);
624                agg_spill_manager
625                    .write_agg_state_row(key_row.chain(agg_state_row), hash_code)
626                    .await?;
627            }
628
629            // Release memory occupied by agg hash map.
630            self.mem_context.add(memory_usage_diff);
631
632            // Spill input chunks.
633            #[for_await]
634            for chunk in input_stream {
635                let chunk: DataChunk = chunk?;
636                let hash_codes = chunk.get_hash_values(
637                    self.group_key_columns.as_slice(),
638                    agg_spill_manager.spill_build_hasher,
639                );
640                agg_spill_manager
641                    .write_input_chunk(
642                        chunk,
643                        hash_codes
644                            .into_iter()
645                            .map(|hash_code| hash_code.value())
646                            .collect(),
647                    )
648                    .await?;
649            }
650
651            agg_spill_manager.close_writers().await?;
652
653            // Process each partition one by one.
654            for i in 0..agg_spill_manager.partition_num {
655                let partition_size = agg_spill_manager.estimate_partition_size(i).await?;
656
657                let agg_state_stream = agg_spill_manager.read_agg_state_partition(i).await?;
658                let input_stream = agg_spill_manager.read_input_partition(i).await?;
659
660                let sub_hash_agg_executor: HashAggExecutor<K> = HashAggExecutor::new_inner(
661                    self.aggs.clone(),
662                    self.group_key_columns.clone(),
663                    self.group_key_types.clone(),
664                    self.schema.clone(),
665                    Box::new(WrapStreamExecutor::new(child_schema.clone(), input_stream)),
666                    Some(Box::new(WrapStreamExecutor::new(
667                        self.schema.clone(),
668                        agg_state_stream,
669                    ))),
670                    format!("{}-sub{}", self.identity.clone(), i),
671                    self.chunk_size,
672                    self.mem_context.clone(),
673                    self.spill_backend.clone(),
674                    self.spill_metrics.clone(),
675                    Some(partition_size),
676                    self.shutdown_rx.clone(),
677                );
678
679                debug!(
680                    "create sub_hash_agg {} for hash_agg {} to spill",
681                    sub_hash_agg_executor.identity, self.identity
682                );
683
684                let sub_hash_agg_stream = Box::new(sub_hash_agg_executor).execute();
685
686                #[for_await]
687                for chunk in sub_hash_agg_stream {
688                    let chunk = chunk?;
689                    yield chunk;
690                }
691
692                // Clear files of the current partition.
693                agg_spill_manager.clear_partition(i).await?;
694            }
695        } else {
696            // Don't use `into_iter` here, it may cause memory leak.
697            let mut result = groups.iter_mut();
698            let cardinality = self.chunk_size;
699            loop {
700                let mut group_builders: Vec<_> = self
701                    .group_key_types
702                    .iter()
703                    .map(|datatype| datatype.create_array_builder(cardinality))
704                    .collect();
705
706                let mut agg_builders: Vec<_> = self
707                    .aggs
708                    .iter()
709                    .map(|agg| agg.return_type().create_array_builder(cardinality))
710                    .collect();
711
712                let mut has_next = false;
713                let mut array_len = 0;
714                for (key, states) in result.by_ref().take(cardinality) {
715                    self.shutdown_rx.check()?;
716                    has_next = true;
717                    array_len += 1;
718                    key.deserialize_to_builders(&mut group_builders[..], &self.group_key_types)?;
719                    for ((agg, state), builder) in (self.aggs.iter())
720                        .zip_eq_fast(states)
721                        .zip_eq_fast(&mut agg_builders)
722                    {
723                        let result = agg.get_result(state).await?;
724                        builder.append(result);
725                    }
726                }
727                if !has_next {
728                    break; // exit loop
729                }
730
731                let columns = group_builders
732                    .into_iter()
733                    .chain(agg_builders)
734                    .map(|b| b.finish().into())
735                    .collect::<Vec<_>>();
736
737                let output = DataChunk::new(columns, array_len);
738                yield output;
739            }
740        }
741    }
742}
743
744#[cfg(test)]
745mod tests {
746    use std::alloc::{AllocError, Allocator, Global, Layout};
747    use std::ptr::NonNull;
748    use std::sync::atomic::{AtomicBool, Ordering};
749
750    use futures_async_stream::for_await;
751    use risingwave_common::metrics::LabelGuardedIntGauge;
752    use risingwave_common::test_prelude::DataChunkTestExt;
753    use risingwave_common::util::sort_util::{ColumnOrder, OrderType};
754    use risingwave_pb::data::PbDataType;
755    use risingwave_pb::data::data_type::TypeName;
756    use risingwave_pb::expr::agg_call::PbKind as PbAggKind;
757    use risingwave_pb::expr::{AggCall, InputRef};
758
759    use super::*;
760    use crate::executor::SortExecutor;
761    use crate::executor::test_utils::{MockExecutor, diff_executor_output};
762
763    const CHUNK_SIZE: usize = 1024;
764
765    #[tokio::test]
766    async fn execute_int32_grouped() {
767        let parent_mem = MemoryContext::root(LabelGuardedIntGauge::<4>::test_int_gauge(), u64::MAX);
768        {
769            let src_exec = Box::new(MockExecutor::with_chunk(
770                DataChunk::from_pretty(
771                    "i i i
772                 0 1 1
773                 1 1 1
774                 0 0 1
775                 1 1 2
776                 1 0 1
777                 0 0 2
778                 1 1 3
779                 0 1 2",
780                ),
781                Schema::new(vec![
782                    Field::unnamed(DataType::Int32),
783                    Field::unnamed(DataType::Int32),
784                    Field::unnamed(DataType::Int64),
785                ]),
786            ));
787
788            let agg_call = AggCall {
789                kind: PbAggKind::Sum as i32,
790                args: vec![InputRef {
791                    index: 2,
792                    r#type: Some(PbDataType {
793                        type_name: TypeName::Int32 as i32,
794                        ..Default::default()
795                    }),
796                }],
797                return_type: Some(PbDataType {
798                    type_name: TypeName::Int64 as i32,
799                    ..Default::default()
800                }),
801                distinct: false,
802                order_by: vec![],
803                filter: None,
804                direct_args: vec![],
805                udf: None,
806                scalar: None,
807            };
808
809            let agg_prost = HashAggNode {
810                group_key: vec![0, 1],
811                agg_calls: vec![agg_call],
812            };
813
814            let mem_context = MemoryContext::new(
815                Some(parent_mem.clone()),
816                LabelGuardedIntGauge::<4>::test_int_gauge(),
817            );
818            let actual_exec = HashAggExecutorBuilder::deserialize(
819                &agg_prost,
820                src_exec,
821                TaskId::default(),
822                "HashAggExecutor".to_owned(),
823                CHUNK_SIZE,
824                mem_context.clone(),
825                None,
826                BatchSpillMetrics::for_test(),
827                ShutdownToken::empty(),
828            )
829            .unwrap();
830
831            // TODO: currently the order is fixed unless the hasher is changed
832            let expect_exec = Box::new(MockExecutor::with_chunk(
833                DataChunk::from_pretty(
834                    "i i I
835                 1 0 1
836                 0 0 3
837                 0 1 3
838                 1 1 6",
839                ),
840                Schema::new(vec![
841                    Field::unnamed(DataType::Int32),
842                    Field::unnamed(DataType::Int32),
843                    Field::unnamed(DataType::Int64),
844                ]),
845            ));
846            diff_executor_output(actual_exec, expect_exec).await;
847
848            // check estimated memory usage = 4 groups x state size
849            assert_eq!(mem_context.get_bytes_used() as usize, 4 * 24);
850        }
851
852        // Ensure that agg memory counter has been dropped.
853        assert_eq!(0, parent_mem.get_bytes_used());
854    }
855
856    #[tokio::test]
857    async fn execute_count_star() {
858        let src_exec = MockExecutor::with_chunk(
859            DataChunk::from_pretty(
860                "i
861                 0
862                 1
863                 0
864                 1
865                 1
866                 0
867                 1
868                 0",
869            ),
870            Schema::new(vec![Field::unnamed(DataType::Int32)]),
871        );
872
873        let agg_call = AggCall {
874            kind: PbAggKind::Count as i32,
875            args: vec![],
876            return_type: Some(PbDataType {
877                type_name: TypeName::Int64 as i32,
878                ..Default::default()
879            }),
880            distinct: false,
881            order_by: vec![],
882            filter: None,
883            direct_args: vec![],
884            udf: None,
885            scalar: None,
886        };
887
888        let agg_prost = HashAggNode {
889            group_key: vec![],
890            agg_calls: vec![agg_call],
891        };
892
893        let actual_exec = HashAggExecutorBuilder::deserialize(
894            &agg_prost,
895            Box::new(src_exec),
896            TaskId::default(),
897            "HashAggExecutor".to_owned(),
898            CHUNK_SIZE,
899            MemoryContext::none(),
900            None,
901            BatchSpillMetrics::for_test(),
902            ShutdownToken::empty(),
903        )
904        .unwrap();
905
906        let expect_exec = MockExecutor::with_chunk(
907            DataChunk::from_pretty(
908                "I
909                 8",
910            ),
911            Schema::new(vec![Field::unnamed(DataType::Int64)]),
912        );
913        diff_executor_output(actual_exec, Box::new(expect_exec)).await;
914    }
915
916    /// A test to verify that `HashMap` may leak memory counter when using `into_iter`.
917    #[test]
918    #[should_panic] // TODO(MrCroxx): This bug is fixed and the test should panic. Remove the test and fix the related code later.
919    fn test_hashmap_into_iter_bug() {
920        let dropped: Arc<AtomicBool> = Arc::new(AtomicBool::new(false));
921
922        {
923            struct MyAllocInner {
924                drop_flag: Arc<AtomicBool>,
925            }
926
927            #[derive(Clone)]
928            struct MyAlloc {
929                #[expect(dead_code)]
930                inner: Arc<MyAllocInner>,
931            }
932
933            impl Drop for MyAllocInner {
934                fn drop(&mut self) {
935                    println!("MyAlloc freed.");
936                    self.drop_flag.store(true, Ordering::SeqCst);
937                }
938            }
939
940            unsafe impl Allocator for MyAlloc {
941                fn allocate(
942                    &self,
943                    layout: Layout,
944                ) -> std::result::Result<NonNull<[u8]>, AllocError> {
945                    let g = Global;
946                    g.allocate(layout)
947                }
948
949                unsafe fn deallocate(&self, ptr: NonNull<u8>, layout: Layout) {
950                    unsafe {
951                        let g = Global;
952                        g.deallocate(ptr, layout)
953                    }
954                }
955            }
956
957            let mut map = hashbrown::HashMap::with_capacity_in(
958                10,
959                MyAlloc {
960                    inner: Arc::new(MyAllocInner {
961                        drop_flag: dropped.clone(),
962                    }),
963                },
964            );
965            for i in 0..10 {
966                map.entry(i).or_insert_with(|| "i".to_owned());
967            }
968
969            for (k, v) in map {
970                println!("{}, {}", k, v);
971            }
972        }
973
974        assert!(!dropped.load(Ordering::SeqCst));
975    }
976
977    #[tokio::test]
978    async fn test_shutdown() {
979        let src_exec = MockExecutor::with_chunk(
980            DataChunk::from_pretty(
981                "i i i
982                 0 1 1",
983            ),
984            Schema::new(vec![Field::unnamed(DataType::Int32); 3]),
985        );
986
987        let agg_call = AggCall {
988            kind: PbAggKind::Sum as i32,
989            args: vec![InputRef {
990                index: 2,
991                r#type: Some(PbDataType {
992                    type_name: TypeName::Int32 as i32,
993                    ..Default::default()
994                }),
995            }],
996            return_type: Some(PbDataType {
997                type_name: TypeName::Int64 as i32,
998                ..Default::default()
999            }),
1000            distinct: false,
1001            order_by: vec![],
1002            filter: None,
1003            direct_args: vec![],
1004            udf: None,
1005            scalar: None,
1006        };
1007
1008        let agg_prost = HashAggNode {
1009            group_key: vec![0, 1],
1010            agg_calls: vec![agg_call],
1011        };
1012
1013        let (shutdown_tx, shutdown_rx) = ShutdownToken::new();
1014        let actual_exec = HashAggExecutorBuilder::deserialize(
1015            &agg_prost,
1016            Box::new(src_exec),
1017            TaskId::default(),
1018            "HashAggExecutor".to_owned(),
1019            CHUNK_SIZE,
1020            MemoryContext::none(),
1021            None,
1022            BatchSpillMetrics::for_test(),
1023            shutdown_rx,
1024        )
1025        .unwrap();
1026
1027        shutdown_tx.cancel();
1028
1029        #[for_await]
1030        for data in actual_exec.execute() {
1031            assert!(data.is_err());
1032            break;
1033        }
1034    }
1035
1036    fn create_order_by_executor(child: BoxedExecutor) -> BoxedExecutor {
1037        let column_orders = child
1038            .schema()
1039            .fields
1040            .iter()
1041            .enumerate()
1042            .map(|(i, _)| ColumnOrder {
1043                column_index: i,
1044                order_type: OrderType::ascending(),
1045            })
1046            .collect_vec();
1047
1048        Box::new(SortExecutor::new(
1049            child,
1050            Arc::new(column_orders),
1051            "SortExecutor".into(),
1052            CHUNK_SIZE,
1053            MemoryContext::none(),
1054            None,
1055            BatchSpillMetrics::for_test(),
1056        ))
1057    }
1058
1059    #[tokio::test]
1060    async fn test_spill_hash_agg() {
1061        let src_exec = Box::new(MockExecutor::with_chunk(
1062            DataChunk::from_pretty(
1063                "i i i
1064                 0 1 1
1065                 1 1 1
1066                 0 0 1
1067                 1 1 2
1068                 1 0 1
1069                 0 0 2
1070                 1 1 3
1071                 0 1 2",
1072            ),
1073            Schema::new(vec![
1074                Field::unnamed(DataType::Int32),
1075                Field::unnamed(DataType::Int32),
1076                Field::unnamed(DataType::Int64),
1077            ]),
1078        ));
1079
1080        let agg_call = AggCall {
1081            kind: PbAggKind::Sum as i32,
1082            args: vec![InputRef {
1083                index: 2,
1084                r#type: Some(PbDataType {
1085                    type_name: TypeName::Int32 as i32,
1086                    ..Default::default()
1087                }),
1088            }],
1089            return_type: Some(PbDataType {
1090                type_name: TypeName::Int64 as i32,
1091                ..Default::default()
1092            }),
1093            distinct: false,
1094            order_by: vec![],
1095            filter: None,
1096            direct_args: vec![],
1097            udf: None,
1098            scalar: None,
1099        };
1100
1101        let agg_prost = HashAggNode {
1102            group_key: vec![0, 1],
1103            agg_calls: vec![agg_call],
1104        };
1105
1106        let mem_context =
1107            MemoryContext::new_with_mem_limit(None, LabelGuardedIntGauge::<4>::test_int_gauge(), 0);
1108        let actual_exec = HashAggExecutorBuilder::deserialize(
1109            &agg_prost,
1110            src_exec,
1111            TaskId::default(),
1112            "HashAggExecutor".to_owned(),
1113            CHUNK_SIZE,
1114            mem_context.clone(),
1115            Some(SpillBackend::Memory),
1116            BatchSpillMetrics::for_test(),
1117            ShutdownToken::empty(),
1118        )
1119        .unwrap();
1120
1121        let actual_exec = create_order_by_executor(actual_exec);
1122
1123        let expect_exec = Box::new(MockExecutor::with_chunk(
1124            DataChunk::from_pretty(
1125                "i i I
1126                 1 0 1
1127                 0 0 3
1128                 0 1 3
1129                 1 1 6",
1130            ),
1131            Schema::new(vec![
1132                Field::unnamed(DataType::Int32),
1133                Field::unnamed(DataType::Int32),
1134                Field::unnamed(DataType::Int64),
1135            ]),
1136        ));
1137
1138        let expect_exec = create_order_by_executor(expect_exec);
1139        diff_executor_output(actual_exec, expect_exec).await;
1140    }
1141}