risingwave_frontend/optimizer/rule/
push_calculation_of_join_rule.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use fixedbitset::FixedBitSet;
16use itertools::Itertools;
17use risingwave_common::util::iter_util::ZipEqFast;
18use risingwave_pb::expr::expr_node::Type;
19
20use super::BoxedRule;
21use crate::expr::{Expr, ExprImpl, ExprRewriter, FunctionCall, InputRef, align_types};
22use crate::optimizer::PlanRef;
23use crate::optimizer::plan_node::{LogicalJoin, LogicalProject};
24use crate::optimizer::rule::Rule;
25use crate::utils::{ColIndexMapping, Condition};
26
27pub struct PushCalculationOfJoinRule {}
28
29impl Rule for PushCalculationOfJoinRule {
30    fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
31        let join: &LogicalJoin = plan.as_logical_join()?;
32        let (mut left, mut right, mut on, join_type, mut output_indices) = join.clone().decompose();
33        let left_col_num = left.schema().len();
34        let right_col_num = right.schema().len();
35
36        let exprs = on.conjunctions;
37        let (left_exprs, right_exprs, indices_and_ty_of_func_calls) =
38            Self::find_comparison_exprs(left_col_num, right_col_num, &exprs);
39
40        // Store only the expressions that need a new column in the projection
41        let left_exprs_non_input_ref: Vec<_> = left_exprs
42            .iter()
43            .filter(|e| e.as_input_ref().is_none())
44            .cloned()
45            .collect();
46        let right_exprs_non_input_ref: Vec<_> = right_exprs
47            .iter()
48            .filter(|e| e.as_input_ref().is_none())
49            .cloned()
50            .collect();
51
52        let new_internal_col_num = left_col_num
53            + left_exprs_non_input_ref.len()
54            + right_col_num
55            + right_exprs_non_input_ref.len();
56        // used to shift indices of input_refs pointing the right side of `join` with
57        // `left_exprs.len`.
58        let mut col_index_mapping = {
59            let map = (0..left_col_num)
60                .chain(
61                    (left_col_num..left_col_num + right_col_num)
62                        .map(|i| i + left_exprs_non_input_ref.len()),
63                )
64                .map(Some)
65                .collect_vec();
66            ColIndexMapping::new(map, new_internal_col_num)
67        };
68        let (mut exprs, new_output_indices) =
69            Self::remap_exprs_and_output_indices(exprs, output_indices, &mut col_index_mapping);
70        output_indices = new_output_indices;
71
72        // ```ignore
73        // the internal table of join has has the following schema:
74        // original left's columns | left_exprs | original right's columns | right_exprs
75        //```
76        // `left_index` and `right_index` will scan through `left_exprs` and `right_exprs`
77        // respectively.
78        let mut left_index = left_col_num;
79        let mut right_index = left_col_num + left_exprs_non_input_ref.len() + right_col_num;
80        let mut right_exprs_mapping = {
81            let map = (0..right_col_num)
82                .map(|i| i + left_col_num + left_exprs_non_input_ref.len())
83                .map(Some)
84                .collect_vec();
85            ColIndexMapping::new(map, new_internal_col_num)
86        };
87        // replace chosen function calls.
88        for (((index_of_func_call, ty), left_expr), right_expr) in indices_and_ty_of_func_calls
89            .into_iter()
90            .zip_eq_fast(&left_exprs)
91            .zip_eq_fast(&right_exprs)
92        {
93            let left_input = if left_expr.as_input_ref().is_some() {
94                left_expr.clone()
95            } else {
96                left_index += 1;
97                InputRef::new(left_index - 1, left_expr.return_type()).into()
98            };
99            let right_input = if right_expr.as_input_ref().is_some() {
100                right_exprs_mapping.rewrite_expr(right_expr.clone())
101            } else {
102                right_index += 1;
103                InputRef::new(right_index - 1, right_expr.return_type()).into()
104            };
105            exprs[index_of_func_call] = FunctionCall::new(ty, vec![left_input, right_input])
106                .unwrap()
107                .into();
108        }
109        on = Condition {
110            conjunctions: exprs,
111        };
112
113        // add project to do the calculation.
114        let new_input = |input: PlanRef, appended_exprs: Vec<ExprImpl>| {
115            let mut exprs = input
116                .schema()
117                .data_types()
118                .into_iter()
119                .enumerate()
120                .map(|(i, data_type)| InputRef::new(i, data_type).into())
121                .collect_vec();
122            exprs.extend(appended_exprs);
123            LogicalProject::create(input, exprs)
124        };
125        // avoid unnecessary `project`s.
126        if !left_exprs_non_input_ref.is_empty() {
127            left = new_input(left, left_exprs_non_input_ref);
128        }
129        if !right_exprs_non_input_ref.is_empty() {
130            right = new_input(right, right_exprs_non_input_ref);
131        }
132
133        Some(LogicalJoin::with_output_indices(left, right, join_type, on, output_indices).into())
134    }
135}
136
137impl PushCalculationOfJoinRule {
138    /// find the comparison exprs and return their inputs, types and indices.
139    fn find_comparison_exprs(
140        left_col_num: usize,
141        right_col_num: usize,
142        exprs: &[ExprImpl],
143    ) -> (Vec<ExprImpl>, Vec<ExprImpl>, Vec<(usize, Type)>) {
144        let left_bit_map = FixedBitSet::from_iter(0..left_col_num);
145        let right_bit_map = FixedBitSet::from_iter(left_col_num..left_col_num + right_col_num);
146
147        let mut left_exprs = vec![];
148        let mut right_exprs = vec![];
149        // indices and return types of function calls whose's inputs will be calculated in
150        // `project`s
151        let mut indices_and_ty_of_func_calls = vec![];
152        let is_comparison_type = |ty| {
153            matches!(
154                ty,
155                Type::LessThan
156                    | Type::LessThanOrEqual
157                    | Type::Equal
158                    | Type::IsNotDistinctFrom
159                    | Type::GreaterThan
160                    | Type::GreaterThanOrEqual
161            )
162        };
163        for (index, expr) in exprs.iter().enumerate() {
164            let ExprImpl::FunctionCall(func) = expr else {
165                continue;
166            };
167            if !is_comparison_type(func.func_type()) {
168                continue;
169            }
170            // Do not decompose the comparison if it contains `now()`
171            if expr.count_nows() > 0 {
172                continue;
173            }
174            let (mut ty, left, right) = func.clone().decompose_as_binary();
175            // we just cast the return types of inputs of binary predicates for `HashJoin` and
176            // `DynamicFilter`.
177            let left_input_bits = left.collect_input_refs(left_col_num + right_col_num);
178            let right_input_bits = right.collect_input_refs(left_col_num + right_col_num);
179            let (mut left, mut right) = if left_input_bits.is_subset(&left_bit_map)
180                && right_input_bits.is_subset(&right_bit_map)
181            {
182                (left, right)
183            } else if left_input_bits.is_subset(&right_bit_map)
184                && right_input_bits.is_subset(&left_bit_map)
185            {
186                ty = ExprImpl::reverse_comparison(ty);
187                (right, left)
188            } else {
189                continue;
190            };
191            // when both `left` and `right` are `input_ref`, and they have the same return type
192            // there is no need to calculate them in project.
193            if left.as_input_ref().is_some()
194                && right.as_input_ref().is_some()
195                && left.return_type() == right.return_type()
196            {
197                continue;
198            }
199            // align return types to avoid error when executing join.
200            align_types([&mut left, &mut right].into_iter()).unwrap();
201            left_exprs.push(left);
202            {
203                let mut shift_with_offset = ColIndexMapping::with_shift_offset(
204                    left_col_num + right_col_num,
205                    -(left_col_num as isize),
206                );
207                let right = shift_with_offset.rewrite_expr(right);
208                right_exprs.push(right);
209            }
210            indices_and_ty_of_func_calls.push((index, ty));
211        }
212        (left_exprs, right_exprs, indices_and_ty_of_func_calls)
213    }
214
215    /// use `col_index_mapping` to remap `exprs` and `output_indices`.
216    fn remap_exprs_and_output_indices(
217        exprs: Vec<ExprImpl>,
218        output_indices: Vec<usize>,
219        col_index_mapping: &mut ColIndexMapping,
220    ) -> (Vec<ExprImpl>, Vec<usize>) {
221        let exprs: Vec<ExprImpl> = exprs
222            .into_iter()
223            .map(|expr| col_index_mapping.rewrite_expr(expr))
224            .collect();
225        let output_indices = output_indices
226            .into_iter()
227            .map(|i| col_index_mapping.map(i))
228            .collect();
229        (exprs, output_indices)
230    }
231
232    pub fn create() -> BoxedRule {
233        Box::new(PushCalculationOfJoinRule {})
234    }
235}