risingwave_frontend/optimizer/plan_node/
eq_join_predicate.rs1use std::fmt;
16
17use itertools::Itertools;
18use risingwave_common::catalog::Schema;
19
20use crate::expr::{
21 ExprRewriter, ExprType, ExprVisitor, FunctionCall, InequalityInputPair, InputRef,
22 InputRefDisplay,
23};
24use crate::utils::{ColIndexMapping, Condition, ConditionDisplay};
25
26#[derive(Debug, Clone, PartialEq, Eq, Hash)]
28pub struct EqJoinPredicate {
29 other_cond: Condition,
31
32 eq_keys: Vec<(InputRef, InputRef, bool)>,
37
38 left_cols_num: usize,
39 right_cols_num: usize,
40}
41
42impl fmt::Display for EqJoinPredicate {
43 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
44 let mut eq_keys = self.eq_keys().iter();
45 if let Some((k1, k2, null_safe)) = eq_keys.next() {
46 write!(
47 f,
48 "{} {} {}",
49 k1,
50 if *null_safe {
51 "IS NOT DISTINCT FROM"
52 } else {
53 "="
54 },
55 k2
56 )?;
57 }
58 for (k1, k2, null_safe) in eq_keys {
59 write!(
60 f,
61 "AND {} {} {}",
62 k1,
63 if *null_safe {
64 "IS NOT DISTINCT FROM"
65 } else {
66 "="
67 },
68 k2
69 )?;
70 }
71 if !self.other_cond.always_true() {
72 write!(f, " AND {}", self.other_cond)?;
73 }
74
75 Ok(())
76 }
77}
78
79impl EqJoinPredicate {
80 pub fn new(
82 other_cond: Condition,
83 eq_keys: Vec<(InputRef, InputRef, bool)>,
84 left_cols_num: usize,
85 right_cols_num: usize,
86 ) -> Self {
87 Self {
88 other_cond,
89 eq_keys,
90 left_cols_num,
91 right_cols_num,
92 }
93 }
94
95 pub fn create(left_cols_num: usize, right_cols_num: usize, on_clause: Condition) -> Self {
111 let (eq_keys, other_cond) = on_clause.split_eq_keys(left_cols_num, right_cols_num);
112 Self::new(other_cond, eq_keys, left_cols_num, right_cols_num)
113 }
114
115 pub fn eq_cond(&self) -> Condition {
117 Condition {
118 conjunctions: self
119 .eq_keys
120 .iter()
121 .cloned()
122 .map(|(l, r, null_safe)| {
123 FunctionCall::new(
124 if null_safe {
125 ExprType::IsNotDistinctFrom
126 } else {
127 ExprType::Equal
128 },
129 vec![l.into(), r.into()],
130 )
131 .unwrap()
132 .into()
133 })
134 .collect(),
135 }
136 }
137
138 pub fn non_eq_cond(&self) -> Condition {
139 self.other_cond.clone()
140 }
141
142 pub fn all_cond(&self) -> Condition {
143 let cond = self.eq_cond();
144 cond.and(self.non_eq_cond())
145 }
146
147 pub fn has_eq(&self) -> bool {
148 !self.eq_keys.is_empty()
149 }
150
151 pub fn has_non_eq(&self) -> bool {
152 !self.other_cond.always_true()
153 }
154
155 pub fn other_cond(&self) -> &Condition {
157 &self.other_cond
158 }
159
160 pub fn other_cond_mut(&mut self) -> &mut Condition {
162 &mut self.other_cond
163 }
164
165 pub fn eq_predicate(&self) -> Self {
167 Self {
168 other_cond: Condition::true_cond(),
169 eq_keys: self.eq_keys.clone(),
170 left_cols_num: self.left_cols_num,
171 right_cols_num: self.right_cols_num,
172 }
173 }
174
175 pub fn eq_keys(&self) -> &[(InputRef, InputRef, bool)] {
179 self.eq_keys.as_ref()
180 }
181
182 pub fn eq_indexes(&self) -> Vec<(usize, usize)> {
186 self.eq_keys
187 .iter()
188 .map(|(left, right, _)| (left.index(), right.index() - self.left_cols_num))
189 .collect()
190 }
191
192 pub(crate) fn inequality_pairs(&self) -> (usize, Vec<(usize, InequalityInputPair)>) {
193 (
194 self.left_cols_num,
195 self.other_cond()
196 .extract_inequality_keys(self.left_cols_num, self.right_cols_num),
197 )
198 }
199
200 pub fn eq_indexes_typed(&self) -> Vec<(InputRef, InputRef)> {
202 self.eq_keys
203 .iter()
204 .cloned()
205 .map(|(left, mut right, _)| {
206 right.index -= self.left_cols_num;
207 (left, right)
208 })
209 .collect()
210 }
211
212 pub fn eq_keys_are_type_aligned(&self) -> bool {
213 let mut aligned = true;
214 for (l, r, _) in &self.eq_keys {
215 aligned &= l.data_type == r.data_type;
216 }
217 aligned
218 }
219
220 pub fn left_eq_indexes(&self) -> Vec<usize> {
221 self.eq_keys
222 .iter()
223 .map(|(left, _, _)| left.index())
224 .collect()
225 }
226
227 pub fn right_eq_indexes(&self) -> Vec<usize> {
229 self.eq_keys
230 .iter()
231 .map(|(_, right, _)| right.index() - self.left_cols_num)
232 .collect()
233 }
234
235 pub fn null_safes(&self) -> Vec<bool> {
236 self.eq_keys
237 .iter()
238 .map(|(_, _, null_safe)| *null_safe)
239 .collect()
240 }
241
242 pub fn r2l_eq_columns_mapping(
244 &self,
245 left_cols_num: usize,
246 right_cols_num: usize,
247 ) -> ColIndexMapping {
248 let mut map = vec![None; right_cols_num];
249 for (left, right, _) in self.eq_keys() {
250 map[right.index - left_cols_num] = Some(left.index);
251 }
252 ColIndexMapping::new(map, left_cols_num)
253 }
254
255 pub fn l2r_eq_columns_mapping(
257 &self,
258 left_cols_num: usize,
259 right_cols_num: usize,
260 ) -> ColIndexMapping {
261 let mut map = vec![None; left_cols_num];
262 for (left, right, _) in self.eq_keys() {
263 map[left.index] = Some(right.index - left_cols_num);
264 }
265 ColIndexMapping::new(map, right_cols_num)
266 }
267
268 pub fn reorder(self, reorder_idx: &[usize]) -> Self {
270 assert!(reorder_idx.len() <= self.eq_keys.len());
271 let mut new_eq_keys = Vec::with_capacity(self.eq_keys.len());
272 for idx in reorder_idx {
273 new_eq_keys.push(self.eq_keys[*idx].clone());
274 }
275 for idx in 0..self.eq_keys.len() {
276 if !reorder_idx.contains(&idx) {
277 new_eq_keys.push(self.eq_keys[idx].clone());
278 }
279 }
280
281 Self::new(
282 self.other_cond,
283 new_eq_keys,
284 self.left_cols_num,
285 self.right_cols_num,
286 )
287 }
288
289 pub fn retain_prefix_eq_key(self, prefix_len: usize) -> Self {
292 assert!(prefix_len <= self.eq_keys.len());
293 let (retain_eq_key, other_eq_key) = self.eq_keys.split_at(prefix_len);
294 let mut new_other_conjunctions = self.other_cond.conjunctions;
295 new_other_conjunctions.extend(
296 other_eq_key
297 .iter()
298 .cloned()
299 .map(|(l, r, null_safe)| {
300 FunctionCall::new(
301 if null_safe {
302 ExprType::IsNotDistinctFrom
303 } else {
304 ExprType::Equal
305 },
306 vec![l.into(), r.into()],
307 )
308 .unwrap()
309 .into()
310 })
311 .collect_vec(),
312 );
313
314 let new_other_cond = Condition {
315 conjunctions: new_other_conjunctions,
316 };
317
318 Self::new(
319 new_other_cond,
320 retain_eq_key.to_owned(),
321 self.left_cols_num,
322 self.right_cols_num,
323 )
324 }
325
326 pub fn rewrite_exprs(&self, rewriter: &mut (impl ExprRewriter + ?Sized)) -> Self {
327 let mut new = self.clone();
328 new.other_cond = new.other_cond.rewrite_expr(rewriter);
329 new
330 }
331
332 pub fn visit_exprs(&self, v: &mut (impl ExprVisitor + ?Sized)) {
333 self.other_cond.visit_expr(v);
334 }
335}
336
337pub struct EqJoinPredicateDisplay<'a> {
338 pub eq_join_predicate: &'a EqJoinPredicate,
339 pub input_schema: &'a Schema,
340}
341
342impl EqJoinPredicateDisplay<'_> {
343 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
344 let that = self.eq_join_predicate;
345 let mut eq_keys = that.eq_keys().iter();
346 if let Some((k1, k2, null_safe)) = eq_keys.next() {
347 write!(
348 f,
349 "{} {} {}",
350 InputRefDisplay {
351 input_ref: k1,
352 input_schema: self.input_schema
353 },
354 if *null_safe {
355 "IS NOT DISTINCT FROM"
356 } else {
357 "="
358 },
359 InputRefDisplay {
360 input_ref: k2,
361 input_schema: self.input_schema
362 }
363 )?;
364 }
365 for (k1, k2, null_safe) in eq_keys {
366 write!(
367 f,
368 " AND {} {} {}",
369 InputRefDisplay {
370 input_ref: k1,
371 input_schema: self.input_schema
372 },
373 if *null_safe {
374 "IS NOT DISTINCT FROM"
375 } else {
376 "="
377 },
378 InputRefDisplay {
379 input_ref: k2,
380 input_schema: self.input_schema
381 }
382 )?;
383 }
384 if !that.other_cond.always_true() {
385 write!(
386 f,
387 " AND {}",
388 ConditionDisplay {
389 condition: &that.other_cond,
390 input_schema: self.input_schema
391 }
392 )?;
393 }
394
395 Ok(())
396 }
397}
398
399impl fmt::Display for EqJoinPredicateDisplay<'_> {
400 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
401 self.fmt(f)
402 }
403}
404
405impl fmt::Debug for EqJoinPredicateDisplay<'_> {
406 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
407 self.fmt(f)
408 }
409}