risingwave_batch_executors/executor/
sort_agg.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::ops::Range;
16
17use futures_async_stream::try_stream;
18use itertools::Itertools;
19use risingwave_common::array::{Array, ArrayBuilderImpl, ArrayImpl, DataChunk, StreamChunk};
20use risingwave_common::catalog::{Field, Schema};
21use risingwave_common::util::iter_util::ZipEqFast;
22use risingwave_expr::aggregate::{AggCall, AggregateState, BoxedAggregateFunction};
23use risingwave_expr::expr::{BoxedExpression, build_from_prost};
24use risingwave_pb::batch_plan::plan_node::NodeBody;
25
26use crate::error::{BatchError, Result};
27use crate::executor::aggregation::build as build_agg;
28use crate::executor::{
29    BoxedDataChunkStream, BoxedExecutor, BoxedExecutorBuilder, Executor, ExecutorBuilder,
30};
31use crate::task::ShutdownToken;
32
33/// `SortAggExecutor` implements the sort aggregate algorithm, which assumes
34/// that the input chunks has already been sorted by group columns.
35/// The aggregation will be applied to tuples within the same group.
36/// And the output schema is `[group columns, agg result]`.
37///
38/// As a special case, simple aggregate without groups satisfies the requirement
39/// automatically because all tuples should be aggregated together.
40pub struct SortAggExecutor {
41    aggs: Vec<BoxedAggregateFunction>,
42    group_key: Vec<BoxedExpression>,
43    child: BoxedExecutor,
44    schema: Schema,
45    identity: String,
46    output_size_limit: usize, // make unit test easy
47    shutdown_rx: ShutdownToken,
48}
49
50impl BoxedExecutorBuilder for SortAggExecutor {
51    async fn new_boxed_executor(
52        source: &ExecutorBuilder<'_>,
53        inputs: Vec<BoxedExecutor>,
54    ) -> Result<BoxedExecutor> {
55        let [child]: [_; 1] = inputs.try_into().unwrap();
56
57        let sort_agg_node = try_match_expand!(
58            source.plan_node().get_node_body().unwrap(),
59            NodeBody::SortAgg
60        )?;
61
62        let aggs: Vec<_> = sort_agg_node
63            .get_agg_calls()
64            .iter()
65            .map(|agg| AggCall::from_protobuf(agg).and_then(|agg| build_agg(&agg)))
66            .try_collect()?;
67
68        let group_key: Vec<_> = sort_agg_node
69            .get_group_key()
70            .iter()
71            .map(build_from_prost)
72            .try_collect()?;
73
74        let fields = group_key
75            .iter()
76            .map(|e| e.return_type())
77            .chain(aggs.iter().map(|e| e.return_type()))
78            .map(Field::unnamed)
79            .collect::<Vec<Field>>();
80
81        Ok(Box::new(Self {
82            aggs,
83            group_key,
84            child,
85            schema: Schema { fields },
86            identity: source.plan_node().get_identity().clone(),
87            output_size_limit: source.context().get_config().developer.chunk_size,
88            shutdown_rx: source.shutdown_rx().clone(),
89        }))
90    }
91}
92
93impl Executor for SortAggExecutor {
94    fn schema(&self) -> &Schema {
95        &self.schema
96    }
97
98    fn identity(&self) -> &str {
99        &self.identity
100    }
101
102    fn execute(self: Box<Self>) -> BoxedDataChunkStream {
103        self.do_execute()
104    }
105}
106
107impl SortAggExecutor {
108    #[try_stream(boxed, ok = DataChunk, error = BatchError)]
109    async fn do_execute(mut self: Box<Self>) {
110        let mut left_capacity = self.output_size_limit;
111        let mut agg_states: Vec<_> = self
112            .aggs
113            .iter()
114            .map(|agg| agg.create_state())
115            .try_collect()?;
116        let (mut group_builders, mut agg_builders) =
117            Self::create_builders(&self.group_key, &self.aggs);
118        let mut curr_group = if self.group_key.is_empty() {
119            Some(Vec::new())
120        } else {
121            None
122        };
123
124        #[for_await]
125        for child_chunk in self.child.execute() {
126            let child_chunk = StreamChunk::from(child_chunk?.compact());
127            let mut group_columns = Vec::with_capacity(self.group_key.len());
128            for expr in &mut self.group_key {
129                self.shutdown_rx.check()?;
130                let result = expr.eval(&child_chunk).await?;
131                group_columns.push(result);
132            }
133
134            let groups = if group_columns.is_empty() {
135                EqGroups::single_with_len(child_chunk.cardinality())
136            } else {
137                let groups: Vec<_> = group_columns
138                    .iter()
139                    .map(|col| EqGroups::detect(col))
140                    .try_collect()?;
141                EqGroups::intersect(&groups)
142            };
143
144            for range in groups.ranges() {
145                self.shutdown_rx.check()?;
146                let group: Vec<_> = group_columns
147                    .iter()
148                    .map(|col| col.datum_at(range.start))
149                    .collect();
150
151                if curr_group.as_ref() != Some(&group) {
152                    if let Some(group) = curr_group.replace(group) {
153                        group_builders
154                            .iter_mut()
155                            .zip_eq_fast(group.into_iter())
156                            .for_each(|(builder, datum)| {
157                                builder.append(datum);
158                            });
159                        Self::output_agg_states(&self.aggs, &mut agg_states, &mut agg_builders)
160                            .await?;
161                        left_capacity -= 1;
162
163                        if left_capacity == 0 {
164                            let output = DataChunk::new(
165                                group_builders
166                                    .into_iter()
167                                    .chain(agg_builders)
168                                    .map(|b| b.finish().into())
169                                    .collect(),
170                                self.output_size_limit,
171                            );
172                            yield output;
173
174                            (group_builders, agg_builders) =
175                                Self::create_builders(&self.group_key, &self.aggs);
176                            left_capacity = self.output_size_limit;
177                        }
178                    }
179                }
180
181                Self::update_agg_states(&self.aggs, &mut agg_states, &child_chunk, range).await?;
182            }
183        }
184
185        if let Some(group) = curr_group.take() {
186            group_builders
187                .iter_mut()
188                .zip_eq_fast(group.into_iter())
189                .for_each(|(builder, datum)| {
190                    builder.append(datum);
191                });
192            Self::output_agg_states(&self.aggs, &mut agg_states, &mut agg_builders).await?;
193            left_capacity -= 1;
194
195            let output = DataChunk::new(
196                group_builders
197                    .into_iter()
198                    .chain(agg_builders)
199                    .map(|b| b.finish().into())
200                    .collect(),
201                self.output_size_limit - left_capacity,
202            );
203            yield output;
204        }
205    }
206
207    async fn update_agg_states(
208        aggs: &[BoxedAggregateFunction],
209        agg_states: &mut [AggregateState],
210        child_chunk: &StreamChunk,
211        range: Range<usize>,
212    ) -> Result<()> {
213        for (agg, state) in aggs.iter().zip_eq_fast(agg_states.iter_mut()) {
214            agg.update_range(state, child_chunk, range.clone()).await?;
215        }
216        Ok(())
217    }
218
219    async fn output_agg_states(
220        aggs: &[BoxedAggregateFunction],
221        agg_states: &mut [AggregateState],
222        agg_builders: &mut [ArrayBuilderImpl],
223    ) -> Result<()> {
224        for ((agg, state), builder) in aggs
225            .iter()
226            .zip_eq_fast(agg_states.iter_mut())
227            .zip_eq_fast(agg_builders)
228        {
229            let result = agg.get_result(state).await?;
230            builder.append(result);
231            *state = agg.create_state()?;
232        }
233        Ok(())
234    }
235
236    fn create_builders(
237        group_key: &[BoxedExpression],
238        aggs: &[BoxedAggregateFunction],
239    ) -> (Vec<ArrayBuilderImpl>, Vec<ArrayBuilderImpl>) {
240        let group_builders = group_key
241            .iter()
242            .map(|e| e.return_type().create_array_builder(1))
243            .collect();
244
245        let agg_builders = aggs
246            .iter()
247            .map(|e| e.return_type().create_array_builder(1))
248            .collect();
249
250        (group_builders, agg_builders)
251    }
252}
253
254#[derive(Default, Debug)]
255struct EqGroups {
256    /// `[0, I1, ..., In, Len]` -> `[[0, I1), [I1, I2), ..., [In, Len)]`
257    /// `[0]` -> `[]`
258    indices: Vec<usize>,
259}
260
261impl EqGroups {
262    fn new(indices: Vec<usize>) -> Self {
263        EqGroups { indices }
264    }
265
266    fn single_with_len(len: usize) -> Self {
267        EqGroups {
268            indices: vec![0, len],
269        }
270    }
271
272    fn ranges(&self) -> impl Iterator<Item = Range<usize>> + '_ {
273        EqGroupsIter {
274            indices: &self.indices,
275            curr: 0,
276        }
277    }
278
279    /// Detect the equality groups in the given array.
280    fn detect(array: &ArrayImpl) -> Result<EqGroups> {
281        dispatch_array_variants!(array, array, { Ok(Self::detect_inner(array)) })
282    }
283
284    fn detect_inner<T>(array: &T) -> EqGroups
285    where
286        T: Array,
287        for<'a> T::RefItem<'a>: Eq,
288    {
289        let mut indices = vec![0];
290        if array.is_empty() {
291            return EqGroups { indices };
292        }
293        let mut curr_group = array.value_at(0);
294        for i in 1..array.len() {
295            let v = array.value_at(i);
296            if v == curr_group {
297                continue;
298            }
299            curr_group = v;
300            indices.push(i);
301        }
302        indices.push(array.len());
303        EqGroups::new(indices)
304    }
305
306    /// `intersect` combines the grouping information from each column into a single one.
307    /// This is required so that we know `group by c1, c2` with `c1 = [a, a, c, c, d, d]`
308    /// and `c2 = [g, h, h, h, h, h]` actually forms 4 groups: `[(a, g), (a, h), (c, h), (d, h)]`.
309    ///
310    /// Since the internal encoding is a sequence of sorted indices, this is effectively
311    /// merging all sequences into a single one with deduplication. In the example above,
312    /// the `EqGroups` of `c1` is `[2, 4]` and that of `c2` is `[1]`, so the output of
313    /// `intersect` would be `[1, 2, 4]` identifying the new groups starting at these indices.
314    fn intersect(columns: &[EqGroups]) -> EqGroups {
315        let mut indices = Vec::new();
316        // Use of BinaryHeap here is not to get a performant implementation but a
317        // concise one. The number of group columns would not be huge.
318        // Storing iterator rather than (ci, idx) in heap actually makes the implementation
319        // more verbose:
320        // https://play.rust-lang.org/?version=stable&mode=debug&edition=2018&gist=1e3b098ee3ef352d5a0cac03b3193799
321        use std::cmp::Reverse;
322        use std::collections::BinaryHeap;
323        let mut heap = BinaryHeap::new();
324        for (ci, column) in columns.iter().enumerate() {
325            if let Some(ri) = column.indices.first() {
326                heap.push(Reverse((ri, ci, 0)));
327            }
328        }
329        while let Some(Reverse((ri, ci, idx))) = heap.pop() {
330            if let Some(ri_next) = columns[ci].indices.get(idx + 1) {
331                heap.push(Reverse((ri_next, ci, idx + 1)));
332            }
333            if indices.last() == Some(ri) {
334                continue;
335            }
336            indices.push(*ri);
337        }
338        EqGroups::new(indices)
339    }
340}
341
342struct EqGroupsIter<'a> {
343    indices: &'a [usize],
344    curr: usize,
345}
346
347impl Iterator for EqGroupsIter<'_> {
348    type Item = Range<usize>;
349
350    fn next(&mut self) -> Option<Self::Item> {
351        if self.curr + 1 >= self.indices.len() {
352            return None;
353        }
354        let ret = self.indices[self.curr]..self.indices[self.curr + 1];
355        self.curr += 1;
356        Some(ret)
357    }
358}
359
360#[cfg(test)]
361mod tests {
362    use assert_matches::assert_matches;
363    use futures::StreamExt;
364    use futures_async_stream::for_await;
365    use risingwave_common::array::{Array as _, I64Array};
366    use risingwave_common::test_prelude::DataChunkTestExt;
367    use risingwave_common::types::DataType;
368    use risingwave_expr::expr::build_from_pretty;
369
370    use super::*;
371    use crate::executor::test_utils::MockExecutor;
372
373    #[tokio::test]
374    async fn execute_count_star_int32() -> Result<()> {
375        // mock a child executor
376        let schema = Schema {
377            fields: vec![
378                Field::unnamed(DataType::Int32),
379                Field::unnamed(DataType::Int32),
380                Field::unnamed(DataType::Int32),
381            ],
382        };
383        let mut child = MockExecutor::new(schema);
384        child.add(DataChunk::from_pretty(
385            "i i i
386             1 1 7
387             2 1 8
388             3 3 8
389             4 3 9",
390        ));
391        child.add(DataChunk::from_pretty(
392            "i i i
393             1 3 9
394             2 4 9
395             3 4 9
396             4 5 9",
397        ));
398        child.add(DataChunk::from_pretty(
399            "i i i
400             1 5 9
401             2 5 9
402             3 5 9
403             4 5 9",
404        ));
405
406        let count_star = build_agg(&AggCall::from_pretty("(count:int8)"))?;
407        let group_exprs: Vec<BoxedExpression> = vec![];
408        let aggs = vec![count_star];
409
410        // chain group key fields and agg state schema to get output schema for sort agg
411        let fields = group_exprs
412            .iter()
413            .map(|e| e.return_type())
414            .chain(aggs.iter().map(|e| e.return_type()))
415            .map(Field::unnamed)
416            .collect::<Vec<Field>>();
417
418        let executor = Box::new(SortAggExecutor {
419            aggs,
420            group_key: group_exprs,
421            child: Box::new(child),
422            schema: Schema { fields },
423            identity: "SortAggExecutor".to_owned(),
424            output_size_limit: 3,
425            shutdown_rx: ShutdownToken::empty(),
426        });
427
428        let fields = &executor.schema().fields;
429        assert_eq!(fields.len(), 1);
430        assert_eq!(fields[0].data_type, DataType::Int64);
431
432        let mut stream = executor.execute();
433        let res = stream.next().await.unwrap();
434        assert_matches!(res, Ok(_));
435        assert_matches!(stream.next().await, None);
436
437        let chunk = res?;
438        assert_eq!(chunk.cardinality(), 1);
439        let actual = chunk.column_at(0);
440        let actual_agg: &I64Array = actual.as_ref().into();
441        let v = actual_agg.iter().collect::<Vec<Option<i64>>>();
442
443        // check the result
444        assert_eq!(v, vec![Some(12)]);
445        Ok(())
446    }
447
448    #[tokio::test]
449    async fn execute_count_star_int32_grouped() -> Result<()> {
450        // mock a child executor
451        let schema = Schema {
452            fields: vec![
453                Field::unnamed(DataType::Int32),
454                Field::unnamed(DataType::Int32),
455                Field::unnamed(DataType::Int32),
456            ],
457        };
458        let mut child = MockExecutor::new(schema);
459        child.add(DataChunk::from_pretty(
460            "i i i
461             1 1 7
462             2 1 8
463             3 3 8
464             4 3 9
465             5 4 9",
466        ));
467        child.add(DataChunk::from_pretty(
468            "i i i
469             1 4 9
470             2 4 9
471             3 4 9
472             4 5 9
473             5 6 9
474             6 7 9
475             7 7 9
476             8 8 9",
477        ));
478        child.add(DataChunk::from_pretty(
479            "i i i
480             1 8 9
481             2 8 9
482             3 8 9
483             4 8 9
484             5 8 9",
485        ));
486
487        let count_star = build_agg(&AggCall::from_pretty("(count:int8)"))?;
488        let group_exprs: Vec<_> = (1..=2)
489            .map(|idx| build_from_pretty(format!("${idx}:int4")))
490            .collect();
491
492        let aggs = vec![count_star];
493
494        // chain group key fields and agg state schema to get output schema for sort agg
495        let fields = group_exprs
496            .iter()
497            .map(|e| e.return_type())
498            .chain(aggs.iter().map(|e| e.return_type()))
499            .map(Field::unnamed)
500            .collect::<Vec<Field>>();
501
502        let executor = Box::new(SortAggExecutor {
503            aggs,
504            group_key: group_exprs,
505            child: Box::new(child),
506            schema: Schema { fields },
507            identity: "SortAggExecutor".to_owned(),
508            output_size_limit: 3,
509            shutdown_rx: ShutdownToken::empty(),
510        });
511
512        let fields = &executor.schema().fields;
513        assert_eq!(fields[0].data_type, DataType::Int32);
514        assert_eq!(fields[1].data_type, DataType::Int32);
515        assert_eq!(fields[2].data_type, DataType::Int64);
516
517        let mut stream = executor.execute();
518        let res = stream.next().await.unwrap();
519        assert_matches!(res, Ok(_));
520
521        let chunk = res?;
522        assert_eq!(chunk.cardinality(), 3);
523        let actual = chunk.column_at(2);
524        let actual_agg: &I64Array = actual.as_ref().into();
525        let v = actual_agg.iter().collect::<Vec<Option<i64>>>();
526
527        // check the result
528        assert_eq!(v, vec![Some(1), Some(1), Some(1)]);
529        check_group_key_column(&chunk, 0, vec![Some(1), Some(1), Some(3)]);
530        check_group_key_column(&chunk, 1, vec![Some(7), Some(8), Some(8)]);
531
532        let res = stream.next().await.unwrap();
533        assert_matches!(res, Ok(_));
534
535        let chunk = res?;
536        assert_eq!(chunk.cardinality(), 3);
537        let actual = chunk.column_at(2);
538        let actual_agg: &I64Array = actual.as_ref().into();
539        let v = actual_agg.iter().collect::<Vec<Option<i64>>>();
540
541        assert_eq!(v, vec![Some(1), Some(4), Some(1)]);
542        check_group_key_column(&chunk, 0, vec![Some(3), Some(4), Some(5)]);
543        check_group_key_column(&chunk, 1, vec![Some(9), Some(9), Some(9)]);
544
545        // check the result
546        let res = stream.next().await.unwrap();
547        assert_matches!(res, Ok(_));
548
549        let chunk = res?;
550        assert_eq!(chunk.cardinality(), 3);
551        let actual = chunk.column_at(2);
552        let actual_agg: &I64Array = actual.as_ref().into();
553        let v = actual_agg.iter().collect::<Vec<Option<i64>>>();
554
555        // check the result
556        assert_eq!(v, vec![Some(1), Some(2), Some(6)]);
557        check_group_key_column(&chunk, 0, vec![Some(6), Some(7), Some(8)]);
558        check_group_key_column(&chunk, 1, vec![Some(9), Some(9), Some(9)]);
559
560        assert_matches!(stream.next().await, None);
561        Ok(())
562    }
563
564    #[tokio::test]
565    async fn execute_sum_int32() -> Result<()> {
566        let schema = Schema {
567            fields: vec![Field::unnamed(DataType::Int32)],
568        };
569        let mut child = MockExecutor::new(schema);
570        child.add(DataChunk::from_pretty(
571            " i
572              1
573              2
574              3
575              4
576              5
577              6
578              7
579              8
580              9
581             10",
582        ));
583
584        let sum_agg = build_agg(&AggCall::from_pretty("(sum:int8 $0:int4)"))?;
585
586        let group_exprs: Vec<BoxedExpression> = vec![];
587        let aggs = vec![sum_agg];
588        let fields = group_exprs
589            .iter()
590            .map(|e| e.return_type())
591            .chain(aggs.iter().map(|e| e.return_type()))
592            .map(Field::unnamed)
593            .collect::<Vec<Field>>();
594        let executor = Box::new(SortAggExecutor {
595            aggs,
596            group_key: vec![],
597            child: Box::new(child),
598            schema: Schema { fields },
599            identity: "SortAggExecutor".to_owned(),
600            output_size_limit: 4,
601            shutdown_rx: ShutdownToken::empty(),
602        });
603
604        let mut stream = executor.execute();
605        let chunk = stream.next().await.unwrap()?;
606        assert_matches!(stream.next().await, None);
607
608        let actual = chunk.column_at(0);
609        let actual: &I64Array = actual.as_ref().into();
610        let v = actual.iter().collect::<Vec<Option<i64>>>();
611        assert_eq!(v, vec![Some(55)]);
612
613        assert_matches!(stream.next().await, None);
614        Ok(())
615    }
616
617    #[tokio::test]
618    async fn execute_sum_int32_grouped() -> Result<()> {
619        // mock a child executor
620        let schema = Schema {
621            fields: vec![
622                Field::unnamed(DataType::Int32),
623                Field::unnamed(DataType::Int32),
624                Field::unnamed(DataType::Int32),
625            ],
626        };
627        let mut child = MockExecutor::new(schema);
628        child.add(DataChunk::from_pretty(
629            "i i i
630             1 1 7
631             2 1 8
632             3 3 8
633             4 3 9",
634        ));
635        child.add(DataChunk::from_pretty(
636            "i i i
637             1 3 9
638             2 4 9
639             3 4 9
640             4 5 9",
641        ));
642        child.add(DataChunk::from_pretty(
643            "i i i
644             1 5 9
645             2 5 9
646             3 5 9
647             4 5 9",
648        ));
649
650        let sum_agg = build_agg(&AggCall::from_pretty("(sum:int8 $0:int4)"))?;
651        let group_exprs: Vec<_> = (1..=2)
652            .map(|idx| build_from_pretty(format!("${idx}:int4")))
653            .collect();
654
655        let aggs = vec![sum_agg];
656
657        // chain group key fields and agg state schema to get output schema for sort agg
658        let fields = group_exprs
659            .iter()
660            .map(|e| e.return_type())
661            .chain(aggs.iter().map(|e| e.return_type()))
662            .map(Field::unnamed)
663            .collect::<Vec<Field>>();
664
665        let output_size_limit = 4;
666        let executor = Box::new(SortAggExecutor {
667            aggs,
668            group_key: group_exprs,
669            child: Box::new(child),
670            schema: Schema { fields },
671            identity: "SortAggExecutor".to_owned(),
672            output_size_limit,
673            shutdown_rx: ShutdownToken::empty(),
674        });
675
676        let fields = &executor.schema().fields;
677        assert_eq!(fields[0].data_type, DataType::Int32);
678        assert_eq!(fields[1].data_type, DataType::Int32);
679        assert_eq!(fields[2].data_type, DataType::Int64);
680
681        let mut stream = executor.execute();
682        let res = stream.next().await.unwrap();
683        assert_matches!(res, Ok(_));
684
685        let chunk = res?;
686        let actual = chunk.column_at(2);
687        let actual_agg: &I64Array = actual.as_ref().into();
688        let v = actual_agg.iter().collect::<Vec<Option<i64>>>();
689
690        // check the result
691        assert_eq!(v, vec![Some(1), Some(2), Some(3), Some(5)]);
692        check_group_key_column(&chunk, 0, vec![Some(1), Some(1), Some(3), Some(3)]);
693        check_group_key_column(&chunk, 1, vec![Some(7), Some(8), Some(8), Some(9)]);
694
695        let res = stream.next().await.unwrap();
696        assert_matches!(res, Ok(_));
697
698        let chunk = res?;
699        let actual2 = chunk.column_at(2);
700        let actual_agg2: &I64Array = actual2.as_ref().into();
701        let v = actual_agg2.iter().collect::<Vec<Option<i64>>>();
702
703        // check the result
704        assert_eq!(v, vec![Some(5), Some(14)]);
705        check_group_key_column(&chunk, 0, vec![Some(4), Some(5)]);
706        check_group_key_column(&chunk, 1, vec![Some(9), Some(9)]);
707
708        assert_matches!(stream.next().await, None);
709        Ok(())
710    }
711
712    #[tokio::test]
713    async fn execute_sum_int32_grouped_exceed_limit() -> Result<()> {
714        // mock a child executor
715        let schema = Schema {
716            fields: vec![
717                Field::unnamed(DataType::Int32),
718                Field::unnamed(DataType::Int32),
719                Field::unnamed(DataType::Int32),
720            ],
721        };
722        let mut child = MockExecutor::new(schema);
723        child.add(DataChunk::from_pretty(
724            " i  i  i
725              1  1  7
726              2  1  8
727              3  3  8
728              4  3  8
729              5  4  9
730              6  4  9
731              7  5  9
732              8  5  9
733              9  6 10
734             10  6 10",
735        ));
736        child.add(DataChunk::from_pretty(
737            " i  i  i
738              1  6 10
739              2  7 12",
740        ));
741
742        let sum_agg = build_agg(&AggCall::from_pretty("(sum:int8 $0:int4)"))?;
743        let group_exprs: Vec<_> = (1..=2)
744            .map(|idx| build_from_pretty(format!("${idx}:int4")))
745            .collect();
746
747        let aggs = vec![sum_agg];
748
749        // chain group key fields and agg state schema to get output schema for sort agg
750        let fields = group_exprs
751            .iter()
752            .map(|e| e.return_type())
753            .chain(aggs.iter().map(|e| e.return_type()))
754            .map(Field::unnamed)
755            .collect::<Vec<Field>>();
756
757        let executor = Box::new(SortAggExecutor {
758            aggs,
759            group_key: group_exprs,
760            child: Box::new(child),
761            schema: Schema { fields },
762            identity: "SortAggExecutor".to_owned(),
763            output_size_limit: 3,
764            shutdown_rx: ShutdownToken::empty(),
765        });
766
767        let fields = &executor.schema().fields;
768        assert_eq!(fields[0].data_type, DataType::Int32);
769        assert_eq!(fields[1].data_type, DataType::Int32);
770        assert_eq!(fields[2].data_type, DataType::Int64);
771
772        // check first chunk
773        let mut stream = executor.execute();
774        let res = stream.next().await.unwrap();
775        assert_matches!(res, Ok(_));
776
777        let chunk = res?;
778        let actual = chunk.column_at(2);
779        let actual_agg: &I64Array = actual.as_ref().into();
780        let v = actual_agg.iter().collect::<Vec<Option<i64>>>();
781        assert_eq!(v, vec![Some(1), Some(2), Some(7)]);
782        check_group_key_column(&chunk, 0, vec![Some(1), Some(1), Some(3)]);
783        check_group_key_column(&chunk, 1, vec![Some(7), Some(8), Some(8)]);
784
785        // check second chunk
786        let res = stream.next().await.unwrap();
787        assert_matches!(res, Ok(_));
788
789        let chunk = res?;
790        let actual2 = chunk.column_at(2);
791        let actual_agg2: &I64Array = actual2.as_ref().into();
792        let v = actual_agg2.iter().collect::<Vec<Option<i64>>>();
793        assert_eq!(v, vec![Some(11), Some(15), Some(20)]);
794        check_group_key_column(&chunk, 0, vec![Some(4), Some(5), Some(6)]);
795        check_group_key_column(&chunk, 1, vec![Some(9), Some(9), Some(10)]);
796
797        // check third chunk
798        let res = stream.next().await.unwrap();
799        assert_matches!(res, Ok(_));
800
801        let chunk = res?;
802        let actual2 = chunk.column_at(2);
803        let actual_agg2: &I64Array = actual2.as_ref().into();
804        let v = actual_agg2.iter().collect::<Vec<Option<i64>>>();
805
806        assert_eq!(v, vec![Some(2)]);
807        check_group_key_column(&chunk, 0, vec![Some(7)]);
808        check_group_key_column(&chunk, 1, vec![Some(12)]);
809
810        assert_matches!(stream.next().await, None);
811        Ok(())
812    }
813
814    fn check_group_key_column(actual: &DataChunk, col_idx: usize, expect: Vec<Option<i32>>) {
815        assert_eq!(
816            actual
817                .column_at(col_idx)
818                .as_int32()
819                .iter()
820                .collect::<Vec<_>>(),
821            expect
822        );
823    }
824
825    #[tokio::test]
826    async fn test_shutdown_rx() -> Result<()> {
827        let child = MockExecutor::with_chunk(
828            DataChunk::from_pretty(
829                "i
830                 4",
831            ),
832            Schema::new(vec![Field::unnamed(DataType::Int32)]),
833        );
834
835        let sum_agg = build_agg(&AggCall::from_pretty("(sum:int8 $0:int4)"))?;
836        let group_exprs: Vec<_> = (1..=2)
837            .map(|idx| build_from_pretty(format!("${idx}:int4")))
838            .collect();
839
840        let aggs = vec![sum_agg];
841
842        // chain group key fields and agg state schema to get output schema for sort agg
843        let fields = group_exprs
844            .iter()
845            .map(|e| e.return_type())
846            .chain(aggs.iter().map(|e| e.return_type()))
847            .map(Field::unnamed)
848            .collect::<Vec<Field>>();
849
850        let output_size_limit = 4;
851        let (shutdown_tx, shutdown_rx) = ShutdownToken::new();
852        let executor = Box::new(SortAggExecutor {
853            aggs,
854            group_key: group_exprs,
855            child: Box::new(child),
856            schema: Schema { fields },
857            identity: "SortAggExecutor".to_owned(),
858            output_size_limit,
859            shutdown_rx,
860        });
861        shutdown_tx.cancel();
862        #[for_await]
863        for data in executor.execute() {
864            assert!(data.is_err());
865            break;
866        }
867
868        Ok(())
869    }
870}