risingwave_frontend/optimizer/plan_node/
eq_join_predicate.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt;
16
17use itertools::Itertools;
18use risingwave_common::catalog::Schema;
19
20use crate::expr::{
21    ExprRewriter, ExprType, ExprVisitor, FunctionCall, InequalityInputPair, InputRef,
22    InputRefDisplay,
23};
24use crate::utils::{ColIndexMapping, Condition, ConditionDisplay};
25
26/// The join predicate used in optimizer
27#[derive(Debug, Clone, PartialEq, Eq, Hash)]
28pub struct EqJoinPredicate {
29    /// Other conditions, linked with `AND` conjunction.
30    other_cond: Condition,
31
32    /// `Vec` of `(left_col_index, right_col_index, null_safe)`,
33    /// representing a conjunction of `left_col_index = right_col_index`
34    ///
35    /// Note: `right_col_index` starts from `left_cols_num`
36    eq_keys: Vec<(InputRef, InputRef, bool)>,
37
38    left_cols_num: usize,
39    right_cols_num: usize,
40}
41
42impl fmt::Display for EqJoinPredicate {
43    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
44        let mut eq_keys = self.eq_keys().iter();
45        if let Some((k1, k2, null_safe)) = eq_keys.next() {
46            write!(
47                f,
48                "{} {} {}",
49                k1,
50                if *null_safe {
51                    "IS NOT DISTINCT FROM"
52                } else {
53                    "="
54                },
55                k2
56            )?;
57        }
58        for (k1, k2, null_safe) in eq_keys {
59            write!(
60                f,
61                "AND {} {} {}",
62                k1,
63                if *null_safe {
64                    "IS NOT DISTINCT FROM"
65                } else {
66                    "="
67                },
68                k2
69            )?;
70        }
71        if !self.other_cond.always_true() {
72            write!(f, " AND {}", self.other_cond)?;
73        }
74
75        Ok(())
76    }
77}
78
79impl EqJoinPredicate {
80    /// The new method for `JoinPredicate` without any analysis, check or rewrite.
81    pub fn new(
82        other_cond: Condition,
83        eq_keys: Vec<(InputRef, InputRef, bool)>,
84        left_cols_num: usize,
85        right_cols_num: usize,
86    ) -> Self {
87        Self {
88            other_cond,
89            eq_keys,
90            left_cols_num,
91            right_cols_num,
92        }
93    }
94
95    /// `create` will analyze the on clause condition and construct a `JoinPredicate`.
96    /// e.g.
97    /// ```sql
98    ///   select a.v1, a.v2, b.v1, b.v2 from a,b on a.v1 = a.v2 and a.v1 = b.v1 and a.v2 > b.v2
99    /// ```
100    /// will call the `create` function with `left_colsnum` = 2 and `on_clause` is (supposed
101    /// `input_ref` count start from 0)
102    /// ```sql
103    /// input_ref(0) = input_ref(1) and input_ref(0) = input_ref(2) and input_ref(1) > input_ref(3)
104    /// ```
105    /// And the `create functions` should return `JoinPredicate`
106    /// ```sql
107    ///   other_conds = Vec[input_ref(0) = input_ref(1), input_ref(1) > input_ref(3)],
108    ///   keys= Vec[(0,2)]
109    /// ```
110    pub fn create(left_cols_num: usize, right_cols_num: usize, on_clause: Condition) -> Self {
111        let (eq_keys, other_cond) = on_clause.split_eq_keys(left_cols_num, right_cols_num);
112        Self::new(other_cond, eq_keys, left_cols_num, right_cols_num)
113    }
114
115    /// Get join predicate's eq conds.
116    pub fn eq_cond(&self) -> Condition {
117        Condition {
118            conjunctions: self
119                .eq_keys
120                .iter()
121                .cloned()
122                .map(|(l, r, null_safe)| {
123                    FunctionCall::new(
124                        if null_safe {
125                            ExprType::IsNotDistinctFrom
126                        } else {
127                            ExprType::Equal
128                        },
129                        vec![l.into(), r.into()],
130                    )
131                    .unwrap()
132                    .into()
133                })
134                .collect(),
135        }
136    }
137
138    pub fn non_eq_cond(&self) -> Condition {
139        self.other_cond.clone()
140    }
141
142    pub fn all_cond(&self) -> Condition {
143        let cond = self.eq_cond();
144        cond.and(self.non_eq_cond())
145    }
146
147    pub fn has_eq(&self) -> bool {
148        !self.eq_keys.is_empty()
149    }
150
151    pub fn has_non_eq(&self) -> bool {
152        !self.other_cond.always_true()
153    }
154
155    /// Get a reference to the join predicate's other cond.
156    pub fn other_cond(&self) -> &Condition {
157        &self.other_cond
158    }
159
160    /// Get a mutable reference to the join predicate's other cond.
161    pub fn other_cond_mut(&mut self) -> &mut Condition {
162        &mut self.other_cond
163    }
164
165    /// Get the equal predicate
166    pub fn eq_predicate(&self) -> Self {
167        Self {
168            other_cond: Condition::true_cond(),
169            eq_keys: self.eq_keys.clone(),
170            left_cols_num: self.left_cols_num,
171            right_cols_num: self.right_cols_num,
172        }
173    }
174
175    /// Get a reference to the join predicate's eq keys.
176    ///
177    /// Note: `right_col_index` starts from `left_cols_num`
178    pub fn eq_keys(&self) -> &[(InputRef, InputRef, bool)] {
179        self.eq_keys.as_ref()
180    }
181
182    /// `Vec` of `(left_col_index, right_col_index)`.
183    ///
184    /// Note: `right_col_index` starts from `0`
185    pub fn eq_indexes(&self) -> Vec<(usize, usize)> {
186        self.eq_keys
187            .iter()
188            .map(|(left, right, _)| (left.index(), right.index() - self.left_cols_num))
189            .collect()
190    }
191
192    pub(crate) fn inequality_pairs(&self) -> (usize, Vec<(usize, InequalityInputPair)>) {
193        (
194            self.left_cols_num,
195            self.other_cond()
196                .extract_inequality_keys(self.left_cols_num, self.right_cols_num),
197        )
198    }
199
200    /// Note: `right_col_index` starts from `0`
201    pub fn eq_indexes_typed(&self) -> Vec<(InputRef, InputRef)> {
202        self.eq_keys
203            .iter()
204            .cloned()
205            .map(|(left, mut right, _)| {
206                right.index -= self.left_cols_num;
207                (left, right)
208            })
209            .collect()
210    }
211
212    pub fn eq_keys_are_type_aligned(&self) -> bool {
213        let mut aligned = true;
214        for (l, r, _) in &self.eq_keys {
215            aligned &= l.data_type == r.data_type;
216        }
217        aligned
218    }
219
220    pub fn left_eq_indexes(&self) -> Vec<usize> {
221        self.eq_keys
222            .iter()
223            .map(|(left, _, _)| left.index())
224            .collect()
225    }
226
227    /// Note: `right_col_index` starts from `0`
228    pub fn right_eq_indexes(&self) -> Vec<usize> {
229        self.eq_keys
230            .iter()
231            .map(|(_, right, _)| right.index() - self.left_cols_num)
232            .collect()
233    }
234
235    pub fn null_safes(&self) -> Vec<bool> {
236        self.eq_keys
237            .iter()
238            .map(|(_, _, null_safe)| *null_safe)
239            .collect()
240    }
241
242    /// return the eq columns index mapping from right inputs to left inputs
243    pub fn r2l_eq_columns_mapping(
244        &self,
245        left_cols_num: usize,
246        right_cols_num: usize,
247    ) -> ColIndexMapping {
248        let mut map = vec![None; right_cols_num];
249        for (left, right, _) in self.eq_keys() {
250            map[right.index - left_cols_num] = Some(left.index);
251        }
252        ColIndexMapping::new(map, left_cols_num)
253    }
254
255    /// return the eq columns index mapping from left inputs to right inputs
256    pub fn l2r_eq_columns_mapping(
257        &self,
258        left_cols_num: usize,
259        right_cols_num: usize,
260    ) -> ColIndexMapping {
261        let mut map = vec![None; left_cols_num];
262        for (left, right, _) in self.eq_keys() {
263            map[left.index] = Some(right.index - left_cols_num);
264        }
265        ColIndexMapping::new(map, right_cols_num)
266    }
267
268    /// Reorder the `eq_keys` according to the `reorder_idx`.
269    pub fn reorder(self, reorder_idx: &[usize]) -> Self {
270        assert!(reorder_idx.len() <= self.eq_keys.len());
271        let mut new_eq_keys = Vec::with_capacity(self.eq_keys.len());
272        for idx in reorder_idx {
273            new_eq_keys.push(self.eq_keys[*idx].clone());
274        }
275        for idx in 0..self.eq_keys.len() {
276            if !reorder_idx.contains(&idx) {
277                new_eq_keys.push(self.eq_keys[idx].clone());
278            }
279        }
280
281        Self::new(
282            self.other_cond,
283            new_eq_keys,
284            self.left_cols_num,
285            self.right_cols_num,
286        )
287    }
288
289    /// Retain the prefix of `eq_keys` based on the `prefix_len`. The other part is moved to the
290    /// other condition.
291    pub fn retain_prefix_eq_key(self, prefix_len: usize) -> Self {
292        assert!(prefix_len <= self.eq_keys.len());
293        let (retain_eq_key, other_eq_key) = self.eq_keys.split_at(prefix_len);
294        let mut new_other_conjunctions = self.other_cond.conjunctions;
295        new_other_conjunctions.extend(
296            other_eq_key
297                .iter()
298                .cloned()
299                .map(|(l, r, null_safe)| {
300                    FunctionCall::new(
301                        if null_safe {
302                            ExprType::IsNotDistinctFrom
303                        } else {
304                            ExprType::Equal
305                        },
306                        vec![l.into(), r.into()],
307                    )
308                    .unwrap()
309                    .into()
310                })
311                .collect_vec(),
312        );
313
314        let new_other_cond = Condition {
315            conjunctions: new_other_conjunctions,
316        };
317
318        Self::new(
319            new_other_cond,
320            retain_eq_key.to_owned(),
321            self.left_cols_num,
322            self.right_cols_num,
323        )
324    }
325
326    pub fn rewrite_exprs(&self, rewriter: &mut (impl ExprRewriter + ?Sized)) -> Self {
327        let mut new = self.clone();
328        new.other_cond = new.other_cond.rewrite_expr(rewriter);
329        new
330    }
331
332    pub fn visit_exprs(&self, v: &mut (impl ExprVisitor + ?Sized)) {
333        self.other_cond.visit_expr(v);
334    }
335}
336
337pub struct EqJoinPredicateDisplay<'a> {
338    pub eq_join_predicate: &'a EqJoinPredicate,
339    pub input_schema: &'a Schema,
340}
341
342impl EqJoinPredicateDisplay<'_> {
343    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
344        let that = self.eq_join_predicate;
345        let mut eq_keys = that.eq_keys().iter();
346        if let Some((k1, k2, null_safe)) = eq_keys.next() {
347            write!(
348                f,
349                "{} {} {}",
350                InputRefDisplay {
351                    input_ref: k1,
352                    input_schema: self.input_schema
353                },
354                if *null_safe {
355                    "IS NOT DISTINCT FROM"
356                } else {
357                    "="
358                },
359                InputRefDisplay {
360                    input_ref: k2,
361                    input_schema: self.input_schema
362                }
363            )?;
364        }
365        for (k1, k2, null_safe) in eq_keys {
366            write!(
367                f,
368                " AND {} {} {}",
369                InputRefDisplay {
370                    input_ref: k1,
371                    input_schema: self.input_schema
372                },
373                if *null_safe {
374                    "IS NOT DISTINCT FROM"
375                } else {
376                    "="
377                },
378                InputRefDisplay {
379                    input_ref: k2,
380                    input_schema: self.input_schema
381                }
382            )?;
383        }
384        if !that.other_cond.always_true() {
385            write!(
386                f,
387                " AND {}",
388                ConditionDisplay {
389                    condition: &that.other_cond,
390                    input_schema: self.input_schema
391                }
392            )?;
393        }
394
395        Ok(())
396    }
397}
398
399impl fmt::Display for EqJoinPredicateDisplay<'_> {
400    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
401        self.fmt(f)
402    }
403}
404
405impl fmt::Debug for EqJoinPredicateDisplay<'_> {
406    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
407        self.fmt(f)
408    }
409}