1use std::fmt::Debug;
16use std::marker::PhantomData;
17use std::sync::Arc;
18
19use risingwave_common::array::StreamChunk;
20use risingwave_common::array::stream_record::{Record, RecordType};
21use risingwave_common::bitmap::Bitmap;
22use risingwave_common::catalog::Schema;
23use risingwave_common::must_match;
24use risingwave_common::row::{OwnedRow, Row, RowExt};
25use risingwave_common::util::iter_util::ZipEqFast;
26use risingwave_common_estimate_size::EstimateSize;
27use risingwave_expr::aggregate::{AggCall, BoxedAggregateFunction};
28use risingwave_pb::stream_plan::PbAggNodeVersion;
29use risingwave_storage::StateStore;
30
31use super::agg_state::{AggState, AggStateStorage};
32use crate::common::table::state_table::StateTable;
33use crate::consistency::consistency_panic;
34use crate::executor::PkIndices;
35use crate::executor::error::StreamExecutorResult;
36
37#[derive(Debug)]
38pub struct Context {
39 group_key: Option<GroupKey>,
40}
41
42impl Context {
43 pub fn group_key(&self) -> Option<&GroupKey> {
44 self.group_key.as_ref()
45 }
46
47 pub fn group_key_row(&self) -> OwnedRow {
48 self.group_key()
49 .map(GroupKey::table_row)
50 .cloned()
51 .unwrap_or_default()
52 }
53}
54
55fn row_count_of(ctx: &Context, row: Option<impl Row>, row_count_col: usize) -> usize {
56 match row {
57 Some(row) => {
58 let mut row_count = row
59 .datum_at(row_count_col)
60 .expect("row count field should not be NULL")
61 .into_int64();
62
63 if row_count < 0 {
64 consistency_panic!(group = ?ctx.group_key_row(), row_count, "row count should be non-negative");
65
66 row_count = 0;
74 }
75 row_count.try_into().unwrap()
76 }
77 None => 0,
78 }
79}
80
81pub trait Strategy {
82 fn infer_change_type(
85 ctx: &Context,
86 prev_row: Option<&OwnedRow>,
87 curr_row: &OwnedRow,
88 row_count_col: usize,
89 ) -> Option<RecordType>;
90}
91
92pub struct AlwaysOutput;
94pub struct OnlyOutputIfHasInput;
97
98impl Strategy for AlwaysOutput {
99 fn infer_change_type(
100 ctx: &Context,
101 prev_row: Option<&OwnedRow>,
102 _curr_row: &OwnedRow,
103 row_count_col: usize,
104 ) -> Option<RecordType> {
105 let prev_row_count = row_count_of(ctx, prev_row, row_count_col);
106 match prev_row {
107 None => {
108 assert_eq!(prev_row_count, 0);
112
113 Some(RecordType::Insert)
115 }
116 Some(_prev_outputs) => Some(RecordType::Update),
127 }
128 }
129}
130
131impl Strategy for OnlyOutputIfHasInput {
132 fn infer_change_type(
133 ctx: &Context,
134 prev_row: Option<&OwnedRow>,
135 curr_row: &OwnedRow,
136 row_count_col: usize,
137 ) -> Option<RecordType> {
138 let prev_row_count = row_count_of(ctx, prev_row, row_count_col);
139 let curr_row_count = row_count_of(ctx, Some(curr_row), row_count_col);
140
141 match (prev_row_count, curr_row_count) {
142 (0, 0) => {
143 None
145 }
146 (0, _) => {
147 Some(RecordType::Insert)
149 }
150 (_, 0) => {
151 Some(RecordType::Delete)
153 }
154 (_, _) => {
155 if prev_row.expect("must exist previous row") == curr_row {
157 None
159 } else {
160 Some(RecordType::Update)
161 }
162 }
163 }
164 }
165}
166
167#[derive(Clone, Debug)]
169pub struct GroupKey {
170 row_prefix: OwnedRow,
171 table_pk_projection: Arc<[usize]>,
172}
173
174impl GroupKey {
175 pub fn new(row_prefix: OwnedRow, table_pk_projection: Option<Arc<[usize]>>) -> Self {
176 let table_pk_projection =
177 table_pk_projection.unwrap_or_else(|| (0..row_prefix.len()).collect());
178 Self {
179 row_prefix,
180 table_pk_projection,
181 }
182 }
183
184 pub fn len(&self) -> usize {
185 self.row_prefix.len()
186 }
187
188 pub fn is_empty(&self) -> bool {
189 self.row_prefix.is_empty()
190 }
191
192 pub fn table_row(&self) -> &OwnedRow {
194 &self.row_prefix
195 }
196
197 pub fn table_pk(&self) -> impl Row + '_ {
199 (&self.row_prefix).project(&self.table_pk_projection)
200 }
201
202 pub fn cache_key(&self) -> impl Row + '_ {
204 self.table_row()
205 }
206}
207
208pub struct AggGroup<S: StateStore, Strtg: Strategy> {
210 ctx: Context,
212
213 states: Vec<AggState>,
215
216 prev_inter_states: Option<OwnedRow>,
218
219 prev_outputs: Option<OwnedRow>,
222
223 row_count_index: usize,
225
226 emit_on_window_close: bool,
228
229 _phantom: PhantomData<(S, Strtg)>,
230}
231
232impl<S: StateStore, Strtg: Strategy> Debug for AggGroup<S, Strtg> {
233 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
234 f.debug_struct("AggGroup")
235 .field("group_key", &self.ctx.group_key)
236 .field("prev_inter_states", &self.prev_inter_states)
237 .field("prev_outputs", &self.prev_outputs)
238 .field("row_count_index", &self.row_count_index)
239 .field("emit_on_window_close", &self.emit_on_window_close)
240 .finish()
241 }
242}
243
244impl<S: StateStore, Strtg: Strategy> EstimateSize for AggGroup<S, Strtg> {
245 fn estimated_heap_size(&self) -> usize {
246 self.states
248 .iter()
249 .map(|state| state.estimated_heap_size())
250 .sum()
251 }
252}
253
254impl<S: StateStore, Strtg: Strategy> AggGroup<S, Strtg> {
255 #[allow(clippy::too_many_arguments)]
260 pub async fn create(
261 version: PbAggNodeVersion,
262 group_key: Option<GroupKey>,
263 agg_calls: &[AggCall],
264 agg_funcs: &[BoxedAggregateFunction],
265 storages: &[AggStateStorage<S>],
266 intermediate_state_table: &StateTable<S>,
267 pk_indices: &PkIndices,
268 row_count_index: usize,
269 emit_on_window_close: bool,
270 extreme_cache_size: usize,
271 input_schema: &Schema,
272 ) -> StreamExecutorResult<Self> {
273 let inter_states = intermediate_state_table
274 .get_row(group_key.as_ref().map(GroupKey::table_pk))
275 .await?;
276 if let Some(inter_states) = &inter_states {
277 assert_eq!(inter_states.len(), agg_calls.len());
278 }
279
280 let mut states = Vec::with_capacity(agg_calls.len());
281 for (idx, (agg_call, agg_func)) in agg_calls.iter().zip_eq_fast(agg_funcs).enumerate() {
282 let state = AggState::create(
283 version,
284 agg_call,
285 agg_func,
286 &storages[idx],
287 inter_states.as_ref().map(|s| &s[idx]),
288 pk_indices,
289 extreme_cache_size,
290 input_schema,
291 )?;
292 states.push(state);
293 }
294
295 let mut this = Self {
296 ctx: Context { group_key },
297 states,
298 prev_inter_states: inter_states,
299 prev_outputs: None, row_count_index,
301 emit_on_window_close,
302 _phantom: PhantomData,
303 };
304
305 if !this.emit_on_window_close && this.prev_inter_states.is_some() {
306 let (outputs, _stats) = this.get_outputs(storages, agg_funcs).await?;
307 this.prev_outputs = Some(outputs);
308 }
309
310 Ok(this)
311 }
312
313 #[allow(clippy::too_many_arguments)]
316 pub fn for_eowc_output(
317 version: PbAggNodeVersion,
318 group_key: Option<GroupKey>,
319 agg_calls: &[AggCall],
320 agg_funcs: &[BoxedAggregateFunction],
321 storages: &[AggStateStorage<S>],
322 inter_states: &OwnedRow,
323 pk_indices: &PkIndices,
324 row_count_index: usize,
325 emit_on_window_close: bool,
326 extreme_cache_size: usize,
327 input_schema: &Schema,
328 ) -> StreamExecutorResult<Self> {
329 let mut states = Vec::with_capacity(agg_calls.len());
330 for (idx, (agg_call, agg_func)) in agg_calls.iter().zip_eq_fast(agg_funcs).enumerate() {
331 let state = AggState::create(
332 version,
333 agg_call,
334 agg_func,
335 &storages[idx],
336 Some(&inter_states[idx]),
337 pk_indices,
338 extreme_cache_size,
339 input_schema,
340 )?;
341 states.push(state);
342 }
343
344 Ok(Self {
345 ctx: Context { group_key },
346 states,
347 prev_inter_states: None, prev_outputs: None, row_count_index,
350 emit_on_window_close,
351 _phantom: PhantomData,
352 })
353 }
354
355 pub fn group_key(&self) -> Option<&GroupKey> {
356 self.ctx.group_key()
357 }
358
359 fn curr_row_count(&self) -> usize {
361 let row_count_state = must_match!(
362 self.states[self.row_count_index],
363 AggState::Value(ref state) => state
364 );
365 row_count_of(&self.ctx, Some([row_count_state.as_datum().clone()]), 0)
366 }
367
368 pub(crate) fn is_uninitialized(&self) -> bool {
369 self.prev_inter_states.is_none()
370 }
371
372 pub async fn apply_chunk(
377 &mut self,
378 chunk: &StreamChunk,
379 calls: &[AggCall],
380 funcs: &[BoxedAggregateFunction],
381 visibilities: Vec<Bitmap>,
382 ) -> StreamExecutorResult<()> {
383 if self.curr_row_count() == 0 {
384 tracing::trace!(group = ?self.ctx.group_key_row(), "first time see this group");
385 }
386 for (((state, call), func), visibility) in (self.states.iter_mut())
387 .zip_eq_fast(calls)
388 .zip_eq_fast(funcs)
389 .zip_eq_fast(visibilities)
390 {
391 state.apply_chunk(chunk, call, func, visibility).await?;
392 }
393
394 if self.curr_row_count() == 0 {
395 tracing::trace!(group = ?self.ctx.group_key_row(), "last time see this group");
396 }
397
398 Ok(())
399 }
400
401 fn reset(&mut self, funcs: &[BoxedAggregateFunction]) -> StreamExecutorResult<()> {
404 for (state, func) in self.states.iter_mut().zip_eq_fast(funcs) {
405 state.reset(func)?;
406 }
407 Ok(())
408 }
409
410 fn get_inter_states(&self, funcs: &[BoxedAggregateFunction]) -> StreamExecutorResult<OwnedRow> {
412 let mut inter_states = Vec::with_capacity(self.states.len());
413 for (state, func) in self.states.iter().zip_eq_fast(funcs) {
414 let encoded = match state {
415 AggState::Value(s) => func.encode_state(s)?,
416 AggState::MaterializedInput(_) => None,
418 };
419 inter_states.push(encoded);
420 }
421 Ok(OwnedRow::new(inter_states))
422 }
423
424 async fn get_outputs(
429 &mut self,
430 storages: &[AggStateStorage<S>],
431 funcs: &[BoxedAggregateFunction],
432 ) -> StreamExecutorResult<(OwnedRow, AggStateCacheStats)> {
433 let row_count = self.curr_row_count();
434 if row_count == 0 {
435 self.reset(funcs)?;
442 }
443 let mut stats = AggStateCacheStats::default();
444 futures::future::try_join_all(
445 self.states
446 .iter_mut()
447 .zip_eq_fast(storages)
448 .zip_eq_fast(funcs)
449 .map(|((state, storage), func)| {
450 state.get_output(storage, func, self.ctx.group_key())
451 }),
452 )
453 .await
454 .map(|outputs_and_stats| {
455 outputs_and_stats
456 .into_iter()
457 .map(|(output, stat)| {
458 stats.merge(stat);
459 output
460 })
461 .collect::<Vec<_>>()
462 })
463 .map(|row| (OwnedRow::new(row), stats))
464 }
465
466 pub fn build_states_change(
471 &mut self,
472 funcs: &[BoxedAggregateFunction],
473 ) -> StreamExecutorResult<Option<Record<OwnedRow>>> {
474 let curr_inter_states = self.get_inter_states(funcs)?;
475 let change_type = Strtg::infer_change_type(
476 &self.ctx,
477 self.prev_inter_states.as_ref(),
478 &curr_inter_states,
479 self.row_count_index,
480 );
481
482 tracing::trace!(
483 group = ?self.ctx.group_key_row(),
484 prev_inter_states = ?self.prev_inter_states,
485 curr_inter_states = ?curr_inter_states,
486 change_type = ?change_type,
487 "build intermediate states change"
488 );
489
490 let Some(change_type) = change_type else {
491 return Ok(None);
492 };
493 Ok(Some(match change_type {
494 RecordType::Insert => {
495 let new_row = self
496 .group_key()
497 .map(GroupKey::table_row)
498 .chain(&curr_inter_states)
499 .into_owned_row();
500 self.prev_inter_states = Some(curr_inter_states);
501 Record::Insert { new_row }
502 }
503 RecordType::Delete => {
504 let prev_inter_states = self
505 .prev_inter_states
506 .take()
507 .expect("must exist previous intermediate states");
508 let old_row = self
509 .group_key()
510 .map(GroupKey::table_row)
511 .chain(prev_inter_states)
512 .into_owned_row();
513 Record::Delete { old_row }
514 }
515 RecordType::Update => {
516 let new_row = self
517 .group_key()
518 .map(GroupKey::table_row)
519 .chain(&curr_inter_states)
520 .into_owned_row();
521 let prev_inter_states = self
522 .prev_inter_states
523 .replace(curr_inter_states)
524 .expect("must exist previous intermediate states");
525 let old_row = self
526 .group_key()
527 .map(GroupKey::table_row)
528 .chain(prev_inter_states)
529 .into_owned_row();
530 Record::Update { old_row, new_row }
531 }
532 }))
533 }
534
535 pub async fn build_outputs_change(
543 &mut self,
544 storages: &[AggStateStorage<S>],
545 funcs: &[BoxedAggregateFunction],
546 ) -> StreamExecutorResult<(Option<Record<OwnedRow>>, AggStateCacheStats)> {
547 let (curr_outputs, stats) = self.get_outputs(storages, funcs).await?;
548
549 let change_type = Strtg::infer_change_type(
550 &self.ctx,
551 self.prev_outputs.as_ref(),
552 &curr_outputs,
553 self.row_count_index,
554 );
555
556 tracing::trace!(
557 group = ?self.ctx.group_key_row(),
558 prev_outputs = ?self.prev_outputs,
559 curr_outputs = ?curr_outputs,
560 change_type = ?change_type,
561 "build outputs change"
562 );
563
564 let Some(change_type) = change_type else {
565 return Ok((None, stats));
566 };
567 Ok((
568 Some(match change_type {
569 RecordType::Insert => {
570 let new_row = self
571 .group_key()
572 .map(GroupKey::table_row)
573 .chain(&curr_outputs)
574 .into_owned_row();
575 self.prev_outputs = Some(curr_outputs);
579 Record::Insert { new_row }
580 }
581 RecordType::Delete => {
582 let prev_outputs = self.prev_outputs.take();
583 let old_row = self
584 .group_key()
585 .map(GroupKey::table_row)
586 .chain(prev_outputs)
587 .into_owned_row();
588 Record::Delete { old_row }
589 }
590 RecordType::Update => {
591 let new_row = self
592 .group_key()
593 .map(GroupKey::table_row)
594 .chain(&curr_outputs)
595 .into_owned_row();
596 let prev_outputs = self.prev_outputs.replace(curr_outputs);
597 let old_row = self
598 .group_key()
599 .map(GroupKey::table_row)
600 .chain(prev_outputs)
601 .into_owned_row();
602 Record::Update { old_row, new_row }
603 }
604 }),
605 stats,
606 ))
607 }
608}
609
610#[derive(Debug, Default)]
612pub struct AggStateCacheStats {
613 pub agg_state_cache_lookup_count: u64,
614 pub agg_state_cache_miss_count: u64,
615}
616
617impl AggStateCacheStats {
618 fn merge(&mut self, other: Self) {
619 self.agg_state_cache_lookup_count += other.agg_state_cache_lookup_count;
620 self.agg_state_cache_miss_count += other.agg_state_cache_miss_count;
621 }
622}