risingwave_stream/executor/aggregate/
agg_group.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::Debug;
16use std::marker::PhantomData;
17use std::sync::Arc;
18
19use risingwave_common::array::StreamChunk;
20use risingwave_common::array::stream_record::{Record, RecordType};
21use risingwave_common::bitmap::Bitmap;
22use risingwave_common::catalog::Schema;
23use risingwave_common::must_match;
24use risingwave_common::row::{OwnedRow, Row, RowExt};
25use risingwave_common::util::iter_util::ZipEqFast;
26use risingwave_common_estimate_size::EstimateSize;
27use risingwave_expr::aggregate::{AggCall, BoxedAggregateFunction};
28use risingwave_pb::stream_plan::PbAggNodeVersion;
29use risingwave_storage::StateStore;
30
31use super::agg_state::{AggState, AggStateStorage};
32use crate::common::table::state_table::StateTable;
33use crate::consistency::consistency_panic;
34use crate::executor::StreamKeyRef;
35use crate::executor::error::StreamExecutorResult;
36
37#[derive(Debug)]
38pub struct Context {
39    group_key: Option<GroupKey>,
40}
41
42impl Context {
43    pub fn group_key(&self) -> Option<&GroupKey> {
44        self.group_key.as_ref()
45    }
46
47    pub fn group_key_row(&self) -> OwnedRow {
48        self.group_key()
49            .map(GroupKey::table_row)
50            .cloned()
51            .unwrap_or_default()
52    }
53}
54
55fn row_count_of(ctx: &Context, row: Option<impl Row>, row_count_col: usize) -> usize {
56    match row {
57        Some(row) => {
58            let mut row_count = row
59                .datum_at(row_count_col)
60                .expect("row count field should not be NULL")
61                .into_int64();
62
63            if row_count < 0 {
64                consistency_panic!(group = ?ctx.group_key_row(), row_count, "row count should be non-negative");
65
66                // NOTE: Here is the case that an inconsistent `DELETE` arrives at HashAgg executor, and there's no
67                // corresponding group existing before (or has been deleted). In this case, previous row count should
68                // be `0` and current row count be `-1` after handling the `DELETE`. To ignore the inconsistency, we
69                // reset `row_count` to `0` here, so that `OnlyOutputIfHasInput` will return no change, so that the
70                // inconsistent will be hidden from downstream. This won't prevent from incorrect results of existing
71                // groups, but at least can prevent from downstream panicking due to non-existing keys.
72                // See https://github.com/risingwavelabs/risingwave/issues/14031 for more information.
73                row_count = 0;
74            }
75            row_count.try_into().unwrap()
76        }
77        None => 0,
78    }
79}
80
81pub trait Strategy {
82    /// Infer the change type of the aggregation result. Don't need to take the ownership of
83    /// `prev_row` and `curr_row`.
84    fn infer_change_type(
85        ctx: &Context,
86        prev_row: Option<&OwnedRow>,
87        curr_row: &OwnedRow,
88        row_count_col: usize,
89    ) -> Option<RecordType>;
90}
91
92/// The strategy that always outputs the aggregation result no matter there're input rows or not.
93pub struct AlwaysOutput;
94/// The strategy that only outputs the aggregation result when there're input rows. If row count
95/// drops to 0, the output row will be deleted.
96pub struct OnlyOutputIfHasInput;
97
98impl Strategy for AlwaysOutput {
99    fn infer_change_type(
100        ctx: &Context,
101        prev_row: Option<&OwnedRow>,
102        _curr_row: &OwnedRow,
103        row_count_col: usize,
104    ) -> Option<RecordType> {
105        let prev_row_count = row_count_of(ctx, prev_row, row_count_col);
106        match prev_row {
107            None => {
108                // First time to build changes, assert to ensure correctness.
109                // Note that it's not true vice versa, i.e. `prev_row_count == 0` doesn't imply
110                // `prev_outputs == None`.
111                assert_eq!(prev_row_count, 0);
112
113                // Generate output no matter whether current row count is 0 or not.
114                Some(RecordType::Insert)
115            }
116            // NOTE(kwannoel): We always output, even if the update is a no-op.
117            // e.g. the following will still be emitted downstream:
118            // ```
119            // U- 1
120            // U+ 1
121            // ```
122            // This is to support `approx_percentile` via `row_merge`, which requires
123            // both the lhs and rhs to always output updates per epoch, or not all.
124            // Otherwise we are unable to construct a full row, if only one side updates,
125            // as the `row_merge` executor is stateless.
126            Some(_prev_outputs) => Some(RecordType::Update),
127        }
128    }
129}
130
131impl Strategy for OnlyOutputIfHasInput {
132    fn infer_change_type(
133        ctx: &Context,
134        prev_row: Option<&OwnedRow>,
135        curr_row: &OwnedRow,
136        row_count_col: usize,
137    ) -> Option<RecordType> {
138        let prev_row_count = row_count_of(ctx, prev_row, row_count_col);
139        let curr_row_count = row_count_of(ctx, Some(curr_row), row_count_col);
140
141        match (prev_row_count, curr_row_count) {
142            (0, 0) => {
143                // No rows of current group exist.
144                None
145            }
146            (0, _) => {
147                // Insert new output row for this newly emerged group.
148                Some(RecordType::Insert)
149            }
150            (_, 0) => {
151                // Delete old output row for this newly disappeared group.
152                Some(RecordType::Delete)
153            }
154            (_, _) => {
155                // Update output row.
156                if prev_row.expect("must exist previous row") == curr_row {
157                    // No output change.
158                    None
159                } else {
160                    Some(RecordType::Update)
161                }
162            }
163        }
164    }
165}
166
167/// [`GroupKey`] wraps a concrete group key and handle its mapping to state table pk.
168#[derive(Clone, Debug)]
169pub struct GroupKey {
170    row_prefix: OwnedRow,
171    table_pk_projection: Arc<[usize]>,
172}
173
174impl GroupKey {
175    pub fn new(row_prefix: OwnedRow, table_pk_projection: Option<Arc<[usize]>>) -> Self {
176        let table_pk_projection =
177            table_pk_projection.unwrap_or_else(|| (0..row_prefix.len()).collect());
178        Self {
179            row_prefix,
180            table_pk_projection,
181        }
182    }
183
184    pub fn len(&self) -> usize {
185        self.row_prefix.len()
186    }
187
188    pub fn is_empty(&self) -> bool {
189        self.row_prefix.is_empty()
190    }
191
192    /// Get the group key for state table row prefix.
193    pub fn table_row(&self) -> &OwnedRow {
194        &self.row_prefix
195    }
196
197    /// Get the group key for state table pk prefix.
198    pub fn table_pk(&self) -> impl Row + '_ {
199        (&self.row_prefix).project(&self.table_pk_projection)
200    }
201
202    /// Get the group key for LRU cache key prefix.
203    pub fn cache_key(&self) -> impl Row + '_ {
204        self.table_row()
205    }
206}
207
208/// [`AggGroup`] manages agg states of all agg calls for one `group_key`.
209pub struct AggGroup<S: StateStore, Strtg: Strategy> {
210    /// Agg group context, containing the group key.
211    ctx: Context,
212
213    /// Current managed states for all [`AggCall`]s.
214    states: Vec<AggState>,
215
216    /// Previous intermediate states, stored in the intermediate state table.
217    prev_inter_states: Option<OwnedRow>,
218
219    /// Previous outputs, yielded to downstream.
220    /// If `EOWC` is true, this field is not used.
221    prev_outputs: Option<OwnedRow>,
222
223    /// Index of row count agg call (`count(*)`) in the call list.
224    row_count_index: usize,
225
226    /// Whether the emit policy is EOWC.
227    emit_on_window_close: bool,
228
229    _phantom: PhantomData<(S, Strtg)>,
230}
231
232impl<S: StateStore, Strtg: Strategy> Debug for AggGroup<S, Strtg> {
233    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
234        f.debug_struct("AggGroup")
235            .field("group_key", &self.ctx.group_key)
236            .field("prev_inter_states", &self.prev_inter_states)
237            .field("prev_outputs", &self.prev_outputs)
238            .field("row_count_index", &self.row_count_index)
239            .field("emit_on_window_close", &self.emit_on_window_close)
240            .finish()
241    }
242}
243
244impl<S: StateStore, Strtg: Strategy> EstimateSize for AggGroup<S, Strtg> {
245    fn estimated_heap_size(&self) -> usize {
246        // TODO(rc): should include the size of `prev_inter_states` and `prev_outputs`
247        self.states
248            .iter()
249            .map(|state| state.estimated_heap_size())
250            .sum()
251    }
252}
253
254impl<S: StateStore, Strtg: Strategy> AggGroup<S, Strtg> {
255    /// Create [`AggGroup`] for the given [`AggCall`]s and `group_key`.
256    /// For [`SimpleAggExecutor`], the `group_key` should be `None`.
257    ///
258    /// [`SimpleAggExecutor`]: crate::executor::aggregate::SimpleAggExecutor
259    #[allow(clippy::too_many_arguments)]
260    pub async fn create(
261        version: PbAggNodeVersion,
262        group_key: Option<GroupKey>,
263        agg_calls: &[AggCall],
264        agg_funcs: &[BoxedAggregateFunction],
265        storages: &[AggStateStorage<S>],
266        intermediate_state_table: &StateTable<S>,
267        stream_key: StreamKeyRef<'_>,
268        row_count_index: usize,
269        emit_on_window_close: bool,
270        extreme_cache_size: usize,
271        input_schema: &Schema,
272    ) -> StreamExecutorResult<(Self, AggStateCacheStats)> {
273        let mut stats = AggStateCacheStats::default();
274
275        let inter_states = intermediate_state_table
276            .get_row(group_key.as_ref().map(GroupKey::table_pk))
277            .await?;
278        if let Some(inter_states) = &inter_states {
279            assert_eq!(inter_states.len(), agg_calls.len());
280        }
281
282        let mut states = Vec::with_capacity(agg_calls.len());
283        for (idx, (agg_call, agg_func)) in agg_calls.iter().zip_eq_fast(agg_funcs).enumerate() {
284            let state = AggState::create(
285                version,
286                agg_call,
287                agg_func,
288                &storages[idx],
289                inter_states.as_ref().map(|s| &s[idx]),
290                stream_key,
291                extreme_cache_size,
292                input_schema,
293            )?;
294            states.push(state);
295        }
296
297        let mut this = Self {
298            ctx: Context { group_key },
299            states,
300            prev_inter_states: inter_states,
301            prev_outputs: None, // will be set below
302            row_count_index,
303            emit_on_window_close,
304            _phantom: PhantomData,
305        };
306
307        if !this.emit_on_window_close && this.prev_inter_states.is_some() {
308            let (outputs, init_stats) = this.get_outputs(storages, agg_funcs).await?;
309            this.prev_outputs = Some(outputs);
310            stats.merge(init_stats);
311        }
312
313        Ok((this, stats))
314    }
315
316    /// Create a group from intermediate states for EOWC output.
317    /// Will always produce `Insert` when building change.
318    #[allow(clippy::too_many_arguments)]
319    pub fn for_eowc_output(
320        version: PbAggNodeVersion,
321        group_key: Option<GroupKey>,
322        agg_calls: &[AggCall],
323        agg_funcs: &[BoxedAggregateFunction],
324        storages: &[AggStateStorage<S>],
325        inter_states: &OwnedRow,
326        stream_key: StreamKeyRef<'_>,
327        row_count_index: usize,
328        emit_on_window_close: bool,
329        extreme_cache_size: usize,
330        input_schema: &Schema,
331    ) -> StreamExecutorResult<Self> {
332        let mut states = Vec::with_capacity(agg_calls.len());
333        for (idx, (agg_call, agg_func)) in agg_calls.iter().zip_eq_fast(agg_funcs).enumerate() {
334            let state = AggState::create(
335                version,
336                agg_call,
337                agg_func,
338                &storages[idx],
339                Some(&inter_states[idx]),
340                stream_key,
341                extreme_cache_size,
342                input_schema,
343            )?;
344            states.push(state);
345        }
346
347        Ok(Self {
348            ctx: Context { group_key },
349            states,
350            prev_inter_states: None, // this doesn't matter
351            prev_outputs: None,      // this will make sure the outputs change to be `Insert`
352            row_count_index,
353            emit_on_window_close,
354            _phantom: PhantomData,
355        })
356    }
357
358    pub fn group_key(&self) -> Option<&GroupKey> {
359        self.ctx.group_key()
360    }
361
362    /// Get current row count of this group.
363    fn curr_row_count(&self) -> usize {
364        let row_count_state = must_match!(
365            self.states[self.row_count_index],
366            AggState::Value(ref state) => state
367        );
368        row_count_of(&self.ctx, Some([row_count_state.as_datum().clone()]), 0)
369    }
370
371    pub(crate) fn is_uninitialized(&self) -> bool {
372        self.prev_inter_states.is_none()
373    }
374
375    /// Apply input chunk to all managed agg states.
376    ///
377    /// `mappings` contains the column mappings from input chunk to each agg call.
378    /// `visibilities` contains the row visibility of the input chunk for each agg call.
379    pub async fn apply_chunk(
380        &mut self,
381        chunk: &StreamChunk,
382        calls: &[AggCall],
383        funcs: &[BoxedAggregateFunction],
384        visibilities: Vec<Bitmap>,
385    ) -> StreamExecutorResult<()> {
386        if self.curr_row_count() == 0 {
387            tracing::trace!(group = ?self.ctx.group_key_row(), "first time see this group");
388        }
389        for (((state, call), func), visibility) in (self.states.iter_mut())
390            .zip_eq_fast(calls)
391            .zip_eq_fast(funcs)
392            .zip_eq_fast(visibilities)
393        {
394            state.apply_chunk(chunk, call, func, visibility).await?;
395        }
396
397        if self.curr_row_count() == 0 {
398            tracing::trace!(group = ?self.ctx.group_key_row(), "last time see this group");
399        }
400
401        Ok(())
402    }
403
404    /// Reset all in-memory states to their initial state, i.e. to reset all agg state structs to
405    /// the status as if they are just created, no input applied and no row in state table.
406    fn reset(&mut self, funcs: &[BoxedAggregateFunction]) -> StreamExecutorResult<()> {
407        for (state, func) in self.states.iter_mut().zip_eq_fast(funcs) {
408            state.reset(func)?;
409        }
410        Ok(())
411    }
412
413    /// Get the encoded intermediate states of all managed agg states.
414    fn get_inter_states(&self, funcs: &[BoxedAggregateFunction]) -> StreamExecutorResult<OwnedRow> {
415        let mut inter_states = Vec::with_capacity(self.states.len());
416        for (state, func) in self.states.iter().zip_eq_fast(funcs) {
417            let encoded = match state {
418                AggState::Value(s) => func.encode_state(s)?,
419                // For minput state, we don't need to store it in state table.
420                AggState::MaterializedInput(_) => None,
421            };
422            inter_states.push(encoded);
423        }
424        Ok(OwnedRow::new(inter_states))
425    }
426
427    /// Get the outputs of all managed agg states, without group key prefix.
428    /// Possibly need to read/sync from state table if the state not cached in memory.
429    /// This method is idempotent, i.e. it can be called multiple times and the outputs are
430    /// guaranteed to be the same.
431    async fn get_outputs(
432        &mut self,
433        storages: &[AggStateStorage<S>],
434        funcs: &[BoxedAggregateFunction],
435    ) -> StreamExecutorResult<(OwnedRow, AggStateCacheStats)> {
436        let row_count = self.curr_row_count();
437        if row_count == 0 {
438            // Reset all states (in fact only value states will be reset).
439            // This is important because for some agg calls (e.g. `sum`), if no row is applied,
440            // they should output NULL, for some other calls (e.g. `sum0`), they should output 0.
441            // This actually also prevents inconsistent negative row count from being worse.
442            // FIXME(rc): Deciding whether to reset states according to `row_count` is not precisely
443            // correct, see https://github.com/risingwavelabs/risingwave/issues/7412 for bug description.
444            self.reset(funcs)?;
445        }
446        let mut stats = AggStateCacheStats::default();
447        futures::future::try_join_all(
448            self.states
449                .iter_mut()
450                .zip_eq_fast(storages)
451                .zip_eq_fast(funcs)
452                .map(|((state, storage), func)| {
453                    state.get_output(storage, func, self.ctx.group_key())
454                }),
455        )
456        .await
457        .map(|outputs_and_stats| {
458            outputs_and_stats
459                .into_iter()
460                .map(|(output, stat)| {
461                    stats.merge(stat);
462                    output
463                })
464                .collect::<Vec<_>>()
465        })
466        .map(|row| (OwnedRow::new(row), stats))
467    }
468
469    /// Build change for aggregation intermediate states, according to previous and current agg states.
470    /// The change should be applied to the intermediate state table.
471    ///
472    /// The saved previous inter states will be updated to the latest states after calling this method.
473    pub fn build_states_change(
474        &mut self,
475        funcs: &[BoxedAggregateFunction],
476    ) -> StreamExecutorResult<Option<Record<OwnedRow>>> {
477        let curr_inter_states = self.get_inter_states(funcs)?;
478        let change_type = Strtg::infer_change_type(
479            &self.ctx,
480            self.prev_inter_states.as_ref(),
481            &curr_inter_states,
482            self.row_count_index,
483        );
484
485        tracing::trace!(
486            group = ?self.ctx.group_key_row(),
487            prev_inter_states = ?self.prev_inter_states,
488            curr_inter_states = ?curr_inter_states,
489            change_type = ?change_type,
490            "build intermediate states change"
491        );
492
493        let Some(change_type) = change_type else {
494            return Ok(None);
495        };
496        Ok(Some(match change_type {
497            RecordType::Insert => {
498                let new_row = self
499                    .group_key()
500                    .map(GroupKey::table_row)
501                    .chain(&curr_inter_states)
502                    .into_owned_row();
503                self.prev_inter_states = Some(curr_inter_states);
504                Record::Insert { new_row }
505            }
506            RecordType::Delete => {
507                let prev_inter_states = self
508                    .prev_inter_states
509                    .take()
510                    .expect("must exist previous intermediate states");
511                let old_row = self
512                    .group_key()
513                    .map(GroupKey::table_row)
514                    .chain(prev_inter_states)
515                    .into_owned_row();
516                Record::Delete { old_row }
517            }
518            RecordType::Update => {
519                let new_row = self
520                    .group_key()
521                    .map(GroupKey::table_row)
522                    .chain(&curr_inter_states)
523                    .into_owned_row();
524                let prev_inter_states = self
525                    .prev_inter_states
526                    .replace(curr_inter_states)
527                    .expect("must exist previous intermediate states");
528                let old_row = self
529                    .group_key()
530                    .map(GroupKey::table_row)
531                    .chain(prev_inter_states)
532                    .into_owned_row();
533                Record::Update { old_row, new_row }
534            }
535        }))
536    }
537
538    /// Build aggregation result change, according to previous and current agg outputs.
539    /// The change should be yielded to downstream.
540    ///
541    /// The saved previous outputs will be updated to the latest outputs after this method.
542    ///
543    /// Note that this method is very likely to cost more than `build_states_change`, because it
544    /// needs to produce output for materialized input states which may involve state table read.
545    pub async fn build_outputs_change(
546        &mut self,
547        storages: &[AggStateStorage<S>],
548        funcs: &[BoxedAggregateFunction],
549    ) -> StreamExecutorResult<(Option<Record<OwnedRow>>, AggStateCacheStats)> {
550        let (curr_outputs, stats) = self.get_outputs(storages, funcs).await?;
551
552        let change_type = Strtg::infer_change_type(
553            &self.ctx,
554            self.prev_outputs.as_ref(),
555            &curr_outputs,
556            self.row_count_index,
557        );
558
559        tracing::trace!(
560            group = ?self.ctx.group_key_row(),
561            prev_outputs = ?self.prev_outputs,
562            curr_outputs = ?curr_outputs,
563            change_type = ?change_type,
564            "build outputs change"
565        );
566
567        let Some(change_type) = change_type else {
568            return Ok((None, stats));
569        };
570        Ok((
571            Some(match change_type {
572                RecordType::Insert => {
573                    let new_row = self
574                        .group_key()
575                        .map(GroupKey::table_row)
576                        .chain(&curr_outputs)
577                        .into_owned_row();
578                    // Although we say the `prev_outputs` field is not used in EOWC mode, we still
579                    // do the same here to keep the code simple. When it's actually running in EOWC
580                    // mode, `build_outputs_change` will be called only once for each group.
581                    self.prev_outputs = Some(curr_outputs);
582                    Record::Insert { new_row }
583                }
584                RecordType::Delete => {
585                    let prev_outputs = self.prev_outputs.take();
586                    let old_row = self
587                        .group_key()
588                        .map(GroupKey::table_row)
589                        .chain(prev_outputs)
590                        .into_owned_row();
591                    Record::Delete { old_row }
592                }
593                RecordType::Update => {
594                    let new_row = self
595                        .group_key()
596                        .map(GroupKey::table_row)
597                        .chain(&curr_outputs)
598                        .into_owned_row();
599                    let prev_outputs = self.prev_outputs.replace(curr_outputs);
600                    let old_row = self
601                        .group_key()
602                        .map(GroupKey::table_row)
603                        .chain(prev_outputs)
604                        .into_owned_row();
605                    Record::Update { old_row, new_row }
606                }
607            }),
608            stats,
609        ))
610    }
611}
612
613/// Stats for agg state cache operations.
614#[derive(Debug, Default)]
615pub struct AggStateCacheStats {
616    pub agg_state_cache_lookup_count: u64,
617    pub agg_state_cache_miss_count: u64,
618}
619
620impl AggStateCacheStats {
621    fn merge(&mut self, other: Self) {
622        self.agg_state_cache_lookup_count += other.agg_state_cache_lookup_count;
623        self.agg_state_cache_miss_count += other.agg_state_cache_miss_count;
624    }
625}