1use std::collections::HashMap;
16
17use risingwave_common::array::stream_record::Record;
18use risingwave_common::util::epoch::EpochPair;
19use risingwave_common::util::iter_util::ZipEqFast;
20use risingwave_expr::aggregate::{AggCall, BoxedAggregateFunction, build_retractable};
21use risingwave_pb::stream_plan::PbAggNodeVersion;
22
23use super::agg_group::{AggGroup, AlwaysOutput};
24use super::agg_state::AggStateStorage;
25use super::distinct::DistinctDeduplicater;
26use super::{AggExecutorArgs, SimpleAggExecutorExtraArgs, agg_call_filter_res, iter_table_storage};
27use crate::executor::prelude::*;
28
29pub struct SimpleAggExecutor<S: StateStore> {
43 input: Executor,
44 inner: ExecutorInner<S>,
45}
46
47struct ExecutorInner<S: StateStore> {
48 version: PbAggNodeVersion,
50
51 actor_ctx: ActorContextRef,
52 info: ExecutorInfo,
53
54 input_pk_indices: Vec<usize>,
56
57 input_schema: Schema,
59
60 agg_calls: Vec<AggCall>,
62
63 agg_funcs: Vec<BoxedAggregateFunction>,
65
66 row_count_index: usize,
68
69 storages: Vec<AggStateStorage<S>>,
71
72 intermediate_state_table: StateTable<S>,
76
77 distinct_dedup_tables: HashMap<usize, StateTable<S>>,
80
81 watermark_epoch: AtomicU64Ref,
83
84 extreme_cache_size: usize,
86
87 must_output_per_barrier: bool,
90}
91
92impl<S: StateStore> ExecutorInner<S> {
93 fn all_state_tables_mut(&mut self) -> impl Iterator<Item = &mut StateTable<S>> {
94 iter_table_storage(&mut self.storages)
95 .chain(self.distinct_dedup_tables.values_mut())
96 .chain(std::iter::once(&mut self.intermediate_state_table))
97 }
98}
99
100struct ExecutionVars<S: StateStore> {
101 agg_group: AggGroup<S, AlwaysOutput>,
103
104 distinct_dedup: DistinctDeduplicater<S>,
106
107 state_changed: bool,
109}
110
111impl<S: StateStore> Execute for SimpleAggExecutor<S> {
112 fn execute(self: Box<Self>) -> BoxedMessageStream {
113 self.execute_inner().boxed()
114 }
115}
116
117impl<S: StateStore> SimpleAggExecutor<S> {
118 pub fn new(args: AggExecutorArgs<S, SimpleAggExecutorExtraArgs>) -> StreamResult<Self> {
119 let input_info = args.input.info().clone();
120 Ok(Self {
121 input: args.input,
122 inner: ExecutorInner {
123 version: args.version,
124 actor_ctx: args.actor_ctx,
125 info: args.info,
126 input_pk_indices: input_info.pk_indices,
127 input_schema: input_info.schema,
128 agg_funcs: args.agg_calls.iter().map(build_retractable).try_collect()?,
129 agg_calls: args.agg_calls,
130 row_count_index: args.row_count_index,
131 storages: args.storages,
132 intermediate_state_table: args.intermediate_state_table,
133 distinct_dedup_tables: args.distinct_dedup_tables,
134 watermark_epoch: args.watermark_epoch,
135 extreme_cache_size: args.extreme_cache_size,
136 must_output_per_barrier: args.extra.must_output_per_barrier,
137 },
138 })
139 }
140
141 async fn apply_chunk(
142 this: &mut ExecutorInner<S>,
143 vars: &mut ExecutionVars<S>,
144 chunk: StreamChunk,
145 ) -> StreamExecutorResult<()> {
146 if chunk.cardinality() == 0 {
147 return Ok(());
149 }
150
151 let mut call_visibilities = Vec::with_capacity(this.agg_calls.len());
153 for agg_call in &this.agg_calls {
154 let vis = agg_call_filter_res(agg_call, &chunk).await?;
155 call_visibilities.push(vis);
156 }
157
158 let visibilities = vars
160 .distinct_dedup
161 .dedup_chunk(
162 chunk.ops(),
163 chunk.columns(),
164 call_visibilities,
165 &mut this.distinct_dedup_tables,
166 None,
167 )
168 .await?;
169
170 for (storage, visibility) in this.storages.iter_mut().zip_eq_fast(visibilities.iter()) {
172 if let AggStateStorage::MaterializedInput { table, mapping, .. } = storage {
173 let chunk = chunk.project_with_vis(mapping.upstream_columns(), visibility.clone());
174 table.write_chunk(chunk);
175 }
176 }
177
178 vars.agg_group
180 .apply_chunk(&chunk, &this.agg_calls, &this.agg_funcs, visibilities)
181 .await?;
182
183 vars.state_changed = true;
185
186 Ok(())
187 }
188
189 async fn flush_data(
190 this: &mut ExecutorInner<S>,
191 vars: &mut ExecutionVars<S>,
192 epoch: EpochPair,
193 ) -> StreamExecutorResult<Option<StreamChunk>> {
194 if vars.state_changed || vars.agg_group.is_uninitialized() {
195 vars.distinct_dedup.flush(&mut this.distinct_dedup_tables)?;
197
198 if let Some(inter_states_change) =
200 vars.agg_group.build_states_change(&this.agg_funcs)?
201 {
202 this.intermediate_state_table
203 .write_record(inter_states_change);
204 }
205 }
206 vars.state_changed = false;
207
208 let (outputs_change, _stats) = vars
210 .agg_group
211 .build_outputs_change(&this.storages, &this.agg_funcs)
212 .await?;
213
214 let change =
215 outputs_change.expect("`AlwaysOutput` strategy will output a change in any case");
216 let chunk = if !this.must_output_per_barrier
217 && let Record::Update { old_row, new_row } = &change
218 && old_row == new_row
219 {
220 None
222 } else {
223 Some(change.to_stream_chunk(&this.info.schema.data_types()))
224 };
225
226 futures::future::try_join_all(
228 this.all_state_tables_mut()
229 .map(|table| table.commit_assert_no_update_vnode_bitmap(epoch)),
230 )
231 .await?;
232
233 Ok(chunk)
234 }
235
236 async fn try_flush_data(this: &mut ExecutorInner<S>) -> StreamExecutorResult<()> {
237 futures::future::try_join_all(this.all_state_tables_mut().map(|table| table.try_flush()))
238 .await?;
239 Ok(())
240 }
241
242 #[try_stream(ok = Message, error = StreamExecutorError)]
243 async fn execute_inner(self) {
244 let Self {
245 input,
246 inner: mut this,
247 } = self;
248
249 let mut input = input.execute();
250 let barrier = expect_first_barrier(&mut input).await?;
251 let first_epoch = barrier.epoch;
252 yield Message::Barrier(barrier);
253
254 for table in this.all_state_tables_mut() {
255 table.init_epoch(first_epoch).await?;
256 }
257
258 let distinct_dedup = DistinctDeduplicater::new(
259 &this.agg_calls,
260 this.watermark_epoch.clone(),
261 &this.distinct_dedup_tables,
262 &this.actor_ctx,
263 );
264
265 let mut vars = ExecutionVars {
267 agg_group: AggGroup::create(
268 this.version,
269 None,
270 &this.agg_calls,
271 &this.agg_funcs,
272 &this.storages,
273 &this.intermediate_state_table,
274 &this.input_pk_indices,
275 this.row_count_index,
276 false, this.extreme_cache_size,
278 &this.input_schema,
279 )
280 .await?,
281 distinct_dedup,
282 state_changed: false,
283 };
284
285 #[for_await]
286 for msg in input {
287 let msg = msg?;
288 match msg {
289 Message::Watermark(_) => {}
290 Message::Chunk(chunk) => {
291 Self::apply_chunk(&mut this, &mut vars, chunk).await?;
292 Self::try_flush_data(&mut this).await?;
293 }
294 Message::Barrier(barrier) => {
295 if let Some(chunk) =
296 Self::flush_data(&mut this, &mut vars, barrier.epoch).await?
297 {
298 yield Message::Chunk(chunk);
299 }
300 yield Message::Barrier(barrier);
301 }
302 }
303 }
304 }
305}
306
307#[cfg(test)]
308mod tests {
309 use assert_matches::assert_matches;
310 use risingwave_common::array::stream_chunk::StreamChunkTestExt;
311 use risingwave_common::catalog::Field;
312 use risingwave_common::types::*;
313 use risingwave_common::util::epoch::test_epoch;
314 use risingwave_storage::memory::MemoryStateStore;
315
316 use super::*;
317 use crate::executor::test_utils::agg_executor::new_boxed_simple_agg_executor;
318 use crate::executor::test_utils::*;
319
320 #[tokio::test]
321 async fn test_simple_aggregation_in_memory() {
322 test_simple_aggregation(MemoryStateStore::new()).await
323 }
324
325 async fn test_simple_aggregation<S: StateStore>(store: S) {
326 let schema = Schema {
327 fields: vec![
328 Field::unnamed(DataType::Int64),
329 Field::unnamed(DataType::Int64),
330 Field::unnamed(DataType::Int64),
332 ],
333 };
334 let (mut tx, source) = MockSource::channel();
335 let source = source.into_executor(schema, vec![2]);
336 tx.push_barrier(test_epoch(1), false);
337 tx.push_barrier(test_epoch(2), false);
338 tx.push_chunk(StreamChunk::from_pretty(
339 " I I I
340 + 100 200 1001
341 + 10 14 1002
342 + 4 300 1003",
343 ));
344 tx.push_barrier(test_epoch(3), false);
345 tx.push_chunk(StreamChunk::from_pretty(
346 " I I I
347 - 100 200 1001
348 - 10 14 1002 D
349 - 4 300 1003
350 + 104 500 1004",
351 ));
352 tx.push_barrier(test_epoch(4), false);
353
354 let agg_calls = vec![
355 AggCall::from_pretty("(count:int8)"),
356 AggCall::from_pretty("(sum:int8 $0:int8)"),
357 AggCall::from_pretty("(sum:int8 $1:int8)"),
358 AggCall::from_pretty("(min:int8 $0:int8)"),
359 ];
360
361 let simple_agg = new_boxed_simple_agg_executor(
362 ActorContext::for_test(123),
363 store,
364 source,
365 false,
366 agg_calls,
367 0,
368 vec![2],
369 1,
370 false,
371 )
372 .await;
373 let mut simple_agg = simple_agg.execute();
374
375 simple_agg.next().await.unwrap().unwrap();
377 let msg = simple_agg.next().await.unwrap().unwrap();
379 assert_eq!(
380 *msg.as_chunk().unwrap(),
381 StreamChunk::from_pretty(
382 " I I I I
383 + 0 . . . "
384 )
385 );
386 assert_matches!(
387 simple_agg.next().await.unwrap().unwrap(),
388 Message::Barrier { .. }
389 );
390
391 let msg = simple_agg.next().await.unwrap().unwrap();
393 assert_eq!(
394 *msg.as_chunk().unwrap(),
395 StreamChunk::from_pretty(
396 " I I I I
397 U- 0 . . .
398 U+ 3 114 514 4"
399 )
400 );
401 assert_matches!(
402 simple_agg.next().await.unwrap().unwrap(),
403 Message::Barrier { .. }
404 );
405
406 let msg = simple_agg.next().await.unwrap().unwrap();
407 assert_eq!(
408 *msg.as_chunk().unwrap(),
409 StreamChunk::from_pretty(
410 " I I I I
411 U- 3 114 514 4
412 U+ 2 114 514 10"
413 )
414 );
415 }
416
417 #[tokio::test]
419 async fn test_simple_aggregation_always_output_per_epoch() {
420 let store = MemoryStateStore::new();
421 let schema = Schema {
422 fields: vec![
423 Field::unnamed(DataType::Int64),
424 Field::unnamed(DataType::Int64),
425 Field::unnamed(DataType::Int64),
427 ],
428 };
429 let (mut tx, source) = MockSource::channel();
430 let source = source.into_executor(schema, vec![2]);
431 tx.push_barrier(test_epoch(1), false);
433 tx.push_barrier(test_epoch(2), false);
435 tx.push_chunk(StreamChunk::from_pretty(
436 " I I I
437 + 100 200 1001
438 - 100 200 1001",
439 ));
440 tx.push_barrier(test_epoch(3), false);
441 tx.push_barrier(test_epoch(4), false);
442
443 let agg_calls = vec![
444 AggCall::from_pretty("(count:int8)"),
445 AggCall::from_pretty("(sum:int8 $0:int8)"),
446 AggCall::from_pretty("(sum:int8 $1:int8)"),
447 AggCall::from_pretty("(min:int8 $0:int8)"),
448 ];
449
450 let simple_agg = new_boxed_simple_agg_executor(
451 ActorContext::for_test(123),
452 store,
453 source,
454 false,
455 agg_calls,
456 0,
457 vec![2],
458 1,
459 true,
460 )
461 .await;
462 let mut simple_agg = simple_agg.execute();
463
464 simple_agg.next().await.unwrap().unwrap();
466 let msg = simple_agg.next().await.unwrap().unwrap();
468 assert_eq!(
469 *msg.as_chunk().unwrap(),
470 StreamChunk::from_pretty(
471 " I I I I
472 + 0 . . . "
473 )
474 );
475 assert_matches!(
476 simple_agg.next().await.unwrap().unwrap(),
477 Message::Barrier { .. }
478 );
479
480 let msg = simple_agg.next().await.unwrap().unwrap();
482 assert_eq!(
483 *msg.as_chunk().unwrap(),
484 StreamChunk::from_pretty(
485 " I I I I
486 U- 0 . . .
487 U+ 0 . . ."
488 )
489 );
490 assert_matches!(
491 simple_agg.next().await.unwrap().unwrap(),
492 Message::Barrier { .. }
493 );
494
495 let msg = simple_agg.next().await.unwrap().unwrap();
497 assert_eq!(
498 *msg.as_chunk().unwrap(),
499 StreamChunk::from_pretty(
500 " I I I I
501 U- 0 . . .
502 U+ 0 . . ."
503 )
504 );
505 assert_matches!(
506 simple_agg.next().await.unwrap().unwrap(),
507 Message::Barrier { .. }
508 );
509 }
510
511 #[tokio::test]
513 async fn test_simple_aggregation_omit_noop_update() {
514 let store = MemoryStateStore::new();
515 let schema = Schema {
516 fields: vec![
517 Field::unnamed(DataType::Int64),
518 Field::unnamed(DataType::Int64),
519 Field::unnamed(DataType::Int64),
521 ],
522 };
523 let (mut tx, source) = MockSource::channel();
524 let source = source.into_executor(schema, vec![2]);
525 tx.push_barrier(test_epoch(1), false);
527 tx.push_barrier(test_epoch(2), false);
529 tx.push_chunk(StreamChunk::from_pretty(
530 " I I I
531 + 100 200 1001
532 - 100 200 1001",
533 ));
534 tx.push_barrier(test_epoch(3), false);
535 tx.push_barrier(test_epoch(4), false);
536
537 let agg_calls = vec![
538 AggCall::from_pretty("(count:int8)"),
539 AggCall::from_pretty("(sum:int8 $0:int8)"),
540 AggCall::from_pretty("(sum:int8 $1:int8)"),
541 AggCall::from_pretty("(min:int8 $0:int8)"),
542 ];
543
544 let simple_agg = new_boxed_simple_agg_executor(
545 ActorContext::for_test(123),
546 store,
547 source,
548 false,
549 agg_calls,
550 0,
551 vec![2],
552 1,
553 false,
554 )
555 .await;
556 let mut simple_agg = simple_agg.execute();
557
558 simple_agg.next().await.unwrap().unwrap();
560 let msg = simple_agg.next().await.unwrap().unwrap();
562 assert_eq!(
563 *msg.as_chunk().unwrap(),
564 StreamChunk::from_pretty(
565 " I I I I
566 + 0 . . . "
567 )
568 );
569 assert_matches!(
570 simple_agg.next().await.unwrap().unwrap(),
571 Message::Barrier { .. }
572 );
573
574 assert_matches!(
576 simple_agg.next().await.unwrap().unwrap(),
577 Message::Barrier { .. }
578 );
579
580 assert_matches!(
582 simple_agg.next().await.unwrap().unwrap(),
583 Message::Barrier { .. }
584 );
585 }
586}