risingwave_frontend/optimizer/rule/
table_function_to_project_set_rule.rs

1// Copyright 2023 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use itertools::Itertools;
16use risingwave_common::types::DataType;
17
18use super::prelude::{PlanRef, *};
19use crate::expr::{Expr, ExprImpl, ExprType, FunctionCall, InputRef};
20use crate::optimizer::plan_node::generic::GenericPlanRef;
21use crate::optimizer::plan_node::{
22    LogicalProject, LogicalProjectSet, LogicalTableFunction, LogicalValues, PlanTreeNodeUnary,
23};
24
25/// Transform a `TableFunction` (used in FROM clause) into a `ProjectSet` so that it can be unnested later if it contains `CorrelatedInputRef`.
26///
27/// Before:
28///
29/// ```text
30///             LogicalTableFunction
31/// ```
32///
33/// After:
34///
35///
36/// ```text
37///             LogicalProject (type alignment)
38///                   |
39///            LogicalProjectSet
40///                   |
41///             LogicalValues
42/// ```
43pub struct TableFunctionToProjectSetRule {}
44impl Rule<Logical> for TableFunctionToProjectSetRule {
45    fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
46        let logical_table_function: &LogicalTableFunction = plan.as_logical_table_function()?;
47        let table_function =
48            ExprImpl::TableFunction(logical_table_function.table_function().clone().into());
49        let table_function_return_type = table_function.return_type();
50        let logical_values = LogicalValues::create_empty_scalar(logical_table_function.base.ctx());
51        let logical_project_set = LogicalProjectSet::create(logical_values, vec![table_function]);
52        // We need a project to align schema type because
53        // 1. `LogicalProjectSet` has a hidden column `projected_row_id` (0-th col)
54        // 2. When the function returns a struct type, TableFunction will return flatten it into multiple columns, while ProjectSet still returns a single column.
55        let table_function_col_idx = 1;
56        let logical_project = if let DataType::Struct(st) = table_function_return_type.clone() {
57            let exprs = st
58                .types()
59                .enumerate()
60                .map(|(i, data_type)| {
61                    let field_access = FunctionCall::new_unchecked(
62                        ExprType::Field,
63                        vec![
64                            InputRef::new(
65                                table_function_col_idx,
66                                table_function_return_type.clone(),
67                            )
68                            .into(),
69                            ExprImpl::literal_int(i as i32),
70                        ],
71                        data_type.clone(),
72                    );
73                    ExprImpl::FunctionCall(field_access.into())
74                })
75                .collect_vec();
76            LogicalProject::new(logical_project_set, exprs)
77        } else {
78            LogicalProject::with_out_col_idx(
79                logical_project_set,
80                std::iter::once(table_function_col_idx),
81            )
82        };
83
84        if logical_table_function.with_ordinality {
85            let projected_row_id = InputRef::new(0, DataType::Int64).into();
86            let ordinality = FunctionCall::new(
87                ExprType::Add,
88                vec![projected_row_id, ExprImpl::literal_bigint(1)],
89            )
90            .unwrap() // i64 + i64 is ok
91            .into();
92            let mut exprs = logical_project.exprs().clone();
93            exprs.push(ordinality);
94            let logical_project = LogicalProject::new(logical_project.input(), exprs);
95            Some(logical_project.into())
96        } else {
97            Some(logical_project.into())
98        }
99    }
100}
101
102impl TableFunctionToProjectSetRule {
103    pub fn create() -> BoxedRule {
104        Box::new(TableFunctionToProjectSetRule {})
105    }
106}