risingwave_batch_executors/executor/
hash_agg.rs

1// Copyright 2024 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::hash::BuildHasher;
16use std::marker::PhantomData;
17use std::sync::Arc;
18
19use bytes::Bytes;
20use futures_async_stream::try_stream;
21use hashbrown::hash_map::Entry;
22use itertools::Itertools;
23use risingwave_common::array::{DataChunk, StreamChunk};
24use risingwave_common::bitmap::{Bitmap, FilterByBitmap};
25use risingwave_common::catalog::{Field, Schema};
26use risingwave_common::hash::{HashKey, HashKeyDispatcher, PrecomputedBuildHasher};
27use risingwave_common::memory::MemoryContext;
28use risingwave_common::row::{OwnedRow, Row, RowExt};
29use risingwave_common::types::{DataType, ToOwnedDatum};
30use risingwave_common::util::chunk_coalesce::DataChunkBuilder;
31use risingwave_common::util::iter_util::ZipEqFast;
32use risingwave_common_estimate_size::EstimateSize;
33use risingwave_expr::aggregate::{AggCall, AggregateState, BoxedAggregateFunction};
34use risingwave_pb::Message;
35use risingwave_pb::batch_plan::HashAggNode;
36use risingwave_pb::batch_plan::plan_node::NodeBody;
37use risingwave_pb::data::DataChunk as PbDataChunk;
38
39use crate::error::{BatchError, Result};
40use crate::executor::aggregation::build as build_agg;
41use crate::executor::{
42    BoxedDataChunkStream, BoxedExecutor, BoxedExecutorBuilder, Executor, ExecutorBuilder,
43    WrapStreamExecutor,
44};
45use crate::monitor::BatchSpillMetrics;
46use crate::spill::spill_op::SpillBackend::Disk;
47use crate::spill::spill_op::{
48    DEFAULT_SPILL_PARTITION_NUM, SPILL_AT_LEAST_MEMORY, SpillBackend, SpillBuildHasher, SpillOp,
49};
50use crate::task::{ShutdownToken, TaskId};
51
52type AggHashMap<K, A> = hashbrown::HashMap<K, Vec<AggregateState>, PrecomputedBuildHasher, A>;
53
54/// A dispatcher to help create specialized hash agg executor.
55impl HashKeyDispatcher for HashAggExecutorBuilder {
56    type Output = BoxedExecutor;
57
58    fn dispatch_impl<K: HashKey>(self) -> Self::Output {
59        Box::new(HashAggExecutor::<K>::new(
60            Arc::new(self.aggs),
61            self.group_key_columns,
62            self.group_key_types,
63            self.schema,
64            self.child,
65            self.identity,
66            self.chunk_size,
67            self.mem_context,
68            self.spill_backend,
69            self.spill_metrics,
70            self.shutdown_rx,
71        ))
72    }
73
74    fn data_types(&self) -> &[DataType] {
75        &self.group_key_types
76    }
77}
78
79pub struct HashAggExecutorBuilder {
80    aggs: Vec<BoxedAggregateFunction>,
81    group_key_columns: Vec<usize>,
82    group_key_types: Vec<DataType>,
83    child: BoxedExecutor,
84    schema: Schema,
85    #[expect(dead_code)]
86    task_id: TaskId,
87    identity: String,
88    chunk_size: usize,
89    mem_context: MemoryContext,
90    spill_backend: Option<SpillBackend>,
91    spill_metrics: Arc<BatchSpillMetrics>,
92    shutdown_rx: ShutdownToken,
93}
94
95impl HashAggExecutorBuilder {
96    fn deserialize(
97        hash_agg_node: &HashAggNode,
98        child: BoxedExecutor,
99        task_id: TaskId,
100        identity: String,
101        chunk_size: usize,
102        mem_context: MemoryContext,
103        spill_backend: Option<SpillBackend>,
104        spill_metrics: Arc<BatchSpillMetrics>,
105        shutdown_rx: ShutdownToken,
106    ) -> Result<BoxedExecutor> {
107        let aggs: Vec<_> = hash_agg_node
108            .get_agg_calls()
109            .iter()
110            .map(|agg| AggCall::from_protobuf(agg).and_then(|agg| build_agg(&agg)))
111            .try_collect()?;
112
113        let group_key_columns = hash_agg_node
114            .get_group_key()
115            .iter()
116            .map(|x| *x as usize)
117            .collect_vec();
118
119        let child_schema = child.schema();
120
121        let group_key_types = group_key_columns
122            .iter()
123            .map(|i| child_schema.fields[*i].data_type.clone())
124            .collect_vec();
125
126        let fields = group_key_types
127            .iter()
128            .cloned()
129            .chain(aggs.iter().map(|e| e.return_type()))
130            .map(Field::unnamed)
131            .collect::<Vec<Field>>();
132
133        let builder = HashAggExecutorBuilder {
134            aggs,
135            group_key_columns,
136            group_key_types,
137            child,
138            schema: Schema { fields },
139            task_id,
140            identity,
141            chunk_size,
142            mem_context,
143            spill_backend,
144            spill_metrics,
145            shutdown_rx,
146        };
147
148        Ok(builder.dispatch())
149    }
150}
151
152impl BoxedExecutorBuilder for HashAggExecutorBuilder {
153    async fn new_boxed_executor(
154        source: &ExecutorBuilder<'_>,
155        inputs: Vec<BoxedExecutor>,
156    ) -> Result<BoxedExecutor> {
157        let [child]: [_; 1] = inputs.try_into().unwrap();
158
159        let hash_agg_node = try_match_expand!(
160            source.plan_node().get_node_body().unwrap(),
161            NodeBody::HashAgg
162        )?;
163
164        let identity = source.plan_node().get_identity();
165
166        let spill_metrics = source.context().spill_metrics();
167
168        Self::deserialize(
169            hash_agg_node,
170            child,
171            source.task_id.clone(),
172            identity.clone(),
173            source.context().get_config().developer.chunk_size,
174            source.context().create_executor_mem_context(identity),
175            if source.context().get_config().enable_spill {
176                Some(Disk)
177            } else {
178                None
179            },
180            spill_metrics,
181            source.shutdown_rx().clone(),
182        )
183    }
184}
185
186/// `HashAggExecutor` implements the hash aggregate algorithm.
187pub struct HashAggExecutor<K> {
188    /// Aggregate functions.
189    aggs: Arc<Vec<BoxedAggregateFunction>>,
190    /// Column indexes that specify a group
191    group_key_columns: Vec<usize>,
192    /// Data types of group key columns
193    group_key_types: Vec<DataType>,
194    /// Output schema
195    schema: Schema,
196    child: BoxedExecutor,
197    /// Used to initialize the state of the aggregation from the spilled files.
198    init_agg_state_executor: Option<BoxedExecutor>,
199    identity: String,
200    chunk_size: usize,
201    mem_context: MemoryContext,
202    spill_backend: Option<SpillBackend>,
203    spill_metrics: Arc<BatchSpillMetrics>,
204    /// The upper bound of memory usage for this executor.
205    memory_upper_bound: Option<u64>,
206    shutdown_rx: ShutdownToken,
207    _phantom: PhantomData<K>,
208}
209
210impl<K> HashAggExecutor<K> {
211    #[allow(clippy::too_many_arguments)]
212    pub fn new(
213        aggs: Arc<Vec<BoxedAggregateFunction>>,
214        group_key_columns: Vec<usize>,
215        group_key_types: Vec<DataType>,
216        schema: Schema,
217        child: BoxedExecutor,
218        identity: String,
219        chunk_size: usize,
220        mem_context: MemoryContext,
221        spill_backend: Option<SpillBackend>,
222        spill_metrics: Arc<BatchSpillMetrics>,
223        shutdown_rx: ShutdownToken,
224    ) -> Self {
225        Self::new_inner(
226            aggs,
227            group_key_columns,
228            group_key_types,
229            schema,
230            child,
231            None,
232            identity,
233            chunk_size,
234            mem_context,
235            spill_backend,
236            spill_metrics,
237            None,
238            shutdown_rx,
239        )
240    }
241
242    #[allow(clippy::too_many_arguments)]
243    fn new_inner(
244        aggs: Arc<Vec<BoxedAggregateFunction>>,
245        group_key_columns: Vec<usize>,
246        group_key_types: Vec<DataType>,
247        schema: Schema,
248        child: BoxedExecutor,
249        init_agg_state_executor: Option<BoxedExecutor>,
250        identity: String,
251        chunk_size: usize,
252        mem_context: MemoryContext,
253        spill_backend: Option<SpillBackend>,
254        spill_metrics: Arc<BatchSpillMetrics>,
255        memory_upper_bound: Option<u64>,
256        shutdown_rx: ShutdownToken,
257    ) -> Self {
258        HashAggExecutor {
259            aggs,
260            group_key_columns,
261            group_key_types,
262            schema,
263            child,
264            init_agg_state_executor,
265            identity,
266            chunk_size,
267            mem_context,
268            spill_backend,
269            spill_metrics,
270            memory_upper_bound,
271            shutdown_rx,
272            _phantom: PhantomData,
273        }
274    }
275}
276
277impl<K: HashKey + Send + Sync> Executor for HashAggExecutor<K> {
278    fn schema(&self) -> &Schema {
279        &self.schema
280    }
281
282    fn identity(&self) -> &str {
283        &self.identity
284    }
285
286    fn execute(self: Box<Self>) -> BoxedDataChunkStream {
287        self.do_execute()
288    }
289}
290
291/// `AggSpillManager` is used to manage how to write spill data file and read them back.
292/// The spill data first need to be partitioned. Each partition contains 2 files: `agg_state_file` and `input_chunks_file`.
293/// The spill file consume a data chunk and serialize the chunk into a protobuf bytes.
294/// Finally, spill file content will look like the below.
295/// The file write pattern is append-only and the read pattern is sequential scan.
296/// This can maximize the disk IO performance.
297///
298/// ```text
299/// [proto_len]
300/// [proto_bytes]
301/// ...
302/// [proto_len]
303/// [proto_bytes]
304/// ```
305pub struct AggSpillManager {
306    op: SpillOp,
307    partition_num: usize,
308    agg_state_writers: Vec<opendal::Writer>,
309    agg_state_chunk_builder: Vec<DataChunkBuilder>,
310    input_writers: Vec<opendal::Writer>,
311    input_chunk_builders: Vec<DataChunkBuilder>,
312    spill_build_hasher: SpillBuildHasher,
313    group_key_types: Vec<DataType>,
314    child_data_types: Vec<DataType>,
315    agg_data_types: Vec<DataType>,
316    spill_chunk_size: usize,
317    spill_metrics: Arc<BatchSpillMetrics>,
318}
319
320impl AggSpillManager {
321    fn new(
322        spill_backend: SpillBackend,
323        agg_identity: &String,
324        partition_num: usize,
325        group_key_types: Vec<DataType>,
326        agg_data_types: Vec<DataType>,
327        child_data_types: Vec<DataType>,
328        spill_chunk_size: usize,
329        spill_metrics: Arc<BatchSpillMetrics>,
330    ) -> Result<Self> {
331        let suffix_uuid = uuid::Uuid::new_v4();
332        let dir = format!("{}-{}/", agg_identity, suffix_uuid);
333        let op = SpillOp::create(dir, spill_backend)?;
334        let agg_state_writers = Vec::with_capacity(partition_num);
335        let agg_state_chunk_builder = Vec::with_capacity(partition_num);
336        let input_writers = Vec::with_capacity(partition_num);
337        let input_chunk_builders = Vec::with_capacity(partition_num);
338        // Use uuid to generate an unique hasher so that when recursive spilling happens they would use a different hasher to avoid data skew.
339        let spill_build_hasher = SpillBuildHasher(suffix_uuid.as_u64_pair().1);
340        Ok(Self {
341            op,
342            partition_num,
343            agg_state_writers,
344            agg_state_chunk_builder,
345            input_writers,
346            input_chunk_builders,
347            spill_build_hasher,
348            group_key_types,
349            child_data_types,
350            agg_data_types,
351            spill_chunk_size,
352            spill_metrics,
353        })
354    }
355
356    async fn init_writers(&mut self) -> Result<()> {
357        for i in 0..self.partition_num {
358            let agg_state_partition_file_name = format!("agg-state-p{}", i);
359            let w = self.op.writer_with(&agg_state_partition_file_name).await?;
360            self.agg_state_writers.push(w);
361
362            let partition_file_name = format!("input-chunks-p{}", i);
363            let w = self.op.writer_with(&partition_file_name).await?;
364            self.input_writers.push(w);
365            self.input_chunk_builders.push(DataChunkBuilder::new(
366                self.child_data_types.clone(),
367                self.spill_chunk_size,
368            ));
369            self.agg_state_chunk_builder.push(DataChunkBuilder::new(
370                self.group_key_types
371                    .iter()
372                    .cloned()
373                    .chain(self.agg_data_types.iter().cloned())
374                    .collect(),
375                self.spill_chunk_size,
376            ));
377        }
378        Ok(())
379    }
380
381    async fn write_agg_state_row(&mut self, row: impl Row, hash_code: u64) -> Result<()> {
382        let partition = hash_code as usize % self.partition_num;
383        if let Some(output_chunk) = self.agg_state_chunk_builder[partition].append_one_row(row) {
384            let chunk_pb: PbDataChunk = output_chunk.to_protobuf();
385            let buf = Message::encode_to_vec(&chunk_pb);
386            let len_bytes = Bytes::copy_from_slice(&(buf.len() as u32).to_le_bytes());
387            self.spill_metrics
388                .batch_spill_write_bytes
389                .inc_by((buf.len() + len_bytes.len()) as u64);
390            self.agg_state_writers[partition].write(len_bytes).await?;
391            self.agg_state_writers[partition].write(buf).await?;
392        }
393        Ok(())
394    }
395
396    async fn write_input_chunk(&mut self, chunk: DataChunk, hash_codes: Vec<u64>) -> Result<()> {
397        let (columns, vis) = chunk.into_parts_v2();
398        for partition in 0..self.partition_num {
399            let new_vis = vis.clone()
400                & Bitmap::from_iter(
401                    hash_codes
402                        .iter()
403                        .map(|hash_code| (*hash_code as usize % self.partition_num) == partition),
404                );
405            let new_chunk = DataChunk::from_parts(columns.clone(), new_vis);
406            for output_chunk in self.input_chunk_builders[partition].append_chunk(new_chunk) {
407                let chunk_pb: PbDataChunk = output_chunk.to_protobuf();
408                let buf = Message::encode_to_vec(&chunk_pb);
409                let len_bytes = Bytes::copy_from_slice(&(buf.len() as u32).to_le_bytes());
410                self.spill_metrics
411                    .batch_spill_write_bytes
412                    .inc_by((buf.len() + len_bytes.len()) as u64);
413                self.input_writers[partition].write(len_bytes).await?;
414                self.input_writers[partition].write(buf).await?;
415            }
416        }
417        Ok(())
418    }
419
420    async fn close_writers(&mut self) -> Result<()> {
421        for partition in 0..self.partition_num {
422            if let Some(output_chunk) = self.agg_state_chunk_builder[partition].consume_all() {
423                let chunk_pb: PbDataChunk = output_chunk.to_protobuf();
424                let buf = Message::encode_to_vec(&chunk_pb);
425                let len_bytes = Bytes::copy_from_slice(&(buf.len() as u32).to_le_bytes());
426                self.spill_metrics
427                    .batch_spill_write_bytes
428                    .inc_by((buf.len() + len_bytes.len()) as u64);
429                self.agg_state_writers[partition].write(len_bytes).await?;
430                self.agg_state_writers[partition].write(buf).await?;
431            }
432
433            if let Some(output_chunk) = self.input_chunk_builders[partition].consume_all() {
434                let chunk_pb: PbDataChunk = output_chunk.to_protobuf();
435                let buf = Message::encode_to_vec(&chunk_pb);
436                let len_bytes = Bytes::copy_from_slice(&(buf.len() as u32).to_le_bytes());
437                self.spill_metrics
438                    .batch_spill_write_bytes
439                    .inc_by((buf.len() + len_bytes.len()) as u64);
440                self.input_writers[partition].write(len_bytes).await?;
441                self.input_writers[partition].write(buf).await?;
442            }
443        }
444
445        for mut w in self.agg_state_writers.drain(..) {
446            w.close().await?;
447        }
448        for mut w in self.input_writers.drain(..) {
449            w.close().await?;
450        }
451        Ok(())
452    }
453
454    async fn read_agg_state_partition(&mut self, partition: usize) -> Result<BoxedDataChunkStream> {
455        let agg_state_partition_file_name = format!("agg-state-p{}", partition);
456        let r = self.op.reader_with(&agg_state_partition_file_name).await?;
457        Ok(SpillOp::read_stream(r, self.spill_metrics.clone()))
458    }
459
460    async fn read_input_partition(&mut self, partition: usize) -> Result<BoxedDataChunkStream> {
461        let input_partition_file_name = format!("input-chunks-p{}", partition);
462        let r = self.op.reader_with(&input_partition_file_name).await?;
463        Ok(SpillOp::read_stream(r, self.spill_metrics.clone()))
464    }
465
466    async fn estimate_partition_size(&self, partition: usize) -> Result<u64> {
467        let agg_state_partition_file_name = format!("agg-state-p{}", partition);
468        let agg_state_size = self
469            .op
470            .stat(&agg_state_partition_file_name)
471            .await?
472            .content_length();
473        let input_partition_file_name = format!("input-chunks-p{}", partition);
474        let input_size = self
475            .op
476            .stat(&input_partition_file_name)
477            .await?
478            .content_length();
479        Ok(agg_state_size + input_size)
480    }
481
482    async fn clear_partition(&mut self, partition: usize) -> Result<()> {
483        let agg_state_partition_file_name = format!("agg-state-p{}", partition);
484        self.op.delete(&agg_state_partition_file_name).await?;
485        let input_partition_file_name = format!("input-chunks-p{}", partition);
486        self.op.delete(&input_partition_file_name).await?;
487        Ok(())
488    }
489}
490
491impl<K: HashKey + Send + Sync> HashAggExecutor<K> {
492    #[try_stream(boxed, ok = DataChunk, error = BatchError)]
493    async fn do_execute(self: Box<Self>) {
494        let child_schema = self.child.schema().clone();
495        let mut need_to_spill = false;
496        // If the memory upper bound is less than 1MB, we don't need to check memory usage.
497        let check_memory = match self.memory_upper_bound {
498            Some(upper_bound) => upper_bound > SPILL_AT_LEAST_MEMORY,
499            None => true,
500        };
501
502        // hash map for each agg groups
503        let mut groups = AggHashMap::<K, _>::with_hasher_in(
504            PrecomputedBuildHasher,
505            self.mem_context.global_allocator(),
506        );
507
508        if let Some(init_agg_state_executor) = self.init_agg_state_executor {
509            // `init_agg_state_executor` exists which means this is a sub `HashAggExecutor` used to consume spilling data.
510            // The spilled agg states by its parent executor need to be recovered first.
511            let mut init_agg_state_stream = init_agg_state_executor.execute();
512            #[for_await]
513            for chunk in &mut init_agg_state_stream {
514                let chunk = chunk?;
515                let group_key_indices = (0..self.group_key_columns.len()).collect_vec();
516                let keys = K::build_many(&group_key_indices, &chunk);
517                let mut memory_usage_diff = 0;
518                for (row_id, key) in keys.into_iter().enumerate() {
519                    let mut agg_states = vec![];
520                    for i in 0..self.aggs.len() {
521                        let agg = &self.aggs[i];
522                        let datum = chunk
523                            .row_at(row_id)
524                            .0
525                            .datum_at(self.group_key_columns.len() + i)
526                            .to_owned_datum();
527                        let agg_state = agg.decode_state(datum)?;
528                        memory_usage_diff += agg_state.estimated_size() as i64;
529                        agg_states.push(agg_state);
530                    }
531                    groups.try_insert(key, agg_states).unwrap();
532                }
533
534                if !self.mem_context.add(memory_usage_diff) && check_memory {
535                    warn!(
536                        "not enough memory to load one partition agg state after spill which is not a normal case, so keep going"
537                    );
538                }
539            }
540        }
541
542        let mut input_stream = self.child.execute();
543        // consume all chunks to compute the agg result
544        #[for_await]
545        for chunk in &mut input_stream {
546            let chunk = StreamChunk::from(chunk?);
547            let keys = K::build_many(self.group_key_columns.as_slice(), &chunk);
548            let mut memory_usage_diff = 0;
549            for (row_id, key) in keys
550                .into_iter()
551                .enumerate()
552                .filter_by_bitmap(chunk.visibility())
553            {
554                let mut new_group = false;
555                let states = match groups.entry(key) {
556                    Entry::Occupied(entry) => entry.into_mut(),
557                    Entry::Vacant(entry) => {
558                        new_group = true;
559                        let states = self
560                            .aggs
561                            .iter()
562                            .map(|agg| agg.create_state())
563                            .try_collect()?;
564                        entry.insert(states)
565                    }
566                };
567
568                // TODO: currently not a vectorized implementation
569                for (agg, state) in self.aggs.iter().zip_eq_fast(states) {
570                    if !new_group {
571                        memory_usage_diff -= state.estimated_size() as i64;
572                    }
573                    agg.update_range(state, &chunk, row_id..row_id + 1).await?;
574                    memory_usage_diff += state.estimated_size() as i64;
575                }
576            }
577            // update memory usage
578            if !self.mem_context.add(memory_usage_diff) && check_memory {
579                if self.spill_backend.is_some() {
580                    need_to_spill = true;
581                    break;
582                } else {
583                    Err(BatchError::OutOfMemory(self.mem_context.mem_limit()))?;
584                }
585            }
586        }
587
588        if need_to_spill {
589            // A spilling version of aggregation based on the RFC: Spill Hash Aggregation https://github.com/risingwavelabs/rfcs/pull/89
590            // When HashAggExecutor told memory is insufficient, AggSpillManager will start to partition the hash table and spill to disk.
591            // After spilling the hash table, AggSpillManager will consume all chunks from the input executor,
592            // partition and spill to disk with the same hash function as the hash table spilling.
593            // Finally, we would get e.g. 20 partitions. Each partition should contain a portion of the original hash table and input data.
594            // A sub HashAggExecutor would be used to consume each partition one by one.
595            // If memory is still not enough in the sub HashAggExecutor, it will spill its hash table and input recursively.
596            info!(
597                "batch hash agg executor {} starts to spill out",
598                &self.identity
599            );
600            let mut agg_spill_manager = AggSpillManager::new(
601                self.spill_backend.clone().unwrap(),
602                &self.identity,
603                DEFAULT_SPILL_PARTITION_NUM,
604                self.group_key_types.clone(),
605                self.aggs.iter().map(|agg| agg.return_type()).collect(),
606                child_schema.data_types(),
607                self.chunk_size,
608                self.spill_metrics.clone(),
609            )?;
610            agg_spill_manager.init_writers().await?;
611
612            let mut memory_usage_diff = 0;
613            // Spill agg states.
614            for (key, states) in groups {
615                let key_row = key.deserialize(&self.group_key_types)?;
616                let mut agg_datums = vec![];
617                for (agg, state) in self.aggs.iter().zip_eq_fast(states) {
618                    let encode_state = agg.encode_state(&state)?;
619                    memory_usage_diff -= state.estimated_size() as i64;
620                    agg_datums.push(encode_state);
621                }
622                let agg_state_row = OwnedRow::from_iter(agg_datums.into_iter());
623                let hash_code = agg_spill_manager.spill_build_hasher.hash_one(key);
624                agg_spill_manager
625                    .write_agg_state_row(key_row.chain(agg_state_row), hash_code)
626                    .await?;
627            }
628
629            // Release memory occupied by agg hash map.
630            self.mem_context.add(memory_usage_diff);
631
632            // Spill input chunks.
633            #[for_await]
634            for chunk in input_stream {
635                let chunk: DataChunk = chunk?;
636                let hash_codes = chunk.get_hash_values(
637                    self.group_key_columns.as_slice(),
638                    agg_spill_manager.spill_build_hasher,
639                );
640                agg_spill_manager
641                    .write_input_chunk(
642                        chunk,
643                        hash_codes
644                            .into_iter()
645                            .map(|hash_code| hash_code.value())
646                            .collect(),
647                    )
648                    .await?;
649            }
650
651            agg_spill_manager.close_writers().await?;
652
653            // Process each partition one by one.
654            for i in 0..agg_spill_manager.partition_num {
655                let partition_size = agg_spill_manager.estimate_partition_size(i).await?;
656
657                let agg_state_stream = agg_spill_manager.read_agg_state_partition(i).await?;
658                let input_stream = agg_spill_manager.read_input_partition(i).await?;
659
660                let sub_hash_agg_executor: HashAggExecutor<K> = HashAggExecutor::new_inner(
661                    self.aggs.clone(),
662                    self.group_key_columns.clone(),
663                    self.group_key_types.clone(),
664                    self.schema.clone(),
665                    Box::new(WrapStreamExecutor::new(child_schema.clone(), input_stream)),
666                    Some(Box::new(WrapStreamExecutor::new(
667                        self.schema.clone(),
668                        agg_state_stream,
669                    ))),
670                    format!("{}-sub{}", self.identity.clone(), i),
671                    self.chunk_size,
672                    self.mem_context.clone(),
673                    self.spill_backend.clone(),
674                    self.spill_metrics.clone(),
675                    Some(partition_size),
676                    self.shutdown_rx.clone(),
677                );
678
679                debug!(
680                    "create sub_hash_agg {} for hash_agg {} to spill",
681                    sub_hash_agg_executor.identity, self.identity
682                );
683
684                let sub_hash_agg_stream = Box::new(sub_hash_agg_executor).execute();
685
686                #[for_await]
687                for chunk in sub_hash_agg_stream {
688                    let chunk = chunk?;
689                    yield chunk;
690                }
691
692                // Clear files of the current partition.
693                agg_spill_manager.clear_partition(i).await?;
694            }
695        } else {
696            // Don't use `into_iter` here, it may cause memory leak.
697            let mut result = groups.iter_mut();
698            let cardinality = self.chunk_size;
699            loop {
700                let mut group_builders: Vec<_> = self
701                    .group_key_types
702                    .iter()
703                    .map(|datatype| datatype.create_array_builder(cardinality))
704                    .collect();
705
706                let mut agg_builders: Vec<_> = self
707                    .aggs
708                    .iter()
709                    .map(|agg| agg.return_type().create_array_builder(cardinality))
710                    .collect();
711
712                let mut has_next = false;
713                let mut array_len = 0;
714                for (key, states) in result.by_ref().take(cardinality) {
715                    self.shutdown_rx.check()?;
716                    has_next = true;
717                    array_len += 1;
718                    key.deserialize_to_builders(&mut group_builders[..], &self.group_key_types)?;
719                    for ((agg, state), builder) in (self.aggs.iter())
720                        .zip_eq_fast(states)
721                        .zip_eq_fast(&mut agg_builders)
722                    {
723                        let result = agg.get_result(state).await?;
724                        builder.append(result);
725                    }
726                }
727                if !has_next {
728                    break; // exit loop
729                }
730
731                let columns = group_builders
732                    .into_iter()
733                    .chain(agg_builders)
734                    .map(|b| b.finish().into())
735                    .collect::<Vec<_>>();
736
737                let output = DataChunk::new(columns, array_len);
738                yield output;
739            }
740        }
741    }
742}
743
744#[cfg(test)]
745mod tests {
746    use std::alloc::{AllocError, Allocator, Global, Layout};
747    use std::ptr::NonNull;
748    use std::sync::atomic::{AtomicBool, Ordering};
749
750    use allocator_api2::alloc::{AllocError as AllocErrorApi2, Allocator as AllocatorApi2};
751    use futures_async_stream::for_await;
752    use risingwave_common::metrics::LabelGuardedIntGauge;
753    use risingwave_common::test_prelude::DataChunkTestExt;
754    use risingwave_common::util::sort_util::{ColumnOrder, OrderType};
755    use risingwave_pb::data::PbDataType;
756    use risingwave_pb::data::data_type::TypeName;
757    use risingwave_pb::expr::agg_call::PbKind as PbAggKind;
758    use risingwave_pb::expr::{AggCall, InputRef};
759
760    use super::*;
761    use crate::executor::SortExecutor;
762    use crate::executor::test_utils::{MockExecutor, diff_executor_output};
763
764    const CHUNK_SIZE: usize = 1024;
765
766    #[tokio::test]
767    async fn execute_int32_grouped() {
768        let parent_mem = MemoryContext::root(LabelGuardedIntGauge::test_int_gauge::<4>(), u64::MAX);
769        {
770            let src_exec = Box::new(MockExecutor::with_chunk(
771                DataChunk::from_pretty(
772                    "i i i
773                 0 1 1
774                 1 1 1
775                 0 0 1
776                 1 1 2
777                 1 0 1
778                 0 0 2
779                 1 1 3
780                 0 1 2",
781                ),
782                Schema::new(vec![
783                    Field::unnamed(DataType::Int32),
784                    Field::unnamed(DataType::Int32),
785                    Field::unnamed(DataType::Int64),
786                ]),
787            ));
788
789            let agg_call = AggCall {
790                kind: PbAggKind::Sum as i32,
791                args: vec![InputRef {
792                    index: 2,
793                    r#type: Some(PbDataType {
794                        type_name: TypeName::Int32 as i32,
795                        ..Default::default()
796                    }),
797                }],
798                return_type: Some(PbDataType {
799                    type_name: TypeName::Int64 as i32,
800                    ..Default::default()
801                }),
802                distinct: false,
803                order_by: vec![],
804                filter: None,
805                direct_args: vec![],
806                udf: None,
807                scalar: None,
808            };
809
810            let agg_prost = HashAggNode {
811                group_key: vec![0, 1],
812                agg_calls: vec![agg_call],
813            };
814
815            let mem_context = MemoryContext::new(
816                Some(parent_mem.clone()),
817                LabelGuardedIntGauge::test_int_gauge::<4>(),
818            );
819            let actual_exec = HashAggExecutorBuilder::deserialize(
820                &agg_prost,
821                src_exec,
822                TaskId::default(),
823                "HashAggExecutor".to_owned(),
824                CHUNK_SIZE,
825                mem_context.clone(),
826                None,
827                BatchSpillMetrics::for_test(),
828                ShutdownToken::empty(),
829            )
830            .unwrap();
831
832            // TODO: currently the order is fixed unless the hasher is changed
833            let expect_exec = Box::new(MockExecutor::with_chunk(
834                DataChunk::from_pretty(
835                    "i i I
836                 1 0 1
837                 0 0 3
838                 0 1 3
839                 1 1 6",
840                ),
841                Schema::new(vec![
842                    Field::unnamed(DataType::Int32),
843                    Field::unnamed(DataType::Int32),
844                    Field::unnamed(DataType::Int64),
845                ]),
846            ));
847            diff_executor_output(actual_exec, expect_exec).await;
848
849            // check estimated memory usage = 4 groups x state size
850            assert_eq!(mem_context.get_bytes_used() as usize, 4 * 24);
851        }
852
853        // Ensure that agg memory counter has been dropped.
854        assert_eq!(0, parent_mem.get_bytes_used());
855    }
856
857    #[tokio::test]
858    async fn execute_count_star() {
859        let src_exec = MockExecutor::with_chunk(
860            DataChunk::from_pretty(
861                "i
862                 0
863                 1
864                 0
865                 1
866                 1
867                 0
868                 1
869                 0",
870            ),
871            Schema::new(vec![Field::unnamed(DataType::Int32)]),
872        );
873
874        let agg_call = AggCall {
875            kind: PbAggKind::Count as i32,
876            args: vec![],
877            return_type: Some(PbDataType {
878                type_name: TypeName::Int64 as i32,
879                ..Default::default()
880            }),
881            distinct: false,
882            order_by: vec![],
883            filter: None,
884            direct_args: vec![],
885            udf: None,
886            scalar: None,
887        };
888
889        let agg_prost = HashAggNode {
890            group_key: vec![],
891            agg_calls: vec![agg_call],
892        };
893
894        let actual_exec = HashAggExecutorBuilder::deserialize(
895            &agg_prost,
896            Box::new(src_exec),
897            TaskId::default(),
898            "HashAggExecutor".to_owned(),
899            CHUNK_SIZE,
900            MemoryContext::none(),
901            None,
902            BatchSpillMetrics::for_test(),
903            ShutdownToken::empty(),
904        )
905        .unwrap();
906
907        let expect_exec = MockExecutor::with_chunk(
908            DataChunk::from_pretty(
909                "I
910                 8",
911            ),
912            Schema::new(vec![Field::unnamed(DataType::Int64)]),
913        );
914        diff_executor_output(actual_exec, Box::new(expect_exec)).await;
915    }
916
917    /// A test to verify that `HashMap` may leak memory counter when using `into_iter`.
918    #[test]
919    #[should_panic] // TODO(MrCroxx): This bug is fixed and the test should panic. Remove the test and fix the related code later.
920    fn test_hashmap_into_iter_bug() {
921        let dropped: Arc<AtomicBool> = Arc::new(AtomicBool::new(false));
922
923        {
924            struct MyAllocInner {
925                drop_flag: Arc<AtomicBool>,
926            }
927
928            #[derive(Clone)]
929            struct MyAlloc {
930                #[expect(dead_code)]
931                inner: Arc<MyAllocInner>,
932            }
933
934            impl Drop for MyAllocInner {
935                fn drop(&mut self) {
936                    println!("MyAlloc freed.");
937                    self.drop_flag.store(true, Ordering::SeqCst);
938                }
939            }
940
941            unsafe impl Allocator for MyAlloc {
942                fn allocate(
943                    &self,
944                    layout: Layout,
945                ) -> std::result::Result<NonNull<[u8]>, AllocError> {
946                    let g = Global;
947                    g.allocate(layout)
948                }
949
950                unsafe fn deallocate(&self, ptr: NonNull<u8>, layout: Layout) {
951                    unsafe {
952                        let g = Global;
953                        g.deallocate(ptr, layout)
954                    }
955                }
956            }
957
958            unsafe impl AllocatorApi2 for MyAlloc {
959                fn allocate(
960                    &self,
961                    layout: Layout,
962                ) -> std::result::Result<NonNull<[u8]>, AllocErrorApi2> {
963                    let g = Global;
964                    g.allocate(layout).map_err(|_| AllocErrorApi2)
965                }
966
967                unsafe fn deallocate(&self, ptr: NonNull<u8>, layout: Layout) {
968                    unsafe {
969                        let g = Global;
970                        g.deallocate(ptr, layout)
971                    }
972                }
973            }
974
975            let mut map = hashbrown::HashMap::with_capacity_in(
976                10,
977                MyAlloc {
978                    inner: Arc::new(MyAllocInner {
979                        drop_flag: dropped.clone(),
980                    }),
981                },
982            );
983            for i in 0..10 {
984                map.entry(i).or_insert_with(|| "i".to_owned());
985            }
986
987            for (k, v) in map {
988                println!("{}, {}", k, v);
989            }
990        }
991
992        assert!(!dropped.load(Ordering::SeqCst));
993    }
994
995    #[tokio::test]
996    async fn test_shutdown() {
997        let src_exec = MockExecutor::with_chunk(
998            DataChunk::from_pretty(
999                "i i i
1000                 0 1 1",
1001            ),
1002            Schema::new(vec![Field::unnamed(DataType::Int32); 3]),
1003        );
1004
1005        let agg_call = AggCall {
1006            kind: PbAggKind::Sum as i32,
1007            args: vec![InputRef {
1008                index: 2,
1009                r#type: Some(PbDataType {
1010                    type_name: TypeName::Int32 as i32,
1011                    ..Default::default()
1012                }),
1013            }],
1014            return_type: Some(PbDataType {
1015                type_name: TypeName::Int64 as i32,
1016                ..Default::default()
1017            }),
1018            distinct: false,
1019            order_by: vec![],
1020            filter: None,
1021            direct_args: vec![],
1022            udf: None,
1023            scalar: None,
1024        };
1025
1026        let agg_prost = HashAggNode {
1027            group_key: vec![0, 1],
1028            agg_calls: vec![agg_call],
1029        };
1030
1031        let (shutdown_tx, shutdown_rx) = ShutdownToken::new();
1032        let actual_exec = HashAggExecutorBuilder::deserialize(
1033            &agg_prost,
1034            Box::new(src_exec),
1035            TaskId::default(),
1036            "HashAggExecutor".to_owned(),
1037            CHUNK_SIZE,
1038            MemoryContext::none(),
1039            None,
1040            BatchSpillMetrics::for_test(),
1041            shutdown_rx,
1042        )
1043        .unwrap();
1044
1045        shutdown_tx.cancel();
1046
1047        #[for_await]
1048        for data in actual_exec.execute() {
1049            assert!(data.is_err());
1050            break;
1051        }
1052    }
1053
1054    fn create_order_by_executor(child: BoxedExecutor) -> BoxedExecutor {
1055        let column_orders = child
1056            .schema()
1057            .fields
1058            .iter()
1059            .enumerate()
1060            .map(|(i, _)| ColumnOrder {
1061                column_index: i,
1062                order_type: OrderType::ascending(),
1063            })
1064            .collect_vec();
1065
1066        Box::new(SortExecutor::new(
1067            child,
1068            Arc::new(column_orders),
1069            "SortExecutor".into(),
1070            CHUNK_SIZE,
1071            MemoryContext::none(),
1072            None,
1073            BatchSpillMetrics::for_test(),
1074        ))
1075    }
1076
1077    #[tokio::test]
1078    async fn test_spill_hash_agg() {
1079        let src_exec = Box::new(MockExecutor::with_chunk(
1080            DataChunk::from_pretty(
1081                "i i i
1082                 0 1 1
1083                 1 1 1
1084                 0 0 1
1085                 1 1 2
1086                 1 0 1
1087                 0 0 2
1088                 1 1 3
1089                 0 1 2",
1090            ),
1091            Schema::new(vec![
1092                Field::unnamed(DataType::Int32),
1093                Field::unnamed(DataType::Int32),
1094                Field::unnamed(DataType::Int64),
1095            ]),
1096        ));
1097
1098        let agg_call = AggCall {
1099            kind: PbAggKind::Sum as i32,
1100            args: vec![InputRef {
1101                index: 2,
1102                r#type: Some(PbDataType {
1103                    type_name: TypeName::Int32 as i32,
1104                    ..Default::default()
1105                }),
1106            }],
1107            return_type: Some(PbDataType {
1108                type_name: TypeName::Int64 as i32,
1109                ..Default::default()
1110            }),
1111            distinct: false,
1112            order_by: vec![],
1113            filter: None,
1114            direct_args: vec![],
1115            udf: None,
1116            scalar: None,
1117        };
1118
1119        let agg_prost = HashAggNode {
1120            group_key: vec![0, 1],
1121            agg_calls: vec![agg_call],
1122        };
1123
1124        let mem_context =
1125            MemoryContext::new_with_mem_limit(None, LabelGuardedIntGauge::test_int_gauge::<4>(), 0);
1126        let actual_exec = HashAggExecutorBuilder::deserialize(
1127            &agg_prost,
1128            src_exec,
1129            TaskId::default(),
1130            "HashAggExecutor".to_owned(),
1131            CHUNK_SIZE,
1132            mem_context.clone(),
1133            Some(SpillBackend::Memory),
1134            BatchSpillMetrics::for_test(),
1135            ShutdownToken::empty(),
1136        )
1137        .unwrap();
1138
1139        let actual_exec = create_order_by_executor(actual_exec);
1140
1141        let expect_exec = Box::new(MockExecutor::with_chunk(
1142            DataChunk::from_pretty(
1143                "i i I
1144                 1 0 1
1145                 0 0 3
1146                 0 1 3
1147                 1 1 6",
1148            ),
1149            Schema::new(vec![
1150                Field::unnamed(DataType::Int32),
1151                Field::unnamed(DataType::Int32),
1152                Field::unnamed(DataType::Int64),
1153            ]),
1154        ));
1155
1156        let expect_exec = create_order_by_executor(expect_exec);
1157        diff_executor_output(actual_exec, expect_exec).await;
1158    }
1159}