1use std::collections::HashMap;
16
17use risingwave_common::array::stream_record::Record;
18use risingwave_common::util::epoch::EpochPair;
19use risingwave_common::util::iter_util::ZipEqFast;
20use risingwave_expr::aggregate::{AggCall, BoxedAggregateFunction, build_retractable};
21use risingwave_pb::stream_plan::PbAggNodeVersion;
22
23use super::agg_group::{AggGroup, AlwaysOutput};
24use super::agg_state::AggStateStorage;
25use super::distinct::DistinctDeduplicater;
26use super::{AggExecutorArgs, SimpleAggExecutorExtraArgs, agg_call_filter_res, iter_table_storage};
27use crate::executor::prelude::*;
28
29pub struct SimpleAggExecutor<S: StateStore> {
43 input: Executor,
44 inner: ExecutorInner<S>,
45}
46
47struct ExecutorInner<S: StateStore> {
48 version: PbAggNodeVersion,
50
51 actor_ctx: ActorContextRef,
52 info: ExecutorInfo,
53
54 input_stream_key: Vec<usize>,
56
57 input_schema: Schema,
59
60 agg_calls: Vec<AggCall>,
62
63 agg_funcs: Vec<BoxedAggregateFunction>,
65
66 row_count_index: usize,
68
69 storages: Vec<AggStateStorage<S>>,
71
72 intermediate_state_table: StateTable<S>,
76
77 distinct_dedup_tables: HashMap<usize, StateTable<S>>,
80
81 watermark_epoch: AtomicU64Ref,
83
84 extreme_cache_size: usize,
86
87 must_output_per_barrier: bool,
90}
91
92impl<S: StateStore> ExecutorInner<S> {
93 fn all_state_tables_mut(&mut self) -> impl Iterator<Item = &mut StateTable<S>> {
94 iter_table_storage(&mut self.storages)
95 .chain(self.distinct_dedup_tables.values_mut())
96 .chain(std::iter::once(&mut self.intermediate_state_table))
97 }
98}
99
100struct ExecutionVars<S: StateStore> {
101 agg_group: AggGroup<S, AlwaysOutput>,
103
104 distinct_dedup: DistinctDeduplicater<S>,
106
107 state_changed: bool,
109}
110
111impl<S: StateStore> Execute for SimpleAggExecutor<S> {
112 fn execute(self: Box<Self>) -> BoxedMessageStream {
113 self.execute_inner().boxed()
114 }
115}
116
117impl<S: StateStore> SimpleAggExecutor<S> {
118 pub fn new(args: AggExecutorArgs<S, SimpleAggExecutorExtraArgs>) -> StreamResult<Self> {
119 let input_info = args.input.info().clone();
120 Ok(Self {
121 input: args.input,
122 inner: ExecutorInner {
123 version: args.version,
124 actor_ctx: args.actor_ctx,
125 info: args.info,
126 input_stream_key: input_info.stream_key,
127 input_schema: input_info.schema,
128 agg_funcs: args.agg_calls.iter().map(build_retractable).try_collect()?,
129 agg_calls: args.agg_calls,
130 row_count_index: args.row_count_index,
131 storages: args.storages,
132 intermediate_state_table: args.intermediate_state_table,
133 distinct_dedup_tables: args.distinct_dedup_tables,
134 watermark_epoch: args.watermark_epoch,
135 extreme_cache_size: args.extreme_cache_size,
136 must_output_per_barrier: args.extra.must_output_per_barrier,
137 },
138 })
139 }
140
141 async fn apply_chunk(
142 this: &mut ExecutorInner<S>,
143 vars: &mut ExecutionVars<S>,
144 chunk: StreamChunk,
145 ) -> StreamExecutorResult<()> {
146 if chunk.cardinality() == 0 {
147 return Ok(());
149 }
150
151 let mut call_visibilities = Vec::with_capacity(this.agg_calls.len());
153 for agg_call in &this.agg_calls {
154 let vis = agg_call_filter_res(agg_call, &chunk).await?;
155 call_visibilities.push(vis);
156 }
157
158 let visibilities = vars
160 .distinct_dedup
161 .dedup_chunk(
162 chunk.ops(),
163 chunk.columns(),
164 call_visibilities,
165 &mut this.distinct_dedup_tables,
166 None,
167 )
168 .await?;
169
170 for (storage, visibility) in this.storages.iter_mut().zip_eq_fast(visibilities.iter()) {
172 if let AggStateStorage::MaterializedInput { table, mapping, .. } = storage {
173 let chunk = chunk.project_with_vis(mapping.upstream_columns(), visibility.clone());
174 table.write_chunk(chunk);
175 }
176 }
177
178 vars.agg_group
180 .apply_chunk(&chunk, &this.agg_calls, &this.agg_funcs, visibilities)
181 .await?;
182
183 vars.state_changed = true;
185
186 Ok(())
187 }
188
189 async fn flush_data(
190 this: &mut ExecutorInner<S>,
191 vars: &mut ExecutionVars<S>,
192 epoch: EpochPair,
193 ) -> StreamExecutorResult<Option<StreamChunk>> {
194 if vars.state_changed || vars.agg_group.is_uninitialized() {
195 vars.distinct_dedup.flush(&mut this.distinct_dedup_tables)?;
197
198 if let Some(inter_states_change) =
200 vars.agg_group.build_states_change(&this.agg_funcs)?
201 {
202 this.intermediate_state_table
203 .write_record(inter_states_change);
204 }
205 }
206 vars.state_changed = false;
207
208 let (outputs_change, _stats) = vars
210 .agg_group
211 .build_outputs_change(&this.storages, &this.agg_funcs)
212 .await?;
213
214 let change =
215 outputs_change.expect("`AlwaysOutput` strategy will output a change in any case");
216 let chunk = if !this.must_output_per_barrier
217 && let Record::Update { old_row, new_row } = &change
218 && old_row == new_row
219 {
220 None
222 } else {
223 Some(change.to_stream_chunk(&this.info.schema.data_types()))
224 };
225
226 futures::future::try_join_all(
228 this.all_state_tables_mut()
229 .map(|table| table.commit_assert_no_update_vnode_bitmap(epoch)),
230 )
231 .await?;
232
233 Ok(chunk)
234 }
235
236 async fn try_flush_data(this: &mut ExecutorInner<S>) -> StreamExecutorResult<()> {
237 futures::future::try_join_all(this.all_state_tables_mut().map(|table| table.try_flush()))
238 .await?;
239 Ok(())
240 }
241
242 #[try_stream(ok = Message, error = StreamExecutorError)]
243 async fn execute_inner(self) {
244 let Self {
245 input,
246 inner: mut this,
247 } = self;
248
249 let mut input = input.execute();
250 let barrier = expect_first_barrier(&mut input).await?;
251 let first_epoch = barrier.epoch;
252 yield Message::Barrier(barrier);
253
254 for table in this.all_state_tables_mut() {
255 table.init_epoch(first_epoch).await?;
256 }
257
258 let distinct_dedup = DistinctDeduplicater::new(
259 &this.agg_calls,
260 this.watermark_epoch.clone(),
261 &this.distinct_dedup_tables,
262 &this.actor_ctx,
263 );
264
265 let (agg_group, _stats) = AggGroup::create(
267 this.version,
268 None,
269 &this.agg_calls,
270 &this.agg_funcs,
271 &this.storages,
272 &this.intermediate_state_table,
273 &this.input_stream_key,
274 this.row_count_index,
275 false, this.extreme_cache_size,
277 &this.input_schema,
278 )
279 .await?;
280
281 let mut vars = ExecutionVars {
282 agg_group,
283 distinct_dedup,
284 state_changed: false,
285 };
286
287 #[for_await]
288 for msg in input {
289 let msg = msg?;
290 match msg {
291 Message::Watermark(_) => {}
292 Message::Chunk(chunk) => {
293 Self::apply_chunk(&mut this, &mut vars, chunk).await?;
294 Self::try_flush_data(&mut this).await?;
295 }
296 Message::Barrier(barrier) => {
297 if let Some(chunk) =
298 Self::flush_data(&mut this, &mut vars, barrier.epoch).await?
299 {
300 yield Message::Chunk(chunk);
301 }
302 yield Message::Barrier(barrier);
303 }
304 }
305 }
306 }
307}
308
309#[cfg(test)]
310mod tests {
311 use assert_matches::assert_matches;
312 use risingwave_common::array::stream_chunk::StreamChunkTestExt;
313 use risingwave_common::catalog::Field;
314 use risingwave_common::util::epoch::test_epoch;
315 use risingwave_storage::memory::MemoryStateStore;
316
317 use super::*;
318 use crate::executor::test_utils::agg_executor::new_boxed_simple_agg_executor;
319 use crate::executor::test_utils::*;
320
321 #[tokio::test]
322 async fn test_simple_aggregation_in_memory() {
323 test_simple_aggregation(MemoryStateStore::new()).await
324 }
325
326 async fn test_simple_aggregation<S: StateStore>(store: S) {
327 let schema = Schema {
328 fields: vec![
329 Field::unnamed(DataType::Int64),
330 Field::unnamed(DataType::Int64),
331 Field::unnamed(DataType::Int64),
333 ],
334 };
335 let (mut tx, source) = MockSource::channel();
336 let source = source.into_executor(schema, vec![2]);
337 tx.push_barrier(test_epoch(1), false);
338 tx.push_barrier(test_epoch(2), false);
339 tx.push_chunk(StreamChunk::from_pretty(
340 " I I I
341 + 100 200 1001
342 + 10 14 1002
343 + 4 300 1003",
344 ));
345 tx.push_barrier(test_epoch(3), false);
346 tx.push_chunk(StreamChunk::from_pretty(
347 " I I I
348 - 100 200 1001
349 - 10 14 1002 D
350 - 4 300 1003
351 + 104 500 1004",
352 ));
353 tx.push_barrier(test_epoch(4), false);
354
355 let agg_calls = vec![
356 AggCall::from_pretty("(count:int8)"),
357 AggCall::from_pretty("(sum:int8 $0:int8)"),
358 AggCall::from_pretty("(sum:int8 $1:int8)"),
359 AggCall::from_pretty("(min:int8 $0:int8)"),
360 ];
361
362 let simple_agg = new_boxed_simple_agg_executor(
363 ActorContext::for_test(123),
364 store,
365 source,
366 false,
367 agg_calls,
368 0,
369 vec![2],
370 1,
371 false,
372 )
373 .await;
374 let mut simple_agg = simple_agg.execute();
375
376 simple_agg.next().await.unwrap().unwrap();
378 let msg = simple_agg.next().await.unwrap().unwrap();
380 assert_eq!(
381 *msg.as_chunk().unwrap(),
382 StreamChunk::from_pretty(
383 " I I I I
384 + 0 . . . "
385 )
386 );
387 assert_matches!(
388 simple_agg.next().await.unwrap().unwrap(),
389 Message::Barrier { .. }
390 );
391
392 let msg = simple_agg.next().await.unwrap().unwrap();
394 assert_eq!(
395 *msg.as_chunk().unwrap(),
396 StreamChunk::from_pretty(
397 " I I I I
398 U- 0 . . .
399 U+ 3 114 514 4"
400 )
401 );
402 assert_matches!(
403 simple_agg.next().await.unwrap().unwrap(),
404 Message::Barrier { .. }
405 );
406
407 let msg = simple_agg.next().await.unwrap().unwrap();
408 assert_eq!(
409 *msg.as_chunk().unwrap(),
410 StreamChunk::from_pretty(
411 " I I I I
412 U- 3 114 514 4
413 U+ 2 114 514 10"
414 )
415 );
416 }
417
418 #[tokio::test]
420 async fn test_simple_aggregation_always_output_per_epoch() {
421 let store = MemoryStateStore::new();
422 let schema = Schema {
423 fields: vec![
424 Field::unnamed(DataType::Int64),
425 Field::unnamed(DataType::Int64),
426 Field::unnamed(DataType::Int64),
428 ],
429 };
430 let (mut tx, source) = MockSource::channel();
431 let source = source.into_executor(schema, vec![2]);
432 tx.push_barrier(test_epoch(1), false);
434 tx.push_barrier(test_epoch(2), false);
436 tx.push_chunk(StreamChunk::from_pretty(
437 " I I I
438 + 100 200 1001
439 - 100 200 1001",
440 ));
441 tx.push_barrier(test_epoch(3), false);
442 tx.push_barrier(test_epoch(4), false);
443
444 let agg_calls = vec![
445 AggCall::from_pretty("(count:int8)"),
446 AggCall::from_pretty("(sum:int8 $0:int8)"),
447 AggCall::from_pretty("(sum:int8 $1:int8)"),
448 AggCall::from_pretty("(min:int8 $0:int8)"),
449 ];
450
451 let simple_agg = new_boxed_simple_agg_executor(
452 ActorContext::for_test(123),
453 store,
454 source,
455 false,
456 agg_calls,
457 0,
458 vec![2],
459 1,
460 true,
461 )
462 .await;
463 let mut simple_agg = simple_agg.execute();
464
465 simple_agg.next().await.unwrap().unwrap();
467 let msg = simple_agg.next().await.unwrap().unwrap();
469 assert_eq!(
470 *msg.as_chunk().unwrap(),
471 StreamChunk::from_pretty(
472 " I I I I
473 + 0 . . . "
474 )
475 );
476 assert_matches!(
477 simple_agg.next().await.unwrap().unwrap(),
478 Message::Barrier { .. }
479 );
480
481 let msg = simple_agg.next().await.unwrap().unwrap();
483 assert_eq!(
484 *msg.as_chunk().unwrap(),
485 StreamChunk::from_pretty(
486 " I I I I
487 U- 0 . . .
488 U+ 0 . . ."
489 )
490 );
491 assert_matches!(
492 simple_agg.next().await.unwrap().unwrap(),
493 Message::Barrier { .. }
494 );
495
496 let msg = simple_agg.next().await.unwrap().unwrap();
498 assert_eq!(
499 *msg.as_chunk().unwrap(),
500 StreamChunk::from_pretty(
501 " I I I I
502 U- 0 . . .
503 U+ 0 . . ."
504 )
505 );
506 assert_matches!(
507 simple_agg.next().await.unwrap().unwrap(),
508 Message::Barrier { .. }
509 );
510 }
511
512 #[tokio::test]
514 async fn test_simple_aggregation_omit_noop_update() {
515 let store = MemoryStateStore::new();
516 let schema = Schema {
517 fields: vec![
518 Field::unnamed(DataType::Int64),
519 Field::unnamed(DataType::Int64),
520 Field::unnamed(DataType::Int64),
522 ],
523 };
524 let (mut tx, source) = MockSource::channel();
525 let source = source.into_executor(schema, vec![2]);
526 tx.push_barrier(test_epoch(1), false);
528 tx.push_barrier(test_epoch(2), false);
530 tx.push_chunk(StreamChunk::from_pretty(
531 " I I I
532 + 100 200 1001
533 - 100 200 1001",
534 ));
535 tx.push_barrier(test_epoch(3), false);
536 tx.push_barrier(test_epoch(4), false);
537
538 let agg_calls = vec![
539 AggCall::from_pretty("(count:int8)"),
540 AggCall::from_pretty("(sum:int8 $0:int8)"),
541 AggCall::from_pretty("(sum:int8 $1:int8)"),
542 AggCall::from_pretty("(min:int8 $0:int8)"),
543 ];
544
545 let simple_agg = new_boxed_simple_agg_executor(
546 ActorContext::for_test(123),
547 store,
548 source,
549 false,
550 agg_calls,
551 0,
552 vec![2],
553 1,
554 false,
555 )
556 .await;
557 let mut simple_agg = simple_agg.execute();
558
559 simple_agg.next().await.unwrap().unwrap();
561 let msg = simple_agg.next().await.unwrap().unwrap();
563 assert_eq!(
564 *msg.as_chunk().unwrap(),
565 StreamChunk::from_pretty(
566 " I I I I
567 + 0 . . . "
568 )
569 );
570 assert_matches!(
571 simple_agg.next().await.unwrap().unwrap(),
572 Message::Barrier { .. }
573 );
574
575 assert_matches!(
577 simple_agg.next().await.unwrap().unwrap(),
578 Message::Barrier { .. }
579 );
580
581 assert_matches!(
583 simple_agg.next().await.unwrap().unwrap(),
584 Message::Barrier { .. }
585 );
586 }
587}