risingwave_stream/executor/aggregate/
stateless_simple_agg.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use itertools::Itertools;
16use risingwave_common::array::Op;
17use risingwave_common::util::iter_util::ZipEqFast;
18use risingwave_expr::aggregate::{
19    AggCall, AggregateState, BoxedAggregateFunction, build_retractable,
20};
21
22use super::agg_call_filter_res;
23use crate::executor::prelude::*;
24
25pub struct StatelessSimpleAggExecutor {
26    _ctx: ActorContextRef,
27    pub(super) input: Executor,
28    pub(super) schema: Schema,
29    pub(super) aggs: Vec<BoxedAggregateFunction>,
30    pub(super) agg_calls: Vec<AggCall>,
31}
32
33impl Execute for StatelessSimpleAggExecutor {
34    fn execute(self: Box<Self>) -> BoxedMessageStream {
35        self.execute_inner().boxed()
36    }
37}
38
39impl StatelessSimpleAggExecutor {
40    async fn apply_chunk(
41        agg_calls: &[AggCall],
42        aggs: &[BoxedAggregateFunction],
43        states: &mut [AggregateState],
44        chunk: &StreamChunk,
45    ) -> StreamExecutorResult<()> {
46        for ((agg, call), state) in aggs.iter().zip_eq_fast(agg_calls).zip_eq_fast(states) {
47            let vis = agg_call_filter_res(call, chunk).await?;
48            let chunk = chunk.project_with_vis(call.args.val_indices(), vis);
49            agg.update(state, &chunk).await?;
50        }
51        Ok(())
52    }
53
54    #[try_stream(ok = Message, error = StreamExecutorError)]
55    async fn execute_inner(self) {
56        let StatelessSimpleAggExecutor {
57            _ctx,
58            input,
59            schema,
60            aggs,
61            agg_calls,
62        } = self;
63        let input = input.execute();
64        let mut is_dirty = false;
65        let mut states: Vec<_> = aggs.iter().map(|agg| agg.create_state()).try_collect()?;
66
67        #[for_await]
68        for msg in input {
69            let msg = msg?;
70            match msg {
71                Message::Watermark(_) => {}
72                Message::Chunk(chunk) => {
73                    Self::apply_chunk(&agg_calls, &aggs, &mut states, &chunk).await?;
74                    is_dirty = true;
75                }
76                m @ Message::Barrier(_) => {
77                    if is_dirty {
78                        is_dirty = false;
79
80                        let mut builders = schema.create_array_builders(1);
81                        for ((agg, state), builder) in aggs
82                            .iter()
83                            .zip_eq_fast(states.iter_mut())
84                            .zip_eq_fast(builders.iter_mut())
85                        {
86                            let data = agg.get_result(state).await?;
87                            *state = agg.create_state()?;
88                            trace!("append: {:?}", data);
89                            builder.append(data);
90                        }
91                        let columns = builders
92                            .into_iter()
93                            .map(|builder| Ok::<_, StreamExecutorError>(builder.finish().into()))
94                            .try_collect()?;
95                        let ops = vec![Op::Insert; 1];
96
97                        yield Message::Chunk(StreamChunk::new(ops, columns));
98                    }
99
100                    yield m;
101                }
102            }
103        }
104    }
105}
106
107impl StatelessSimpleAggExecutor {
108    pub fn new(
109        ctx: ActorContextRef,
110        input: Executor,
111        schema: Schema,
112        agg_calls: Vec<AggCall>,
113    ) -> StreamResult<Self> {
114        let aggs = agg_calls.iter().map(build_retractable).try_collect()?;
115        Ok(StatelessSimpleAggExecutor {
116            _ctx: ctx,
117            input,
118            schema,
119            aggs,
120            agg_calls,
121        })
122    }
123}
124
125#[cfg(test)]
126mod tests {
127    use assert_matches::assert_matches;
128    use risingwave_common::array::stream_chunk::StreamChunkTestExt;
129    use risingwave_common::catalog::schema_test_utils;
130    use risingwave_common::util::epoch::test_epoch;
131
132    use super::*;
133    use crate::executor::test_utils::MockSource;
134    use crate::executor::test_utils::agg_executor::generate_agg_schema;
135
136    #[tokio::test]
137    async fn test_no_chunk() {
138        let schema = schema_test_utils::ii();
139        let (mut tx, source) = MockSource::channel();
140        let source = source.into_executor(schema, vec![2]);
141        tx.push_barrier(test_epoch(1), false);
142        tx.push_barrier(test_epoch(2), false);
143        tx.push_barrier(test_epoch(3), false);
144
145        let agg_calls = vec![AggCall::from_pretty("(count:int8)")];
146        let schema = generate_agg_schema(&source, &agg_calls, None);
147
148        let simple_agg =
149            StatelessSimpleAggExecutor::new(ActorContext::for_test(123), source, schema, agg_calls)
150                .unwrap();
151        let mut simple_agg = simple_agg.boxed().execute();
152
153        assert_matches!(
154            simple_agg.next().await.unwrap().unwrap(),
155            Message::Barrier { .. }
156        );
157        assert_matches!(
158            simple_agg.next().await.unwrap().unwrap(),
159            Message::Barrier { .. }
160        );
161        assert_matches!(
162            simple_agg.next().await.unwrap().unwrap(),
163            Message::Barrier { .. }
164        );
165    }
166
167    #[tokio::test]
168    async fn test_local_simple_agg() {
169        let schema = schema_test_utils::iii();
170        let (mut tx, source) = MockSource::channel();
171        let source = source.into_executor(schema, vec![2]);
172        tx.push_barrier(test_epoch(1), false);
173        tx.push_chunk(StreamChunk::from_pretty(
174            "   I   I    I
175            + 100 200 1001
176            +  10  14 1002
177            +   4 300 1003",
178        ));
179        tx.push_barrier(test_epoch(2), false);
180        tx.push_chunk(StreamChunk::from_pretty(
181            "   I   I    I
182            - 100 200 1001
183            -  10  14 1002 D
184            -   4 300 1003
185            + 104 500 1004",
186        ));
187        tx.push_barrier(test_epoch(3), false);
188
189        let agg_calls = vec![
190            AggCall::from_pretty("(count:int8)"),
191            AggCall::from_pretty("(sum:int8 $0:int8)"),
192            AggCall::from_pretty("(sum:int8 $1:int8)"),
193        ];
194        let schema = generate_agg_schema(&source, &agg_calls, None);
195
196        let simple_agg =
197            StatelessSimpleAggExecutor::new(ActorContext::for_test(123), source, schema, agg_calls)
198                .unwrap();
199        let mut simple_agg = simple_agg.boxed().execute();
200
201        // Consume the init barrier
202        simple_agg.next().await.unwrap().unwrap();
203        // Consume stream chunk
204        let msg = simple_agg.next().await.unwrap().unwrap();
205        assert_eq!(
206            msg.into_chunk().unwrap(),
207            StreamChunk::from_pretty(
208                " I   I   I
209                + 3 114 514"
210            )
211        );
212
213        assert_matches!(
214            simple_agg.next().await.unwrap().unwrap(),
215            Message::Barrier { .. }
216        );
217
218        let msg = simple_agg.next().await.unwrap().unwrap();
219        assert_eq!(
220            msg.into_chunk().unwrap(),
221            StreamChunk::from_pretty(
222                "  I I I
223                + -1 0 0"
224            )
225        );
226    }
227}