risingwave_stream/executor/aggregate/
agg_group.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::Debug;
16use std::marker::PhantomData;
17use std::sync::Arc;
18
19use futures::future::try_join_all;
20use risingwave_common::array::StreamChunk;
21use risingwave_common::array::stream_record::{Record, RecordType};
22use risingwave_common::bitmap::Bitmap;
23use risingwave_common::catalog::Schema;
24use risingwave_common::must_match;
25use risingwave_common::row::{OwnedRow, Row, RowExt};
26use risingwave_common::util::iter_util::ZipEqFast;
27use risingwave_common_estimate_size::EstimateSize;
28use risingwave_expr::aggregate::{AggCall, BoxedAggregateFunction};
29use risingwave_pb::stream_plan::PbAggNodeVersion;
30use risingwave_storage::StateStore;
31
32use super::agg_state::{AggState, AggStateStorage};
33use crate::common::table::state_table::StateTable;
34use crate::consistency::consistency_panic;
35use crate::executor::PkIndices;
36use crate::executor::error::StreamExecutorResult;
37
38#[derive(Debug)]
39pub struct Context {
40    group_key: Option<GroupKey>,
41}
42
43impl Context {
44    pub fn group_key(&self) -> Option<&GroupKey> {
45        self.group_key.as_ref()
46    }
47
48    pub fn group_key_row(&self) -> OwnedRow {
49        self.group_key()
50            .map(GroupKey::table_row)
51            .cloned()
52            .unwrap_or_default()
53    }
54}
55
56fn row_count_of(ctx: &Context, row: Option<impl Row>, row_count_col: usize) -> usize {
57    match row {
58        Some(row) => {
59            let mut row_count = row
60                .datum_at(row_count_col)
61                .expect("row count field should not be NULL")
62                .into_int64();
63
64            if row_count < 0 {
65                consistency_panic!(group = ?ctx.group_key_row(), row_count, "row count should be non-negative");
66
67                // NOTE: Here is the case that an inconsistent `DELETE` arrives at HashAgg executor, and there's no
68                // corresponding group existing before (or has been deleted). In this case, previous row count should
69                // be `0` and current row count be `-1` after handling the `DELETE`. To ignore the inconsistency, we
70                // reset `row_count` to `0` here, so that `OnlyOutputIfHasInput` will return no change, so that the
71                // inconsistent will be hidden from downstream. This won't prevent from incorrect results of existing
72                // groups, but at least can prevent from downstream panicking due to non-existing keys.
73                // See https://github.com/risingwavelabs/risingwave/issues/14031 for more information.
74                row_count = 0;
75            }
76            row_count.try_into().unwrap()
77        }
78        None => 0,
79    }
80}
81
82pub trait Strategy {
83    /// Infer the change type of the aggregation result. Don't need to take the ownership of
84    /// `prev_row` and `curr_row`.
85    fn infer_change_type(
86        ctx: &Context,
87        prev_row: Option<&OwnedRow>,
88        curr_row: &OwnedRow,
89        row_count_col: usize,
90    ) -> Option<RecordType>;
91}
92
93/// The strategy that always outputs the aggregation result no matter there're input rows or not.
94pub struct AlwaysOutput;
95/// The strategy that only outputs the aggregation result when there're input rows. If row count
96/// drops to 0, the output row will be deleted.
97pub struct OnlyOutputIfHasInput;
98
99impl Strategy for AlwaysOutput {
100    fn infer_change_type(
101        ctx: &Context,
102        prev_row: Option<&OwnedRow>,
103        _curr_row: &OwnedRow,
104        row_count_col: usize,
105    ) -> Option<RecordType> {
106        let prev_row_count = row_count_of(ctx, prev_row, row_count_col);
107        match prev_row {
108            None => {
109                // First time to build changes, assert to ensure correctness.
110                // Note that it's not true vice versa, i.e. `prev_row_count == 0` doesn't imply
111                // `prev_outputs == None`.
112                assert_eq!(prev_row_count, 0);
113
114                // Generate output no matter whether current row count is 0 or not.
115                Some(RecordType::Insert)
116            }
117            // NOTE(kwannoel): We always output, even if the update is a no-op.
118            // e.g. the following will still be emitted downstream:
119            // ```
120            // U- 1
121            // U+ 1
122            // ```
123            // This is to support `approx_percentile` via `row_merge`, which requires
124            // both the lhs and rhs to always output updates per epoch, or not all.
125            // Otherwise we are unable to construct a full row, if only one side updates,
126            // as the `row_merge` executor is stateless.
127            Some(_prev_outputs) => Some(RecordType::Update),
128        }
129    }
130}
131
132impl Strategy for OnlyOutputIfHasInput {
133    fn infer_change_type(
134        ctx: &Context,
135        prev_row: Option<&OwnedRow>,
136        curr_row: &OwnedRow,
137        row_count_col: usize,
138    ) -> Option<RecordType> {
139        let prev_row_count = row_count_of(ctx, prev_row, row_count_col);
140        let curr_row_count = row_count_of(ctx, Some(curr_row), row_count_col);
141
142        match (prev_row_count, curr_row_count) {
143            (0, 0) => {
144                // No rows of current group exist.
145                None
146            }
147            (0, _) => {
148                // Insert new output row for this newly emerged group.
149                Some(RecordType::Insert)
150            }
151            (_, 0) => {
152                // Delete old output row for this newly disappeared group.
153                Some(RecordType::Delete)
154            }
155            (_, _) => {
156                // Update output row.
157                if prev_row.expect("must exist previous row") == curr_row {
158                    // No output change.
159                    None
160                } else {
161                    Some(RecordType::Update)
162                }
163            }
164        }
165    }
166}
167
168/// [`GroupKey`] wraps a concrete group key and handle its mapping to state table pk.
169#[derive(Clone, Debug)]
170pub struct GroupKey {
171    row_prefix: OwnedRow,
172    table_pk_projection: Arc<[usize]>,
173}
174
175impl GroupKey {
176    pub fn new(row_prefix: OwnedRow, table_pk_projection: Option<Arc<[usize]>>) -> Self {
177        let table_pk_projection =
178            table_pk_projection.unwrap_or_else(|| (0..row_prefix.len()).collect());
179        Self {
180            row_prefix,
181            table_pk_projection,
182        }
183    }
184
185    pub fn len(&self) -> usize {
186        self.row_prefix.len()
187    }
188
189    pub fn is_empty(&self) -> bool {
190        self.row_prefix.is_empty()
191    }
192
193    /// Get the group key for state table row prefix.
194    pub fn table_row(&self) -> &OwnedRow {
195        &self.row_prefix
196    }
197
198    /// Get the group key for state table pk prefix.
199    pub fn table_pk(&self) -> impl Row + '_ {
200        (&self.row_prefix).project(&self.table_pk_projection)
201    }
202
203    /// Get the group key for LRU cache key prefix.
204    pub fn cache_key(&self) -> impl Row + '_ {
205        self.table_row()
206    }
207}
208
209/// [`AggGroup`] manages agg states of all agg calls for one `group_key`.
210pub struct AggGroup<S: StateStore, Strtg: Strategy> {
211    /// Agg group context, containing the group key.
212    ctx: Context,
213
214    /// Current managed states for all [`AggCall`]s.
215    states: Vec<AggState>,
216
217    /// Previous intermediate states, stored in the intermediate state table.
218    prev_inter_states: Option<OwnedRow>,
219
220    /// Previous outputs, yielded to downstream.
221    /// If `EOWC` is true, this field is not used.
222    prev_outputs: Option<OwnedRow>,
223
224    /// Index of row count agg call (`count(*)`) in the call list.
225    row_count_index: usize,
226
227    /// Whether the emit policy is EOWC.
228    emit_on_window_close: bool,
229
230    _phantom: PhantomData<(S, Strtg)>,
231}
232
233impl<S: StateStore, Strtg: Strategy> Debug for AggGroup<S, Strtg> {
234    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
235        f.debug_struct("AggGroup")
236            .field("group_key", &self.ctx.group_key)
237            .field("prev_inter_states", &self.prev_inter_states)
238            .field("prev_outputs", &self.prev_outputs)
239            .field("row_count_index", &self.row_count_index)
240            .field("emit_on_window_close", &self.emit_on_window_close)
241            .finish()
242    }
243}
244
245impl<S: StateStore, Strtg: Strategy> EstimateSize for AggGroup<S, Strtg> {
246    fn estimated_heap_size(&self) -> usize {
247        // TODO(rc): should include the size of `prev_inter_states` and `prev_outputs`
248        self.states
249            .iter()
250            .map(|state| state.estimated_heap_size())
251            .sum()
252    }
253}
254
255impl<S: StateStore, Strtg: Strategy> AggGroup<S, Strtg> {
256    /// Create [`AggGroup`] for the given [`AggCall`]s and `group_key`.
257    /// For [`SimpleAggExecutor`], the `group_key` should be `None`.
258    ///
259    /// [`SimpleAggExecutor`]: crate::executor::aggregate::SimpleAggExecutor
260    #[allow(clippy::too_many_arguments)]
261    pub async fn create(
262        version: PbAggNodeVersion,
263        group_key: Option<GroupKey>,
264        agg_calls: &[AggCall],
265        agg_funcs: &[BoxedAggregateFunction],
266        storages: &[AggStateStorage<S>],
267        intermediate_state_table: &StateTable<S>,
268        pk_indices: &PkIndices,
269        row_count_index: usize,
270        emit_on_window_close: bool,
271        extreme_cache_size: usize,
272        input_schema: &Schema,
273    ) -> StreamExecutorResult<(Self, AggStateCacheStats)> {
274        let mut stats = AggStateCacheStats::default();
275
276        let inter_states = intermediate_state_table
277            .get_row(group_key.as_ref().map(GroupKey::table_pk))
278            .await?;
279        if let Some(inter_states) = &inter_states {
280            assert_eq!(inter_states.len(), agg_calls.len());
281        }
282
283        let mut states = Vec::with_capacity(agg_calls.len());
284        for (idx, (agg_call, agg_func)) in agg_calls.iter().zip_eq_fast(agg_funcs).enumerate() {
285            let state = AggState::create(
286                version,
287                agg_call,
288                agg_func,
289                &storages[idx],
290                inter_states.as_ref().map(|s| &s[idx]),
291                pk_indices,
292                extreme_cache_size,
293                input_schema,
294            )?;
295            states.push(state);
296        }
297
298        let mut this = Self {
299            ctx: Context { group_key },
300            states,
301            prev_inter_states: inter_states,
302            prev_outputs: None, // will be set below
303            row_count_index,
304            emit_on_window_close,
305            _phantom: PhantomData,
306        };
307
308        if !this.emit_on_window_close && this.prev_inter_states.is_some() {
309            let (outputs, init_stats) = this.get_outputs(storages, agg_funcs).await?;
310            this.prev_outputs = Some(outputs);
311            stats.merge(init_stats);
312        }
313
314        Ok((this, stats))
315    }
316
317    /// Create a group from intermediate states for EOWC output.
318    /// Will always produce `Insert` when building change.
319    #[allow(clippy::too_many_arguments)]
320    pub fn for_eowc_output(
321        version: PbAggNodeVersion,
322        group_key: Option<GroupKey>,
323        agg_calls: &[AggCall],
324        agg_funcs: &[BoxedAggregateFunction],
325        storages: &[AggStateStorage<S>],
326        inter_states: &OwnedRow,
327        pk_indices: &PkIndices,
328        row_count_index: usize,
329        emit_on_window_close: bool,
330        extreme_cache_size: usize,
331        input_schema: &Schema,
332    ) -> StreamExecutorResult<Self> {
333        let mut states = Vec::with_capacity(agg_calls.len());
334        for (idx, (agg_call, agg_func)) in agg_calls.iter().zip_eq_fast(agg_funcs).enumerate() {
335            let state = AggState::create(
336                version,
337                agg_call,
338                agg_func,
339                &storages[idx],
340                Some(&inter_states[idx]),
341                pk_indices,
342                extreme_cache_size,
343                input_schema,
344            )?;
345            states.push(state);
346        }
347
348        Ok(Self {
349            ctx: Context { group_key },
350            states,
351            prev_inter_states: None, // this doesn't matter
352            prev_outputs: None,      // this will make sure the outputs change to be `Insert`
353            row_count_index,
354            emit_on_window_close,
355            _phantom: PhantomData,
356        })
357    }
358
359    pub fn group_key(&self) -> Option<&GroupKey> {
360        self.ctx.group_key()
361    }
362
363    /// Get current row count of this group.
364    fn curr_row_count(&self) -> usize {
365        let row_count_state = must_match!(
366            self.states[self.row_count_index],
367            AggState::Value(ref state) => state
368        );
369        row_count_of(&self.ctx, Some([row_count_state.as_datum().clone()]), 0)
370    }
371
372    pub(crate) fn is_uninitialized(&self) -> bool {
373        self.prev_inter_states.is_none()
374    }
375
376    /// Apply input chunk to all managed agg states.
377    ///
378    /// `mappings` contains the column mappings from input chunk to each agg call.
379    /// `visibilities` contains the row visibility of the input chunk for each agg call.
380    pub async fn apply_chunk(
381        &mut self,
382        chunk: &StreamChunk,
383        calls: &[AggCall],
384        funcs: &[BoxedAggregateFunction],
385        visibilities: Vec<Bitmap>,
386    ) -> StreamExecutorResult<()> {
387        if self.curr_row_count() == 0 {
388            tracing::trace!(group = ?self.ctx.group_key_row(), "first time see this group");
389        }
390
391        let concurrency = 10;
392        let len = self.states.len();
393
394        for chunk_start in (0..len).step_by(concurrency) {
395            let chunk_end = std::cmp::min(chunk_start + concurrency, len);
396
397            // Create futures for this chunk
398            let futures = &mut self.states[chunk_start..chunk_end]
399                .iter_mut()
400                .zip_eq_fast(&calls[chunk_start..chunk_end])
401                .zip_eq_fast(&funcs[chunk_start..chunk_end])
402                .zip_eq_fast(&visibilities[chunk_start..chunk_end])
403                .map(|(((state, call), func), visibility)| {
404                    state.apply_chunk(chunk, call, func, visibility.clone())
405                });
406
407            try_join_all(futures).await?;
408        }
409
410        if self.curr_row_count() == 0 {
411            tracing::trace!(group = ?self.ctx.group_key_row(), "last time see this group");
412        }
413
414        Ok(())
415    }
416
417    /// Reset all in-memory states to their initial state, i.e. to reset all agg state structs to
418    /// the status as if they are just created, no input applied and no row in state table.
419    fn reset(&mut self, funcs: &[BoxedAggregateFunction]) -> StreamExecutorResult<()> {
420        for (state, func) in self.states.iter_mut().zip_eq_fast(funcs) {
421            state.reset(func)?;
422        }
423        Ok(())
424    }
425
426    /// Get the encoded intermediate states of all managed agg states.
427    fn get_inter_states(&self, funcs: &[BoxedAggregateFunction]) -> StreamExecutorResult<OwnedRow> {
428        let mut inter_states = Vec::with_capacity(self.states.len());
429        for (state, func) in self.states.iter().zip_eq_fast(funcs) {
430            let encoded = match state {
431                AggState::Value(s) => func.encode_state(s)?,
432                // For minput state, we don't need to store it in state table.
433                AggState::MaterializedInput(_) => None,
434            };
435            inter_states.push(encoded);
436        }
437        Ok(OwnedRow::new(inter_states))
438    }
439
440    /// Get the outputs of all managed agg states, without group key prefix.
441    /// Possibly need to read/sync from state table if the state not cached in memory.
442    /// This method is idempotent, i.e. it can be called multiple times and the outputs are
443    /// guaranteed to be the same.
444    async fn get_outputs(
445        &mut self,
446        storages: &[AggStateStorage<S>],
447        funcs: &[BoxedAggregateFunction],
448    ) -> StreamExecutorResult<(OwnedRow, AggStateCacheStats)> {
449        let row_count = self.curr_row_count();
450        if row_count == 0 {
451            // Reset all states (in fact only value states will be reset).
452            // This is important because for some agg calls (e.g. `sum`), if no row is applied,
453            // they should output NULL, for some other calls (e.g. `sum0`), they should output 0.
454            // This actually also prevents inconsistent negative row count from being worse.
455            // FIXME(rc): Deciding whether to reset states according to `row_count` is not precisely
456            // correct, see https://github.com/risingwavelabs/risingwave/issues/7412 for bug description.
457            self.reset(funcs)?;
458        }
459        let mut stats = AggStateCacheStats::default();
460        futures::future::try_join_all(
461            self.states
462                .iter_mut()
463                .zip_eq_fast(storages)
464                .zip_eq_fast(funcs)
465                .map(|((state, storage), func)| {
466                    state.get_output(storage, func, self.ctx.group_key())
467                }),
468        )
469        .await
470        .map(|outputs_and_stats| {
471            outputs_and_stats
472                .into_iter()
473                .map(|(output, stat)| {
474                    stats.merge(stat);
475                    output
476                })
477                .collect::<Vec<_>>()
478        })
479        .map(|row| (OwnedRow::new(row), stats))
480    }
481
482    /// Build change for aggregation intermediate states, according to previous and current agg states.
483    /// The change should be applied to the intermediate state table.
484    ///
485    /// The saved previous inter states will be updated to the latest states after calling this method.
486    pub fn build_states_change(
487        &mut self,
488        funcs: &[BoxedAggregateFunction],
489    ) -> StreamExecutorResult<Option<Record<OwnedRow>>> {
490        let curr_inter_states = self.get_inter_states(funcs)?;
491        let change_type = Strtg::infer_change_type(
492            &self.ctx,
493            self.prev_inter_states.as_ref(),
494            &curr_inter_states,
495            self.row_count_index,
496        );
497
498        tracing::trace!(
499            group = ?self.ctx.group_key_row(),
500            prev_inter_states = ?self.prev_inter_states,
501            curr_inter_states = ?curr_inter_states,
502            change_type = ?change_type,
503            "build intermediate states change"
504        );
505
506        let Some(change_type) = change_type else {
507            return Ok(None);
508        };
509        Ok(Some(match change_type {
510            RecordType::Insert => {
511                let new_row = self
512                    .group_key()
513                    .map(GroupKey::table_row)
514                    .chain(&curr_inter_states)
515                    .into_owned_row();
516                self.prev_inter_states = Some(curr_inter_states);
517                Record::Insert { new_row }
518            }
519            RecordType::Delete => {
520                let prev_inter_states = self
521                    .prev_inter_states
522                    .take()
523                    .expect("must exist previous intermediate states");
524                let old_row = self
525                    .group_key()
526                    .map(GroupKey::table_row)
527                    .chain(prev_inter_states)
528                    .into_owned_row();
529                Record::Delete { old_row }
530            }
531            RecordType::Update => {
532                let new_row = self
533                    .group_key()
534                    .map(GroupKey::table_row)
535                    .chain(&curr_inter_states)
536                    .into_owned_row();
537                let prev_inter_states = self
538                    .prev_inter_states
539                    .replace(curr_inter_states)
540                    .expect("must exist previous intermediate states");
541                let old_row = self
542                    .group_key()
543                    .map(GroupKey::table_row)
544                    .chain(prev_inter_states)
545                    .into_owned_row();
546                Record::Update { old_row, new_row }
547            }
548        }))
549    }
550
551    /// Build aggregation result change, according to previous and current agg outputs.
552    /// The change should be yielded to downstream.
553    ///
554    /// The saved previous outputs will be updated to the latest outputs after this method.
555    ///
556    /// Note that this method is very likely to cost more than `build_states_change`, because it
557    /// needs to produce output for materialized input states which may involve state table read.
558    pub async fn build_outputs_change(
559        &mut self,
560        storages: &[AggStateStorage<S>],
561        funcs: &[BoxedAggregateFunction],
562    ) -> StreamExecutorResult<(Option<Record<OwnedRow>>, AggStateCacheStats)> {
563        let (curr_outputs, stats) = self.get_outputs(storages, funcs).await?;
564
565        let change_type = Strtg::infer_change_type(
566            &self.ctx,
567            self.prev_outputs.as_ref(),
568            &curr_outputs,
569            self.row_count_index,
570        );
571
572        tracing::trace!(
573            group = ?self.ctx.group_key_row(),
574            prev_outputs = ?self.prev_outputs,
575            curr_outputs = ?curr_outputs,
576            change_type = ?change_type,
577            "build outputs change"
578        );
579
580        let Some(change_type) = change_type else {
581            return Ok((None, stats));
582        };
583        Ok((
584            Some(match change_type {
585                RecordType::Insert => {
586                    let new_row = self
587                        .group_key()
588                        .map(GroupKey::table_row)
589                        .chain(&curr_outputs)
590                        .into_owned_row();
591                    // Although we say the `prev_outputs` field is not used in EOWC mode, we still
592                    // do the same here to keep the code simple. When it's actually running in EOWC
593                    // mode, `build_outputs_change` will be called only once for each group.
594                    self.prev_outputs = Some(curr_outputs);
595                    Record::Insert { new_row }
596                }
597                RecordType::Delete => {
598                    let prev_outputs = self.prev_outputs.take();
599                    let old_row = self
600                        .group_key()
601                        .map(GroupKey::table_row)
602                        .chain(prev_outputs)
603                        .into_owned_row();
604                    Record::Delete { old_row }
605                }
606                RecordType::Update => {
607                    let new_row = self
608                        .group_key()
609                        .map(GroupKey::table_row)
610                        .chain(&curr_outputs)
611                        .into_owned_row();
612                    let prev_outputs = self.prev_outputs.replace(curr_outputs);
613                    let old_row = self
614                        .group_key()
615                        .map(GroupKey::table_row)
616                        .chain(prev_outputs)
617                        .into_owned_row();
618                    Record::Update { old_row, new_row }
619                }
620            }),
621            stats,
622        ))
623    }
624}
625
626/// Stats for agg state cache operations.
627#[derive(Debug, Default)]
628pub struct AggStateCacheStats {
629    pub agg_state_cache_lookup_count: u64,
630    pub agg_state_cache_miss_count: u64,
631}
632
633impl AggStateCacheStats {
634    fn merge(&mut self, other: Self) {
635        self.agg_state_cache_lookup_count += other.agg_state_cache_lookup_count;
636        self.agg_state_cache_miss_count += other.agg_state_cache_miss_count;
637    }
638}