Skip to main content

risingwave_stream/executor/aggregate/
simple_agg.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16
17use risingwave_common::array::stream_record::Record;
18use risingwave_common::util::epoch::EpochPair;
19use risingwave_common::util::iter_util::ZipEqFast;
20use risingwave_expr::aggregate::{AggCall, BoxedAggregateFunction, build_retractable};
21use risingwave_pb::stream_plan::PbAggNodeVersion;
22
23use super::agg_group::{AggGroup, AlwaysOutput};
24use super::agg_state::AggStateStorage;
25use super::distinct::DistinctDeduplicater;
26use super::{AggExecutorArgs, SimpleAggExecutorExtraArgs, agg_call_filter_res, iter_table_storage};
27use crate::executor::prelude::*;
28
29/// `SimpleAggExecutor` is the aggregation operator for streaming system.
30/// To create an aggregation operator, states and expressions should be passed along the
31/// constructor.
32///
33/// `SimpleAggExecutor` maintains multiple states together. If there are `n` states and `n`
34/// expressions, there will be `n` columns as output.
35///
36/// As the engine processes data in chunks, it is possible that multiple update
37/// messages could consolidate to a single row update. For example, our source
38/// emits 1000 inserts in one chunk, and we aggregate the count function on that.
39/// Current `SimpleAggExecutor` will only emit one row for a whole chunk.
40/// Therefore, we "automatically" implement a window function inside
41/// `SimpleAggExecutor`.
42pub struct SimpleAggExecutor<S: StateStore> {
43    input: Executor,
44    inner: ExecutorInner<S>,
45}
46
47struct ExecutorInner<S: StateStore> {
48    /// Version of aggregation executors.
49    version: PbAggNodeVersion,
50
51    actor_ctx: ActorContextRef,
52    info: ExecutorInfo,
53
54    /// Stream key indices from input. Only used by `AggNodeVersion` before `ISSUE_13465`.
55    input_stream_key: Vec<usize>,
56
57    /// Schema from input.
58    input_schema: Schema,
59
60    /// An operator will support multiple aggregation calls.
61    agg_calls: Vec<AggCall>,
62
63    /// Aggregate functions.
64    agg_funcs: Vec<BoxedAggregateFunction>,
65
66    /// Index of row count agg call (`count(*)`) in the call list.
67    row_count_index: usize,
68
69    /// State storage for each agg calls.
70    storages: Vec<AggStateStorage<S>>,
71
72    /// Intermediate state table for value-state agg calls.
73    /// The state of all value-state aggregates are collected and stored in this
74    /// table when `flush_data` is called.
75    intermediate_state_table: StateTable<S>,
76
77    /// State tables for deduplicating rows on distinct key for distinct agg calls.
78    /// One table per distinct column (may be shared by multiple agg calls).
79    distinct_dedup_tables: HashMap<usize, StateTable<S>>,
80
81    /// Watermark epoch.
82    watermark_epoch: AtomicU64Ref,
83
84    /// Extreme state cache size
85    extreme_cache_size: usize,
86
87    /// Required by the downstream `RowMergeExecutor`,
88    /// currently only used by the `approx_percentile`'s two phase plan
89    must_output_per_barrier: bool,
90}
91
92impl<S: StateStore> ExecutorInner<S> {
93    fn all_state_tables_mut(&mut self) -> impl Iterator<Item = &mut StateTable<S>> {
94        iter_table_storage(&mut self.storages)
95            .chain(self.distinct_dedup_tables.values_mut())
96            .chain(std::iter::once(&mut self.intermediate_state_table))
97    }
98}
99
100struct ExecutionVars<S: StateStore> {
101    /// The single [`AggGroup`].
102    agg_group: AggGroup<S, AlwaysOutput>,
103
104    /// Distinct deduplicater to deduplicate input rows for each distinct agg call.
105    distinct_dedup: DistinctDeduplicater<S>,
106
107    /// Mark the agg state is changed in the current epoch or not.
108    state_changed: bool,
109}
110
111impl<S: StateStore> Execute for SimpleAggExecutor<S> {
112    fn execute(self: Box<Self>) -> BoxedMessageStream {
113        self.execute_inner().boxed()
114    }
115}
116
117impl<S: StateStore> SimpleAggExecutor<S> {
118    pub fn new(args: AggExecutorArgs<S, SimpleAggExecutorExtraArgs>) -> StreamResult<Self> {
119        let input_info = args.input.info().clone();
120        Ok(Self {
121            input: args.input,
122            inner: ExecutorInner {
123                version: args.version,
124                actor_ctx: args.actor_ctx,
125                info: args.info,
126                input_stream_key: input_info.stream_key,
127                input_schema: input_info.schema,
128                agg_funcs: args.agg_calls.iter().map(build_retractable).try_collect()?,
129                agg_calls: args.agg_calls,
130                row_count_index: args.row_count_index,
131                storages: args.storages,
132                intermediate_state_table: args.intermediate_state_table,
133                distinct_dedup_tables: args.distinct_dedup_tables,
134                watermark_epoch: args.watermark_epoch,
135                extreme_cache_size: args.extreme_cache_size,
136                must_output_per_barrier: args.extra.must_output_per_barrier,
137            },
138        })
139    }
140
141    async fn apply_chunk(
142        this: &mut ExecutorInner<S>,
143        vars: &mut ExecutionVars<S>,
144        chunk: StreamChunk,
145    ) -> StreamExecutorResult<()> {
146        if chunk.cardinality() == 0 {
147            // If the chunk is empty, do nothing.
148            return Ok(());
149        }
150
151        // Calculate the row visibility for every agg call.
152        let mut call_visibilities = Vec::with_capacity(this.agg_calls.len());
153        for agg_call in &this.agg_calls {
154            let vis = agg_call_filter_res(agg_call, &chunk).await?;
155            call_visibilities.push(vis);
156        }
157
158        // Deduplicate for distinct columns.
159        let visibilities = vars
160            .distinct_dedup
161            .dedup_chunk(
162                chunk.ops(),
163                chunk.columns(),
164                call_visibilities,
165                &mut this.distinct_dedup_tables,
166                None,
167            )
168            .await?;
169
170        // Materialize input chunk if needed and possible.
171        for (storage, visibility) in this.storages.iter_mut().zip_eq_fast(visibilities.iter()) {
172            if let AggStateStorage::MaterializedInput { table, mapping, .. } = storage {
173                let chunk = chunk.project_with_vis(mapping.upstream_columns(), visibility.clone());
174                table.write_chunk(chunk);
175            }
176        }
177
178        // Apply chunk to each of the state (per agg_call).
179        vars.agg_group
180            .apply_chunk(&chunk, &this.agg_calls, &this.agg_funcs, visibilities)
181            .await?;
182
183        // Mark state as changed.
184        vars.state_changed = true;
185
186        Ok(())
187    }
188
189    async fn flush_data(
190        this: &mut ExecutorInner<S>,
191        vars: &mut ExecutionVars<S>,
192        epoch: EpochPair,
193    ) -> StreamExecutorResult<Option<StreamChunk>> {
194        if vars.state_changed || vars.agg_group.is_uninitialized() {
195            // Flush distinct dedup state.
196            vars.distinct_dedup.flush(&mut this.distinct_dedup_tables)?;
197
198            // Build and apply change for intermediate states.
199            if let Some(inter_states_change) =
200                vars.agg_group.build_states_change(&this.agg_funcs)?
201            {
202                this.intermediate_state_table
203                    .write_record(inter_states_change);
204            }
205        }
206        vars.state_changed = false;
207
208        // Build and apply change for the final outputs.
209        let (outputs_change, _stats) = vars
210            .agg_group
211            .build_outputs_change(&this.storages, &this.agg_funcs)
212            .await?;
213
214        let change =
215            outputs_change.expect("`AlwaysOutput` strategy will output a change in any case");
216        let chunk = if !this.must_output_per_barrier
217            && let Record::Update { old_row, new_row } = &change
218            && old_row == new_row
219        {
220            // for cases without approx percentile, we don't need to output the change if it's noop
221            None
222        } else {
223            Some(change.to_stream_chunk(&this.info.schema.data_types()))
224        };
225
226        // Commit all state tables.
227        futures::future::try_join_all(
228            this.all_state_tables_mut()
229                .map(|table| table.commit_assert_no_update_vnode_bitmap(epoch)),
230        )
231        .await?;
232
233        Ok(chunk)
234    }
235
236    async fn try_flush_data(this: &mut ExecutorInner<S>) -> StreamExecutorResult<()> {
237        futures::future::try_join_all(this.all_state_tables_mut().map(|table| table.try_flush()))
238            .await?;
239        Ok(())
240    }
241
242    #[try_stream(ok = Message, error = StreamExecutorError)]
243    async fn execute_inner(self) {
244        let Self {
245            input,
246            inner: mut this,
247        } = self;
248
249        let mut input = input.execute();
250        let barrier = expect_first_barrier(&mut input).await?;
251        let first_epoch = barrier.epoch;
252        yield Message::Barrier(barrier);
253
254        for table in this.all_state_tables_mut() {
255            table.init_epoch(first_epoch).await?;
256        }
257
258        let distinct_dedup = DistinctDeduplicater::new(
259            &this.agg_calls,
260            this.watermark_epoch.clone(),
261            &this.distinct_dedup_tables,
262            &this.actor_ctx,
263        );
264
265        // This will fetch previous agg states from the intermediate state table.
266        let (agg_group, _stats) = AggGroup::create(
267            this.version,
268            None,
269            &this.agg_calls,
270            &this.agg_funcs,
271            &this.storages,
272            &this.intermediate_state_table,
273            &this.input_stream_key,
274            this.row_count_index,
275            false, // emit on window close
276            this.extreme_cache_size,
277            &this.input_schema,
278        )
279        .await?;
280
281        let mut vars = ExecutionVars {
282            agg_group,
283            distinct_dedup,
284            state_changed: false,
285        };
286
287        #[for_await]
288        for msg in input {
289            let msg = msg?;
290            match msg {
291                Message::Watermark(_) => {}
292                Message::Chunk(chunk) => {
293                    Self::apply_chunk(&mut this, &mut vars, chunk).await?;
294                    Self::try_flush_data(&mut this).await?;
295                }
296                Message::Barrier(barrier) => {
297                    if let Some(chunk) =
298                        Self::flush_data(&mut this, &mut vars, barrier.epoch).await?
299                    {
300                        yield Message::Chunk(chunk);
301                    }
302                    yield Message::Barrier(barrier);
303                }
304            }
305        }
306    }
307}
308
309#[cfg(test)]
310mod tests {
311    use assert_matches::assert_matches;
312    use risingwave_common::array::stream_chunk::StreamChunkTestExt;
313    use risingwave_common::catalog::Field;
314    use risingwave_common::util::epoch::test_epoch;
315    use risingwave_storage::memory::MemoryStateStore;
316
317    use super::*;
318    use crate::executor::test_utils::agg_executor::new_boxed_simple_agg_executor;
319    use crate::executor::test_utils::*;
320
321    #[tokio::test]
322    async fn test_simple_aggregation_in_memory() {
323        test_simple_aggregation(MemoryStateStore::new()).await
324    }
325
326    async fn test_simple_aggregation<S: StateStore>(store: S) {
327        let schema = Schema {
328            fields: vec![
329                Field::unnamed(DataType::Int64),
330                Field::unnamed(DataType::Int64),
331                // primary key column`
332                Field::unnamed(DataType::Int64),
333            ],
334        };
335        let (mut tx, source) = MockSource::channel();
336        let source = source.into_executor(schema, vec![2]);
337        tx.push_barrier(test_epoch(1), false);
338        tx.push_barrier(test_epoch(2), false);
339        tx.push_chunk(StreamChunk::from_pretty(
340            "   I   I    I
341            + 100 200 1001
342            +  10  14 1002
343            +   4 300 1003",
344        ));
345        tx.push_barrier(test_epoch(3), false);
346        tx.push_chunk(StreamChunk::from_pretty(
347            "   I   I    I
348            - 100 200 1001
349            -  10  14 1002 D
350            -   4 300 1003
351            + 104 500 1004",
352        ));
353        tx.push_barrier(test_epoch(4), false);
354
355        let agg_calls = vec![
356            AggCall::from_pretty("(count:int8)"),
357            AggCall::from_pretty("(sum:int8 $0:int8)"),
358            AggCall::from_pretty("(sum:int8 $1:int8)"),
359            AggCall::from_pretty("(min:int8 $0:int8)"),
360        ];
361
362        let simple_agg = new_boxed_simple_agg_executor(
363            ActorContext::for_test(123),
364            store,
365            source,
366            false,
367            agg_calls,
368            0,
369            vec![2],
370            1,
371            false,
372        )
373        .await;
374        let mut simple_agg = simple_agg.execute();
375
376        // Consume the init barrier
377        simple_agg.next().await.unwrap().unwrap();
378        // Consume stream chunk
379        let msg = simple_agg.next().await.unwrap().unwrap();
380        assert_eq!(
381            *msg.as_chunk().unwrap(),
382            StreamChunk::from_pretty(
383                " I   I   I  I
384                + 0   .   .  . "
385            )
386        );
387        assert_matches!(
388            simple_agg.next().await.unwrap().unwrap(),
389            Message::Barrier { .. }
390        );
391
392        // Consume stream chunk
393        let msg = simple_agg.next().await.unwrap().unwrap();
394        assert_eq!(
395            *msg.as_chunk().unwrap(),
396            StreamChunk::from_pretty(
397                "  I   I   I  I
398                U- 0   .   .  .
399                U+ 3 114 514  4"
400            )
401        );
402        assert_matches!(
403            simple_agg.next().await.unwrap().unwrap(),
404            Message::Barrier { .. }
405        );
406
407        let msg = simple_agg.next().await.unwrap().unwrap();
408        assert_eq!(
409            *msg.as_chunk().unwrap(),
410            StreamChunk::from_pretty(
411                "  I   I   I  I
412                U- 3 114 514  4
413                U+ 2 114 514 10"
414            )
415        );
416    }
417
418    // NOTE(kwannoel): `approx_percentile` + `keyed_merge` depend on this property for correctness.
419    #[tokio::test]
420    async fn test_simple_aggregation_always_output_per_epoch() {
421        let store = MemoryStateStore::new();
422        let schema = Schema {
423            fields: vec![
424                Field::unnamed(DataType::Int64),
425                Field::unnamed(DataType::Int64),
426                // primary key column`
427                Field::unnamed(DataType::Int64),
428            ],
429        };
430        let (mut tx, source) = MockSource::channel();
431        let source = source.into_executor(schema, vec![2]);
432        // initial barrier
433        tx.push_barrier(test_epoch(1), false);
434        // next barrier
435        tx.push_barrier(test_epoch(2), false);
436        tx.push_chunk(StreamChunk::from_pretty(
437            "   I   I    I
438            + 100 200 1001
439            - 100 200 1001",
440        ));
441        tx.push_barrier(test_epoch(3), false);
442        tx.push_barrier(test_epoch(4), false);
443
444        let agg_calls = vec![
445            AggCall::from_pretty("(count:int8)"),
446            AggCall::from_pretty("(sum:int8 $0:int8)"),
447            AggCall::from_pretty("(sum:int8 $1:int8)"),
448            AggCall::from_pretty("(min:int8 $0:int8)"),
449        ];
450
451        let simple_agg = new_boxed_simple_agg_executor(
452            ActorContext::for_test(123),
453            store,
454            source,
455            false,
456            agg_calls,
457            0,
458            vec![2],
459            1,
460            true,
461        )
462        .await;
463        let mut simple_agg = simple_agg.execute();
464
465        // Consume the init barrier
466        simple_agg.next().await.unwrap().unwrap();
467        // Consume stream chunk
468        let msg = simple_agg.next().await.unwrap().unwrap();
469        assert_eq!(
470            *msg.as_chunk().unwrap(),
471            StreamChunk::from_pretty(
472                " I   I   I  I
473                + 0   .   .  . "
474            )
475        );
476        assert_matches!(
477            simple_agg.next().await.unwrap().unwrap(),
478            Message::Barrier { .. }
479        );
480
481        // Consume stream chunk
482        let msg = simple_agg.next().await.unwrap().unwrap();
483        assert_eq!(
484            *msg.as_chunk().unwrap(),
485            StreamChunk::from_pretty(
486                "  I   I   I  I
487                U- 0   .   .  .
488                U+ 0   .   .  ."
489            )
490        );
491        assert_matches!(
492            simple_agg.next().await.unwrap().unwrap(),
493            Message::Barrier { .. }
494        );
495
496        // Consume stream chunk
497        let msg = simple_agg.next().await.unwrap().unwrap();
498        assert_eq!(
499            *msg.as_chunk().unwrap(),
500            StreamChunk::from_pretty(
501                "  I   I   I  I
502                U- 0   .   .  .
503                U+ 0   .   .  ."
504            )
505        );
506        assert_matches!(
507            simple_agg.next().await.unwrap().unwrap(),
508            Message::Barrier { .. }
509        );
510    }
511
512    // NOTE(kwannoel): `approx_percentile` + `keyed_merge` depend on this property for correctness.
513    #[tokio::test]
514    async fn test_simple_aggregation_omit_noop_update() {
515        let store = MemoryStateStore::new();
516        let schema = Schema {
517            fields: vec![
518                Field::unnamed(DataType::Int64),
519                Field::unnamed(DataType::Int64),
520                // primary key column`
521                Field::unnamed(DataType::Int64),
522            ],
523        };
524        let (mut tx, source) = MockSource::channel();
525        let source = source.into_executor(schema, vec![2]);
526        // initial barrier
527        tx.push_barrier(test_epoch(1), false);
528        // next barrier
529        tx.push_barrier(test_epoch(2), false);
530        tx.push_chunk(StreamChunk::from_pretty(
531            "   I   I    I
532                + 100 200 1001
533                - 100 200 1001",
534        ));
535        tx.push_barrier(test_epoch(3), false);
536        tx.push_barrier(test_epoch(4), false);
537
538        let agg_calls = vec![
539            AggCall::from_pretty("(count:int8)"),
540            AggCall::from_pretty("(sum:int8 $0:int8)"),
541            AggCall::from_pretty("(sum:int8 $1:int8)"),
542            AggCall::from_pretty("(min:int8 $0:int8)"),
543        ];
544
545        let simple_agg = new_boxed_simple_agg_executor(
546            ActorContext::for_test(123),
547            store,
548            source,
549            false,
550            agg_calls,
551            0,
552            vec![2],
553            1,
554            false,
555        )
556        .await;
557        let mut simple_agg = simple_agg.execute();
558
559        // Consume the init barrier
560        simple_agg.next().await.unwrap().unwrap();
561        // Consume stream chunk
562        let msg = simple_agg.next().await.unwrap().unwrap();
563        assert_eq!(
564            *msg.as_chunk().unwrap(),
565            StreamChunk::from_pretty(
566                " I   I   I  I
567                + 0   .   .  . "
568            )
569        );
570        assert_matches!(
571            simple_agg.next().await.unwrap().unwrap(),
572            Message::Barrier { .. }
573        );
574
575        // No stream chunk
576        assert_matches!(
577            simple_agg.next().await.unwrap().unwrap(),
578            Message::Barrier { .. }
579        );
580
581        // No stream chunk
582        assert_matches!(
583            simple_agg.next().await.unwrap().unwrap(),
584            Message::Barrier { .. }
585        );
586    }
587}