risingwave_stream/executor/aggregate/
stateless_simple_agg.rs
1use itertools::Itertools;
16use risingwave_common::array::Op;
17use risingwave_common::util::iter_util::ZipEqFast;
18use risingwave_expr::aggregate::{
19 AggCall, AggregateState, BoxedAggregateFunction, build_retractable,
20};
21
22use super::agg_call_filter_res;
23use crate::executor::prelude::*;
24
25pub struct StatelessSimpleAggExecutor {
26 _ctx: ActorContextRef,
27 pub(super) input: Executor,
28 pub(super) schema: Schema,
29 pub(super) aggs: Vec<BoxedAggregateFunction>,
30 pub(super) agg_calls: Vec<AggCall>,
31}
32
33impl Execute for StatelessSimpleAggExecutor {
34 fn execute(self: Box<Self>) -> BoxedMessageStream {
35 self.execute_inner().boxed()
36 }
37}
38
39impl StatelessSimpleAggExecutor {
40 async fn apply_chunk(
41 agg_calls: &[AggCall],
42 aggs: &[BoxedAggregateFunction],
43 states: &mut [AggregateState],
44 chunk: &StreamChunk,
45 ) -> StreamExecutorResult<()> {
46 for ((agg, call), state) in aggs.iter().zip_eq_fast(agg_calls).zip_eq_fast(states) {
47 let vis = agg_call_filter_res(call, chunk).await?;
48 let chunk = chunk.project_with_vis(call.args.val_indices(), vis);
49 agg.update(state, &chunk).await?;
50 }
51 Ok(())
52 }
53
54 #[try_stream(ok = Message, error = StreamExecutorError)]
55 async fn execute_inner(self) {
56 let StatelessSimpleAggExecutor {
57 _ctx,
58 input,
59 schema,
60 aggs,
61 agg_calls,
62 } = self;
63 let input = input.execute();
64 let mut is_dirty = false;
65 let mut states: Vec<_> = aggs.iter().map(|agg| agg.create_state()).try_collect()?;
66
67 #[for_await]
68 for msg in input {
69 let msg = msg?;
70 match msg {
71 Message::Watermark(_) => {}
72 Message::Chunk(chunk) => {
73 Self::apply_chunk(&agg_calls, &aggs, &mut states, &chunk).await?;
74 is_dirty = true;
75 }
76 m @ Message::Barrier(_) => {
77 if is_dirty {
78 is_dirty = false;
79
80 let mut builders = schema.create_array_builders(1);
81 for ((agg, state), builder) in aggs
82 .iter()
83 .zip_eq_fast(states.iter_mut())
84 .zip_eq_fast(builders.iter_mut())
85 {
86 let data = agg.get_result(state).await?;
87 *state = agg.create_state()?;
88 trace!("append: {:?}", data);
89 builder.append(data);
90 }
91 let columns = builders
92 .into_iter()
93 .map(|builder| Ok::<_, StreamExecutorError>(builder.finish().into()))
94 .try_collect()?;
95 let ops = vec![Op::Insert; 1];
96
97 yield Message::Chunk(StreamChunk::new(ops, columns));
98 }
99
100 yield m;
101 }
102 }
103 }
104 }
105}
106
107impl StatelessSimpleAggExecutor {
108 pub fn new(
109 ctx: ActorContextRef,
110 input: Executor,
111 schema: Schema,
112 agg_calls: Vec<AggCall>,
113 ) -> StreamResult<Self> {
114 let aggs = agg_calls.iter().map(build_retractable).try_collect()?;
115 Ok(StatelessSimpleAggExecutor {
116 _ctx: ctx,
117 input,
118 schema,
119 aggs,
120 agg_calls,
121 })
122 }
123}
124
125#[cfg(test)]
126mod tests {
127 use assert_matches::assert_matches;
128 use risingwave_common::array::stream_chunk::StreamChunkTestExt;
129 use risingwave_common::catalog::schema_test_utils;
130 use risingwave_common::util::epoch::test_epoch;
131
132 use super::*;
133 use crate::executor::test_utils::MockSource;
134 use crate::executor::test_utils::agg_executor::generate_agg_schema;
135
136 #[tokio::test]
137 async fn test_no_chunk() {
138 let schema = schema_test_utils::ii();
139 let (mut tx, source) = MockSource::channel();
140 let source = source.into_executor(schema, vec![2]);
141 tx.push_barrier(test_epoch(1), false);
142 tx.push_barrier(test_epoch(2), false);
143 tx.push_barrier(test_epoch(3), false);
144
145 let agg_calls = vec![AggCall::from_pretty("(count:int8)")];
146 let schema = generate_agg_schema(&source, &agg_calls, None);
147
148 let simple_agg =
149 StatelessSimpleAggExecutor::new(ActorContext::for_test(123), source, schema, agg_calls)
150 .unwrap();
151 let mut simple_agg = simple_agg.boxed().execute();
152
153 assert_matches!(
154 simple_agg.next().await.unwrap().unwrap(),
155 Message::Barrier { .. }
156 );
157 assert_matches!(
158 simple_agg.next().await.unwrap().unwrap(),
159 Message::Barrier { .. }
160 );
161 assert_matches!(
162 simple_agg.next().await.unwrap().unwrap(),
163 Message::Barrier { .. }
164 );
165 }
166
167 #[tokio::test]
168 async fn test_local_simple_agg() {
169 let schema = schema_test_utils::iii();
170 let (mut tx, source) = MockSource::channel();
171 let source = source.into_executor(schema, vec![2]);
172 tx.push_barrier(test_epoch(1), false);
173 tx.push_chunk(StreamChunk::from_pretty(
174 " I I I
175 + 100 200 1001
176 + 10 14 1002
177 + 4 300 1003",
178 ));
179 tx.push_barrier(test_epoch(2), false);
180 tx.push_chunk(StreamChunk::from_pretty(
181 " I I I
182 - 100 200 1001
183 - 10 14 1002 D
184 - 4 300 1003
185 + 104 500 1004",
186 ));
187 tx.push_barrier(test_epoch(3), false);
188
189 let agg_calls = vec![
190 AggCall::from_pretty("(count:int8)"),
191 AggCall::from_pretty("(sum:int8 $0:int8)"),
192 AggCall::from_pretty("(sum:int8 $1:int8)"),
193 ];
194 let schema = generate_agg_schema(&source, &agg_calls, None);
195
196 let simple_agg =
197 StatelessSimpleAggExecutor::new(ActorContext::for_test(123), source, schema, agg_calls)
198 .unwrap();
199 let mut simple_agg = simple_agg.boxed().execute();
200
201 simple_agg.next().await.unwrap().unwrap();
203 let msg = simple_agg.next().await.unwrap().unwrap();
205 assert_eq!(
206 msg.into_chunk().unwrap(),
207 StreamChunk::from_pretty(
208 " I I I
209 + 3 114 514"
210 )
211 );
212
213 assert_matches!(
214 simple_agg.next().await.unwrap().unwrap(),
215 Message::Barrier { .. }
216 );
217
218 let msg = simple_agg.next().await.unwrap().unwrap();
219 assert_eq!(
220 msg.into_chunk().unwrap(),
221 StreamChunk::from_pretty(
222 " I I I
223 + -1 0 0"
224 )
225 );
226 }
227}