risingwave_stream/executor/aggregate/
minput.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::ops::Bound::{self};
16
17use futures::{StreamExt, pin_mut};
18use futures_async_stream::for_await;
19use itertools::Itertools;
20use risingwave_common::array::StreamChunk;
21use risingwave_common::catalog::Schema;
22use risingwave_common::row::{OwnedRow, RowExt};
23use risingwave_common::types::Datum;
24use risingwave_common::util::row_serde::OrderedRowSerde;
25use risingwave_common::util::sort_util::{ColumnOrder, OrderType};
26use risingwave_common_estimate_size::EstimateSize;
27use risingwave_expr::aggregate::{AggCall, AggType, BoxedAggregateFunction, PbAggKind};
28use risingwave_pb::stream_plan::PbAggNodeVersion;
29use risingwave_storage::StateStore;
30use risingwave_storage::store::PrefetchOptions;
31
32use super::agg_group::{AggStateCacheStats, GroupKey};
33use super::agg_state_cache::{AggStateCache, GenericAggStateCache};
34use crate::common::StateTableColumnMapping;
35use crate::common::state_cache::{OrderedStateCache, TopNStateCache};
36use crate::common::table::state_table::StateTable;
37use crate::executor::{PkIndices, StreamExecutorResult};
38
39/// Aggregation state as a materialization of input chunks.
40///
41/// For example, in `string_agg`, several useful columns are picked from input chunks and
42/// stored in the state table when applying chunks, and the aggregation result is calculated
43/// when need to get output.
44#[derive(EstimateSize)]
45pub struct MaterializedInputState {
46    /// Argument column indices in input chunks.
47    arg_col_indices: Vec<usize>,
48
49    /// Argument column indices in state table, group key skipped.
50    state_table_arg_col_indices: Vec<usize>,
51
52    /// The columns to order by in input chunks.
53    order_col_indices: Vec<usize>,
54
55    /// The columns to order by in state table, group key skipped.
56    state_table_order_col_indices: Vec<usize>,
57
58    /// Cache of state table.
59    cache: Box<dyn AggStateCache + Send + Sync>,
60
61    /// Whether to output the first value from cache.
62    output_first_value: bool,
63
64    /// Serializer for cache key.
65    #[estimate_size(ignore)]
66    cache_key_serializer: OrderedRowSerde,
67}
68
69impl MaterializedInputState {
70    /// Create an instance from [`AggCall`].
71    pub fn new(
72        version: PbAggNodeVersion,
73        agg_call: &AggCall,
74        pk_indices: &PkIndices,
75        order_columns: &[ColumnOrder],
76        col_mapping: &StateTableColumnMapping,
77        extreme_cache_size: usize,
78        input_schema: &Schema,
79    ) -> StreamExecutorResult<Self> {
80        if agg_call.distinct && version < PbAggNodeVersion::Issue12140 {
81            panic!(
82                "RisingWave versions before issue #12140 is resolved has critical bug, you must re-create current MV to ensure correctness."
83            );
84        }
85
86        let arg_col_indices = agg_call.args.val_indices().to_vec();
87
88        let (order_col_indices, order_types) = if version < PbAggNodeVersion::Issue13465 {
89            generate_order_columns_before_version_issue_13465(
90                agg_call,
91                pk_indices,
92                &arg_col_indices,
93            )
94        } else {
95            order_columns
96                .iter()
97                .map(|o| (o.column_index, o.order_type))
98                .unzip()
99        };
100
101        // map argument columns to state table column indices
102        let state_table_arg_col_indices = arg_col_indices
103            .iter()
104            .map(|i| {
105                col_mapping
106                    .upstream_to_state_table(*i)
107                    .expect("the argument columns must appear in the state table")
108            })
109            .collect_vec();
110
111        // map order by columns to state table column indices
112        let state_table_order_col_indices = order_col_indices
113            .iter()
114            .map(|i| {
115                col_mapping
116                    .upstream_to_state_table(*i)
117                    .expect("the order columns must appear in the state table")
118            })
119            .collect_vec();
120
121        let cache_key_data_types = order_col_indices
122            .iter()
123            .map(|i| input_schema[*i].data_type())
124            .collect_vec();
125        let cache_key_serializer = OrderedRowSerde::new(cache_key_data_types, order_types);
126
127        let cache: Box<dyn AggStateCache + Send + Sync> = match agg_call.agg_type {
128            AggType::Builtin(
129                PbAggKind::Min | PbAggKind::Max | PbAggKind::FirstValue | PbAggKind::LastValue,
130            ) => Box::new(GenericAggStateCache::new(
131                TopNStateCache::new(extreme_cache_size),
132                agg_call.args.arg_types(),
133            )),
134            AggType::Builtin(
135                PbAggKind::StringAgg
136                | PbAggKind::ArrayAgg
137                | PbAggKind::JsonbAgg
138                | PbAggKind::JsonbObjectAgg
139                | PbAggKind::PercentileCont
140                | PbAggKind::PercentileDisc
141                | PbAggKind::Mode,
142            )
143            | AggType::WrapScalar(_) => Box::new(GenericAggStateCache::new(
144                OrderedStateCache::new(),
145                agg_call.args.arg_types(),
146            )),
147            _ => panic!(
148                "Agg type `{}` is not expected to have materialized input state",
149                agg_call.agg_type
150            ),
151        };
152        let output_first_value = matches!(
153            agg_call.agg_type,
154            AggType::Builtin(
155                PbAggKind::Min | PbAggKind::Max | PbAggKind::FirstValue | PbAggKind::LastValue
156            )
157        );
158
159        Ok(Self {
160            arg_col_indices,
161            state_table_arg_col_indices,
162            order_col_indices,
163            state_table_order_col_indices,
164            cache,
165            output_first_value,
166            cache_key_serializer,
167        })
168    }
169
170    /// Apply a chunk of data to the state cache.
171    /// This method should never involve any state table operations.
172    pub fn apply_chunk(&mut self, chunk: &StreamChunk) -> StreamExecutorResult<()> {
173        self.cache.apply_batch(
174            chunk,
175            &self.cache_key_serializer,
176            &self.arg_col_indices,
177            &self.order_col_indices,
178        );
179        Ok(())
180    }
181
182    /// Get the output of the state.
183    /// We may need to read from the state table into the cache to get the output.
184    pub async fn get_output(
185        &mut self,
186        state_table: &StateTable<impl StateStore>,
187        group_key: Option<&GroupKey>,
188        func: &BoxedAggregateFunction,
189    ) -> StreamExecutorResult<(Datum, AggStateCacheStats)> {
190        let mut stats = AggStateCacheStats::default();
191        stats.agg_state_cache_lookup_count += 1;
192
193        if !self.cache.is_synced() {
194            stats.agg_state_cache_miss_count += 1;
195
196            let mut cache_filler = self.cache.begin_syncing();
197            let sub_range: &(Bound<OwnedRow>, Bound<OwnedRow>) =
198                &(Bound::Unbounded, Bound::Unbounded);
199            let all_data_iter = state_table
200                .iter_with_prefix(
201                    group_key.map(GroupKey::table_pk),
202                    sub_range,
203                    PrefetchOptions {
204                        prefetch: cache_filler.capacity().is_none(),
205                        for_large_query: false,
206                    },
207                )
208                .await?;
209            pin_mut!(all_data_iter);
210
211            #[for_await]
212            for keyed_row in all_data_iter.take(cache_filler.capacity().unwrap_or(usize::MAX)) {
213                let state_row = keyed_row?;
214                let cache_key = {
215                    let mut cache_key = Vec::new();
216                    self.cache_key_serializer.serialize(
217                        state_row
218                            .as_ref()
219                            .project(&self.state_table_order_col_indices),
220                        &mut cache_key,
221                    );
222                    cache_key.into()
223                };
224                let cache_value = self
225                    .state_table_arg_col_indices
226                    .iter()
227                    .map(|i| state_row[*i].clone())
228                    .collect();
229                cache_filler.append(cache_key, cache_value);
230            }
231            cache_filler.finish();
232        }
233        assert!(self.cache.is_synced());
234
235        if self.output_first_value {
236            // special case for `min`, `max`, `first_value` and `last_value`
237            // take the first value from the cache
238            Ok((self.cache.output_first(), stats))
239        } else {
240            const CHUNK_SIZE: usize = 1024;
241            let chunks = self.cache.output_batches(CHUNK_SIZE).collect_vec();
242            let mut state = func.create_state()?;
243            for chunk in chunks {
244                func.update(&mut state, &chunk).await?;
245            }
246            Ok((func.get_result(&state).await?, stats))
247        }
248    }
249
250    #[cfg(test)]
251    async fn get_output_no_stats(
252        &mut self,
253        state_table: &StateTable<impl StateStore>,
254        group_key: Option<&GroupKey>,
255        func: &BoxedAggregateFunction,
256    ) -> StreamExecutorResult<Datum> {
257        let (res, _stats) = self.get_output(state_table, group_key, func).await?;
258        Ok(res)
259    }
260}
261
262/// Copied from old code before <https://github.com/risingwavelabs/risingwave/commit/0020507edbc4010b20aeeb560c7bea9159315602>.
263fn generate_order_columns_before_version_issue_13465(
264    agg_call: &AggCall,
265    pk_indices: &PkIndices,
266    arg_col_indices: &[usize],
267) -> (Vec<usize>, Vec<OrderType>) {
268    let (mut order_col_indices, mut order_types) = if matches!(
269        agg_call.agg_type,
270        AggType::Builtin(PbAggKind::Min | PbAggKind::Max)
271    ) {
272        // `min`/`max` need not to order by any other columns, but have to
273        // order by the agg value implicitly.
274        let order_type = if matches!(agg_call.agg_type, AggType::Builtin(PbAggKind::Min)) {
275            OrderType::ascending()
276        } else {
277            OrderType::descending()
278        };
279        (vec![arg_col_indices[0]], vec![order_type])
280    } else {
281        agg_call
282            .column_orders
283            .iter()
284            .map(|p| {
285                (
286                    p.column_index,
287                    if matches!(agg_call.agg_type, AggType::Builtin(PbAggKind::LastValue)) {
288                        p.order_type.reverse()
289                    } else {
290                        p.order_type
291                    },
292                )
293            })
294            .unzip()
295    };
296
297    if agg_call.distinct {
298        // If distinct, we need to materialize input with the distinct keys
299        // As we only support single-column distinct for now, we use the
300        // `agg_call.args.val_indices()[0]` as the distinct key.
301        if !order_col_indices.contains(&agg_call.args.val_indices()[0]) {
302            order_col_indices.push(agg_call.args.val_indices()[0]);
303            order_types.push(OrderType::ascending());
304        }
305    } else {
306        // If not distinct, we need to materialize input with the primary keys
307        let pk_len = pk_indices.len();
308        order_col_indices.extend(pk_indices.iter());
309        order_types.extend(itertools::repeat_n(OrderType::ascending(), pk_len));
310    }
311
312    (order_col_indices, order_types)
313}
314
315#[cfg(test)]
316mod tests {
317    use std::collections::HashSet;
318
319    use itertools::Itertools;
320    use rand::Rng;
321    use rand::seq::IteratorRandom;
322    use risingwave_common::array::StreamChunk;
323    use risingwave_common::catalog::{ColumnDesc, ColumnId, Field, Schema, TableId};
324    use risingwave_common::row::OwnedRow;
325    use risingwave_common::test_prelude::StreamChunkTestExt;
326    use risingwave_common::types::{DataType, ListValue};
327    use risingwave_common::util::epoch::{EpochPair, test_epoch};
328    use risingwave_common::util::sort_util::{ColumnOrder, OrderType};
329    use risingwave_expr::aggregate::{AggCall, build_append_only};
330    use risingwave_pb::stream_plan::PbAggNodeVersion;
331    use risingwave_storage::StateStore;
332    use risingwave_storage::memory::MemoryStateStore;
333
334    use super::*;
335    use crate::common::StateTableColumnMapping;
336    use crate::common::table::state_table::StateTable;
337    use crate::common::table::test_utils::gen_pbtable;
338    use crate::executor::{PkIndices, StreamExecutorResult};
339
340    fn create_chunk<S: StateStore>(
341        pretty: &str,
342        table: &mut StateTable<S>,
343        col_mapping: &StateTableColumnMapping,
344    ) -> StreamChunk {
345        let chunk = StreamChunk::from_pretty(pretty);
346        table.write_chunk(chunk.project(col_mapping.upstream_columns()));
347        chunk
348    }
349
350    async fn create_mem_state_table(
351        input_schema: &Schema,
352        upstream_columns: Vec<usize>,
353        order_types: Vec<OrderType>,
354    ) -> (StateTable<MemoryStateStore>, StateTableColumnMapping) {
355        // see `LogicalAgg::infer_stream_agg_state` for the construction of state table
356        let table_id = TableId::new(rand::rng().random());
357        let columns = upstream_columns
358            .iter()
359            .map(|col_idx| input_schema[*col_idx].data_type())
360            .enumerate()
361            .map(|(i, data_type)| ColumnDesc::unnamed(ColumnId::new(i as i32), data_type))
362            .collect_vec();
363        let mapping = StateTableColumnMapping::new(upstream_columns, None);
364        let pk_len = order_types.len();
365        let table = StateTable::from_table_catalog(
366            &gen_pbtable(table_id, columns, order_types, (0..pk_len).collect(), 0),
367            MemoryStateStore::new(),
368            None,
369        )
370        .await;
371        (table, mapping)
372    }
373
374    #[tokio::test]
375    async fn test_extreme_agg_state_basic_min() -> StreamExecutorResult<()> {
376        // Assumption of input schema:
377        // (a: varchar, b: int32, c: int32, _row_id: int64)
378
379        let field1 = Field::unnamed(DataType::Varchar);
380        let field2 = Field::unnamed(DataType::Int32);
381        let field3 = Field::unnamed(DataType::Int32);
382        let field4 = Field::unnamed(DataType::Int64);
383        let input_schema = Schema::new(vec![field1, field2, field3, field4]);
384
385        let agg_call = AggCall::from_pretty("(min:int4 $2:int4)"); // min(c)
386        let agg = build_append_only(&agg_call).unwrap();
387        let group_key = None;
388
389        let (mut table, mapping) = create_mem_state_table(
390            &input_schema,
391            vec![2, 3],
392            vec![
393                OrderType::ascending(), // for AggKind::Min
394                OrderType::ascending(),
395            ],
396        )
397        .await;
398
399        let order_columns = vec![
400            ColumnOrder::new(2, OrderType::ascending()), // c ASC for AggKind::Min
401            ColumnOrder::new(3, OrderType::ascending()), // _row_id
402        ];
403        let mut state = MaterializedInputState::new(
404            PbAggNodeVersion::LATEST,
405            &agg_call,
406            &PkIndices::new(), // unused
407            &order_columns,
408            &mapping,
409            usize::MAX,
410            &input_schema,
411        )
412        .unwrap();
413
414        let mut epoch = EpochPair::new_test_epoch(test_epoch(1));
415        table.init_epoch(epoch).await.unwrap();
416
417        {
418            let chunk = create_chunk(
419                " T i i I
420                + a 1 8 123
421                + b 5 2 128
422                - b 5 2 128
423                + c 1 3 130",
424                &mut table,
425                &mapping,
426            );
427
428            state.apply_chunk(&chunk)?;
429
430            epoch.inc_for_test();
431            table.commit_for_test(epoch).await.unwrap();
432
433            let res = state
434                .get_output_no_stats(&table, group_key.as_ref(), &agg)
435                .await?;
436            assert_eq!(res, Some(3i32.into()));
437        }
438
439        {
440            let chunk = create_chunk(
441                " T i i I
442                + d 0 8 134
443                + e 2 2 137",
444                &mut table,
445                &mapping,
446            );
447
448            state.apply_chunk(&chunk)?;
449
450            epoch.inc_for_test();
451            table.commit_for_test(epoch).await.unwrap();
452
453            let res = state
454                .get_output_no_stats(&table, group_key.as_ref(), &agg)
455                .await?;
456            assert_eq!(res, Some(2i32.into()));
457        }
458
459        {
460            // test recovery (cold start)
461            let mut state = MaterializedInputState::new(
462                PbAggNodeVersion::LATEST,
463                &agg_call,
464                &PkIndices::new(), // unused
465                &order_columns,
466                &mapping,
467                usize::MAX,
468                &input_schema,
469            )
470            .unwrap();
471            let res = state
472                .get_output_no_stats(&table, group_key.as_ref(), &agg)
473                .await?;
474            assert_eq!(res, Some(2i32.into()));
475        }
476
477        Ok(())
478    }
479
480    #[tokio::test]
481    async fn test_extreme_agg_state_basic_max() -> StreamExecutorResult<()> {
482        // Assumption of input schema:
483        // (a: varchar, b: int32, c: int32, _row_id: int64)
484
485        let field1 = Field::unnamed(DataType::Varchar);
486        let field2 = Field::unnamed(DataType::Int32);
487        let field3 = Field::unnamed(DataType::Int32);
488        let field4 = Field::unnamed(DataType::Int64);
489        let input_schema = Schema::new(vec![field1, field2, field3, field4]);
490
491        let agg_call = AggCall::from_pretty("(max:int4 $2:int4)"); // max(c)
492        let agg = build_append_only(&agg_call).unwrap();
493        let group_key = None;
494
495        let (mut table, mapping) = create_mem_state_table(
496            &input_schema,
497            vec![2, 3],
498            vec![
499                OrderType::descending(), // for AggKind::Max
500                OrderType::ascending(),
501            ],
502        )
503        .await;
504
505        let order_columns = vec![
506            ColumnOrder::new(2, OrderType::descending()), // c DESC for AggKind::Max
507            ColumnOrder::new(3, OrderType::ascending()),  // _row_id
508        ];
509        let mut state = MaterializedInputState::new(
510            PbAggNodeVersion::LATEST,
511            &agg_call,
512            &PkIndices::new(), // unused
513            &order_columns,
514            &mapping,
515            usize::MAX,
516            &input_schema,
517        )
518        .unwrap();
519
520        let mut epoch = EpochPair::new_test_epoch(test_epoch(1));
521        table.init_epoch(epoch).await.unwrap();
522
523        {
524            let chunk = create_chunk(
525                " T i i I
526                + a 1 8 123
527                + b 5 2 128
528                - b 5 2 128
529                + c 1 3 130",
530                &mut table,
531                &mapping,
532            );
533
534            state.apply_chunk(&chunk)?;
535
536            epoch.inc_for_test();
537            table.commit_for_test(epoch).await.unwrap();
538
539            let res = state
540                .get_output_no_stats(&table, group_key.as_ref(), &agg)
541                .await?;
542            assert_eq!(res, Some(8i32.into()));
543        }
544
545        {
546            let chunk = create_chunk(
547                " T i i I
548                + d 0 9 134
549                + e 2 2 137",
550                &mut table,
551                &mapping,
552            );
553
554            state.apply_chunk(&chunk)?;
555
556            epoch.inc_for_test();
557            table.commit_for_test(epoch).await.unwrap();
558
559            let res = state
560                .get_output_no_stats(&table, group_key.as_ref(), &agg)
561                .await?;
562            assert_eq!(res, Some(9i32.into()));
563        }
564
565        {
566            // test recovery (cold start)
567            let mut state = MaterializedInputState::new(
568                PbAggNodeVersion::LATEST,
569                &agg_call,
570                &PkIndices::new(), // unused
571                &order_columns,
572                &mapping,
573                usize::MAX,
574                &input_schema,
575            )
576            .unwrap();
577
578            let res = state
579                .get_output_no_stats(&table, group_key.as_ref(), &agg)
580                .await?;
581            assert_eq!(res, Some(9i32.into()));
582        }
583
584        Ok(())
585    }
586
587    #[tokio::test]
588    async fn test_extreme_agg_state_with_hidden_input() -> StreamExecutorResult<()> {
589        // Assumption of input schema:
590        // (a: varchar, b: int32, c: int32, _row_id: int64)
591
592        let field1 = Field::unnamed(DataType::Varchar);
593        let field2 = Field::unnamed(DataType::Int32);
594        let field3 = Field::unnamed(DataType::Int32);
595        let field4 = Field::unnamed(DataType::Int64);
596        let input_schema = Schema::new(vec![field1, field2, field3, field4]);
597
598        let agg_call_1 = AggCall::from_pretty("(min:varchar $0:varchar)"); // min(a)
599        let agg_call_2 = AggCall::from_pretty("(max:int4 $1:int4)"); // max(b)
600        let agg1 = build_append_only(&agg_call_1).unwrap();
601        let agg2 = build_append_only(&agg_call_2).unwrap();
602        let group_key = None;
603
604        let (mut table_1, mapping_1) = create_mem_state_table(
605            &input_schema,
606            vec![0, 3],
607            vec![
608                OrderType::ascending(), // for AggKind::Min
609                OrderType::ascending(),
610            ],
611        )
612        .await;
613        let (mut table_2, mapping_2) = create_mem_state_table(
614            &input_schema,
615            vec![1, 3],
616            vec![
617                OrderType::descending(), // for AggKind::Max
618                OrderType::ascending(),
619            ],
620        )
621        .await;
622
623        let mut epoch = EpochPair::new_test_epoch(test_epoch(1));
624        table_1.init_epoch(epoch).await.unwrap();
625        table_2.init_epoch(epoch).await.unwrap();
626
627        let order_columns_1 = vec![
628            ColumnOrder::new(0, OrderType::ascending()), // a ASC for AggKind::Min
629            ColumnOrder::new(3, OrderType::ascending()), // _row_id
630        ];
631        let mut state_1 = MaterializedInputState::new(
632            PbAggNodeVersion::LATEST,
633            &agg_call_1,
634            &PkIndices::new(), // unused
635            &order_columns_1,
636            &mapping_1,
637            usize::MAX,
638            &input_schema,
639        )
640        .unwrap();
641
642        let order_columns_2 = vec![
643            ColumnOrder::new(1, OrderType::descending()), // b DESC for AggKind::Max
644            ColumnOrder::new(3, OrderType::ascending()),  // _row_id
645        ];
646        let mut state_2 = MaterializedInputState::new(
647            PbAggNodeVersion::LATEST,
648            &agg_call_2,
649            &PkIndices::new(), // unused
650            &order_columns_2,
651            &mapping_2,
652            usize::MAX,
653            &input_schema,
654        )
655        .unwrap();
656
657        {
658            let chunk_1 = create_chunk(
659                " T i i I
660                + a 1 8 123
661                + b 5 2 128
662                - b 5 2 128
663                + c 1 3 130
664                + . 9 4 131 D
665                + . 6 5 132 D
666                + c . 3 133",
667                &mut table_1,
668                &mapping_1,
669            );
670            let chunk_2 = create_chunk(
671                " T i i I
672                + a 1 8 123
673                + b 5 2 128
674                - b 5 2 128
675                + c 1 3 130
676                + . 9 4 131
677                + . 6 5 132
678                + c . 3 133 D",
679                &mut table_2,
680                &mapping_2,
681            );
682
683            state_1.apply_chunk(&chunk_1)?;
684            state_2.apply_chunk(&chunk_2)?;
685
686            epoch.inc_for_test();
687            table_1.commit_for_test(epoch).await.unwrap();
688            table_2.commit_for_test(epoch).await.unwrap();
689
690            let out1 = state_1
691                .get_output_no_stats(&table_1, group_key.as_ref(), &agg1)
692                .await?;
693            assert_eq!(out1, Some("a".into()));
694
695            let out2 = state_2
696                .get_output_no_stats(&table_2, group_key.as_ref(), &agg2)
697                .await?;
698            assert_eq!(out2, Some(9i32.into()));
699        }
700
701        Ok(())
702    }
703
704    #[tokio::test]
705    async fn test_extreme_agg_state_grouped() -> StreamExecutorResult<()> {
706        // Assumption of input schema:
707        // (a: varchar, b: int32, c: int32, _row_id: int64)
708
709        let field1 = Field::unnamed(DataType::Varchar);
710        let field2 = Field::unnamed(DataType::Int32);
711        let field3 = Field::unnamed(DataType::Int32);
712        let field4 = Field::unnamed(DataType::Int64);
713        let input_schema = Schema::new(vec![field1, field2, field3, field4]);
714
715        let agg_call = AggCall::from_pretty("(max:int4 $1:int4)"); // max(b)
716        let agg = build_append_only(&agg_call).unwrap();
717        let group_key = Some(GroupKey::new(OwnedRow::new(vec![Some(8.into())]), None));
718
719        let (mut table, mapping) = create_mem_state_table(
720            &input_schema,
721            vec![2, 1, 3],
722            vec![
723                OrderType::ascending(),  // c ASC
724                OrderType::descending(), // b DESC for AggKind::Max
725                OrderType::ascending(),  // _row_id ASC
726            ],
727        )
728        .await;
729
730        let order_columns = vec![
731            ColumnOrder::new(1, OrderType::descending()), // b DESC for AggKind::Max
732            ColumnOrder::new(3, OrderType::ascending()),  // _row_id
733        ];
734        let mut state = MaterializedInputState::new(
735            PbAggNodeVersion::LATEST,
736            &agg_call,
737            &PkIndices::new(), // unused
738            &order_columns,
739            &mapping,
740            usize::MAX,
741            &input_schema,
742        )
743        .unwrap();
744
745        let mut epoch = EpochPair::new_test_epoch(test_epoch(1));
746        table.init_epoch(epoch).await.unwrap();
747
748        {
749            let chunk = create_chunk(
750                " T i i I
751                + a 1 8 123
752                + b 5 8 128
753                + c 7 3 130 D // hide this row",
754                &mut table,
755                &mapping,
756            );
757
758            state.apply_chunk(&chunk)?;
759
760            epoch.inc_for_test();
761            table.commit_for_test(epoch).await.unwrap();
762
763            let res = state
764                .get_output_no_stats(&table, group_key.as_ref(), &agg)
765                .await?;
766            assert_eq!(res, Some(5i32.into()));
767        }
768
769        {
770            let chunk = create_chunk(
771                " T i i I
772                + d 9 2 134 D // hide this row
773                + e 8 8 137",
774                &mut table,
775                &mapping,
776            );
777
778            state.apply_chunk(&chunk)?;
779
780            epoch.inc_for_test();
781            table.commit_for_test(epoch).await.unwrap();
782
783            let res = state
784                .get_output_no_stats(&table, group_key.as_ref(), &agg)
785                .await?;
786            assert_eq!(res, Some(8i32.into()));
787        }
788
789        {
790            // test recovery (cold start)
791            let mut state = MaterializedInputState::new(
792                PbAggNodeVersion::LATEST,
793                &agg_call,
794                &PkIndices::new(), // unused
795                &order_columns,
796                &mapping,
797                usize::MAX,
798                &input_schema,
799            )
800            .unwrap();
801
802            let res = state
803                .get_output_no_stats(&table, group_key.as_ref(), &agg)
804                .await?;
805            assert_eq!(res, Some(8i32.into()));
806        }
807
808        Ok(())
809    }
810
811    #[tokio::test]
812    async fn test_extreme_agg_state_with_random_values() -> StreamExecutorResult<()> {
813        // Assumption of input schema:
814        // (a: int32, _row_id: int64)
815
816        let field1 = Field::unnamed(DataType::Int32);
817        let field2 = Field::unnamed(DataType::Int64);
818        let input_schema = Schema::new(vec![field1, field2]);
819
820        let agg_call = AggCall::from_pretty("(min:int4 $0:int4)"); // min(a)
821        let agg = build_append_only(&agg_call).unwrap();
822        let group_key = None;
823
824        let (mut table, mapping) = create_mem_state_table(
825            &input_schema,
826            vec![0, 1],
827            vec![
828                OrderType::ascending(), // for AggKind::Min
829                OrderType::ascending(),
830            ],
831        )
832        .await;
833
834        let mut epoch = EpochPair::new_test_epoch(test_epoch(1));
835        table.init_epoch(epoch).await.unwrap();
836
837        let order_columns = vec![
838            ColumnOrder::new(0, OrderType::ascending()), // a ASC for AggKind::Min
839            ColumnOrder::new(1, OrderType::ascending()), // _row_id
840        ];
841        let mut state = MaterializedInputState::new(
842            PbAggNodeVersion::LATEST,
843            &agg_call,
844            &PkIndices::new(), // unused
845            &order_columns,
846            &mapping,
847            1024,
848            &input_schema,
849        )
850        .unwrap();
851
852        let mut rng = rand::rng();
853        let insert_values: Vec<i32> = (0..10000).map(|_| rng.random()).collect_vec();
854        let delete_values: HashSet<_> = insert_values
855            .iter()
856            .choose_multiple(&mut rng, 1000)
857            .into_iter()
858            .collect();
859        let mut min_value = i32::MAX;
860
861        {
862            let mut pretty_lines = vec!["i I".to_owned()];
863            for (row_id, value) in insert_values
864                .iter()
865                .enumerate()
866                .take(insert_values.len() / 2)
867            {
868                pretty_lines.push(format!("+ {} {}", value, row_id));
869                if delete_values.contains(&value) {
870                    pretty_lines.push(format!("- {} {}", value, row_id));
871                    continue;
872                }
873                if *value < min_value {
874                    min_value = *value;
875                }
876            }
877
878            let chunk = create_chunk(&pretty_lines.join("\n"), &mut table, &mapping);
879            state.apply_chunk(&chunk)?;
880
881            epoch.inc_for_test();
882            table.commit_for_test(epoch).await.unwrap();
883
884            let res = state
885                .get_output_no_stats(&table, group_key.as_ref(), &agg)
886                .await?;
887            assert_eq!(res, Some(min_value.into()));
888        }
889
890        {
891            let mut pretty_lines = vec!["i I".to_owned()];
892            for (row_id, value) in insert_values
893                .iter()
894                .enumerate()
895                .skip(insert_values.len() / 2)
896            {
897                pretty_lines.push(format!("+ {} {}", value, row_id));
898                if delete_values.contains(&value) {
899                    pretty_lines.push(format!("- {} {}", value, row_id));
900                    continue;
901                }
902                if *value < min_value {
903                    min_value = *value;
904                }
905            }
906
907            let chunk = create_chunk(&pretty_lines.join("\n"), &mut table, &mapping);
908            state.apply_chunk(&chunk)?;
909
910            epoch.inc_for_test();
911            table.commit_for_test(epoch).await.unwrap();
912
913            let res = state
914                .get_output_no_stats(&table, group_key.as_ref(), &agg)
915                .await?;
916            assert_eq!(res, Some(min_value.into()));
917        }
918
919        Ok(())
920    }
921
922    #[tokio::test]
923    async fn test_extreme_agg_state_cache_maintenance() -> StreamExecutorResult<()> {
924        // Assumption of input schema:
925        // (a: int32, _row_id: int64)
926
927        let field1 = Field::unnamed(DataType::Int32);
928        let field2 = Field::unnamed(DataType::Int64);
929        let input_schema = Schema::new(vec![field1, field2]);
930
931        let agg_call = AggCall::from_pretty("(min:int4 $0:int4)"); // min(a)
932        let agg = build_append_only(&agg_call).unwrap();
933        let group_key = None;
934
935        let (mut table, mapping) = create_mem_state_table(
936            &input_schema,
937            vec![0, 1],
938            vec![
939                OrderType::ascending(), // for AggKind::Min
940                OrderType::ascending(),
941            ],
942        )
943        .await;
944
945        let order_columns = vec![
946            ColumnOrder::new(0, OrderType::ascending()), // a ASC for AggKind::Min
947            ColumnOrder::new(1, OrderType::ascending()), // _row_id
948        ];
949        let mut state = MaterializedInputState::new(
950            PbAggNodeVersion::LATEST,
951            &agg_call,
952            &PkIndices::new(), // unused
953            &order_columns,
954            &mapping,
955            3, // cache capacity = 3 for easy testing
956            &input_schema,
957        )
958        .unwrap();
959
960        let mut epoch = EpochPair::new_test_epoch(test_epoch(1));
961        table.init_epoch(epoch).await.unwrap();
962
963        {
964            let chunk = create_chunk(
965                " i  I
966                + 4  123
967                + 8  128
968                + 12 129",
969                &mut table,
970                &mapping,
971            );
972            state.apply_chunk(&chunk)?;
973
974            epoch.inc_for_test();
975            table.commit_for_test(epoch).await.unwrap();
976
977            let res = state
978                .get_output_no_stats(&table, group_key.as_ref(), &agg)
979                .await?;
980            assert_eq!(res, Some(4i32.into()));
981        }
982
983        {
984            let chunk = create_chunk(
985                " i I
986                + 9  130 // this will evict 12
987                - 9  130
988                + 13 128
989                - 4  123
990                - 8  128",
991                &mut table,
992                &mapping,
993            );
994            state.apply_chunk(&chunk)?;
995
996            epoch.inc_for_test();
997            table.commit_for_test(epoch).await.unwrap();
998
999            let res = state
1000                .get_output_no_stats(&table, group_key.as_ref(), &agg)
1001                .await?;
1002            assert_eq!(res, Some(12i32.into()));
1003        }
1004
1005        {
1006            let chunk = create_chunk(
1007                " i  I
1008                + 1  131
1009                + 2  132
1010                + 3  133 // evict all from cache
1011                - 1  131
1012                - 2  132
1013                - 3  133
1014                + 14 134",
1015                &mut table,
1016                &mapping,
1017            );
1018            state.apply_chunk(&chunk)?;
1019
1020            epoch.inc_for_test();
1021            table.commit_for_test(epoch).await.unwrap();
1022
1023            let res = state
1024                .get_output_no_stats(&table, group_key.as_ref(), &agg)
1025                .await?;
1026            assert_eq!(res, Some(12i32.into()));
1027        }
1028
1029        Ok(())
1030    }
1031
1032    #[tokio::test]
1033    async fn test_string_agg_state() -> StreamExecutorResult<()> {
1034        // Assumption of input schema:
1035        // (a: varchar, _delim: varchar, b: int32, c: int32, _row_id: int64)
1036        // where `a` is the column to aggregate
1037
1038        let input_schema = Schema::new(vec![
1039            Field::unnamed(DataType::Varchar),
1040            Field::unnamed(DataType::Varchar),
1041            Field::unnamed(DataType::Int32),
1042            Field::unnamed(DataType::Int32),
1043            Field::unnamed(DataType::Int64),
1044        ]);
1045
1046        let agg_call = AggCall::from_pretty(
1047            "(string_agg:varchar $0:varchar $1:varchar orderby $2:asc $0:desc)",
1048        );
1049        let agg = build_append_only(&agg_call).unwrap();
1050        let group_key = None;
1051
1052        let (mut table, mapping) = create_mem_state_table(
1053            &input_schema,
1054            vec![2, 0, 4, 1],
1055            vec![
1056                OrderType::ascending(),  // b ASC
1057                OrderType::descending(), // a DESC
1058                OrderType::ascending(),  // _row_id ASC
1059            ],
1060        )
1061        .await;
1062
1063        let order_columns = vec![
1064            ColumnOrder::new(2, OrderType::ascending()),  // b ASC
1065            ColumnOrder::new(0, OrderType::descending()), // a DESC
1066            ColumnOrder::new(4, OrderType::ascending()),  // _row_id ASC
1067        ];
1068        let mut state = MaterializedInputState::new(
1069            PbAggNodeVersion::LATEST,
1070            &agg_call,
1071            &PkIndices::new(), // unused
1072            &order_columns,
1073            &mapping,
1074            usize::MAX,
1075            &input_schema,
1076        )
1077        .unwrap();
1078
1079        let mut epoch = EpochPair::new_test_epoch(test_epoch(1));
1080        table.init_epoch(epoch).await.unwrap();
1081
1082        {
1083            let chunk = create_chunk(
1084                " T T i i I
1085                + a , 1 8 123
1086                + b / 5 2 128
1087                - b / 5 2 128
1088                + c _ 1 3 130",
1089                &mut table,
1090                &mapping,
1091            );
1092            state.apply_chunk(&chunk)?;
1093
1094            epoch.inc_for_test();
1095            table.commit_for_test(epoch).await.unwrap();
1096
1097            let res = state
1098                .get_output_no_stats(&table, group_key.as_ref(), &agg)
1099                .await?;
1100            assert_eq!(res, Some("c,a".into()));
1101        }
1102
1103        {
1104            let chunk = create_chunk(
1105                " T T i i I
1106                + d - 0 8 134
1107                + e + 2 2 137",
1108                &mut table,
1109                &mapping,
1110            );
1111            state.apply_chunk(&chunk)?;
1112
1113            epoch.inc_for_test();
1114            table.commit_for_test(epoch).await.unwrap();
1115
1116            let res = state
1117                .get_output_no_stats(&table, group_key.as_ref(), &agg)
1118                .await?;
1119            assert_eq!(res, Some("d_c,a+e".into()));
1120        }
1121
1122        Ok(())
1123    }
1124
1125    #[tokio::test]
1126    async fn test_array_agg_state() -> StreamExecutorResult<()> {
1127        // Assumption of input schema:
1128        // (a: varchar, b: int32, c: int32, _row_id: int64)
1129        // where `a` is the column to aggregate
1130
1131        let field1 = Field::unnamed(DataType::Varchar);
1132        let field2 = Field::unnamed(DataType::Int32);
1133        let field3 = Field::unnamed(DataType::Int32);
1134        let field4 = Field::unnamed(DataType::Int64);
1135        let input_schema = Schema::new(vec![field1, field2, field3, field4]);
1136
1137        let agg_call = AggCall::from_pretty("(array_agg:int4[] $1:int4 orderby $2:asc $0:desc)");
1138        let agg = build_append_only(&agg_call).unwrap();
1139        let group_key = None;
1140
1141        let (mut table, mapping) = create_mem_state_table(
1142            &input_schema,
1143            vec![2, 0, 3, 1],
1144            vec![
1145                OrderType::ascending(),  // c ASC
1146                OrderType::descending(), // a DESC
1147                OrderType::ascending(),  // _row_id ASC
1148            ],
1149        )
1150        .await;
1151
1152        let order_columns = vec![
1153            ColumnOrder::new(2, OrderType::ascending()),  // c ASC
1154            ColumnOrder::new(0, OrderType::descending()), // a DESC
1155            ColumnOrder::new(3, OrderType::ascending()),  // _row_id ASC
1156        ];
1157        let mut state = MaterializedInputState::new(
1158            PbAggNodeVersion::LATEST,
1159            &agg_call,
1160            &PkIndices::new(), // unused
1161            &order_columns,
1162            &mapping,
1163            usize::MAX,
1164            &input_schema,
1165        )
1166        .unwrap();
1167
1168        let mut epoch = EpochPair::new_test_epoch(test_epoch(1));
1169        table.init_epoch(epoch).await.unwrap();
1170        {
1171            let chunk = create_chunk(
1172                " T i i I
1173                + a 1 8 123
1174                + b 5 2 128
1175                - b 5 2 128
1176                + c 2 3 130",
1177                &mut table,
1178                &mapping,
1179            );
1180            state.apply_chunk(&chunk)?;
1181
1182            epoch.inc_for_test();
1183            table.commit_for_test(epoch).await.unwrap();
1184
1185            let res = state
1186                .get_output_no_stats(&table, group_key.as_ref(), &agg)
1187                .await?;
1188            assert_eq!(res.unwrap().as_list(), &ListValue::from_iter([2, 1]));
1189        }
1190
1191        {
1192            let chunk = create_chunk(
1193                " T i i I
1194                + d 0 8 134
1195                + e 2 2 137",
1196                &mut table,
1197                &mapping,
1198            );
1199            state.apply_chunk(&chunk)?;
1200
1201            epoch.inc_for_test();
1202            table.commit_for_test(epoch).await.unwrap();
1203
1204            let res = state
1205                .get_output_no_stats(&table, group_key.as_ref(), &agg)
1206                .await?;
1207            assert_eq!(res.unwrap().as_list(), &ListValue::from_iter([2, 2, 0, 1]));
1208        }
1209
1210        Ok(())
1211    }
1212}