risingwave_frontend/optimizer/rule/
push_calculation_of_join_rule.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use fixedbitset::FixedBitSet;
16use itertools::Itertools;
17use risingwave_common::util::iter_util::ZipEqFast;
18use risingwave_pb::expr::expr_node::Type;
19
20use super::prelude::{PlanRef, *};
21use crate::expr::{Expr, ExprImpl, ExprRewriter, FunctionCall, InputRef, align_types};
22use crate::optimizer::plan_node::{LogicalJoin, LogicalProject};
23use crate::utils::{ColIndexMapping, Condition};
24
25pub struct PushCalculationOfJoinRule {}
26
27impl Rule<Logical> for PushCalculationOfJoinRule {
28    fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
29        let join: &LogicalJoin = plan.as_logical_join()?;
30        let (mut left, mut right, mut on, join_type, mut output_indices) = join.clone().decompose();
31        let left_col_num = left.schema().len();
32        let right_col_num = right.schema().len();
33
34        let exprs = on.conjunctions;
35        let (left_exprs, right_exprs, indices_and_ty_of_func_calls) =
36            Self::find_comparison_exprs(left_col_num, right_col_num, &exprs);
37
38        // Store only the expressions that need a new column in the projection
39        let left_exprs_non_input_ref: Vec<_> = left_exprs
40            .iter()
41            .filter(|e| e.as_input_ref().is_none())
42            .cloned()
43            .collect();
44        let right_exprs_non_input_ref: Vec<_> = right_exprs
45            .iter()
46            .filter(|e| e.as_input_ref().is_none())
47            .cloned()
48            .collect();
49
50        let new_internal_col_num = left_col_num
51            + left_exprs_non_input_ref.len()
52            + right_col_num
53            + right_exprs_non_input_ref.len();
54        // used to shift indices of input_refs pointing the right side of `join` with
55        // `left_exprs.len`.
56        let mut col_index_mapping = {
57            let map = (0..left_col_num)
58                .chain(
59                    (left_col_num..left_col_num + right_col_num)
60                        .map(|i| i + left_exprs_non_input_ref.len()),
61                )
62                .map(Some)
63                .collect_vec();
64            ColIndexMapping::new(map, new_internal_col_num)
65        };
66        let (mut exprs, new_output_indices) =
67            Self::remap_exprs_and_output_indices(exprs, output_indices, &mut col_index_mapping);
68        output_indices = new_output_indices;
69
70        // ```ignore
71        // the internal table of join has has the following schema:
72        // original left's columns | left_exprs | original right's columns | right_exprs
73        //```
74        // `left_index` and `right_index` will scan through `left_exprs` and `right_exprs`
75        // respectively.
76        let mut left_index = left_col_num;
77        let mut right_index = left_col_num + left_exprs_non_input_ref.len() + right_col_num;
78        let mut right_exprs_mapping = {
79            let map = (0..right_col_num)
80                .map(|i| i + left_col_num + left_exprs_non_input_ref.len())
81                .map(Some)
82                .collect_vec();
83            ColIndexMapping::new(map, new_internal_col_num)
84        };
85        // replace chosen function calls.
86        for (((index_of_func_call, ty), left_expr), right_expr) in indices_and_ty_of_func_calls
87            .into_iter()
88            .zip_eq_fast(&left_exprs)
89            .zip_eq_fast(&right_exprs)
90        {
91            let left_input = if left_expr.as_input_ref().is_some() {
92                left_expr.clone()
93            } else {
94                left_index += 1;
95                InputRef::new(left_index - 1, left_expr.return_type()).into()
96            };
97            let right_input = if right_expr.as_input_ref().is_some() {
98                right_exprs_mapping.rewrite_expr(right_expr.clone())
99            } else {
100                right_index += 1;
101                InputRef::new(right_index - 1, right_expr.return_type()).into()
102            };
103            exprs[index_of_func_call] = FunctionCall::new(ty, vec![left_input, right_input])
104                .unwrap()
105                .into();
106        }
107        on = Condition {
108            conjunctions: exprs,
109        };
110
111        // add project to do the calculation.
112        let new_input = |input: PlanRef, appended_exprs: Vec<ExprImpl>| {
113            let mut exprs = input
114                .schema()
115                .data_types()
116                .into_iter()
117                .enumerate()
118                .map(|(i, data_type)| InputRef::new(i, data_type).into())
119                .collect_vec();
120            exprs.extend(appended_exprs);
121            LogicalProject::create(input, exprs)
122        };
123        // avoid unnecessary `project`s.
124        if !left_exprs_non_input_ref.is_empty() {
125            left = new_input(left, left_exprs_non_input_ref);
126        }
127        if !right_exprs_non_input_ref.is_empty() {
128            right = new_input(right, right_exprs_non_input_ref);
129        }
130
131        Some(LogicalJoin::with_output_indices(left, right, join_type, on, output_indices).into())
132    }
133}
134
135impl PushCalculationOfJoinRule {
136    /// find the comparison exprs and return their inputs, types and indices.
137    fn find_comparison_exprs(
138        left_col_num: usize,
139        right_col_num: usize,
140        exprs: &[ExprImpl],
141    ) -> (Vec<ExprImpl>, Vec<ExprImpl>, Vec<(usize, Type)>) {
142        let left_bit_map = FixedBitSet::from_iter(0..left_col_num);
143        let right_bit_map = FixedBitSet::from_iter(left_col_num..left_col_num + right_col_num);
144
145        let mut left_exprs = vec![];
146        let mut right_exprs = vec![];
147        // indices and return types of function calls whose's inputs will be calculated in
148        // `project`s
149        let mut indices_and_ty_of_func_calls = vec![];
150        let is_comparison_type = |ty| {
151            matches!(
152                ty,
153                Type::LessThan
154                    | Type::LessThanOrEqual
155                    | Type::Equal
156                    | Type::IsNotDistinctFrom
157                    | Type::GreaterThan
158                    | Type::GreaterThanOrEqual
159            )
160        };
161        for (index, expr) in exprs.iter().enumerate() {
162            let ExprImpl::FunctionCall(func) = expr else {
163                continue;
164            };
165            if !is_comparison_type(func.func_type()) {
166                continue;
167            }
168            // Do not decompose the comparison if it contains `now()`
169            if expr.count_nows() > 0 {
170                continue;
171            }
172            let (mut ty, left, right) = func.clone().decompose_as_binary();
173            // we just cast the return types of inputs of binary predicates for `HashJoin` and
174            // `DynamicFilter`.
175            let left_input_bits = left.collect_input_refs(left_col_num + right_col_num);
176            let right_input_bits = right.collect_input_refs(left_col_num + right_col_num);
177            let (mut left, mut right) = if left_input_bits.is_subset(&left_bit_map)
178                && right_input_bits.is_subset(&right_bit_map)
179            {
180                (left, right)
181            } else if left_input_bits.is_subset(&right_bit_map)
182                && right_input_bits.is_subset(&left_bit_map)
183            {
184                ty = ExprImpl::reverse_comparison(ty);
185                (right, left)
186            } else {
187                continue;
188            };
189            // when both `left` and `right` are `input_ref`, and they have the same return type
190            // there is no need to calculate them in project.
191            if left.as_input_ref().is_some()
192                && right.as_input_ref().is_some()
193                && left.return_type() == right.return_type()
194            {
195                continue;
196            }
197            // align return types to avoid error when executing join.
198            align_types([&mut left, &mut right].into_iter()).unwrap();
199            left_exprs.push(left);
200            {
201                let mut shift_with_offset = ColIndexMapping::with_shift_offset(
202                    left_col_num + right_col_num,
203                    -(left_col_num as isize),
204                );
205                let right = shift_with_offset.rewrite_expr(right);
206                right_exprs.push(right);
207            }
208            indices_and_ty_of_func_calls.push((index, ty));
209        }
210        (left_exprs, right_exprs, indices_and_ty_of_func_calls)
211    }
212
213    /// use `col_index_mapping` to remap `exprs` and `output_indices`.
214    fn remap_exprs_and_output_indices(
215        exprs: Vec<ExprImpl>,
216        output_indices: Vec<usize>,
217        col_index_mapping: &mut ColIndexMapping,
218    ) -> (Vec<ExprImpl>, Vec<usize>) {
219        let exprs: Vec<ExprImpl> = exprs
220            .into_iter()
221            .map(|expr| col_index_mapping.rewrite_expr(expr))
222            .collect();
223        let output_indices = output_indices
224            .into_iter()
225            .map(|i| col_index_mapping.map(i))
226            .collect();
227        (exprs, output_indices)
228    }
229
230    pub fn create() -> BoxedRule {
231        Box::new(PushCalculationOfJoinRule {})
232    }
233}