1use std::fmt::Debug;
16use std::marker::PhantomData;
17use std::sync::Arc;
18
19use futures::future::try_join_all;
20use risingwave_common::array::StreamChunk;
21use risingwave_common::array::stream_record::{Record, RecordType};
22use risingwave_common::bitmap::Bitmap;
23use risingwave_common::catalog::Schema;
24use risingwave_common::must_match;
25use risingwave_common::row::{OwnedRow, Row, RowExt};
26use risingwave_common::util::iter_util::ZipEqFast;
27use risingwave_common_estimate_size::EstimateSize;
28use risingwave_expr::aggregate::{AggCall, BoxedAggregateFunction};
29use risingwave_pb::stream_plan::PbAggNodeVersion;
30use risingwave_storage::StateStore;
31
32use super::agg_state::{AggState, AggStateStorage};
33use crate::common::table::state_table::StateTable;
34use crate::consistency::consistency_panic;
35use crate::executor::PkIndices;
36use crate::executor::error::StreamExecutorResult;
37
38#[derive(Debug)]
39pub struct Context {
40 group_key: Option<GroupKey>,
41}
42
43impl Context {
44 pub fn group_key(&self) -> Option<&GroupKey> {
45 self.group_key.as_ref()
46 }
47
48 pub fn group_key_row(&self) -> OwnedRow {
49 self.group_key()
50 .map(GroupKey::table_row)
51 .cloned()
52 .unwrap_or_default()
53 }
54}
55
56fn row_count_of(ctx: &Context, row: Option<impl Row>, row_count_col: usize) -> usize {
57 match row {
58 Some(row) => {
59 let mut row_count = row
60 .datum_at(row_count_col)
61 .expect("row count field should not be NULL")
62 .into_int64();
63
64 if row_count < 0 {
65 consistency_panic!(group = ?ctx.group_key_row(), row_count, "row count should be non-negative");
66
67 row_count = 0;
75 }
76 row_count.try_into().unwrap()
77 }
78 None => 0,
79 }
80}
81
82pub trait Strategy {
83 fn infer_change_type(
86 ctx: &Context,
87 prev_row: Option<&OwnedRow>,
88 curr_row: &OwnedRow,
89 row_count_col: usize,
90 ) -> Option<RecordType>;
91}
92
93pub struct AlwaysOutput;
95pub struct OnlyOutputIfHasInput;
98
99impl Strategy for AlwaysOutput {
100 fn infer_change_type(
101 ctx: &Context,
102 prev_row: Option<&OwnedRow>,
103 _curr_row: &OwnedRow,
104 row_count_col: usize,
105 ) -> Option<RecordType> {
106 let prev_row_count = row_count_of(ctx, prev_row, row_count_col);
107 match prev_row {
108 None => {
109 assert_eq!(prev_row_count, 0);
113
114 Some(RecordType::Insert)
116 }
117 Some(_prev_outputs) => Some(RecordType::Update),
128 }
129 }
130}
131
132impl Strategy for OnlyOutputIfHasInput {
133 fn infer_change_type(
134 ctx: &Context,
135 prev_row: Option<&OwnedRow>,
136 curr_row: &OwnedRow,
137 row_count_col: usize,
138 ) -> Option<RecordType> {
139 let prev_row_count = row_count_of(ctx, prev_row, row_count_col);
140 let curr_row_count = row_count_of(ctx, Some(curr_row), row_count_col);
141
142 match (prev_row_count, curr_row_count) {
143 (0, 0) => {
144 None
146 }
147 (0, _) => {
148 Some(RecordType::Insert)
150 }
151 (_, 0) => {
152 Some(RecordType::Delete)
154 }
155 (_, _) => {
156 if prev_row.expect("must exist previous row") == curr_row {
158 None
160 } else {
161 Some(RecordType::Update)
162 }
163 }
164 }
165 }
166}
167
168#[derive(Clone, Debug)]
170pub struct GroupKey {
171 row_prefix: OwnedRow,
172 table_pk_projection: Arc<[usize]>,
173}
174
175impl GroupKey {
176 pub fn new(row_prefix: OwnedRow, table_pk_projection: Option<Arc<[usize]>>) -> Self {
177 let table_pk_projection =
178 table_pk_projection.unwrap_or_else(|| (0..row_prefix.len()).collect());
179 Self {
180 row_prefix,
181 table_pk_projection,
182 }
183 }
184
185 pub fn len(&self) -> usize {
186 self.row_prefix.len()
187 }
188
189 pub fn is_empty(&self) -> bool {
190 self.row_prefix.is_empty()
191 }
192
193 pub fn table_row(&self) -> &OwnedRow {
195 &self.row_prefix
196 }
197
198 pub fn table_pk(&self) -> impl Row + '_ {
200 (&self.row_prefix).project(&self.table_pk_projection)
201 }
202
203 pub fn cache_key(&self) -> impl Row + '_ {
205 self.table_row()
206 }
207}
208
209pub struct AggGroup<S: StateStore, Strtg: Strategy> {
211 ctx: Context,
213
214 states: Vec<AggState>,
216
217 prev_inter_states: Option<OwnedRow>,
219
220 prev_outputs: Option<OwnedRow>,
223
224 row_count_index: usize,
226
227 emit_on_window_close: bool,
229
230 _phantom: PhantomData<(S, Strtg)>,
231}
232
233impl<S: StateStore, Strtg: Strategy> Debug for AggGroup<S, Strtg> {
234 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
235 f.debug_struct("AggGroup")
236 .field("group_key", &self.ctx.group_key)
237 .field("prev_inter_states", &self.prev_inter_states)
238 .field("prev_outputs", &self.prev_outputs)
239 .field("row_count_index", &self.row_count_index)
240 .field("emit_on_window_close", &self.emit_on_window_close)
241 .finish()
242 }
243}
244
245impl<S: StateStore, Strtg: Strategy> EstimateSize for AggGroup<S, Strtg> {
246 fn estimated_heap_size(&self) -> usize {
247 self.states
249 .iter()
250 .map(|state| state.estimated_heap_size())
251 .sum()
252 }
253}
254
255impl<S: StateStore, Strtg: Strategy> AggGroup<S, Strtg> {
256 #[allow(clippy::too_many_arguments)]
261 pub async fn create(
262 version: PbAggNodeVersion,
263 group_key: Option<GroupKey>,
264 agg_calls: &[AggCall],
265 agg_funcs: &[BoxedAggregateFunction],
266 storages: &[AggStateStorage<S>],
267 intermediate_state_table: &StateTable<S>,
268 pk_indices: &PkIndices,
269 row_count_index: usize,
270 emit_on_window_close: bool,
271 extreme_cache_size: usize,
272 input_schema: &Schema,
273 ) -> StreamExecutorResult<Self> {
274 let inter_states = intermediate_state_table
275 .get_row(group_key.as_ref().map(GroupKey::table_pk))
276 .await?;
277 if let Some(inter_states) = &inter_states {
278 assert_eq!(inter_states.len(), agg_calls.len());
279 }
280
281 let mut states = Vec::with_capacity(agg_calls.len());
282 for (idx, (agg_call, agg_func)) in agg_calls.iter().zip_eq_fast(agg_funcs).enumerate() {
283 let state = AggState::create(
284 version,
285 agg_call,
286 agg_func,
287 &storages[idx],
288 inter_states.as_ref().map(|s| &s[idx]),
289 pk_indices,
290 extreme_cache_size,
291 input_schema,
292 )?;
293 states.push(state);
294 }
295
296 let mut this = Self {
297 ctx: Context { group_key },
298 states,
299 prev_inter_states: inter_states,
300 prev_outputs: None, row_count_index,
302 emit_on_window_close,
303 _phantom: PhantomData,
304 };
305
306 if !this.emit_on_window_close && this.prev_inter_states.is_some() {
307 let (outputs, _stats) = this.get_outputs(storages, agg_funcs).await?;
308 this.prev_outputs = Some(outputs);
309 }
310
311 Ok(this)
312 }
313
314 #[allow(clippy::too_many_arguments)]
317 pub fn for_eowc_output(
318 version: PbAggNodeVersion,
319 group_key: Option<GroupKey>,
320 agg_calls: &[AggCall],
321 agg_funcs: &[BoxedAggregateFunction],
322 storages: &[AggStateStorage<S>],
323 inter_states: &OwnedRow,
324 pk_indices: &PkIndices,
325 row_count_index: usize,
326 emit_on_window_close: bool,
327 extreme_cache_size: usize,
328 input_schema: &Schema,
329 ) -> StreamExecutorResult<Self> {
330 let mut states = Vec::with_capacity(agg_calls.len());
331 for (idx, (agg_call, agg_func)) in agg_calls.iter().zip_eq_fast(agg_funcs).enumerate() {
332 let state = AggState::create(
333 version,
334 agg_call,
335 agg_func,
336 &storages[idx],
337 Some(&inter_states[idx]),
338 pk_indices,
339 extreme_cache_size,
340 input_schema,
341 )?;
342 states.push(state);
343 }
344
345 Ok(Self {
346 ctx: Context { group_key },
347 states,
348 prev_inter_states: None, prev_outputs: None, row_count_index,
351 emit_on_window_close,
352 _phantom: PhantomData,
353 })
354 }
355
356 pub fn group_key(&self) -> Option<&GroupKey> {
357 self.ctx.group_key()
358 }
359
360 fn curr_row_count(&self) -> usize {
362 let row_count_state = must_match!(
363 self.states[self.row_count_index],
364 AggState::Value(ref state) => state
365 );
366 row_count_of(&self.ctx, Some([row_count_state.as_datum().clone()]), 0)
367 }
368
369 pub(crate) fn is_uninitialized(&self) -> bool {
370 self.prev_inter_states.is_none()
371 }
372
373 pub async fn apply_chunk(
378 &mut self,
379 chunk: &StreamChunk,
380 calls: &[AggCall],
381 funcs: &[BoxedAggregateFunction],
382 visibilities: Vec<Bitmap>,
383 ) -> StreamExecutorResult<()> {
384 if self.curr_row_count() == 0 {
385 tracing::trace!(group = ?self.ctx.group_key_row(), "first time see this group");
386 }
387
388 let concurrency = 10;
389 let len = self.states.len();
390
391 for chunk_start in (0..len).step_by(concurrency) {
392 let chunk_end = std::cmp::min(chunk_start + concurrency, len);
393
394 let futures = &mut self.states[chunk_start..chunk_end]
396 .iter_mut()
397 .zip_eq_fast(&calls[chunk_start..chunk_end])
398 .zip_eq_fast(&funcs[chunk_start..chunk_end])
399 .zip_eq_fast(&visibilities[chunk_start..chunk_end])
400 .map(|(((state, call), func), visibility)| {
401 state.apply_chunk(chunk, call, func, visibility.clone())
402 });
403
404 try_join_all(futures).await?;
405 }
406
407 if self.curr_row_count() == 0 {
408 tracing::trace!(group = ?self.ctx.group_key_row(), "last time see this group");
409 }
410
411 Ok(())
412 }
413
414 fn reset(&mut self, funcs: &[BoxedAggregateFunction]) -> StreamExecutorResult<()> {
417 for (state, func) in self.states.iter_mut().zip_eq_fast(funcs) {
418 state.reset(func)?;
419 }
420 Ok(())
421 }
422
423 fn get_inter_states(&self, funcs: &[BoxedAggregateFunction]) -> StreamExecutorResult<OwnedRow> {
425 let mut inter_states = Vec::with_capacity(self.states.len());
426 for (state, func) in self.states.iter().zip_eq_fast(funcs) {
427 let encoded = match state {
428 AggState::Value(s) => func.encode_state(s)?,
429 AggState::MaterializedInput(_) => None,
431 };
432 inter_states.push(encoded);
433 }
434 Ok(OwnedRow::new(inter_states))
435 }
436
437 async fn get_outputs(
442 &mut self,
443 storages: &[AggStateStorage<S>],
444 funcs: &[BoxedAggregateFunction],
445 ) -> StreamExecutorResult<(OwnedRow, AggStateCacheStats)> {
446 let row_count = self.curr_row_count();
447 if row_count == 0 {
448 self.reset(funcs)?;
455 }
456 let mut stats = AggStateCacheStats::default();
457 futures::future::try_join_all(
458 self.states
459 .iter_mut()
460 .zip_eq_fast(storages)
461 .zip_eq_fast(funcs)
462 .map(|((state, storage), func)| {
463 state.get_output(storage, func, self.ctx.group_key())
464 }),
465 )
466 .await
467 .map(|outputs_and_stats| {
468 outputs_and_stats
469 .into_iter()
470 .map(|(output, stat)| {
471 stats.merge(stat);
472 output
473 })
474 .collect::<Vec<_>>()
475 })
476 .map(|row| (OwnedRow::new(row), stats))
477 }
478
479 pub fn build_states_change(
484 &mut self,
485 funcs: &[BoxedAggregateFunction],
486 ) -> StreamExecutorResult<Option<Record<OwnedRow>>> {
487 let curr_inter_states = self.get_inter_states(funcs)?;
488 let change_type = Strtg::infer_change_type(
489 &self.ctx,
490 self.prev_inter_states.as_ref(),
491 &curr_inter_states,
492 self.row_count_index,
493 );
494
495 tracing::trace!(
496 group = ?self.ctx.group_key_row(),
497 prev_inter_states = ?self.prev_inter_states,
498 curr_inter_states = ?curr_inter_states,
499 change_type = ?change_type,
500 "build intermediate states change"
501 );
502
503 let Some(change_type) = change_type else {
504 return Ok(None);
505 };
506 Ok(Some(match change_type {
507 RecordType::Insert => {
508 let new_row = self
509 .group_key()
510 .map(GroupKey::table_row)
511 .chain(&curr_inter_states)
512 .into_owned_row();
513 self.prev_inter_states = Some(curr_inter_states);
514 Record::Insert { new_row }
515 }
516 RecordType::Delete => {
517 let prev_inter_states = self
518 .prev_inter_states
519 .take()
520 .expect("must exist previous intermediate states");
521 let old_row = self
522 .group_key()
523 .map(GroupKey::table_row)
524 .chain(prev_inter_states)
525 .into_owned_row();
526 Record::Delete { old_row }
527 }
528 RecordType::Update => {
529 let new_row = self
530 .group_key()
531 .map(GroupKey::table_row)
532 .chain(&curr_inter_states)
533 .into_owned_row();
534 let prev_inter_states = self
535 .prev_inter_states
536 .replace(curr_inter_states)
537 .expect("must exist previous intermediate states");
538 let old_row = self
539 .group_key()
540 .map(GroupKey::table_row)
541 .chain(prev_inter_states)
542 .into_owned_row();
543 Record::Update { old_row, new_row }
544 }
545 }))
546 }
547
548 pub async fn build_outputs_change(
556 &mut self,
557 storages: &[AggStateStorage<S>],
558 funcs: &[BoxedAggregateFunction],
559 ) -> StreamExecutorResult<(Option<Record<OwnedRow>>, AggStateCacheStats)> {
560 let (curr_outputs, stats) = self.get_outputs(storages, funcs).await?;
561
562 let change_type = Strtg::infer_change_type(
563 &self.ctx,
564 self.prev_outputs.as_ref(),
565 &curr_outputs,
566 self.row_count_index,
567 );
568
569 tracing::trace!(
570 group = ?self.ctx.group_key_row(),
571 prev_outputs = ?self.prev_outputs,
572 curr_outputs = ?curr_outputs,
573 change_type = ?change_type,
574 "build outputs change"
575 );
576
577 let Some(change_type) = change_type else {
578 return Ok((None, stats));
579 };
580 Ok((
581 Some(match change_type {
582 RecordType::Insert => {
583 let new_row = self
584 .group_key()
585 .map(GroupKey::table_row)
586 .chain(&curr_outputs)
587 .into_owned_row();
588 self.prev_outputs = Some(curr_outputs);
592 Record::Insert { new_row }
593 }
594 RecordType::Delete => {
595 let prev_outputs = self.prev_outputs.take();
596 let old_row = self
597 .group_key()
598 .map(GroupKey::table_row)
599 .chain(prev_outputs)
600 .into_owned_row();
601 Record::Delete { old_row }
602 }
603 RecordType::Update => {
604 let new_row = self
605 .group_key()
606 .map(GroupKey::table_row)
607 .chain(&curr_outputs)
608 .into_owned_row();
609 let prev_outputs = self.prev_outputs.replace(curr_outputs);
610 let old_row = self
611 .group_key()
612 .map(GroupKey::table_row)
613 .chain(prev_outputs)
614 .into_owned_row();
615 Record::Update { old_row, new_row }
616 }
617 }),
618 stats,
619 ))
620 }
621}
622
623#[derive(Debug, Default)]
625pub struct AggStateCacheStats {
626 pub agg_state_cache_lookup_count: u64,
627 pub agg_state_cache_miss_count: u64,
628}
629
630impl AggStateCacheStats {
631 fn merge(&mut self, other: Self) {
632 self.agg_state_cache_lookup_count += other.agg_state_cache_lookup_count;
633 self.agg_state_cache_miss_count += other.agg_state_cache_miss_count;
634 }
635}