risingwave_stream/executor/aggregate/
minput.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::ops::Bound::{self};
16
17use futures::{StreamExt, pin_mut};
18use futures_async_stream::for_await;
19use itertools::Itertools;
20use risingwave_common::array::StreamChunk;
21use risingwave_common::catalog::Schema;
22use risingwave_common::row::{OwnedRow, RowExt};
23use risingwave_common::types::Datum;
24use risingwave_common::util::row_serde::OrderedRowSerde;
25use risingwave_common::util::sort_util::{ColumnOrder, OrderType};
26use risingwave_common_estimate_size::EstimateSize;
27use risingwave_expr::aggregate::{AggCall, AggType, BoxedAggregateFunction, PbAggKind};
28use risingwave_pb::stream_plan::PbAggNodeVersion;
29use risingwave_storage::StateStore;
30use risingwave_storage::store::PrefetchOptions;
31
32use super::agg_group::{AggStateCacheStats, GroupKey};
33use super::agg_state_cache::{AggStateCache, GenericAggStateCache};
34use crate::common::StateTableColumnMapping;
35use crate::common::state_cache::{OrderedStateCache, TopNStateCache};
36use crate::common::table::state_table::StateTable;
37use crate::executor::{PkIndices, StreamExecutorResult};
38
39/// Aggregation state as a materialization of input chunks.
40///
41/// For example, in `string_agg`, several useful columns are picked from input chunks and
42/// stored in the state table when applying chunks, and the aggregation result is calculated
43/// when need to get output.
44#[derive(EstimateSize)]
45pub struct MaterializedInputState {
46    /// Argument column indices in input chunks.
47    arg_col_indices: Vec<usize>,
48
49    /// Argument column indices in state table, group key skipped.
50    state_table_arg_col_indices: Vec<usize>,
51
52    /// The columns to order by in input chunks.
53    order_col_indices: Vec<usize>,
54
55    /// The columns to order by in state table, group key skipped.
56    state_table_order_col_indices: Vec<usize>,
57
58    /// Cache of state table.
59    cache: Box<dyn AggStateCache + Send + Sync>,
60
61    /// Whether to output the first value from cache.
62    output_first_value: bool,
63
64    /// Serializer for cache key.
65    #[estimate_size(ignore)]
66    cache_key_serializer: OrderedRowSerde,
67}
68
69impl MaterializedInputState {
70    /// Create an instance from [`AggCall`].
71    pub fn new(
72        version: PbAggNodeVersion,
73        agg_call: &AggCall,
74        pk_indices: &PkIndices,
75        order_columns: &[ColumnOrder],
76        col_mapping: &StateTableColumnMapping,
77        extreme_cache_size: usize,
78        input_schema: &Schema,
79    ) -> StreamExecutorResult<Self> {
80        if agg_call.distinct && version < PbAggNodeVersion::Issue12140 {
81            panic!(
82                "RisingWave versions before issue #12140 is resolved has critical bug, you must re-create current MV to ensure correctness."
83            );
84        }
85
86        let arg_col_indices = agg_call.args.val_indices().to_vec();
87
88        let (order_col_indices, order_types) = if version < PbAggNodeVersion::Issue13465 {
89            generate_order_columns_before_version_issue_13465(
90                agg_call,
91                pk_indices,
92                &arg_col_indices,
93            )
94        } else {
95            order_columns
96                .iter()
97                .map(|o| (o.column_index, o.order_type))
98                .unzip()
99        };
100
101        // map argument columns to state table column indices
102        let state_table_arg_col_indices = arg_col_indices
103            .iter()
104            .map(|i| {
105                col_mapping
106                    .upstream_to_state_table(*i)
107                    .expect("the argument columns must appear in the state table")
108            })
109            .collect_vec();
110
111        // map order by columns to state table column indices
112        let state_table_order_col_indices = order_col_indices
113            .iter()
114            .map(|i| {
115                col_mapping
116                    .upstream_to_state_table(*i)
117                    .expect("the order columns must appear in the state table")
118            })
119            .collect_vec();
120
121        let cache_key_data_types = order_col_indices
122            .iter()
123            .map(|i| input_schema[*i].data_type())
124            .collect_vec();
125        let cache_key_serializer = OrderedRowSerde::new(cache_key_data_types, order_types);
126
127        let cache: Box<dyn AggStateCache + Send + Sync> = match agg_call.agg_type {
128            AggType::Builtin(
129                PbAggKind::Min | PbAggKind::Max | PbAggKind::FirstValue | PbAggKind::LastValue,
130            ) => Box::new(GenericAggStateCache::new(
131                TopNStateCache::new(extreme_cache_size),
132                agg_call.args.arg_types(),
133            )),
134            AggType::Builtin(
135                PbAggKind::StringAgg
136                | PbAggKind::ArrayAgg
137                | PbAggKind::JsonbAgg
138                | PbAggKind::JsonbObjectAgg,
139            )
140            | AggType::WrapScalar(_) => Box::new(GenericAggStateCache::new(
141                OrderedStateCache::new(),
142                agg_call.args.arg_types(),
143            )),
144            _ => panic!(
145                "Agg type `{}` is not expected to have materialized input state",
146                agg_call.agg_type
147            ),
148        };
149        let output_first_value = matches!(
150            agg_call.agg_type,
151            AggType::Builtin(
152                PbAggKind::Min | PbAggKind::Max | PbAggKind::FirstValue | PbAggKind::LastValue
153            )
154        );
155
156        Ok(Self {
157            arg_col_indices,
158            state_table_arg_col_indices,
159            order_col_indices,
160            state_table_order_col_indices,
161            cache,
162            output_first_value,
163            cache_key_serializer,
164        })
165    }
166
167    /// Apply a chunk of data to the state cache.
168    /// This method should never involve any state table operations.
169    pub fn apply_chunk(&mut self, chunk: &StreamChunk) -> StreamExecutorResult<()> {
170        self.cache.apply_batch(
171            chunk,
172            &self.cache_key_serializer,
173            &self.arg_col_indices,
174            &self.order_col_indices,
175        );
176        Ok(())
177    }
178
179    /// Get the output of the state.
180    /// We may need to read from the state table into the cache to get the output.
181    pub async fn get_output(
182        &mut self,
183        state_table: &StateTable<impl StateStore>,
184        group_key: Option<&GroupKey>,
185        func: &BoxedAggregateFunction,
186    ) -> StreamExecutorResult<(Datum, AggStateCacheStats)> {
187        let mut stats = AggStateCacheStats::default();
188        stats.agg_state_cache_lookup_count += 1;
189
190        if !self.cache.is_synced() {
191            stats.agg_state_cache_miss_count += 1;
192
193            let mut cache_filler = self.cache.begin_syncing();
194            let sub_range: &(Bound<OwnedRow>, Bound<OwnedRow>) =
195                &(Bound::Unbounded, Bound::Unbounded);
196            let all_data_iter = state_table
197                .iter_with_prefix(
198                    group_key.map(GroupKey::table_pk),
199                    sub_range,
200                    PrefetchOptions {
201                        prefetch: cache_filler.capacity().is_none(),
202                        for_large_query: false,
203                    },
204                )
205                .await?;
206            pin_mut!(all_data_iter);
207
208            #[for_await]
209            for keyed_row in all_data_iter.take(cache_filler.capacity().unwrap_or(usize::MAX)) {
210                let state_row = keyed_row?;
211                let cache_key = {
212                    let mut cache_key = Vec::new();
213                    self.cache_key_serializer.serialize(
214                        state_row
215                            .as_ref()
216                            .project(&self.state_table_order_col_indices),
217                        &mut cache_key,
218                    );
219                    cache_key.into()
220                };
221                let cache_value = self
222                    .state_table_arg_col_indices
223                    .iter()
224                    .map(|i| state_row[*i].clone())
225                    .collect();
226                cache_filler.append(cache_key, cache_value);
227            }
228            cache_filler.finish();
229        }
230        assert!(self.cache.is_synced());
231
232        if self.output_first_value {
233            // special case for `min`, `max`, `first_value` and `last_value`
234            // take the first value from the cache
235            Ok((self.cache.output_first(), stats))
236        } else {
237            const CHUNK_SIZE: usize = 1024;
238            let chunks = self.cache.output_batches(CHUNK_SIZE).collect_vec();
239            let mut state = func.create_state()?;
240            for chunk in chunks {
241                func.update(&mut state, &chunk).await?;
242            }
243            Ok((func.get_result(&state).await?, stats))
244        }
245    }
246
247    #[cfg(test)]
248    async fn get_output_no_stats(
249        &mut self,
250        state_table: &StateTable<impl StateStore>,
251        group_key: Option<&GroupKey>,
252        func: &BoxedAggregateFunction,
253    ) -> StreamExecutorResult<Datum> {
254        let (res, _stats) = self.get_output(state_table, group_key, func).await?;
255        Ok(res)
256    }
257}
258
259/// Copied from old code before <https://github.com/risingwavelabs/risingwave/commit/0020507edbc4010b20aeeb560c7bea9159315602>.
260fn generate_order_columns_before_version_issue_13465(
261    agg_call: &AggCall,
262    pk_indices: &PkIndices,
263    arg_col_indices: &[usize],
264) -> (Vec<usize>, Vec<OrderType>) {
265    let (mut order_col_indices, mut order_types) = if matches!(
266        agg_call.agg_type,
267        AggType::Builtin(PbAggKind::Min | PbAggKind::Max)
268    ) {
269        // `min`/`max` need not to order by any other columns, but have to
270        // order by the agg value implicitly.
271        let order_type = if matches!(agg_call.agg_type, AggType::Builtin(PbAggKind::Min)) {
272            OrderType::ascending()
273        } else {
274            OrderType::descending()
275        };
276        (vec![arg_col_indices[0]], vec![order_type])
277    } else {
278        agg_call
279            .column_orders
280            .iter()
281            .map(|p| {
282                (
283                    p.column_index,
284                    if matches!(agg_call.agg_type, AggType::Builtin(PbAggKind::LastValue)) {
285                        p.order_type.reverse()
286                    } else {
287                        p.order_type
288                    },
289                )
290            })
291            .unzip()
292    };
293
294    if agg_call.distinct {
295        // If distinct, we need to materialize input with the distinct keys
296        // As we only support single-column distinct for now, we use the
297        // `agg_call.args.val_indices()[0]` as the distinct key.
298        if !order_col_indices.contains(&agg_call.args.val_indices()[0]) {
299            order_col_indices.push(agg_call.args.val_indices()[0]);
300            order_types.push(OrderType::ascending());
301        }
302    } else {
303        // If not distinct, we need to materialize input with the primary keys
304        let pk_len = pk_indices.len();
305        order_col_indices.extend(pk_indices.iter());
306        order_types.extend(itertools::repeat_n(OrderType::ascending(), pk_len));
307    }
308
309    (order_col_indices, order_types)
310}
311
312#[cfg(test)]
313mod tests {
314    use std::collections::HashSet;
315
316    use itertools::Itertools;
317    use rand::Rng;
318    use rand::seq::IteratorRandom;
319    use risingwave_common::array::StreamChunk;
320    use risingwave_common::catalog::{ColumnDesc, ColumnId, Field, Schema, TableId};
321    use risingwave_common::row::OwnedRow;
322    use risingwave_common::test_prelude::StreamChunkTestExt;
323    use risingwave_common::types::{DataType, ListValue};
324    use risingwave_common::util::epoch::{EpochPair, test_epoch};
325    use risingwave_common::util::sort_util::{ColumnOrder, OrderType};
326    use risingwave_expr::aggregate::{AggCall, build_append_only};
327    use risingwave_pb::stream_plan::PbAggNodeVersion;
328    use risingwave_storage::StateStore;
329    use risingwave_storage::memory::MemoryStateStore;
330
331    use super::*;
332    use crate::common::StateTableColumnMapping;
333    use crate::common::table::state_table::StateTable;
334    use crate::common::table::test_utils::gen_pbtable;
335    use crate::executor::{PkIndices, StreamExecutorResult};
336
337    fn create_chunk<S: StateStore>(
338        pretty: &str,
339        table: &mut StateTable<S>,
340        col_mapping: &StateTableColumnMapping,
341    ) -> StreamChunk {
342        let chunk = StreamChunk::from_pretty(pretty);
343        table.write_chunk(chunk.project(col_mapping.upstream_columns()));
344        chunk
345    }
346
347    async fn create_mem_state_table(
348        input_schema: &Schema,
349        upstream_columns: Vec<usize>,
350        order_types: Vec<OrderType>,
351    ) -> (StateTable<MemoryStateStore>, StateTableColumnMapping) {
352        // see `LogicalAgg::infer_stream_agg_state` for the construction of state table
353        let table_id = TableId::new(rand::rng().random());
354        let columns = upstream_columns
355            .iter()
356            .map(|col_idx| input_schema[*col_idx].data_type())
357            .enumerate()
358            .map(|(i, data_type)| ColumnDesc::unnamed(ColumnId::new(i as i32), data_type))
359            .collect_vec();
360        let mapping = StateTableColumnMapping::new(upstream_columns, None);
361        let pk_len = order_types.len();
362        let table = StateTable::from_table_catalog(
363            &gen_pbtable(table_id, columns, order_types, (0..pk_len).collect(), 0),
364            MemoryStateStore::new(),
365            None,
366        )
367        .await;
368        (table, mapping)
369    }
370
371    #[tokio::test]
372    async fn test_extreme_agg_state_basic_min() -> StreamExecutorResult<()> {
373        // Assumption of input schema:
374        // (a: varchar, b: int32, c: int32, _row_id: int64)
375
376        let field1 = Field::unnamed(DataType::Varchar);
377        let field2 = Field::unnamed(DataType::Int32);
378        let field3 = Field::unnamed(DataType::Int32);
379        let field4 = Field::unnamed(DataType::Int64);
380        let input_schema = Schema::new(vec![field1, field2, field3, field4]);
381
382        let agg_call = AggCall::from_pretty("(min:int4 $2:int4)"); // min(c)
383        let agg = build_append_only(&agg_call).unwrap();
384        let group_key = None;
385
386        let (mut table, mapping) = create_mem_state_table(
387            &input_schema,
388            vec![2, 3],
389            vec![
390                OrderType::ascending(), // for AggKind::Min
391                OrderType::ascending(),
392            ],
393        )
394        .await;
395
396        let order_columns = vec![
397            ColumnOrder::new(2, OrderType::ascending()), // c ASC for AggKind::Min
398            ColumnOrder::new(3, OrderType::ascending()), // _row_id
399        ];
400        let mut state = MaterializedInputState::new(
401            PbAggNodeVersion::LATEST,
402            &agg_call,
403            &PkIndices::new(), // unused
404            &order_columns,
405            &mapping,
406            usize::MAX,
407            &input_schema,
408        )
409        .unwrap();
410
411        let mut epoch = EpochPair::new_test_epoch(test_epoch(1));
412        table.init_epoch(epoch).await.unwrap();
413
414        {
415            let chunk = create_chunk(
416                " T i i I
417                + a 1 8 123
418                + b 5 2 128
419                - b 5 2 128
420                + c 1 3 130",
421                &mut table,
422                &mapping,
423            );
424
425            state.apply_chunk(&chunk)?;
426
427            epoch.inc_for_test();
428            table.commit_for_test(epoch).await.unwrap();
429
430            let res = state
431                .get_output_no_stats(&table, group_key.as_ref(), &agg)
432                .await?;
433            assert_eq!(res, Some(3i32.into()));
434        }
435
436        {
437            let chunk = create_chunk(
438                " T i i I
439                + d 0 8 134
440                + e 2 2 137",
441                &mut table,
442                &mapping,
443            );
444
445            state.apply_chunk(&chunk)?;
446
447            epoch.inc_for_test();
448            table.commit_for_test(epoch).await.unwrap();
449
450            let res = state
451                .get_output_no_stats(&table, group_key.as_ref(), &agg)
452                .await?;
453            assert_eq!(res, Some(2i32.into()));
454        }
455
456        {
457            // test recovery (cold start)
458            let mut state = MaterializedInputState::new(
459                PbAggNodeVersion::LATEST,
460                &agg_call,
461                &PkIndices::new(), // unused
462                &order_columns,
463                &mapping,
464                usize::MAX,
465                &input_schema,
466            )
467            .unwrap();
468            let res = state
469                .get_output_no_stats(&table, group_key.as_ref(), &agg)
470                .await?;
471            assert_eq!(res, Some(2i32.into()));
472        }
473
474        Ok(())
475    }
476
477    #[tokio::test]
478    async fn test_extreme_agg_state_basic_max() -> StreamExecutorResult<()> {
479        // Assumption of input schema:
480        // (a: varchar, b: int32, c: int32, _row_id: int64)
481
482        let field1 = Field::unnamed(DataType::Varchar);
483        let field2 = Field::unnamed(DataType::Int32);
484        let field3 = Field::unnamed(DataType::Int32);
485        let field4 = Field::unnamed(DataType::Int64);
486        let input_schema = Schema::new(vec![field1, field2, field3, field4]);
487
488        let agg_call = AggCall::from_pretty("(max:int4 $2:int4)"); // max(c)
489        let agg = build_append_only(&agg_call).unwrap();
490        let group_key = None;
491
492        let (mut table, mapping) = create_mem_state_table(
493            &input_schema,
494            vec![2, 3],
495            vec![
496                OrderType::descending(), // for AggKind::Max
497                OrderType::ascending(),
498            ],
499        )
500        .await;
501
502        let order_columns = vec![
503            ColumnOrder::new(2, OrderType::descending()), // c DESC for AggKind::Max
504            ColumnOrder::new(3, OrderType::ascending()),  // _row_id
505        ];
506        let mut state = MaterializedInputState::new(
507            PbAggNodeVersion::LATEST,
508            &agg_call,
509            &PkIndices::new(), // unused
510            &order_columns,
511            &mapping,
512            usize::MAX,
513            &input_schema,
514        )
515        .unwrap();
516
517        let mut epoch = EpochPair::new_test_epoch(test_epoch(1));
518        table.init_epoch(epoch).await.unwrap();
519
520        {
521            let chunk = create_chunk(
522                " T i i I
523                + a 1 8 123
524                + b 5 2 128
525                - b 5 2 128
526                + c 1 3 130",
527                &mut table,
528                &mapping,
529            );
530
531            state.apply_chunk(&chunk)?;
532
533            epoch.inc_for_test();
534            table.commit_for_test(epoch).await.unwrap();
535
536            let res = state
537                .get_output_no_stats(&table, group_key.as_ref(), &agg)
538                .await?;
539            assert_eq!(res, Some(8i32.into()));
540        }
541
542        {
543            let chunk = create_chunk(
544                " T i i I
545                + d 0 9 134
546                + e 2 2 137",
547                &mut table,
548                &mapping,
549            );
550
551            state.apply_chunk(&chunk)?;
552
553            epoch.inc_for_test();
554            table.commit_for_test(epoch).await.unwrap();
555
556            let res = state
557                .get_output_no_stats(&table, group_key.as_ref(), &agg)
558                .await?;
559            assert_eq!(res, Some(9i32.into()));
560        }
561
562        {
563            // test recovery (cold start)
564            let mut state = MaterializedInputState::new(
565                PbAggNodeVersion::LATEST,
566                &agg_call,
567                &PkIndices::new(), // unused
568                &order_columns,
569                &mapping,
570                usize::MAX,
571                &input_schema,
572            )
573            .unwrap();
574
575            let res = state
576                .get_output_no_stats(&table, group_key.as_ref(), &agg)
577                .await?;
578            assert_eq!(res, Some(9i32.into()));
579        }
580
581        Ok(())
582    }
583
584    #[tokio::test]
585    async fn test_extreme_agg_state_with_hidden_input() -> StreamExecutorResult<()> {
586        // Assumption of input schema:
587        // (a: varchar, b: int32, c: int32, _row_id: int64)
588
589        let field1 = Field::unnamed(DataType::Varchar);
590        let field2 = Field::unnamed(DataType::Int32);
591        let field3 = Field::unnamed(DataType::Int32);
592        let field4 = Field::unnamed(DataType::Int64);
593        let input_schema = Schema::new(vec![field1, field2, field3, field4]);
594
595        let agg_call_1 = AggCall::from_pretty("(min:varchar $0:varchar)"); // min(a)
596        let agg_call_2 = AggCall::from_pretty("(max:int4 $1:int4)"); // max(b)
597        let agg1 = build_append_only(&agg_call_1).unwrap();
598        let agg2 = build_append_only(&agg_call_2).unwrap();
599        let group_key = None;
600
601        let (mut table_1, mapping_1) = create_mem_state_table(
602            &input_schema,
603            vec![0, 3],
604            vec![
605                OrderType::ascending(), // for AggKind::Min
606                OrderType::ascending(),
607            ],
608        )
609        .await;
610        let (mut table_2, mapping_2) = create_mem_state_table(
611            &input_schema,
612            vec![1, 3],
613            vec![
614                OrderType::descending(), // for AggKind::Max
615                OrderType::ascending(),
616            ],
617        )
618        .await;
619
620        let mut epoch = EpochPair::new_test_epoch(test_epoch(1));
621        table_1.init_epoch(epoch).await.unwrap();
622        table_2.init_epoch(epoch).await.unwrap();
623
624        let order_columns_1 = vec![
625            ColumnOrder::new(0, OrderType::ascending()), // a ASC for AggKind::Min
626            ColumnOrder::new(3, OrderType::ascending()), // _row_id
627        ];
628        let mut state_1 = MaterializedInputState::new(
629            PbAggNodeVersion::LATEST,
630            &agg_call_1,
631            &PkIndices::new(), // unused
632            &order_columns_1,
633            &mapping_1,
634            usize::MAX,
635            &input_schema,
636        )
637        .unwrap();
638
639        let order_columns_2 = vec![
640            ColumnOrder::new(1, OrderType::descending()), // b DESC for AggKind::Max
641            ColumnOrder::new(3, OrderType::ascending()),  // _row_id
642        ];
643        let mut state_2 = MaterializedInputState::new(
644            PbAggNodeVersion::LATEST,
645            &agg_call_2,
646            &PkIndices::new(), // unused
647            &order_columns_2,
648            &mapping_2,
649            usize::MAX,
650            &input_schema,
651        )
652        .unwrap();
653
654        {
655            let chunk_1 = create_chunk(
656                " T i i I
657                + a 1 8 123
658                + b 5 2 128
659                - b 5 2 128
660                + c 1 3 130
661                + . 9 4 131 D
662                + . 6 5 132 D
663                + c . 3 133",
664                &mut table_1,
665                &mapping_1,
666            );
667            let chunk_2 = create_chunk(
668                " T i i I
669                + a 1 8 123
670                + b 5 2 128
671                - b 5 2 128
672                + c 1 3 130
673                + . 9 4 131
674                + . 6 5 132
675                + c . 3 133 D",
676                &mut table_2,
677                &mapping_2,
678            );
679
680            state_1.apply_chunk(&chunk_1)?;
681            state_2.apply_chunk(&chunk_2)?;
682
683            epoch.inc_for_test();
684            table_1.commit_for_test(epoch).await.unwrap();
685            table_2.commit_for_test(epoch).await.unwrap();
686
687            let out1 = state_1
688                .get_output_no_stats(&table_1, group_key.as_ref(), &agg1)
689                .await?;
690            assert_eq!(out1, Some("a".into()));
691
692            let out2 = state_2
693                .get_output_no_stats(&table_2, group_key.as_ref(), &agg2)
694                .await?;
695            assert_eq!(out2, Some(9i32.into()));
696        }
697
698        Ok(())
699    }
700
701    #[tokio::test]
702    async fn test_extreme_agg_state_grouped() -> StreamExecutorResult<()> {
703        // Assumption of input schema:
704        // (a: varchar, b: int32, c: int32, _row_id: int64)
705
706        let field1 = Field::unnamed(DataType::Varchar);
707        let field2 = Field::unnamed(DataType::Int32);
708        let field3 = Field::unnamed(DataType::Int32);
709        let field4 = Field::unnamed(DataType::Int64);
710        let input_schema = Schema::new(vec![field1, field2, field3, field4]);
711
712        let agg_call = AggCall::from_pretty("(max:int4 $1:int4)"); // max(b)
713        let agg = build_append_only(&agg_call).unwrap();
714        let group_key = Some(GroupKey::new(OwnedRow::new(vec![Some(8.into())]), None));
715
716        let (mut table, mapping) = create_mem_state_table(
717            &input_schema,
718            vec![2, 1, 3],
719            vec![
720                OrderType::ascending(),  // c ASC
721                OrderType::descending(), // b DESC for AggKind::Max
722                OrderType::ascending(),  // _row_id ASC
723            ],
724        )
725        .await;
726
727        let order_columns = vec![
728            ColumnOrder::new(1, OrderType::descending()), // b DESC for AggKind::Max
729            ColumnOrder::new(3, OrderType::ascending()),  // _row_id
730        ];
731        let mut state = MaterializedInputState::new(
732            PbAggNodeVersion::LATEST,
733            &agg_call,
734            &PkIndices::new(), // unused
735            &order_columns,
736            &mapping,
737            usize::MAX,
738            &input_schema,
739        )
740        .unwrap();
741
742        let mut epoch = EpochPair::new_test_epoch(test_epoch(1));
743        table.init_epoch(epoch).await.unwrap();
744
745        {
746            let chunk = create_chunk(
747                " T i i I
748                + a 1 8 123
749                + b 5 8 128
750                + c 7 3 130 D // hide this row",
751                &mut table,
752                &mapping,
753            );
754
755            state.apply_chunk(&chunk)?;
756
757            epoch.inc_for_test();
758            table.commit_for_test(epoch).await.unwrap();
759
760            let res = state
761                .get_output_no_stats(&table, group_key.as_ref(), &agg)
762                .await?;
763            assert_eq!(res, Some(5i32.into()));
764        }
765
766        {
767            let chunk = create_chunk(
768                " T i i I
769                + d 9 2 134 D // hide this row
770                + e 8 8 137",
771                &mut table,
772                &mapping,
773            );
774
775            state.apply_chunk(&chunk)?;
776
777            epoch.inc_for_test();
778            table.commit_for_test(epoch).await.unwrap();
779
780            let res = state
781                .get_output_no_stats(&table, group_key.as_ref(), &agg)
782                .await?;
783            assert_eq!(res, Some(8i32.into()));
784        }
785
786        {
787            // test recovery (cold start)
788            let mut state = MaterializedInputState::new(
789                PbAggNodeVersion::LATEST,
790                &agg_call,
791                &PkIndices::new(), // unused
792                &order_columns,
793                &mapping,
794                usize::MAX,
795                &input_schema,
796            )
797            .unwrap();
798
799            let res = state
800                .get_output_no_stats(&table, group_key.as_ref(), &agg)
801                .await?;
802            assert_eq!(res, Some(8i32.into()));
803        }
804
805        Ok(())
806    }
807
808    #[tokio::test]
809    async fn test_extreme_agg_state_with_random_values() -> StreamExecutorResult<()> {
810        // Assumption of input schema:
811        // (a: int32, _row_id: int64)
812
813        let field1 = Field::unnamed(DataType::Int32);
814        let field2 = Field::unnamed(DataType::Int64);
815        let input_schema = Schema::new(vec![field1, field2]);
816
817        let agg_call = AggCall::from_pretty("(min:int4 $0:int4)"); // min(a)
818        let agg = build_append_only(&agg_call).unwrap();
819        let group_key = None;
820
821        let (mut table, mapping) = create_mem_state_table(
822            &input_schema,
823            vec![0, 1],
824            vec![
825                OrderType::ascending(), // for AggKind::Min
826                OrderType::ascending(),
827            ],
828        )
829        .await;
830
831        let mut epoch = EpochPair::new_test_epoch(test_epoch(1));
832        table.init_epoch(epoch).await.unwrap();
833
834        let order_columns = vec![
835            ColumnOrder::new(0, OrderType::ascending()), // a ASC for AggKind::Min
836            ColumnOrder::new(1, OrderType::ascending()), // _row_id
837        ];
838        let mut state = MaterializedInputState::new(
839            PbAggNodeVersion::LATEST,
840            &agg_call,
841            &PkIndices::new(), // unused
842            &order_columns,
843            &mapping,
844            1024,
845            &input_schema,
846        )
847        .unwrap();
848
849        let mut rng = rand::rng();
850        let insert_values: Vec<i32> = (0..10000).map(|_| rng.random()).collect_vec();
851        let delete_values: HashSet<_> = insert_values
852            .iter()
853            .choose_multiple(&mut rng, 1000)
854            .into_iter()
855            .collect();
856        let mut min_value = i32::MAX;
857
858        {
859            let mut pretty_lines = vec!["i I".to_owned()];
860            for (row_id, value) in insert_values
861                .iter()
862                .enumerate()
863                .take(insert_values.len() / 2)
864            {
865                pretty_lines.push(format!("+ {} {}", value, row_id));
866                if delete_values.contains(&value) {
867                    pretty_lines.push(format!("- {} {}", value, row_id));
868                    continue;
869                }
870                if *value < min_value {
871                    min_value = *value;
872                }
873            }
874
875            let chunk = create_chunk(&pretty_lines.join("\n"), &mut table, &mapping);
876            state.apply_chunk(&chunk)?;
877
878            epoch.inc_for_test();
879            table.commit_for_test(epoch).await.unwrap();
880
881            let res = state
882                .get_output_no_stats(&table, group_key.as_ref(), &agg)
883                .await?;
884            assert_eq!(res, Some(min_value.into()));
885        }
886
887        {
888            let mut pretty_lines = vec!["i I".to_owned()];
889            for (row_id, value) in insert_values
890                .iter()
891                .enumerate()
892                .skip(insert_values.len() / 2)
893            {
894                pretty_lines.push(format!("+ {} {}", value, row_id));
895                if delete_values.contains(&value) {
896                    pretty_lines.push(format!("- {} {}", value, row_id));
897                    continue;
898                }
899                if *value < min_value {
900                    min_value = *value;
901                }
902            }
903
904            let chunk = create_chunk(&pretty_lines.join("\n"), &mut table, &mapping);
905            state.apply_chunk(&chunk)?;
906
907            epoch.inc_for_test();
908            table.commit_for_test(epoch).await.unwrap();
909
910            let res = state
911                .get_output_no_stats(&table, group_key.as_ref(), &agg)
912                .await?;
913            assert_eq!(res, Some(min_value.into()));
914        }
915
916        Ok(())
917    }
918
919    #[tokio::test]
920    async fn test_extreme_agg_state_cache_maintenance() -> StreamExecutorResult<()> {
921        // Assumption of input schema:
922        // (a: int32, _row_id: int64)
923
924        let field1 = Field::unnamed(DataType::Int32);
925        let field2 = Field::unnamed(DataType::Int64);
926        let input_schema = Schema::new(vec![field1, field2]);
927
928        let agg_call = AggCall::from_pretty("(min:int4 $0:int4)"); // min(a)
929        let agg = build_append_only(&agg_call).unwrap();
930        let group_key = None;
931
932        let (mut table, mapping) = create_mem_state_table(
933            &input_schema,
934            vec![0, 1],
935            vec![
936                OrderType::ascending(), // for AggKind::Min
937                OrderType::ascending(),
938            ],
939        )
940        .await;
941
942        let order_columns = vec![
943            ColumnOrder::new(0, OrderType::ascending()), // a ASC for AggKind::Min
944            ColumnOrder::new(1, OrderType::ascending()), // _row_id
945        ];
946        let mut state = MaterializedInputState::new(
947            PbAggNodeVersion::LATEST,
948            &agg_call,
949            &PkIndices::new(), // unused
950            &order_columns,
951            &mapping,
952            3, // cache capacity = 3 for easy testing
953            &input_schema,
954        )
955        .unwrap();
956
957        let mut epoch = EpochPair::new_test_epoch(test_epoch(1));
958        table.init_epoch(epoch).await.unwrap();
959
960        {
961            let chunk = create_chunk(
962                " i  I
963                + 4  123
964                + 8  128
965                + 12 129",
966                &mut table,
967                &mapping,
968            );
969            state.apply_chunk(&chunk)?;
970
971            epoch.inc_for_test();
972            table.commit_for_test(epoch).await.unwrap();
973
974            let res = state
975                .get_output_no_stats(&table, group_key.as_ref(), &agg)
976                .await?;
977            assert_eq!(res, Some(4i32.into()));
978        }
979
980        {
981            let chunk = create_chunk(
982                " i I
983                + 9  130 // this will evict 12
984                - 9  130
985                + 13 128
986                - 4  123
987                - 8  128",
988                &mut table,
989                &mapping,
990            );
991            state.apply_chunk(&chunk)?;
992
993            epoch.inc_for_test();
994            table.commit_for_test(epoch).await.unwrap();
995
996            let res = state
997                .get_output_no_stats(&table, group_key.as_ref(), &agg)
998                .await?;
999            assert_eq!(res, Some(12i32.into()));
1000        }
1001
1002        {
1003            let chunk = create_chunk(
1004                " i  I
1005                + 1  131
1006                + 2  132
1007                + 3  133 // evict all from cache
1008                - 1  131
1009                - 2  132
1010                - 3  133
1011                + 14 134",
1012                &mut table,
1013                &mapping,
1014            );
1015            state.apply_chunk(&chunk)?;
1016
1017            epoch.inc_for_test();
1018            table.commit_for_test(epoch).await.unwrap();
1019
1020            let res = state
1021                .get_output_no_stats(&table, group_key.as_ref(), &agg)
1022                .await?;
1023            assert_eq!(res, Some(12i32.into()));
1024        }
1025
1026        Ok(())
1027    }
1028
1029    #[tokio::test]
1030    async fn test_string_agg_state() -> StreamExecutorResult<()> {
1031        // Assumption of input schema:
1032        // (a: varchar, _delim: varchar, b: int32, c: int32, _row_id: int64)
1033        // where `a` is the column to aggregate
1034
1035        let input_schema = Schema::new(vec![
1036            Field::unnamed(DataType::Varchar),
1037            Field::unnamed(DataType::Varchar),
1038            Field::unnamed(DataType::Int32),
1039            Field::unnamed(DataType::Int32),
1040            Field::unnamed(DataType::Int64),
1041        ]);
1042
1043        let agg_call = AggCall::from_pretty(
1044            "(string_agg:varchar $0:varchar $1:varchar orderby $2:asc $0:desc)",
1045        );
1046        let agg = build_append_only(&agg_call).unwrap();
1047        let group_key = None;
1048
1049        let (mut table, mapping) = create_mem_state_table(
1050            &input_schema,
1051            vec![2, 0, 4, 1],
1052            vec![
1053                OrderType::ascending(),  // b ASC
1054                OrderType::descending(), // a DESC
1055                OrderType::ascending(),  // _row_id ASC
1056            ],
1057        )
1058        .await;
1059
1060        let order_columns = vec![
1061            ColumnOrder::new(2, OrderType::ascending()),  // b ASC
1062            ColumnOrder::new(0, OrderType::descending()), // a DESC
1063            ColumnOrder::new(4, OrderType::ascending()),  // _row_id ASC
1064        ];
1065        let mut state = MaterializedInputState::new(
1066            PbAggNodeVersion::LATEST,
1067            &agg_call,
1068            &PkIndices::new(), // unused
1069            &order_columns,
1070            &mapping,
1071            usize::MAX,
1072            &input_schema,
1073        )
1074        .unwrap();
1075
1076        let mut epoch = EpochPair::new_test_epoch(test_epoch(1));
1077        table.init_epoch(epoch).await.unwrap();
1078
1079        {
1080            let chunk = create_chunk(
1081                " T T i i I
1082                + a , 1 8 123
1083                + b / 5 2 128
1084                - b / 5 2 128
1085                + c _ 1 3 130",
1086                &mut table,
1087                &mapping,
1088            );
1089            state.apply_chunk(&chunk)?;
1090
1091            epoch.inc_for_test();
1092            table.commit_for_test(epoch).await.unwrap();
1093
1094            let res = state
1095                .get_output_no_stats(&table, group_key.as_ref(), &agg)
1096                .await?;
1097            assert_eq!(res, Some("c,a".into()));
1098        }
1099
1100        {
1101            let chunk = create_chunk(
1102                " T T i i I
1103                + d - 0 8 134
1104                + e + 2 2 137",
1105                &mut table,
1106                &mapping,
1107            );
1108            state.apply_chunk(&chunk)?;
1109
1110            epoch.inc_for_test();
1111            table.commit_for_test(epoch).await.unwrap();
1112
1113            let res = state
1114                .get_output_no_stats(&table, group_key.as_ref(), &agg)
1115                .await?;
1116            assert_eq!(res, Some("d_c,a+e".into()));
1117        }
1118
1119        Ok(())
1120    }
1121
1122    #[tokio::test]
1123    async fn test_array_agg_state() -> StreamExecutorResult<()> {
1124        // Assumption of input schema:
1125        // (a: varchar, b: int32, c: int32, _row_id: int64)
1126        // where `a` is the column to aggregate
1127
1128        let field1 = Field::unnamed(DataType::Varchar);
1129        let field2 = Field::unnamed(DataType::Int32);
1130        let field3 = Field::unnamed(DataType::Int32);
1131        let field4 = Field::unnamed(DataType::Int64);
1132        let input_schema = Schema::new(vec![field1, field2, field3, field4]);
1133
1134        let agg_call = AggCall::from_pretty("(array_agg:int4[] $1:int4 orderby $2:asc $0:desc)");
1135        let agg = build_append_only(&agg_call).unwrap();
1136        let group_key = None;
1137
1138        let (mut table, mapping) = create_mem_state_table(
1139            &input_schema,
1140            vec![2, 0, 3, 1],
1141            vec![
1142                OrderType::ascending(),  // c ASC
1143                OrderType::descending(), // a DESC
1144                OrderType::ascending(),  // _row_id ASC
1145            ],
1146        )
1147        .await;
1148
1149        let order_columns = vec![
1150            ColumnOrder::new(2, OrderType::ascending()),  // c ASC
1151            ColumnOrder::new(0, OrderType::descending()), // a DESC
1152            ColumnOrder::new(3, OrderType::ascending()),  // _row_id ASC
1153        ];
1154        let mut state = MaterializedInputState::new(
1155            PbAggNodeVersion::LATEST,
1156            &agg_call,
1157            &PkIndices::new(), // unused
1158            &order_columns,
1159            &mapping,
1160            usize::MAX,
1161            &input_schema,
1162        )
1163        .unwrap();
1164
1165        let mut epoch = EpochPair::new_test_epoch(test_epoch(1));
1166        table.init_epoch(epoch).await.unwrap();
1167        {
1168            let chunk = create_chunk(
1169                " T i i I
1170                + a 1 8 123
1171                + b 5 2 128
1172                - b 5 2 128
1173                + c 2 3 130",
1174                &mut table,
1175                &mapping,
1176            );
1177            state.apply_chunk(&chunk)?;
1178
1179            epoch.inc_for_test();
1180            table.commit_for_test(epoch).await.unwrap();
1181
1182            let res = state
1183                .get_output_no_stats(&table, group_key.as_ref(), &agg)
1184                .await?;
1185            assert_eq!(res.unwrap().as_list(), &ListValue::from_iter([2, 1]));
1186        }
1187
1188        {
1189            let chunk = create_chunk(
1190                " T i i I
1191                + d 0 8 134
1192                + e 2 2 137",
1193                &mut table,
1194                &mapping,
1195            );
1196            state.apply_chunk(&chunk)?;
1197
1198            epoch.inc_for_test();
1199            table.commit_for_test(epoch).await.unwrap();
1200
1201            let res = state
1202                .get_output_no_stats(&table, group_key.as_ref(), &agg)
1203                .await?;
1204            assert_eq!(res.unwrap().as_list(), &ListValue::from_iter([2, 2, 0, 1]));
1205        }
1206
1207        Ok(())
1208    }
1209}