risingwave_stream/executor/aggregate/
agg_group.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::Debug;
16use std::marker::PhantomData;
17use std::sync::Arc;
18
19use risingwave_common::array::StreamChunk;
20use risingwave_common::array::stream_record::{Record, RecordType};
21use risingwave_common::bitmap::Bitmap;
22use risingwave_common::catalog::Schema;
23use risingwave_common::must_match;
24use risingwave_common::row::{OwnedRow, Row, RowExt};
25use risingwave_common::util::iter_util::ZipEqFast;
26use risingwave_common_estimate_size::EstimateSize;
27use risingwave_expr::aggregate::{AggCall, BoxedAggregateFunction};
28use risingwave_pb::stream_plan::PbAggNodeVersion;
29use risingwave_storage::StateStore;
30
31use super::agg_state::{AggState, AggStateStorage};
32use crate::common::table::state_table::StateTable;
33use crate::consistency::consistency_panic;
34use crate::executor::PkIndices;
35use crate::executor::error::StreamExecutorResult;
36
37#[derive(Debug)]
38pub struct Context {
39    group_key: Option<GroupKey>,
40}
41
42impl Context {
43    pub fn group_key(&self) -> Option<&GroupKey> {
44        self.group_key.as_ref()
45    }
46
47    pub fn group_key_row(&self) -> OwnedRow {
48        self.group_key()
49            .map(GroupKey::table_row)
50            .cloned()
51            .unwrap_or_default()
52    }
53}
54
55fn row_count_of(ctx: &Context, row: Option<impl Row>, row_count_col: usize) -> usize {
56    match row {
57        Some(row) => {
58            let mut row_count = row
59                .datum_at(row_count_col)
60                .expect("row count field should not be NULL")
61                .into_int64();
62
63            if row_count < 0 {
64                consistency_panic!(group = ?ctx.group_key_row(), row_count, "row count should be non-negative");
65
66                // NOTE: Here is the case that an inconsistent `DELETE` arrives at HashAgg executor, and there's no
67                // corresponding group existing before (or has been deleted). In this case, previous row count should
68                // be `0` and current row count be `-1` after handling the `DELETE`. To ignore the inconsistency, we
69                // reset `row_count` to `0` here, so that `OnlyOutputIfHasInput` will return no change, so that the
70                // inconsistent will be hidden from downstream. This won't prevent from incorrect results of existing
71                // groups, but at least can prevent from downstream panicking due to non-existing keys.
72                // See https://github.com/risingwavelabs/risingwave/issues/14031 for more information.
73                row_count = 0;
74            }
75            row_count.try_into().unwrap()
76        }
77        None => 0,
78    }
79}
80
81pub trait Strategy {
82    /// Infer the change type of the aggregation result. Don't need to take the ownership of
83    /// `prev_row` and `curr_row`.
84    fn infer_change_type(
85        ctx: &Context,
86        prev_row: Option<&OwnedRow>,
87        curr_row: &OwnedRow,
88        row_count_col: usize,
89    ) -> Option<RecordType>;
90}
91
92/// The strategy that always outputs the aggregation result no matter there're input rows or not.
93pub struct AlwaysOutput;
94/// The strategy that only outputs the aggregation result when there're input rows. If row count
95/// drops to 0, the output row will be deleted.
96pub struct OnlyOutputIfHasInput;
97
98impl Strategy for AlwaysOutput {
99    fn infer_change_type(
100        ctx: &Context,
101        prev_row: Option<&OwnedRow>,
102        _curr_row: &OwnedRow,
103        row_count_col: usize,
104    ) -> Option<RecordType> {
105        let prev_row_count = row_count_of(ctx, prev_row, row_count_col);
106        match prev_row {
107            None => {
108                // First time to build changes, assert to ensure correctness.
109                // Note that it's not true vice versa, i.e. `prev_row_count == 0` doesn't imply
110                // `prev_outputs == None`.
111                assert_eq!(prev_row_count, 0);
112
113                // Generate output no matter whether current row count is 0 or not.
114                Some(RecordType::Insert)
115            }
116            // NOTE(kwannoel): We always output, even if the update is a no-op.
117            // e.g. the following will still be emitted downstream:
118            // ```
119            // U- 1
120            // U+ 1
121            // ```
122            // This is to support `approx_percentile` via `row_merge`, which requires
123            // both the lhs and rhs to always output updates per epoch, or not all.
124            // Otherwise we are unable to construct a full row, if only one side updates,
125            // as the `row_merge` executor is stateless.
126            Some(_prev_outputs) => Some(RecordType::Update),
127        }
128    }
129}
130
131impl Strategy for OnlyOutputIfHasInput {
132    fn infer_change_type(
133        ctx: &Context,
134        prev_row: Option<&OwnedRow>,
135        curr_row: &OwnedRow,
136        row_count_col: usize,
137    ) -> Option<RecordType> {
138        let prev_row_count = row_count_of(ctx, prev_row, row_count_col);
139        let curr_row_count = row_count_of(ctx, Some(curr_row), row_count_col);
140
141        match (prev_row_count, curr_row_count) {
142            (0, 0) => {
143                // No rows of current group exist.
144                None
145            }
146            (0, _) => {
147                // Insert new output row for this newly emerged group.
148                Some(RecordType::Insert)
149            }
150            (_, 0) => {
151                // Delete old output row for this newly disappeared group.
152                Some(RecordType::Delete)
153            }
154            (_, _) => {
155                // Update output row.
156                if prev_row.expect("must exist previous row") == curr_row {
157                    // No output change.
158                    None
159                } else {
160                    Some(RecordType::Update)
161                }
162            }
163        }
164    }
165}
166
167/// [`GroupKey`] wraps a concrete group key and handle its mapping to state table pk.
168#[derive(Clone, Debug)]
169pub struct GroupKey {
170    row_prefix: OwnedRow,
171    table_pk_projection: Arc<[usize]>,
172}
173
174impl GroupKey {
175    pub fn new(row_prefix: OwnedRow, table_pk_projection: Option<Arc<[usize]>>) -> Self {
176        let table_pk_projection =
177            table_pk_projection.unwrap_or_else(|| (0..row_prefix.len()).collect());
178        Self {
179            row_prefix,
180            table_pk_projection,
181        }
182    }
183
184    pub fn len(&self) -> usize {
185        self.row_prefix.len()
186    }
187
188    pub fn is_empty(&self) -> bool {
189        self.row_prefix.is_empty()
190    }
191
192    /// Get the group key for state table row prefix.
193    pub fn table_row(&self) -> &OwnedRow {
194        &self.row_prefix
195    }
196
197    /// Get the group key for state table pk prefix.
198    pub fn table_pk(&self) -> impl Row + '_ {
199        (&self.row_prefix).project(&self.table_pk_projection)
200    }
201
202    /// Get the group key for LRU cache key prefix.
203    pub fn cache_key(&self) -> impl Row + '_ {
204        self.table_row()
205    }
206}
207
208/// [`AggGroup`] manages agg states of all agg calls for one `group_key`.
209pub struct AggGroup<S: StateStore, Strtg: Strategy> {
210    /// Agg group context, containing the group key.
211    ctx: Context,
212
213    /// Current managed states for all [`AggCall`]s.
214    states: Vec<AggState>,
215
216    /// Previous intermediate states, stored in the intermediate state table.
217    prev_inter_states: Option<OwnedRow>,
218
219    /// Previous outputs, yielded to downstream.
220    /// If `EOWC` is true, this field is not used.
221    prev_outputs: Option<OwnedRow>,
222
223    /// Index of row count agg call (`count(*)`) in the call list.
224    row_count_index: usize,
225
226    /// Whether the emit policy is EOWC.
227    emit_on_window_close: bool,
228
229    _phantom: PhantomData<(S, Strtg)>,
230}
231
232impl<S: StateStore, Strtg: Strategy> Debug for AggGroup<S, Strtg> {
233    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
234        f.debug_struct("AggGroup")
235            .field("group_key", &self.ctx.group_key)
236            .field("prev_inter_states", &self.prev_inter_states)
237            .field("prev_outputs", &self.prev_outputs)
238            .field("row_count_index", &self.row_count_index)
239            .field("emit_on_window_close", &self.emit_on_window_close)
240            .finish()
241    }
242}
243
244impl<S: StateStore, Strtg: Strategy> EstimateSize for AggGroup<S, Strtg> {
245    fn estimated_heap_size(&self) -> usize {
246        // TODO(rc): should include the size of `prev_inter_states` and `prev_outputs`
247        self.states
248            .iter()
249            .map(|state| state.estimated_heap_size())
250            .sum()
251    }
252}
253
254impl<S: StateStore, Strtg: Strategy> AggGroup<S, Strtg> {
255    /// Create [`AggGroup`] for the given [`AggCall`]s and `group_key`.
256    /// For [`SimpleAggExecutor`], the `group_key` should be `None`.
257    ///
258    /// [`SimpleAggExecutor`]: crate::executor::aggregate::SimpleAggExecutor
259    #[allow(clippy::too_many_arguments)]
260    pub async fn create(
261        version: PbAggNodeVersion,
262        group_key: Option<GroupKey>,
263        agg_calls: &[AggCall],
264        agg_funcs: &[BoxedAggregateFunction],
265        storages: &[AggStateStorage<S>],
266        intermediate_state_table: &StateTable<S>,
267        pk_indices: &PkIndices,
268        row_count_index: usize,
269        emit_on_window_close: bool,
270        extreme_cache_size: usize,
271        input_schema: &Schema,
272    ) -> StreamExecutorResult<Self> {
273        let inter_states = intermediate_state_table
274            .get_row(group_key.as_ref().map(GroupKey::table_pk))
275            .await?;
276        if let Some(inter_states) = &inter_states {
277            assert_eq!(inter_states.len(), agg_calls.len());
278        }
279
280        let mut states = Vec::with_capacity(agg_calls.len());
281        for (idx, (agg_call, agg_func)) in agg_calls.iter().zip_eq_fast(agg_funcs).enumerate() {
282            let state = AggState::create(
283                version,
284                agg_call,
285                agg_func,
286                &storages[idx],
287                inter_states.as_ref().map(|s| &s[idx]),
288                pk_indices,
289                extreme_cache_size,
290                input_schema,
291            )?;
292            states.push(state);
293        }
294
295        let mut this = Self {
296            ctx: Context { group_key },
297            states,
298            prev_inter_states: inter_states,
299            prev_outputs: None, // will be set below
300            row_count_index,
301            emit_on_window_close,
302            _phantom: PhantomData,
303        };
304
305        if !this.emit_on_window_close && this.prev_inter_states.is_some() {
306            let (outputs, _stats) = this.get_outputs(storages, agg_funcs).await?;
307            this.prev_outputs = Some(outputs);
308        }
309
310        Ok(this)
311    }
312
313    /// Create a group from intermediate states for EOWC output.
314    /// Will always produce `Insert` when building change.
315    #[allow(clippy::too_many_arguments)]
316    pub fn for_eowc_output(
317        version: PbAggNodeVersion,
318        group_key: Option<GroupKey>,
319        agg_calls: &[AggCall],
320        agg_funcs: &[BoxedAggregateFunction],
321        storages: &[AggStateStorage<S>],
322        inter_states: &OwnedRow,
323        pk_indices: &PkIndices,
324        row_count_index: usize,
325        emit_on_window_close: bool,
326        extreme_cache_size: usize,
327        input_schema: &Schema,
328    ) -> StreamExecutorResult<Self> {
329        let mut states = Vec::with_capacity(agg_calls.len());
330        for (idx, (agg_call, agg_func)) in agg_calls.iter().zip_eq_fast(agg_funcs).enumerate() {
331            let state = AggState::create(
332                version,
333                agg_call,
334                agg_func,
335                &storages[idx],
336                Some(&inter_states[idx]),
337                pk_indices,
338                extreme_cache_size,
339                input_schema,
340            )?;
341            states.push(state);
342        }
343
344        Ok(Self {
345            ctx: Context { group_key },
346            states,
347            prev_inter_states: None, // this doesn't matter
348            prev_outputs: None,      // this will make sure the outputs change to be `Insert`
349            row_count_index,
350            emit_on_window_close,
351            _phantom: PhantomData,
352        })
353    }
354
355    pub fn group_key(&self) -> Option<&GroupKey> {
356        self.ctx.group_key()
357    }
358
359    /// Get current row count of this group.
360    fn curr_row_count(&self) -> usize {
361        let row_count_state = must_match!(
362            self.states[self.row_count_index],
363            AggState::Value(ref state) => state
364        );
365        row_count_of(&self.ctx, Some([row_count_state.as_datum().clone()]), 0)
366    }
367
368    pub(crate) fn is_uninitialized(&self) -> bool {
369        self.prev_inter_states.is_none()
370    }
371
372    /// Apply input chunk to all managed agg states.
373    ///
374    /// `mappings` contains the column mappings from input chunk to each agg call.
375    /// `visibilities` contains the row visibility of the input chunk for each agg call.
376    pub async fn apply_chunk(
377        &mut self,
378        chunk: &StreamChunk,
379        calls: &[AggCall],
380        funcs: &[BoxedAggregateFunction],
381        visibilities: Vec<Bitmap>,
382    ) -> StreamExecutorResult<()> {
383        if self.curr_row_count() == 0 {
384            tracing::trace!(group = ?self.ctx.group_key_row(), "first time see this group");
385        }
386        for (((state, call), func), visibility) in (self.states.iter_mut())
387            .zip_eq_fast(calls)
388            .zip_eq_fast(funcs)
389            .zip_eq_fast(visibilities)
390        {
391            state.apply_chunk(chunk, call, func, visibility).await?;
392        }
393
394        if self.curr_row_count() == 0 {
395            tracing::trace!(group = ?self.ctx.group_key_row(), "last time see this group");
396        }
397
398        Ok(())
399    }
400
401    /// Reset all in-memory states to their initial state, i.e. to reset all agg state structs to
402    /// the status as if they are just created, no input applied and no row in state table.
403    fn reset(&mut self, funcs: &[BoxedAggregateFunction]) -> StreamExecutorResult<()> {
404        for (state, func) in self.states.iter_mut().zip_eq_fast(funcs) {
405            state.reset(func)?;
406        }
407        Ok(())
408    }
409
410    /// Get the encoded intermediate states of all managed agg states.
411    fn get_inter_states(&self, funcs: &[BoxedAggregateFunction]) -> StreamExecutorResult<OwnedRow> {
412        let mut inter_states = Vec::with_capacity(self.states.len());
413        for (state, func) in self.states.iter().zip_eq_fast(funcs) {
414            let encoded = match state {
415                AggState::Value(s) => func.encode_state(s)?,
416                // For minput state, we don't need to store it in state table.
417                AggState::MaterializedInput(_) => None,
418            };
419            inter_states.push(encoded);
420        }
421        Ok(OwnedRow::new(inter_states))
422    }
423
424    /// Get the outputs of all managed agg states, without group key prefix.
425    /// Possibly need to read/sync from state table if the state not cached in memory.
426    /// This method is idempotent, i.e. it can be called multiple times and the outputs are
427    /// guaranteed to be the same.
428    async fn get_outputs(
429        &mut self,
430        storages: &[AggStateStorage<S>],
431        funcs: &[BoxedAggregateFunction],
432    ) -> StreamExecutorResult<(OwnedRow, AggStateCacheStats)> {
433        let row_count = self.curr_row_count();
434        if row_count == 0 {
435            // Reset all states (in fact only value states will be reset).
436            // This is important because for some agg calls (e.g. `sum`), if no row is applied,
437            // they should output NULL, for some other calls (e.g. `sum0`), they should output 0.
438            // This actually also prevents inconsistent negative row count from being worse.
439            // FIXME(rc): Deciding whether to reset states according to `row_count` is not precisely
440            // correct, see https://github.com/risingwavelabs/risingwave/issues/7412 for bug description.
441            self.reset(funcs)?;
442        }
443        let mut stats = AggStateCacheStats::default();
444        futures::future::try_join_all(
445            self.states
446                .iter_mut()
447                .zip_eq_fast(storages)
448                .zip_eq_fast(funcs)
449                .map(|((state, storage), func)| {
450                    state.get_output(storage, func, self.ctx.group_key())
451                }),
452        )
453        .await
454        .map(|outputs_and_stats| {
455            outputs_and_stats
456                .into_iter()
457                .map(|(output, stat)| {
458                    stats.merge(stat);
459                    output
460                })
461                .collect::<Vec<_>>()
462        })
463        .map(|row| (OwnedRow::new(row), stats))
464    }
465
466    /// Build change for aggregation intermediate states, according to previous and current agg states.
467    /// The change should be applied to the intermediate state table.
468    ///
469    /// The saved previous inter states will be updated to the latest states after calling this method.
470    pub fn build_states_change(
471        &mut self,
472        funcs: &[BoxedAggregateFunction],
473    ) -> StreamExecutorResult<Option<Record<OwnedRow>>> {
474        let curr_inter_states = self.get_inter_states(funcs)?;
475        let change_type = Strtg::infer_change_type(
476            &self.ctx,
477            self.prev_inter_states.as_ref(),
478            &curr_inter_states,
479            self.row_count_index,
480        );
481
482        tracing::trace!(
483            group = ?self.ctx.group_key_row(),
484            prev_inter_states = ?self.prev_inter_states,
485            curr_inter_states = ?curr_inter_states,
486            change_type = ?change_type,
487            "build intermediate states change"
488        );
489
490        let Some(change_type) = change_type else {
491            return Ok(None);
492        };
493        Ok(Some(match change_type {
494            RecordType::Insert => {
495                let new_row = self
496                    .group_key()
497                    .map(GroupKey::table_row)
498                    .chain(&curr_inter_states)
499                    .into_owned_row();
500                self.prev_inter_states = Some(curr_inter_states);
501                Record::Insert { new_row }
502            }
503            RecordType::Delete => {
504                let prev_inter_states = self
505                    .prev_inter_states
506                    .take()
507                    .expect("must exist previous intermediate states");
508                let old_row = self
509                    .group_key()
510                    .map(GroupKey::table_row)
511                    .chain(prev_inter_states)
512                    .into_owned_row();
513                Record::Delete { old_row }
514            }
515            RecordType::Update => {
516                let new_row = self
517                    .group_key()
518                    .map(GroupKey::table_row)
519                    .chain(&curr_inter_states)
520                    .into_owned_row();
521                let prev_inter_states = self
522                    .prev_inter_states
523                    .replace(curr_inter_states)
524                    .expect("must exist previous intermediate states");
525                let old_row = self
526                    .group_key()
527                    .map(GroupKey::table_row)
528                    .chain(prev_inter_states)
529                    .into_owned_row();
530                Record::Update { old_row, new_row }
531            }
532        }))
533    }
534
535    /// Build aggregation result change, according to previous and current agg outputs.
536    /// The change should be yielded to downstream.
537    ///
538    /// The saved previous outputs will be updated to the latest outputs after this method.
539    ///
540    /// Note that this method is very likely to cost more than `build_states_change`, because it
541    /// needs to produce output for materialized input states which may involve state table read.
542    pub async fn build_outputs_change(
543        &mut self,
544        storages: &[AggStateStorage<S>],
545        funcs: &[BoxedAggregateFunction],
546    ) -> StreamExecutorResult<(Option<Record<OwnedRow>>, AggStateCacheStats)> {
547        let (curr_outputs, stats) = self.get_outputs(storages, funcs).await?;
548
549        let change_type = Strtg::infer_change_type(
550            &self.ctx,
551            self.prev_outputs.as_ref(),
552            &curr_outputs,
553            self.row_count_index,
554        );
555
556        tracing::trace!(
557            group = ?self.ctx.group_key_row(),
558            prev_outputs = ?self.prev_outputs,
559            curr_outputs = ?curr_outputs,
560            change_type = ?change_type,
561            "build outputs change"
562        );
563
564        let Some(change_type) = change_type else {
565            return Ok((None, stats));
566        };
567        Ok((
568            Some(match change_type {
569                RecordType::Insert => {
570                    let new_row = self
571                        .group_key()
572                        .map(GroupKey::table_row)
573                        .chain(&curr_outputs)
574                        .into_owned_row();
575                    // Although we say the `prev_outputs` field is not used in EOWC mode, we still
576                    // do the same here to keep the code simple. When it's actually running in EOWC
577                    // mode, `build_outputs_change` will be called only once for each group.
578                    self.prev_outputs = Some(curr_outputs);
579                    Record::Insert { new_row }
580                }
581                RecordType::Delete => {
582                    let prev_outputs = self.prev_outputs.take();
583                    let old_row = self
584                        .group_key()
585                        .map(GroupKey::table_row)
586                        .chain(prev_outputs)
587                        .into_owned_row();
588                    Record::Delete { old_row }
589                }
590                RecordType::Update => {
591                    let new_row = self
592                        .group_key()
593                        .map(GroupKey::table_row)
594                        .chain(&curr_outputs)
595                        .into_owned_row();
596                    let prev_outputs = self.prev_outputs.replace(curr_outputs);
597                    let old_row = self
598                        .group_key()
599                        .map(GroupKey::table_row)
600                        .chain(prev_outputs)
601                        .into_owned_row();
602                    Record::Update { old_row, new_row }
603                }
604            }),
605            stats,
606        ))
607    }
608}
609
610/// Stats for agg state cache operations.
611#[derive(Debug, Default)]
612pub struct AggStateCacheStats {
613    pub agg_state_cache_lookup_count: u64,
614    pub agg_state_cache_miss_count: u64,
615}
616
617impl AggStateCacheStats {
618    fn merge(&mut self, other: Self) {
619        self.agg_state_cache_lookup_count += other.agg_state_cache_lookup_count;
620        self.agg_state_cache_miss_count += other.agg_state_cache_miss_count;
621    }
622}