risingwave_frontend/optimizer/rule/
join_project_transpose_rule.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use itertools::Itertools;
16use risingwave_common::util::column_index_mapping::ColIndexMapping;
17use risingwave_pb::plan_common::JoinType;
18
19use super::{BoxedRule, Rule};
20use crate::expr::{ExprRewriter, InputRef};
21use crate::optimizer::plan_node::{LogicalJoin, LogicalProject};
22use crate::utils::IndexRewriter;
23
24/// Before this rule:
25/// `join(project_a(t1),project_b(t2))`
26///
27/// After this rule:
28/// `new_project(new_join(t1,t2))`
29/// `new_join` is a full out join.
30/// `new_project` is a projection combine `project_a` and `project_b` and it only output the
31/// columns that are output in the original `join`.
32pub struct JoinProjectTransposeRule {}
33
34impl Rule for JoinProjectTransposeRule {
35    fn apply(&self, plan: crate::PlanRef) -> Option<crate::PlanRef> {
36        let join = plan.as_logical_join()?;
37
38        let (left, right, on, join_type, _) = join.clone().decompose();
39
40        let (left_input_index_on_condition, right_input_index_on_condition) =
41            join.input_idx_on_condition();
42
43        let full_output_len = left.schema().len() + right.schema().len();
44        let right_output_len = right.schema().len();
45        let left_output_len = left.schema().len();
46        let mut full_proj_exprs = Vec::with_capacity(full_output_len);
47
48        let mut old_i2new_i = ColIndexMapping::empty(0, 0);
49
50        let mut has_new_left: bool = false;
51        let mut has_new_right: bool = false;
52
53        // prepare for pull up left child.
54        let new_left = if let Some(project) = left.as_logical_project()
55            && left_input_index_on_condition
56                .iter()
57                .all(|index| project.exprs()[*index].as_input_ref().is_some())
58            && join_type != JoinType::RightAnti
59            && join_type != JoinType::RightSemi
60            && join_type != JoinType::RightOuter
61            && join_type != JoinType::FullOuter
62        {
63            let (exprs, child) = project.clone().decompose();
64
65            old_i2new_i = old_i2new_i.union(
66                &join
67                    .i2l_col_mapping_ignore_join_type()
68                    .composite(&project.o2i_col_mapping()),
69            );
70
71            full_proj_exprs.extend(exprs);
72
73            has_new_left = true;
74
75            child
76        } else {
77            old_i2new_i = old_i2new_i.union(&join.i2l_col_mapping_ignore_join_type());
78
79            for i in 0..left_output_len {
80                full_proj_exprs.push(
81                    InputRef {
82                        index: i,
83                        data_type: left.schema().data_types()[i].clone(),
84                    }
85                    .into(),
86                );
87            }
88
89            left
90        };
91
92        // prepare for pull up right child.
93        let new_right = if let Some(project) = right.as_logical_project()
94            && right_input_index_on_condition
95                .iter()
96                .all(|index| project.exprs()[*index].as_input_ref().is_some())
97            && join_type != JoinType::LeftAnti
98            && join_type != JoinType::LeftSemi
99            && join_type != JoinType::LeftOuter
100            && join_type != JoinType::FullOuter
101        {
102            let (exprs, child) = project.clone().decompose();
103
104            old_i2new_i = old_i2new_i.union(
105                &join
106                    .i2r_col_mapping_ignore_join_type()
107                    .composite(&project.o2i_col_mapping())
108                    .clone_with_offset(new_left.schema().len()),
109            );
110
111            let mut index_writer = IndexRewriter::new(
112                ColIndexMapping::identity(child.schema().len())
113                    .clone_with_offset(new_left.schema().len()),
114            );
115            full_proj_exprs.extend(
116                exprs
117                    .into_iter()
118                    .map(|expr| index_writer.rewrite_expr(expr)),
119            );
120
121            has_new_right = true;
122
123            child
124        } else {
125            old_i2new_i = old_i2new_i.union(
126                &join
127                    .i2r_col_mapping_ignore_join_type()
128                    .clone_with_offset(new_left.schema().len()),
129            );
130
131            for i in 0..right_output_len {
132                full_proj_exprs.push(
133                    InputRef {
134                        index: i + new_left.schema().len(),
135                        data_type: right.schema().data_types()[i].clone(),
136                    }
137                    .into(),
138                );
139            }
140
141            right
142        };
143
144        // No project will be pulled up
145        if !has_new_left && !has_new_right {
146            return None;
147        }
148
149        let new_cond = on.rewrite_expr(&mut IndexRewriter::new(old_i2new_i));
150        let new_join = LogicalJoin::new(new_left, new_right, join_type, new_cond);
151
152        // remain only the columns that are output in the original join
153        let new_proj_exprs = join
154            .output_indices()
155            .iter()
156            .map(|i| full_proj_exprs[*i].clone())
157            .collect_vec();
158        let new_project = LogicalProject::new(new_join.into(), new_proj_exprs);
159
160        Some(new_project.into())
161    }
162}
163
164impl JoinProjectTransposeRule {
165    pub fn create() -> BoxedRule {
166        Box::new(JoinProjectTransposeRule {})
167    }
168}