risingwave_frontend/optimizer/rule/
push_calculation_of_join_rule.rs1use fixedbitset::FixedBitSet;
16use itertools::Itertools;
17use risingwave_common::util::iter_util::ZipEqFast;
18use risingwave_pb::expr::expr_node::Type;
19
20use super::prelude::{PlanRef, *};
21use crate::expr::{Expr, ExprImpl, ExprRewriter, FunctionCall, InputRef, align_types};
22use crate::optimizer::plan_node::{LogicalJoin, LogicalProject};
23use crate::utils::{ColIndexMapping, Condition};
24
25pub struct PushCalculationOfJoinRule {}
26
27impl Rule<Logical> for PushCalculationOfJoinRule {
28 fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
29 let join: &LogicalJoin = plan.as_logical_join()?;
30 let (mut left, mut right, mut on, join_type, mut output_indices) = join.clone().decompose();
31 let left_col_num = left.schema().len();
32 let right_col_num = right.schema().len();
33
34 let exprs = on.conjunctions;
35 let (left_exprs, right_exprs, indices_and_ty_of_func_calls) =
36 Self::find_comparison_exprs(left_col_num, right_col_num, &exprs);
37
38 let left_exprs_non_input_ref: Vec<_> = left_exprs
40 .iter()
41 .filter(|e| e.as_input_ref().is_none())
42 .cloned()
43 .collect();
44 let right_exprs_non_input_ref: Vec<_> = right_exprs
45 .iter()
46 .filter(|e| e.as_input_ref().is_none())
47 .cloned()
48 .collect();
49
50 let new_internal_col_num = left_col_num
51 + left_exprs_non_input_ref.len()
52 + right_col_num
53 + right_exprs_non_input_ref.len();
54 let mut col_index_mapping = {
57 let map = (0..left_col_num)
58 .chain(
59 (left_col_num..left_col_num + right_col_num)
60 .map(|i| i + left_exprs_non_input_ref.len()),
61 )
62 .map(Some)
63 .collect_vec();
64 ColIndexMapping::new(map, new_internal_col_num)
65 };
66 let (mut exprs, new_output_indices) =
67 Self::remap_exprs_and_output_indices(exprs, output_indices, &mut col_index_mapping);
68 output_indices = new_output_indices;
69
70 let mut left_index = left_col_num;
77 let mut right_index = left_col_num + left_exprs_non_input_ref.len() + right_col_num;
78 let mut right_exprs_mapping = {
79 let map = (0..right_col_num)
80 .map(|i| i + left_col_num + left_exprs_non_input_ref.len())
81 .map(Some)
82 .collect_vec();
83 ColIndexMapping::new(map, new_internal_col_num)
84 };
85 for (((index_of_func_call, ty), left_expr), right_expr) in indices_and_ty_of_func_calls
87 .into_iter()
88 .zip_eq_fast(&left_exprs)
89 .zip_eq_fast(&right_exprs)
90 {
91 let left_input = if left_expr.as_input_ref().is_some() {
92 left_expr.clone()
93 } else {
94 left_index += 1;
95 InputRef::new(left_index - 1, left_expr.return_type()).into()
96 };
97 let right_input = if right_expr.as_input_ref().is_some() {
98 right_exprs_mapping.rewrite_expr(right_expr.clone())
99 } else {
100 right_index += 1;
101 InputRef::new(right_index - 1, right_expr.return_type()).into()
102 };
103 exprs[index_of_func_call] = FunctionCall::new(ty, vec![left_input, right_input])
104 .unwrap()
105 .into();
106 }
107 on = Condition {
108 conjunctions: exprs,
109 };
110
111 let new_input = |input: PlanRef, appended_exprs: Vec<ExprImpl>| {
113 let mut exprs = input
114 .schema()
115 .data_types()
116 .into_iter()
117 .enumerate()
118 .map(|(i, data_type)| InputRef::new(i, data_type).into())
119 .collect_vec();
120 exprs.extend(appended_exprs);
121 LogicalProject::create(input, exprs)
122 };
123 if !left_exprs_non_input_ref.is_empty() {
125 left = new_input(left, left_exprs_non_input_ref);
126 }
127 if !right_exprs_non_input_ref.is_empty() {
128 right = new_input(right, right_exprs_non_input_ref);
129 }
130
131 Some(LogicalJoin::with_output_indices(left, right, join_type, on, output_indices).into())
132 }
133}
134
135impl PushCalculationOfJoinRule {
136 fn find_comparison_exprs(
138 left_col_num: usize,
139 right_col_num: usize,
140 exprs: &[ExprImpl],
141 ) -> (Vec<ExprImpl>, Vec<ExprImpl>, Vec<(usize, Type)>) {
142 let left_bit_map = FixedBitSet::from_iter(0..left_col_num);
143 let right_bit_map = FixedBitSet::from_iter(left_col_num..left_col_num + right_col_num);
144
145 let mut left_exprs = vec![];
146 let mut right_exprs = vec![];
147 let mut indices_and_ty_of_func_calls = vec![];
150 let is_comparison_type = |ty| {
151 matches!(
152 ty,
153 Type::LessThan
154 | Type::LessThanOrEqual
155 | Type::Equal
156 | Type::IsNotDistinctFrom
157 | Type::GreaterThan
158 | Type::GreaterThanOrEqual
159 )
160 };
161 for (index, expr) in exprs.iter().enumerate() {
162 let ExprImpl::FunctionCall(func) = expr else {
163 continue;
164 };
165 if !is_comparison_type(func.func_type()) {
166 continue;
167 }
168 if expr.count_nows() > 0 {
170 continue;
171 }
172 let (mut ty, left, right) = func.clone().decompose_as_binary();
173 let left_input_bits = left.collect_input_refs(left_col_num + right_col_num);
176 let right_input_bits = right.collect_input_refs(left_col_num + right_col_num);
177 let (mut left, mut right) = if left_input_bits.is_subset(&left_bit_map)
178 && right_input_bits.is_subset(&right_bit_map)
179 {
180 (left, right)
181 } else if left_input_bits.is_subset(&right_bit_map)
182 && right_input_bits.is_subset(&left_bit_map)
183 {
184 ty = ExprImpl::reverse_comparison(ty);
185 (right, left)
186 } else {
187 continue;
188 };
189 if left.as_input_ref().is_some()
192 && right.as_input_ref().is_some()
193 && left.return_type() == right.return_type()
194 {
195 continue;
196 }
197 align_types([&mut left, &mut right].into_iter()).unwrap();
199 left_exprs.push(left);
200 {
201 let mut shift_with_offset = ColIndexMapping::with_shift_offset(
202 left_col_num + right_col_num,
203 -(left_col_num as isize),
204 );
205 let right = shift_with_offset.rewrite_expr(right);
206 right_exprs.push(right);
207 }
208 indices_and_ty_of_func_calls.push((index, ty));
209 }
210 (left_exprs, right_exprs, indices_and_ty_of_func_calls)
211 }
212
213 fn remap_exprs_and_output_indices(
215 exprs: Vec<ExprImpl>,
216 output_indices: Vec<usize>,
217 col_index_mapping: &mut ColIndexMapping,
218 ) -> (Vec<ExprImpl>, Vec<usize>) {
219 let exprs: Vec<ExprImpl> = exprs
220 .into_iter()
221 .map(|expr| col_index_mapping.rewrite_expr(expr))
222 .collect();
223 let output_indices = output_indices
224 .into_iter()
225 .map(|i| col_index_mapping.map(i))
226 .collect();
227 (exprs, output_indices)
228 }
229
230 pub fn create() -> BoxedRule {
231 Box::new(PushCalculationOfJoinRule {})
232 }
233}