risingwave_frontend/optimizer/plan_node/
logical_join.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16
17use fixedbitset::FixedBitSet;
18use itertools::{EitherOrBoth, Itertools};
19use pretty_xmlish::{Pretty, XmlNode};
20use risingwave_expr::bail;
21use risingwave_pb::expr::expr_node::PbType;
22use risingwave_pb::plan_common::{AsOfJoinDesc, JoinType, PbAsOfJoinInequalityType};
23use risingwave_pb::stream_plan::StreamScanType;
24use risingwave_sqlparser::ast::AsOf;
25
26use super::generic::{
27    GenericPlanNode, GenericPlanRef, push_down_into_join, push_down_join_condition,
28};
29use super::utils::{Distill, childless_record};
30use super::{
31    ColPrunable, ExprRewritable, Logical, PlanBase, PlanRef, PlanTreeNodeBinary, PredicatePushdown,
32    StreamHashJoin, StreamProject, ToBatch, ToStream, generic,
33};
34use crate::error::{ErrorCode, Result, RwError};
35use crate::expr::{CollectInputRef, Expr, ExprImpl, ExprRewriter, ExprType, ExprVisitor, InputRef};
36use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
37use crate::optimizer::plan_node::generic::DynamicFilter;
38use crate::optimizer::plan_node::stream_asof_join::StreamAsOfJoin;
39use crate::optimizer::plan_node::utils::IndicesDisplay;
40use crate::optimizer::plan_node::{
41    BatchHashJoin, BatchLookupJoin, BatchNestedLoopJoin, ColumnPruningContext, EqJoinPredicate,
42    LogicalFilter, LogicalScan, PredicatePushdownContext, RewriteStreamContext,
43    StreamDynamicFilter, StreamFilter, StreamTableScan, StreamTemporalJoin, ToStreamContext,
44};
45use crate::optimizer::plan_visitor::LogicalCardinalityExt;
46use crate::optimizer::property::{Distribution, Order, RequiredDist};
47use crate::utils::{ColIndexMapping, ColIndexMappingRewriteExt, Condition, ConditionDisplay};
48
49/// `LogicalJoin` combines two relations according to some condition.
50///
51/// Each output row has fields from the left and right inputs. The set of output rows is a subset
52/// of the cartesian product of the two inputs; precisely which subset depends on the join
53/// condition. In addition, the output columns are a subset of the columns of the left and
54/// right columns, dependent on the output indices provided. A repeat output index is illegal.
55#[derive(Debug, Clone, PartialEq, Eq, Hash)]
56pub struct LogicalJoin {
57    pub base: PlanBase<Logical>,
58    core: generic::Join<PlanRef>,
59}
60
61impl Distill for LogicalJoin {
62    fn distill<'a>(&self) -> XmlNode<'a> {
63        let verbose = self.base.ctx().is_explain_verbose();
64        let mut vec = Vec::with_capacity(if verbose { 3 } else { 2 });
65        vec.push(("type", Pretty::debug(&self.join_type())));
66
67        let concat_schema = self.core.concat_schema();
68        let cond = Pretty::debug(&ConditionDisplay {
69            condition: self.on(),
70            input_schema: &concat_schema,
71        });
72        vec.push(("on", cond));
73
74        if verbose {
75            let data = IndicesDisplay::from_join(&self.core, &concat_schema);
76            vec.push(("output", data));
77        }
78
79        childless_record("LogicalJoin", vec)
80    }
81}
82
83impl LogicalJoin {
84    pub(crate) fn new(left: PlanRef, right: PlanRef, join_type: JoinType, on: Condition) -> Self {
85        let core = generic::Join::with_full_output(left, right, join_type, on);
86        Self::with_core(core)
87    }
88
89    pub(crate) fn with_output_indices(
90        left: PlanRef,
91        right: PlanRef,
92        join_type: JoinType,
93        on: Condition,
94        output_indices: Vec<usize>,
95    ) -> Self {
96        let core = generic::Join::new(left, right, on, join_type, output_indices);
97        Self::with_core(core)
98    }
99
100    pub fn with_core(core: generic::Join<PlanRef>) -> Self {
101        let base = PlanBase::new_logical_with_core(&core);
102        LogicalJoin { base, core }
103    }
104
105    pub fn create(
106        left: PlanRef,
107        right: PlanRef,
108        join_type: JoinType,
109        on_clause: ExprImpl,
110    ) -> PlanRef {
111        Self::new(left, right, join_type, Condition::with_expr(on_clause)).into()
112    }
113
114    pub fn internal_column_num(&self) -> usize {
115        self.core.internal_column_num()
116    }
117
118    pub fn i2l_col_mapping_ignore_join_type(&self) -> ColIndexMapping {
119        self.core.i2l_col_mapping_ignore_join_type()
120    }
121
122    pub fn i2r_col_mapping_ignore_join_type(&self) -> ColIndexMapping {
123        self.core.i2r_col_mapping_ignore_join_type()
124    }
125
126    /// Get a reference to the logical join's on.
127    pub fn on(&self) -> &Condition {
128        &self.core.on
129    }
130
131    /// Collect all input ref in the on condition. And separate them into left and right.
132    pub fn input_idx_on_condition(&self) -> (Vec<usize>, Vec<usize>) {
133        let input_refs = self
134            .core
135            .on
136            .collect_input_refs(self.core.left.schema().len() + self.core.right.schema().len());
137        let index_group = input_refs
138            .ones()
139            .chunk_by(|i| *i < self.core.left.schema().len());
140        let left_index = index_group
141            .into_iter()
142            .next()
143            .map_or(vec![], |group| group.1.collect_vec());
144        let right_index = index_group.into_iter().next().map_or(vec![], |group| {
145            group
146                .1
147                .map(|i| i - self.core.left.schema().len())
148                .collect_vec()
149        });
150        (left_index, right_index)
151    }
152
153    /// Get the join type of the logical join.
154    pub fn join_type(&self) -> JoinType {
155        self.core.join_type
156    }
157
158    /// Get the eq join key of the logical join.
159    pub fn eq_indexes(&self) -> Vec<(usize, usize)> {
160        self.core.eq_indexes()
161    }
162
163    /// Get the output indices of the logical join.
164    pub fn output_indices(&self) -> &Vec<usize> {
165        &self.core.output_indices
166    }
167
168    /// Clone with new output indices
169    pub fn clone_with_output_indices(&self, output_indices: Vec<usize>) -> Self {
170        Self::with_core(generic::Join {
171            output_indices,
172            ..self.core.clone()
173        })
174    }
175
176    /// Clone with new `on` condition
177    pub fn clone_with_cond(&self, on: Condition) -> Self {
178        Self::with_core(generic::Join {
179            on,
180            ..self.core.clone()
181        })
182    }
183
184    pub fn is_left_join(&self) -> bool {
185        matches!(self.join_type(), JoinType::LeftSemi | JoinType::LeftAnti)
186    }
187
188    pub fn is_right_join(&self) -> bool {
189        matches!(self.join_type(), JoinType::RightSemi | JoinType::RightAnti)
190    }
191
192    pub fn is_full_out(&self) -> bool {
193        self.core.is_full_out()
194    }
195
196    pub fn is_asof_join(&self) -> bool {
197        self.join_type() == JoinType::AsofInner || self.join_type() == JoinType::AsofLeftOuter
198    }
199
200    pub fn output_indices_are_trivial(&self) -> bool {
201        self.output_indices() == &(0..self.internal_column_num()).collect_vec()
202    }
203
204    /// Try to simplify the outer join with the predicate on the top of the join
205    ///
206    /// now it is just a naive implementation for comparison expression, we can give a more general
207    /// implementation with constant folding in future
208    fn simplify_outer(predicate: &Condition, left_col_num: usize, join_type: JoinType) -> JoinType {
209        let (mut gen_null_in_left, mut gen_null_in_right) = match join_type {
210            JoinType::LeftOuter => (false, true),
211            JoinType::RightOuter => (true, false),
212            JoinType::FullOuter => (true, true),
213            _ => return join_type,
214        };
215
216        for expr in &predicate.conjunctions {
217            if let ExprImpl::FunctionCall(func) = expr {
218                match func.func_type() {
219                    ExprType::Equal
220                    | ExprType::NotEqual
221                    | ExprType::LessThan
222                    | ExprType::LessThanOrEqual
223                    | ExprType::GreaterThan
224                    | ExprType::GreaterThanOrEqual => {
225                        for input in func.inputs() {
226                            if let ExprImpl::InputRef(input) = input {
227                                let idx = input.index;
228                                if idx < left_col_num {
229                                    gen_null_in_left = false;
230                                } else {
231                                    gen_null_in_right = false;
232                                }
233                            }
234                        }
235                    }
236                    _ => {}
237                };
238            }
239        }
240
241        match (gen_null_in_left, gen_null_in_right) {
242            (true, true) => JoinType::FullOuter,
243            (true, false) => JoinType::RightOuter,
244            (false, true) => JoinType::LeftOuter,
245            (false, false) => JoinType::Inner,
246        }
247    }
248
249    /// Index Join:
250    /// Try to convert logical join into batch lookup join and meanwhile it will do
251    /// the index selection for the lookup table so that we can benefit from indexes.
252    fn to_batch_lookup_join_with_index_selection(
253        &self,
254        predicate: EqJoinPredicate,
255        logical_join: generic::Join<PlanRef>,
256    ) -> Result<Option<BatchLookupJoin>> {
257        match logical_join.join_type {
258            JoinType::Inner
259            | JoinType::LeftOuter
260            | JoinType::LeftSemi
261            | JoinType::LeftAnti
262            | JoinType::AsofInner
263            | JoinType::AsofLeftOuter => {}
264            _ => return Ok(None),
265        };
266
267        // Index selection for index join.
268        let right = self.right();
269        // Lookup Join only supports basic tables on the join's right side.
270        let logical_scan: &LogicalScan = if let Some(logical_scan) = right.as_logical_scan() {
271            logical_scan
272        } else {
273            return Ok(None);
274        };
275
276        let mut result_plan = None;
277        // Lookup primary table.
278        if let Some(lookup_join) =
279            self.to_batch_lookup_join(predicate.clone(), logical_join.clone())?
280        {
281            result_plan = Some(lookup_join);
282        }
283
284        let indexes = logical_scan.indexes();
285        for index in indexes {
286            if let Some(index_scan) = logical_scan.to_index_scan_if_index_covered(index) {
287                let index_scan: PlanRef = index_scan.into();
288                let that = self.clone_with_left_right(self.left(), index_scan.clone());
289                let mut new_logical_join = logical_join.clone();
290                new_logical_join.right = index_scan.to_batch().expect("index scan failed to batch");
291
292                // Lookup covered index.
293                if let Some(lookup_join) =
294                    that.to_batch_lookup_join(predicate.clone(), new_logical_join)?
295                {
296                    match &result_plan {
297                        None => result_plan = Some(lookup_join),
298                        Some(prev_lookup_join) => {
299                            // Prefer to choose lookup join with longer lookup prefix len.
300                            if prev_lookup_join.lookup_prefix_len()
301                                < lookup_join.lookup_prefix_len()
302                            {
303                                result_plan = Some(lookup_join)
304                            }
305                        }
306                    }
307                }
308            }
309        }
310
311        Ok(result_plan)
312    }
313
314    /// Try to convert logical join into batch lookup join.
315    fn to_batch_lookup_join(
316        &self,
317        predicate: EqJoinPredicate,
318        logical_join: generic::Join<PlanRef>,
319    ) -> Result<Option<BatchLookupJoin>> {
320        match logical_join.join_type {
321            JoinType::Inner
322            | JoinType::LeftOuter
323            | JoinType::LeftSemi
324            | JoinType::LeftAnti
325            | JoinType::AsofInner
326            | JoinType::AsofLeftOuter => {}
327            _ => return Ok(None),
328        };
329
330        let right = self.right();
331        // Lookup Join only supports basic tables on the join's right side.
332        let logical_scan: &LogicalScan = if let Some(logical_scan) = right.as_logical_scan() {
333            logical_scan
334        } else {
335            return Ok(None);
336        };
337        let table_desc = logical_scan.table_desc().clone();
338        let output_column_ids = logical_scan.output_column_ids();
339
340        // Verify that the right join key columns are the the prefix of the primary key and
341        // also contain the distribution key.
342        let order_col_ids = table_desc.order_column_ids();
343        let order_key = table_desc.order_column_indices();
344        let dist_key = table_desc.distribution_key.clone();
345        // The at least prefix of order key that contains distribution key.
346        let mut dist_key_in_order_key_pos = vec![];
347        for d in dist_key {
348            let pos = order_key
349                .iter()
350                .position(|&x| x == d)
351                .expect("dist_key must in order_key");
352            dist_key_in_order_key_pos.push(pos);
353        }
354        // The shortest prefix of order key that contains distribution key.
355        let shortest_prefix_len = dist_key_in_order_key_pos
356            .iter()
357            .max()
358            .map_or(0, |pos| pos + 1);
359
360        // Distributed lookup join can't support lookup table with a singleton distribution.
361        if shortest_prefix_len == 0 {
362            return Ok(None);
363        }
364
365        // Reorder the join equal predicate to match the order key.
366        let mut reorder_idx = Vec::with_capacity(shortest_prefix_len);
367        for order_col_id in order_col_ids {
368            let mut found = false;
369            for (i, eq_idx) in predicate.right_eq_indexes().into_iter().enumerate() {
370                if order_col_id == output_column_ids[eq_idx] {
371                    reorder_idx.push(i);
372                    found = true;
373                    break;
374                }
375            }
376            if !found {
377                break;
378            }
379        }
380        if reorder_idx.len() < shortest_prefix_len {
381            return Ok(None);
382        }
383        let lookup_prefix_len = reorder_idx.len();
384        let predicate = predicate.reorder(&reorder_idx);
385
386        // Extract the predicate from logical scan. Only pure scan is supported.
387        let (new_scan, scan_predicate, project_expr) = logical_scan.predicate_pull_up();
388        // Construct output column to require column mapping
389        let o2r = if let Some(project_expr) = project_expr {
390            project_expr
391                .into_iter()
392                .map(|x| x.as_input_ref().unwrap().index)
393                .collect_vec()
394        } else {
395            (0..logical_scan.output_col_idx().len()).collect_vec()
396        };
397        let left_schema_len = logical_join.left.schema().len();
398
399        let mut join_predicate_rewriter = LookupJoinPredicateRewriter {
400            offset: left_schema_len,
401            mapping: o2r.clone(),
402        };
403
404        let new_eq_cond = predicate
405            .eq_cond()
406            .rewrite_expr(&mut join_predicate_rewriter);
407
408        let mut scan_predicate_rewriter = LookupJoinScanPredicateRewriter {
409            offset: left_schema_len,
410        };
411
412        let new_other_cond = predicate
413            .other_cond()
414            .clone()
415            .rewrite_expr(&mut join_predicate_rewriter)
416            .and(scan_predicate.rewrite_expr(&mut scan_predicate_rewriter));
417
418        let new_join_on = new_eq_cond.and(new_other_cond);
419        let new_predicate = EqJoinPredicate::create(
420            left_schema_len,
421            new_scan.schema().len(),
422            new_join_on.clone(),
423        );
424
425        // We discovered that we cannot use a lookup join after pulling up the predicate
426        // from one side and simplifying the condition. Let's use some other join instead.
427        if !new_predicate.has_eq() {
428            return Ok(None);
429        }
430
431        // Rewrite the join output indices and all output indices referred to the old scan need to
432        // rewrite.
433        let new_join_output_indices = logical_join
434            .output_indices
435            .iter()
436            .map(|&x| {
437                if x < left_schema_len {
438                    x
439                } else {
440                    o2r[x - left_schema_len] + left_schema_len
441                }
442            })
443            .collect_vec();
444
445        let new_scan_output_column_ids = new_scan.output_column_ids();
446        let as_of = new_scan.as_of.clone();
447
448        // Construct a new logical join, because we have change its RHS.
449        let new_logical_join = generic::Join::new(
450            logical_join.left,
451            new_scan.into(),
452            new_join_on,
453            logical_join.join_type,
454            new_join_output_indices,
455        );
456
457        let asof_desc = self
458            .is_asof_join()
459            .then(|| {
460                Self::get_inequality_desc_from_predicate(
461                    predicate.other_cond().clone(),
462                    left_schema_len,
463                )
464            })
465            .transpose()?;
466
467        Ok(Some(BatchLookupJoin::new(
468            new_logical_join,
469            new_predicate,
470            table_desc,
471            new_scan_output_column_ids,
472            lookup_prefix_len,
473            false,
474            as_of,
475            asof_desc,
476        )))
477    }
478
479    pub fn decompose(self) -> (PlanRef, PlanRef, Condition, JoinType, Vec<usize>) {
480        self.core.decompose()
481    }
482}
483
484impl PlanTreeNodeBinary for LogicalJoin {
485    fn left(&self) -> PlanRef {
486        self.core.left.clone()
487    }
488
489    fn right(&self) -> PlanRef {
490        self.core.right.clone()
491    }
492
493    fn clone_with_left_right(&self, left: PlanRef, right: PlanRef) -> Self {
494        Self::with_core(generic::Join {
495            left,
496            right,
497            ..self.core.clone()
498        })
499    }
500
501    fn rewrite_with_left_right(
502        &self,
503        left: PlanRef,
504        left_col_change: ColIndexMapping,
505        right: PlanRef,
506        right_col_change: ColIndexMapping,
507    ) -> (Self, ColIndexMapping) {
508        let (new_on, new_output_indices) = {
509            let (mut map, _) = left_col_change.clone().into_parts();
510            let (mut right_map, _) = right_col_change.clone().into_parts();
511            for i in right_map.iter_mut().flatten() {
512                *i += left.schema().len();
513            }
514            map.append(&mut right_map);
515            let mut mapping = ColIndexMapping::new(map, left.schema().len() + right.schema().len());
516
517            let new_output_indices = self
518                .output_indices()
519                .iter()
520                .map(|&i| mapping.map(i))
521                .collect::<Vec<_>>();
522            let new_on = self.on().clone().rewrite_expr(&mut mapping);
523            (new_on, new_output_indices)
524        };
525
526        let join = Self::with_output_indices(
527            left,
528            right,
529            self.join_type(),
530            new_on,
531            new_output_indices.clone(),
532        );
533
534        let new_i2o = ColIndexMapping::with_remaining_columns(
535            &new_output_indices,
536            join.internal_column_num(),
537        );
538
539        let old_o2i = self.core.o2i_col_mapping();
540
541        let old_o2l = old_o2i
542            .composite(&self.core.i2l_col_mapping())
543            .composite(&left_col_change);
544        let old_o2r = old_o2i
545            .composite(&self.core.i2r_col_mapping())
546            .composite(&right_col_change);
547        let new_l2o = join.core.l2i_col_mapping().composite(&new_i2o);
548        let new_r2o = join.core.r2i_col_mapping().composite(&new_i2o);
549
550        let out_col_change = old_o2l
551            .composite(&new_l2o)
552            .union(&old_o2r.composite(&new_r2o));
553        (join, out_col_change)
554    }
555}
556
557impl_plan_tree_node_for_binary! { LogicalJoin }
558
559impl ColPrunable for LogicalJoin {
560    fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef {
561        // make `required_cols` point to internal table instead of output schema.
562        let required_cols = required_cols
563            .iter()
564            .map(|i| self.output_indices()[*i])
565            .collect_vec();
566        let left_len = self.left().schema().fields.len();
567
568        let total_len = self.left().schema().len() + self.right().schema().len();
569        let mut resized_required_cols = FixedBitSet::with_capacity(total_len);
570
571        required_cols.iter().for_each(|&i| {
572            if self.is_right_join() {
573                resized_required_cols.insert(left_len + i);
574            } else {
575                resized_required_cols.insert(i);
576            }
577        });
578
579        // add those columns which are required in the join condition to
580        // to those that are required in the output
581        let mut visitor = CollectInputRef::new(resized_required_cols);
582        self.on().visit_expr(&mut visitor);
583        let left_right_required_cols = FixedBitSet::from(visitor).ones().collect_vec();
584
585        let mut left_required_cols = Vec::new();
586        let mut right_required_cols = Vec::new();
587        left_right_required_cols.iter().for_each(|&i| {
588            if i < left_len {
589                left_required_cols.push(i);
590            } else {
591                right_required_cols.push(i - left_len);
592            }
593        });
594
595        let mut on = self.on().clone();
596        let mut mapping =
597            ColIndexMapping::with_remaining_columns(&left_right_required_cols, total_len);
598        on = on.rewrite_expr(&mut mapping);
599
600        let new_output_indices = {
601            let required_inputs_in_output = if self.is_left_join() {
602                &left_required_cols
603            } else if self.is_right_join() {
604                &right_required_cols
605            } else {
606                &left_right_required_cols
607            };
608
609            let mapping =
610                ColIndexMapping::with_remaining_columns(required_inputs_in_output, total_len);
611            required_cols.iter().map(|&i| mapping.map(i)).collect_vec()
612        };
613
614        LogicalJoin::with_output_indices(
615            self.left().prune_col(&left_required_cols, ctx),
616            self.right().prune_col(&right_required_cols, ctx),
617            self.join_type(),
618            on,
619            new_output_indices,
620        )
621        .into()
622    }
623}
624
625impl ExprRewritable for LogicalJoin {
626    fn has_rewritable_expr(&self) -> bool {
627        true
628    }
629
630    fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
631        let mut core = self.core.clone();
632        core.rewrite_exprs(r);
633        Self {
634            base: self.base.clone_with_new_plan_id(),
635            core,
636        }
637        .into()
638    }
639}
640
641impl ExprVisitable for LogicalJoin {
642    fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
643        self.core.visit_exprs(v);
644    }
645}
646
647/// We are trying to derive a predicate to apply to the other side of a join if all
648/// the `InputRef`s in the predicate are eq condition columns, and can hence be substituted
649/// with the corresponding eq condition columns of the other side.
650///
651/// Strategy:
652/// 1. If the function is pure except for any `InputRef` (which may refer to impure computation),
653///    then we proceed. Else abort.
654/// 2. Then, we collect `InputRef`s in the conjunction.
655/// 3. If they are all columns in the given side of join eq condition, then we proceed. Else abort.
656/// 4. We then rewrite the `ExprImpl`, by replacing `InputRef` column indices with the equivalent in
657///    the other side.
658///
659/// # Arguments
660///
661/// Suppose we derive a predicate from the left side to be pushed to the right side.
662/// * `expr`: An expr from the left side.
663/// * `col_num`: The number of columns in the left side.
664fn derive_predicate_from_eq_condition(
665    expr: &ExprImpl,
666    eq_condition: &EqJoinPredicate,
667    col_num: usize,
668    expr_is_left: bool,
669) -> Option<ExprImpl> {
670    if expr.is_impure() {
671        return None;
672    }
673    let eq_indices = eq_condition
674        .eq_indexes_typed()
675        .iter()
676        .filter_map(|(l, r)| {
677            if l.return_type() != r.return_type() {
678                None
679            } else if expr_is_left {
680                Some(l.index())
681            } else {
682                Some(r.index())
683            }
684        })
685        .collect_vec();
686    if expr
687        .collect_input_refs(col_num)
688        .ones()
689        .any(|index| !eq_indices.contains(&index))
690    {
691        // expr contains an InputRef not in eq_condition
692        return None;
693    }
694    // The function is pure except for `InputRef` and all `InputRef`s are `eq_condition` indices.
695    // Hence, we can substitute those `InputRef`s with indices from the other side.
696    let other_side_mapping = if expr_is_left {
697        eq_condition.eq_indexes_typed().into_iter().collect()
698    } else {
699        eq_condition
700            .eq_indexes_typed()
701            .into_iter()
702            .map(|(x, y)| (y, x))
703            .collect()
704    };
705    struct InputRefsRewriter {
706        mapping: HashMap<InputRef, InputRef>,
707    }
708    impl ExprRewriter for InputRefsRewriter {
709        fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
710            self.mapping[&input_ref].clone().into()
711        }
712    }
713    Some(
714        InputRefsRewriter {
715            mapping: other_side_mapping,
716        }
717        .rewrite_expr(expr.clone()),
718    )
719}
720
721/// Rewrite the join predicate and all columns referred to the scan side need to rewrite.
722struct LookupJoinPredicateRewriter {
723    offset: usize,
724    mapping: Vec<usize>,
725}
726impl ExprRewriter for LookupJoinPredicateRewriter {
727    fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
728        if input_ref.index() < self.offset {
729            input_ref.into()
730        } else {
731            InputRef::new(
732                self.mapping[input_ref.index() - self.offset] + self.offset,
733                input_ref.return_type(),
734            )
735            .into()
736        }
737    }
738}
739
740/// Rewrite the scan predicate so we can add it to the join predicate.
741struct LookupJoinScanPredicateRewriter {
742    offset: usize,
743}
744impl ExprRewriter for LookupJoinScanPredicateRewriter {
745    fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
746        InputRef::new(input_ref.index() + self.offset, input_ref.return_type()).into()
747    }
748}
749
750impl PredicatePushdown for LogicalJoin {
751    /// Pushes predicates above and within a join node into the join node and/or its children nodes.
752    ///
753    /// # Which predicates can be pushed
754    ///
755    /// For inner join, we can do all kinds of pushdown.
756    ///
757    /// For left/right semi join, we can push filter to left/right and on-clause,
758    /// and push on-clause to left/right.
759    ///
760    /// For left/right anti join, we can push filter to left/right, but on-clause can not be pushed
761    ///
762    /// ## Outer Join
763    ///
764    /// Preserved Row table
765    /// : The table in an Outer Join that must return all rows.
766    ///
767    /// Null Supplying table
768    /// : This is the table that has nulls filled in for its columns in unmatched rows.
769    ///
770    /// |                          | Preserved Row table | Null Supplying table |
771    /// |--------------------------|---------------------|----------------------|
772    /// | Join predicate (on)      | Not Pushed          | Pushed               |
773    /// | Where predicate (filter) | Pushed              | Not Pushed           |
774    fn predicate_pushdown(
775        &self,
776        predicate: Condition,
777        ctx: &mut PredicatePushdownContext,
778    ) -> PlanRef {
779        // rewrite output col referencing indices as internal cols
780        let mut predicate = {
781            let mut mapping = self.core.o2i_col_mapping();
782            predicate.rewrite_expr(&mut mapping)
783        };
784
785        let left_col_num = self.left().schema().len();
786        let right_col_num = self.right().schema().len();
787        let join_type = LogicalJoin::simplify_outer(&predicate, left_col_num, self.join_type());
788
789        let push_down_temporal_predicate = !self.should_be_temporal_join();
790
791        let (left_from_filter, right_from_filter, on) = push_down_into_join(
792            &mut predicate,
793            left_col_num,
794            right_col_num,
795            join_type,
796            push_down_temporal_predicate,
797        );
798
799        let mut new_on = self.on().clone().and(on);
800        let (left_from_on, right_from_on) = push_down_join_condition(
801            &mut new_on,
802            left_col_num,
803            right_col_num,
804            join_type,
805            push_down_temporal_predicate,
806        );
807
808        let left_predicate = left_from_filter.and(left_from_on);
809        let right_predicate = right_from_filter.and(right_from_on);
810
811        // Derive conditions to push to the other side based on eq condition columns
812        let eq_condition = EqJoinPredicate::create(left_col_num, right_col_num, new_on.clone());
813
814        // Only push to RHS if RHS is inner side of a join (RHS requires match on LHS)
815        let right_from_left = if matches!(
816            join_type,
817            JoinType::Inner | JoinType::LeftOuter | JoinType::RightSemi | JoinType::LeftSemi
818        ) {
819            Condition {
820                conjunctions: left_predicate
821                    .conjunctions
822                    .iter()
823                    .filter_map(|expr| {
824                        derive_predicate_from_eq_condition(expr, &eq_condition, left_col_num, true)
825                    })
826                    .collect(),
827            }
828        } else {
829            Condition::true_cond()
830        };
831
832        // Only push to LHS if LHS is inner side of a join (LHS requires match on RHS)
833        let left_from_right = if matches!(
834            join_type,
835            JoinType::Inner | JoinType::RightOuter | JoinType::LeftSemi | JoinType::RightSemi
836        ) {
837            Condition {
838                conjunctions: right_predicate
839                    .conjunctions
840                    .iter()
841                    .filter_map(|expr| {
842                        derive_predicate_from_eq_condition(
843                            expr,
844                            &eq_condition,
845                            right_col_num,
846                            false,
847                        )
848                    })
849                    .collect(),
850            }
851        } else {
852            Condition::true_cond()
853        };
854
855        let left_predicate = left_predicate.and(left_from_right);
856        let right_predicate = right_predicate.and(right_from_left);
857
858        let new_left = self.left().predicate_pushdown(left_predicate, ctx);
859        let new_right = self.right().predicate_pushdown(right_predicate, ctx);
860        let new_join = LogicalJoin::with_output_indices(
861            new_left,
862            new_right,
863            join_type,
864            new_on,
865            self.output_indices().clone(),
866        );
867
868        let mut mapping = self.core.i2o_col_mapping();
869        predicate = predicate.rewrite_expr(&mut mapping);
870        LogicalFilter::create(new_join.into(), predicate)
871    }
872}
873
874impl LogicalJoin {
875    fn get_stream_input_for_hash_join(
876        &self,
877        predicate: &EqJoinPredicate,
878        ctx: &mut ToStreamContext,
879    ) -> Result<(PlanRef, PlanRef)> {
880        use super::stream::prelude::*;
881
882        let mut right = self.right().to_stream_with_dist_required(
883            &RequiredDist::shard_by_key(self.right().schema().len(), &predicate.right_eq_indexes()),
884            ctx,
885        )?;
886        let mut left = self.left();
887
888        let r2l = predicate.r2l_eq_columns_mapping(left.schema().len(), right.schema().len());
889        let l2r = predicate.l2r_eq_columns_mapping(left.schema().len(), right.schema().len());
890
891        let right_dist = right.distribution();
892        match right_dist {
893            Distribution::HashShard(_) => {
894                let left_dist = r2l
895                    .rewrite_required_distribution(&RequiredDist::PhysicalDist(right_dist.clone()));
896                left = left.to_stream_with_dist_required(&left_dist, ctx)?;
897            }
898            Distribution::UpstreamHashShard(_, _) => {
899                left = left.to_stream_with_dist_required(
900                    &RequiredDist::shard_by_key(
901                        self.left().schema().len(),
902                        &predicate.left_eq_indexes(),
903                    ),
904                    ctx,
905                )?;
906                let left_dist = left.distribution();
907                match left_dist {
908                    Distribution::HashShard(_) => {
909                        let right_dist = l2r.rewrite_required_distribution(
910                            &RequiredDist::PhysicalDist(left_dist.clone()),
911                        );
912                        right = right_dist.enforce_if_not_satisfies(right, &Order::any())?
913                    }
914                    Distribution::UpstreamHashShard(_, _) => {
915                        left = RequiredDist::hash_shard(&predicate.left_eq_indexes())
916                            .enforce_if_not_satisfies(left, &Order::any())?;
917                        right = RequiredDist::hash_shard(&predicate.right_eq_indexes())
918                            .enforce_if_not_satisfies(right, &Order::any())?;
919                    }
920                    _ => unreachable!(),
921                }
922            }
923            _ => unreachable!(),
924        }
925        Ok((left, right))
926    }
927
928    fn to_stream_hash_join(
929        &self,
930        predicate: EqJoinPredicate,
931        ctx: &mut ToStreamContext,
932    ) -> Result<PlanRef> {
933        use super::stream::prelude::*;
934
935        assert!(predicate.has_eq());
936        let (left, right) = self.get_stream_input_for_hash_join(&predicate, ctx)?;
937
938        let logical_join = self.clone_with_left_right(left, right);
939
940        // Convert to Hash Join for equal joins
941        // For inner joins, pull non-equal conditions to a filter operator on top of it
942        // We do so as the filter operator can apply the non-equal condition batch-wise (vectorized)
943        // as opposed to the HashJoin, which applies the condition row-wise.
944
945        let stream_hash_join = StreamHashJoin::new(logical_join.core.clone(), predicate.clone());
946        let pull_filter = self.join_type() == JoinType::Inner
947            && stream_hash_join.eq_join_predicate().has_non_eq()
948            && stream_hash_join.inequality_pairs().is_empty();
949        if pull_filter {
950            let default_indices = (0..self.internal_column_num()).collect::<Vec<_>>();
951
952            // Temporarily remove output indices.
953            let logical_join = logical_join.clone_with_output_indices(default_indices.clone());
954            let eq_cond = EqJoinPredicate::new(
955                Condition::true_cond(),
956                predicate.eq_keys().to_vec(),
957                self.left().schema().len(),
958                self.right().schema().len(),
959            );
960            let logical_join = logical_join.clone_with_cond(eq_cond.eq_cond());
961            let hash_join = StreamHashJoin::new(logical_join.core, eq_cond).into();
962            let logical_filter = generic::Filter::new(predicate.non_eq_cond(), hash_join);
963            let plan = StreamFilter::new(logical_filter).into();
964            if self.output_indices() != &default_indices {
965                let logical_project = generic::Project::with_mapping(
966                    plan,
967                    ColIndexMapping::with_remaining_columns(
968                        self.output_indices(),
969                        self.internal_column_num(),
970                    ),
971                );
972                Ok(StreamProject::new(logical_project).into())
973            } else {
974                Ok(plan)
975            }
976        } else {
977            Ok(stream_hash_join.into())
978        }
979    }
980
981    fn should_be_temporal_join(&self) -> bool {
982        let right = self.right();
983        if let Some(logical_scan) = right.as_logical_scan() {
984            matches!(logical_scan.as_of(), Some(AsOf::ProcessTime))
985        } else {
986            false
987        }
988    }
989
990    fn to_stream_temporal_join_with_index_selection(
991        &self,
992        predicate: EqJoinPredicate,
993        ctx: &mut ToStreamContext,
994    ) -> Result<StreamTemporalJoin> {
995        // Index selection for temporal join.
996        let right = self.right();
997        // `should_be_temporal_join()` has already check right input for us.
998        let logical_scan: &LogicalScan = right.as_logical_scan().unwrap();
999
1000        // Use primary table.
1001        let mut result_plan = self.to_stream_temporal_join(predicate.clone(), ctx);
1002        // Return directly if this temporal join can match the pk of its right table.
1003        if let Ok(temporal_join) = &result_plan
1004            && temporal_join.eq_join_predicate().eq_indexes().len()
1005                == logical_scan.primary_key().len()
1006        {
1007            return result_plan;
1008        }
1009        let indexes = logical_scan.indexes();
1010        for index in indexes {
1011            // Use index table
1012            if let Some(index_scan) = logical_scan.to_index_scan_if_index_covered(index) {
1013                let index_scan: PlanRef = index_scan.into();
1014                let that = self.clone_with_left_right(self.left(), index_scan.clone());
1015                if let Ok(temporal_join) = that.to_stream_temporal_join(predicate.clone(), ctx) {
1016                    match &result_plan {
1017                        Err(_) => result_plan = Ok(temporal_join),
1018                        Ok(prev_temporal_join) => {
1019                            // Prefer to the temporal join with a longer lookup prefix len.
1020                            if prev_temporal_join.eq_join_predicate().eq_indexes().len()
1021                                < temporal_join.eq_join_predicate().eq_indexes().len()
1022                            {
1023                                result_plan = Ok(temporal_join)
1024                            }
1025                        }
1026                    }
1027                }
1028            }
1029        }
1030
1031        result_plan
1032    }
1033
1034    fn check_temporal_rhs(right: &PlanRef) -> Result<&LogicalScan> {
1035        let Some(logical_scan) = right.as_logical_scan() else {
1036            return Err(RwError::from(ErrorCode::NotSupported(
1037                "Temporal join requires a table scan as its lookup table".into(),
1038                "Please provide a table scan".into(),
1039            )));
1040        };
1041
1042        if !matches!(logical_scan.as_of(), Some(AsOf::ProcessTime)) {
1043            return Err(RwError::from(ErrorCode::NotSupported(
1044                "Temporal join requires a table defined as temporal table".into(),
1045                "Please use FOR SYSTEM_TIME AS OF PROCTIME() syntax".into(),
1046            )));
1047        }
1048        Ok(logical_scan)
1049    }
1050
1051    fn temporal_join_scan_predicate_pull_up(
1052        logical_scan: &LogicalScan,
1053        predicate: EqJoinPredicate,
1054        output_indices: &[usize],
1055        left_schema_len: usize,
1056    ) -> Result<(StreamTableScan, EqJoinPredicate, Condition, Vec<usize>)> {
1057        // Extract the predicate from logical scan. Only pure scan is supported.
1058        let (new_scan, scan_predicate, project_expr) = logical_scan.predicate_pull_up();
1059        // Construct output column to require column mapping
1060        let o2r = if let Some(project_expr) = project_expr {
1061            project_expr
1062                .into_iter()
1063                .map(|x| x.as_input_ref().unwrap().index)
1064                .collect_vec()
1065        } else {
1066            (0..logical_scan.output_col_idx().len()).collect_vec()
1067        };
1068        let mut join_predicate_rewriter = LookupJoinPredicateRewriter {
1069            offset: left_schema_len,
1070            mapping: o2r.clone(),
1071        };
1072
1073        let new_eq_cond = predicate
1074            .eq_cond()
1075            .rewrite_expr(&mut join_predicate_rewriter);
1076
1077        let mut scan_predicate_rewriter = LookupJoinScanPredicateRewriter {
1078            offset: left_schema_len,
1079        };
1080
1081        let new_other_cond = predicate
1082            .other_cond()
1083            .clone()
1084            .rewrite_expr(&mut join_predicate_rewriter)
1085            .and(scan_predicate.rewrite_expr(&mut scan_predicate_rewriter));
1086
1087        let new_join_on = new_eq_cond.and(new_other_cond);
1088
1089        let new_predicate = EqJoinPredicate::create(
1090            left_schema_len,
1091            new_scan.schema().len(),
1092            new_join_on.clone(),
1093        );
1094
1095        // Rewrite the join output indices and all output indices referred to the old scan need to
1096        // rewrite.
1097        let new_join_output_indices = output_indices
1098            .iter()
1099            .map(|&x| {
1100                if x < left_schema_len {
1101                    x
1102                } else {
1103                    o2r[x - left_schema_len] + left_schema_len
1104                }
1105            })
1106            .collect_vec();
1107        // Use UpstreamOnly chain type
1108        let new_stream_table_scan =
1109            StreamTableScan::new_with_stream_scan_type(new_scan, StreamScanType::UpstreamOnly);
1110        Ok((
1111            new_stream_table_scan,
1112            new_predicate,
1113            new_join_on,
1114            new_join_output_indices,
1115        ))
1116    }
1117
1118    fn to_stream_temporal_join(
1119        &self,
1120        predicate: EqJoinPredicate,
1121        ctx: &mut ToStreamContext,
1122    ) -> Result<StreamTemporalJoin> {
1123        use super::stream::prelude::*;
1124
1125        assert!(predicate.has_eq());
1126
1127        let right = self.right();
1128
1129        let logical_scan = Self::check_temporal_rhs(&right)?;
1130
1131        let table_desc = logical_scan.table_desc();
1132        let output_column_ids = logical_scan.output_column_ids();
1133
1134        // Verify that the right join key columns are the the prefix of the primary key and
1135        // also contain the distribution key.
1136        let order_col_ids = table_desc.order_column_ids();
1137        let order_key = table_desc.order_column_indices();
1138        let dist_key = table_desc.distribution_key.clone();
1139
1140        let mut dist_key_in_order_key_pos = vec![];
1141        for d in dist_key {
1142            let pos = order_key
1143                .iter()
1144                .position(|&x| x == d)
1145                .expect("dist_key must in order_key");
1146            dist_key_in_order_key_pos.push(pos);
1147        }
1148        // The shortest prefix of order key that contains distribution key.
1149        let shortest_prefix_len = dist_key_in_order_key_pos
1150            .iter()
1151            .max()
1152            .map_or(0, |pos| pos + 1);
1153
1154        // Reorder the join equal predicate to match the order key.
1155        let mut reorder_idx = Vec::with_capacity(shortest_prefix_len);
1156        for order_col_id in order_col_ids {
1157            let mut found = false;
1158            for (i, eq_idx) in predicate.right_eq_indexes().into_iter().enumerate() {
1159                if order_col_id == output_column_ids[eq_idx] {
1160                    reorder_idx.push(i);
1161                    found = true;
1162                    break;
1163                }
1164            }
1165            if !found {
1166                break;
1167            }
1168        }
1169        if reorder_idx.len() < shortest_prefix_len {
1170            // TODO: support index selection for temporal join and refine this error message.
1171            return Err(RwError::from(ErrorCode::NotSupported(
1172                "Temporal join requires the lookup table's primary key contained exactly in the equivalence condition".into(),
1173                "Please add the primary key of the lookup table to the join condition and remove any other conditions".into(),
1174            )));
1175        }
1176        let lookup_prefix_len = reorder_idx.len();
1177        let predicate = predicate.reorder(&reorder_idx);
1178
1179        let required_dist = if dist_key_in_order_key_pos.is_empty() {
1180            RequiredDist::single()
1181        } else {
1182            let left_eq_indexes = predicate.left_eq_indexes();
1183            let left_dist_key = dist_key_in_order_key_pos
1184                .iter()
1185                .map(|pos| left_eq_indexes[*pos])
1186                .collect_vec();
1187
1188            RequiredDist::hash_shard(&left_dist_key)
1189        };
1190
1191        let left = self.left().to_stream(ctx)?;
1192        // Enforce a shuffle for the temporal join LHS to let the scheduler be able to schedule the join fragment together with the RHS with a `no_shuffle` exchange.
1193        let left = required_dist.enforce(left, &Order::any());
1194
1195        let (new_stream_table_scan, new_predicate, new_join_on, new_join_output_indices) =
1196            Self::temporal_join_scan_predicate_pull_up(
1197                logical_scan,
1198                predicate,
1199                self.output_indices(),
1200                self.left().schema().len(),
1201            )?;
1202
1203        let right = RequiredDist::no_shuffle(new_stream_table_scan.into());
1204        if !new_predicate.has_eq() {
1205            return Err(RwError::from(ErrorCode::NotSupported(
1206                "Temporal join requires a non trivial join condition".into(),
1207                "Please remove the false condition of the join".into(),
1208            )));
1209        }
1210
1211        // Construct a new logical join, because we have change its RHS.
1212        let new_logical_join = generic::Join::new(
1213            left,
1214            right,
1215            new_join_on,
1216            self.join_type(),
1217            new_join_output_indices,
1218        );
1219
1220        let new_predicate = new_predicate.retain_prefix_eq_key(lookup_prefix_len);
1221
1222        Ok(StreamTemporalJoin::new(
1223            new_logical_join,
1224            new_predicate,
1225            false,
1226        ))
1227    }
1228
1229    fn to_stream_nested_loop_temporal_join(
1230        &self,
1231        predicate: EqJoinPredicate,
1232        ctx: &mut ToStreamContext,
1233    ) -> Result<StreamTemporalJoin> {
1234        use super::stream::prelude::*;
1235        assert!(!predicate.has_eq());
1236
1237        let left = self.left().to_stream_with_dist_required(
1238            &RequiredDist::PhysicalDist(Distribution::Broadcast),
1239            ctx,
1240        )?;
1241        assert!(left.as_stream_exchange().is_some());
1242
1243        if self.join_type() != JoinType::Inner {
1244            return Err(RwError::from(ErrorCode::NotSupported(
1245                "Temporal join requires an inner join".into(),
1246                "Please use an inner join".into(),
1247            )));
1248        }
1249
1250        if !left.append_only() {
1251            return Err(RwError::from(ErrorCode::NotSupported(
1252                "Nested-loop Temporal join requires the left hash side to be append only".into(),
1253                "Please ensure the left hash side is append only".into(),
1254            )));
1255        }
1256
1257        let right = self.right();
1258        let logical_scan = Self::check_temporal_rhs(&right)?;
1259
1260        let (new_stream_table_scan, new_predicate, new_join_on, new_join_output_indices) =
1261            Self::temporal_join_scan_predicate_pull_up(
1262                logical_scan,
1263                predicate,
1264                self.output_indices(),
1265                self.left().schema().len(),
1266            )?;
1267
1268        let right = RequiredDist::no_shuffle(new_stream_table_scan.into());
1269
1270        // Construct a new logical join, because we have change its RHS.
1271        let new_logical_join = generic::Join::new(
1272            left,
1273            right,
1274            new_join_on,
1275            self.join_type(),
1276            new_join_output_indices,
1277        );
1278
1279        Ok(StreamTemporalJoin::new(
1280            new_logical_join,
1281            new_predicate,
1282            true,
1283        ))
1284    }
1285
1286    fn to_stream_dynamic_filter(
1287        &self,
1288        predicate: Condition,
1289        ctx: &mut ToStreamContext,
1290    ) -> Result<Option<PlanRef>> {
1291        use super::stream::prelude::*;
1292
1293        // If there is exactly one predicate, it is a comparison (<, <=, >, >=), and the
1294        // join is a `Inner` or `LeftSemi` join, we can convert the scalar subquery into a
1295        // `StreamDynamicFilter`
1296
1297        // Check if `Inner`/`LeftSemi`
1298        if !matches!(self.join_type(), JoinType::Inner | JoinType::LeftSemi) {
1299            return Ok(None);
1300        }
1301
1302        // Check if right side is a scalar
1303        if !self.right().max_one_row() {
1304            return Ok(None);
1305        }
1306        if self.right().schema().len() != 1 {
1307            return Ok(None);
1308        }
1309
1310        // Check if the join condition is a correlated comparison
1311        if predicate.conjunctions.len() > 1 {
1312            return Ok(None);
1313        }
1314        let expr: ExprImpl = predicate.into();
1315        let (left_ref, comparator, right_ref) = match expr.as_comparison_cond() {
1316            Some(v) => v,
1317            None => return Ok(None),
1318        };
1319
1320        let condition_cross_inputs = left_ref.index < self.left().schema().len()
1321            && right_ref.index == self.left().schema().len() /* right side has only one column */;
1322        if !condition_cross_inputs {
1323            // Maybe we should panic here because it means some predicates are not pushed down.
1324            return Ok(None);
1325        }
1326
1327        // We align input types on all join predicates with cmp operator
1328        if self.left().schema().fields()[left_ref.index].data_type
1329            != self.right().schema().fields()[0].data_type
1330        {
1331            return Ok(None);
1332        }
1333
1334        // Check if non of the columns from the inner side is required to output
1335        let all_output_from_left = self
1336            .output_indices()
1337            .iter()
1338            .all(|i| *i < self.left().schema().len());
1339        if !all_output_from_left {
1340            return Ok(None);
1341        }
1342
1343        let left = self.left().to_stream(ctx)?;
1344        let right = self.right().to_stream_with_dist_required(
1345            &RequiredDist::PhysicalDist(Distribution::Broadcast),
1346            ctx,
1347        )?;
1348
1349        assert!(right.as_stream_exchange().is_some());
1350        assert_eq!(
1351            *right.inputs().iter().exactly_one().unwrap().distribution(),
1352            Distribution::Single
1353        );
1354
1355        let core = DynamicFilter::new(comparator, left_ref.index, left, right);
1356        let plan = StreamDynamicFilter::new(core).into();
1357        // TODO: `DynamicFilterExecutor` should support `output_indices` in `ChunkBuilder`
1358        if self
1359            .output_indices()
1360            .iter()
1361            .copied()
1362            .ne(0..self.left().schema().len())
1363        {
1364            // The schema of dynamic filter is always the same as the left side now, and we have
1365            // checked that all output columns are from the left side before.
1366            let logical_project = generic::Project::with_mapping(
1367                plan,
1368                ColIndexMapping::with_remaining_columns(
1369                    self.output_indices(),
1370                    self.left().schema().len(),
1371                ),
1372            );
1373            Ok(Some(StreamProject::new(logical_project).into()))
1374        } else {
1375            Ok(Some(plan))
1376        }
1377    }
1378
1379    pub fn index_lookup_join_to_batch_lookup_join(&self) -> Result<PlanRef> {
1380        let predicate = EqJoinPredicate::create(
1381            self.left().schema().len(),
1382            self.right().schema().len(),
1383            self.on().clone(),
1384        );
1385        assert!(predicate.has_eq());
1386
1387        let mut logical_join = self.core.clone();
1388        logical_join.left = logical_join.left.to_batch()?;
1389        logical_join.right = logical_join.right.to_batch()?;
1390
1391        Ok(self
1392            .to_batch_lookup_join(predicate, logical_join)?
1393            .expect("Fail to convert to lookup join")
1394            .into())
1395    }
1396
1397    fn to_stream_asof_join(
1398        &self,
1399        predicate: EqJoinPredicate,
1400        ctx: &mut ToStreamContext,
1401    ) -> Result<StreamAsOfJoin> {
1402        use super::stream::prelude::*;
1403
1404        if predicate.eq_keys().is_empty() {
1405            return Err(ErrorCode::InvalidInputSyntax(
1406                "AsOf join requires at least 1 equal condition".to_owned(),
1407            )
1408            .into());
1409        }
1410
1411        let (left, right) = self.get_stream_input_for_hash_join(&predicate, ctx)?;
1412        let left_len = left.schema().len();
1413        let logical_join = self.clone_with_left_right(left, right);
1414
1415        let inequality_desc =
1416            Self::get_inequality_desc_from_predicate(predicate.other_cond().clone(), left_len)?;
1417
1418        Ok(StreamAsOfJoin::new(
1419            logical_join.core.clone(),
1420            predicate,
1421            inequality_desc,
1422        ))
1423    }
1424
1425    /// Convert the logical join to a Hash join.
1426    fn to_batch_hash_join(
1427        &self,
1428        logical_join: generic::Join<PlanRef>,
1429        predicate: EqJoinPredicate,
1430    ) -> Result<PlanRef> {
1431        use super::batch::prelude::*;
1432
1433        let left_schema_len = logical_join.left.schema().len();
1434        let asof_desc = self
1435            .is_asof_join()
1436            .then(|| {
1437                Self::get_inequality_desc_from_predicate(
1438                    predicate.other_cond().clone(),
1439                    left_schema_len,
1440                )
1441            })
1442            .transpose()?;
1443
1444        let batch_join = BatchHashJoin::new(logical_join, predicate, asof_desc);
1445        Ok(batch_join.into())
1446    }
1447
1448    pub fn get_inequality_desc_from_predicate(
1449        predicate: Condition,
1450        left_input_len: usize,
1451    ) -> Result<AsOfJoinDesc> {
1452        let expr: ExprImpl = predicate.into();
1453        if let Some((left_input_ref, expr_type, right_input_ref)) = expr.as_comparison_cond() {
1454            if left_input_ref.index() < left_input_len && right_input_ref.index() >= left_input_len
1455            {
1456                Ok(AsOfJoinDesc {
1457                    left_idx: left_input_ref.index() as u32,
1458                    right_idx: (right_input_ref.index() - left_input_len) as u32,
1459                    inequality_type: Self::expr_type_to_comparison_type(expr_type)?.into(),
1460                })
1461            } else {
1462                bail!("inequal condition from the same side should be push down in optimizer");
1463            }
1464        } else {
1465            Err(ErrorCode::InvalidInputSyntax(
1466                "AsOf join requires exactly 1 ineuquality condition".to_owned(),
1467            )
1468            .into())
1469        }
1470    }
1471
1472    fn expr_type_to_comparison_type(expr_type: PbType) -> Result<PbAsOfJoinInequalityType> {
1473        match expr_type {
1474            PbType::LessThan => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeLt),
1475            PbType::LessThanOrEqual => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeLe),
1476            PbType::GreaterThan => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeGt),
1477            PbType::GreaterThanOrEqual => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeGe),
1478            _ => Err(ErrorCode::InvalidInputSyntax(format!(
1479                "Invalid comparison type: {}",
1480                expr_type.as_str_name()
1481            ))
1482            .into()),
1483        }
1484    }
1485}
1486
1487impl ToBatch for LogicalJoin {
1488    fn to_batch(&self) -> Result<PlanRef> {
1489        let predicate = EqJoinPredicate::create(
1490            self.left().schema().len(),
1491            self.right().schema().len(),
1492            self.on().clone(),
1493        );
1494
1495        let mut logical_join = self.core.clone();
1496        logical_join.left = logical_join.left.to_batch()?;
1497        logical_join.right = logical_join.right.to_batch()?;
1498
1499        let ctx = self.base.ctx();
1500        let config = ctx.session_ctx().config();
1501
1502        if predicate.has_eq() {
1503            if !predicate.eq_keys_are_type_aligned() {
1504                return Err(ErrorCode::InternalError(format!(
1505                    "Join eq keys are not aligned for predicate: {predicate:?}"
1506                ))
1507                .into());
1508            }
1509            if config.batch_enable_lookup_join() {
1510                if let Some(lookup_join) = self.to_batch_lookup_join_with_index_selection(
1511                    predicate.clone(),
1512                    logical_join.clone(),
1513                )? {
1514                    return Ok(lookup_join.into());
1515                }
1516            }
1517            self.to_batch_hash_join(logical_join, predicate)
1518        } else if self.is_asof_join() {
1519            return Err(ErrorCode::InvalidInputSyntax(
1520                "AsOf join requires at least 1 equal condition".to_owned(),
1521            )
1522            .into());
1523        } else {
1524            // Convert to Nested-loop Join for non-equal joins
1525            Ok(BatchNestedLoopJoin::new(logical_join).into())
1526        }
1527    }
1528}
1529
1530impl ToStream for LogicalJoin {
1531    fn to_stream(&self, ctx: &mut ToStreamContext) -> Result<PlanRef> {
1532        if self
1533            .on()
1534            .conjunctions
1535            .iter()
1536            .any(|cond| cond.count_nows() > 0)
1537        {
1538            return Err(ErrorCode::NotSupported(
1539                "optimizer has tried to separate the temporal predicate(with now() expression) from the on condition, but it still reminded in on join's condition. Considering move it into WHERE clause?".to_owned(),
1540                 "please refer to https://www.risingwave.dev/docs/current/sql-pattern-temporal-filters/ for more information".to_owned()).into());
1541        }
1542
1543        let predicate = EqJoinPredicate::create(
1544            self.left().schema().len(),
1545            self.right().schema().len(),
1546            self.on().clone(),
1547        );
1548
1549        if self.join_type() == JoinType::AsofInner || self.join_type() == JoinType::AsofLeftOuter {
1550            self.to_stream_asof_join(predicate, ctx).map(|x| x.into())
1551        } else if predicate.has_eq() {
1552            if !predicate.eq_keys_are_type_aligned() {
1553                return Err(ErrorCode::InternalError(format!(
1554                    "Join eq keys are not aligned for predicate: {predicate:?}"
1555                ))
1556                .into());
1557            }
1558
1559            if self.should_be_temporal_join() {
1560                self.to_stream_temporal_join_with_index_selection(predicate, ctx)
1561                    .map(|x| x.into())
1562            } else {
1563                self.to_stream_hash_join(predicate, ctx)
1564            }
1565        } else if self.should_be_temporal_join() {
1566            self.to_stream_nested_loop_temporal_join(predicate, ctx)
1567                .map(|x| x.into())
1568        } else if let Some(dynamic_filter) =
1569            self.to_stream_dynamic_filter(self.on().clone(), ctx)?
1570        {
1571            Ok(dynamic_filter)
1572        } else {
1573            Err(RwError::from(ErrorCode::NotSupported(
1574                "streaming nested-loop join".to_owned(),
1575                "The non-equal join in the query requires a nested-loop join executor, which could be very expensive to run. \
1576                 Consider rewriting the query to use dynamic filter as a substitute if possible.\n\
1577                 See also: https://docs.risingwave.com/docs/current/sql-pattern-dynamic-filters/".to_owned(),
1578            )))
1579        }
1580    }
1581
1582    fn logical_rewrite_for_stream(
1583        &self,
1584        ctx: &mut RewriteStreamContext,
1585    ) -> Result<(PlanRef, ColIndexMapping)> {
1586        let (left, left_col_change) = self.left().logical_rewrite_for_stream(ctx)?;
1587        let left_len = left.schema().len();
1588        let (right, right_col_change) = self.right().logical_rewrite_for_stream(ctx)?;
1589        let (join, out_col_change) = self.rewrite_with_left_right(
1590            left.clone(),
1591            left_col_change,
1592            right.clone(),
1593            right_col_change,
1594        );
1595
1596        let mapping = ColIndexMapping::with_remaining_columns(
1597            join.output_indices(),
1598            join.internal_column_num(),
1599        );
1600
1601        let l2o = join.core.l2i_col_mapping().composite(&mapping);
1602        let r2o = join.core.r2i_col_mapping().composite(&mapping);
1603
1604        // Add missing pk indices to the logical join
1605        let mut left_to_add = left
1606            .expect_stream_key()
1607            .iter()
1608            .cloned()
1609            .filter(|i| l2o.try_map(*i).is_none())
1610            .collect_vec();
1611
1612        let mut right_to_add = right
1613            .expect_stream_key()
1614            .iter()
1615            .filter(|&&i| r2o.try_map(i).is_none())
1616            .map(|&i| i + left_len)
1617            .collect_vec();
1618
1619        // NOTE(st1page): add join keys in the pk_indices a work around before we really have stream
1620        // key.
1621        let right_len = right.schema().len();
1622        let eq_predicate = EqJoinPredicate::create(left_len, right_len, join.on().clone());
1623
1624        let either_or_both = self.core.add_which_join_key_to_pk();
1625
1626        for (lk, rk) in eq_predicate.eq_indexes() {
1627            match either_or_both {
1628                EitherOrBoth::Left(_) => {
1629                    if l2o.try_map(lk).is_none() {
1630                        left_to_add.push(lk);
1631                    }
1632                }
1633                EitherOrBoth::Right(_) => {
1634                    if r2o.try_map(rk).is_none() {
1635                        right_to_add.push(rk + left_len)
1636                    }
1637                }
1638                EitherOrBoth::Both(_, _) => {
1639                    if l2o.try_map(lk).is_none() {
1640                        left_to_add.push(lk);
1641                    }
1642                    if r2o.try_map(rk).is_none() {
1643                        right_to_add.push(rk + left_len)
1644                    }
1645                }
1646            };
1647        }
1648        let left_to_add = left_to_add.into_iter().unique();
1649        let right_to_add = right_to_add.into_iter().unique();
1650        // NOTE(st1page) over
1651
1652        let mut new_output_indices = join.output_indices().clone();
1653        if !join.is_right_join() {
1654            new_output_indices.extend(left_to_add);
1655        }
1656        if !join.is_left_join() {
1657            new_output_indices.extend(right_to_add);
1658        }
1659
1660        let join_with_pk = join.clone_with_output_indices(new_output_indices);
1661
1662        let plan = if join_with_pk.join_type() == JoinType::FullOuter {
1663            // ignore the all NULL to maintain the stream key's uniqueness, see https://github.com/risingwavelabs/risingwave/issues/8084 for more information
1664
1665            let l2o = join_with_pk
1666                .core
1667                .l2i_col_mapping()
1668                .composite(&join_with_pk.core.i2o_col_mapping());
1669            let r2o = join_with_pk
1670                .core
1671                .r2i_col_mapping()
1672                .composite(&join_with_pk.core.i2o_col_mapping());
1673            let left_right_stream_keys = join_with_pk
1674                .left()
1675                .expect_stream_key()
1676                .iter()
1677                .map(|i| l2o.map(*i))
1678                .chain(
1679                    join_with_pk
1680                        .right()
1681                        .expect_stream_key()
1682                        .iter()
1683                        .map(|i| r2o.map(*i)),
1684                )
1685                .collect_vec();
1686            let plan: PlanRef = join_with_pk.into();
1687            LogicalFilter::filter_out_all_null_keys(plan, &left_right_stream_keys)
1688        } else {
1689            join_with_pk.into()
1690        };
1691
1692        // the added columns is at the end, so it will not change the exists column index
1693        Ok((plan, out_col_change))
1694    }
1695}
1696
1697#[cfg(test)]
1698mod tests {
1699
1700    use std::collections::HashSet;
1701
1702    use risingwave_common::catalog::{Field, Schema};
1703    use risingwave_common::types::{DataType, Datum};
1704    use risingwave_pb::expr::expr_node::Type;
1705
1706    use super::*;
1707    use crate::expr::{FunctionCall, Literal, assert_eq_input_ref};
1708    use crate::optimizer::optimizer_context::OptimizerContext;
1709    use crate::optimizer::plan_node::LogicalValues;
1710    use crate::optimizer::property::FunctionalDependency;
1711
1712    /// Pruning
1713    /// ```text
1714    /// Join(on: input_ref(1)=input_ref(3))
1715    ///   TableScan(v1, v2, v3)
1716    ///   TableScan(v4, v5, v6)
1717    /// ```
1718    /// with required columns [2,3] will result in
1719    /// ```text
1720    /// Project(input_ref(1), input_ref(2))
1721    ///   Join(on: input_ref(0)=input_ref(2))
1722    ///     TableScan(v2, v3)
1723    ///     TableScan(v4)
1724    /// ```
1725    #[tokio::test]
1726    async fn test_prune_join() {
1727        let ty = DataType::Int32;
1728        let ctx = OptimizerContext::mock().await;
1729        let fields: Vec<Field> = (1..7)
1730            .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
1731            .collect();
1732        let left = LogicalValues::new(
1733            vec![],
1734            Schema {
1735                fields: fields[0..3].to_vec(),
1736            },
1737            ctx.clone(),
1738        );
1739        let right = LogicalValues::new(
1740            vec![],
1741            Schema {
1742                fields: fields[3..6].to_vec(),
1743            },
1744            ctx,
1745        );
1746        let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
1747            FunctionCall::new(
1748                Type::Equal,
1749                vec![
1750                    ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
1751                    ExprImpl::InputRef(Box::new(InputRef::new(3, ty))),
1752                ],
1753            )
1754            .unwrap(),
1755        ));
1756        let join_type = JoinType::Inner;
1757        let join: PlanRef = LogicalJoin::new(
1758            left.into(),
1759            right.into(),
1760            join_type,
1761            Condition::with_expr(on),
1762        )
1763        .into();
1764
1765        // Perform the prune
1766        let required_cols = vec![2, 3];
1767        let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1768
1769        // Check the result
1770        let join = plan.as_logical_join().unwrap();
1771        assert_eq!(join.schema().fields().len(), 2);
1772        assert_eq!(join.schema().fields()[0], fields[2]);
1773        assert_eq!(join.schema().fields()[1], fields[3]);
1774
1775        let expr: ExprImpl = join.on().clone().into();
1776        let call = expr.as_function_call().unwrap();
1777        assert_eq_input_ref!(&call.inputs()[0], 0);
1778        assert_eq_input_ref!(&call.inputs()[1], 2);
1779
1780        let left = join.left();
1781        let left = left.as_logical_values().unwrap();
1782        assert_eq!(left.schema().fields(), &fields[1..3]);
1783        let right = join.right();
1784        let right = right.as_logical_values().unwrap();
1785        assert_eq!(right.schema().fields(), &fields[3..4]);
1786    }
1787
1788    /// Semi join panicked previously at `prune_col`. Add test to prevent regression.
1789    #[tokio::test]
1790    async fn test_prune_semi_join() {
1791        let ty = DataType::Int32;
1792        let ctx = OptimizerContext::mock().await;
1793        let fields: Vec<Field> = (1..7)
1794            .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
1795            .collect();
1796        let left = LogicalValues::new(
1797            vec![],
1798            Schema {
1799                fields: fields[0..3].to_vec(),
1800            },
1801            ctx.clone(),
1802        );
1803        let right = LogicalValues::new(
1804            vec![],
1805            Schema {
1806                fields: fields[3..6].to_vec(),
1807            },
1808            ctx,
1809        );
1810        let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
1811            FunctionCall::new(
1812                Type::Equal,
1813                vec![
1814                    ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
1815                    ExprImpl::InputRef(Box::new(InputRef::new(4, ty))),
1816                ],
1817            )
1818            .unwrap(),
1819        ));
1820        for join_type in [
1821            JoinType::LeftSemi,
1822            JoinType::RightSemi,
1823            JoinType::LeftAnti,
1824            JoinType::RightAnti,
1825        ] {
1826            let join = LogicalJoin::new(
1827                left.clone().into(),
1828                right.clone().into(),
1829                join_type,
1830                Condition::with_expr(on.clone()),
1831            );
1832
1833            let offset = if join.is_right_join() { 3 } else { 0 };
1834            let join: PlanRef = join.into();
1835            // Perform the prune
1836            let required_cols = vec![0];
1837            // key 0 is never used in the join (always key 1)
1838            let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1839            let as_plan = plan.as_logical_join().unwrap();
1840            // Check the result
1841            assert_eq!(as_plan.schema().fields().len(), 1);
1842            assert_eq!(as_plan.schema().fields()[0], fields[offset]);
1843
1844            // Perform the prune
1845            let required_cols = vec![0, 1, 2];
1846            // should not panic here
1847            let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1848            let as_plan = plan.as_logical_join().unwrap();
1849            // Check the result
1850            assert_eq!(as_plan.schema().fields().len(), 3);
1851            assert_eq!(as_plan.schema().fields()[0], fields[offset]);
1852            assert_eq!(as_plan.schema().fields()[1], fields[offset + 1]);
1853            assert_eq!(as_plan.schema().fields()[2], fields[offset + 2]);
1854        }
1855    }
1856
1857    /// Pruning
1858    /// ```text
1859    /// Join(on: input_ref(1)=input_ref(3))
1860    ///   TableScan(v1, v2, v3)
1861    ///   TableScan(v4, v5, v6)
1862    /// ```
1863    /// with required columns [1,3] will result in
1864    /// ```text
1865    /// Join(on: input_ref(0)=input_ref(1))
1866    ///   TableScan(v2)
1867    ///   TableScan(v4)
1868    /// ```
1869    #[tokio::test]
1870    async fn test_prune_join_no_project() {
1871        let ty = DataType::Int32;
1872        let ctx = OptimizerContext::mock().await;
1873        let fields: Vec<Field> = (1..7)
1874            .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
1875            .collect();
1876        let left = LogicalValues::new(
1877            vec![],
1878            Schema {
1879                fields: fields[0..3].to_vec(),
1880            },
1881            ctx.clone(),
1882        );
1883        let right = LogicalValues::new(
1884            vec![],
1885            Schema {
1886                fields: fields[3..6].to_vec(),
1887            },
1888            ctx,
1889        );
1890        let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
1891            FunctionCall::new(
1892                Type::Equal,
1893                vec![
1894                    ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
1895                    ExprImpl::InputRef(Box::new(InputRef::new(3, ty))),
1896                ],
1897            )
1898            .unwrap(),
1899        ));
1900        let join_type = JoinType::Inner;
1901        let join: PlanRef = LogicalJoin::new(
1902            left.into(),
1903            right.into(),
1904            join_type,
1905            Condition::with_expr(on),
1906        )
1907        .into();
1908
1909        // Perform the prune
1910        let required_cols = vec![1, 3];
1911        let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1912
1913        // Check the result
1914        let join = plan.as_logical_join().unwrap();
1915        assert_eq!(join.schema().fields().len(), 2);
1916        assert_eq!(join.schema().fields()[0], fields[1]);
1917        assert_eq!(join.schema().fields()[1], fields[3]);
1918
1919        let expr: ExprImpl = join.on().clone().into();
1920        let call = expr.as_function_call().unwrap();
1921        assert_eq_input_ref!(&call.inputs()[0], 0);
1922        assert_eq_input_ref!(&call.inputs()[1], 1);
1923
1924        let left = join.left();
1925        let left = left.as_logical_values().unwrap();
1926        assert_eq!(left.schema().fields(), &fields[1..2]);
1927        let right = join.right();
1928        let right = right.as_logical_values().unwrap();
1929        assert_eq!(right.schema().fields(), &fields[3..4]);
1930    }
1931
1932    /// Convert
1933    /// ```text
1934    /// Join(on: ($1 = $3) AND ($2 == 42))
1935    ///   TableScan(v1, v2, v3)
1936    ///   TableScan(v4, v5, v6)
1937    /// ```
1938    /// to
1939    /// ```text
1940    /// Filter($2 == 42)
1941    ///   HashJoin(on: $1 = $3)
1942    ///     TableScan(v1, v2, v3)
1943    ///     TableScan(v4, v5, v6)
1944    /// ```
1945    #[tokio::test]
1946    async fn test_join_to_batch() {
1947        let ctx = OptimizerContext::mock().await;
1948        let fields: Vec<Field> = (1..7)
1949            .map(|i| Field::with_name(DataType::Int32, format!("v{}", i)))
1950            .collect();
1951        let left = LogicalValues::new(
1952            vec![],
1953            Schema {
1954                fields: fields[0..3].to_vec(),
1955            },
1956            ctx.clone(),
1957        );
1958        let right = LogicalValues::new(
1959            vec![],
1960            Schema {
1961                fields: fields[3..6].to_vec(),
1962            },
1963            ctx,
1964        );
1965
1966        fn input_ref(i: usize) -> ExprImpl {
1967            ExprImpl::InputRef(Box::new(InputRef::new(i, DataType::Int32)))
1968        }
1969        let eq_cond = ExprImpl::FunctionCall(Box::new(
1970            FunctionCall::new(Type::Equal, vec![input_ref(1), input_ref(3)]).unwrap(),
1971        ));
1972        let non_eq_cond = ExprImpl::FunctionCall(Box::new(
1973            FunctionCall::new(
1974                Type::Equal,
1975                vec![
1976                    input_ref(2),
1977                    ExprImpl::Literal(Box::new(Literal::new(
1978                        Datum::Some(42_i32.into()),
1979                        DataType::Int32,
1980                    ))),
1981                ],
1982            )
1983            .unwrap(),
1984        ));
1985        // Condition: ($1 = $3) AND ($2 == 42)
1986        let on_cond = ExprImpl::FunctionCall(Box::new(
1987            FunctionCall::new(Type::And, vec![eq_cond.clone(), non_eq_cond.clone()]).unwrap(),
1988        ));
1989
1990        let join_type = JoinType::Inner;
1991        let logical_join = LogicalJoin::new(
1992            left.into(),
1993            right.into(),
1994            join_type,
1995            Condition::with_expr(on_cond),
1996        );
1997
1998        // Perform `to_batch`
1999        let result = logical_join.to_batch().unwrap();
2000
2001        // Expected plan:  HashJoin($1 = $3 AND $2 == 42)
2002        let hash_join = result.as_batch_hash_join().unwrap();
2003        assert_eq!(
2004            ExprImpl::from(hash_join.eq_join_predicate().eq_cond()),
2005            eq_cond
2006        );
2007        assert_eq!(
2008            *hash_join
2009                .eq_join_predicate()
2010                .non_eq_cond()
2011                .conjunctions
2012                .first()
2013                .unwrap(),
2014            non_eq_cond
2015        );
2016    }
2017
2018    /// Convert
2019    /// ```text
2020    /// Join(join_type: left outer, on: ($1 = $3) AND ($2 == 42))
2021    ///   TableScan(v1, v2, v3)
2022    ///   TableScan(v4, v5, v6)
2023    /// ```
2024    /// to
2025    /// ```text
2026    /// HashJoin(join_type: left outer, on: ($1 = $3) AND ($2 == 42))
2027    ///   TableScan(v1, v2, v3)
2028    ///   TableScan(v4, v5, v6)
2029    /// ```
2030    #[tokio::test]
2031    #[ignore] // ignore due to refactor logical scan, but the test seem to duplicate with the explain test
2032    // framework, maybe we will remove it?
2033    async fn test_join_to_stream() {
2034        // let ctx = Rc::new(RefCell::new(QueryContext::mock().await));
2035        // let fields: Vec<Field> = (1..7)
2036        //     .map(|i| Field {
2037        //         data_type: DataType::Int32,
2038        //         name: format!("v{}", i),
2039        //     })
2040        //     .collect();
2041        // let left = LogicalScan::new(
2042        //     "left".to_string(),
2043        //     TableId::new(0),
2044        //     vec![1.into(), 2.into(), 3.into()],
2045        //     Schema {
2046        //         fields: fields[0..3].to_vec(),
2047        //     },
2048        //     ctx.clone(),
2049        // );
2050        // let right = LogicalScan::new(
2051        //     "right".to_string(),
2052        //     TableId::new(0),
2053        //     vec![4.into(), 5.into(), 6.into()],
2054        //     Schema {
2055        //         fields: fields[3..6].to_vec(),
2056        //     },
2057        //     ctx,
2058        // );
2059        // let eq_cond = ExprImpl::FunctionCall(Box::new(
2060        //     FunctionCall::new(
2061        //         Type::Equal,
2062        //         vec![
2063        //             ExprImpl::InputRef(Box::new(InputRef::new(1, DataType::Int32))),
2064        //             ExprImpl::InputRef(Box::new(InputRef::new(3, DataType::Int32))),
2065        //         ],
2066        //     )
2067        //     .unwrap(),
2068        // ));
2069        // let non_eq_cond = ExprImpl::FunctionCall(Box::new(
2070        //     FunctionCall::new(
2071        //         Type::Equal,
2072        //         vec![
2073        //             ExprImpl::InputRef(Box::new(InputRef::new(2, DataType::Int32))),
2074        //             ExprImpl::Literal(Box::new(Literal::new(
2075        //                 Datum::Some(42_i32.into()),
2076        //                 DataType::Int32,
2077        //             ))),
2078        //         ],
2079        //     )
2080        //     .unwrap(),
2081        // ));
2082        // // Condition: ($1 = $3) AND ($2 == 42)
2083        // let on_cond = ExprImpl::FunctionCall(Box::new(
2084        //     FunctionCall::new(Type::And, vec![eq_cond, non_eq_cond]).unwrap(),
2085        // ));
2086
2087        // let join_type = JoinType::LeftOuter;
2088        // let logical_join = LogicalJoin::new(
2089        //     left.into(),
2090        //     right.into(),
2091        //     join_type,
2092        //     Condition::with_expr(on_cond.clone()),
2093        // );
2094
2095        // // Perform `to_stream`
2096        // let result = logical_join.to_stream();
2097
2098        // // Expected plan: HashJoin(($1 = $3) AND ($2 == 42))
2099        // let hash_join = result.as_stream_hash_join().unwrap();
2100        // assert_eq!(hash_join.eq_join_predicate().all_cond().as_expr(), on_cond);
2101    }
2102    /// Pruning
2103    /// ```text
2104    /// Join(on: input_ref(1)=input_ref(3))
2105    ///   TableScan(v1, v2, v3)
2106    ///   TableScan(v4, v5, v6)
2107    /// ```
2108    /// with required columns [3, 2] will result in
2109    /// ```text
2110    /// Project(input_ref(2), input_ref(1))
2111    ///   Join(on: input_ref(0)=input_ref(2))
2112    ///     TableScan(v2, v3)
2113    ///     TableScan(v4)
2114    /// ```
2115    #[tokio::test]
2116    async fn test_join_column_prune_with_order_required() {
2117        let ty = DataType::Int32;
2118        let ctx = OptimizerContext::mock().await;
2119        let fields: Vec<Field> = (1..7)
2120            .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
2121            .collect();
2122        let left = LogicalValues::new(
2123            vec![],
2124            Schema {
2125                fields: fields[0..3].to_vec(),
2126            },
2127            ctx.clone(),
2128        );
2129        let right = LogicalValues::new(
2130            vec![],
2131            Schema {
2132                fields: fields[3..6].to_vec(),
2133            },
2134            ctx,
2135        );
2136        let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
2137            FunctionCall::new(
2138                Type::Equal,
2139                vec![
2140                    ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
2141                    ExprImpl::InputRef(Box::new(InputRef::new(3, ty))),
2142                ],
2143            )
2144            .unwrap(),
2145        ));
2146        let join_type = JoinType::Inner;
2147        let join: PlanRef = LogicalJoin::new(
2148            left.into(),
2149            right.into(),
2150            join_type,
2151            Condition::with_expr(on),
2152        )
2153        .into();
2154
2155        // Perform the prune
2156        let required_cols = vec![3, 2];
2157        let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
2158
2159        // Check the result
2160        let join = plan.as_logical_join().unwrap();
2161        assert_eq!(join.schema().fields().len(), 2);
2162        assert_eq!(join.schema().fields()[0], fields[3]);
2163        assert_eq!(join.schema().fields()[1], fields[2]);
2164
2165        let expr: ExprImpl = join.on().clone().into();
2166        let call = expr.as_function_call().unwrap();
2167        assert_eq_input_ref!(&call.inputs()[0], 0);
2168        assert_eq_input_ref!(&call.inputs()[1], 2);
2169
2170        let left = join.left();
2171        let left = left.as_logical_values().unwrap();
2172        assert_eq!(left.schema().fields(), &fields[1..3]);
2173        let right = join.right();
2174        let right = right.as_logical_values().unwrap();
2175        assert_eq!(right.schema().fields(), &fields[3..4]);
2176    }
2177
2178    #[tokio::test]
2179    async fn fd_derivation_inner_outer_join() {
2180        // left: [l0, l1], right: [r0, r1, r2]
2181        // FD: l0 --> l1, r0 --> { r1, r2 }
2182        // On: l0 = 0 AND l1 = r1
2183        //
2184        // Inner Join:
2185        //  Schema: [l0, l1, r0, r1, r2]
2186        //  FD: l0 --> l1, r0 --> { r1, r2 }, {} --> l0, l1 --> r1, r1 --> l1
2187        // Left Outer Join:
2188        //  Schema: [l0, l1, r0, r1, r2]
2189        //  FD: l0 --> l1
2190        // Right Outer Join:
2191        //  Schema: [l0, l1, r0, r1, r2]
2192        //  FD: r0 --> { r1, r2 }
2193        // Full Outer Join:
2194        //  Schema: [l0, l1, r0, r1, r2]
2195        //  FD: empty
2196        // Left Semi/Anti Join:
2197        //  Schema: [l0, l1]
2198        //  FD: l0 --> l1
2199        // Right Semi/Anti Join:
2200        //  Schema: [r0, r1, r2]
2201        //  FD: r0 --> {r1, r2}
2202        let ctx = OptimizerContext::mock().await;
2203        let left = {
2204            let fields: Vec<Field> = vec![
2205                Field::with_name(DataType::Int32, "l0"),
2206                Field::with_name(DataType::Int32, "l1"),
2207            ];
2208            let mut values = LogicalValues::new(vec![], Schema { fields }, ctx.clone());
2209            // 0 --> 1
2210            values
2211                .base
2212                .functional_dependency_mut()
2213                .add_functional_dependency_by_column_indices(&[0], &[1]);
2214            values
2215        };
2216        let right = {
2217            let fields: Vec<Field> = vec![
2218                Field::with_name(DataType::Int32, "r0"),
2219                Field::with_name(DataType::Int32, "r1"),
2220                Field::with_name(DataType::Int32, "r2"),
2221            ];
2222            let mut values = LogicalValues::new(vec![], Schema { fields }, ctx);
2223            // 0 --> 1, 2
2224            values
2225                .base
2226                .functional_dependency_mut()
2227                .add_functional_dependency_by_column_indices(&[0], &[1, 2]);
2228            values
2229        };
2230        // l0 = 0 AND l1 = r1
2231        let on: ExprImpl = FunctionCall::new(
2232            Type::And,
2233            vec![
2234                FunctionCall::new(
2235                    Type::Equal,
2236                    vec![
2237                        InputRef::new(0, DataType::Int32).into(),
2238                        ExprImpl::literal_int(0),
2239                    ],
2240                )
2241                .unwrap()
2242                .into(),
2243                FunctionCall::new(
2244                    Type::Equal,
2245                    vec![
2246                        InputRef::new(1, DataType::Int32).into(),
2247                        InputRef::new(3, DataType::Int32).into(),
2248                    ],
2249                )
2250                .unwrap()
2251                .into(),
2252            ],
2253        )
2254        .unwrap()
2255        .into();
2256        let expected_fd_set = [
2257            (
2258                JoinType::Inner,
2259                [
2260                    // inherit from left
2261                    FunctionalDependency::with_indices(5, &[0], &[1]),
2262                    // inherit from right
2263                    FunctionalDependency::with_indices(5, &[2], &[3, 4]),
2264                    // constant column in join condition
2265                    FunctionalDependency::with_indices(5, &[], &[0]),
2266                    // eq column in join condition
2267                    FunctionalDependency::with_indices(5, &[1], &[3]),
2268                    FunctionalDependency::with_indices(5, &[3], &[1]),
2269                ]
2270                .into_iter()
2271                .collect::<HashSet<_>>(),
2272            ),
2273            (JoinType::FullOuter, HashSet::new()),
2274            (
2275                JoinType::RightOuter,
2276                [
2277                    // inherit from right
2278                    FunctionalDependency::with_indices(5, &[2], &[3, 4]),
2279                ]
2280                .into_iter()
2281                .collect::<HashSet<_>>(),
2282            ),
2283            (
2284                JoinType::LeftOuter,
2285                [
2286                    // inherit from left
2287                    FunctionalDependency::with_indices(5, &[0], &[1]),
2288                ]
2289                .into_iter()
2290                .collect::<HashSet<_>>(),
2291            ),
2292            (
2293                JoinType::LeftSemi,
2294                [
2295                    // inherit from left
2296                    FunctionalDependency::with_indices(2, &[0], &[1]),
2297                ]
2298                .into_iter()
2299                .collect::<HashSet<_>>(),
2300            ),
2301            (
2302                JoinType::LeftAnti,
2303                [
2304                    // inherit from left
2305                    FunctionalDependency::with_indices(2, &[0], &[1]),
2306                ]
2307                .into_iter()
2308                .collect::<HashSet<_>>(),
2309            ),
2310            (
2311                JoinType::RightSemi,
2312                [
2313                    // inherit from right
2314                    FunctionalDependency::with_indices(3, &[0], &[1, 2]),
2315                ]
2316                .into_iter()
2317                .collect::<HashSet<_>>(),
2318            ),
2319            (
2320                JoinType::RightAnti,
2321                [
2322                    // inherit from right
2323                    FunctionalDependency::with_indices(3, &[0], &[1, 2]),
2324                ]
2325                .into_iter()
2326                .collect::<HashSet<_>>(),
2327            ),
2328        ];
2329
2330        for (join_type, expected_res) in expected_fd_set {
2331            let join = LogicalJoin::new(
2332                left.clone().into(),
2333                right.clone().into(),
2334                join_type,
2335                Condition::with_expr(on.clone()),
2336            );
2337            let fd_set = join
2338                .functional_dependency()
2339                .as_dependencies()
2340                .iter()
2341                .cloned()
2342                .collect::<HashSet<_>>();
2343            assert_eq!(fd_set, expected_res);
2344        }
2345    }
2346}