1use std::collections::HashMap;
16
17use risingwave_common::array::stream_record::Record;
18use risingwave_common::util::epoch::EpochPair;
19use risingwave_common::util::iter_util::ZipEqFast;
20use risingwave_expr::aggregate::{AggCall, BoxedAggregateFunction, build_retractable};
21use risingwave_pb::stream_plan::PbAggNodeVersion;
22
23use super::agg_group::{AggGroup, AlwaysOutput};
24use super::agg_state::AggStateStorage;
25use super::distinct::DistinctDeduplicater;
26use super::{AggExecutorArgs, SimpleAggExecutorExtraArgs, agg_call_filter_res, iter_table_storage};
27use crate::executor::prelude::*;
28
29pub struct SimpleAggExecutor<S: StateStore> {
43 input: Executor,
44 inner: ExecutorInner<S>,
45}
46
47struct ExecutorInner<S: StateStore> {
48 version: PbAggNodeVersion,
50
51 actor_ctx: ActorContextRef,
52 info: ExecutorInfo,
53
54 input_pk_indices: Vec<usize>,
56
57 input_schema: Schema,
59
60 agg_calls: Vec<AggCall>,
62
63 agg_funcs: Vec<BoxedAggregateFunction>,
65
66 row_count_index: usize,
68
69 storages: Vec<AggStateStorage<S>>,
71
72 intermediate_state_table: StateTable<S>,
76
77 distinct_dedup_tables: HashMap<usize, StateTable<S>>,
80
81 watermark_epoch: AtomicU64Ref,
83
84 extreme_cache_size: usize,
86
87 must_output_per_barrier: bool,
90}
91
92impl<S: StateStore> ExecutorInner<S> {
93 fn all_state_tables_mut(&mut self) -> impl Iterator<Item = &mut StateTable<S>> {
94 iter_table_storage(&mut self.storages)
95 .chain(self.distinct_dedup_tables.values_mut())
96 .chain(std::iter::once(&mut self.intermediate_state_table))
97 }
98}
99
100struct ExecutionVars<S: StateStore> {
101 agg_group: AggGroup<S, AlwaysOutput>,
103
104 distinct_dedup: DistinctDeduplicater<S>,
106
107 state_changed: bool,
109}
110
111impl<S: StateStore> Execute for SimpleAggExecutor<S> {
112 fn execute(self: Box<Self>) -> BoxedMessageStream {
113 self.execute_inner().boxed()
114 }
115}
116
117impl<S: StateStore> SimpleAggExecutor<S> {
118 pub fn new(args: AggExecutorArgs<S, SimpleAggExecutorExtraArgs>) -> StreamResult<Self> {
119 let input_info = args.input.info().clone();
120 Ok(Self {
121 input: args.input,
122 inner: ExecutorInner {
123 version: args.version,
124 actor_ctx: args.actor_ctx,
125 info: args.info,
126 input_pk_indices: input_info.pk_indices,
127 input_schema: input_info.schema,
128 agg_funcs: args.agg_calls.iter().map(build_retractable).try_collect()?,
129 agg_calls: args.agg_calls,
130 row_count_index: args.row_count_index,
131 storages: args.storages,
132 intermediate_state_table: args.intermediate_state_table,
133 distinct_dedup_tables: args.distinct_dedup_tables,
134 watermark_epoch: args.watermark_epoch,
135 extreme_cache_size: args.extreme_cache_size,
136 must_output_per_barrier: args.extra.must_output_per_barrier,
137 },
138 })
139 }
140
141 async fn apply_chunk(
142 this: &mut ExecutorInner<S>,
143 vars: &mut ExecutionVars<S>,
144 chunk: StreamChunk,
145 ) -> StreamExecutorResult<()> {
146 if chunk.cardinality() == 0 {
147 return Ok(());
149 }
150
151 let mut call_visibilities = Vec::with_capacity(this.agg_calls.len());
153 for agg_call in &this.agg_calls {
154 let vis = agg_call_filter_res(agg_call, &chunk).await?;
155 call_visibilities.push(vis);
156 }
157
158 let visibilities = vars
160 .distinct_dedup
161 .dedup_chunk(
162 chunk.ops(),
163 chunk.columns(),
164 call_visibilities,
165 &mut this.distinct_dedup_tables,
166 None,
167 )
168 .await?;
169
170 for (storage, visibility) in this.storages.iter_mut().zip_eq_fast(visibilities.iter()) {
172 if let AggStateStorage::MaterializedInput { table, mapping, .. } = storage {
173 let chunk = chunk.project_with_vis(mapping.upstream_columns(), visibility.clone());
174 table.write_chunk(chunk);
175 }
176 }
177
178 vars.agg_group
180 .apply_chunk(&chunk, &this.agg_calls, &this.agg_funcs, visibilities)
181 .await?;
182
183 vars.state_changed = true;
185
186 Ok(())
187 }
188
189 async fn flush_data(
190 this: &mut ExecutorInner<S>,
191 vars: &mut ExecutionVars<S>,
192 epoch: EpochPair,
193 ) -> StreamExecutorResult<Option<StreamChunk>> {
194 if vars.state_changed || vars.agg_group.is_uninitialized() {
195 vars.distinct_dedup.flush(&mut this.distinct_dedup_tables)?;
197
198 if let Some(inter_states_change) =
200 vars.agg_group.build_states_change(&this.agg_funcs)?
201 {
202 this.intermediate_state_table
203 .write_record(inter_states_change);
204 }
205 }
206 vars.state_changed = false;
207
208 let (outputs_change, _stats) = vars
210 .agg_group
211 .build_outputs_change(&this.storages, &this.agg_funcs)
212 .await?;
213
214 let change =
215 outputs_change.expect("`AlwaysOutput` strategy will output a change in any case");
216 let chunk = if !this.must_output_per_barrier
217 && let Record::Update { old_row, new_row } = &change
218 && old_row == new_row
219 {
220 None
222 } else {
223 Some(change.to_stream_chunk(&this.info.schema.data_types()))
224 };
225
226 futures::future::try_join_all(
228 this.all_state_tables_mut()
229 .map(|table| table.commit_assert_no_update_vnode_bitmap(epoch)),
230 )
231 .await?;
232
233 Ok(chunk)
234 }
235
236 async fn try_flush_data(this: &mut ExecutorInner<S>) -> StreamExecutorResult<()> {
237 futures::future::try_join_all(this.all_state_tables_mut().map(|table| table.try_flush()))
238 .await?;
239 Ok(())
240 }
241
242 #[try_stream(ok = Message, error = StreamExecutorError)]
243 async fn execute_inner(self) {
244 let Self {
245 input,
246 inner: mut this,
247 } = self;
248
249 let mut input = input.execute();
250 let barrier = expect_first_barrier(&mut input).await?;
251 let first_epoch = barrier.epoch;
252 yield Message::Barrier(barrier);
253
254 for table in this.all_state_tables_mut() {
255 table.init_epoch(first_epoch).await?;
256 }
257
258 let distinct_dedup = DistinctDeduplicater::new(
259 &this.agg_calls,
260 this.watermark_epoch.clone(),
261 &this.distinct_dedup_tables,
262 &this.actor_ctx,
263 );
264
265 let (agg_group, _stats) = AggGroup::create(
267 this.version,
268 None,
269 &this.agg_calls,
270 &this.agg_funcs,
271 &this.storages,
272 &this.intermediate_state_table,
273 &this.input_pk_indices,
274 this.row_count_index,
275 false, this.extreme_cache_size,
277 &this.input_schema,
278 )
279 .await?;
280
281 let mut vars = ExecutionVars {
282 agg_group,
283 distinct_dedup,
284 state_changed: false,
285 };
286
287 #[for_await]
288 for msg in input {
289 let msg = msg?;
290 match msg {
291 Message::Watermark(_) => {}
292 Message::Chunk(chunk) => {
293 Self::apply_chunk(&mut this, &mut vars, chunk).await?;
294 Self::try_flush_data(&mut this).await?;
295 }
296 Message::Barrier(barrier) => {
297 if let Some(chunk) =
298 Self::flush_data(&mut this, &mut vars, barrier.epoch).await?
299 {
300 yield Message::Chunk(chunk);
301 }
302 yield Message::Barrier(barrier);
303 }
304 }
305 }
306 }
307}
308
309#[cfg(test)]
310mod tests {
311 use assert_matches::assert_matches;
312 use risingwave_common::array::stream_chunk::StreamChunkTestExt;
313 use risingwave_common::catalog::Field;
314 use risingwave_common::types::*;
315 use risingwave_common::util::epoch::test_epoch;
316 use risingwave_storage::memory::MemoryStateStore;
317
318 use super::*;
319 use crate::executor::test_utils::agg_executor::new_boxed_simple_agg_executor;
320 use crate::executor::test_utils::*;
321
322 #[tokio::test]
323 async fn test_simple_aggregation_in_memory() {
324 test_simple_aggregation(MemoryStateStore::new()).await
325 }
326
327 async fn test_simple_aggregation<S: StateStore>(store: S) {
328 let schema = Schema {
329 fields: vec![
330 Field::unnamed(DataType::Int64),
331 Field::unnamed(DataType::Int64),
332 Field::unnamed(DataType::Int64),
334 ],
335 };
336 let (mut tx, source) = MockSource::channel();
337 let source = source.into_executor(schema, vec![2]);
338 tx.push_barrier(test_epoch(1), false);
339 tx.push_barrier(test_epoch(2), false);
340 tx.push_chunk(StreamChunk::from_pretty(
341 " I I I
342 + 100 200 1001
343 + 10 14 1002
344 + 4 300 1003",
345 ));
346 tx.push_barrier(test_epoch(3), false);
347 tx.push_chunk(StreamChunk::from_pretty(
348 " I I I
349 - 100 200 1001
350 - 10 14 1002 D
351 - 4 300 1003
352 + 104 500 1004",
353 ));
354 tx.push_barrier(test_epoch(4), false);
355
356 let agg_calls = vec![
357 AggCall::from_pretty("(count:int8)"),
358 AggCall::from_pretty("(sum:int8 $0:int8)"),
359 AggCall::from_pretty("(sum:int8 $1:int8)"),
360 AggCall::from_pretty("(min:int8 $0:int8)"),
361 ];
362
363 let simple_agg = new_boxed_simple_agg_executor(
364 ActorContext::for_test(123),
365 store,
366 source,
367 false,
368 agg_calls,
369 0,
370 vec![2],
371 1,
372 false,
373 )
374 .await;
375 let mut simple_agg = simple_agg.execute();
376
377 simple_agg.next().await.unwrap().unwrap();
379 let msg = simple_agg.next().await.unwrap().unwrap();
381 assert_eq!(
382 *msg.as_chunk().unwrap(),
383 StreamChunk::from_pretty(
384 " I I I I
385 + 0 . . . "
386 )
387 );
388 assert_matches!(
389 simple_agg.next().await.unwrap().unwrap(),
390 Message::Barrier { .. }
391 );
392
393 let msg = simple_agg.next().await.unwrap().unwrap();
395 assert_eq!(
396 *msg.as_chunk().unwrap(),
397 StreamChunk::from_pretty(
398 " I I I I
399 U- 0 . . .
400 U+ 3 114 514 4"
401 )
402 );
403 assert_matches!(
404 simple_agg.next().await.unwrap().unwrap(),
405 Message::Barrier { .. }
406 );
407
408 let msg = simple_agg.next().await.unwrap().unwrap();
409 assert_eq!(
410 *msg.as_chunk().unwrap(),
411 StreamChunk::from_pretty(
412 " I I I I
413 U- 3 114 514 4
414 U+ 2 114 514 10"
415 )
416 );
417 }
418
419 #[tokio::test]
421 async fn test_simple_aggregation_always_output_per_epoch() {
422 let store = MemoryStateStore::new();
423 let schema = Schema {
424 fields: vec![
425 Field::unnamed(DataType::Int64),
426 Field::unnamed(DataType::Int64),
427 Field::unnamed(DataType::Int64),
429 ],
430 };
431 let (mut tx, source) = MockSource::channel();
432 let source = source.into_executor(schema, vec![2]);
433 tx.push_barrier(test_epoch(1), false);
435 tx.push_barrier(test_epoch(2), false);
437 tx.push_chunk(StreamChunk::from_pretty(
438 " I I I
439 + 100 200 1001
440 - 100 200 1001",
441 ));
442 tx.push_barrier(test_epoch(3), false);
443 tx.push_barrier(test_epoch(4), false);
444
445 let agg_calls = vec![
446 AggCall::from_pretty("(count:int8)"),
447 AggCall::from_pretty("(sum:int8 $0:int8)"),
448 AggCall::from_pretty("(sum:int8 $1:int8)"),
449 AggCall::from_pretty("(min:int8 $0:int8)"),
450 ];
451
452 let simple_agg = new_boxed_simple_agg_executor(
453 ActorContext::for_test(123),
454 store,
455 source,
456 false,
457 agg_calls,
458 0,
459 vec![2],
460 1,
461 true,
462 )
463 .await;
464 let mut simple_agg = simple_agg.execute();
465
466 simple_agg.next().await.unwrap().unwrap();
468 let msg = simple_agg.next().await.unwrap().unwrap();
470 assert_eq!(
471 *msg.as_chunk().unwrap(),
472 StreamChunk::from_pretty(
473 " I I I I
474 + 0 . . . "
475 )
476 );
477 assert_matches!(
478 simple_agg.next().await.unwrap().unwrap(),
479 Message::Barrier { .. }
480 );
481
482 let msg = simple_agg.next().await.unwrap().unwrap();
484 assert_eq!(
485 *msg.as_chunk().unwrap(),
486 StreamChunk::from_pretty(
487 " I I I I
488 U- 0 . . .
489 U+ 0 . . ."
490 )
491 );
492 assert_matches!(
493 simple_agg.next().await.unwrap().unwrap(),
494 Message::Barrier { .. }
495 );
496
497 let msg = simple_agg.next().await.unwrap().unwrap();
499 assert_eq!(
500 *msg.as_chunk().unwrap(),
501 StreamChunk::from_pretty(
502 " I I I I
503 U- 0 . . .
504 U+ 0 . . ."
505 )
506 );
507 assert_matches!(
508 simple_agg.next().await.unwrap().unwrap(),
509 Message::Barrier { .. }
510 );
511 }
512
513 #[tokio::test]
515 async fn test_simple_aggregation_omit_noop_update() {
516 let store = MemoryStateStore::new();
517 let schema = Schema {
518 fields: vec![
519 Field::unnamed(DataType::Int64),
520 Field::unnamed(DataType::Int64),
521 Field::unnamed(DataType::Int64),
523 ],
524 };
525 let (mut tx, source) = MockSource::channel();
526 let source = source.into_executor(schema, vec![2]);
527 tx.push_barrier(test_epoch(1), false);
529 tx.push_barrier(test_epoch(2), false);
531 tx.push_chunk(StreamChunk::from_pretty(
532 " I I I
533 + 100 200 1001
534 - 100 200 1001",
535 ));
536 tx.push_barrier(test_epoch(3), false);
537 tx.push_barrier(test_epoch(4), false);
538
539 let agg_calls = vec![
540 AggCall::from_pretty("(count:int8)"),
541 AggCall::from_pretty("(sum:int8 $0:int8)"),
542 AggCall::from_pretty("(sum:int8 $1:int8)"),
543 AggCall::from_pretty("(min:int8 $0:int8)"),
544 ];
545
546 let simple_agg = new_boxed_simple_agg_executor(
547 ActorContext::for_test(123),
548 store,
549 source,
550 false,
551 agg_calls,
552 0,
553 vec![2],
554 1,
555 false,
556 )
557 .await;
558 let mut simple_agg = simple_agg.execute();
559
560 simple_agg.next().await.unwrap().unwrap();
562 let msg = simple_agg.next().await.unwrap().unwrap();
564 assert_eq!(
565 *msg.as_chunk().unwrap(),
566 StreamChunk::from_pretty(
567 " I I I I
568 + 0 . . . "
569 )
570 );
571 assert_matches!(
572 simple_agg.next().await.unwrap().unwrap(),
573 Message::Barrier { .. }
574 );
575
576 assert_matches!(
578 simple_agg.next().await.unwrap().unwrap(),
579 Message::Barrier { .. }
580 );
581
582 assert_matches!(
584 simple_agg.next().await.unwrap().unwrap(),
585 Message::Barrier { .. }
586 );
587 }
588}