1use std::fmt::Debug;
16use std::marker::PhantomData;
17use std::sync::Arc;
18
19use risingwave_common::array::StreamChunk;
20use risingwave_common::array::stream_record::{Record, RecordType};
21use risingwave_common::bitmap::Bitmap;
22use risingwave_common::catalog::Schema;
23use risingwave_common::must_match;
24use risingwave_common::row::{OwnedRow, Row, RowExt};
25use risingwave_common::util::iter_util::ZipEqFast;
26use risingwave_common_estimate_size::EstimateSize;
27use risingwave_expr::aggregate::{AggCall, BoxedAggregateFunction};
28use risingwave_pb::stream_plan::PbAggNodeVersion;
29use risingwave_storage::StateStore;
30
31use super::agg_state::{AggState, AggStateStorage};
32use crate::common::table::state_table::StateTable;
33use crate::consistency::consistency_panic;
34use crate::executor::StreamKeyRef;
35use crate::executor::error::StreamExecutorResult;
36
37#[derive(Debug)]
38pub struct Context {
39 group_key: Option<GroupKey>,
40}
41
42impl Context {
43 pub fn group_key(&self) -> Option<&GroupKey> {
44 self.group_key.as_ref()
45 }
46
47 pub fn group_key_row(&self) -> OwnedRow {
48 self.group_key()
49 .map(GroupKey::table_row)
50 .cloned()
51 .unwrap_or_default()
52 }
53}
54
55fn row_count_of(ctx: &Context, row: Option<impl Row>, row_count_col: usize) -> usize {
56 match row {
57 Some(row) => {
58 let mut row_count = row
59 .datum_at(row_count_col)
60 .expect("row count field should not be NULL")
61 .into_int64();
62
63 if row_count < 0 {
64 consistency_panic!(group = ?ctx.group_key_row(), row_count, "row count should be non-negative");
65
66 row_count = 0;
74 }
75 row_count.try_into().unwrap()
76 }
77 None => 0,
78 }
79}
80
81pub trait Strategy {
82 fn infer_change_type(
85 ctx: &Context,
86 prev_row: Option<&OwnedRow>,
87 curr_row: &OwnedRow,
88 row_count_col: usize,
89 ) -> Option<RecordType>;
90}
91
92pub struct AlwaysOutput;
94pub struct OnlyOutputIfHasInput;
97
98impl Strategy for AlwaysOutput {
99 fn infer_change_type(
100 ctx: &Context,
101 prev_row: Option<&OwnedRow>,
102 _curr_row: &OwnedRow,
103 row_count_col: usize,
104 ) -> Option<RecordType> {
105 let prev_row_count = row_count_of(ctx, prev_row, row_count_col);
106 match prev_row {
107 None => {
108 assert_eq!(prev_row_count, 0);
112
113 Some(RecordType::Insert)
115 }
116 Some(_prev_outputs) => Some(RecordType::Update),
127 }
128 }
129}
130
131impl Strategy for OnlyOutputIfHasInput {
132 fn infer_change_type(
133 ctx: &Context,
134 prev_row: Option<&OwnedRow>,
135 curr_row: &OwnedRow,
136 row_count_col: usize,
137 ) -> Option<RecordType> {
138 let prev_row_count = row_count_of(ctx, prev_row, row_count_col);
139 let curr_row_count = row_count_of(ctx, Some(curr_row), row_count_col);
140
141 match (prev_row_count, curr_row_count) {
142 (0, 0) => {
143 None
145 }
146 (0, _) => {
147 Some(RecordType::Insert)
149 }
150 (_, 0) => {
151 Some(RecordType::Delete)
153 }
154 (_, _) => {
155 if prev_row.expect("must exist previous row") == curr_row {
157 None
159 } else {
160 Some(RecordType::Update)
161 }
162 }
163 }
164 }
165}
166
167#[derive(Clone, Debug)]
169pub struct GroupKey {
170 row_prefix: OwnedRow,
171 table_pk_projection: Arc<[usize]>,
172}
173
174impl GroupKey {
175 pub fn new(row_prefix: OwnedRow, table_pk_projection: Option<Arc<[usize]>>) -> Self {
176 let table_pk_projection =
177 table_pk_projection.unwrap_or_else(|| (0..row_prefix.len()).collect());
178 Self {
179 row_prefix,
180 table_pk_projection,
181 }
182 }
183
184 pub fn len(&self) -> usize {
185 self.row_prefix.len()
186 }
187
188 pub fn is_empty(&self) -> bool {
189 self.row_prefix.is_empty()
190 }
191
192 pub fn table_row(&self) -> &OwnedRow {
194 &self.row_prefix
195 }
196
197 pub fn table_pk(&self) -> impl Row + '_ {
199 (&self.row_prefix).project(&self.table_pk_projection)
200 }
201
202 pub fn cache_key(&self) -> impl Row + '_ {
204 self.table_row()
205 }
206}
207
208pub struct AggGroup<S: StateStore, Strtg: Strategy> {
210 ctx: Context,
212
213 states: Vec<AggState>,
215
216 prev_inter_states: Option<OwnedRow>,
218
219 prev_outputs: Option<OwnedRow>,
222
223 row_count_index: usize,
225
226 emit_on_window_close: bool,
228
229 _phantom: PhantomData<(S, Strtg)>,
230}
231
232impl<S: StateStore, Strtg: Strategy> Debug for AggGroup<S, Strtg> {
233 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
234 f.debug_struct("AggGroup")
235 .field("group_key", &self.ctx.group_key)
236 .field("prev_inter_states", &self.prev_inter_states)
237 .field("prev_outputs", &self.prev_outputs)
238 .field("row_count_index", &self.row_count_index)
239 .field("emit_on_window_close", &self.emit_on_window_close)
240 .finish()
241 }
242}
243
244impl<S: StateStore, Strtg: Strategy> EstimateSize for AggGroup<S, Strtg> {
245 fn estimated_heap_size(&self) -> usize {
246 self.states
248 .iter()
249 .map(|state| state.estimated_heap_size())
250 .sum()
251 }
252}
253
254impl<S: StateStore, Strtg: Strategy> AggGroup<S, Strtg> {
255 #[allow(clippy::too_many_arguments)]
260 pub async fn create(
261 version: PbAggNodeVersion,
262 group_key: Option<GroupKey>,
263 agg_calls: &[AggCall],
264 agg_funcs: &[BoxedAggregateFunction],
265 storages: &[AggStateStorage<S>],
266 intermediate_state_table: &StateTable<S>,
267 stream_key: StreamKeyRef<'_>,
268 row_count_index: usize,
269 emit_on_window_close: bool,
270 extreme_cache_size: usize,
271 input_schema: &Schema,
272 ) -> StreamExecutorResult<(Self, AggStateCacheStats)> {
273 let mut stats = AggStateCacheStats::default();
274
275 let inter_states = intermediate_state_table
276 .get_row(group_key.as_ref().map(GroupKey::table_pk))
277 .await?;
278 if let Some(inter_states) = &inter_states {
279 assert_eq!(inter_states.len(), agg_calls.len());
280 }
281
282 let mut states = Vec::with_capacity(agg_calls.len());
283 for (idx, (agg_call, agg_func)) in agg_calls.iter().zip_eq_fast(agg_funcs).enumerate() {
284 let state = AggState::create(
285 version,
286 agg_call,
287 agg_func,
288 &storages[idx],
289 inter_states.as_ref().map(|s| &s[idx]),
290 stream_key,
291 extreme_cache_size,
292 input_schema,
293 )?;
294 states.push(state);
295 }
296
297 let mut this = Self {
298 ctx: Context { group_key },
299 states,
300 prev_inter_states: inter_states,
301 prev_outputs: None, row_count_index,
303 emit_on_window_close,
304 _phantom: PhantomData,
305 };
306
307 if !this.emit_on_window_close && this.prev_inter_states.is_some() {
308 let (outputs, init_stats) = this.get_outputs(storages, agg_funcs).await?;
309 this.prev_outputs = Some(outputs);
310 stats.merge(init_stats);
311 }
312
313 Ok((this, stats))
314 }
315
316 #[allow(clippy::too_many_arguments)]
319 pub fn for_eowc_output(
320 version: PbAggNodeVersion,
321 group_key: Option<GroupKey>,
322 agg_calls: &[AggCall],
323 agg_funcs: &[BoxedAggregateFunction],
324 storages: &[AggStateStorage<S>],
325 inter_states: &OwnedRow,
326 stream_key: StreamKeyRef<'_>,
327 row_count_index: usize,
328 emit_on_window_close: bool,
329 extreme_cache_size: usize,
330 input_schema: &Schema,
331 ) -> StreamExecutorResult<Self> {
332 let mut states = Vec::with_capacity(agg_calls.len());
333 for (idx, (agg_call, agg_func)) in agg_calls.iter().zip_eq_fast(agg_funcs).enumerate() {
334 let state = AggState::create(
335 version,
336 agg_call,
337 agg_func,
338 &storages[idx],
339 Some(&inter_states[idx]),
340 stream_key,
341 extreme_cache_size,
342 input_schema,
343 )?;
344 states.push(state);
345 }
346
347 Ok(Self {
348 ctx: Context { group_key },
349 states,
350 prev_inter_states: None, prev_outputs: None, row_count_index,
353 emit_on_window_close,
354 _phantom: PhantomData,
355 })
356 }
357
358 pub fn group_key(&self) -> Option<&GroupKey> {
359 self.ctx.group_key()
360 }
361
362 fn curr_row_count(&self) -> usize {
364 let row_count_state = must_match!(
365 self.states[self.row_count_index],
366 AggState::Value(ref state) => state
367 );
368 row_count_of(&self.ctx, Some([row_count_state.as_datum().clone()]), 0)
369 }
370
371 pub(crate) fn is_uninitialized(&self) -> bool {
372 self.prev_inter_states.is_none()
373 }
374
375 pub async fn apply_chunk(
380 &mut self,
381 chunk: &StreamChunk,
382 calls: &[AggCall],
383 funcs: &[BoxedAggregateFunction],
384 visibilities: Vec<Bitmap>,
385 ) -> StreamExecutorResult<()> {
386 if self.curr_row_count() == 0 {
387 tracing::trace!(group = ?self.ctx.group_key_row(), "first time see this group");
388 }
389 for (((state, call), func), visibility) in (self.states.iter_mut())
390 .zip_eq_fast(calls)
391 .zip_eq_fast(funcs)
392 .zip_eq_fast(visibilities)
393 {
394 state.apply_chunk(chunk, call, func, visibility).await?;
395 }
396
397 if self.curr_row_count() == 0 {
398 tracing::trace!(group = ?self.ctx.group_key_row(), "last time see this group");
399 }
400
401 Ok(())
402 }
403
404 fn reset(&mut self, funcs: &[BoxedAggregateFunction]) -> StreamExecutorResult<()> {
407 for (state, func) in self.states.iter_mut().zip_eq_fast(funcs) {
408 state.reset(func)?;
409 }
410 Ok(())
411 }
412
413 fn get_inter_states(&self, funcs: &[BoxedAggregateFunction]) -> StreamExecutorResult<OwnedRow> {
415 let mut inter_states = Vec::with_capacity(self.states.len());
416 for (state, func) in self.states.iter().zip_eq_fast(funcs) {
417 let encoded = match state {
418 AggState::Value(s) => func.encode_state(s)?,
419 AggState::MaterializedInput(_) => None,
421 };
422 inter_states.push(encoded);
423 }
424 Ok(OwnedRow::new(inter_states))
425 }
426
427 async fn get_outputs(
432 &mut self,
433 storages: &[AggStateStorage<S>],
434 funcs: &[BoxedAggregateFunction],
435 ) -> StreamExecutorResult<(OwnedRow, AggStateCacheStats)> {
436 let row_count = self.curr_row_count();
437 if row_count == 0 {
438 self.reset(funcs)?;
445 }
446 let mut stats = AggStateCacheStats::default();
447 futures::future::try_join_all(
448 self.states
449 .iter_mut()
450 .zip_eq_fast(storages)
451 .zip_eq_fast(funcs)
452 .map(|((state, storage), func)| {
453 state.get_output(storage, func, self.ctx.group_key())
454 }),
455 )
456 .await
457 .map(|outputs_and_stats| {
458 outputs_and_stats
459 .into_iter()
460 .map(|(output, stat)| {
461 stats.merge(stat);
462 output
463 })
464 .collect::<Vec<_>>()
465 })
466 .map(|row| (OwnedRow::new(row), stats))
467 }
468
469 pub fn build_states_change(
474 &mut self,
475 funcs: &[BoxedAggregateFunction],
476 ) -> StreamExecutorResult<Option<Record<OwnedRow>>> {
477 let curr_inter_states = self.get_inter_states(funcs)?;
478 let change_type = Strtg::infer_change_type(
479 &self.ctx,
480 self.prev_inter_states.as_ref(),
481 &curr_inter_states,
482 self.row_count_index,
483 );
484
485 tracing::trace!(
486 group = ?self.ctx.group_key_row(),
487 prev_inter_states = ?self.prev_inter_states,
488 curr_inter_states = ?curr_inter_states,
489 change_type = ?change_type,
490 "build intermediate states change"
491 );
492
493 let Some(change_type) = change_type else {
494 return Ok(None);
495 };
496 Ok(Some(match change_type {
497 RecordType::Insert => {
498 let new_row = self
499 .group_key()
500 .map(GroupKey::table_row)
501 .chain(&curr_inter_states)
502 .into_owned_row();
503 self.prev_inter_states = Some(curr_inter_states);
504 Record::Insert { new_row }
505 }
506 RecordType::Delete => {
507 let prev_inter_states = self
508 .prev_inter_states
509 .take()
510 .expect("must exist previous intermediate states");
511 let old_row = self
512 .group_key()
513 .map(GroupKey::table_row)
514 .chain(prev_inter_states)
515 .into_owned_row();
516 Record::Delete { old_row }
517 }
518 RecordType::Update => {
519 let new_row = self
520 .group_key()
521 .map(GroupKey::table_row)
522 .chain(&curr_inter_states)
523 .into_owned_row();
524 let prev_inter_states = self
525 .prev_inter_states
526 .replace(curr_inter_states)
527 .expect("must exist previous intermediate states");
528 let old_row = self
529 .group_key()
530 .map(GroupKey::table_row)
531 .chain(prev_inter_states)
532 .into_owned_row();
533 Record::Update { old_row, new_row }
534 }
535 }))
536 }
537
538 pub async fn build_outputs_change(
546 &mut self,
547 storages: &[AggStateStorage<S>],
548 funcs: &[BoxedAggregateFunction],
549 ) -> StreamExecutorResult<(Option<Record<OwnedRow>>, AggStateCacheStats)> {
550 let (curr_outputs, stats) = self.get_outputs(storages, funcs).await?;
551
552 let change_type = Strtg::infer_change_type(
553 &self.ctx,
554 self.prev_outputs.as_ref(),
555 &curr_outputs,
556 self.row_count_index,
557 );
558
559 tracing::trace!(
560 group = ?self.ctx.group_key_row(),
561 prev_outputs = ?self.prev_outputs,
562 curr_outputs = ?curr_outputs,
563 change_type = ?change_type,
564 "build outputs change"
565 );
566
567 let Some(change_type) = change_type else {
568 return Ok((None, stats));
569 };
570 Ok((
571 Some(match change_type {
572 RecordType::Insert => {
573 let new_row = self
574 .group_key()
575 .map(GroupKey::table_row)
576 .chain(&curr_outputs)
577 .into_owned_row();
578 self.prev_outputs = Some(curr_outputs);
582 Record::Insert { new_row }
583 }
584 RecordType::Delete => {
585 let prev_outputs = self.prev_outputs.take();
586 let old_row = self
587 .group_key()
588 .map(GroupKey::table_row)
589 .chain(prev_outputs)
590 .into_owned_row();
591 Record::Delete { old_row }
592 }
593 RecordType::Update => {
594 let new_row = self
595 .group_key()
596 .map(GroupKey::table_row)
597 .chain(&curr_outputs)
598 .into_owned_row();
599 let prev_outputs = self.prev_outputs.replace(curr_outputs);
600 let old_row = self
601 .group_key()
602 .map(GroupKey::table_row)
603 .chain(prev_outputs)
604 .into_owned_row();
605 Record::Update { old_row, new_row }
606 }
607 }),
608 stats,
609 ))
610 }
611}
612
613#[derive(Debug, Default)]
615pub struct AggStateCacheStats {
616 pub agg_state_cache_lookup_count: u64,
617 pub agg_state_cache_miss_count: u64,
618}
619
620impl AggStateCacheStats {
621 fn merge(&mut self, other: Self) {
622 self.agg_state_cache_lookup_count += other.agg_state_cache_lookup_count;
623 self.agg_state_cache_miss_count += other.agg_state_cache_miss_count;
624 }
625}