risingwave_frontend/optimizer/plan_node/
eq_join_predicate.rs

1// Copyright 2022 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt;
16
17use itertools::Itertools;
18use risingwave_common::catalog::Schema;
19
20use crate::expr::{
21    ExprRewriter, ExprType, ExprVisitor, FunctionCall, InequalityInputPair, InputRef,
22    InputRefDisplay,
23};
24use crate::utils::{ColIndexMapping, Condition, ConditionDisplay};
25
26/// The join predicate used in optimizer
27#[derive(Debug, Clone, PartialEq, Eq, Hash)]
28pub struct EqJoinPredicate {
29    /// Other conditions, linked with `AND` conjunction.
30    other_cond: Condition,
31
32    /// `Vec` of `(left_col_index, right_col_index, null_safe)`,
33    /// representing a conjunction of `left_col_index = right_col_index`
34    ///
35    /// Note: `right_col_index` starts from `left_cols_num`
36    eq_keys: Vec<(InputRef, InputRef, bool)>,
37
38    left_cols_num: usize,
39    right_cols_num: usize,
40}
41
42impl fmt::Display for EqJoinPredicate {
43    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
44        let mut eq_keys = self.eq_keys().iter();
45        if let Some((k1, k2, null_safe)) = eq_keys.next() {
46            write!(
47                f,
48                "{} {} {}",
49                k1,
50                if *null_safe {
51                    "IS NOT DISTINCT FROM"
52                } else {
53                    "="
54                },
55                k2
56            )?;
57        }
58        for (k1, k2, null_safe) in eq_keys {
59            write!(
60                f,
61                "AND {} {} {}",
62                k1,
63                if *null_safe {
64                    "IS NOT DISTINCT FROM"
65                } else {
66                    "="
67                },
68                k2
69            )?;
70        }
71        if !self.other_cond.always_true() {
72            write!(f, " AND {}", self.other_cond)?;
73        }
74
75        Ok(())
76    }
77}
78
79impl EqJoinPredicate {
80    /// The new method for `JoinPredicate` without any analysis, check or rewrite.
81    pub fn new(
82        other_cond: Condition,
83        eq_keys: Vec<(InputRef, InputRef, bool)>,
84        left_cols_num: usize,
85        right_cols_num: usize,
86    ) -> Self {
87        Self {
88            other_cond,
89            eq_keys,
90            left_cols_num,
91            right_cols_num,
92        }
93    }
94
95    /// `create` will analyze the on clause condition and construct a `JoinPredicate`.
96    /// e.g.
97    /// ```sql
98    ///   select a.v1, a.v2, b.v1, b.v2 from a,b on a.v1 = a.v2 and a.v1 = b.v1 and a.v2 > b.v2
99    /// ```
100    /// will call the `create` function with `left_colsnum` = 2 and `on_clause` is (supposed
101    /// `input_ref` count start from 0)
102    /// ```sql
103    /// input_ref(0) = input_ref(1) and input_ref(0) = input_ref(2) and input_ref(1) > input_ref(3)
104    /// ```
105    /// And the `create functions` should return `JoinPredicate`
106    /// ```sql
107    ///   other_conds = Vec[input_ref(0) = input_ref(1), input_ref(1) > input_ref(3)],
108    ///   keys= Vec[(0,2)]
109    /// ```
110    pub fn create(left_cols_num: usize, right_cols_num: usize, on_clause: Condition) -> Self {
111        let (eq_keys, other_cond) = on_clause.split_eq_keys(left_cols_num, right_cols_num);
112        Self::new(other_cond, eq_keys, left_cols_num, right_cols_num)
113    }
114
115    /// Get join predicate's eq conds.
116    pub fn eq_cond(&self) -> Condition {
117        Condition {
118            conjunctions: self
119                .eq_keys
120                .iter()
121                .cloned()
122                .map(|(l, r, null_safe)| {
123                    FunctionCall::new(
124                        if null_safe {
125                            ExprType::IsNotDistinctFrom
126                        } else {
127                            ExprType::Equal
128                        },
129                        vec![l.into(), r.into()],
130                    )
131                    .unwrap()
132                    .into()
133                })
134                .collect(),
135        }
136    }
137
138    pub fn non_eq_cond(&self) -> Condition {
139        self.other_cond.clone()
140    }
141
142    pub fn all_cond(&self) -> Condition {
143        let cond = self.eq_cond();
144        cond.and(self.non_eq_cond())
145    }
146
147    pub fn has_eq(&self) -> bool {
148        !self.eq_keys.is_empty()
149    }
150
151    pub fn has_non_eq(&self) -> bool {
152        !self.other_cond.always_true()
153    }
154
155    /// Get a reference to the join predicate's other cond.
156    pub fn other_cond(&self) -> &Condition {
157        &self.other_cond
158    }
159
160    /// Get a mutable reference to the join predicate's other cond.
161    pub fn other_cond_mut(&mut self) -> &mut Condition {
162        &mut self.other_cond
163    }
164
165    /// Get the equal predicate
166    pub fn eq_predicate(&self) -> Self {
167        Self {
168            other_cond: Condition::true_cond(),
169            eq_keys: self.eq_keys.clone(),
170            left_cols_num: self.left_cols_num,
171            right_cols_num: self.right_cols_num,
172        }
173    }
174
175    /// Get a reference to the join predicate's eq keys.
176    ///
177    /// Note: `right_col_index` starts from `left_cols_num`
178    pub fn eq_keys(&self) -> &[(InputRef, InputRef, bool)] {
179        self.eq_keys.as_ref()
180    }
181
182    /// `Vec` of `(left_col_index, right_col_index)`.
183    ///
184    /// Note: `right_col_index` starts from `0`
185    pub fn eq_indexes(&self) -> Vec<(usize, usize)> {
186        self.eq_keys
187            .iter()
188            .map(|(left, right, _)| (left.index(), right.index() - self.left_cols_num))
189            .collect()
190    }
191
192    /// Returns a list of `(conjunction_index, InequalityInputPair)` where:
193    /// - `left_idx` is the column index from the left input
194    /// - `right_idx` is the column index from the right input (NOT offset by `left_cols_num`)
195    /// - `op` is the comparison operator
196    pub(crate) fn inequality_pairs_v2(&self) -> Vec<(usize, InequalityInputPair)> {
197        self.other_cond()
198            .extract_inequality_keys(self.left_cols_num, self.right_cols_num)
199    }
200
201    /// Note: `right_col_index` starts from `0`
202    pub fn eq_indexes_typed(&self) -> Vec<(InputRef, InputRef)> {
203        self.eq_keys
204            .iter()
205            .cloned()
206            .map(|(left, mut right, _)| {
207                right.index -= self.left_cols_num;
208                (left, right)
209            })
210            .collect()
211    }
212
213    pub fn eq_keys_are_type_aligned(&self) -> bool {
214        let mut aligned = true;
215        for (l, r, _) in &self.eq_keys {
216            aligned &= l.data_type == r.data_type;
217        }
218        aligned
219    }
220
221    pub fn left_eq_indexes(&self) -> Vec<usize> {
222        self.eq_keys
223            .iter()
224            .map(|(left, _, _)| left.index())
225            .collect()
226    }
227
228    /// Note: `right_col_index` starts from `0`
229    pub fn right_eq_indexes(&self) -> Vec<usize> {
230        self.eq_keys
231            .iter()
232            .map(|(_, right, _)| right.index() - self.left_cols_num)
233            .collect()
234    }
235
236    pub fn null_safes(&self) -> Vec<bool> {
237        self.eq_keys
238            .iter()
239            .map(|(_, _, null_safe)| *null_safe)
240            .collect()
241    }
242
243    /// return the eq columns index mapping from right inputs to left inputs
244    pub fn r2l_eq_columns_mapping(
245        &self,
246        left_cols_num: usize,
247        right_cols_num: usize,
248    ) -> ColIndexMapping {
249        let mut map = vec![None; right_cols_num];
250        for (left, right, _) in self.eq_keys() {
251            map[right.index - left_cols_num] = Some(left.index);
252        }
253        ColIndexMapping::new(map, left_cols_num)
254    }
255
256    /// return the eq columns index mapping from left inputs to right inputs
257    pub fn l2r_eq_columns_mapping(
258        &self,
259        left_cols_num: usize,
260        right_cols_num: usize,
261    ) -> ColIndexMapping {
262        let mut map = vec![None; left_cols_num];
263        for (left, right, _) in self.eq_keys() {
264            map[left.index] = Some(right.index - left_cols_num);
265        }
266        ColIndexMapping::new(map, right_cols_num)
267    }
268
269    /// Reorder the `eq_keys` according to the `reorder_idx`.
270    pub fn reorder(self, reorder_idx: &[usize]) -> Self {
271        assert!(reorder_idx.len() <= self.eq_keys.len());
272        let mut new_eq_keys = Vec::with_capacity(self.eq_keys.len());
273        for idx in reorder_idx {
274            new_eq_keys.push(self.eq_keys[*idx].clone());
275        }
276        for idx in 0..self.eq_keys.len() {
277            if !reorder_idx.contains(&idx) {
278                new_eq_keys.push(self.eq_keys[idx].clone());
279            }
280        }
281
282        Self::new(
283            self.other_cond,
284            new_eq_keys,
285            self.left_cols_num,
286            self.right_cols_num,
287        )
288    }
289
290    /// Retain the prefix of `eq_keys` based on the `prefix_len`. The other part is moved to the
291    /// other condition.
292    pub fn retain_prefix_eq_key(self, prefix_len: usize) -> Self {
293        assert!(prefix_len <= self.eq_keys.len());
294        let (retain_eq_key, other_eq_key) = self.eq_keys.split_at(prefix_len);
295        let mut new_other_conjunctions = self.other_cond.conjunctions;
296        new_other_conjunctions.extend(
297            other_eq_key
298                .iter()
299                .cloned()
300                .map(|(l, r, null_safe)| {
301                    FunctionCall::new(
302                        if null_safe {
303                            ExprType::IsNotDistinctFrom
304                        } else {
305                            ExprType::Equal
306                        },
307                        vec![l.into(), r.into()],
308                    )
309                    .unwrap()
310                    .into()
311                })
312                .collect_vec(),
313        );
314
315        let new_other_cond = Condition {
316            conjunctions: new_other_conjunctions,
317        };
318
319        Self::new(
320            new_other_cond,
321            retain_eq_key.to_owned(),
322            self.left_cols_num,
323            self.right_cols_num,
324        )
325    }
326
327    pub fn rewrite_exprs(&self, rewriter: &mut (impl ExprRewriter + ?Sized)) -> Self {
328        let mut new = self.clone();
329        new.other_cond = new.other_cond.rewrite_expr(rewriter);
330        new
331    }
332
333    pub fn visit_exprs(&self, v: &mut (impl ExprVisitor + ?Sized)) {
334        self.other_cond.visit_expr(v);
335    }
336}
337
338pub struct EqJoinPredicateDisplay<'a> {
339    pub eq_join_predicate: &'a EqJoinPredicate,
340    pub input_schema: &'a Schema,
341}
342
343impl EqJoinPredicateDisplay<'_> {
344    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
345        let that = self.eq_join_predicate;
346        let mut eq_keys = that.eq_keys().iter();
347        let mut printed_any = false;
348        if let Some((k1, k2, null_safe)) = eq_keys.next() {
349            write!(
350                f,
351                "{} {} {}",
352                InputRefDisplay {
353                    input_ref: k1,
354                    input_schema: self.input_schema
355                },
356                if *null_safe {
357                    "IS NOT DISTINCT FROM"
358                } else {
359                    "="
360                },
361                InputRefDisplay {
362                    input_ref: k2,
363                    input_schema: self.input_schema
364                }
365            )?;
366            printed_any = true;
367        }
368        for (k1, k2, null_safe) in eq_keys {
369            write!(
370                f,
371                " AND {} {} {}",
372                InputRefDisplay {
373                    input_ref: k1,
374                    input_schema: self.input_schema
375                },
376                if *null_safe {
377                    "IS NOT DISTINCT FROM"
378                } else {
379                    "="
380                },
381                InputRefDisplay {
382                    input_ref: k2,
383                    input_schema: self.input_schema
384                }
385            )?;
386            printed_any = true;
387        }
388        if !that.other_cond.always_true() {
389            write!(
390                f,
391                "{}{}",
392                if printed_any { " AND " } else { "" },
393                ConditionDisplay {
394                    condition: &that.other_cond,
395                    input_schema: self.input_schema
396                }
397            )?;
398            printed_any = true;
399        }
400        if !printed_any {
401            write!(f, "true")?;
402        }
403
404        Ok(())
405    }
406}
407
408impl fmt::Display for EqJoinPredicateDisplay<'_> {
409    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
410        self.fmt(f)
411    }
412}
413
414impl fmt::Debug for EqJoinPredicateDisplay<'_> {
415    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
416        self.fmt(f)
417    }
418}