risingwave_frontend/optimizer/plan_node/
eq_join_predicate.rs1use std::fmt;
16
17use itertools::Itertools;
18use risingwave_common::catalog::Schema;
19
20use crate::expr::{
21 ExprRewriter, ExprType, ExprVisitor, FunctionCall, InequalityInputPair, InputRef,
22 InputRefDisplay,
23};
24use crate::utils::{ColIndexMapping, Condition, ConditionDisplay};
25
26#[derive(Debug, Clone, PartialEq, Eq, Hash)]
28pub struct EqJoinPredicate {
29 other_cond: Condition,
31
32 eq_keys: Vec<(InputRef, InputRef, bool)>,
37
38 left_cols_num: usize,
39 right_cols_num: usize,
40}
41
42impl fmt::Display for EqJoinPredicate {
43 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
44 let mut eq_keys = self.eq_keys().iter();
45 if let Some((k1, k2, null_safe)) = eq_keys.next() {
46 write!(
47 f,
48 "{} {} {}",
49 k1,
50 if *null_safe {
51 "IS NOT DISTINCT FROM"
52 } else {
53 "="
54 },
55 k2
56 )?;
57 }
58 for (k1, k2, null_safe) in eq_keys {
59 write!(
60 f,
61 "AND {} {} {}",
62 k1,
63 if *null_safe {
64 "IS NOT DISTINCT FROM"
65 } else {
66 "="
67 },
68 k2
69 )?;
70 }
71 if !self.other_cond.always_true() {
72 write!(f, " AND {}", self.other_cond)?;
73 }
74
75 Ok(())
76 }
77}
78
79impl EqJoinPredicate {
80 pub fn new(
82 other_cond: Condition,
83 eq_keys: Vec<(InputRef, InputRef, bool)>,
84 left_cols_num: usize,
85 right_cols_num: usize,
86 ) -> Self {
87 Self {
88 other_cond,
89 eq_keys,
90 left_cols_num,
91 right_cols_num,
92 }
93 }
94
95 pub fn create(left_cols_num: usize, right_cols_num: usize, on_clause: Condition) -> Self {
111 let (eq_keys, other_cond) = on_clause.split_eq_keys(left_cols_num, right_cols_num);
112 Self::new(other_cond, eq_keys, left_cols_num, right_cols_num)
113 }
114
115 pub fn eq_cond(&self) -> Condition {
117 Condition {
118 conjunctions: self
119 .eq_keys
120 .iter()
121 .cloned()
122 .map(|(l, r, null_safe)| {
123 FunctionCall::new(
124 if null_safe {
125 ExprType::IsNotDistinctFrom
126 } else {
127 ExprType::Equal
128 },
129 vec![l.into(), r.into()],
130 )
131 .unwrap()
132 .into()
133 })
134 .collect(),
135 }
136 }
137
138 pub fn non_eq_cond(&self) -> Condition {
139 self.other_cond.clone()
140 }
141
142 pub fn all_cond(&self) -> Condition {
143 let cond = self.eq_cond();
144 cond.and(self.non_eq_cond())
145 }
146
147 pub fn has_eq(&self) -> bool {
148 !self.eq_keys.is_empty()
149 }
150
151 pub fn has_non_eq(&self) -> bool {
152 !self.other_cond.always_true()
153 }
154
155 pub fn other_cond(&self) -> &Condition {
157 &self.other_cond
158 }
159
160 pub fn other_cond_mut(&mut self) -> &mut Condition {
162 &mut self.other_cond
163 }
164
165 pub fn eq_predicate(&self) -> Self {
167 Self {
168 other_cond: Condition::true_cond(),
169 eq_keys: self.eq_keys.clone(),
170 left_cols_num: self.left_cols_num,
171 right_cols_num: self.right_cols_num,
172 }
173 }
174
175 pub fn eq_keys(&self) -> &[(InputRef, InputRef, bool)] {
179 self.eq_keys.as_ref()
180 }
181
182 pub fn eq_indexes(&self) -> Vec<(usize, usize)> {
186 self.eq_keys
187 .iter()
188 .map(|(left, right, _)| (left.index(), right.index() - self.left_cols_num))
189 .collect()
190 }
191
192 pub(crate) fn inequality_pairs_v2(&self) -> Vec<(usize, InequalityInputPair)> {
197 self.other_cond()
198 .extract_inequality_keys(self.left_cols_num, self.right_cols_num)
199 }
200
201 pub fn eq_indexes_typed(&self) -> Vec<(InputRef, InputRef)> {
203 self.eq_keys
204 .iter()
205 .cloned()
206 .map(|(left, mut right, _)| {
207 right.index -= self.left_cols_num;
208 (left, right)
209 })
210 .collect()
211 }
212
213 pub fn eq_keys_are_type_aligned(&self) -> bool {
214 let mut aligned = true;
215 for (l, r, _) in &self.eq_keys {
216 aligned &= l.data_type == r.data_type;
217 }
218 aligned
219 }
220
221 pub fn left_eq_indexes(&self) -> Vec<usize> {
222 self.eq_keys
223 .iter()
224 .map(|(left, _, _)| left.index())
225 .collect()
226 }
227
228 pub fn right_eq_indexes(&self) -> Vec<usize> {
230 self.eq_keys
231 .iter()
232 .map(|(_, right, _)| right.index() - self.left_cols_num)
233 .collect()
234 }
235
236 pub fn null_safes(&self) -> Vec<bool> {
237 self.eq_keys
238 .iter()
239 .map(|(_, _, null_safe)| *null_safe)
240 .collect()
241 }
242
243 pub fn r2l_eq_columns_mapping(
245 &self,
246 left_cols_num: usize,
247 right_cols_num: usize,
248 ) -> ColIndexMapping {
249 let mut map = vec![None; right_cols_num];
250 for (left, right, _) in self.eq_keys() {
251 map[right.index - left_cols_num] = Some(left.index);
252 }
253 ColIndexMapping::new(map, left_cols_num)
254 }
255
256 pub fn l2r_eq_columns_mapping(
258 &self,
259 left_cols_num: usize,
260 right_cols_num: usize,
261 ) -> ColIndexMapping {
262 let mut map = vec![None; left_cols_num];
263 for (left, right, _) in self.eq_keys() {
264 map[left.index] = Some(right.index - left_cols_num);
265 }
266 ColIndexMapping::new(map, right_cols_num)
267 }
268
269 pub fn reorder(self, reorder_idx: &[usize]) -> Self {
271 assert!(reorder_idx.len() <= self.eq_keys.len());
272 let mut new_eq_keys = Vec::with_capacity(self.eq_keys.len());
273 for idx in reorder_idx {
274 new_eq_keys.push(self.eq_keys[*idx].clone());
275 }
276 for idx in 0..self.eq_keys.len() {
277 if !reorder_idx.contains(&idx) {
278 new_eq_keys.push(self.eq_keys[idx].clone());
279 }
280 }
281
282 Self::new(
283 self.other_cond,
284 new_eq_keys,
285 self.left_cols_num,
286 self.right_cols_num,
287 )
288 }
289
290 pub fn retain_prefix_eq_key(self, prefix_len: usize) -> Self {
293 assert!(prefix_len <= self.eq_keys.len());
294 let (retain_eq_key, other_eq_key) = self.eq_keys.split_at(prefix_len);
295 let mut new_other_conjunctions = self.other_cond.conjunctions;
296 new_other_conjunctions.extend(
297 other_eq_key
298 .iter()
299 .cloned()
300 .map(|(l, r, null_safe)| {
301 FunctionCall::new(
302 if null_safe {
303 ExprType::IsNotDistinctFrom
304 } else {
305 ExprType::Equal
306 },
307 vec![l.into(), r.into()],
308 )
309 .unwrap()
310 .into()
311 })
312 .collect_vec(),
313 );
314
315 let new_other_cond = Condition {
316 conjunctions: new_other_conjunctions,
317 };
318
319 Self::new(
320 new_other_cond,
321 retain_eq_key.to_owned(),
322 self.left_cols_num,
323 self.right_cols_num,
324 )
325 }
326
327 pub fn rewrite_exprs(&self, rewriter: &mut (impl ExprRewriter + ?Sized)) -> Self {
328 let mut new = self.clone();
329 new.other_cond = new.other_cond.rewrite_expr(rewriter);
330 new
331 }
332
333 pub fn visit_exprs(&self, v: &mut (impl ExprVisitor + ?Sized)) {
334 self.other_cond.visit_expr(v);
335 }
336}
337
338pub struct EqJoinPredicateDisplay<'a> {
339 pub eq_join_predicate: &'a EqJoinPredicate,
340 pub input_schema: &'a Schema,
341}
342
343impl EqJoinPredicateDisplay<'_> {
344 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
345 let that = self.eq_join_predicate;
346 let mut eq_keys = that.eq_keys().iter();
347 let mut printed_any = false;
348 if let Some((k1, k2, null_safe)) = eq_keys.next() {
349 write!(
350 f,
351 "{} {} {}",
352 InputRefDisplay {
353 input_ref: k1,
354 input_schema: self.input_schema
355 },
356 if *null_safe {
357 "IS NOT DISTINCT FROM"
358 } else {
359 "="
360 },
361 InputRefDisplay {
362 input_ref: k2,
363 input_schema: self.input_schema
364 }
365 )?;
366 printed_any = true;
367 }
368 for (k1, k2, null_safe) in eq_keys {
369 write!(
370 f,
371 " AND {} {} {}",
372 InputRefDisplay {
373 input_ref: k1,
374 input_schema: self.input_schema
375 },
376 if *null_safe {
377 "IS NOT DISTINCT FROM"
378 } else {
379 "="
380 },
381 InputRefDisplay {
382 input_ref: k2,
383 input_schema: self.input_schema
384 }
385 )?;
386 printed_any = true;
387 }
388 if !that.other_cond.always_true() {
389 write!(
390 f,
391 "{}{}",
392 if printed_any { " AND " } else { "" },
393 ConditionDisplay {
394 condition: &that.other_cond,
395 input_schema: self.input_schema
396 }
397 )?;
398 printed_any = true;
399 }
400 if !printed_any {
401 write!(f, "true")?;
402 }
403
404 Ok(())
405 }
406}
407
408impl fmt::Display for EqJoinPredicateDisplay<'_> {
409 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
410 self.fmt(f)
411 }
412}
413
414impl fmt::Debug for EqJoinPredicateDisplay<'_> {
415 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
416 self.fmt(f)
417 }
418}