risingwave_frontend/optimizer/rule/
table_function_to_project_set_rule.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use itertools::Itertools;
16use risingwave_common::catalog::Schema;
17use risingwave_common::types::DataType;
18
19use super::prelude::{PlanRef, *};
20use crate::expr::{Expr, ExprImpl, ExprType, FunctionCall, InputRef};
21use crate::optimizer::plan_node::generic::GenericPlanRef;
22use crate::optimizer::plan_node::{
23    LogicalProject, LogicalProjectSet, LogicalTableFunction, LogicalValues, PlanTreeNodeUnary,
24};
25
26/// Transform a `TableFunction` (used in FROM clause) into a `ProjectSet` so that it can be unnested later if it contains `CorrelatedInputRef`.
27///
28/// Before:
29///
30/// ```text
31///             LogicalTableFunction
32/// ```
33///
34/// After:
35///
36///
37/// ```text
38///             LogicalProject (type alignment)
39///                   |
40///            LogicalProjectSet
41///                   |
42///             LogicalValues
43/// ```
44pub struct TableFunctionToProjectSetRule {}
45impl Rule<Logical> for TableFunctionToProjectSetRule {
46    fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
47        let logical_table_function: &LogicalTableFunction = plan.as_logical_table_function()?;
48        let table_function =
49            ExprImpl::TableFunction(logical_table_function.table_function().clone().into());
50        let table_function_return_type = table_function.return_type();
51        let logical_values = LogicalValues::create(
52            vec![vec![]],
53            Schema::new(vec![]),
54            logical_table_function.base.ctx().clone(),
55        );
56        let logical_project_set = LogicalProjectSet::create(logical_values, vec![table_function]);
57        // We need a project to align schema type because
58        // 1. `LogicalProjectSet` has a hidden column `projected_row_id` (0-th col)
59        // 2. When the function returns a struct type, TableFunction will return flatten it into multiple columns, while ProjectSet still returns a single column.
60        let table_function_col_idx = 1;
61        let logical_project = if let DataType::Struct(st) = table_function_return_type.clone() {
62            let exprs = st
63                .types()
64                .enumerate()
65                .map(|(i, data_type)| {
66                    let field_access = FunctionCall::new_unchecked(
67                        ExprType::Field,
68                        vec![
69                            InputRef::new(
70                                table_function_col_idx,
71                                table_function_return_type.clone(),
72                            )
73                            .into(),
74                            ExprImpl::literal_int(i as i32),
75                        ],
76                        data_type.clone(),
77                    );
78                    ExprImpl::FunctionCall(field_access.into())
79                })
80                .collect_vec();
81            LogicalProject::new(logical_project_set, exprs)
82        } else {
83            LogicalProject::with_out_col_idx(
84                logical_project_set,
85                std::iter::once(table_function_col_idx),
86            )
87        };
88
89        if logical_table_function.with_ordinality {
90            let projected_row_id = InputRef::new(0, DataType::Int64).into();
91            let ordinality = FunctionCall::new(
92                ExprType::Add,
93                vec![projected_row_id, ExprImpl::literal_bigint(1)],
94            )
95            .unwrap() // i64 + i64 is ok
96            .into();
97            let mut exprs = logical_project.exprs().clone();
98            exprs.push(ordinality);
99            let logical_project = LogicalProject::new(logical_project.input(), exprs);
100            Some(logical_project.into())
101        } else {
102            Some(logical_project.into())
103        }
104    }
105}
106
107impl TableFunctionToProjectSetRule {
108    pub fn create() -> BoxedRule {
109        Box::new(TableFunctionToProjectSetRule {})
110    }
111}