risingwave_frontend/optimizer/rule/
pull_up_hop_rule.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use risingwave_common::util::column_index_mapping::ColIndexMapping;
16use risingwave_pb::plan_common::JoinType;
17
18use super::{BoxedRule, Rule};
19use crate::optimizer::plan_node::{LogicalHopWindow, LogicalJoin};
20use crate::utils::IndexRewriter;
21
22pub struct PullUpHopRule {}
23
24impl Rule for PullUpHopRule {
25    fn apply(&self, plan: crate::PlanRef) -> Option<crate::PlanRef> {
26        let join = plan.as_logical_join()?;
27
28        let (left, right, on, join_type, mut output_index) = join.clone().decompose();
29
30        let (left_input_index_on_condition, right_input_index_on_condition) =
31            join.input_idx_on_condition();
32
33        let (left_output_pos, right_output_pos) = {
34            let mut left_output_pos = vec![];
35            let mut right_output_pos = vec![];
36            output_index.iter_mut().enumerate().for_each(|(pos, idx)| {
37                if *idx < left.schema().len() {
38                    left_output_pos.push(pos);
39                } else {
40                    right_output_pos.push(pos);
41                    // make right output index start from 0. We can identify left and right output
42                    // index by the output_pos.
43                    *idx -= left.schema().len();
44                }
45            });
46            (left_output_pos, right_output_pos)
47        };
48
49        let mut old_i2new_i = ColIndexMapping::empty(0, 0);
50
51        let mut pull_up_left = false;
52        let mut pull_up_right = false;
53
54        let (new_left, left_time_col, left_window_slide, left_window_size, left_window_offset) =
55            if let Some(hop) = left.as_logical_hop_window()
56                && left_input_index_on_condition.iter().all(|&index| {
57                    (hop.output_window_start_col_idx() != Some(index))
58                        && (hop.output_window_end_col_idx() != Some(index))
59                })
60                && join_type != JoinType::RightAnti
61                && join_type != JoinType::RightSemi
62                && join_type != JoinType::RightOuter
63                && join_type != JoinType::FullOuter
64            {
65                let (input, time_col, window_slide, window_size, window_offset, _) =
66                    hop.clone().into_parts();
67
68                old_i2new_i = old_i2new_i.union(
69                    &join
70                        .i2l_col_mapping_ignore_join_type()
71                        .composite(&hop.o2i_col_mapping()),
72                );
73                left_output_pos.iter().for_each(|&pos| {
74                    output_index[pos] = hop.output2internal_col_mapping().map(output_index[pos]);
75                });
76                pull_up_left = true;
77                (
78                    input,
79                    Some(time_col),
80                    Some(window_slide),
81                    Some(window_size),
82                    Some(window_offset),
83                )
84            } else {
85                old_i2new_i = old_i2new_i.union(&join.i2l_col_mapping_ignore_join_type());
86
87                (left, None, None, None, None)
88            };
89
90        let (new_right, right_time_col, right_window_slide, right_window_size, right_window_offset) =
91            if let Some(hop) = right.as_logical_hop_window()
92                && right_input_index_on_condition.iter().all(|&index| {
93                    hop.output_window_start_col_idx() != Some(index)
94                        && hop.output_window_end_col_idx() != Some(index)
95                })
96                && join_type != JoinType::LeftAnti
97                && join_type != JoinType::LeftSemi
98                && join_type != JoinType::LeftOuter
99                && join_type != JoinType::FullOuter
100            {
101                let (input, time_col, window_slide, window_size, window_offset, _) =
102                    hop.clone().into_parts();
103
104                old_i2new_i = old_i2new_i.union(
105                    &join
106                        .i2r_col_mapping_ignore_join_type()
107                        .composite(&hop.o2i_col_mapping())
108                        .clone_with_offset(new_left.schema().len()),
109                );
110
111                right_output_pos.iter().for_each(|&pos| {
112                    output_index[pos] = hop.output2internal_col_mapping().map(output_index[pos]);
113                });
114                pull_up_right = true;
115                (
116                    input,
117                    Some(time_col),
118                    Some(window_slide),
119                    Some(window_size),
120                    Some(window_offset),
121                )
122            } else {
123                old_i2new_i = old_i2new_i.union(
124                    &join
125                        .i2r_col_mapping_ignore_join_type()
126                        .clone_with_offset(new_left.schema().len()),
127                );
128
129                (right, None, None, None, None)
130            };
131
132        if !pull_up_left && !pull_up_right {
133            return None;
134        }
135
136        let new_output_index = {
137            let new_right_output_len =
138                if join_type == JoinType::LeftSemi || join_type == JoinType::LeftAnti {
139                    0
140                } else {
141                    new_right.schema().len()
142                };
143            let new_left_output_len =
144                if join_type == JoinType::RightSemi || join_type == JoinType::RightAnti {
145                    0
146                } else {
147                    new_left.schema().len()
148                };
149
150            // The left output index can separate into two parts:
151            // `left_other_column | left_window_start | letf_window_end`
152            // The right output index can separate into two parts:
153            // `right_other_column | right_window_start | right_window_end`
154            //
155            // If we pull up left, the column index will be changed to:
156            // `left_other_column | right_column | left_window_start | letf_window_end`,
157            // we need to update the index of left window start and left window end.
158            //
159            // If we pull up right and left, the column index will be changed to:
160            // `left_other_column | right_other_column | left_window_start | letf_window_end |
161            // right_window_tart | right_window_end |`, we need to update the index of
162            // left window start and left window end and right window start and right window end.
163            if pull_up_left {
164                left_output_pos.iter().for_each(|&pos| {
165                    if output_index[pos] >= new_left_output_len {
166                        output_index[pos] += new_right_output_len;
167                    }
168                });
169            }
170            if pull_up_right && pull_up_left {
171                right_output_pos.iter().for_each(|&pos| {
172                    if output_index[pos] < new_right_output_len {
173                        output_index[pos] += new_left_output_len;
174                    } else {
175                        output_index[pos] +=
176                            new_left_output_len + LogicalHopWindow::ADDITION_COLUMN_LEN;
177                    }
178                });
179            } else {
180                right_output_pos.iter().for_each(|&pos| {
181                    output_index[pos] += new_left_output_len;
182                });
183            }
184            output_index
185        };
186        let new_left_len = new_left.schema().len();
187        let new_cond = on.rewrite_expr(&mut IndexRewriter::new(old_i2new_i));
188        let new_join = LogicalJoin::new(new_left, new_right, join_type, new_cond);
189
190        let new_hop = if pull_up_left && pull_up_right {
191            let left_hop = LogicalHopWindow::create(
192                new_join.into(),
193                left_time_col.unwrap(),
194                left_window_slide.unwrap(),
195                left_window_size.unwrap(),
196                left_window_offset.unwrap(),
197            );
198            LogicalHopWindow::create(
199                left_hop,
200                right_time_col
201                    .unwrap()
202                    .clone_with_offset(new_left_len as isize),
203                right_window_slide.unwrap(),
204                right_window_size.unwrap(),
205                right_window_offset.unwrap(),
206            )
207        } else if pull_up_left {
208            LogicalHopWindow::create(
209                new_join.into(),
210                left_time_col.unwrap(),
211                left_window_slide.unwrap(),
212                left_window_size.unwrap(),
213                left_window_offset.unwrap(),
214            )
215        } else {
216            LogicalHopWindow::create(
217                new_join.into(),
218                right_time_col
219                    .unwrap()
220                    .clone_with_offset(new_left_len as isize),
221                right_window_slide.unwrap(),
222                right_window_size.unwrap(),
223                right_window_offset.unwrap(),
224            )
225        };
226
227        Some(
228            new_hop
229                .as_logical_hop_window()
230                .unwrap()
231                .clone_with_output_indices(new_output_index)
232                .into(),
233        )
234    }
235}
236
237impl PullUpHopRule {
238    pub fn create() -> BoxedRule {
239        Box::new(PullUpHopRule {})
240    }
241}