risingwave_stream/executor/aggregate/
agg_group.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::Debug;
16use std::marker::PhantomData;
17use std::sync::Arc;
18
19use futures::future::try_join_all;
20use risingwave_common::array::StreamChunk;
21use risingwave_common::array::stream_record::{Record, RecordType};
22use risingwave_common::bitmap::Bitmap;
23use risingwave_common::catalog::Schema;
24use risingwave_common::must_match;
25use risingwave_common::row::{OwnedRow, Row, RowExt};
26use risingwave_common::util::iter_util::ZipEqFast;
27use risingwave_common_estimate_size::EstimateSize;
28use risingwave_expr::aggregate::{AggCall, BoxedAggregateFunction};
29use risingwave_pb::stream_plan::PbAggNodeVersion;
30use risingwave_storage::StateStore;
31
32use super::agg_state::{AggState, AggStateStorage};
33use crate::common::table::state_table::StateTable;
34use crate::consistency::consistency_panic;
35use crate::executor::PkIndices;
36use crate::executor::error::StreamExecutorResult;
37
38#[derive(Debug)]
39pub struct Context {
40    group_key: Option<GroupKey>,
41}
42
43impl Context {
44    pub fn group_key(&self) -> Option<&GroupKey> {
45        self.group_key.as_ref()
46    }
47
48    pub fn group_key_row(&self) -> OwnedRow {
49        self.group_key()
50            .map(GroupKey::table_row)
51            .cloned()
52            .unwrap_or_default()
53    }
54}
55
56fn row_count_of(ctx: &Context, row: Option<impl Row>, row_count_col: usize) -> usize {
57    match row {
58        Some(row) => {
59            let mut row_count = row
60                .datum_at(row_count_col)
61                .expect("row count field should not be NULL")
62                .into_int64();
63
64            if row_count < 0 {
65                consistency_panic!(group = ?ctx.group_key_row(), row_count, "row count should be non-negative");
66
67                // NOTE: Here is the case that an inconsistent `DELETE` arrives at HashAgg executor, and there's no
68                // corresponding group existing before (or has been deleted). In this case, previous row count should
69                // be `0` and current row count be `-1` after handling the `DELETE`. To ignore the inconsistency, we
70                // reset `row_count` to `0` here, so that `OnlyOutputIfHasInput` will return no change, so that the
71                // inconsistent will be hidden from downstream. This won't prevent from incorrect results of existing
72                // groups, but at least can prevent from downstream panicking due to non-existing keys.
73                // See https://github.com/risingwavelabs/risingwave/issues/14031 for more information.
74                row_count = 0;
75            }
76            row_count.try_into().unwrap()
77        }
78        None => 0,
79    }
80}
81
82pub trait Strategy {
83    /// Infer the change type of the aggregation result. Don't need to take the ownership of
84    /// `prev_row` and `curr_row`.
85    fn infer_change_type(
86        ctx: &Context,
87        prev_row: Option<&OwnedRow>,
88        curr_row: &OwnedRow,
89        row_count_col: usize,
90    ) -> Option<RecordType>;
91}
92
93/// The strategy that always outputs the aggregation result no matter there're input rows or not.
94pub struct AlwaysOutput;
95/// The strategy that only outputs the aggregation result when there're input rows. If row count
96/// drops to 0, the output row will be deleted.
97pub struct OnlyOutputIfHasInput;
98
99impl Strategy for AlwaysOutput {
100    fn infer_change_type(
101        ctx: &Context,
102        prev_row: Option<&OwnedRow>,
103        _curr_row: &OwnedRow,
104        row_count_col: usize,
105    ) -> Option<RecordType> {
106        let prev_row_count = row_count_of(ctx, prev_row, row_count_col);
107        match prev_row {
108            None => {
109                // First time to build changes, assert to ensure correctness.
110                // Note that it's not true vice versa, i.e. `prev_row_count == 0` doesn't imply
111                // `prev_outputs == None`.
112                assert_eq!(prev_row_count, 0);
113
114                // Generate output no matter whether current row count is 0 or not.
115                Some(RecordType::Insert)
116            }
117            // NOTE(kwannoel): We always output, even if the update is a no-op.
118            // e.g. the following will still be emitted downstream:
119            // ```
120            // U- 1
121            // U+ 1
122            // ```
123            // This is to support `approx_percentile` via `row_merge`, which requires
124            // both the lhs and rhs to always output updates per epoch, or not all.
125            // Otherwise we are unable to construct a full row, if only one side updates,
126            // as the `row_merge` executor is stateless.
127            Some(_prev_outputs) => Some(RecordType::Update),
128        }
129    }
130}
131
132impl Strategy for OnlyOutputIfHasInput {
133    fn infer_change_type(
134        ctx: &Context,
135        prev_row: Option<&OwnedRow>,
136        curr_row: &OwnedRow,
137        row_count_col: usize,
138    ) -> Option<RecordType> {
139        let prev_row_count = row_count_of(ctx, prev_row, row_count_col);
140        let curr_row_count = row_count_of(ctx, Some(curr_row), row_count_col);
141
142        match (prev_row_count, curr_row_count) {
143            (0, 0) => {
144                // No rows of current group exist.
145                None
146            }
147            (0, _) => {
148                // Insert new output row for this newly emerged group.
149                Some(RecordType::Insert)
150            }
151            (_, 0) => {
152                // Delete old output row for this newly disappeared group.
153                Some(RecordType::Delete)
154            }
155            (_, _) => {
156                // Update output row.
157                if prev_row.expect("must exist previous row") == curr_row {
158                    // No output change.
159                    None
160                } else {
161                    Some(RecordType::Update)
162                }
163            }
164        }
165    }
166}
167
168/// [`GroupKey`] wraps a concrete group key and handle its mapping to state table pk.
169#[derive(Clone, Debug)]
170pub struct GroupKey {
171    row_prefix: OwnedRow,
172    table_pk_projection: Arc<[usize]>,
173}
174
175impl GroupKey {
176    pub fn new(row_prefix: OwnedRow, table_pk_projection: Option<Arc<[usize]>>) -> Self {
177        let table_pk_projection =
178            table_pk_projection.unwrap_or_else(|| (0..row_prefix.len()).collect());
179        Self {
180            row_prefix,
181            table_pk_projection,
182        }
183    }
184
185    pub fn len(&self) -> usize {
186        self.row_prefix.len()
187    }
188
189    pub fn is_empty(&self) -> bool {
190        self.row_prefix.is_empty()
191    }
192
193    /// Get the group key for state table row prefix.
194    pub fn table_row(&self) -> &OwnedRow {
195        &self.row_prefix
196    }
197
198    /// Get the group key for state table pk prefix.
199    pub fn table_pk(&self) -> impl Row + '_ {
200        (&self.row_prefix).project(&self.table_pk_projection)
201    }
202
203    /// Get the group key for LRU cache key prefix.
204    pub fn cache_key(&self) -> impl Row + '_ {
205        self.table_row()
206    }
207}
208
209/// [`AggGroup`] manages agg states of all agg calls for one `group_key`.
210pub struct AggGroup<S: StateStore, Strtg: Strategy> {
211    /// Agg group context, containing the group key.
212    ctx: Context,
213
214    /// Current managed states for all [`AggCall`]s.
215    states: Vec<AggState>,
216
217    /// Previous intermediate states, stored in the intermediate state table.
218    prev_inter_states: Option<OwnedRow>,
219
220    /// Previous outputs, yielded to downstream.
221    /// If `EOWC` is true, this field is not used.
222    prev_outputs: Option<OwnedRow>,
223
224    /// Index of row count agg call (`count(*)`) in the call list.
225    row_count_index: usize,
226
227    /// Whether the emit policy is EOWC.
228    emit_on_window_close: bool,
229
230    _phantom: PhantomData<(S, Strtg)>,
231}
232
233impl<S: StateStore, Strtg: Strategy> Debug for AggGroup<S, Strtg> {
234    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
235        f.debug_struct("AggGroup")
236            .field("group_key", &self.ctx.group_key)
237            .field("prev_inter_states", &self.prev_inter_states)
238            .field("prev_outputs", &self.prev_outputs)
239            .field("row_count_index", &self.row_count_index)
240            .field("emit_on_window_close", &self.emit_on_window_close)
241            .finish()
242    }
243}
244
245impl<S: StateStore, Strtg: Strategy> EstimateSize for AggGroup<S, Strtg> {
246    fn estimated_heap_size(&self) -> usize {
247        // TODO(rc): should include the size of `prev_inter_states` and `prev_outputs`
248        self.states
249            .iter()
250            .map(|state| state.estimated_heap_size())
251            .sum()
252    }
253}
254
255impl<S: StateStore, Strtg: Strategy> AggGroup<S, Strtg> {
256    /// Create [`AggGroup`] for the given [`AggCall`]s and `group_key`.
257    /// For [`SimpleAggExecutor`], the `group_key` should be `None`.
258    ///
259    /// [`SimpleAggExecutor`]: crate::executor::aggregate::SimpleAggExecutor
260    #[allow(clippy::too_many_arguments)]
261    pub async fn create(
262        version: PbAggNodeVersion,
263        group_key: Option<GroupKey>,
264        agg_calls: &[AggCall],
265        agg_funcs: &[BoxedAggregateFunction],
266        storages: &[AggStateStorage<S>],
267        intermediate_state_table: &StateTable<S>,
268        pk_indices: &PkIndices,
269        row_count_index: usize,
270        emit_on_window_close: bool,
271        extreme_cache_size: usize,
272        input_schema: &Schema,
273    ) -> StreamExecutorResult<Self> {
274        let inter_states = intermediate_state_table
275            .get_row(group_key.as_ref().map(GroupKey::table_pk))
276            .await?;
277        if let Some(inter_states) = &inter_states {
278            assert_eq!(inter_states.len(), agg_calls.len());
279        }
280
281        let mut states = Vec::with_capacity(agg_calls.len());
282        for (idx, (agg_call, agg_func)) in agg_calls.iter().zip_eq_fast(agg_funcs).enumerate() {
283            let state = AggState::create(
284                version,
285                agg_call,
286                agg_func,
287                &storages[idx],
288                inter_states.as_ref().map(|s| &s[idx]),
289                pk_indices,
290                extreme_cache_size,
291                input_schema,
292            )?;
293            states.push(state);
294        }
295
296        let mut this = Self {
297            ctx: Context { group_key },
298            states,
299            prev_inter_states: inter_states,
300            prev_outputs: None, // will be set below
301            row_count_index,
302            emit_on_window_close,
303            _phantom: PhantomData,
304        };
305
306        if !this.emit_on_window_close && this.prev_inter_states.is_some() {
307            let (outputs, _stats) = this.get_outputs(storages, agg_funcs).await?;
308            this.prev_outputs = Some(outputs);
309        }
310
311        Ok(this)
312    }
313
314    /// Create a group from intermediate states for EOWC output.
315    /// Will always produce `Insert` when building change.
316    #[allow(clippy::too_many_arguments)]
317    pub fn for_eowc_output(
318        version: PbAggNodeVersion,
319        group_key: Option<GroupKey>,
320        agg_calls: &[AggCall],
321        agg_funcs: &[BoxedAggregateFunction],
322        storages: &[AggStateStorage<S>],
323        inter_states: &OwnedRow,
324        pk_indices: &PkIndices,
325        row_count_index: usize,
326        emit_on_window_close: bool,
327        extreme_cache_size: usize,
328        input_schema: &Schema,
329    ) -> StreamExecutorResult<Self> {
330        let mut states = Vec::with_capacity(agg_calls.len());
331        for (idx, (agg_call, agg_func)) in agg_calls.iter().zip_eq_fast(agg_funcs).enumerate() {
332            let state = AggState::create(
333                version,
334                agg_call,
335                agg_func,
336                &storages[idx],
337                Some(&inter_states[idx]),
338                pk_indices,
339                extreme_cache_size,
340                input_schema,
341            )?;
342            states.push(state);
343        }
344
345        Ok(Self {
346            ctx: Context { group_key },
347            states,
348            prev_inter_states: None, // this doesn't matter
349            prev_outputs: None,      // this will make sure the outputs change to be `Insert`
350            row_count_index,
351            emit_on_window_close,
352            _phantom: PhantomData,
353        })
354    }
355
356    pub fn group_key(&self) -> Option<&GroupKey> {
357        self.ctx.group_key()
358    }
359
360    /// Get current row count of this group.
361    fn curr_row_count(&self) -> usize {
362        let row_count_state = must_match!(
363            self.states[self.row_count_index],
364            AggState::Value(ref state) => state
365        );
366        row_count_of(&self.ctx, Some([row_count_state.as_datum().clone()]), 0)
367    }
368
369    pub(crate) fn is_uninitialized(&self) -> bool {
370        self.prev_inter_states.is_none()
371    }
372
373    /// Apply input chunk to all managed agg states.
374    ///
375    /// `mappings` contains the column mappings from input chunk to each agg call.
376    /// `visibilities` contains the row visibility of the input chunk for each agg call.
377    pub async fn apply_chunk(
378        &mut self,
379        chunk: &StreamChunk,
380        calls: &[AggCall],
381        funcs: &[BoxedAggregateFunction],
382        visibilities: Vec<Bitmap>,
383    ) -> StreamExecutorResult<()> {
384        if self.curr_row_count() == 0 {
385            tracing::trace!(group = ?self.ctx.group_key_row(), "first time see this group");
386        }
387
388        let concurrency = 10;
389        let len = self.states.len();
390
391        for chunk_start in (0..len).step_by(concurrency) {
392            let chunk_end = std::cmp::min(chunk_start + concurrency, len);
393
394            // Create futures for this chunk
395            let futures = &mut self.states[chunk_start..chunk_end]
396                .iter_mut()
397                .zip_eq_fast(&calls[chunk_start..chunk_end])
398                .zip_eq_fast(&funcs[chunk_start..chunk_end])
399                .zip_eq_fast(&visibilities[chunk_start..chunk_end])
400                .map(|(((state, call), func), visibility)| {
401                    state.apply_chunk(chunk, call, func, visibility.clone())
402                });
403
404            try_join_all(futures).await?;
405        }
406
407        if self.curr_row_count() == 0 {
408            tracing::trace!(group = ?self.ctx.group_key_row(), "last time see this group");
409        }
410
411        Ok(())
412    }
413
414    /// Reset all in-memory states to their initial state, i.e. to reset all agg state structs to
415    /// the status as if they are just created, no input applied and no row in state table.
416    fn reset(&mut self, funcs: &[BoxedAggregateFunction]) -> StreamExecutorResult<()> {
417        for (state, func) in self.states.iter_mut().zip_eq_fast(funcs) {
418            state.reset(func)?;
419        }
420        Ok(())
421    }
422
423    /// Get the encoded intermediate states of all managed agg states.
424    fn get_inter_states(&self, funcs: &[BoxedAggregateFunction]) -> StreamExecutorResult<OwnedRow> {
425        let mut inter_states = Vec::with_capacity(self.states.len());
426        for (state, func) in self.states.iter().zip_eq_fast(funcs) {
427            let encoded = match state {
428                AggState::Value(s) => func.encode_state(s)?,
429                // For minput state, we don't need to store it in state table.
430                AggState::MaterializedInput(_) => None,
431            };
432            inter_states.push(encoded);
433        }
434        Ok(OwnedRow::new(inter_states))
435    }
436
437    /// Get the outputs of all managed agg states, without group key prefix.
438    /// Possibly need to read/sync from state table if the state not cached in memory.
439    /// This method is idempotent, i.e. it can be called multiple times and the outputs are
440    /// guaranteed to be the same.
441    async fn get_outputs(
442        &mut self,
443        storages: &[AggStateStorage<S>],
444        funcs: &[BoxedAggregateFunction],
445    ) -> StreamExecutorResult<(OwnedRow, AggStateCacheStats)> {
446        let row_count = self.curr_row_count();
447        if row_count == 0 {
448            // Reset all states (in fact only value states will be reset).
449            // This is important because for some agg calls (e.g. `sum`), if no row is applied,
450            // they should output NULL, for some other calls (e.g. `sum0`), they should output 0.
451            // This actually also prevents inconsistent negative row count from being worse.
452            // FIXME(rc): Deciding whether to reset states according to `row_count` is not precisely
453            // correct, see https://github.com/risingwavelabs/risingwave/issues/7412 for bug description.
454            self.reset(funcs)?;
455        }
456        let mut stats = AggStateCacheStats::default();
457        futures::future::try_join_all(
458            self.states
459                .iter_mut()
460                .zip_eq_fast(storages)
461                .zip_eq_fast(funcs)
462                .map(|((state, storage), func)| {
463                    state.get_output(storage, func, self.ctx.group_key())
464                }),
465        )
466        .await
467        .map(|outputs_and_stats| {
468            outputs_and_stats
469                .into_iter()
470                .map(|(output, stat)| {
471                    stats.merge(stat);
472                    output
473                })
474                .collect::<Vec<_>>()
475        })
476        .map(|row| (OwnedRow::new(row), stats))
477    }
478
479    /// Build change for aggregation intermediate states, according to previous and current agg states.
480    /// The change should be applied to the intermediate state table.
481    ///
482    /// The saved previous inter states will be updated to the latest states after calling this method.
483    pub fn build_states_change(
484        &mut self,
485        funcs: &[BoxedAggregateFunction],
486    ) -> StreamExecutorResult<Option<Record<OwnedRow>>> {
487        let curr_inter_states = self.get_inter_states(funcs)?;
488        let change_type = Strtg::infer_change_type(
489            &self.ctx,
490            self.prev_inter_states.as_ref(),
491            &curr_inter_states,
492            self.row_count_index,
493        );
494
495        tracing::trace!(
496            group = ?self.ctx.group_key_row(),
497            prev_inter_states = ?self.prev_inter_states,
498            curr_inter_states = ?curr_inter_states,
499            change_type = ?change_type,
500            "build intermediate states change"
501        );
502
503        let Some(change_type) = change_type else {
504            return Ok(None);
505        };
506        Ok(Some(match change_type {
507            RecordType::Insert => {
508                let new_row = self
509                    .group_key()
510                    .map(GroupKey::table_row)
511                    .chain(&curr_inter_states)
512                    .into_owned_row();
513                self.prev_inter_states = Some(curr_inter_states);
514                Record::Insert { new_row }
515            }
516            RecordType::Delete => {
517                let prev_inter_states = self
518                    .prev_inter_states
519                    .take()
520                    .expect("must exist previous intermediate states");
521                let old_row = self
522                    .group_key()
523                    .map(GroupKey::table_row)
524                    .chain(prev_inter_states)
525                    .into_owned_row();
526                Record::Delete { old_row }
527            }
528            RecordType::Update => {
529                let new_row = self
530                    .group_key()
531                    .map(GroupKey::table_row)
532                    .chain(&curr_inter_states)
533                    .into_owned_row();
534                let prev_inter_states = self
535                    .prev_inter_states
536                    .replace(curr_inter_states)
537                    .expect("must exist previous intermediate states");
538                let old_row = self
539                    .group_key()
540                    .map(GroupKey::table_row)
541                    .chain(prev_inter_states)
542                    .into_owned_row();
543                Record::Update { old_row, new_row }
544            }
545        }))
546    }
547
548    /// Build aggregation result change, according to previous and current agg outputs.
549    /// The change should be yielded to downstream.
550    ///
551    /// The saved previous outputs will be updated to the latest outputs after this method.
552    ///
553    /// Note that this method is very likely to cost more than `build_states_change`, because it
554    /// needs to produce output for materialized input states which may involve state table read.
555    pub async fn build_outputs_change(
556        &mut self,
557        storages: &[AggStateStorage<S>],
558        funcs: &[BoxedAggregateFunction],
559    ) -> StreamExecutorResult<(Option<Record<OwnedRow>>, AggStateCacheStats)> {
560        let (curr_outputs, stats) = self.get_outputs(storages, funcs).await?;
561
562        let change_type = Strtg::infer_change_type(
563            &self.ctx,
564            self.prev_outputs.as_ref(),
565            &curr_outputs,
566            self.row_count_index,
567        );
568
569        tracing::trace!(
570            group = ?self.ctx.group_key_row(),
571            prev_outputs = ?self.prev_outputs,
572            curr_outputs = ?curr_outputs,
573            change_type = ?change_type,
574            "build outputs change"
575        );
576
577        let Some(change_type) = change_type else {
578            return Ok((None, stats));
579        };
580        Ok((
581            Some(match change_type {
582                RecordType::Insert => {
583                    let new_row = self
584                        .group_key()
585                        .map(GroupKey::table_row)
586                        .chain(&curr_outputs)
587                        .into_owned_row();
588                    // Although we say the `prev_outputs` field is not used in EOWC mode, we still
589                    // do the same here to keep the code simple. When it's actually running in EOWC
590                    // mode, `build_outputs_change` will be called only once for each group.
591                    self.prev_outputs = Some(curr_outputs);
592                    Record::Insert { new_row }
593                }
594                RecordType::Delete => {
595                    let prev_outputs = self.prev_outputs.take();
596                    let old_row = self
597                        .group_key()
598                        .map(GroupKey::table_row)
599                        .chain(prev_outputs)
600                        .into_owned_row();
601                    Record::Delete { old_row }
602                }
603                RecordType::Update => {
604                    let new_row = self
605                        .group_key()
606                        .map(GroupKey::table_row)
607                        .chain(&curr_outputs)
608                        .into_owned_row();
609                    let prev_outputs = self.prev_outputs.replace(curr_outputs);
610                    let old_row = self
611                        .group_key()
612                        .map(GroupKey::table_row)
613                        .chain(prev_outputs)
614                        .into_owned_row();
615                    Record::Update { old_row, new_row }
616                }
617            }),
618            stats,
619        ))
620    }
621}
622
623/// Stats for agg state cache operations.
624#[derive(Debug, Default)]
625pub struct AggStateCacheStats {
626    pub agg_state_cache_lookup_count: u64,
627    pub agg_state_cache_miss_count: u64,
628}
629
630impl AggStateCacheStats {
631    fn merge(&mut self, other: Self) {
632        self.agg_state_cache_lookup_count += other.agg_state_cache_lookup_count;
633        self.agg_state_cache_miss_count += other.agg_state_cache_miss_count;
634    }
635}