risingwave_stream/executor/aggregate/
simple_agg.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16
17use risingwave_common::array::stream_record::Record;
18use risingwave_common::util::epoch::EpochPair;
19use risingwave_common::util::iter_util::ZipEqFast;
20use risingwave_expr::aggregate::{AggCall, BoxedAggregateFunction, build_retractable};
21use risingwave_pb::stream_plan::PbAggNodeVersion;
22
23use super::agg_group::{AggGroup, AlwaysOutput};
24use super::agg_state::AggStateStorage;
25use super::distinct::DistinctDeduplicater;
26use super::{AggExecutorArgs, SimpleAggExecutorExtraArgs, agg_call_filter_res, iter_table_storage};
27use crate::executor::prelude::*;
28
29/// `SimpleAggExecutor` is the aggregation operator for streaming system.
30/// To create an aggregation operator, states and expressions should be passed along the
31/// constructor.
32///
33/// `SimpleAggExecutor` maintains multiple states together. If there are `n` states and `n`
34/// expressions, there will be `n` columns as output.
35///
36/// As the engine processes data in chunks, it is possible that multiple update
37/// messages could consolidate to a single row update. For example, our source
38/// emits 1000 inserts in one chunk, and we aggregate the count function on that.
39/// Current `SimpleAggExecutor` will only emit one row for a whole chunk.
40/// Therefore, we "automatically" implement a window function inside
41/// `SimpleAggExecutor`.
42pub struct SimpleAggExecutor<S: StateStore> {
43    input: Executor,
44    inner: ExecutorInner<S>,
45}
46
47struct ExecutorInner<S: StateStore> {
48    /// Version of aggregation executors.
49    version: PbAggNodeVersion,
50
51    actor_ctx: ActorContextRef,
52    info: ExecutorInfo,
53
54    /// Pk indices from input. Only used by `AggNodeVersion` before `ISSUE_13465`.
55    input_pk_indices: Vec<usize>,
56
57    /// Schema from input.
58    input_schema: Schema,
59
60    /// An operator will support multiple aggregation calls.
61    agg_calls: Vec<AggCall>,
62
63    /// Aggregate functions.
64    agg_funcs: Vec<BoxedAggregateFunction>,
65
66    /// Index of row count agg call (`count(*)`) in the call list.
67    row_count_index: usize,
68
69    /// State storage for each agg calls.
70    storages: Vec<AggStateStorage<S>>,
71
72    /// Intermediate state table for value-state agg calls.
73    /// The state of all value-state aggregates are collected and stored in this
74    /// table when `flush_data` is called.
75    intermediate_state_table: StateTable<S>,
76
77    /// State tables for deduplicating rows on distinct key for distinct agg calls.
78    /// One table per distinct column (may be shared by multiple agg calls).
79    distinct_dedup_tables: HashMap<usize, StateTable<S>>,
80
81    /// Watermark epoch.
82    watermark_epoch: AtomicU64Ref,
83
84    /// Extreme state cache size
85    extreme_cache_size: usize,
86
87    /// Required by the downstream `RowMergeExecutor`,
88    /// currently only used by the `approx_percentile`'s two phase plan
89    must_output_per_barrier: bool,
90}
91
92impl<S: StateStore> ExecutorInner<S> {
93    fn all_state_tables_mut(&mut self) -> impl Iterator<Item = &mut StateTable<S>> {
94        iter_table_storage(&mut self.storages)
95            .chain(self.distinct_dedup_tables.values_mut())
96            .chain(std::iter::once(&mut self.intermediate_state_table))
97    }
98}
99
100struct ExecutionVars<S: StateStore> {
101    /// The single [`AggGroup`].
102    agg_group: AggGroup<S, AlwaysOutput>,
103
104    /// Distinct deduplicater to deduplicate input rows for each distinct agg call.
105    distinct_dedup: DistinctDeduplicater<S>,
106
107    /// Mark the agg state is changed in the current epoch or not.
108    state_changed: bool,
109}
110
111impl<S: StateStore> Execute for SimpleAggExecutor<S> {
112    fn execute(self: Box<Self>) -> BoxedMessageStream {
113        self.execute_inner().boxed()
114    }
115}
116
117impl<S: StateStore> SimpleAggExecutor<S> {
118    pub fn new(args: AggExecutorArgs<S, SimpleAggExecutorExtraArgs>) -> StreamResult<Self> {
119        let input_info = args.input.info().clone();
120        Ok(Self {
121            input: args.input,
122            inner: ExecutorInner {
123                version: args.version,
124                actor_ctx: args.actor_ctx,
125                info: args.info,
126                input_pk_indices: input_info.pk_indices,
127                input_schema: input_info.schema,
128                agg_funcs: args.agg_calls.iter().map(build_retractable).try_collect()?,
129                agg_calls: args.agg_calls,
130                row_count_index: args.row_count_index,
131                storages: args.storages,
132                intermediate_state_table: args.intermediate_state_table,
133                distinct_dedup_tables: args.distinct_dedup_tables,
134                watermark_epoch: args.watermark_epoch,
135                extreme_cache_size: args.extreme_cache_size,
136                must_output_per_barrier: args.extra.must_output_per_barrier,
137            },
138        })
139    }
140
141    async fn apply_chunk(
142        this: &mut ExecutorInner<S>,
143        vars: &mut ExecutionVars<S>,
144        chunk: StreamChunk,
145    ) -> StreamExecutorResult<()> {
146        if chunk.cardinality() == 0 {
147            // If the chunk is empty, do nothing.
148            return Ok(());
149        }
150
151        // Calculate the row visibility for every agg call.
152        let mut call_visibilities = Vec::with_capacity(this.agg_calls.len());
153        for agg_call in &this.agg_calls {
154            let vis = agg_call_filter_res(agg_call, &chunk).await?;
155            call_visibilities.push(vis);
156        }
157
158        // Deduplicate for distinct columns.
159        let visibilities = vars
160            .distinct_dedup
161            .dedup_chunk(
162                chunk.ops(),
163                chunk.columns(),
164                call_visibilities,
165                &mut this.distinct_dedup_tables,
166                None,
167            )
168            .await?;
169
170        // Materialize input chunk if needed and possible.
171        for (storage, visibility) in this.storages.iter_mut().zip_eq_fast(visibilities.iter()) {
172            if let AggStateStorage::MaterializedInput { table, mapping, .. } = storage {
173                let chunk = chunk.project_with_vis(mapping.upstream_columns(), visibility.clone());
174                table.write_chunk(chunk);
175            }
176        }
177
178        // Apply chunk to each of the state (per agg_call).
179        vars.agg_group
180            .apply_chunk(&chunk, &this.agg_calls, &this.agg_funcs, visibilities)
181            .await?;
182
183        // Mark state as changed.
184        vars.state_changed = true;
185
186        Ok(())
187    }
188
189    async fn flush_data(
190        this: &mut ExecutorInner<S>,
191        vars: &mut ExecutionVars<S>,
192        epoch: EpochPair,
193    ) -> StreamExecutorResult<Option<StreamChunk>> {
194        if vars.state_changed || vars.agg_group.is_uninitialized() {
195            // Flush distinct dedup state.
196            vars.distinct_dedup.flush(&mut this.distinct_dedup_tables)?;
197
198            // Build and apply change for intermediate states.
199            if let Some(inter_states_change) =
200                vars.agg_group.build_states_change(&this.agg_funcs)?
201            {
202                this.intermediate_state_table
203                    .write_record(inter_states_change);
204            }
205        }
206        vars.state_changed = false;
207
208        // Build and apply change for the final outputs.
209        let (outputs_change, _stats) = vars
210            .agg_group
211            .build_outputs_change(&this.storages, &this.agg_funcs)
212            .await?;
213
214        let change =
215            outputs_change.expect("`AlwaysOutput` strategy will output a change in any case");
216        let chunk = if !this.must_output_per_barrier
217            && let Record::Update { old_row, new_row } = &change
218            && old_row == new_row
219        {
220            // for cases without approx percentile, we don't need to output the change if it's noop
221            None
222        } else {
223            Some(change.to_stream_chunk(&this.info.schema.data_types()))
224        };
225
226        // Commit all state tables.
227        futures::future::try_join_all(
228            this.all_state_tables_mut()
229                .map(|table| table.commit_assert_no_update_vnode_bitmap(epoch)),
230        )
231        .await?;
232
233        Ok(chunk)
234    }
235
236    async fn try_flush_data(this: &mut ExecutorInner<S>) -> StreamExecutorResult<()> {
237        futures::future::try_join_all(this.all_state_tables_mut().map(|table| table.try_flush()))
238            .await?;
239        Ok(())
240    }
241
242    #[try_stream(ok = Message, error = StreamExecutorError)]
243    async fn execute_inner(self) {
244        let Self {
245            input,
246            inner: mut this,
247        } = self;
248
249        let mut input = input.execute();
250        let barrier = expect_first_barrier(&mut input).await?;
251        let first_epoch = barrier.epoch;
252        yield Message::Barrier(barrier);
253
254        for table in this.all_state_tables_mut() {
255            table.init_epoch(first_epoch).await?;
256        }
257
258        let distinct_dedup = DistinctDeduplicater::new(
259            &this.agg_calls,
260            this.watermark_epoch.clone(),
261            &this.distinct_dedup_tables,
262            &this.actor_ctx,
263        );
264
265        // This will fetch previous agg states from the intermediate state table.
266        let (agg_group, _stats) = AggGroup::create(
267            this.version,
268            None,
269            &this.agg_calls,
270            &this.agg_funcs,
271            &this.storages,
272            &this.intermediate_state_table,
273            &this.input_pk_indices,
274            this.row_count_index,
275            false, // emit on window close
276            this.extreme_cache_size,
277            &this.input_schema,
278        )
279        .await?;
280
281        let mut vars = ExecutionVars {
282            agg_group,
283            distinct_dedup,
284            state_changed: false,
285        };
286
287        #[for_await]
288        for msg in input {
289            let msg = msg?;
290            match msg {
291                Message::Watermark(_) => {}
292                Message::Chunk(chunk) => {
293                    Self::apply_chunk(&mut this, &mut vars, chunk).await?;
294                    Self::try_flush_data(&mut this).await?;
295                }
296                Message::Barrier(barrier) => {
297                    if let Some(chunk) =
298                        Self::flush_data(&mut this, &mut vars, barrier.epoch).await?
299                    {
300                        yield Message::Chunk(chunk);
301                    }
302                    yield Message::Barrier(barrier);
303                }
304            }
305        }
306    }
307}
308
309#[cfg(test)]
310mod tests {
311    use assert_matches::assert_matches;
312    use risingwave_common::array::stream_chunk::StreamChunkTestExt;
313    use risingwave_common::catalog::Field;
314    use risingwave_common::types::*;
315    use risingwave_common::util::epoch::test_epoch;
316    use risingwave_storage::memory::MemoryStateStore;
317
318    use super::*;
319    use crate::executor::test_utils::agg_executor::new_boxed_simple_agg_executor;
320    use crate::executor::test_utils::*;
321
322    #[tokio::test]
323    async fn test_simple_aggregation_in_memory() {
324        test_simple_aggregation(MemoryStateStore::new()).await
325    }
326
327    async fn test_simple_aggregation<S: StateStore>(store: S) {
328        let schema = Schema {
329            fields: vec![
330                Field::unnamed(DataType::Int64),
331                Field::unnamed(DataType::Int64),
332                // primary key column`
333                Field::unnamed(DataType::Int64),
334            ],
335        };
336        let (mut tx, source) = MockSource::channel();
337        let source = source.into_executor(schema, vec![2]);
338        tx.push_barrier(test_epoch(1), false);
339        tx.push_barrier(test_epoch(2), false);
340        tx.push_chunk(StreamChunk::from_pretty(
341            "   I   I    I
342            + 100 200 1001
343            +  10  14 1002
344            +   4 300 1003",
345        ));
346        tx.push_barrier(test_epoch(3), false);
347        tx.push_chunk(StreamChunk::from_pretty(
348            "   I   I    I
349            - 100 200 1001
350            -  10  14 1002 D
351            -   4 300 1003
352            + 104 500 1004",
353        ));
354        tx.push_barrier(test_epoch(4), false);
355
356        let agg_calls = vec![
357            AggCall::from_pretty("(count:int8)"),
358            AggCall::from_pretty("(sum:int8 $0:int8)"),
359            AggCall::from_pretty("(sum:int8 $1:int8)"),
360            AggCall::from_pretty("(min:int8 $0:int8)"),
361        ];
362
363        let simple_agg = new_boxed_simple_agg_executor(
364            ActorContext::for_test(123),
365            store,
366            source,
367            false,
368            agg_calls,
369            0,
370            vec![2],
371            1,
372            false,
373        )
374        .await;
375        let mut simple_agg = simple_agg.execute();
376
377        // Consume the init barrier
378        simple_agg.next().await.unwrap().unwrap();
379        // Consume stream chunk
380        let msg = simple_agg.next().await.unwrap().unwrap();
381        assert_eq!(
382            *msg.as_chunk().unwrap(),
383            StreamChunk::from_pretty(
384                " I   I   I  I
385                + 0   .   .  . "
386            )
387        );
388        assert_matches!(
389            simple_agg.next().await.unwrap().unwrap(),
390            Message::Barrier { .. }
391        );
392
393        // Consume stream chunk
394        let msg = simple_agg.next().await.unwrap().unwrap();
395        assert_eq!(
396            *msg.as_chunk().unwrap(),
397            StreamChunk::from_pretty(
398                "  I   I   I  I
399                U- 0   .   .  .
400                U+ 3 114 514  4"
401            )
402        );
403        assert_matches!(
404            simple_agg.next().await.unwrap().unwrap(),
405            Message::Barrier { .. }
406        );
407
408        let msg = simple_agg.next().await.unwrap().unwrap();
409        assert_eq!(
410            *msg.as_chunk().unwrap(),
411            StreamChunk::from_pretty(
412                "  I   I   I  I
413                U- 3 114 514  4
414                U+ 2 114 514 10"
415            )
416        );
417    }
418
419    // NOTE(kwannoel): `approx_percentile` + `keyed_merge` depend on this property for correctness.
420    #[tokio::test]
421    async fn test_simple_aggregation_always_output_per_epoch() {
422        let store = MemoryStateStore::new();
423        let schema = Schema {
424            fields: vec![
425                Field::unnamed(DataType::Int64),
426                Field::unnamed(DataType::Int64),
427                // primary key column`
428                Field::unnamed(DataType::Int64),
429            ],
430        };
431        let (mut tx, source) = MockSource::channel();
432        let source = source.into_executor(schema, vec![2]);
433        // initial barrier
434        tx.push_barrier(test_epoch(1), false);
435        // next barrier
436        tx.push_barrier(test_epoch(2), false);
437        tx.push_chunk(StreamChunk::from_pretty(
438            "   I   I    I
439            + 100 200 1001
440            - 100 200 1001",
441        ));
442        tx.push_barrier(test_epoch(3), false);
443        tx.push_barrier(test_epoch(4), false);
444
445        let agg_calls = vec![
446            AggCall::from_pretty("(count:int8)"),
447            AggCall::from_pretty("(sum:int8 $0:int8)"),
448            AggCall::from_pretty("(sum:int8 $1:int8)"),
449            AggCall::from_pretty("(min:int8 $0:int8)"),
450        ];
451
452        let simple_agg = new_boxed_simple_agg_executor(
453            ActorContext::for_test(123),
454            store,
455            source,
456            false,
457            agg_calls,
458            0,
459            vec![2],
460            1,
461            true,
462        )
463        .await;
464        let mut simple_agg = simple_agg.execute();
465
466        // Consume the init barrier
467        simple_agg.next().await.unwrap().unwrap();
468        // Consume stream chunk
469        let msg = simple_agg.next().await.unwrap().unwrap();
470        assert_eq!(
471            *msg.as_chunk().unwrap(),
472            StreamChunk::from_pretty(
473                " I   I   I  I
474                + 0   .   .  . "
475            )
476        );
477        assert_matches!(
478            simple_agg.next().await.unwrap().unwrap(),
479            Message::Barrier { .. }
480        );
481
482        // Consume stream chunk
483        let msg = simple_agg.next().await.unwrap().unwrap();
484        assert_eq!(
485            *msg.as_chunk().unwrap(),
486            StreamChunk::from_pretty(
487                "  I   I   I  I
488                U- 0   .   .  .
489                U+ 0   .   .  ."
490            )
491        );
492        assert_matches!(
493            simple_agg.next().await.unwrap().unwrap(),
494            Message::Barrier { .. }
495        );
496
497        // Consume stream chunk
498        let msg = simple_agg.next().await.unwrap().unwrap();
499        assert_eq!(
500            *msg.as_chunk().unwrap(),
501            StreamChunk::from_pretty(
502                "  I   I   I  I
503                U- 0   .   .  .
504                U+ 0   .   .  ."
505            )
506        );
507        assert_matches!(
508            simple_agg.next().await.unwrap().unwrap(),
509            Message::Barrier { .. }
510        );
511    }
512
513    // NOTE(kwannoel): `approx_percentile` + `keyed_merge` depend on this property for correctness.
514    #[tokio::test]
515    async fn test_simple_aggregation_omit_noop_update() {
516        let store = MemoryStateStore::new();
517        let schema = Schema {
518            fields: vec![
519                Field::unnamed(DataType::Int64),
520                Field::unnamed(DataType::Int64),
521                // primary key column`
522                Field::unnamed(DataType::Int64),
523            ],
524        };
525        let (mut tx, source) = MockSource::channel();
526        let source = source.into_executor(schema, vec![2]);
527        // initial barrier
528        tx.push_barrier(test_epoch(1), false);
529        // next barrier
530        tx.push_barrier(test_epoch(2), false);
531        tx.push_chunk(StreamChunk::from_pretty(
532            "   I   I    I
533                + 100 200 1001
534                - 100 200 1001",
535        ));
536        tx.push_barrier(test_epoch(3), false);
537        tx.push_barrier(test_epoch(4), false);
538
539        let agg_calls = vec![
540            AggCall::from_pretty("(count:int8)"),
541            AggCall::from_pretty("(sum:int8 $0:int8)"),
542            AggCall::from_pretty("(sum:int8 $1:int8)"),
543            AggCall::from_pretty("(min:int8 $0:int8)"),
544        ];
545
546        let simple_agg = new_boxed_simple_agg_executor(
547            ActorContext::for_test(123),
548            store,
549            source,
550            false,
551            agg_calls,
552            0,
553            vec![2],
554            1,
555            false,
556        )
557        .await;
558        let mut simple_agg = simple_agg.execute();
559
560        // Consume the init barrier
561        simple_agg.next().await.unwrap().unwrap();
562        // Consume stream chunk
563        let msg = simple_agg.next().await.unwrap().unwrap();
564        assert_eq!(
565            *msg.as_chunk().unwrap(),
566            StreamChunk::from_pretty(
567                " I   I   I  I
568                + 0   .   .  . "
569            )
570        );
571        assert_matches!(
572            simple_agg.next().await.unwrap().unwrap(),
573            Message::Barrier { .. }
574        );
575
576        // No stream chunk
577        assert_matches!(
578            simple_agg.next().await.unwrap().unwrap(),
579            Message::Barrier { .. }
580        );
581
582        // No stream chunk
583        assert_matches!(
584            simple_agg.next().await.unwrap().unwrap(),
585            Message::Barrier { .. }
586        );
587    }
588}