risingwave_frontend/optimizer/rule/
table_function_to_project_set_rule.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use itertools::Itertools;
16use risingwave_common::catalog::Schema;
17use risingwave_common::types::DataType;
18
19use super::{BoxedRule, Rule};
20use crate::expr::{Expr, ExprImpl, ExprType, FunctionCall, InputRef};
21use crate::optimizer::PlanRef;
22use crate::optimizer::plan_node::generic::GenericPlanRef;
23use crate::optimizer::plan_node::{
24    LogicalProject, LogicalProjectSet, LogicalTableFunction, LogicalValues, PlanTreeNodeUnary,
25};
26
27/// Transform a `TableFunction` (used in FROM clause) into a `ProjectSet` so that it can be unnested later if it contains `CorrelatedInputRef`.
28///
29/// Before:
30///
31/// ```text
32///             LogicalTableFunction
33/// ```
34///
35/// After:
36///
37///
38/// ```text
39///             LogicalProject (type alignment)
40///                   |
41///            LogicalProjectSet
42///                   |
43///             LogicalValues
44/// ```
45pub struct TableFunctionToProjectSetRule {}
46impl Rule for TableFunctionToProjectSetRule {
47    fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
48        let logical_table_function: &LogicalTableFunction = plan.as_logical_table_function()?;
49        let table_function =
50            ExprImpl::TableFunction(logical_table_function.table_function().clone().into());
51        let table_function_return_type = table_function.return_type();
52        let logical_values = LogicalValues::create(
53            vec![vec![]],
54            Schema::new(vec![]),
55            logical_table_function.base.ctx().clone(),
56        );
57        let logical_project_set = LogicalProjectSet::create(logical_values, vec![table_function]);
58        // We need a project to align schema type because
59        // 1. `LogicalProjectSet` has a hidden column `projected_row_id` (0-th col)
60        // 2. When the function returns a struct type, TableFunction will return flatten it into multiple columns, while ProjectSet still returns a single column.
61        let table_function_col_idx = 1;
62        let logical_project = if let DataType::Struct(st) = table_function_return_type.clone() {
63            let exprs = st
64                .types()
65                .enumerate()
66                .map(|(i, data_type)| {
67                    let field_access = FunctionCall::new_unchecked(
68                        ExprType::Field,
69                        vec![
70                            InputRef::new(
71                                table_function_col_idx,
72                                table_function_return_type.clone(),
73                            )
74                            .into(),
75                            ExprImpl::literal_int(i as i32),
76                        ],
77                        data_type.clone(),
78                    );
79                    ExprImpl::FunctionCall(field_access.into())
80                })
81                .collect_vec();
82            LogicalProject::new(logical_project_set, exprs)
83        } else {
84            LogicalProject::with_out_col_idx(
85                logical_project_set,
86                std::iter::once(table_function_col_idx),
87            )
88        };
89
90        if logical_table_function.with_ordinality {
91            let projected_row_id = InputRef::new(0, DataType::Int64).into();
92            let ordinality = FunctionCall::new(
93                ExprType::Add,
94                vec![projected_row_id, ExprImpl::literal_bigint(1)],
95            )
96            .unwrap() // i64 + i64 is ok
97            .into();
98            let mut exprs = logical_project.exprs().clone();
99            exprs.push(ordinality);
100            let logical_project = LogicalProject::new(logical_project.input(), exprs);
101            Some(logical_project.into())
102        } else {
103            Some(logical_project.into())
104        }
105    }
106}
107
108impl TableFunctionToProjectSetRule {
109    pub fn create() -> BoxedRule {
110        Box::new(TableFunctionToProjectSetRule {})
111    }
112}