risingwave_frontend/optimizer/plan_node/generic/
join.rs

1// Copyright 2022 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use itertools::{EitherOrBoth, Itertools};
16use risingwave_common::catalog::{Field, Schema};
17use risingwave_common::types::DataType;
18use risingwave_common::util::sort_util::OrderType;
19use risingwave_pb::plan_common::JoinType;
20
21use super::{EqJoinPredicate, GenericPlanNode, GenericPlanRef};
22use crate::TableCatalog;
23use crate::expr::{ExprRewriter, ExprVisitor};
24use crate::optimizer::optimizer_context::OptimizerContextRef;
25use crate::optimizer::plan_node::StreamPlanRef;
26use crate::optimizer::plan_node::stream::StreamPlanNodeMetadata as _;
27use crate::optimizer::plan_node::stream::prelude::*;
28use crate::optimizer::plan_node::utils::TableCatalogBuilder;
29use crate::optimizer::property::{FunctionalDependencySet, StreamKind};
30use crate::utils::{ColIndexMapping, ColIndexMappingRewriteExt, Condition};
31
32/// Join predicate stored in the join core.
33///
34/// - Logical joins keep the original [`Condition`] for optimizer rules.
35/// - Physical joins keep a fixed [`EqJoinPredicate`] (eq keys are already extracted and must be
36///   preserved even if condition simplification would otherwise drop them).
37#[derive(Debug, Clone, PartialEq, Eq, Hash)]
38pub enum JoinOn {
39    Condition(Condition),
40    EqPredicate(EqJoinPredicate),
41}
42
43impl JoinOn {
44    /// Get the join predicate as a [`Condition`].
45    ///
46    /// - For [`JoinOn::Condition`], this returns the stored condition.
47    /// - For [`JoinOn::EqPredicate`], this converts the stored [`EqJoinPredicate`] back to a
48    ///   condition.
49    ///
50    /// Prefer [`JoinOn::as_condition_ref`] if you only want the original logical condition and
51    /// don't want any conversion.
52    pub fn as_condition(&self) -> Condition {
53        match self {
54            JoinOn::Condition(cond) => cond.clone(),
55            JoinOn::EqPredicate(pred) => pred.all_cond(),
56        }
57    }
58
59    /// Borrow the original logical join condition, if any.
60    ///
61    /// Returns `None` for [`JoinOn::EqPredicate`], because physical plans store a fixed
62    /// [`EqJoinPredicate`] instead of the original condition.
63    pub fn as_condition_ref(&self) -> Option<&Condition> {
64        match self {
65            JoinOn::Condition(cond) => Some(cond),
66            JoinOn::EqPredicate(_) => None,
67        }
68    }
69
70    /// Borrow the fixed [`EqJoinPredicate`], if any.
71    ///
72    /// Returns `None` for [`JoinOn::Condition`]. If the caller needs eq keys, it should extract
73    /// them via [`EqJoinPredicate::create`] (potentially after rewriting/simplifying the
74    /// condition).
75    pub fn as_eq_predicate_ref(&self) -> Option<&EqJoinPredicate> {
76        match self {
77            JoinOn::Condition(_) => None,
78            JoinOn::EqPredicate(pred) => Some(pred),
79        }
80    }
81
82    /// Rewrite expressions inside the predicate.
83    ///
84    /// For [`JoinOn::EqPredicate`], eq keys are treated as fixed (they are expected to be plain
85    /// input refs), so only the "other" condition is rewritten.
86    pub fn rewrite_exprs(&mut self, r: &mut dyn ExprRewriter) {
87        match self {
88            JoinOn::Condition(cond) => {
89                *cond = cond.clone().rewrite_expr(r);
90            }
91            JoinOn::EqPredicate(pred) => {
92                *pred = pred.rewrite_exprs(r);
93            }
94        }
95    }
96
97    /// Visit expressions inside the predicate.
98    ///
99    /// For [`JoinOn::EqPredicate`], only non-eq conditions are visited.
100    pub fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
101        match self {
102            JoinOn::Condition(cond) => cond.visit_expr(v),
103            // eq keys are fixed and contain only input refs; only visit non-eq conditions.
104            JoinOn::EqPredicate(pred) => pred.visit_exprs(v),
105        }
106    }
107}
108
109/// [`Join`] combines two relations according to some condition.
110///
111/// Each output row has fields from the left and right inputs. The set of output rows is a subset
112/// of the cartesian product of the two inputs; precisely which subset depends on the join
113/// condition. In addition, the output columns are a subset of the columns of the left and
114/// right columns, dependent on the output indices provided. A repeat output index is illegal.
115#[derive(Debug, Clone, PartialEq, Eq, Hash)]
116pub struct Join<PlanRef> {
117    pub left: PlanRef,
118    pub right: PlanRef,
119    pub on: JoinOn,
120    pub join_type: JoinType,
121    pub output_indices: Vec<usize>,
122}
123
124pub(crate) fn has_repeated_element(slice: &[usize]) -> bool {
125    (1..slice.len()).any(|i| slice[i..].contains(&slice[i - 1]))
126}
127
128impl<PlanRef: GenericPlanRef> Join<PlanRef> {
129    pub(crate) fn clone_with_inputs<OtherPlanRef>(
130        &self,
131        left: OtherPlanRef,
132        right: OtherPlanRef,
133    ) -> Join<OtherPlanRef> {
134        Join {
135            left,
136            right,
137            on: self.on.clone(),
138            join_type: self.join_type,
139            output_indices: self.output_indices.clone(),
140        }
141    }
142
143    pub(crate) fn rewrite_exprs(&mut self, r: &mut dyn ExprRewriter) {
144        self.on.rewrite_exprs(r);
145    }
146
147    pub(crate) fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
148        self.on.visit_exprs(v);
149    }
150
151    pub fn eq_indexes(&self) -> Vec<(usize, usize)> {
152        let left_len = self.left.schema().len();
153        let right_len = self.right.schema().len();
154        match &self.on {
155            JoinOn::Condition(on) => {
156                EqJoinPredicate::create(left_len, right_len, on.clone()).eq_indexes()
157            }
158            JoinOn::EqPredicate(pred) => pred.eq_indexes(),
159        }
160    }
161
162    pub fn new(
163        left: PlanRef,
164        right: PlanRef,
165        on: Condition,
166        join_type: JoinType,
167        output_indices: Vec<usize>,
168    ) -> Self {
169        // We cannot deal with repeated output indices in join
170        debug_assert!(!has_repeated_element(&output_indices));
171        Self {
172            left,
173            right,
174            on: JoinOn::Condition(on),
175            join_type,
176            output_indices,
177        }
178    }
179
180    pub fn new_with_eq_predicate(
181        left: PlanRef,
182        right: PlanRef,
183        eq_join_predicate: EqJoinPredicate,
184        join_type: JoinType,
185        output_indices: Vec<usize>,
186    ) -> Self {
187        debug_assert!(!has_repeated_element(&output_indices));
188        Self {
189            left,
190            right,
191            on: JoinOn::EqPredicate(eq_join_predicate),
192            join_type,
193            output_indices,
194        }
195    }
196}
197
198impl Join<StreamPlanRef> {
199    pub fn stream_kind(&self) -> Result<StreamKind> {
200        let left_kind = reject_upsert_input!(self.left, "Join");
201        let right_kind = reject_upsert_input!(self.right, "Join");
202
203        // Inner join won't change the append-only behavior of the stream. The rest might.
204        if let JoinType::Inner | JoinType::AsofInner = self.join_type
205            && let StreamKind::AppendOnly = left_kind
206            && let StreamKind::AppendOnly = right_kind
207        {
208            Ok(StreamKind::AppendOnly)
209        } else {
210            Ok(StreamKind::Retract)
211        }
212    }
213
214    /// Return stream hash join internal table catalog and degree table catalog.
215    pub fn infer_internal_and_degree_table_catalog(
216        input: StreamPlanRef,
217        join_key_indices: Vec<usize>,
218        dk_indices_in_jk: Vec<usize>,
219    ) -> (TableCatalog, TableCatalog, Vec<usize>) {
220        let schema = input.schema();
221
222        let internal_table_dist_keys = dk_indices_in_jk
223            .iter()
224            .map(|idx| join_key_indices[*idx])
225            .collect_vec();
226
227        let degree_table_dist_keys = dk_indices_in_jk.clone();
228
229        // The pk of hash join internal and degree table should be join_key + input_pk.
230        let join_key_len = join_key_indices.len();
231        let mut pk_indices = join_key_indices;
232
233        // dedup the pk in dist key..
234        let mut deduped_input_pk_indices = vec![];
235        for input_pk_idx in input.stream_key().unwrap() {
236            if !pk_indices.contains(input_pk_idx)
237                && !deduped_input_pk_indices.contains(input_pk_idx)
238            {
239                deduped_input_pk_indices.push(*input_pk_idx);
240            }
241        }
242
243        pk_indices.extend(deduped_input_pk_indices.clone());
244
245        // Build internal table
246        let mut internal_table_catalog_builder = TableCatalogBuilder::default();
247        let internal_columns_fields = schema.fields().to_vec();
248
249        internal_columns_fields.iter().for_each(|field| {
250            internal_table_catalog_builder.add_column(field);
251        });
252        pk_indices.iter().for_each(|idx| {
253            internal_table_catalog_builder.add_order_column(*idx, OrderType::ascending())
254        });
255
256        // Build degree table.
257        let mut degree_table_catalog_builder = TableCatalogBuilder::default();
258
259        let degree_column_field = Field::with_name(DataType::Int64, "_degree");
260
261        pk_indices.iter().enumerate().for_each(|(order_idx, idx)| {
262            degree_table_catalog_builder.add_column(&internal_columns_fields[*idx]);
263            degree_table_catalog_builder.add_order_column(order_idx, OrderType::ascending());
264        });
265        degree_table_catalog_builder.add_column(&degree_column_field);
266        degree_table_catalog_builder
267            .set_value_indices(vec![degree_table_catalog_builder.columns().len() - 1]);
268
269        internal_table_catalog_builder.set_dist_key_in_pk(dk_indices_in_jk.clone());
270        degree_table_catalog_builder.set_dist_key_in_pk(dk_indices_in_jk);
271
272        (
273            internal_table_catalog_builder.build(internal_table_dist_keys, join_key_len),
274            degree_table_catalog_builder.build(degree_table_dist_keys, join_key_len),
275            deduped_input_pk_indices,
276        )
277    }
278}
279
280impl<PlanRef: GenericPlanRef> GenericPlanNode for Join<PlanRef> {
281    fn schema(&self) -> Schema {
282        let left_schema = self.left.schema();
283        let right_schema = self.right.schema();
284        let i2l = self.i2l_col_mapping();
285        let i2r = self.i2r_col_mapping();
286        let fields = self
287            .output_indices
288            .iter()
289            .map(|&i| match (i2l.try_map(i), i2r.try_map(i)) {
290                (Some(l_i), None) => left_schema.fields()[l_i].clone(),
291                (None, Some(r_i)) => right_schema.fields()[r_i].clone(),
292                _ => panic!(
293                    "left len {}, right len {}, i {}, lmap {:?}, rmap {:?}",
294                    left_schema.len(),
295                    right_schema.len(),
296                    i,
297                    i2l,
298                    i2r
299                ),
300            })
301            .collect();
302        Schema { fields }
303    }
304
305    fn stream_key(&self) -> Option<Vec<usize>> {
306        let eq_indexes = self.eq_indexes();
307        let left_pk = self.left.stream_key()?;
308        let right_pk = self.right.stream_key()?;
309        let l2i = self.l2i_col_mapping();
310        let r2i = self.r2i_col_mapping();
311        let full_out_col_num = self.internal_column_num();
312        let i2o = ColIndexMapping::with_remaining_columns(&self.output_indices, full_out_col_num);
313
314        // Collect PKs in internal column space (without applying i2o mapping yet)
315        let mut pk_indices_internal = left_pk
316            .iter()
317            .map(|index| l2i.try_map(*index))
318            .chain(right_pk.iter().map(|index| r2i.try_map(*index)))
319            .flatten()
320            .collect::<Vec<_>>();
321
322        let either_or_both = self.add_which_join_key_to_pk();
323
324        for (lk, rk) in eq_indexes {
325            match either_or_both {
326                EitherOrBoth::Left(_) => {
327                    // Remove right-side join-key column from pk_indices_internal.
328                    // This may happen when right-side join-key is included in right-side PK.
329                    // e.g. select a, b where a.bid = b.id
330                    // Here the pk_indices should be [a.id, a.bid] instead of [a.id, b.id, a.bid],
331                    // because b.id = a.bid, so either of them would be enough.
332                    if let Some(rk_internal) = r2i.try_map(rk) {
333                        pk_indices_internal.retain(|&x| x != rk_internal);
334                    }
335                    // Add left-side join-key column in pk_indices_internal
336                    if let Some(lk_internal) = l2i.try_map(lk)
337                        && !pk_indices_internal.contains(&lk_internal)
338                    {
339                        pk_indices_internal.push(lk_internal);
340                    }
341                }
342                EitherOrBoth::Right(_) => {
343                    // Remove left-side join-key column from pk_indices_internal
344                    // See the example above
345                    if let Some(lk_internal) = l2i.try_map(lk) {
346                        pk_indices_internal.retain(|&x| x != lk_internal);
347                    }
348                    // Add right-side join-key column in pk_indices_internal
349                    if let Some(rk_internal) = r2i.try_map(rk)
350                        && !pk_indices_internal.contains(&rk_internal)
351                    {
352                        pk_indices_internal.push(rk_internal);
353                    }
354                }
355                EitherOrBoth::Both(_, _) => {
356                    if let Some(lk_internal) = l2i.try_map(lk)
357                        && !pk_indices_internal.contains(&lk_internal)
358                    {
359                        pk_indices_internal.push(lk_internal);
360                    }
361                    if let Some(rk_internal) = r2i.try_map(rk)
362                        && !pk_indices_internal.contains(&rk_internal)
363                    {
364                        pk_indices_internal.push(rk_internal);
365                    }
366                }
367            };
368        }
369
370        // Now apply i2o mapping to get output indices
371        let pk_indices = pk_indices_internal
372            .iter()
373            .map(|&index| i2o.try_map(index))
374            .collect::<Option<Vec<_>>>()?;
375
376        Some(pk_indices)
377    }
378
379    fn ctx(&self) -> OptimizerContextRef {
380        self.left.ctx()
381    }
382
383    fn functional_dependency(&self) -> FunctionalDependencySet {
384        let left_len = self.left.schema().len();
385        let right_len = self.right.schema().len();
386        let left_fd_set = self.left.functional_dependency().clone();
387        let right_fd_set = self.right.functional_dependency().clone();
388
389        let full_out_col_num = self.internal_column_num();
390
391        let get_new_left_fd_set = |left_fd_set: FunctionalDependencySet| {
392            ColIndexMapping::with_shift_offset(left_len, 0)
393                .composite(&ColIndexMapping::identity(full_out_col_num))
394                .rewrite_functional_dependency_set(left_fd_set)
395        };
396        let get_new_right_fd_set = |right_fd_set: FunctionalDependencySet| {
397            ColIndexMapping::with_shift_offset(right_len, left_len.try_into().unwrap())
398                .rewrite_functional_dependency_set(right_fd_set)
399        };
400        let fd_set: FunctionalDependencySet = match self.join_type {
401            JoinType::Inner | JoinType::AsofInner => {
402                let mut fd_set = FunctionalDependencySet::new(full_out_col_num);
403                for i in &self.on.as_condition().conjunctions {
404                    if let Some((col, _)) = i.as_eq_const() {
405                        fd_set.add_constant_columns(&[col.index()])
406                    } else if let Some((left, right)) = i.as_eq_cond() {
407                        fd_set.add_functional_dependency_by_column_indices(
408                            &[left.index()],
409                            &[right.index()],
410                        );
411                        fd_set.add_functional_dependency_by_column_indices(
412                            &[right.index()],
413                            &[left.index()],
414                        );
415                    }
416                }
417                get_new_left_fd_set(left_fd_set)
418                    .into_dependencies()
419                    .into_iter()
420                    .chain(get_new_right_fd_set(right_fd_set).into_dependencies())
421                    .for_each(|fd| fd_set.add_functional_dependency(fd));
422                fd_set
423            }
424            JoinType::LeftOuter | JoinType::AsofLeftOuter => get_new_left_fd_set(left_fd_set),
425            JoinType::RightOuter => get_new_right_fd_set(right_fd_set),
426            JoinType::FullOuter => FunctionalDependencySet::new(full_out_col_num),
427            JoinType::LeftSemi | JoinType::LeftAnti => left_fd_set,
428            JoinType::RightSemi | JoinType::RightAnti => right_fd_set,
429            JoinType::Unspecified => unreachable!(),
430        };
431        ColIndexMapping::with_remaining_columns(&self.output_indices, full_out_col_num)
432            .rewrite_functional_dependency_set(fd_set)
433    }
434}
435
436impl<PlanRef> Join<PlanRef> {
437    pub fn decompose(self) -> (PlanRef, PlanRef, Condition, JoinType, Vec<usize>) {
438        (
439            self.left,
440            self.right,
441            self.on.as_condition(),
442            self.join_type,
443            self.output_indices,
444        )
445    }
446}
447
448impl<PlanRef: GenericPlanRef> Join<PlanRef> {
449    pub fn full_out_col_num(left_len: usize, right_len: usize, join_type: JoinType) -> usize {
450        match join_type {
451            JoinType::Inner
452            | JoinType::LeftOuter
453            | JoinType::RightOuter
454            | JoinType::FullOuter
455            | JoinType::AsofInner
456            | JoinType::AsofLeftOuter => left_len + right_len,
457            JoinType::LeftSemi | JoinType::LeftAnti => left_len,
458            JoinType::RightSemi | JoinType::RightAnti => right_len,
459            JoinType::Unspecified => unreachable!(),
460        }
461    }
462
463    pub fn with_full_output(
464        left: PlanRef,
465        right: PlanRef,
466        join_type: JoinType,
467        on: Condition,
468    ) -> Self {
469        let out_column_num =
470            Self::full_out_col_num(left.schema().len(), right.schema().len(), join_type);
471        Self {
472            left,
473            right,
474            join_type,
475            on: JoinOn::Condition(on),
476            output_indices: (0..out_column_num).collect(),
477        }
478    }
479
480    pub fn with_full_output_eq_predicate(
481        left: PlanRef,
482        right: PlanRef,
483        join_type: JoinType,
484        eq_join_predicate: EqJoinPredicate,
485    ) -> Self {
486        let out_column_num =
487            Self::full_out_col_num(left.schema().len(), right.schema().len(), join_type);
488        Self {
489            left,
490            right,
491            join_type,
492            on: JoinOn::EqPredicate(eq_join_predicate),
493            output_indices: (0..out_column_num).collect(),
494        }
495    }
496
497    pub fn internal_column_num(&self) -> usize {
498        Self::full_out_col_num(
499            self.left.schema().len(),
500            self.right.schema().len(),
501            self.join_type,
502        )
503    }
504
505    pub fn is_full_out(&self) -> bool {
506        self.output_indices.len() == self.internal_column_num()
507    }
508
509    /// Get the Mapping of columnIndex from internal column index to left column index.
510    pub fn i2l_col_mapping(&self) -> ColIndexMapping {
511        let left_len = self.left.schema().len();
512        let right_len = self.right.schema().len();
513
514        match self.join_type {
515            JoinType::Inner
516            | JoinType::LeftOuter
517            | JoinType::RightOuter
518            | JoinType::FullOuter
519            | JoinType::AsofInner
520            | JoinType::AsofLeftOuter => {
521                ColIndexMapping::identity_or_none(left_len + right_len, left_len)
522            }
523
524            JoinType::LeftSemi | JoinType::LeftAnti => ColIndexMapping::identity(left_len),
525            JoinType::RightSemi | JoinType::RightAnti => {
526                ColIndexMapping::empty(right_len, left_len)
527            }
528            JoinType::Unspecified => unreachable!(),
529        }
530    }
531
532    /// Get the Mapping of columnIndex from internal column index to right column index.
533    pub fn i2r_col_mapping(&self) -> ColIndexMapping {
534        let left_len = self.left.schema().len();
535        let right_len = self.right.schema().len();
536
537        match self.join_type {
538            JoinType::Inner
539            | JoinType::LeftOuter
540            | JoinType::RightOuter
541            | JoinType::FullOuter
542            | JoinType::AsofInner
543            | JoinType::AsofLeftOuter => {
544                ColIndexMapping::with_shift_offset(left_len + right_len, -(left_len as isize))
545            }
546            JoinType::LeftSemi | JoinType::LeftAnti => ColIndexMapping::empty(left_len, right_len),
547            JoinType::RightSemi | JoinType::RightAnti => ColIndexMapping::identity(right_len),
548            JoinType::Unspecified => unreachable!(),
549        }
550    }
551
552    /// TODO: This function may can be merged with `i2l_col_mapping` in future.
553    pub fn i2l_col_mapping_ignore_join_type(&self) -> ColIndexMapping {
554        let left_len = self.left.schema().len();
555        let right_len = self.right.schema().len();
556
557        ColIndexMapping::identity_or_none(left_len + right_len, left_len)
558    }
559
560    /// TODO: This function may can be merged with `i2r_col_mapping` in future.
561    pub fn i2r_col_mapping_ignore_join_type(&self) -> ColIndexMapping {
562        let left_len = self.left.schema().len();
563        let right_len = self.right.schema().len();
564
565        ColIndexMapping::with_shift_offset(left_len + right_len, -(left_len as isize))
566    }
567
568    /// Get the Mapping of columnIndex from left column index to internal column index.
569    pub fn l2i_col_mapping(&self) -> ColIndexMapping {
570        self.i2l_col_mapping()
571            .inverse()
572            .expect("must be invertible")
573    }
574
575    /// Get the Mapping of columnIndex from right column index to internal column index.
576    pub fn r2i_col_mapping(&self) -> ColIndexMapping {
577        self.i2r_col_mapping()
578            .inverse()
579            .expect("must be invertible")
580    }
581
582    /// Get the Mapping of columnIndex from internal column index to output column index
583    pub fn i2o_col_mapping(&self) -> ColIndexMapping {
584        ColIndexMapping::with_remaining_columns(&self.output_indices, self.internal_column_num())
585    }
586
587    /// Get the Mapping of columnIndex from output column index to internal column index
588    pub fn o2i_col_mapping(&self) -> ColIndexMapping {
589        // If output_indices = [0, 0, 1], we should use it as `o2i_col_mapping` directly.
590        // If we use `self.i2o_col_mapping().inverse()`, we will lose the first 0.
591        ColIndexMapping::new(
592            self.output_indices.iter().map(|x| Some(*x)).collect(),
593            self.internal_column_num(),
594        )
595    }
596
597    pub fn add_which_join_key_to_pk(&self) -> EitherOrBoth<(), ()> {
598        match self.join_type {
599            JoinType::Inner | JoinType::AsofInner => {
600                // Theoretically adding either side is ok, but the distribution key of the inner
601                // join derived based on the left side by default, so we choose the left side here
602                // to ensure the pk comprises the distribution key.
603                EitherOrBoth::Left(())
604            }
605            JoinType::LeftOuter
606            | JoinType::LeftSemi
607            | JoinType::LeftAnti
608            | JoinType::AsofLeftOuter => EitherOrBoth::Left(()),
609            JoinType::RightSemi | JoinType::RightAnti | JoinType::RightOuter => {
610                EitherOrBoth::Right(())
611            }
612            JoinType::FullOuter => EitherOrBoth::Both((), ()),
613            JoinType::Unspecified => unreachable!(),
614        }
615    }
616
617    pub fn concat_schema(&self) -> Schema {
618        Schema::new(
619            [
620                self.left.schema().fields.clone(),
621                self.right.schema().fields.clone(),
622            ]
623            .concat(),
624        )
625    }
626}
627
628/// Try to split and pushdown `predicate` into a into a join condition and into the inputs of the
629/// join. Returns the pushed predicates. The pushed part will be removed from the original
630/// predicate.
631///
632/// `InputRef`s in the right pushed condition are indexed by the right child's output schema.
633pub fn push_down_into_join(
634    predicate: &mut Condition,
635    left_col_num: usize,
636    right_col_num: usize,
637    ty: JoinType,
638    push_temporal_predicate: bool,
639) -> (Condition, Condition, Condition) {
640    let (left, right) = push_down_to_inputs(
641        predicate,
642        left_col_num,
643        right_col_num,
644        can_push_left_from_filter(ty),
645        can_push_right_from_filter(ty),
646        push_temporal_predicate,
647    );
648
649    let on = if can_push_on_from_filter(ty) {
650        let mut conjunctions = std::mem::take(&mut predicate.conjunctions);
651
652        if push_temporal_predicate {
653            Condition { conjunctions }
654        } else {
655            // Do not push now on to the on, it will be pulled up into a filter instead.
656            let on = Condition {
657                conjunctions: conjunctions
658                    .extract_if(.., |expr| expr.count_nows() == 0)
659                    .collect(),
660            };
661            predicate.conjunctions = conjunctions;
662            on
663        }
664    } else {
665        Condition::true_cond()
666    };
667    (left, right, on)
668}
669
670/// Try to pushes parts of the join condition to its inputs. Returns the pushed predicates. The
671/// pushed part will be removed from the original join predicate.
672///
673/// `InputRef`s in the right pushed condition are indexed by the right child's output schema.
674pub fn push_down_join_condition(
675    on_condition: &mut Condition,
676    left_col_num: usize,
677    right_col_num: usize,
678    ty: JoinType,
679    push_temporal_predicate: bool,
680) -> (Condition, Condition) {
681    push_down_to_inputs(
682        on_condition,
683        left_col_num,
684        right_col_num,
685        can_push_left_from_on(ty),
686        can_push_right_from_on(ty),
687        push_temporal_predicate,
688    )
689}
690
691/// Try to split and pushdown `predicate` into a join's left/right child.
692/// Returns the pushed predicates. The pushed part will be removed from the original predicate.
693///
694/// `InputRef`s in the right `Condition` are shifted by `-left_col_num`.
695fn push_down_to_inputs(
696    predicate: &mut Condition,
697    left_col_num: usize,
698    right_col_num: usize,
699    push_left: bool,
700    push_right: bool,
701    push_temporal_predicate: bool,
702) -> (Condition, Condition) {
703    let mut conjunctions = std::mem::take(&mut predicate.conjunctions);
704    let (mut left, right, mut others) = if push_temporal_predicate {
705        Condition { conjunctions }.split(left_col_num, right_col_num)
706    } else {
707        let temporal_filter_cons = conjunctions
708            .extract_if(.., |e| e.count_nows() != 0)
709            .collect_vec();
710        let (left, right, mut others) =
711            Condition { conjunctions }.split(left_col_num, right_col_num);
712
713        others.conjunctions.extend(temporal_filter_cons);
714        (left, right, others)
715    };
716
717    if !push_left {
718        others.conjunctions.extend(left);
719        left = Condition::true_cond();
720    };
721
722    let right = if push_right {
723        let mut mapping = ColIndexMapping::with_shift_offset(
724            left_col_num + right_col_num,
725            -(left_col_num as isize),
726        );
727        right.rewrite_expr(&mut mapping)
728    } else {
729        others.conjunctions.extend(right);
730        Condition::true_cond()
731    };
732
733    predicate.conjunctions = others.conjunctions;
734
735    (left, right)
736}
737
738pub fn can_push_left_from_filter(ty: JoinType) -> bool {
739    matches!(
740        ty,
741        JoinType::Inner
742            | JoinType::LeftOuter
743            | JoinType::LeftSemi
744            | JoinType::LeftAnti
745            | JoinType::AsofInner
746            | JoinType::AsofLeftOuter
747    )
748}
749
750pub fn can_push_right_from_filter(ty: JoinType) -> bool {
751    matches!(
752        ty,
753        JoinType::Inner
754            | JoinType::RightOuter
755            | JoinType::RightSemi
756            | JoinType::RightAnti
757            | JoinType::AsofInner
758    )
759}
760
761pub fn can_push_on_from_filter(ty: JoinType) -> bool {
762    matches!(
763        ty,
764        JoinType::Inner | JoinType::LeftSemi | JoinType::RightSemi
765    )
766}
767
768pub fn can_push_left_from_on(ty: JoinType) -> bool {
769    matches!(
770        ty,
771        JoinType::Inner
772            | JoinType::RightOuter
773            | JoinType::LeftSemi
774            | JoinType::AsofInner
775            | JoinType::AsofLeftOuter
776    )
777}
778
779pub fn can_push_right_from_on(ty: JoinType) -> bool {
780    matches!(
781        ty,
782        JoinType::Inner
783            | JoinType::LeftOuter
784            | JoinType::RightSemi
785            | JoinType::AsofInner
786            | JoinType::AsofLeftOuter
787    )
788}