risingwave_stream/executor/project/
project_scalar.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use multimap::MultiMap;
16use risingwave_common::row::RowExt;
17use risingwave_common::types::ToOwnedDatum;
18use risingwave_common::util::iter_util::ZipEqFast;
19use risingwave_expr::expr::NonStrictExpression;
20
21use crate::executor::prelude::*;
22
23/// `ProjectExecutor` project data with the `expr`. The `expr` takes a chunk of data,
24/// and returns a new data chunk. And then, `ProjectExecutor` will insert, delete
25/// or update element into next operator according to the result of the expression.
26pub struct ProjectExecutor {
27    input: Executor,
28    inner: Inner,
29}
30
31struct Inner {
32    _ctx: ActorContextRef,
33
34    /// Expressions of the current projection.
35    exprs: Vec<NonStrictExpression>,
36    /// All the watermark derivations, (`input_column_index`, `output_column_index`). And the
37    /// derivation expression is the project's expression itself.
38    watermark_derivations: MultiMap<usize, usize>,
39    /// Indices of nondecreasing expressions in the expression list.
40    nondecreasing_expr_indices: Vec<usize>,
41    /// Last seen values of nondecreasing expressions, buffered to periodically produce watermarks.
42    last_nondec_expr_values: Vec<Option<ScalarImpl>>,
43    /// Whether the stream is paused.
44    is_paused: bool,
45
46    /// Whether there are likely no-op updates in the output chunks, so that eliminating them with
47    /// `StreamChunk::eliminate_adjacent_noop_update` could be beneficial.
48    eliminate_noop_updates: bool,
49}
50
51impl ProjectExecutor {
52    pub fn new(
53        ctx: ActorContextRef,
54        input: Executor,
55        exprs: Vec<NonStrictExpression>,
56        watermark_derivations: MultiMap<usize, usize>,
57        nondecreasing_expr_indices: Vec<usize>,
58        noop_update_hint: bool,
59    ) -> Self {
60        let n_nondecreasing_exprs = nondecreasing_expr_indices.len();
61        let eliminate_noop_updates =
62            noop_update_hint || ctx.config.developer.aggressive_noop_update_elimination;
63        Self {
64            input,
65            inner: Inner {
66                _ctx: ctx,
67                exprs,
68                watermark_derivations,
69                nondecreasing_expr_indices,
70                last_nondec_expr_values: vec![None; n_nondecreasing_exprs],
71                is_paused: false,
72                eliminate_noop_updates,
73            },
74        }
75    }
76}
77
78impl Debug for ProjectExecutor {
79    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
80        f.debug_struct("ProjectExecutor")
81            .field("exprs", &self.inner.exprs)
82            .finish()
83    }
84}
85
86impl Execute for ProjectExecutor {
87    fn execute(self: Box<Self>) -> BoxedMessageStream {
88        self.inner.execute(self.input).boxed()
89    }
90}
91
92pub async fn apply_project_exprs(
93    exprs: &[NonStrictExpression],
94    chunk: StreamChunk,
95) -> StreamExecutorResult<StreamChunk> {
96    let (data_chunk, ops) = chunk.into_parts();
97    let mut projected_columns = Vec::new();
98
99    for expr in exprs {
100        let evaluated_expr = expr.eval_infallible(&data_chunk).await;
101        projected_columns.push(evaluated_expr);
102    }
103    let (_, vis) = data_chunk.into_parts();
104
105    let new_chunk = StreamChunk::with_visibility(ops, projected_columns, vis);
106
107    Ok(new_chunk)
108}
109
110impl Inner {
111    async fn map_filter_chunk(
112        &self,
113        chunk: StreamChunk,
114    ) -> StreamExecutorResult<Option<StreamChunk>> {
115        let mut new_chunk = apply_project_exprs(&self.exprs, chunk).await?;
116        if self.eliminate_noop_updates {
117            new_chunk = new_chunk.eliminate_adjacent_noop_update();
118        }
119        Ok(Some(new_chunk))
120    }
121
122    async fn handle_watermark(&self, watermark: Watermark) -> StreamExecutorResult<Vec<Watermark>> {
123        let out_col_indices = match self.watermark_derivations.get_vec(&watermark.col_idx) {
124            Some(v) => v,
125            None => return Ok(vec![]),
126        };
127        let mut ret = vec![];
128        for out_col_idx in out_col_indices {
129            let out_col_idx = *out_col_idx;
130            let derived_watermark = watermark
131                .clone()
132                .transform_with_expr(&self.exprs[out_col_idx], out_col_idx)
133                .await;
134            if let Some(derived_watermark) = derived_watermark {
135                ret.push(derived_watermark);
136            } else {
137                warn!(
138                    "a NULL watermark is derived with the expression {}!",
139                    out_col_idx
140                );
141            }
142        }
143        Ok(ret)
144    }
145
146    #[try_stream(ok = Message, error = StreamExecutorError)]
147    async fn execute(mut self, input: Executor) {
148        let mut input = input.execute();
149        let first_barrier = expect_first_barrier(&mut input).await?;
150        self.is_paused = first_barrier.is_pause_on_startup();
151        yield Message::Barrier(first_barrier);
152
153        #[for_await]
154        for msg in input {
155            let msg = msg?;
156            match msg {
157                Message::Watermark(w) => {
158                    let watermarks = self.handle_watermark(w).await?;
159                    for watermark in watermarks {
160                        yield Message::Watermark(watermark)
161                    }
162                }
163                Message::Chunk(chunk) => match self.map_filter_chunk(chunk).await? {
164                    Some(new_chunk) => {
165                        if !self.nondecreasing_expr_indices.is_empty()
166                            && let Some((_, first_visible_row)) = new_chunk.rows().next()
167                        {
168                            // it's ok to use the first row here, just one chunk delay
169                            first_visible_row
170                                .project(&self.nondecreasing_expr_indices)
171                                .iter()
172                                .enumerate()
173                                .for_each(|(idx, value)| {
174                                    self.last_nondec_expr_values[idx] =
175                                        Some(value.to_owned_datum().expect(
176                                            "non-decreasing expression should never be NULL",
177                                        ));
178                                });
179                        }
180                        yield Message::Chunk(new_chunk)
181                    }
182                    None => continue,
183                },
184                Message::Barrier(barrier) => {
185                    if !self.is_paused {
186                        for (&expr_idx, value) in self
187                            .nondecreasing_expr_indices
188                            .iter()
189                            .zip_eq_fast(&mut self.last_nondec_expr_values)
190                        {
191                            if let Some(value) = std::mem::take(value) {
192                                yield Message::Watermark(Watermark::new(
193                                    expr_idx,
194                                    self.exprs[expr_idx].return_type(),
195                                    value,
196                                ))
197                            }
198                        }
199                    }
200
201                    if let Some(mutation) = barrier.mutation.as_deref() {
202                        match mutation {
203                            Mutation::Pause => {
204                                self.is_paused = true;
205                            }
206                            Mutation::Resume => {
207                                self.is_paused = false;
208                            }
209                            _ => (),
210                        }
211                    }
212
213                    yield Message::Barrier(barrier);
214                }
215            }
216        }
217    }
218}
219
220#[cfg(test)]
221mod tests {
222    use std::sync::atomic::{self, AtomicI64};
223
224    use risingwave_common::array::DataChunk;
225    use risingwave_common::array::stream_chunk::StreamChunkTestExt;
226    use risingwave_common::catalog::Field;
227    use risingwave_common::types::DefaultOrd;
228    use risingwave_common::util::epoch::test_epoch;
229    use risingwave_expr::expr::{self, Expression, ValueImpl};
230
231    use super::*;
232    use crate::executor::test_utils::expr::build_from_pretty;
233    use crate::executor::test_utils::{MockSource, StreamExecutorTestExt};
234
235    #[tokio::test]
236    async fn test_projection() {
237        let chunk1 = StreamChunk::from_pretty(
238            " I I
239            + 1 4
240            + 2 5
241            + 3 6",
242        );
243        let chunk2 = StreamChunk::from_pretty(
244            " I I
245            + 7 8
246            - 3 6",
247        );
248        let schema = Schema {
249            fields: vec![
250                Field::unnamed(DataType::Int64),
251                Field::unnamed(DataType::Int64),
252            ],
253        };
254        let stream_key = vec![0];
255        let (mut tx, source) = MockSource::channel();
256        let source = source.into_executor(schema, stream_key);
257
258        let test_expr = build_from_pretty("(add:int8 $0:int8 $1:int8)");
259
260        let proj = ProjectExecutor::new(
261            ActorContext::for_test(123),
262            source,
263            vec![test_expr],
264            MultiMap::new(),
265            vec![],
266            false,
267        );
268        let mut proj = proj.boxed().execute();
269
270        tx.push_barrier(test_epoch(1), false);
271        let barrier = proj.next().await.unwrap().unwrap();
272        barrier.as_barrier().unwrap();
273
274        tx.push_chunk(chunk1);
275        tx.push_chunk(chunk2);
276
277        let msg = proj.next().await.unwrap().unwrap();
278        assert_eq!(
279            *msg.as_chunk().unwrap(),
280            StreamChunk::from_pretty(
281                " I
282                + 5
283                + 7
284                + 9"
285            )
286        );
287
288        let msg = proj.next().await.unwrap().unwrap();
289        assert_eq!(
290            *msg.as_chunk().unwrap(),
291            StreamChunk::from_pretty(
292                "  I
293                + 15
294                -  9"
295            )
296        );
297
298        tx.push_barrier(test_epoch(2), true);
299        assert!(proj.next().await.unwrap().unwrap().is_stop());
300    }
301
302    static DUMMY_COUNTER: AtomicI64 = AtomicI64::new(0);
303
304    #[derive(Debug)]
305    struct DummyNondecreasingExpr;
306
307    #[async_trait::async_trait]
308    impl Expression for DummyNondecreasingExpr {
309        fn return_type(&self) -> DataType {
310            DataType::Int64
311        }
312
313        async fn eval_v2(&self, input: &DataChunk) -> expr::Result<ValueImpl> {
314            let value = DUMMY_COUNTER.fetch_add(1, atomic::Ordering::SeqCst);
315            Ok(ValueImpl::Scalar {
316                value: Some(value.into()),
317                capacity: input.capacity(),
318            })
319        }
320
321        async fn eval_row(&self, _input: &OwnedRow) -> expr::Result<Datum> {
322            let value = DUMMY_COUNTER.fetch_add(1, atomic::Ordering::SeqCst);
323            Ok(Some(value.into()))
324        }
325    }
326
327    #[tokio::test]
328    async fn test_watermark_projection() {
329        let schema = Schema {
330            fields: vec![
331                Field::unnamed(DataType::Int64),
332                Field::unnamed(DataType::Int64),
333            ],
334        };
335        let (mut tx, source) = MockSource::channel();
336        let source = source.into_executor(schema, StreamKey::new());
337
338        let a_expr = build_from_pretty("(add:int8 $0:int8 1:int8)");
339        let b_expr = build_from_pretty("(subtract:int8 $0:int8 1:int8)");
340        let c_expr = NonStrictExpression::for_test(DummyNondecreasingExpr);
341
342        let proj = ProjectExecutor::new(
343            ActorContext::for_test(123),
344            source,
345            vec![a_expr, b_expr, c_expr],
346            MultiMap::from_iter(vec![(0, 0), (0, 1)].into_iter()),
347            vec![2],
348            false,
349        );
350        let mut proj = proj.boxed().execute();
351
352        tx.push_barrier(test_epoch(1), false);
353        tx.push_int64_watermark(0, 100);
354
355        proj.expect_barrier().await;
356        let w1 = proj.expect_watermark().await;
357        let w2 = proj.expect_watermark().await;
358        let (w1, w2) = if w1.col_idx < w2.col_idx {
359            (w1, w2)
360        } else {
361            (w2, w1)
362        };
363
364        assert_eq!(
365            w1,
366            Watermark {
367                col_idx: 0,
368                data_type: DataType::Int64,
369                val: ScalarImpl::Int64(101)
370            }
371        );
372        assert_eq!(
373            w2,
374            Watermark {
375                col_idx: 1,
376                data_type: DataType::Int64,
377                val: ScalarImpl::Int64(99)
378            }
379        );
380
381        // just push some random chunks
382        tx.push_chunk(StreamChunk::from_pretty(
383            "   I I
384            + 120 4
385            + 146 5
386            + 133 6",
387        ));
388        proj.expect_chunk().await;
389        tx.push_chunk(StreamChunk::from_pretty(
390            "   I I
391            + 213 8
392            - 133 6",
393        ));
394        proj.expect_chunk().await;
395
396        tx.push_barrier(test_epoch(2), false);
397        let w3 = proj.expect_watermark().await;
398        proj.expect_barrier().await;
399
400        tx.push_chunk(StreamChunk::from_pretty(
401            "   I I
402            + 100 3
403            + 104 5
404            + 187 3",
405        ));
406        proj.expect_chunk().await;
407
408        tx.push_barrier(test_epoch(3), false);
409        let w4 = proj.expect_watermark().await;
410        proj.expect_barrier().await;
411
412        assert_eq!(w3.col_idx, w4.col_idx);
413        assert!(w3.val.default_cmp(&w4.val).is_le());
414
415        tx.push_int64_watermark(1, 100);
416        tx.push_barrier(test_epoch(4), true);
417
418        assert!(proj.next().await.unwrap().unwrap().is_stop());
419    }
420}