risingwave_frontend/optimizer/rule/
push_calculation_of_join_rule.rs1use fixedbitset::FixedBitSet;
16use itertools::Itertools;
17use risingwave_common::util::iter_util::ZipEqFast;
18use risingwave_pb::expr::expr_node::Type;
19
20use super::BoxedRule;
21use crate::expr::{Expr, ExprImpl, ExprRewriter, FunctionCall, InputRef, align_types};
22use crate::optimizer::PlanRef;
23use crate::optimizer::plan_node::{LogicalJoin, LogicalProject};
24use crate::optimizer::rule::Rule;
25use crate::utils::{ColIndexMapping, Condition};
26
27pub struct PushCalculationOfJoinRule {}
28
29impl Rule for PushCalculationOfJoinRule {
30 fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
31 let join: &LogicalJoin = plan.as_logical_join()?;
32 let (mut left, mut right, mut on, join_type, mut output_indices) = join.clone().decompose();
33 let left_col_num = left.schema().len();
34 let right_col_num = right.schema().len();
35
36 let exprs = on.conjunctions;
37 let (left_exprs, right_exprs, indices_and_ty_of_func_calls) =
38 Self::find_comparison_exprs(left_col_num, right_col_num, &exprs);
39
40 let left_exprs_non_input_ref: Vec<_> = left_exprs
42 .iter()
43 .filter(|e| e.as_input_ref().is_none())
44 .cloned()
45 .collect();
46 let right_exprs_non_input_ref: Vec<_> = right_exprs
47 .iter()
48 .filter(|e| e.as_input_ref().is_none())
49 .cloned()
50 .collect();
51
52 let new_internal_col_num = left_col_num
53 + left_exprs_non_input_ref.len()
54 + right_col_num
55 + right_exprs_non_input_ref.len();
56 let mut col_index_mapping = {
59 let map = (0..left_col_num)
60 .chain(
61 (left_col_num..left_col_num + right_col_num)
62 .map(|i| i + left_exprs_non_input_ref.len()),
63 )
64 .map(Some)
65 .collect_vec();
66 ColIndexMapping::new(map, new_internal_col_num)
67 };
68 let (mut exprs, new_output_indices) =
69 Self::remap_exprs_and_output_indices(exprs, output_indices, &mut col_index_mapping);
70 output_indices = new_output_indices;
71
72 let mut left_index = left_col_num;
79 let mut right_index = left_col_num + left_exprs_non_input_ref.len() + right_col_num;
80 let mut right_exprs_mapping = {
81 let map = (0..right_col_num)
82 .map(|i| i + left_col_num + left_exprs_non_input_ref.len())
83 .map(Some)
84 .collect_vec();
85 ColIndexMapping::new(map, new_internal_col_num)
86 };
87 for (((index_of_func_call, ty), left_expr), right_expr) in indices_and_ty_of_func_calls
89 .into_iter()
90 .zip_eq_fast(&left_exprs)
91 .zip_eq_fast(&right_exprs)
92 {
93 let left_input = if left_expr.as_input_ref().is_some() {
94 left_expr.clone()
95 } else {
96 left_index += 1;
97 InputRef::new(left_index - 1, left_expr.return_type()).into()
98 };
99 let right_input = if right_expr.as_input_ref().is_some() {
100 right_exprs_mapping.rewrite_expr(right_expr.clone())
101 } else {
102 right_index += 1;
103 InputRef::new(right_index - 1, right_expr.return_type()).into()
104 };
105 exprs[index_of_func_call] = FunctionCall::new(ty, vec![left_input, right_input])
106 .unwrap()
107 .into();
108 }
109 on = Condition {
110 conjunctions: exprs,
111 };
112
113 let new_input = |input: PlanRef, appended_exprs: Vec<ExprImpl>| {
115 let mut exprs = input
116 .schema()
117 .data_types()
118 .into_iter()
119 .enumerate()
120 .map(|(i, data_type)| InputRef::new(i, data_type).into())
121 .collect_vec();
122 exprs.extend(appended_exprs);
123 LogicalProject::create(input, exprs)
124 };
125 if !left_exprs_non_input_ref.is_empty() {
127 left = new_input(left, left_exprs_non_input_ref);
128 }
129 if !right_exprs_non_input_ref.is_empty() {
130 right = new_input(right, right_exprs_non_input_ref);
131 }
132
133 Some(LogicalJoin::with_output_indices(left, right, join_type, on, output_indices).into())
134 }
135}
136
137impl PushCalculationOfJoinRule {
138 fn find_comparison_exprs(
140 left_col_num: usize,
141 right_col_num: usize,
142 exprs: &[ExprImpl],
143 ) -> (Vec<ExprImpl>, Vec<ExprImpl>, Vec<(usize, Type)>) {
144 let left_bit_map = FixedBitSet::from_iter(0..left_col_num);
145 let right_bit_map = FixedBitSet::from_iter(left_col_num..left_col_num + right_col_num);
146
147 let mut left_exprs = vec![];
148 let mut right_exprs = vec![];
149 let mut indices_and_ty_of_func_calls = vec![];
152 let is_comparison_type = |ty| {
153 matches!(
154 ty,
155 Type::LessThan
156 | Type::LessThanOrEqual
157 | Type::Equal
158 | Type::IsNotDistinctFrom
159 | Type::GreaterThan
160 | Type::GreaterThanOrEqual
161 )
162 };
163 for (index, expr) in exprs.iter().enumerate() {
164 let ExprImpl::FunctionCall(func) = expr else {
165 continue;
166 };
167 if !is_comparison_type(func.func_type()) {
168 continue;
169 }
170 if expr.count_nows() > 0 {
172 continue;
173 }
174 let (mut ty, left, right) = func.clone().decompose_as_binary();
175 let left_input_bits = left.collect_input_refs(left_col_num + right_col_num);
178 let right_input_bits = right.collect_input_refs(left_col_num + right_col_num);
179 let (mut left, mut right) = if left_input_bits.is_subset(&left_bit_map)
180 && right_input_bits.is_subset(&right_bit_map)
181 {
182 (left, right)
183 } else if left_input_bits.is_subset(&right_bit_map)
184 && right_input_bits.is_subset(&left_bit_map)
185 {
186 ty = ExprImpl::reverse_comparison(ty);
187 (right, left)
188 } else {
189 continue;
190 };
191 if left.as_input_ref().is_some()
194 && right.as_input_ref().is_some()
195 && left.return_type() == right.return_type()
196 {
197 continue;
198 }
199 align_types([&mut left, &mut right].into_iter()).unwrap();
201 left_exprs.push(left);
202 {
203 let mut shift_with_offset = ColIndexMapping::with_shift_offset(
204 left_col_num + right_col_num,
205 -(left_col_num as isize),
206 );
207 let right = shift_with_offset.rewrite_expr(right);
208 right_exprs.push(right);
209 }
210 indices_and_ty_of_func_calls.push((index, ty));
211 }
212 (left_exprs, right_exprs, indices_and_ty_of_func_calls)
213 }
214
215 fn remap_exprs_and_output_indices(
217 exprs: Vec<ExprImpl>,
218 output_indices: Vec<usize>,
219 col_index_mapping: &mut ColIndexMapping,
220 ) -> (Vec<ExprImpl>, Vec<usize>) {
221 let exprs: Vec<ExprImpl> = exprs
222 .into_iter()
223 .map(|expr| col_index_mapping.rewrite_expr(expr))
224 .collect();
225 let output_indices = output_indices
226 .into_iter()
227 .map(|i| col_index_mapping.map(i))
228 .collect();
229 (exprs, output_indices)
230 }
231
232 pub fn create() -> BoxedRule {
233 Box::new(PushCalculationOfJoinRule {})
234 }
235}