risingwave_stream/executor/aggregate/
simple_agg.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16
17use risingwave_common::array::stream_record::Record;
18use risingwave_common::util::epoch::EpochPair;
19use risingwave_common::util::iter_util::ZipEqFast;
20use risingwave_expr::aggregate::{AggCall, BoxedAggregateFunction, build_retractable};
21use risingwave_pb::stream_plan::PbAggNodeVersion;
22
23use super::agg_group::{AggGroup, AlwaysOutput};
24use super::agg_state::AggStateStorage;
25use super::distinct::DistinctDeduplicater;
26use super::{AggExecutorArgs, SimpleAggExecutorExtraArgs, agg_call_filter_res, iter_table_storage};
27use crate::executor::prelude::*;
28
29/// `SimpleAggExecutor` is the aggregation operator for streaming system.
30/// To create an aggregation operator, states and expressions should be passed along the
31/// constructor.
32///
33/// `SimpleAggExecutor` maintains multiple states together. If there are `n` states and `n`
34/// expressions, there will be `n` columns as output.
35///
36/// As the engine processes data in chunks, it is possible that multiple update
37/// messages could consolidate to a single row update. For example, our source
38/// emits 1000 inserts in one chunk, and we aggregate the count function on that.
39/// Current `SimpleAggExecutor` will only emit one row for a whole chunk.
40/// Therefore, we "automatically" implement a window function inside
41/// `SimpleAggExecutor`.
42pub struct SimpleAggExecutor<S: StateStore> {
43    input: Executor,
44    inner: ExecutorInner<S>,
45}
46
47struct ExecutorInner<S: StateStore> {
48    /// Version of aggregation executors.
49    version: PbAggNodeVersion,
50
51    actor_ctx: ActorContextRef,
52    info: ExecutorInfo,
53
54    /// Pk indices from input. Only used by `AggNodeVersion` before `ISSUE_13465`.
55    input_pk_indices: Vec<usize>,
56
57    /// Schema from input.
58    input_schema: Schema,
59
60    /// An operator will support multiple aggregation calls.
61    agg_calls: Vec<AggCall>,
62
63    /// Aggregate functions.
64    agg_funcs: Vec<BoxedAggregateFunction>,
65
66    /// Index of row count agg call (`count(*)`) in the call list.
67    row_count_index: usize,
68
69    /// State storage for each agg calls.
70    storages: Vec<AggStateStorage<S>>,
71
72    /// Intermediate state table for value-state agg calls.
73    /// The state of all value-state aggregates are collected and stored in this
74    /// table when `flush_data` is called.
75    intermediate_state_table: StateTable<S>,
76
77    /// State tables for deduplicating rows on distinct key for distinct agg calls.
78    /// One table per distinct column (may be shared by multiple agg calls).
79    distinct_dedup_tables: HashMap<usize, StateTable<S>>,
80
81    /// Watermark epoch.
82    watermark_epoch: AtomicU64Ref,
83
84    /// Extreme state cache size
85    extreme_cache_size: usize,
86
87    /// Required by the downstream `RowMergeExecutor`,
88    /// currently only used by the `approx_percentile`'s two phase plan
89    must_output_per_barrier: bool,
90}
91
92impl<S: StateStore> ExecutorInner<S> {
93    fn all_state_tables_mut(&mut self) -> impl Iterator<Item = &mut StateTable<S>> {
94        iter_table_storage(&mut self.storages)
95            .chain(self.distinct_dedup_tables.values_mut())
96            .chain(std::iter::once(&mut self.intermediate_state_table))
97    }
98}
99
100struct ExecutionVars<S: StateStore> {
101    /// The single [`AggGroup`].
102    agg_group: AggGroup<S, AlwaysOutput>,
103
104    /// Distinct deduplicater to deduplicate input rows for each distinct agg call.
105    distinct_dedup: DistinctDeduplicater<S>,
106
107    /// Mark the agg state is changed in the current epoch or not.
108    state_changed: bool,
109}
110
111impl<S: StateStore> Execute for SimpleAggExecutor<S> {
112    fn execute(self: Box<Self>) -> BoxedMessageStream {
113        self.execute_inner().boxed()
114    }
115}
116
117impl<S: StateStore> SimpleAggExecutor<S> {
118    pub fn new(args: AggExecutorArgs<S, SimpleAggExecutorExtraArgs>) -> StreamResult<Self> {
119        let input_info = args.input.info().clone();
120        Ok(Self {
121            input: args.input,
122            inner: ExecutorInner {
123                version: args.version,
124                actor_ctx: args.actor_ctx,
125                info: args.info,
126                input_pk_indices: input_info.pk_indices,
127                input_schema: input_info.schema,
128                agg_funcs: args.agg_calls.iter().map(build_retractable).try_collect()?,
129                agg_calls: args.agg_calls,
130                row_count_index: args.row_count_index,
131                storages: args.storages,
132                intermediate_state_table: args.intermediate_state_table,
133                distinct_dedup_tables: args.distinct_dedup_tables,
134                watermark_epoch: args.watermark_epoch,
135                extreme_cache_size: args.extreme_cache_size,
136                must_output_per_barrier: args.extra.must_output_per_barrier,
137            },
138        })
139    }
140
141    async fn apply_chunk(
142        this: &mut ExecutorInner<S>,
143        vars: &mut ExecutionVars<S>,
144        chunk: StreamChunk,
145    ) -> StreamExecutorResult<()> {
146        if chunk.cardinality() == 0 {
147            // If the chunk is empty, do nothing.
148            return Ok(());
149        }
150
151        // Calculate the row visibility for every agg call.
152        let mut call_visibilities = Vec::with_capacity(this.agg_calls.len());
153        for agg_call in &this.agg_calls {
154            let vis = agg_call_filter_res(agg_call, &chunk).await?;
155            call_visibilities.push(vis);
156        }
157
158        // Deduplicate for distinct columns.
159        let visibilities = vars
160            .distinct_dedup
161            .dedup_chunk(
162                chunk.ops(),
163                chunk.columns(),
164                call_visibilities,
165                &mut this.distinct_dedup_tables,
166                None,
167            )
168            .await?;
169
170        // Materialize input chunk if needed and possible.
171        for (storage, visibility) in this.storages.iter_mut().zip_eq_fast(visibilities.iter()) {
172            if let AggStateStorage::MaterializedInput { table, mapping, .. } = storage {
173                let chunk = chunk.project_with_vis(mapping.upstream_columns(), visibility.clone());
174                table.write_chunk(chunk);
175            }
176        }
177
178        // Apply chunk to each of the state (per agg_call).
179        vars.agg_group
180            .apply_chunk(&chunk, &this.agg_calls, &this.agg_funcs, visibilities)
181            .await?;
182
183        // Mark state as changed.
184        vars.state_changed = true;
185
186        Ok(())
187    }
188
189    async fn flush_data(
190        this: &mut ExecutorInner<S>,
191        vars: &mut ExecutionVars<S>,
192        epoch: EpochPair,
193    ) -> StreamExecutorResult<Option<StreamChunk>> {
194        if vars.state_changed || vars.agg_group.is_uninitialized() {
195            // Flush distinct dedup state.
196            vars.distinct_dedup.flush(&mut this.distinct_dedup_tables)?;
197
198            // Build and apply change for intermediate states.
199            if let Some(inter_states_change) =
200                vars.agg_group.build_states_change(&this.agg_funcs)?
201            {
202                this.intermediate_state_table
203                    .write_record(inter_states_change);
204            }
205        }
206        vars.state_changed = false;
207
208        // Build and apply change for the final outputs.
209        let (outputs_change, _stats) = vars
210            .agg_group
211            .build_outputs_change(&this.storages, &this.agg_funcs)
212            .await?;
213
214        let change =
215            outputs_change.expect("`AlwaysOutput` strategy will output a change in any case");
216        let chunk = if !this.must_output_per_barrier
217            && let Record::Update { old_row, new_row } = &change
218            && old_row == new_row
219        {
220            // for cases without approx percentile, we don't need to output the change if it's noop
221            None
222        } else {
223            Some(change.to_stream_chunk(&this.info.schema.data_types()))
224        };
225
226        // Commit all state tables.
227        futures::future::try_join_all(
228            this.all_state_tables_mut()
229                .map(|table| table.commit_assert_no_update_vnode_bitmap(epoch)),
230        )
231        .await?;
232
233        Ok(chunk)
234    }
235
236    async fn try_flush_data(this: &mut ExecutorInner<S>) -> StreamExecutorResult<()> {
237        futures::future::try_join_all(this.all_state_tables_mut().map(|table| table.try_flush()))
238            .await?;
239        Ok(())
240    }
241
242    #[try_stream(ok = Message, error = StreamExecutorError)]
243    async fn execute_inner(self) {
244        let Self {
245            input,
246            inner: mut this,
247        } = self;
248
249        let mut input = input.execute();
250        let barrier = expect_first_barrier(&mut input).await?;
251        let first_epoch = barrier.epoch;
252        yield Message::Barrier(barrier);
253
254        for table in this.all_state_tables_mut() {
255            table.init_epoch(first_epoch).await?;
256        }
257
258        let distinct_dedup = DistinctDeduplicater::new(
259            &this.agg_calls,
260            this.watermark_epoch.clone(),
261            &this.distinct_dedup_tables,
262            &this.actor_ctx,
263        );
264
265        // This will fetch previous agg states from the intermediate state table.
266        let mut vars = ExecutionVars {
267            agg_group: AggGroup::create(
268                this.version,
269                None,
270                &this.agg_calls,
271                &this.agg_funcs,
272                &this.storages,
273                &this.intermediate_state_table,
274                &this.input_pk_indices,
275                this.row_count_index,
276                false, // emit on window close
277                this.extreme_cache_size,
278                &this.input_schema,
279            )
280            .await?,
281            distinct_dedup,
282            state_changed: false,
283        };
284
285        #[for_await]
286        for msg in input {
287            let msg = msg?;
288            match msg {
289                Message::Watermark(_) => {}
290                Message::Chunk(chunk) => {
291                    Self::apply_chunk(&mut this, &mut vars, chunk).await?;
292                    Self::try_flush_data(&mut this).await?;
293                }
294                Message::Barrier(barrier) => {
295                    if let Some(chunk) =
296                        Self::flush_data(&mut this, &mut vars, barrier.epoch).await?
297                    {
298                        yield Message::Chunk(chunk);
299                    }
300                    yield Message::Barrier(barrier);
301                }
302            }
303        }
304    }
305}
306
307#[cfg(test)]
308mod tests {
309    use assert_matches::assert_matches;
310    use risingwave_common::array::stream_chunk::StreamChunkTestExt;
311    use risingwave_common::catalog::Field;
312    use risingwave_common::types::*;
313    use risingwave_common::util::epoch::test_epoch;
314    use risingwave_storage::memory::MemoryStateStore;
315
316    use super::*;
317    use crate::executor::test_utils::agg_executor::new_boxed_simple_agg_executor;
318    use crate::executor::test_utils::*;
319
320    #[tokio::test]
321    async fn test_simple_aggregation_in_memory() {
322        test_simple_aggregation(MemoryStateStore::new()).await
323    }
324
325    async fn test_simple_aggregation<S: StateStore>(store: S) {
326        let schema = Schema {
327            fields: vec![
328                Field::unnamed(DataType::Int64),
329                Field::unnamed(DataType::Int64),
330                // primary key column`
331                Field::unnamed(DataType::Int64),
332            ],
333        };
334        let (mut tx, source) = MockSource::channel();
335        let source = source.into_executor(schema, vec![2]);
336        tx.push_barrier(test_epoch(1), false);
337        tx.push_barrier(test_epoch(2), false);
338        tx.push_chunk(StreamChunk::from_pretty(
339            "   I   I    I
340            + 100 200 1001
341            +  10  14 1002
342            +   4 300 1003",
343        ));
344        tx.push_barrier(test_epoch(3), false);
345        tx.push_chunk(StreamChunk::from_pretty(
346            "   I   I    I
347            - 100 200 1001
348            -  10  14 1002 D
349            -   4 300 1003
350            + 104 500 1004",
351        ));
352        tx.push_barrier(test_epoch(4), false);
353
354        let agg_calls = vec![
355            AggCall::from_pretty("(count:int8)"),
356            AggCall::from_pretty("(sum:int8 $0:int8)"),
357            AggCall::from_pretty("(sum:int8 $1:int8)"),
358            AggCall::from_pretty("(min:int8 $0:int8)"),
359        ];
360
361        let simple_agg = new_boxed_simple_agg_executor(
362            ActorContext::for_test(123),
363            store,
364            source,
365            false,
366            agg_calls,
367            0,
368            vec![2],
369            1,
370            false,
371        )
372        .await;
373        let mut simple_agg = simple_agg.execute();
374
375        // Consume the init barrier
376        simple_agg.next().await.unwrap().unwrap();
377        // Consume stream chunk
378        let msg = simple_agg.next().await.unwrap().unwrap();
379        assert_eq!(
380            *msg.as_chunk().unwrap(),
381            StreamChunk::from_pretty(
382                " I   I   I  I
383                + 0   .   .  . "
384            )
385        );
386        assert_matches!(
387            simple_agg.next().await.unwrap().unwrap(),
388            Message::Barrier { .. }
389        );
390
391        // Consume stream chunk
392        let msg = simple_agg.next().await.unwrap().unwrap();
393        assert_eq!(
394            *msg.as_chunk().unwrap(),
395            StreamChunk::from_pretty(
396                "  I   I   I  I
397                U- 0   .   .  .
398                U+ 3 114 514  4"
399            )
400        );
401        assert_matches!(
402            simple_agg.next().await.unwrap().unwrap(),
403            Message::Barrier { .. }
404        );
405
406        let msg = simple_agg.next().await.unwrap().unwrap();
407        assert_eq!(
408            *msg.as_chunk().unwrap(),
409            StreamChunk::from_pretty(
410                "  I   I   I  I
411                U- 3 114 514  4
412                U+ 2 114 514 10"
413            )
414        );
415    }
416
417    // NOTE(kwannoel): `approx_percentile` + `keyed_merge` depend on this property for correctness.
418    #[tokio::test]
419    async fn test_simple_aggregation_always_output_per_epoch() {
420        let store = MemoryStateStore::new();
421        let schema = Schema {
422            fields: vec![
423                Field::unnamed(DataType::Int64),
424                Field::unnamed(DataType::Int64),
425                // primary key column`
426                Field::unnamed(DataType::Int64),
427            ],
428        };
429        let (mut tx, source) = MockSource::channel();
430        let source = source.into_executor(schema, vec![2]);
431        // initial barrier
432        tx.push_barrier(test_epoch(1), false);
433        // next barrier
434        tx.push_barrier(test_epoch(2), false);
435        tx.push_chunk(StreamChunk::from_pretty(
436            "   I   I    I
437            + 100 200 1001
438            - 100 200 1001",
439        ));
440        tx.push_barrier(test_epoch(3), false);
441        tx.push_barrier(test_epoch(4), false);
442
443        let agg_calls = vec![
444            AggCall::from_pretty("(count:int8)"),
445            AggCall::from_pretty("(sum:int8 $0:int8)"),
446            AggCall::from_pretty("(sum:int8 $1:int8)"),
447            AggCall::from_pretty("(min:int8 $0:int8)"),
448        ];
449
450        let simple_agg = new_boxed_simple_agg_executor(
451            ActorContext::for_test(123),
452            store,
453            source,
454            false,
455            agg_calls,
456            0,
457            vec![2],
458            1,
459            true,
460        )
461        .await;
462        let mut simple_agg = simple_agg.execute();
463
464        // Consume the init barrier
465        simple_agg.next().await.unwrap().unwrap();
466        // Consume stream chunk
467        let msg = simple_agg.next().await.unwrap().unwrap();
468        assert_eq!(
469            *msg.as_chunk().unwrap(),
470            StreamChunk::from_pretty(
471                " I   I   I  I
472                + 0   .   .  . "
473            )
474        );
475        assert_matches!(
476            simple_agg.next().await.unwrap().unwrap(),
477            Message::Barrier { .. }
478        );
479
480        // Consume stream chunk
481        let msg = simple_agg.next().await.unwrap().unwrap();
482        assert_eq!(
483            *msg.as_chunk().unwrap(),
484            StreamChunk::from_pretty(
485                "  I   I   I  I
486                U- 0   .   .  .
487                U+ 0   .   .  ."
488            )
489        );
490        assert_matches!(
491            simple_agg.next().await.unwrap().unwrap(),
492            Message::Barrier { .. }
493        );
494
495        // Consume stream chunk
496        let msg = simple_agg.next().await.unwrap().unwrap();
497        assert_eq!(
498            *msg.as_chunk().unwrap(),
499            StreamChunk::from_pretty(
500                "  I   I   I  I
501                U- 0   .   .  .
502                U+ 0   .   .  ."
503            )
504        );
505        assert_matches!(
506            simple_agg.next().await.unwrap().unwrap(),
507            Message::Barrier { .. }
508        );
509    }
510
511    // NOTE(kwannoel): `approx_percentile` + `keyed_merge` depend on this property for correctness.
512    #[tokio::test]
513    async fn test_simple_aggregation_omit_noop_update() {
514        let store = MemoryStateStore::new();
515        let schema = Schema {
516            fields: vec![
517                Field::unnamed(DataType::Int64),
518                Field::unnamed(DataType::Int64),
519                // primary key column`
520                Field::unnamed(DataType::Int64),
521            ],
522        };
523        let (mut tx, source) = MockSource::channel();
524        let source = source.into_executor(schema, vec![2]);
525        // initial barrier
526        tx.push_barrier(test_epoch(1), false);
527        // next barrier
528        tx.push_barrier(test_epoch(2), false);
529        tx.push_chunk(StreamChunk::from_pretty(
530            "   I   I    I
531                + 100 200 1001
532                - 100 200 1001",
533        ));
534        tx.push_barrier(test_epoch(3), false);
535        tx.push_barrier(test_epoch(4), false);
536
537        let agg_calls = vec![
538            AggCall::from_pretty("(count:int8)"),
539            AggCall::from_pretty("(sum:int8 $0:int8)"),
540            AggCall::from_pretty("(sum:int8 $1:int8)"),
541            AggCall::from_pretty("(min:int8 $0:int8)"),
542        ];
543
544        let simple_agg = new_boxed_simple_agg_executor(
545            ActorContext::for_test(123),
546            store,
547            source,
548            false,
549            agg_calls,
550            0,
551            vec![2],
552            1,
553            false,
554        )
555        .await;
556        let mut simple_agg = simple_agg.execute();
557
558        // Consume the init barrier
559        simple_agg.next().await.unwrap().unwrap();
560        // Consume stream chunk
561        let msg = simple_agg.next().await.unwrap().unwrap();
562        assert_eq!(
563            *msg.as_chunk().unwrap(),
564            StreamChunk::from_pretty(
565                " I   I   I  I
566                + 0   .   .  . "
567            )
568        );
569        assert_matches!(
570            simple_agg.next().await.unwrap().unwrap(),
571            Message::Barrier { .. }
572        );
573
574        // No stream chunk
575        assert_matches!(
576            simple_agg.next().await.unwrap().unwrap(),
577            Message::Barrier { .. }
578        );
579
580        // No stream chunk
581        assert_matches!(
582            simple_agg.next().await.unwrap().unwrap(),
583            Message::Barrier { .. }
584        );
585    }
586}