risingwave_frontend/optimizer/rule/
apply_project_transpose_rule.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use itertools::Itertools;
16use risingwave_pb::plan_common::JoinType;
17
18use super::{ApplyOffsetRewriter, BoxedRule, Rule};
19use crate::expr::{ExprImpl, ExprRewriter, InputRef};
20use crate::optimizer::PlanRef;
21use crate::optimizer::plan_node::{LogicalApply, LogicalProject};
22
23/// Transpose `LogicalApply` and `LogicalProject`.
24///
25/// Before:
26///
27/// ```text
28///     LogicalApply
29///    /            \
30///  Domain      LogicalProject
31///                  |
32///                Input
33/// ```
34///
35/// After:
36///
37/// ```text
38///    LogicalProject
39///          |
40///    LogicalApply
41///    /            \
42///  Domain        Input
43/// ```
44pub struct ApplyProjectTransposeRule {}
45impl Rule for ApplyProjectTransposeRule {
46    fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
47        let apply: &LogicalApply = plan.as_logical_apply()?;
48        let (left, right, on, join_type, correlated_id, correlated_indices, max_one_row) =
49            apply.clone().decompose();
50        let project = right.as_logical_project()?;
51        assert_eq!(join_type, JoinType::Inner);
52
53        // Insert all the columns of `LogicalApply`'s left at the beginning of the new
54        // `LogicalProject`.
55        let mut exprs: Vec<ExprImpl> = left
56            .schema()
57            .data_types()
58            .into_iter()
59            .enumerate()
60            .map(|(index, data_type)| InputRef::new(index, data_type).into())
61            .collect();
62
63        let (proj_exprs, proj_input) = project.clone().decompose();
64
65        // replace correlated_input_ref in project exprs
66        let mut rewriter =
67            ApplyOffsetRewriter::new(left.schema().len(), &correlated_indices, correlated_id);
68
69        let new_proj_exprs: Vec<ExprImpl> = proj_exprs
70            .into_iter()
71            .map(|expr| rewriter.rewrite_expr(expr))
72            .collect_vec();
73
74        exprs.extend(new_proj_exprs.clone());
75
76        let mut rewriter = ApplyOnConditionRewriter {
77            left_input_len: left.schema().len(),
78            mapping: new_proj_exprs,
79        };
80        let new_on = on.rewrite_expr(&mut rewriter);
81        let new_apply = LogicalApply::create(
82            left,
83            proj_input,
84            join_type,
85            new_on,
86            correlated_id,
87            correlated_indices,
88            max_one_row,
89        );
90
91        let new_project = LogicalProject::create(new_apply, exprs);
92        Some(new_project)
93    }
94}
95
96impl ApplyProjectTransposeRule {
97    pub fn create() -> BoxedRule {
98        Box::new(ApplyProjectTransposeRule {})
99    }
100}
101
102pub struct ApplyOnConditionRewriter {
103    pub left_input_len: usize,
104    pub mapping: Vec<ExprImpl>,
105}
106
107impl ExprRewriter for ApplyOnConditionRewriter {
108    fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
109        if input_ref.index >= self.left_input_len {
110            self.mapping[input_ref.index() - self.left_input_len].clone()
111        } else {
112            input_ref.into()
113        }
114    }
115}