risingwave_frontend/optimizer/rule/
apply_project_set_transpose_rule.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use itertools::Itertools;
16use risingwave_pb::plan_common::JoinType;
17
18use super::{ApplyOffsetRewriter, BoxedRule, Rule};
19use crate::expr::{ExprImpl, ExprRewriter, InputRef};
20use crate::optimizer::PlanRef;
21use crate::optimizer::plan_node::generic::GenericPlanRef;
22use crate::optimizer::plan_node::{LogicalApply, LogicalProject, LogicalProjectSet};
23
24/// Transpose `LogicalApply` and `LogicalProjectSet`.
25///
26/// Before:
27///
28/// ```text
29///     LogicalApply
30///    /            \
31///  Domain      LogicalProjectSet
32///                  |
33///                Input
34/// ```
35///
36/// After:
37///
38/// ```text
39///     LogicalProject (reorder)
40///           |
41///    LogicalProjectSet
42///          |
43///    LogicalApply
44///    /            \
45///  Domain        Input
46/// ```
47pub struct ApplyProjectSetTransposeRule {}
48impl Rule for ApplyProjectSetTransposeRule {
49    fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
50        let apply: &LogicalApply = plan.as_logical_apply()?;
51        let (left, right, on, join_type, correlated_id, correlated_indices, max_one_row) =
52            apply.clone().decompose();
53        let project_set: &LogicalProjectSet = right.as_logical_project_set()?;
54        let left_schema_len = left.schema().len();
55        assert_eq!(join_type, JoinType::Inner);
56
57        // Insert all the columns of `LogicalApply`'s left at the beginning of the new
58        // `LogicalProjectSet`.
59        let mut exprs: Vec<ExprImpl> = left
60            .schema()
61            .data_types()
62            .into_iter()
63            .enumerate()
64            .map(|(index, data_type)| InputRef::new(index, data_type).into())
65            .collect();
66
67        let (proj_exprs, proj_input) = project_set.clone().decompose();
68
69        // replace correlated_input_ref in project exprs
70        let mut rewriter =
71            ApplyOffsetRewriter::new(left.schema().len(), &correlated_indices, correlated_id);
72
73        let new_proj_exprs: Vec<ExprImpl> = proj_exprs
74            .into_iter()
75            .map(|expr| rewriter.rewrite_expr(expr))
76            .collect_vec();
77
78        exprs.extend(new_proj_exprs.clone());
79
80        let mut rewriter =
81            ApplyOnCondRewriterForProjectSet::new(left.schema().len(), new_proj_exprs);
82        let new_on = on.rewrite_expr(&mut rewriter);
83
84        if rewriter.refer_table_function {
85            // The join on condition refers to the table function column of the `project_set` which
86            // cannot be unnested.
87            return None;
88        }
89
90        let new_apply = LogicalApply::create(
91            left,
92            proj_input,
93            join_type,
94            new_on,
95            correlated_id,
96            correlated_indices,
97            max_one_row,
98        );
99
100        let new_project_set = LogicalProjectSet::create(new_apply, exprs);
101
102        // Since `project_set` has a field `projected_row_id` in its left most column, we need a
103        // project to reorder it to align the schema type.
104        let out_col_idxs = (1..=left_schema_len)
105            .chain(vec![0])
106            .chain((left_schema_len + 1)..new_project_set.schema().len());
107        let reorder_project = LogicalProject::with_out_col_idx(new_project_set, out_col_idxs);
108
109        Some(reorder_project.into())
110    }
111}
112
113impl ApplyProjectSetTransposeRule {
114    pub fn create() -> BoxedRule {
115        Box::new(ApplyProjectSetTransposeRule {})
116    }
117}
118
119pub struct ApplyOnCondRewriterForProjectSet {
120    pub left_input_len: usize,
121    pub mapping: Vec<ExprImpl>,
122    pub refer_table_function: bool,
123}
124
125impl ApplyOnCondRewriterForProjectSet {
126    pub fn new(left_input_len: usize, mapping: Vec<ExprImpl>) -> Self {
127        Self {
128            left_input_len,
129            mapping,
130            refer_table_function: false,
131        }
132    }
133}
134
135impl ExprRewriter for ApplyOnCondRewriterForProjectSet {
136    fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
137        if input_ref.index >= self.left_input_len {
138            // We need to minus 1 to align `projected_row_id` field in `project_set`.
139            let expr = self.mapping[input_ref.index() - self.left_input_len - 1].clone();
140            if matches!(expr, ExprImpl::TableFunction(_)) {
141                self.refer_table_function = true;
142            }
143            expr
144        } else {
145            input_ref.into()
146        }
147    }
148}