1use std::fmt::Debug;
16use std::marker::PhantomData;
17use std::sync::Arc;
18
19use futures::future::try_join_all;
20use risingwave_common::array::StreamChunk;
21use risingwave_common::array::stream_record::{Record, RecordType};
22use risingwave_common::bitmap::Bitmap;
23use risingwave_common::catalog::Schema;
24use risingwave_common::must_match;
25use risingwave_common::row::{OwnedRow, Row, RowExt};
26use risingwave_common::util::iter_util::ZipEqFast;
27use risingwave_common_estimate_size::EstimateSize;
28use risingwave_expr::aggregate::{AggCall, BoxedAggregateFunction};
29use risingwave_pb::stream_plan::PbAggNodeVersion;
30use risingwave_storage::StateStore;
31
32use super::agg_state::{AggState, AggStateStorage};
33use crate::common::table::state_table::StateTable;
34use crate::consistency::consistency_panic;
35use crate::executor::PkIndices;
36use crate::executor::error::StreamExecutorResult;
37
38#[derive(Debug)]
39pub struct Context {
40 group_key: Option<GroupKey>,
41}
42
43impl Context {
44 pub fn group_key(&self) -> Option<&GroupKey> {
45 self.group_key.as_ref()
46 }
47
48 pub fn group_key_row(&self) -> OwnedRow {
49 self.group_key()
50 .map(GroupKey::table_row)
51 .cloned()
52 .unwrap_or_default()
53 }
54}
55
56fn row_count_of(ctx: &Context, row: Option<impl Row>, row_count_col: usize) -> usize {
57 match row {
58 Some(row) => {
59 let mut row_count = row
60 .datum_at(row_count_col)
61 .expect("row count field should not be NULL")
62 .into_int64();
63
64 if row_count < 0 {
65 consistency_panic!(group = ?ctx.group_key_row(), row_count, "row count should be non-negative");
66
67 row_count = 0;
75 }
76 row_count.try_into().unwrap()
77 }
78 None => 0,
79 }
80}
81
82pub trait Strategy {
83 fn infer_change_type(
86 ctx: &Context,
87 prev_row: Option<&OwnedRow>,
88 curr_row: &OwnedRow,
89 row_count_col: usize,
90 ) -> Option<RecordType>;
91}
92
93pub struct AlwaysOutput;
95pub struct OnlyOutputIfHasInput;
98
99impl Strategy for AlwaysOutput {
100 fn infer_change_type(
101 ctx: &Context,
102 prev_row: Option<&OwnedRow>,
103 _curr_row: &OwnedRow,
104 row_count_col: usize,
105 ) -> Option<RecordType> {
106 let prev_row_count = row_count_of(ctx, prev_row, row_count_col);
107 match prev_row {
108 None => {
109 assert_eq!(prev_row_count, 0);
113
114 Some(RecordType::Insert)
116 }
117 Some(_prev_outputs) => Some(RecordType::Update),
128 }
129 }
130}
131
132impl Strategy for OnlyOutputIfHasInput {
133 fn infer_change_type(
134 ctx: &Context,
135 prev_row: Option<&OwnedRow>,
136 curr_row: &OwnedRow,
137 row_count_col: usize,
138 ) -> Option<RecordType> {
139 let prev_row_count = row_count_of(ctx, prev_row, row_count_col);
140 let curr_row_count = row_count_of(ctx, Some(curr_row), row_count_col);
141
142 match (prev_row_count, curr_row_count) {
143 (0, 0) => {
144 None
146 }
147 (0, _) => {
148 Some(RecordType::Insert)
150 }
151 (_, 0) => {
152 Some(RecordType::Delete)
154 }
155 (_, _) => {
156 if prev_row.expect("must exist previous row") == curr_row {
158 None
160 } else {
161 Some(RecordType::Update)
162 }
163 }
164 }
165 }
166}
167
168#[derive(Clone, Debug)]
170pub struct GroupKey {
171 row_prefix: OwnedRow,
172 table_pk_projection: Arc<[usize]>,
173}
174
175impl GroupKey {
176 pub fn new(row_prefix: OwnedRow, table_pk_projection: Option<Arc<[usize]>>) -> Self {
177 let table_pk_projection =
178 table_pk_projection.unwrap_or_else(|| (0..row_prefix.len()).collect());
179 Self {
180 row_prefix,
181 table_pk_projection,
182 }
183 }
184
185 pub fn len(&self) -> usize {
186 self.row_prefix.len()
187 }
188
189 pub fn is_empty(&self) -> bool {
190 self.row_prefix.is_empty()
191 }
192
193 pub fn table_row(&self) -> &OwnedRow {
195 &self.row_prefix
196 }
197
198 pub fn table_pk(&self) -> impl Row + '_ {
200 (&self.row_prefix).project(&self.table_pk_projection)
201 }
202
203 pub fn cache_key(&self) -> impl Row + '_ {
205 self.table_row()
206 }
207}
208
209pub struct AggGroup<S: StateStore, Strtg: Strategy> {
211 ctx: Context,
213
214 states: Vec<AggState>,
216
217 prev_inter_states: Option<OwnedRow>,
219
220 prev_outputs: Option<OwnedRow>,
223
224 row_count_index: usize,
226
227 emit_on_window_close: bool,
229
230 _phantom: PhantomData<(S, Strtg)>,
231}
232
233impl<S: StateStore, Strtg: Strategy> Debug for AggGroup<S, Strtg> {
234 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
235 f.debug_struct("AggGroup")
236 .field("group_key", &self.ctx.group_key)
237 .field("prev_inter_states", &self.prev_inter_states)
238 .field("prev_outputs", &self.prev_outputs)
239 .field("row_count_index", &self.row_count_index)
240 .field("emit_on_window_close", &self.emit_on_window_close)
241 .finish()
242 }
243}
244
245impl<S: StateStore, Strtg: Strategy> EstimateSize for AggGroup<S, Strtg> {
246 fn estimated_heap_size(&self) -> usize {
247 self.states
249 .iter()
250 .map(|state| state.estimated_heap_size())
251 .sum()
252 }
253}
254
255impl<S: StateStore, Strtg: Strategy> AggGroup<S, Strtg> {
256 #[allow(clippy::too_many_arguments)]
261 pub async fn create(
262 version: PbAggNodeVersion,
263 group_key: Option<GroupKey>,
264 agg_calls: &[AggCall],
265 agg_funcs: &[BoxedAggregateFunction],
266 storages: &[AggStateStorage<S>],
267 intermediate_state_table: &StateTable<S>,
268 pk_indices: &PkIndices,
269 row_count_index: usize,
270 emit_on_window_close: bool,
271 extreme_cache_size: usize,
272 input_schema: &Schema,
273 ) -> StreamExecutorResult<(Self, AggStateCacheStats)> {
274 let mut stats = AggStateCacheStats::default();
275
276 let inter_states = intermediate_state_table
277 .get_row(group_key.as_ref().map(GroupKey::table_pk))
278 .await?;
279 if let Some(inter_states) = &inter_states {
280 assert_eq!(inter_states.len(), agg_calls.len());
281 }
282
283 let mut states = Vec::with_capacity(agg_calls.len());
284 for (idx, (agg_call, agg_func)) in agg_calls.iter().zip_eq_fast(agg_funcs).enumerate() {
285 let state = AggState::create(
286 version,
287 agg_call,
288 agg_func,
289 &storages[idx],
290 inter_states.as_ref().map(|s| &s[idx]),
291 pk_indices,
292 extreme_cache_size,
293 input_schema,
294 )?;
295 states.push(state);
296 }
297
298 let mut this = Self {
299 ctx: Context { group_key },
300 states,
301 prev_inter_states: inter_states,
302 prev_outputs: None, row_count_index,
304 emit_on_window_close,
305 _phantom: PhantomData,
306 };
307
308 if !this.emit_on_window_close && this.prev_inter_states.is_some() {
309 let (outputs, init_stats) = this.get_outputs(storages, agg_funcs).await?;
310 this.prev_outputs = Some(outputs);
311 stats.merge(init_stats);
312 }
313
314 Ok((this, stats))
315 }
316
317 #[allow(clippy::too_many_arguments)]
320 pub fn for_eowc_output(
321 version: PbAggNodeVersion,
322 group_key: Option<GroupKey>,
323 agg_calls: &[AggCall],
324 agg_funcs: &[BoxedAggregateFunction],
325 storages: &[AggStateStorage<S>],
326 inter_states: &OwnedRow,
327 pk_indices: &PkIndices,
328 row_count_index: usize,
329 emit_on_window_close: bool,
330 extreme_cache_size: usize,
331 input_schema: &Schema,
332 ) -> StreamExecutorResult<Self> {
333 let mut states = Vec::with_capacity(agg_calls.len());
334 for (idx, (agg_call, agg_func)) in agg_calls.iter().zip_eq_fast(agg_funcs).enumerate() {
335 let state = AggState::create(
336 version,
337 agg_call,
338 agg_func,
339 &storages[idx],
340 Some(&inter_states[idx]),
341 pk_indices,
342 extreme_cache_size,
343 input_schema,
344 )?;
345 states.push(state);
346 }
347
348 Ok(Self {
349 ctx: Context { group_key },
350 states,
351 prev_inter_states: None, prev_outputs: None, row_count_index,
354 emit_on_window_close,
355 _phantom: PhantomData,
356 })
357 }
358
359 pub fn group_key(&self) -> Option<&GroupKey> {
360 self.ctx.group_key()
361 }
362
363 fn curr_row_count(&self) -> usize {
365 let row_count_state = must_match!(
366 self.states[self.row_count_index],
367 AggState::Value(ref state) => state
368 );
369 row_count_of(&self.ctx, Some([row_count_state.as_datum().clone()]), 0)
370 }
371
372 pub(crate) fn is_uninitialized(&self) -> bool {
373 self.prev_inter_states.is_none()
374 }
375
376 pub async fn apply_chunk(
381 &mut self,
382 chunk: &StreamChunk,
383 calls: &[AggCall],
384 funcs: &[BoxedAggregateFunction],
385 visibilities: Vec<Bitmap>,
386 ) -> StreamExecutorResult<()> {
387 if self.curr_row_count() == 0 {
388 tracing::trace!(group = ?self.ctx.group_key_row(), "first time see this group");
389 }
390
391 let concurrency = 10;
392 let len = self.states.len();
393
394 for chunk_start in (0..len).step_by(concurrency) {
395 let chunk_end = std::cmp::min(chunk_start + concurrency, len);
396
397 let futures = &mut self.states[chunk_start..chunk_end]
399 .iter_mut()
400 .zip_eq_fast(&calls[chunk_start..chunk_end])
401 .zip_eq_fast(&funcs[chunk_start..chunk_end])
402 .zip_eq_fast(&visibilities[chunk_start..chunk_end])
403 .map(|(((state, call), func), visibility)| {
404 state.apply_chunk(chunk, call, func, visibility.clone())
405 });
406
407 try_join_all(futures).await?;
408 }
409
410 if self.curr_row_count() == 0 {
411 tracing::trace!(group = ?self.ctx.group_key_row(), "last time see this group");
412 }
413
414 Ok(())
415 }
416
417 fn reset(&mut self, funcs: &[BoxedAggregateFunction]) -> StreamExecutorResult<()> {
420 for (state, func) in self.states.iter_mut().zip_eq_fast(funcs) {
421 state.reset(func)?;
422 }
423 Ok(())
424 }
425
426 fn get_inter_states(&self, funcs: &[BoxedAggregateFunction]) -> StreamExecutorResult<OwnedRow> {
428 let mut inter_states = Vec::with_capacity(self.states.len());
429 for (state, func) in self.states.iter().zip_eq_fast(funcs) {
430 let encoded = match state {
431 AggState::Value(s) => func.encode_state(s)?,
432 AggState::MaterializedInput(_) => None,
434 };
435 inter_states.push(encoded);
436 }
437 Ok(OwnedRow::new(inter_states))
438 }
439
440 async fn get_outputs(
445 &mut self,
446 storages: &[AggStateStorage<S>],
447 funcs: &[BoxedAggregateFunction],
448 ) -> StreamExecutorResult<(OwnedRow, AggStateCacheStats)> {
449 let row_count = self.curr_row_count();
450 if row_count == 0 {
451 self.reset(funcs)?;
458 }
459 let mut stats = AggStateCacheStats::default();
460 futures::future::try_join_all(
461 self.states
462 .iter_mut()
463 .zip_eq_fast(storages)
464 .zip_eq_fast(funcs)
465 .map(|((state, storage), func)| {
466 state.get_output(storage, func, self.ctx.group_key())
467 }),
468 )
469 .await
470 .map(|outputs_and_stats| {
471 outputs_and_stats
472 .into_iter()
473 .map(|(output, stat)| {
474 stats.merge(stat);
475 output
476 })
477 .collect::<Vec<_>>()
478 })
479 .map(|row| (OwnedRow::new(row), stats))
480 }
481
482 pub fn build_states_change(
487 &mut self,
488 funcs: &[BoxedAggregateFunction],
489 ) -> StreamExecutorResult<Option<Record<OwnedRow>>> {
490 let curr_inter_states = self.get_inter_states(funcs)?;
491 let change_type = Strtg::infer_change_type(
492 &self.ctx,
493 self.prev_inter_states.as_ref(),
494 &curr_inter_states,
495 self.row_count_index,
496 );
497
498 tracing::trace!(
499 group = ?self.ctx.group_key_row(),
500 prev_inter_states = ?self.prev_inter_states,
501 curr_inter_states = ?curr_inter_states,
502 change_type = ?change_type,
503 "build intermediate states change"
504 );
505
506 let Some(change_type) = change_type else {
507 return Ok(None);
508 };
509 Ok(Some(match change_type {
510 RecordType::Insert => {
511 let new_row = self
512 .group_key()
513 .map(GroupKey::table_row)
514 .chain(&curr_inter_states)
515 .into_owned_row();
516 self.prev_inter_states = Some(curr_inter_states);
517 Record::Insert { new_row }
518 }
519 RecordType::Delete => {
520 let prev_inter_states = self
521 .prev_inter_states
522 .take()
523 .expect("must exist previous intermediate states");
524 let old_row = self
525 .group_key()
526 .map(GroupKey::table_row)
527 .chain(prev_inter_states)
528 .into_owned_row();
529 Record::Delete { old_row }
530 }
531 RecordType::Update => {
532 let new_row = self
533 .group_key()
534 .map(GroupKey::table_row)
535 .chain(&curr_inter_states)
536 .into_owned_row();
537 let prev_inter_states = self
538 .prev_inter_states
539 .replace(curr_inter_states)
540 .expect("must exist previous intermediate states");
541 let old_row = self
542 .group_key()
543 .map(GroupKey::table_row)
544 .chain(prev_inter_states)
545 .into_owned_row();
546 Record::Update { old_row, new_row }
547 }
548 }))
549 }
550
551 pub async fn build_outputs_change(
559 &mut self,
560 storages: &[AggStateStorage<S>],
561 funcs: &[BoxedAggregateFunction],
562 ) -> StreamExecutorResult<(Option<Record<OwnedRow>>, AggStateCacheStats)> {
563 let (curr_outputs, stats) = self.get_outputs(storages, funcs).await?;
564
565 let change_type = Strtg::infer_change_type(
566 &self.ctx,
567 self.prev_outputs.as_ref(),
568 &curr_outputs,
569 self.row_count_index,
570 );
571
572 tracing::trace!(
573 group = ?self.ctx.group_key_row(),
574 prev_outputs = ?self.prev_outputs,
575 curr_outputs = ?curr_outputs,
576 change_type = ?change_type,
577 "build outputs change"
578 );
579
580 let Some(change_type) = change_type else {
581 return Ok((None, stats));
582 };
583 Ok((
584 Some(match change_type {
585 RecordType::Insert => {
586 let new_row = self
587 .group_key()
588 .map(GroupKey::table_row)
589 .chain(&curr_outputs)
590 .into_owned_row();
591 self.prev_outputs = Some(curr_outputs);
595 Record::Insert { new_row }
596 }
597 RecordType::Delete => {
598 let prev_outputs = self.prev_outputs.take();
599 let old_row = self
600 .group_key()
601 .map(GroupKey::table_row)
602 .chain(prev_outputs)
603 .into_owned_row();
604 Record::Delete { old_row }
605 }
606 RecordType::Update => {
607 let new_row = self
608 .group_key()
609 .map(GroupKey::table_row)
610 .chain(&curr_outputs)
611 .into_owned_row();
612 let prev_outputs = self.prev_outputs.replace(curr_outputs);
613 let old_row = self
614 .group_key()
615 .map(GroupKey::table_row)
616 .chain(prev_outputs)
617 .into_owned_row();
618 Record::Update { old_row, new_row }
619 }
620 }),
621 stats,
622 ))
623 }
624}
625
626#[derive(Debug, Default)]
628pub struct AggStateCacheStats {
629 pub agg_state_cache_lookup_count: u64,
630 pub agg_state_cache_miss_count: u64,
631}
632
633impl AggStateCacheStats {
634 fn merge(&mut self, other: Self) {
635 self.agg_state_cache_lookup_count += other.agg_state_cache_lookup_count;
636 self.agg_state_cache_miss_count += other.agg_state_cache_miss_count;
637 }
638}