risingwave_stream/executor/project/
project_set.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use either::Either;
16use multimap::MultiMap;
17use risingwave_common::array::{ArrayRef, DataChunk, Op};
18use risingwave_common::bail;
19use risingwave_common::row::RowExt;
20use risingwave_common::types::ToOwnedDatum;
21use risingwave_common::util::iter_util::ZipEqFast;
22use risingwave_expr::ExprError;
23use risingwave_expr::expr::{self, EvalErrorReport, NonStrictExpression};
24use risingwave_expr::table_function::{self, BoxedTableFunction, TableFunctionOutputIter};
25use risingwave_pb::expr::PbProjectSetSelectItem;
26use risingwave_pb::expr::project_set_select_item::PbSelectItem;
27
28use crate::executor::prelude::*;
29use crate::task::ActorEvalErrorReport;
30
31const PROJ_ROW_ID_OFFSET: usize = 1;
32
33/// `ProjectSetExecutor` projects data with the `expr`. The `expr` takes a chunk of data,
34/// and returns a new data chunk. And then, `ProjectSetExecutor` will insert, delete
35/// or update element into next operator according to the result of the expression.
36pub struct ProjectSetExecutor {
37    input: Executor,
38    inner: Inner,
39}
40
41struct Inner {
42    _ctx: ActorContextRef,
43
44    /// Expressions of the current `project_section`.
45    select_list: Vec<ProjectSetSelectItem>,
46    chunk_size: usize,
47    /// All the watermark derivations, (`input_column_index`, `expr_idx`). And the
48    /// derivation expression is the `project_set`'s expression itself.
49    watermark_derivations: MultiMap<usize, usize>,
50    /// Indices of nondecreasing expressions in the expression list.
51    nondecreasing_expr_indices: Vec<usize>,
52    error_report: ActorEvalErrorReport,
53}
54
55impl ProjectSetExecutor {
56    pub fn new(
57        ctx: ActorContextRef,
58        input: Executor,
59        select_list: Vec<ProjectSetSelectItem>,
60        chunk_size: usize,
61        watermark_derivations: MultiMap<usize, usize>,
62        nondecreasing_expr_indices: Vec<usize>,
63        error_report: ActorEvalErrorReport,
64    ) -> Self {
65        let inner = Inner {
66            _ctx: ctx,
67            select_list,
68            chunk_size,
69            watermark_derivations,
70            nondecreasing_expr_indices,
71            error_report,
72        };
73
74        Self { input, inner }
75    }
76}
77
78impl Debug for ProjectSetExecutor {
79    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
80        f.debug_struct("ProjectSetExecutor")
81            .field("exprs", &self.inner.select_list)
82            .finish()
83    }
84}
85
86impl Execute for ProjectSetExecutor {
87    fn execute(self: Box<Self>) -> BoxedMessageStream {
88        self.inner.execute(self.input).boxed()
89    }
90}
91
92impl Inner {
93    #[try_stream(ok = Message, error = StreamExecutorError)]
94    async fn execute(self, input: Executor) {
95        let mut input = input.execute();
96        let first_barrier = expect_first_barrier(&mut input).await?;
97        let mut is_paused = first_barrier.is_pause_on_startup();
98        yield Message::Barrier(first_barrier);
99
100        assert!(!self.select_list.is_empty());
101        // First column will be `projected_row_id`, which represents the index in the
102        // output table
103        let data_types: Vec<_> = std::iter::once(DataType::Int64)
104            .chain(self.select_list.iter().map(|i| i.return_type()))
105            .collect();
106        // a temporary row buffer
107        let mut row = vec![DatumRef::None; data_types.len()];
108        let mut builder = StreamChunkBuilder::new(self.chunk_size, data_types);
109
110        let mut last_nondec_expr_values = vec![None; self.nondecreasing_expr_indices.len()];
111
112        #[for_await]
113        for msg in input {
114            match msg? {
115                Message::Watermark(watermark) => {
116                    let watermarks = self.handle_watermark(watermark).await?;
117                    for watermark in watermarks {
118                        yield Message::Watermark(watermark)
119                    }
120                }
121                Message::Barrier(barrier) => {
122                    if !is_paused {
123                        for (&expr_idx, value) in self
124                            .nondecreasing_expr_indices
125                            .iter()
126                            .zip_eq_fast(&mut last_nondec_expr_values)
127                        {
128                            if let Some(value) = std::mem::take(value) {
129                                yield Message::Watermark(Watermark::new(
130                                    expr_idx + PROJ_ROW_ID_OFFSET,
131                                    self.select_list[expr_idx].return_type(),
132                                    value,
133                                ))
134                            }
135                        }
136                    }
137
138                    if let Some(mutation) = barrier.mutation.as_deref() {
139                        match mutation {
140                            Mutation::Pause => {
141                                is_paused = true;
142                            }
143                            Mutation::Resume => {
144                                is_paused = false;
145                            }
146                            _ => (),
147                        }
148                    }
149
150                    yield Message::Barrier(barrier);
151                }
152                Message::Chunk(chunk) => {
153                    let mut results = Vec::with_capacity(self.select_list.len());
154                    for select_item in &self.select_list {
155                        let result = select_item.eval(chunk.data_chunk()).await?;
156                        results.push(result);
157                    }
158
159                    // for each input row
160                    for row_idx in 0..chunk.capacity() {
161                        // ProjectSet cannot preserve that U- is followed by U+,
162                        // so we rewrite update to insert/delete.
163                        let op = match chunk.ops()[row_idx] {
164                            Op::Delete | Op::UpdateDelete => Op::Delete,
165                            Op::Insert | Op::UpdateInsert => Op::Insert,
166                        };
167
168                        // Whether the output corresponds to the current input row.
169                        let is_current_input = |i| {
170                            assert!(
171                                i >= row_idx,
172                                "unexpectedly operating on previous input, i: {i}, row_idx: {row_idx}",
173                            );
174                            i == row_idx
175                        };
176
177                        // for each output row
178                        for projected_row_id in 0i64.. {
179                            // SAFETY:
180                            // We use `row` as a buffer and don't read elements from the previous
181                            // loop. The `transmute` is used for bypassing the borrow checker.
182                            let row: &mut [DatumRef<'_>] =
183                                unsafe { std::mem::transmute(row.as_mut_slice()) };
184
185                            row[0] = Some(projected_row_id.into());
186
187                            // Whether all table functions has exhausted or has failed for current input row.
188                            let mut fully_consumed = true;
189
190                            // for each column
191                            for (item, value) in results.iter_mut().zip_eq_fast(&mut row[1..]) {
192                                *value = match item {
193                                    Either::Left(state) => {
194                                        if let Some((i, result)) = state.peek()
195                                            && is_current_input(i)
196                                        {
197                                            match result {
198                                                Ok(value) => {
199                                                    fully_consumed = false;
200                                                    value
201                                                }
202                                                Err(err) => {
203                                                    self.error_report.report(err);
204                                                    // When we encounter an error from one of the table functions,
205                                                    //
206                                                    // - if there are other successful table functions, `fully_consumed` will still be
207                                                    //   set to `false`, a `NULL` will be set in the output row for the failed table function,
208                                                    //   that's why we set `None` here.
209                                                    //
210                                                    // - if there are no other successful table functions (or we are the only table function),
211                                                    //   `fully_consumed` will be set to `true`, we won't output the row at all but skip
212                                                    //   the whole result set of the given row. Setting `None` here is no-op.
213                                                    None
214                                                }
215                                            }
216                                        } else {
217                                            None
218                                        }
219                                    }
220                                    Either::Right(array) => array.value_at(row_idx),
221                                };
222                            }
223
224                            if fully_consumed {
225                                // Skip the current input row and break the loop to handle the next input row.
226                                // - If all exhausted, this is no-op.
227                                // - If all failed, this skips remaining outputs of the current input row.
228                                for item in &mut results {
229                                    if let Either::Left(state) = item {
230                                        while let Some((i, _)) = state.peek()
231                                            && is_current_input(i)
232                                        {
233                                            state.next().await?;
234                                        }
235                                    }
236                                }
237                                break;
238                            } else {
239                                if let Some(chunk) = builder.append_row(op, &*row) {
240                                    self.update_last_nondec_expr_values(
241                                        &mut last_nondec_expr_values,
242                                        &chunk,
243                                    );
244                                    yield Message::Chunk(chunk);
245                                }
246                                // move to the next row
247                                for item in &mut results {
248                                    if let Either::Left(state) = item
249                                        && matches!(state.peek(), Some((i, _)) if is_current_input(i))
250                                    {
251                                        state.next().await?;
252                                    }
253                                }
254                            }
255                        }
256                    }
257                    if let Some(chunk) = builder.take() {
258                        self.update_last_nondec_expr_values(&mut last_nondec_expr_values, &chunk);
259                        yield Message::Chunk(chunk);
260                    }
261                }
262            }
263        }
264    }
265
266    fn update_last_nondec_expr_values(
267        &self,
268        last_nondec_expr_values: &mut [Datum],
269        chunk: &StreamChunk,
270    ) {
271        if !self.nondecreasing_expr_indices.is_empty() {
272            if let Some((_, first_visible_row)) = chunk.rows().next() {
273                // it's ok to use the first row here, just one chunk delay
274                first_visible_row
275                    .project(&self.nondecreasing_expr_indices)
276                    .iter()
277                    .enumerate()
278                    .for_each(|(idx, value)| {
279                        last_nondec_expr_values[idx] = Some(
280                            value
281                                .to_owned_datum()
282                                .expect("non-decreasing expression should never be NULL"),
283                        );
284                    });
285            }
286        }
287    }
288
289    async fn handle_watermark(&self, watermark: Watermark) -> StreamExecutorResult<Vec<Watermark>> {
290        let expr_indices = match self.watermark_derivations.get_vec(&watermark.col_idx) {
291            Some(v) => v,
292            None => return Ok(vec![]),
293        };
294        let mut ret = vec![];
295        for expr_idx in expr_indices {
296            let expr_idx = *expr_idx;
297            let derived_watermark = match &self.select_list[expr_idx] {
298                ProjectSetSelectItem::Scalar(expr) => {
299                    watermark
300                        .clone()
301                        .transform_with_expr(expr, expr_idx + PROJ_ROW_ID_OFFSET)
302                        .await
303                }
304                ProjectSetSelectItem::Set(_) => {
305                    bail!("Watermark should not be produced by a table function");
306                }
307            };
308
309            if let Some(derived_watermark) = derived_watermark {
310                ret.push(derived_watermark);
311            } else {
312                warn!(
313                    "a NULL watermark is derived with the expression {}!",
314                    expr_idx
315                );
316            }
317        }
318        Ok(ret)
319    }
320}
321
322/// Either a scalar expression or a set-returning function.
323///
324/// See also [`PbProjectSetSelectItem`].
325///
326/// A similar enum is defined in the `batch` module. The difference is that
327/// we use `NonStrictExpression` instead of `BoxedExpression` here.
328#[derive(Debug)]
329pub enum ProjectSetSelectItem {
330    Scalar(NonStrictExpression),
331    Set(BoxedTableFunction),
332}
333
334impl From<BoxedTableFunction> for ProjectSetSelectItem {
335    fn from(table_function: BoxedTableFunction) -> Self {
336        ProjectSetSelectItem::Set(table_function)
337    }
338}
339
340impl From<NonStrictExpression> for ProjectSetSelectItem {
341    fn from(expr: NonStrictExpression) -> Self {
342        ProjectSetSelectItem::Scalar(expr)
343    }
344}
345
346impl ProjectSetSelectItem {
347    pub fn from_prost(
348        prost: &PbProjectSetSelectItem,
349        error_report: impl EvalErrorReport + 'static,
350        chunk_size: usize,
351    ) -> Result<Self, ExprError> {
352        match prost.select_item.as_ref().unwrap() {
353            PbSelectItem::Expr(expr) => {
354                expr::build_non_strict_from_prost(expr, error_report).map(Self::Scalar)
355            }
356            PbSelectItem::TableFunction(tf) => {
357                table_function::build_from_prost(tf, chunk_size).map(Self::Set)
358            }
359        }
360    }
361
362    pub fn return_type(&self) -> DataType {
363        match self {
364            ProjectSetSelectItem::Scalar(expr) => expr.return_type(),
365            ProjectSetSelectItem::Set(tf) => tf.return_type(),
366        }
367    }
368
369    pub async fn eval<'a>(
370        &'a self,
371        input: &'a DataChunk,
372    ) -> Result<Either<TableFunctionOutputIter<'a>, ArrayRef>, ExprError> {
373        match self {
374            Self::Scalar(expr) => Ok(Either::Right(expr.eval_infallible(input).await)),
375            Self::Set(tf) => Ok(Either::Left(
376                TableFunctionOutputIter::new(tf.eval(input).await).await?,
377            )),
378        }
379    }
380}