risingwave_frontend/optimizer/plan_node/
logical_project_set.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use fixedbitset::FixedBitSet;
16use itertools::Itertools;
17use risingwave_common::types::DataType;
18
19use super::utils::impl_distill_by_unit;
20use super::{
21    BatchProjectSet, ColPrunable, ExprRewritable, Logical, LogicalProject, PlanBase, PlanRef,
22    PlanTreeNodeUnary, PredicatePushdown, StreamProjectSet, ToBatch, ToStream,
23    gen_filter_and_pushdown, generic,
24};
25use crate::error::{ErrorCode, Result};
26use crate::expr::{
27    Expr, ExprImpl, ExprRewriter, ExprVisitor, FunctionCall, InputRef, TableFunction,
28    collect_input_refs,
29};
30use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
31use crate::optimizer::plan_node::generic::GenericPlanRef;
32use crate::optimizer::plan_node::{
33    ColumnPruningContext, PredicatePushdownContext, RewriteStreamContext, ToStreamContext,
34};
35use crate::utils::{ColIndexMapping, Condition, Substitute};
36
37/// `LogicalProjectSet` projects one row multiple times according to `select_list`.
38///
39/// Different from `Project`, it supports [`TableFunction`](crate::expr::TableFunction)s.
40/// See also [`ProjectSetSelectItem`](risingwave_pb::expr::ProjectSetSelectItem) for examples.
41///
42/// To have a pk, it has a hidden column `projected_row_id` at the beginning. The implementation of
43/// `LogicalProjectSet` is highly similar to [`LogicalProject`], except for the additional hidden
44/// column.
45#[derive(Debug, Clone, PartialEq, Eq, Hash)]
46pub struct LogicalProjectSet {
47    pub base: PlanBase<Logical>,
48    core: generic::ProjectSet<PlanRef>,
49}
50
51impl LogicalProjectSet {
52    pub fn new(input: PlanRef, select_list: Vec<ExprImpl>) -> Self {
53        assert!(
54            select_list.iter().any(|e| e.has_table_function()),
55            "ProjectSet should have at least one table function."
56        );
57
58        let core = generic::ProjectSet { select_list, input };
59        let base = PlanBase::new_logical_with_core(&core);
60
61        LogicalProjectSet { base, core }
62    }
63
64    /// `create` will analyze select exprs with table functions and construct a plan.
65    ///
66    /// When there is no table functions in the select list, it will return a simple
67    /// `LogicalProject`.
68    ///
69    /// When table functions are used as arguments of a table function or a usual function, the
70    /// arguments will be put at a lower `ProjectSet` while the call will be put at a higher
71    /// `Project` or `ProjectSet`. The plan is like:
72    ///
73    /// ```text
74    /// LogicalProjectSet/LogicalProject -> LogicalProjectSet -> input
75    /// ```
76    ///
77    /// Otherwise it will be a simple `ProjectSet`.
78    pub fn create(input: PlanRef, select_list: Vec<ExprImpl>) -> PlanRef {
79        if select_list
80            .iter()
81            .all(|e: &ExprImpl| !e.has_table_function())
82        {
83            return LogicalProject::create(input, select_list);
84        }
85
86        /// Rewrites a `FunctionCall` or `TableFunction` whose args contain table functions into one
87        /// using `InputRef` as args.
88        struct Rewriter {
89            collected: Vec<TableFunction>,
90            /// The nesting level of calls.
91            ///
92            /// f(x) has level 1 at x, and f(g(x)) has level 2 at x.
93            level: usize,
94            input_schema_len: usize,
95        }
96
97        impl ExprRewriter for Rewriter {
98            fn rewrite_table_function(&mut self, table_func: TableFunction) -> ExprImpl {
99                if self.level == 0 {
100                    // Top-level table function doesn't need to be collected.
101                    self.level += 1;
102
103                    let TableFunction {
104                        args,
105                        return_type,
106                        function_type,
107                        user_defined,
108                    } = table_func;
109                    let args = args
110                        .into_iter()
111                        .map(|expr| self.rewrite_expr(expr))
112                        .collect();
113
114                    self.level -= 1;
115                    TableFunction {
116                        args,
117                        return_type,
118                        function_type,
119                        user_defined,
120                    }
121                    .into()
122                } else {
123                    let input_ref = InputRef::new(
124                        self.input_schema_len + self.collected.len(),
125                        table_func.return_type(),
126                    );
127                    self.collected.push(table_func);
128                    input_ref.into()
129                }
130            }
131
132            fn rewrite_function_call(&mut self, func_call: FunctionCall) -> ExprImpl {
133                self.level += 1;
134                let (func_type, inputs, return_type) = func_call.decompose();
135                let inputs = inputs
136                    .into_iter()
137                    .map(|expr| self.rewrite_expr(expr))
138                    .collect();
139                self.level -= 1;
140                FunctionCall::new_unchecked(func_type, inputs, return_type).into()
141            }
142        }
143
144        let mut rewriter = Rewriter {
145            collected: vec![],
146            level: 0,
147            input_schema_len: input.schema().len(),
148        };
149        let select_list: Vec<_> = select_list
150            .into_iter()
151            .map(|e| rewriter.rewrite_expr(e))
152            .collect();
153
154        if rewriter.collected.is_empty() {
155            LogicalProjectSet::new(input, select_list).into()
156        } else {
157            let mut inner_select_list: Vec<_> = input
158                .schema()
159                .data_types()
160                .into_iter()
161                .enumerate()
162                .map(|(i, ty)| InputRef::new(i, ty).into())
163                .collect();
164            inner_select_list.extend(rewriter.collected.into_iter().map(|tf| tf.into()));
165            let inner = LogicalProjectSet::create(input, inner_select_list);
166
167            /// Increase all the input ref in the outer select list, because the inner project set
168            /// will output a hidden column at the beginning.
169            struct IncInputRef {}
170            impl ExprRewriter for IncInputRef {
171                fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
172                    InputRef::new(input_ref.index + 1, input_ref.data_type).into()
173                }
174            }
175            let mut rewriter = IncInputRef {};
176            let select_list: Vec<_> = select_list
177                .into_iter()
178                .map(|e| rewriter.rewrite_expr(e))
179                .collect();
180
181            if select_list.iter().any(|e| e.has_table_function()) {
182                LogicalProjectSet::new(inner, select_list).into()
183            } else {
184                LogicalProject::new(inner, select_list).into()
185            }
186        }
187    }
188
189    pub fn select_list(&self) -> &Vec<ExprImpl> {
190        &self.core.select_list
191    }
192
193    pub fn decompose(self) -> (Vec<ExprImpl>, PlanRef) {
194        self.core.decompose()
195    }
196}
197
198impl PlanTreeNodeUnary for LogicalProjectSet {
199    fn input(&self) -> PlanRef {
200        self.core.input.clone()
201    }
202
203    fn clone_with_input(&self, input: PlanRef) -> Self {
204        Self::new(input, self.select_list().clone())
205    }
206
207    fn rewrite_with_input(
208        &self,
209        input: PlanRef,
210        mut input_col_change: ColIndexMapping,
211    ) -> (Self, ColIndexMapping) {
212        let select_list = self
213            .select_list()
214            .clone()
215            .into_iter()
216            .map(|item| input_col_change.rewrite_expr(item))
217            .collect();
218        let project_set = Self::new(input, select_list);
219        // change the input columns index will not change the output column index
220        let out_col_change = ColIndexMapping::identity(self.schema().len());
221        (project_set, out_col_change)
222    }
223}
224
225impl_plan_tree_node_for_unary! {LogicalProjectSet}
226impl_distill_by_unit!(LogicalProjectSet, core, "LogicalProjectSet");
227// TODO: add verbose display like Project
228
229impl ColPrunable for LogicalProjectSet {
230    fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef {
231        let output_required_cols = required_cols;
232        let required_cols = {
233            let mut required_cols_set = FixedBitSet::from_iter(required_cols.iter().copied());
234            required_cols_set.grow(self.select_list().len() + 1);
235            let mut cols = required_cols.to_vec();
236            // We should not prune table functions, because the final number of result rows is
237            // depended by all table function calls
238            for (i, e) in self.select_list().iter().enumerate() {
239                if e.has_table_function() && !required_cols_set.contains(i + 1) {
240                    cols.push(i + 1);
241                    required_cols_set.set(i + 1, true);
242                }
243            }
244            cols
245        };
246
247        let input_col_num = self.input().schema().len();
248
249        let input_required_cols = collect_input_refs(
250            input_col_num,
251            required_cols
252                .iter()
253                .filter(|&&i| i > 0)
254                .map(|i| &self.select_list()[*i - 1]),
255        )
256        .ones()
257        .collect_vec();
258        let new_input = self.input().prune_col(&input_required_cols, ctx);
259        let mut mapping = ColIndexMapping::with_remaining_columns(
260            &input_required_cols,
261            self.input().schema().len(),
262        );
263        // Rewrite each InputRef with new index.
264        let select_list = required_cols
265            .iter()
266            .filter(|&&id| id > 0)
267            .map(|&id| mapping.rewrite_expr(self.select_list()[id - 1].clone()))
268            .collect();
269
270        // Reconstruct the LogicalProjectSet
271        let new_node: PlanRef = LogicalProjectSet::create(new_input, select_list);
272        if new_node.schema().len() == output_required_cols.len() {
273            // current schema perfectly fit the required columns
274            new_node
275        } else {
276            // projected_row_id column is not needed so we did a projection to remove it
277            let mut new_output_cols = required_cols.to_vec();
278            if !required_cols.contains(&0) {
279                new_output_cols.insert(0, 0);
280            }
281            let mapping =
282                &ColIndexMapping::with_remaining_columns(&new_output_cols, self.schema().len());
283            let output_required_cols = output_required_cols
284                .iter()
285                .map(|&idx| mapping.map(idx))
286                .collect_vec();
287            let src_size = new_node.schema().len();
288            LogicalProject::with_mapping(
289                new_node,
290                ColIndexMapping::with_remaining_columns(&output_required_cols, src_size),
291            )
292            .into()
293        }
294    }
295}
296
297impl ExprRewritable for LogicalProjectSet {
298    fn has_rewritable_expr(&self) -> bool {
299        true
300    }
301
302    fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
303        let mut core = self.core.clone();
304        core.rewrite_exprs(r);
305        Self {
306            base: self.base.clone_with_new_plan_id(),
307            core,
308        }
309        .into()
310    }
311}
312
313impl ExprVisitable for LogicalProjectSet {
314    fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
315        self.core.visit_exprs(v);
316    }
317}
318
319impl PredicatePushdown for LogicalProjectSet {
320    fn predicate_pushdown(
321        &self,
322        predicate: Condition,
323        ctx: &mut PredicatePushdownContext,
324    ) -> PlanRef {
325        // convert the predicate to one that references the child of the project
326        let mut subst = Substitute {
327            mapping: {
328                let mut output_list = self.select_list().clone();
329                output_list.insert(
330                    0,
331                    ExprImpl::InputRef(Box::new(InputRef {
332                        index: 0,
333                        data_type: DataType::Int64,
334                    })),
335                );
336                output_list
337            },
338        };
339
340        let remain_mask = {
341            let mut remain_mask = FixedBitSet::with_capacity(self.select_list().len() + 1);
342            remain_mask.set(0, true);
343            self.select_list()
344                .iter()
345                .enumerate()
346                .for_each(|(i, e)| remain_mask.set(i + 1, e.is_impure() || e.has_table_function()));
347            remain_mask
348        };
349        let (remained_cond, pushed_cond) = predicate.split_disjoint(&remain_mask);
350        let pushed_cond = pushed_cond.rewrite_expr(&mut subst);
351
352        gen_filter_and_pushdown(self, remained_cond, pushed_cond, ctx)
353    }
354}
355
356impl ToBatch for LogicalProjectSet {
357    fn to_batch(&self) -> Result<PlanRef> {
358        let mut new_logical = self.core.clone();
359        new_logical.input = self.input().to_batch()?;
360        Ok(BatchProjectSet::new(new_logical).into())
361    }
362}
363
364impl ToStream for LogicalProjectSet {
365    fn logical_rewrite_for_stream(
366        &self,
367        ctx: &mut RewriteStreamContext,
368    ) -> Result<(PlanRef, ColIndexMapping)> {
369        let (input, input_col_change) = self.input().logical_rewrite_for_stream(ctx)?;
370        let (project_set, out_col_change) =
371            self.rewrite_with_input(input.clone(), input_col_change);
372
373        // Add missing columns of input_pk into the select list.
374        let input_pk = input.expect_stream_key();
375        let i2o = self.core.i2o_col_mapping();
376        let col_need_to_add = input_pk
377            .iter()
378            .cloned()
379            .filter(|i| i2o.try_map(*i).is_none());
380        let input_schema = input.schema();
381        let select_list =
382            project_set
383                .select_list()
384                .iter()
385                .cloned()
386                .chain(col_need_to_add.map(|idx| {
387                    InputRef::new(idx, input_schema.fields[idx].data_type.clone()).into()
388                }))
389                .collect();
390        let project_set = Self::new(input, select_list);
391        // The added columns is at the end, so it will not change existing column indices.
392        // But the target size of `out_col_change` should be the same as the length of the new
393        // schema.
394        let (map, _) = out_col_change.into_parts();
395        let out_col_change = ColIndexMapping::new(map, project_set.schema().len());
396        Ok((project_set.into(), out_col_change))
397    }
398
399    // TODO: implement to_stream_with_dist_required like LogicalProject
400
401    fn to_stream(&self, ctx: &mut ToStreamContext) -> Result<PlanRef> {
402        if self.select_list().iter().any(|item| item.has_now()) {
403            // User may use `now()` in table function in a wrong way, because we allow `now()` in `FROM` clause.
404            return Err(ErrorCode::NotSupported(
405                "General `now()` function in streaming queries".to_owned(),
406                "Streaming `now()` is currently only supported in GenerateSeries and TemporalFilter patterns.".to_owned(),
407            )
408            .into());
409        }
410
411        let new_input = self.input().to_stream(ctx)?;
412        let mut new_logical = self.core.clone();
413        new_logical.input = new_input;
414        Ok(StreamProjectSet::new(new_logical).into())
415    }
416}
417
418#[cfg(test)]
419mod test {
420    use std::collections::HashSet;
421
422    use risingwave_common::catalog::{Field, Schema};
423
424    use super::*;
425    use crate::optimizer::optimizer_context::OptimizerContext;
426    use crate::optimizer::plan_node::LogicalValues;
427    use crate::optimizer::property::FunctionalDependency;
428
429    #[tokio::test]
430    async fn fd_derivation_project_set() {
431        // input: [v1, v2, v3]
432        // FD: v2 --> v3
433        // output: [projected_row_id, v3, v2, generate_series(v1, v2, v3)],
434        // FD: v2 --> v3
435
436        let ctx = OptimizerContext::mock().await;
437        let fields: Vec<Field> = vec![
438            Field::with_name(DataType::Int32, "v1"),
439            Field::with_name(DataType::Int32, "v2"),
440            Field::with_name(DataType::Int32, "v3"),
441        ];
442        let mut values = LogicalValues::new(vec![], Schema { fields }, ctx);
443        values
444            .base
445            .functional_dependency_mut()
446            .add_functional_dependency_by_column_indices(&[1], &[2]);
447        let project_set = LogicalProjectSet::new(
448            values.into(),
449            vec![
450                ExprImpl::InputRef(Box::new(InputRef::new(2, DataType::Int32))),
451                ExprImpl::InputRef(Box::new(InputRef::new(1, DataType::Int32))),
452                ExprImpl::TableFunction(Box::new(
453                    TableFunction::new(
454                        crate::expr::TableFunctionType::GenerateSeries,
455                        vec![
456                            ExprImpl::InputRef(Box::new(InputRef::new(0, DataType::Int32))),
457                            ExprImpl::InputRef(Box::new(InputRef::new(1, DataType::Int32))),
458                            ExprImpl::InputRef(Box::new(InputRef::new(2, DataType::Int32))),
459                        ],
460                    )
461                    .unwrap(),
462                )),
463            ],
464        );
465        let fd_set: HashSet<FunctionalDependency> = project_set
466            .base
467            .functional_dependency()
468            .as_dependencies()
469            .clone()
470            .into_iter()
471            .collect();
472        let expected_fd_set: HashSet<FunctionalDependency> =
473            [FunctionalDependency::with_indices(4, &[2], &[1])]
474                .into_iter()
475                .collect();
476        assert_eq!(fd_set, expected_fd_set);
477    }
478}