risingwave_frontend/optimizer/plan_node/
logical_join.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16
17use fixedbitset::FixedBitSet;
18use itertools::{EitherOrBoth, Itertools};
19use pretty_xmlish::{Pretty, XmlNode};
20use risingwave_expr::bail;
21use risingwave_pb::expr::expr_node::PbType;
22use risingwave_pb::plan_common::{AsOfJoinDesc, JoinType, PbAsOfJoinInequalityType};
23use risingwave_pb::stream_plan::StreamScanType;
24use risingwave_sqlparser::ast::AsOf;
25
26use super::generic::{
27    GenericPlanNode, GenericPlanRef, push_down_into_join, push_down_join_condition,
28};
29use super::utils::{Distill, childless_record};
30use super::{
31    ColPrunable, ExprRewritable, Logical, PlanBase, PlanRef, PlanTreeNodeBinary, PredicatePushdown,
32    StreamHashJoin, StreamProject, ToBatch, ToStream, generic,
33};
34use crate::error::{ErrorCode, Result, RwError};
35use crate::expr::{CollectInputRef, Expr, ExprImpl, ExprRewriter, ExprType, ExprVisitor, InputRef};
36use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
37use crate::optimizer::plan_node::generic::DynamicFilter;
38use crate::optimizer::plan_node::stream_asof_join::StreamAsOfJoin;
39use crate::optimizer::plan_node::utils::IndicesDisplay;
40use crate::optimizer::plan_node::{
41    BatchHashJoin, BatchLookupJoin, BatchNestedLoopJoin, ColumnPruningContext, EqJoinPredicate,
42    LogicalFilter, LogicalScan, PredicatePushdownContext, RewriteStreamContext,
43    StreamDynamicFilter, StreamFilter, StreamTableScan, StreamTemporalJoin, ToStreamContext,
44};
45use crate::optimizer::plan_visitor::LogicalCardinalityExt;
46use crate::optimizer::property::{Distribution, Order, RequiredDist};
47use crate::utils::{ColIndexMapping, ColIndexMappingRewriteExt, Condition, ConditionDisplay};
48
49/// `LogicalJoin` combines two relations according to some condition.
50///
51/// Each output row has fields from the left and right inputs. The set of output rows is a subset
52/// of the cartesian product of the two inputs; precisely which subset depends on the join
53/// condition. In addition, the output columns are a subset of the columns of the left and
54/// right columns, dependent on the output indices provided. A repeat output index is illegal.
55#[derive(Debug, Clone, PartialEq, Eq, Hash)]
56pub struct LogicalJoin {
57    pub base: PlanBase<Logical>,
58    core: generic::Join<PlanRef>,
59}
60
61impl Distill for LogicalJoin {
62    fn distill<'a>(&self) -> XmlNode<'a> {
63        let verbose = self.base.ctx().is_explain_verbose();
64        let mut vec = Vec::with_capacity(if verbose { 3 } else { 2 });
65        vec.push(("type", Pretty::debug(&self.join_type())));
66
67        let concat_schema = self.core.concat_schema();
68        let cond = Pretty::debug(&ConditionDisplay {
69            condition: self.on(),
70            input_schema: &concat_schema,
71        });
72        vec.push(("on", cond));
73
74        if verbose {
75            let data = IndicesDisplay::from_join(&self.core, &concat_schema);
76            vec.push(("output", data));
77        }
78
79        childless_record("LogicalJoin", vec)
80    }
81}
82
83impl LogicalJoin {
84    pub(crate) fn new(left: PlanRef, right: PlanRef, join_type: JoinType, on: Condition) -> Self {
85        let core = generic::Join::with_full_output(left, right, join_type, on);
86        Self::with_core(core)
87    }
88
89    pub(crate) fn with_output_indices(
90        left: PlanRef,
91        right: PlanRef,
92        join_type: JoinType,
93        on: Condition,
94        output_indices: Vec<usize>,
95    ) -> Self {
96        let core = generic::Join::new(left, right, on, join_type, output_indices);
97        Self::with_core(core)
98    }
99
100    pub fn with_core(core: generic::Join<PlanRef>) -> Self {
101        let base = PlanBase::new_logical_with_core(&core);
102        LogicalJoin { base, core }
103    }
104
105    pub fn create(
106        left: PlanRef,
107        right: PlanRef,
108        join_type: JoinType,
109        on_clause: ExprImpl,
110    ) -> PlanRef {
111        Self::new(left, right, join_type, Condition::with_expr(on_clause)).into()
112    }
113
114    pub fn internal_column_num(&self) -> usize {
115        self.core.internal_column_num()
116    }
117
118    pub fn i2l_col_mapping_ignore_join_type(&self) -> ColIndexMapping {
119        self.core.i2l_col_mapping_ignore_join_type()
120    }
121
122    pub fn i2r_col_mapping_ignore_join_type(&self) -> ColIndexMapping {
123        self.core.i2r_col_mapping_ignore_join_type()
124    }
125
126    /// Get a reference to the logical join's on.
127    pub fn on(&self) -> &Condition {
128        &self.core.on
129    }
130
131    pub fn core(&self) -> &generic::Join<PlanRef> {
132        &self.core
133    }
134
135    /// Collect all input ref in the on condition. And separate them into left and right.
136    pub fn input_idx_on_condition(&self) -> (Vec<usize>, Vec<usize>) {
137        let input_refs = self
138            .core
139            .on
140            .collect_input_refs(self.core.left.schema().len() + self.core.right.schema().len());
141        let index_group = input_refs
142            .ones()
143            .chunk_by(|i| *i < self.core.left.schema().len());
144        let left_index = index_group
145            .into_iter()
146            .next()
147            .map_or(vec![], |group| group.1.collect_vec());
148        let right_index = index_group.into_iter().next().map_or(vec![], |group| {
149            group
150                .1
151                .map(|i| i - self.core.left.schema().len())
152                .collect_vec()
153        });
154        (left_index, right_index)
155    }
156
157    /// Get the join type of the logical join.
158    pub fn join_type(&self) -> JoinType {
159        self.core.join_type
160    }
161
162    /// Get the eq join key of the logical join.
163    pub fn eq_indexes(&self) -> Vec<(usize, usize)> {
164        self.core.eq_indexes()
165    }
166
167    /// Get the output indices of the logical join.
168    pub fn output_indices(&self) -> &Vec<usize> {
169        &self.core.output_indices
170    }
171
172    /// Clone with new output indices
173    pub fn clone_with_output_indices(&self, output_indices: Vec<usize>) -> Self {
174        Self::with_core(generic::Join {
175            output_indices,
176            ..self.core.clone()
177        })
178    }
179
180    /// Clone with new `on` condition
181    pub fn clone_with_cond(&self, on: Condition) -> Self {
182        Self::with_core(generic::Join {
183            on,
184            ..self.core.clone()
185        })
186    }
187
188    pub fn is_left_join(&self) -> bool {
189        matches!(self.join_type(), JoinType::LeftSemi | JoinType::LeftAnti)
190    }
191
192    pub fn is_right_join(&self) -> bool {
193        matches!(self.join_type(), JoinType::RightSemi | JoinType::RightAnti)
194    }
195
196    pub fn is_full_out(&self) -> bool {
197        self.core.is_full_out()
198    }
199
200    pub fn is_asof_join(&self) -> bool {
201        self.join_type() == JoinType::AsofInner || self.join_type() == JoinType::AsofLeftOuter
202    }
203
204    pub fn output_indices_are_trivial(&self) -> bool {
205        self.output_indices() == &(0..self.internal_column_num()).collect_vec()
206    }
207
208    /// Try to simplify the outer join with the predicate on the top of the join
209    ///
210    /// now it is just a naive implementation for comparison expression, we can give a more general
211    /// implementation with constant folding in future
212    fn simplify_outer(predicate: &Condition, left_col_num: usize, join_type: JoinType) -> JoinType {
213        let (mut gen_null_in_left, mut gen_null_in_right) = match join_type {
214            JoinType::LeftOuter => (false, true),
215            JoinType::RightOuter => (true, false),
216            JoinType::FullOuter => (true, true),
217            _ => return join_type,
218        };
219
220        for expr in &predicate.conjunctions {
221            if let ExprImpl::FunctionCall(func) = expr {
222                match func.func_type() {
223                    ExprType::Equal
224                    | ExprType::NotEqual
225                    | ExprType::LessThan
226                    | ExprType::LessThanOrEqual
227                    | ExprType::GreaterThan
228                    | ExprType::GreaterThanOrEqual => {
229                        for input in func.inputs() {
230                            if let ExprImpl::InputRef(input) = input {
231                                let idx = input.index;
232                                if idx < left_col_num {
233                                    gen_null_in_left = false;
234                                } else {
235                                    gen_null_in_right = false;
236                                }
237                            }
238                        }
239                    }
240                    _ => {}
241                };
242            }
243        }
244
245        match (gen_null_in_left, gen_null_in_right) {
246            (true, true) => JoinType::FullOuter,
247            (true, false) => JoinType::RightOuter,
248            (false, true) => JoinType::LeftOuter,
249            (false, false) => JoinType::Inner,
250        }
251    }
252
253    /// Index Join:
254    /// Try to convert logical join into batch lookup join and meanwhile it will do
255    /// the index selection for the lookup table so that we can benefit from indexes.
256    fn to_batch_lookup_join_with_index_selection(
257        &self,
258        predicate: EqJoinPredicate,
259        logical_join: generic::Join<PlanRef>,
260    ) -> Result<Option<BatchLookupJoin>> {
261        match logical_join.join_type {
262            JoinType::Inner
263            | JoinType::LeftOuter
264            | JoinType::LeftSemi
265            | JoinType::LeftAnti
266            | JoinType::AsofInner
267            | JoinType::AsofLeftOuter => {}
268            _ => return Ok(None),
269        };
270
271        // Index selection for index join.
272        let right = self.right();
273        // Lookup Join only supports basic tables on the join's right side.
274        let logical_scan: &LogicalScan = if let Some(logical_scan) = right.as_logical_scan() {
275            logical_scan
276        } else {
277            return Ok(None);
278        };
279
280        let mut result_plan = None;
281        // Lookup primary table.
282        if let Some(lookup_join) =
283            self.to_batch_lookup_join(predicate.clone(), logical_join.clone())?
284        {
285            result_plan = Some(lookup_join);
286        }
287
288        let indexes = logical_scan.indexes();
289        for index in indexes {
290            if let Some(index_scan) = logical_scan.to_index_scan_if_index_covered(index) {
291                let index_scan: PlanRef = index_scan.into();
292                let that = self.clone_with_left_right(self.left(), index_scan.clone());
293                let mut new_logical_join = logical_join.clone();
294                new_logical_join.right = index_scan.to_batch().expect("index scan failed to batch");
295
296                // Lookup covered index.
297                if let Some(lookup_join) =
298                    that.to_batch_lookup_join(predicate.clone(), new_logical_join)?
299                {
300                    match &result_plan {
301                        None => result_plan = Some(lookup_join),
302                        Some(prev_lookup_join) => {
303                            // Prefer to choose lookup join with longer lookup prefix len.
304                            if prev_lookup_join.lookup_prefix_len()
305                                < lookup_join.lookup_prefix_len()
306                            {
307                                result_plan = Some(lookup_join)
308                            }
309                        }
310                    }
311                }
312            }
313        }
314
315        Ok(result_plan)
316    }
317
318    /// Try to convert logical join into batch lookup join.
319    fn to_batch_lookup_join(
320        &self,
321        predicate: EqJoinPredicate,
322        logical_join: generic::Join<PlanRef>,
323    ) -> Result<Option<BatchLookupJoin>> {
324        match logical_join.join_type {
325            JoinType::Inner
326            | JoinType::LeftOuter
327            | JoinType::LeftSemi
328            | JoinType::LeftAnti
329            | JoinType::AsofInner
330            | JoinType::AsofLeftOuter => {}
331            _ => return Ok(None),
332        };
333
334        let right = self.right();
335        // Lookup Join only supports basic tables on the join's right side.
336        let logical_scan: &LogicalScan = if let Some(logical_scan) = right.as_logical_scan() {
337            logical_scan
338        } else {
339            return Ok(None);
340        };
341        let table_desc = logical_scan.table_desc().clone();
342        let output_column_ids = logical_scan.output_column_ids();
343
344        // Verify that the right join key columns are the the prefix of the primary key and
345        // also contain the distribution key.
346        let order_col_ids = table_desc.order_column_ids();
347        let order_key = table_desc.order_column_indices();
348        let dist_key = table_desc.distribution_key.clone();
349        // The at least prefix of order key that contains distribution key.
350        let mut dist_key_in_order_key_pos = vec![];
351        for d in dist_key {
352            let pos = order_key
353                .iter()
354                .position(|&x| x == d)
355                .expect("dist_key must in order_key");
356            dist_key_in_order_key_pos.push(pos);
357        }
358        // The shortest prefix of order key that contains distribution key.
359        let shortest_prefix_len = dist_key_in_order_key_pos
360            .iter()
361            .max()
362            .map_or(0, |pos| pos + 1);
363
364        // Distributed lookup join can't support lookup table with a singleton distribution.
365        if shortest_prefix_len == 0 {
366            return Ok(None);
367        }
368
369        // Reorder the join equal predicate to match the order key.
370        let mut reorder_idx = Vec::with_capacity(shortest_prefix_len);
371        for order_col_id in order_col_ids {
372            let mut found = false;
373            for (i, eq_idx) in predicate.right_eq_indexes().into_iter().enumerate() {
374                if order_col_id == output_column_ids[eq_idx] {
375                    reorder_idx.push(i);
376                    found = true;
377                    break;
378                }
379            }
380            if !found {
381                break;
382            }
383        }
384        if reorder_idx.len() < shortest_prefix_len {
385            return Ok(None);
386        }
387        let lookup_prefix_len = reorder_idx.len();
388        let predicate = predicate.reorder(&reorder_idx);
389
390        // Extract the predicate from logical scan. Only pure scan is supported.
391        let (new_scan, scan_predicate, project_expr) = logical_scan.predicate_pull_up();
392        // Construct output column to require column mapping
393        let o2r = if let Some(project_expr) = project_expr {
394            project_expr
395                .into_iter()
396                .map(|x| x.as_input_ref().unwrap().index)
397                .collect_vec()
398        } else {
399            (0..logical_scan.output_col_idx().len()).collect_vec()
400        };
401        let left_schema_len = logical_join.left.schema().len();
402
403        let mut join_predicate_rewriter = LookupJoinPredicateRewriter {
404            offset: left_schema_len,
405            mapping: o2r.clone(),
406        };
407
408        let new_eq_cond = predicate
409            .eq_cond()
410            .rewrite_expr(&mut join_predicate_rewriter);
411
412        let mut scan_predicate_rewriter = LookupJoinScanPredicateRewriter {
413            offset: left_schema_len,
414        };
415
416        let new_other_cond = predicate
417            .other_cond()
418            .clone()
419            .rewrite_expr(&mut join_predicate_rewriter)
420            .and(scan_predicate.rewrite_expr(&mut scan_predicate_rewriter));
421
422        let new_join_on = new_eq_cond.and(new_other_cond);
423        let new_predicate = EqJoinPredicate::create(
424            left_schema_len,
425            new_scan.schema().len(),
426            new_join_on.clone(),
427        );
428
429        // We discovered that we cannot use a lookup join after pulling up the predicate
430        // from one side and simplifying the condition. Let's use some other join instead.
431        if !new_predicate.has_eq() {
432            return Ok(None);
433        }
434
435        // Rewrite the join output indices and all output indices referred to the old scan need to
436        // rewrite.
437        let new_join_output_indices = logical_join
438            .output_indices
439            .iter()
440            .map(|&x| {
441                if x < left_schema_len {
442                    x
443                } else {
444                    o2r[x - left_schema_len] + left_schema_len
445                }
446            })
447            .collect_vec();
448
449        let new_scan_output_column_ids = new_scan.output_column_ids();
450        let as_of = new_scan.as_of.clone();
451
452        // Construct a new logical join, because we have change its RHS.
453        let new_logical_join = generic::Join::new(
454            logical_join.left,
455            new_scan.into(),
456            new_join_on,
457            logical_join.join_type,
458            new_join_output_indices,
459        );
460
461        let asof_desc = self
462            .is_asof_join()
463            .then(|| {
464                Self::get_inequality_desc_from_predicate(
465                    predicate.other_cond().clone(),
466                    left_schema_len,
467                )
468            })
469            .transpose()?;
470
471        Ok(Some(BatchLookupJoin::new(
472            new_logical_join,
473            new_predicate,
474            table_desc,
475            new_scan_output_column_ids,
476            lookup_prefix_len,
477            false,
478            as_of,
479            asof_desc,
480        )))
481    }
482
483    pub fn decompose(self) -> (PlanRef, PlanRef, Condition, JoinType, Vec<usize>) {
484        self.core.decompose()
485    }
486}
487
488impl PlanTreeNodeBinary for LogicalJoin {
489    fn left(&self) -> PlanRef {
490        self.core.left.clone()
491    }
492
493    fn right(&self) -> PlanRef {
494        self.core.right.clone()
495    }
496
497    fn clone_with_left_right(&self, left: PlanRef, right: PlanRef) -> Self {
498        Self::with_core(generic::Join {
499            left,
500            right,
501            ..self.core.clone()
502        })
503    }
504
505    fn rewrite_with_left_right(
506        &self,
507        left: PlanRef,
508        left_col_change: ColIndexMapping,
509        right: PlanRef,
510        right_col_change: ColIndexMapping,
511    ) -> (Self, ColIndexMapping) {
512        let (new_on, new_output_indices) = {
513            let (mut map, _) = left_col_change.clone().into_parts();
514            let (mut right_map, _) = right_col_change.clone().into_parts();
515            for i in right_map.iter_mut().flatten() {
516                *i += left.schema().len();
517            }
518            map.append(&mut right_map);
519            let mut mapping = ColIndexMapping::new(map, left.schema().len() + right.schema().len());
520
521            let new_output_indices = self
522                .output_indices()
523                .iter()
524                .map(|&i| mapping.map(i))
525                .collect::<Vec<_>>();
526            let new_on = self.on().clone().rewrite_expr(&mut mapping);
527            (new_on, new_output_indices)
528        };
529
530        let join = Self::with_output_indices(
531            left,
532            right,
533            self.join_type(),
534            new_on,
535            new_output_indices.clone(),
536        );
537
538        let new_i2o = ColIndexMapping::with_remaining_columns(
539            &new_output_indices,
540            join.internal_column_num(),
541        );
542
543        let old_o2i = self.core.o2i_col_mapping();
544
545        let old_o2l = old_o2i
546            .composite(&self.core.i2l_col_mapping())
547            .composite(&left_col_change);
548        let old_o2r = old_o2i
549            .composite(&self.core.i2r_col_mapping())
550            .composite(&right_col_change);
551        let new_l2o = join.core.l2i_col_mapping().composite(&new_i2o);
552        let new_r2o = join.core.r2i_col_mapping().composite(&new_i2o);
553
554        let out_col_change = old_o2l
555            .composite(&new_l2o)
556            .union(&old_o2r.composite(&new_r2o));
557        (join, out_col_change)
558    }
559}
560
561impl_plan_tree_node_for_binary! { LogicalJoin }
562
563impl ColPrunable for LogicalJoin {
564    fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef {
565        // make `required_cols` point to internal table instead of output schema.
566        let required_cols = required_cols
567            .iter()
568            .map(|i| self.output_indices()[*i])
569            .collect_vec();
570        let left_len = self.left().schema().fields.len();
571
572        let total_len = self.left().schema().len() + self.right().schema().len();
573        let mut resized_required_cols = FixedBitSet::with_capacity(total_len);
574
575        required_cols.iter().for_each(|&i| {
576            if self.is_right_join() {
577                resized_required_cols.insert(left_len + i);
578            } else {
579                resized_required_cols.insert(i);
580            }
581        });
582
583        // add those columns which are required in the join condition to
584        // to those that are required in the output
585        let mut visitor = CollectInputRef::new(resized_required_cols);
586        self.on().visit_expr(&mut visitor);
587        let left_right_required_cols = FixedBitSet::from(visitor).ones().collect_vec();
588
589        let mut left_required_cols = Vec::new();
590        let mut right_required_cols = Vec::new();
591        left_right_required_cols.iter().for_each(|&i| {
592            if i < left_len {
593                left_required_cols.push(i);
594            } else {
595                right_required_cols.push(i - left_len);
596            }
597        });
598
599        let mut on = self.on().clone();
600        let mut mapping =
601            ColIndexMapping::with_remaining_columns(&left_right_required_cols, total_len);
602        on = on.rewrite_expr(&mut mapping);
603
604        let new_output_indices = {
605            let required_inputs_in_output = if self.is_left_join() {
606                &left_required_cols
607            } else if self.is_right_join() {
608                &right_required_cols
609            } else {
610                &left_right_required_cols
611            };
612
613            let mapping =
614                ColIndexMapping::with_remaining_columns(required_inputs_in_output, total_len);
615            required_cols.iter().map(|&i| mapping.map(i)).collect_vec()
616        };
617
618        LogicalJoin::with_output_indices(
619            self.left().prune_col(&left_required_cols, ctx),
620            self.right().prune_col(&right_required_cols, ctx),
621            self.join_type(),
622            on,
623            new_output_indices,
624        )
625        .into()
626    }
627}
628
629impl ExprRewritable for LogicalJoin {
630    fn has_rewritable_expr(&self) -> bool {
631        true
632    }
633
634    fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
635        let mut core = self.core.clone();
636        core.rewrite_exprs(r);
637        Self {
638            base: self.base.clone_with_new_plan_id(),
639            core,
640        }
641        .into()
642    }
643}
644
645impl ExprVisitable for LogicalJoin {
646    fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
647        self.core.visit_exprs(v);
648    }
649}
650
651/// We are trying to derive a predicate to apply to the other side of a join if all
652/// the `InputRef`s in the predicate are eq condition columns, and can hence be substituted
653/// with the corresponding eq condition columns of the other side.
654///
655/// Strategy:
656/// 1. If the function is pure except for any `InputRef` (which may refer to impure computation),
657///    then we proceed. Else abort.
658/// 2. Then, we collect `InputRef`s in the conjunction.
659/// 3. If they are all columns in the given side of join eq condition, then we proceed. Else abort.
660/// 4. We then rewrite the `ExprImpl`, by replacing `InputRef` column indices with the equivalent in
661///    the other side.
662///
663/// # Arguments
664///
665/// Suppose we derive a predicate from the left side to be pushed to the right side.
666/// * `expr`: An expr from the left side.
667/// * `col_num`: The number of columns in the left side.
668fn derive_predicate_from_eq_condition(
669    expr: &ExprImpl,
670    eq_condition: &EqJoinPredicate,
671    col_num: usize,
672    expr_is_left: bool,
673) -> Option<ExprImpl> {
674    if expr.is_impure() {
675        return None;
676    }
677    let eq_indices = eq_condition
678        .eq_indexes_typed()
679        .iter()
680        .filter_map(|(l, r)| {
681            if l.return_type() != r.return_type() {
682                None
683            } else if expr_is_left {
684                Some(l.index())
685            } else {
686                Some(r.index())
687            }
688        })
689        .collect_vec();
690    if expr
691        .collect_input_refs(col_num)
692        .ones()
693        .any(|index| !eq_indices.contains(&index))
694    {
695        // expr contains an InputRef not in eq_condition
696        return None;
697    }
698    // The function is pure except for `InputRef` and all `InputRef`s are `eq_condition` indices.
699    // Hence, we can substitute those `InputRef`s with indices from the other side.
700    let other_side_mapping = if expr_is_left {
701        eq_condition.eq_indexes_typed().into_iter().collect()
702    } else {
703        eq_condition
704            .eq_indexes_typed()
705            .into_iter()
706            .map(|(x, y)| (y, x))
707            .collect()
708    };
709    struct InputRefsRewriter {
710        mapping: HashMap<InputRef, InputRef>,
711    }
712    impl ExprRewriter for InputRefsRewriter {
713        fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
714            self.mapping[&input_ref].clone().into()
715        }
716    }
717    Some(
718        InputRefsRewriter {
719            mapping: other_side_mapping,
720        }
721        .rewrite_expr(expr.clone()),
722    )
723}
724
725/// Rewrite the join predicate and all columns referred to the scan side need to rewrite.
726struct LookupJoinPredicateRewriter {
727    offset: usize,
728    mapping: Vec<usize>,
729}
730impl ExprRewriter for LookupJoinPredicateRewriter {
731    fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
732        if input_ref.index() < self.offset {
733            input_ref.into()
734        } else {
735            InputRef::new(
736                self.mapping[input_ref.index() - self.offset] + self.offset,
737                input_ref.return_type(),
738            )
739            .into()
740        }
741    }
742}
743
744/// Rewrite the scan predicate so we can add it to the join predicate.
745struct LookupJoinScanPredicateRewriter {
746    offset: usize,
747}
748impl ExprRewriter for LookupJoinScanPredicateRewriter {
749    fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
750        InputRef::new(input_ref.index() + self.offset, input_ref.return_type()).into()
751    }
752}
753
754impl PredicatePushdown for LogicalJoin {
755    /// Pushes predicates above and within a join node into the join node and/or its children nodes.
756    ///
757    /// # Which predicates can be pushed
758    ///
759    /// For inner join, we can do all kinds of pushdown.
760    ///
761    /// For left/right semi join, we can push filter to left/right and on-clause,
762    /// and push on-clause to left/right.
763    ///
764    /// For left/right anti join, we can push filter to left/right, but on-clause can not be pushed
765    ///
766    /// ## Outer Join
767    ///
768    /// Preserved Row table
769    /// : The table in an Outer Join that must return all rows.
770    ///
771    /// Null Supplying table
772    /// : This is the table that has nulls filled in for its columns in unmatched rows.
773    ///
774    /// |                          | Preserved Row table | Null Supplying table |
775    /// |--------------------------|---------------------|----------------------|
776    /// | Join predicate (on)      | Not Pushed          | Pushed               |
777    /// | Where predicate (filter) | Pushed              | Not Pushed           |
778    fn predicate_pushdown(
779        &self,
780        predicate: Condition,
781        ctx: &mut PredicatePushdownContext,
782    ) -> PlanRef {
783        // rewrite output col referencing indices as internal cols
784        let mut predicate = {
785            let mut mapping = self.core.o2i_col_mapping();
786            predicate.rewrite_expr(&mut mapping)
787        };
788
789        let left_col_num = self.left().schema().len();
790        let right_col_num = self.right().schema().len();
791        let join_type = LogicalJoin::simplify_outer(&predicate, left_col_num, self.join_type());
792
793        let push_down_temporal_predicate = !self.should_be_temporal_join();
794
795        let (left_from_filter, right_from_filter, on) = push_down_into_join(
796            &mut predicate,
797            left_col_num,
798            right_col_num,
799            join_type,
800            push_down_temporal_predicate,
801        );
802
803        let mut new_on = self.on().clone().and(on);
804        let (left_from_on, right_from_on) = push_down_join_condition(
805            &mut new_on,
806            left_col_num,
807            right_col_num,
808            join_type,
809            push_down_temporal_predicate,
810        );
811
812        let left_predicate = left_from_filter.and(left_from_on);
813        let right_predicate = right_from_filter.and(right_from_on);
814
815        // Derive conditions to push to the other side based on eq condition columns
816        let eq_condition = EqJoinPredicate::create(left_col_num, right_col_num, new_on.clone());
817
818        // Only push to RHS if RHS is inner side of a join (RHS requires match on LHS)
819        let right_from_left = if matches!(
820            join_type,
821            JoinType::Inner | JoinType::LeftOuter | JoinType::RightSemi | JoinType::LeftSemi
822        ) {
823            Condition {
824                conjunctions: left_predicate
825                    .conjunctions
826                    .iter()
827                    .filter_map(|expr| {
828                        derive_predicate_from_eq_condition(expr, &eq_condition, left_col_num, true)
829                    })
830                    .collect(),
831            }
832        } else {
833            Condition::true_cond()
834        };
835
836        // Only push to LHS if LHS is inner side of a join (LHS requires match on RHS)
837        let left_from_right = if matches!(
838            join_type,
839            JoinType::Inner | JoinType::RightOuter | JoinType::LeftSemi | JoinType::RightSemi
840        ) {
841            Condition {
842                conjunctions: right_predicate
843                    .conjunctions
844                    .iter()
845                    .filter_map(|expr| {
846                        derive_predicate_from_eq_condition(
847                            expr,
848                            &eq_condition,
849                            right_col_num,
850                            false,
851                        )
852                    })
853                    .collect(),
854            }
855        } else {
856            Condition::true_cond()
857        };
858
859        let left_predicate = left_predicate.and(left_from_right);
860        let right_predicate = right_predicate.and(right_from_left);
861
862        let new_left = self.left().predicate_pushdown(left_predicate, ctx);
863        let new_right = self.right().predicate_pushdown(right_predicate, ctx);
864        let new_join = LogicalJoin::with_output_indices(
865            new_left,
866            new_right,
867            join_type,
868            new_on,
869            self.output_indices().clone(),
870        );
871
872        let mut mapping = self.core.i2o_col_mapping();
873        predicate = predicate.rewrite_expr(&mut mapping);
874        LogicalFilter::create(new_join.into(), predicate)
875    }
876}
877
878impl LogicalJoin {
879    fn get_stream_input_for_hash_join(
880        &self,
881        predicate: &EqJoinPredicate,
882        ctx: &mut ToStreamContext,
883    ) -> Result<(PlanRef, PlanRef)> {
884        use super::stream::prelude::*;
885
886        let mut right = self.right().to_stream_with_dist_required(
887            &RequiredDist::shard_by_key(self.right().schema().len(), &predicate.right_eq_indexes()),
888            ctx,
889        )?;
890        let mut left = self.left();
891
892        let r2l = predicate.r2l_eq_columns_mapping(left.schema().len(), right.schema().len());
893        let l2r = predicate.l2r_eq_columns_mapping(left.schema().len(), right.schema().len());
894
895        let right_dist = right.distribution();
896        match right_dist {
897            Distribution::HashShard(_) => {
898                let left_dist = r2l
899                    .rewrite_required_distribution(&RequiredDist::PhysicalDist(right_dist.clone()));
900                left = left.to_stream_with_dist_required(&left_dist, ctx)?;
901            }
902            Distribution::UpstreamHashShard(_, _) => {
903                left = left.to_stream_with_dist_required(
904                    &RequiredDist::shard_by_key(
905                        self.left().schema().len(),
906                        &predicate.left_eq_indexes(),
907                    ),
908                    ctx,
909                )?;
910                let left_dist = left.distribution();
911                match left_dist {
912                    Distribution::HashShard(_) => {
913                        let right_dist = l2r.rewrite_required_distribution(
914                            &RequiredDist::PhysicalDist(left_dist.clone()),
915                        );
916                        right = right_dist.enforce_if_not_satisfies(right, &Order::any())?
917                    }
918                    Distribution::UpstreamHashShard(_, _) => {
919                        left = RequiredDist::hash_shard(&predicate.left_eq_indexes())
920                            .enforce_if_not_satisfies(left, &Order::any())?;
921                        right = RequiredDist::hash_shard(&predicate.right_eq_indexes())
922                            .enforce_if_not_satisfies(right, &Order::any())?;
923                    }
924                    _ => unreachable!(),
925                }
926            }
927            _ => unreachable!(),
928        }
929        Ok((left, right))
930    }
931
932    fn to_stream_hash_join(
933        &self,
934        predicate: EqJoinPredicate,
935        ctx: &mut ToStreamContext,
936    ) -> Result<PlanRef> {
937        use super::stream::prelude::*;
938
939        assert!(predicate.has_eq());
940        let (left, right) = self.get_stream_input_for_hash_join(&predicate, ctx)?;
941
942        let logical_join = self.clone_with_left_right(left, right);
943
944        // Convert to Hash Join for equal joins
945        // For inner joins, pull non-equal conditions to a filter operator on top of it by default.
946        // We do so as the filter operator can apply the non-equal condition batch-wise (vectorized)
947        // as opposed to the HashJoin, which applies the condition row-wise.
948        // However, the default behavior of pulling up non-equal conditions can be overridden by the
949        // session variable `streaming_force_filter_inside_join` as it can save unnecessary
950        // materialization of rows only to be filtered later.
951
952        let stream_hash_join = StreamHashJoin::new(logical_join.core.clone(), predicate.clone());
953
954        let force_filter_inside_join = self
955            .base
956            .ctx()
957            .session_ctx()
958            .config()
959            .streaming_force_filter_inside_join();
960
961        let pull_filter = self.join_type() == JoinType::Inner
962            && stream_hash_join.eq_join_predicate().has_non_eq()
963            && stream_hash_join.inequality_pairs().is_empty()
964            && (!force_filter_inside_join);
965        if pull_filter {
966            let default_indices = (0..self.internal_column_num()).collect::<Vec<_>>();
967
968            // Temporarily remove output indices.
969            let logical_join = logical_join.clone_with_output_indices(default_indices.clone());
970            let eq_cond = EqJoinPredicate::new(
971                Condition::true_cond(),
972                predicate.eq_keys().to_vec(),
973                self.left().schema().len(),
974                self.right().schema().len(),
975            );
976            let logical_join = logical_join.clone_with_cond(eq_cond.eq_cond());
977            let hash_join = StreamHashJoin::new(logical_join.core, eq_cond).into();
978            let logical_filter = generic::Filter::new(predicate.non_eq_cond(), hash_join);
979            let plan = StreamFilter::new(logical_filter).into();
980            if self.output_indices() != &default_indices {
981                let logical_project = generic::Project::with_mapping(
982                    plan,
983                    ColIndexMapping::with_remaining_columns(
984                        self.output_indices(),
985                        self.internal_column_num(),
986                    ),
987                );
988                Ok(StreamProject::new(logical_project).into())
989            } else {
990                Ok(plan)
991            }
992        } else {
993            Ok(stream_hash_join.into())
994        }
995    }
996
997    fn should_be_temporal_join(&self) -> bool {
998        let right = self.right();
999        if let Some(logical_scan) = right.as_logical_scan() {
1000            matches!(logical_scan.as_of(), Some(AsOf::ProcessTime))
1001        } else {
1002            false
1003        }
1004    }
1005
1006    fn to_stream_temporal_join_with_index_selection(
1007        &self,
1008        predicate: EqJoinPredicate,
1009        ctx: &mut ToStreamContext,
1010    ) -> Result<PlanRef> {
1011        // Index selection for temporal join.
1012        let right = self.right();
1013        // `should_be_temporal_join()` has already check right input for us.
1014        let logical_scan: &LogicalScan = right.as_logical_scan().unwrap();
1015
1016        // Use primary table.
1017        let mut result_plan: Result<StreamTemporalJoin> =
1018            self.to_stream_temporal_join(predicate.clone(), ctx);
1019        // Return directly if this temporal join can match the pk of its right table.
1020        if let Ok(temporal_join) = &result_plan
1021            && temporal_join.eq_join_predicate().eq_indexes().len()
1022                == logical_scan.primary_key().len()
1023        {
1024            return result_plan.map(|x| x.into());
1025        }
1026        let indexes = logical_scan.indexes();
1027        for index in indexes {
1028            // Use index table
1029            if let Some(index_scan) = logical_scan.to_index_scan_if_index_covered(index) {
1030                let index_scan: PlanRef = index_scan.into();
1031                let that = self.clone_with_left_right(self.left(), index_scan.clone());
1032                if let Ok(temporal_join) = that.to_stream_temporal_join(predicate.clone(), ctx) {
1033                    match &result_plan {
1034                        Err(_) => result_plan = Ok(temporal_join),
1035                        Ok(prev_temporal_join) => {
1036                            // Prefer to the temporal join with a longer lookup prefix len.
1037                            if prev_temporal_join.eq_join_predicate().eq_indexes().len()
1038                                < temporal_join.eq_join_predicate().eq_indexes().len()
1039                            {
1040                                result_plan = Ok(temporal_join)
1041                            }
1042                        }
1043                    }
1044                }
1045            }
1046        }
1047
1048        result_plan.map(|x| x.into())
1049    }
1050
1051    fn check_temporal_rhs(right: &PlanRef) -> Result<&LogicalScan> {
1052        let Some(logical_scan) = right.as_logical_scan() else {
1053            return Err(RwError::from(ErrorCode::NotSupported(
1054                "Temporal join requires a table scan as its lookup table".into(),
1055                "Please provide a table scan".into(),
1056            )));
1057        };
1058
1059        if !matches!(logical_scan.as_of(), Some(AsOf::ProcessTime)) {
1060            return Err(RwError::from(ErrorCode::NotSupported(
1061                "Temporal join requires a table defined as temporal table".into(),
1062                "Please use FOR SYSTEM_TIME AS OF PROCTIME() syntax".into(),
1063            )));
1064        }
1065        Ok(logical_scan)
1066    }
1067
1068    fn temporal_join_scan_predicate_pull_up(
1069        logical_scan: &LogicalScan,
1070        predicate: EqJoinPredicate,
1071        output_indices: &[usize],
1072        left_schema_len: usize,
1073    ) -> Result<(StreamTableScan, EqJoinPredicate, Condition, Vec<usize>)> {
1074        // Extract the predicate from logical scan. Only pure scan is supported.
1075        let (new_scan, scan_predicate, project_expr) = logical_scan.predicate_pull_up();
1076        // Construct output column to require column mapping
1077        let o2r = if let Some(project_expr) = project_expr {
1078            project_expr
1079                .into_iter()
1080                .map(|x| x.as_input_ref().unwrap().index)
1081                .collect_vec()
1082        } else {
1083            (0..logical_scan.output_col_idx().len()).collect_vec()
1084        };
1085        let mut join_predicate_rewriter = LookupJoinPredicateRewriter {
1086            offset: left_schema_len,
1087            mapping: o2r.clone(),
1088        };
1089
1090        let new_eq_cond = predicate
1091            .eq_cond()
1092            .rewrite_expr(&mut join_predicate_rewriter);
1093
1094        let mut scan_predicate_rewriter = LookupJoinScanPredicateRewriter {
1095            offset: left_schema_len,
1096        };
1097
1098        let new_other_cond = predicate
1099            .other_cond()
1100            .clone()
1101            .rewrite_expr(&mut join_predicate_rewriter)
1102            .and(scan_predicate.rewrite_expr(&mut scan_predicate_rewriter));
1103
1104        let new_join_on = new_eq_cond.and(new_other_cond);
1105
1106        let new_predicate = EqJoinPredicate::create(
1107            left_schema_len,
1108            new_scan.schema().len(),
1109            new_join_on.clone(),
1110        );
1111
1112        // Rewrite the join output indices and all output indices referred to the old scan need to
1113        // rewrite.
1114        let new_join_output_indices = output_indices
1115            .iter()
1116            .map(|&x| {
1117                if x < left_schema_len {
1118                    x
1119                } else {
1120                    o2r[x - left_schema_len] + left_schema_len
1121                }
1122            })
1123            .collect_vec();
1124        // Use UpstreamOnly chain type
1125        let new_stream_table_scan =
1126            StreamTableScan::new_with_stream_scan_type(new_scan, StreamScanType::UpstreamOnly);
1127        Ok((
1128            new_stream_table_scan,
1129            new_predicate,
1130            new_join_on,
1131            new_join_output_indices,
1132        ))
1133    }
1134
1135    fn to_stream_temporal_join(
1136        &self,
1137        predicate: EqJoinPredicate,
1138        ctx: &mut ToStreamContext,
1139    ) -> Result<StreamTemporalJoin> {
1140        use super::stream::prelude::*;
1141
1142        assert!(predicate.has_eq());
1143
1144        let right = self.right();
1145
1146        let logical_scan = Self::check_temporal_rhs(&right)?;
1147
1148        let table_desc = logical_scan.table_desc();
1149        let output_column_ids = logical_scan.output_column_ids();
1150
1151        // Verify that the right join key columns are the the prefix of the primary key and
1152        // also contain the distribution key.
1153        let order_col_ids = table_desc.order_column_ids();
1154        let order_key = table_desc.order_column_indices();
1155        let dist_key = table_desc.distribution_key.clone();
1156
1157        let mut dist_key_in_order_key_pos = vec![];
1158        for d in dist_key {
1159            let pos = order_key
1160                .iter()
1161                .position(|&x| x == d)
1162                .expect("dist_key must in order_key");
1163            dist_key_in_order_key_pos.push(pos);
1164        }
1165        // The shortest prefix of order key that contains distribution key.
1166        let shortest_prefix_len = dist_key_in_order_key_pos
1167            .iter()
1168            .max()
1169            .map_or(0, |pos| pos + 1);
1170
1171        // Reorder the join equal predicate to match the order key.
1172        let mut reorder_idx = Vec::with_capacity(shortest_prefix_len);
1173        for order_col_id in order_col_ids {
1174            let mut found = false;
1175            for (i, eq_idx) in predicate.right_eq_indexes().into_iter().enumerate() {
1176                if order_col_id == output_column_ids[eq_idx] {
1177                    reorder_idx.push(i);
1178                    found = true;
1179                    break;
1180                }
1181            }
1182            if !found {
1183                break;
1184            }
1185        }
1186        if reorder_idx.len() < shortest_prefix_len {
1187            // TODO: support index selection for temporal join and refine this error message.
1188            return Err(RwError::from(ErrorCode::NotSupported(
1189                "Temporal join requires the lookup table's primary key contained exactly in the equivalence condition".into(),
1190                "Please add the primary key of the lookup table to the join condition and remove any other conditions".into(),
1191            )));
1192        }
1193        let lookup_prefix_len = reorder_idx.len();
1194        let predicate = predicate.reorder(&reorder_idx);
1195
1196        let required_dist = if dist_key_in_order_key_pos.is_empty() {
1197            RequiredDist::single()
1198        } else {
1199            let left_eq_indexes = predicate.left_eq_indexes();
1200            let left_dist_key = dist_key_in_order_key_pos
1201                .iter()
1202                .map(|pos| left_eq_indexes[*pos])
1203                .collect_vec();
1204
1205            RequiredDist::hash_shard(&left_dist_key)
1206        };
1207
1208        let left = self.left().to_stream(ctx)?;
1209        // Enforce a shuffle for the temporal join LHS to let the scheduler be able to schedule the join fragment together with the RHS with a `no_shuffle` exchange.
1210        let left = required_dist.enforce(left, &Order::any());
1211
1212        let (new_stream_table_scan, new_predicate, new_join_on, new_join_output_indices) =
1213            Self::temporal_join_scan_predicate_pull_up(
1214                logical_scan,
1215                predicate,
1216                self.output_indices(),
1217                self.left().schema().len(),
1218            )?;
1219
1220        let right = RequiredDist::no_shuffle(new_stream_table_scan.into());
1221        if !new_predicate.has_eq() {
1222            return Err(RwError::from(ErrorCode::NotSupported(
1223                "Temporal join requires a non trivial join condition".into(),
1224                "Please remove the false condition of the join".into(),
1225            )));
1226        }
1227
1228        // Construct a new logical join, because we have change its RHS.
1229        let new_logical_join = generic::Join::new(
1230            left,
1231            right,
1232            new_join_on,
1233            self.join_type(),
1234            new_join_output_indices,
1235        );
1236
1237        let new_predicate = new_predicate.retain_prefix_eq_key(lookup_prefix_len);
1238
1239        Ok(StreamTemporalJoin::new(
1240            new_logical_join,
1241            new_predicate,
1242            false,
1243        ))
1244    }
1245
1246    fn to_stream_nested_loop_temporal_join(
1247        &self,
1248        predicate: EqJoinPredicate,
1249        ctx: &mut ToStreamContext,
1250    ) -> Result<PlanRef> {
1251        use super::stream::prelude::*;
1252        assert!(!predicate.has_eq());
1253
1254        let left = self.left().to_stream_with_dist_required(
1255            &RequiredDist::PhysicalDist(Distribution::Broadcast),
1256            ctx,
1257        )?;
1258        assert!(left.as_stream_exchange().is_some());
1259
1260        if self.join_type() != JoinType::Inner {
1261            return Err(RwError::from(ErrorCode::NotSupported(
1262                "Temporal join requires an inner join".into(),
1263                "Please use an inner join".into(),
1264            )));
1265        }
1266
1267        if !left.append_only() {
1268            return Err(RwError::from(ErrorCode::NotSupported(
1269                "Nested-loop Temporal join requires the left hash side to be append only".into(),
1270                "Please ensure the left hash side is append only".into(),
1271            )));
1272        }
1273
1274        let right = self.right();
1275        let logical_scan = Self::check_temporal_rhs(&right)?;
1276
1277        let (new_stream_table_scan, new_predicate, new_join_on, new_join_output_indices) =
1278            Self::temporal_join_scan_predicate_pull_up(
1279                logical_scan,
1280                predicate,
1281                self.output_indices(),
1282                self.left().schema().len(),
1283            )?;
1284
1285        let right = RequiredDist::no_shuffle(new_stream_table_scan.into());
1286
1287        // Construct a new logical join, because we have change its RHS.
1288        let new_logical_join = generic::Join::new(
1289            left,
1290            right,
1291            new_join_on,
1292            self.join_type(),
1293            new_join_output_indices,
1294        );
1295
1296        Ok(StreamTemporalJoin::new(new_logical_join, new_predicate, true).into())
1297    }
1298
1299    fn to_stream_dynamic_filter(
1300        &self,
1301        predicate: Condition,
1302        ctx: &mut ToStreamContext,
1303    ) -> Result<Option<PlanRef>> {
1304        use super::stream::prelude::*;
1305
1306        // If there is exactly one predicate, it is a comparison (<, <=, >, >=), and the
1307        // join is a `Inner` or `LeftSemi` join, we can convert the scalar subquery into a
1308        // `StreamDynamicFilter`
1309
1310        // Check if `Inner`/`LeftSemi`
1311        if !matches!(self.join_type(), JoinType::Inner | JoinType::LeftSemi) {
1312            return Ok(None);
1313        }
1314
1315        // Check if right side is a scalar
1316        if !self.right().max_one_row() {
1317            return Ok(None);
1318        }
1319        if self.right().schema().len() != 1 {
1320            return Ok(None);
1321        }
1322
1323        // Check if the join condition is a correlated comparison
1324        if predicate.conjunctions.len() > 1 {
1325            return Ok(None);
1326        }
1327        let expr: ExprImpl = predicate.into();
1328        let (left_ref, comparator, right_ref) = match expr.as_comparison_cond() {
1329            Some(v) => v,
1330            None => return Ok(None),
1331        };
1332
1333        let condition_cross_inputs = left_ref.index < self.left().schema().len()
1334            && right_ref.index == self.left().schema().len() /* right side has only one column */;
1335        if !condition_cross_inputs {
1336            // Maybe we should panic here because it means some predicates are not pushed down.
1337            return Ok(None);
1338        }
1339
1340        // We align input types on all join predicates with cmp operator
1341        if self.left().schema().fields()[left_ref.index].data_type
1342            != self.right().schema().fields()[0].data_type
1343        {
1344            return Ok(None);
1345        }
1346
1347        // Check if non of the columns from the inner side is required to output
1348        let all_output_from_left = self
1349            .output_indices()
1350            .iter()
1351            .all(|i| *i < self.left().schema().len());
1352        if !all_output_from_left {
1353            return Ok(None);
1354        }
1355
1356        let left = self.left().to_stream(ctx)?;
1357        let right = self.right().to_stream_with_dist_required(
1358            &RequiredDist::PhysicalDist(Distribution::Broadcast),
1359            ctx,
1360        )?;
1361
1362        assert!(right.as_stream_exchange().is_some());
1363        assert_eq!(
1364            *right.inputs().iter().exactly_one().unwrap().distribution(),
1365            Distribution::Single
1366        );
1367
1368        let core = DynamicFilter::new(comparator, left_ref.index, left, right);
1369        let plan = StreamDynamicFilter::new(core).into();
1370        // TODO: `DynamicFilterExecutor` should support `output_indices` in `ChunkBuilder`
1371        if self
1372            .output_indices()
1373            .iter()
1374            .copied()
1375            .ne(0..self.left().schema().len())
1376        {
1377            // The schema of dynamic filter is always the same as the left side now, and we have
1378            // checked that all output columns are from the left side before.
1379            let logical_project = generic::Project::with_mapping(
1380                plan,
1381                ColIndexMapping::with_remaining_columns(
1382                    self.output_indices(),
1383                    self.left().schema().len(),
1384                ),
1385            );
1386            Ok(Some(StreamProject::new(logical_project).into()))
1387        } else {
1388            Ok(Some(plan))
1389        }
1390    }
1391
1392    pub fn index_lookup_join_to_batch_lookup_join(&self) -> Result<PlanRef> {
1393        let predicate = EqJoinPredicate::create(
1394            self.left().schema().len(),
1395            self.right().schema().len(),
1396            self.on().clone(),
1397        );
1398        assert!(predicate.has_eq());
1399
1400        let mut logical_join = self.core.clone();
1401        logical_join.left = logical_join.left.to_batch()?;
1402        logical_join.right = logical_join.right.to_batch()?;
1403
1404        Ok(self
1405            .to_batch_lookup_join(predicate, logical_join)?
1406            .expect("Fail to convert to lookup join")
1407            .into())
1408    }
1409
1410    fn to_stream_asof_join(
1411        &self,
1412        predicate: EqJoinPredicate,
1413        ctx: &mut ToStreamContext,
1414    ) -> Result<PlanRef> {
1415        use super::stream::prelude::*;
1416
1417        if predicate.eq_keys().is_empty() {
1418            return Err(ErrorCode::InvalidInputSyntax(
1419                "AsOf join requires at least 1 equal condition".to_owned(),
1420            )
1421            .into());
1422        }
1423
1424        let (left, right) = self.get_stream_input_for_hash_join(&predicate, ctx)?;
1425        let left_len = left.schema().len();
1426        let logical_join = self.clone_with_left_right(left, right);
1427
1428        let inequality_desc =
1429            Self::get_inequality_desc_from_predicate(predicate.other_cond().clone(), left_len)?;
1430
1431        Ok(StreamAsOfJoin::new(logical_join.core.clone(), predicate, inequality_desc).into())
1432    }
1433
1434    /// Convert the logical join to a Hash join.
1435    fn to_batch_hash_join(
1436        &self,
1437        logical_join: generic::Join<PlanRef>,
1438        predicate: EqJoinPredicate,
1439    ) -> Result<PlanRef> {
1440        use super::batch::prelude::*;
1441
1442        let left_schema_len = logical_join.left.schema().len();
1443        let asof_desc = self
1444            .is_asof_join()
1445            .then(|| {
1446                Self::get_inequality_desc_from_predicate(
1447                    predicate.other_cond().clone(),
1448                    left_schema_len,
1449                )
1450            })
1451            .transpose()?;
1452
1453        let batch_join = BatchHashJoin::new(logical_join, predicate, asof_desc);
1454        Ok(batch_join.into())
1455    }
1456
1457    pub fn get_inequality_desc_from_predicate(
1458        predicate: Condition,
1459        left_input_len: usize,
1460    ) -> Result<AsOfJoinDesc> {
1461        let expr: ExprImpl = predicate.into();
1462        if let Some((left_input_ref, expr_type, right_input_ref)) = expr.as_comparison_cond() {
1463            if left_input_ref.index() < left_input_len && right_input_ref.index() >= left_input_len
1464            {
1465                Ok(AsOfJoinDesc {
1466                    left_idx: left_input_ref.index() as u32,
1467                    right_idx: (right_input_ref.index() - left_input_len) as u32,
1468                    inequality_type: Self::expr_type_to_comparison_type(expr_type)?.into(),
1469                })
1470            } else {
1471                bail!("inequal condition from the same side should be push down in optimizer");
1472            }
1473        } else {
1474            Err(ErrorCode::InvalidInputSyntax(
1475                "AsOf join requires exactly 1 ineuquality condition".to_owned(),
1476            )
1477            .into())
1478        }
1479    }
1480
1481    fn expr_type_to_comparison_type(expr_type: PbType) -> Result<PbAsOfJoinInequalityType> {
1482        match expr_type {
1483            PbType::LessThan => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeLt),
1484            PbType::LessThanOrEqual => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeLe),
1485            PbType::GreaterThan => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeGt),
1486            PbType::GreaterThanOrEqual => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeGe),
1487            _ => Err(ErrorCode::InvalidInputSyntax(format!(
1488                "Invalid comparison type: {}",
1489                expr_type.as_str_name()
1490            ))
1491            .into()),
1492        }
1493    }
1494}
1495
1496impl ToBatch for LogicalJoin {
1497    fn to_batch(&self) -> Result<PlanRef> {
1498        let predicate = EqJoinPredicate::create(
1499            self.left().schema().len(),
1500            self.right().schema().len(),
1501            self.on().clone(),
1502        );
1503
1504        let mut logical_join = self.core.clone();
1505        logical_join.left = logical_join.left.to_batch()?;
1506        logical_join.right = logical_join.right.to_batch()?;
1507
1508        let ctx = self.base.ctx();
1509        let config = ctx.session_ctx().config();
1510
1511        if predicate.has_eq() {
1512            if !predicate.eq_keys_are_type_aligned() {
1513                return Err(ErrorCode::InternalError(format!(
1514                    "Join eq keys are not aligned for predicate: {predicate:?}"
1515                ))
1516                .into());
1517            }
1518            if config.batch_enable_lookup_join()
1519                && let Some(lookup_join) = self.to_batch_lookup_join_with_index_selection(
1520                    predicate.clone(),
1521                    logical_join.clone(),
1522                )?
1523            {
1524                return Ok(lookup_join.into());
1525            }
1526            self.to_batch_hash_join(logical_join, predicate)
1527        } else if self.is_asof_join() {
1528            Err(ErrorCode::InvalidInputSyntax(
1529                "AsOf join requires at least 1 equal condition".to_owned(),
1530            )
1531            .into())
1532        } else {
1533            // Convert to Nested-loop Join for non-equal joins
1534            Ok(BatchNestedLoopJoin::new(logical_join).into())
1535        }
1536    }
1537}
1538
1539impl ToStream for LogicalJoin {
1540    fn to_stream(&self, ctx: &mut ToStreamContext) -> Result<PlanRef> {
1541        if self
1542            .on()
1543            .conjunctions
1544            .iter()
1545            .any(|cond| cond.count_nows() > 0)
1546        {
1547            return Err(ErrorCode::NotSupported(
1548                "optimizer has tried to separate the temporal predicate(with now() expression) from the on condition, but it still reminded in on join's condition. Considering move it into WHERE clause?".to_owned(),
1549                 "please refer to https://www.risingwave.dev/docs/current/sql-pattern-temporal-filters/ for more information".to_owned()).into());
1550        }
1551
1552        let predicate = EqJoinPredicate::create(
1553            self.left().schema().len(),
1554            self.right().schema().len(),
1555            self.on().clone(),
1556        );
1557
1558        if self.join_type() == JoinType::AsofInner || self.join_type() == JoinType::AsofLeftOuter {
1559            self.to_stream_asof_join(predicate, ctx)
1560        } else if predicate.has_eq() {
1561            if !predicate.eq_keys_are_type_aligned() {
1562                return Err(ErrorCode::InternalError(format!(
1563                    "Join eq keys are not aligned for predicate: {predicate:?}"
1564                ))
1565                .into());
1566            }
1567
1568            if self.should_be_temporal_join() {
1569                self.to_stream_temporal_join_with_index_selection(predicate, ctx)
1570            } else {
1571                self.to_stream_hash_join(predicate, ctx)
1572            }
1573        } else if self.should_be_temporal_join() {
1574            self.to_stream_nested_loop_temporal_join(predicate, ctx)
1575        } else if let Some(dynamic_filter) =
1576            self.to_stream_dynamic_filter(self.on().clone(), ctx)?
1577        {
1578            Ok(dynamic_filter)
1579        } else {
1580            Err(RwError::from(ErrorCode::NotSupported(
1581                "streaming nested-loop join".to_owned(),
1582                "The non-equal join in the query requires a nested-loop join executor, which could be very expensive to run. \
1583                 Consider rewriting the query to use dynamic filter as a substitute if possible.\n\
1584                 See also: https://docs.risingwave.com/docs/current/sql-pattern-dynamic-filters/".to_owned(),
1585            )))
1586        }
1587    }
1588
1589    fn logical_rewrite_for_stream(
1590        &self,
1591        ctx: &mut RewriteStreamContext,
1592    ) -> Result<(PlanRef, ColIndexMapping)> {
1593        let (left, left_col_change) = self.left().logical_rewrite_for_stream(ctx)?;
1594        let left_len = left.schema().len();
1595        let (right, right_col_change) = self.right().logical_rewrite_for_stream(ctx)?;
1596        let (join, out_col_change) = self.rewrite_with_left_right(
1597            left.clone(),
1598            left_col_change,
1599            right.clone(),
1600            right_col_change,
1601        );
1602
1603        let mapping = ColIndexMapping::with_remaining_columns(
1604            join.output_indices(),
1605            join.internal_column_num(),
1606        );
1607
1608        let l2o = join.core.l2i_col_mapping().composite(&mapping);
1609        let r2o = join.core.r2i_col_mapping().composite(&mapping);
1610
1611        // Add missing pk indices to the logical join
1612        let mut left_to_add = left
1613            .expect_stream_key()
1614            .iter()
1615            .cloned()
1616            .filter(|i| l2o.try_map(*i).is_none())
1617            .collect_vec();
1618
1619        let mut right_to_add = right
1620            .expect_stream_key()
1621            .iter()
1622            .filter(|&&i| r2o.try_map(i).is_none())
1623            .map(|&i| i + left_len)
1624            .collect_vec();
1625
1626        // NOTE(st1page): add join keys in the pk_indices a work around before we really have stream
1627        // key.
1628        let right_len = right.schema().len();
1629        let eq_predicate = EqJoinPredicate::create(left_len, right_len, join.on().clone());
1630
1631        let either_or_both = self.core.add_which_join_key_to_pk();
1632
1633        for (lk, rk) in eq_predicate.eq_indexes() {
1634            match either_or_both {
1635                EitherOrBoth::Left(_) => {
1636                    if l2o.try_map(lk).is_none() {
1637                        left_to_add.push(lk);
1638                    }
1639                }
1640                EitherOrBoth::Right(_) => {
1641                    if r2o.try_map(rk).is_none() {
1642                        right_to_add.push(rk + left_len)
1643                    }
1644                }
1645                EitherOrBoth::Both(_, _) => {
1646                    if l2o.try_map(lk).is_none() {
1647                        left_to_add.push(lk);
1648                    }
1649                    if r2o.try_map(rk).is_none() {
1650                        right_to_add.push(rk + left_len)
1651                    }
1652                }
1653            };
1654        }
1655        let left_to_add = left_to_add.into_iter().unique();
1656        let right_to_add = right_to_add.into_iter().unique();
1657        // NOTE(st1page) over
1658
1659        let mut new_output_indices = join.output_indices().clone();
1660        if !join.is_right_join() {
1661            new_output_indices.extend(left_to_add);
1662        }
1663        if !join.is_left_join() {
1664            new_output_indices.extend(right_to_add);
1665        }
1666
1667        let join_with_pk = join.clone_with_output_indices(new_output_indices);
1668
1669        let plan = if join_with_pk.join_type() == JoinType::FullOuter {
1670            // ignore the all NULL to maintain the stream key's uniqueness, see https://github.com/risingwavelabs/risingwave/issues/8084 for more information
1671
1672            let l2o = join_with_pk
1673                .core
1674                .l2i_col_mapping()
1675                .composite(&join_with_pk.core.i2o_col_mapping());
1676            let r2o = join_with_pk
1677                .core
1678                .r2i_col_mapping()
1679                .composite(&join_with_pk.core.i2o_col_mapping());
1680            let left_right_stream_keys = join_with_pk
1681                .left()
1682                .expect_stream_key()
1683                .iter()
1684                .map(|i| l2o.map(*i))
1685                .chain(
1686                    join_with_pk
1687                        .right()
1688                        .expect_stream_key()
1689                        .iter()
1690                        .map(|i| r2o.map(*i)),
1691                )
1692                .collect_vec();
1693            let plan: PlanRef = join_with_pk.into();
1694            LogicalFilter::filter_out_all_null_keys(plan, &left_right_stream_keys)
1695        } else {
1696            join_with_pk.into()
1697        };
1698
1699        // the added columns is at the end, so it will not change the exists column index
1700        Ok((plan, out_col_change))
1701    }
1702}
1703
1704#[cfg(test)]
1705mod tests {
1706
1707    use std::collections::HashSet;
1708
1709    use risingwave_common::catalog::{Field, Schema};
1710    use risingwave_common::types::{DataType, Datum};
1711    use risingwave_pb::expr::expr_node::Type;
1712
1713    use super::*;
1714    use crate::expr::{FunctionCall, Literal, assert_eq_input_ref};
1715    use crate::optimizer::optimizer_context::OptimizerContext;
1716    use crate::optimizer::plan_node::LogicalValues;
1717    use crate::optimizer::property::FunctionalDependency;
1718
1719    /// Pruning
1720    /// ```text
1721    /// Join(on: input_ref(1)=input_ref(3))
1722    ///   TableScan(v1, v2, v3)
1723    ///   TableScan(v4, v5, v6)
1724    /// ```
1725    /// with required columns [2,3] will result in
1726    /// ```text
1727    /// Project(input_ref(1), input_ref(2))
1728    ///   Join(on: input_ref(0)=input_ref(2))
1729    ///     TableScan(v2, v3)
1730    ///     TableScan(v4)
1731    /// ```
1732    #[tokio::test]
1733    async fn test_prune_join() {
1734        let ty = DataType::Int32;
1735        let ctx = OptimizerContext::mock().await;
1736        let fields: Vec<Field> = (1..7)
1737            .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
1738            .collect();
1739        let left = LogicalValues::new(
1740            vec![],
1741            Schema {
1742                fields: fields[0..3].to_vec(),
1743            },
1744            ctx.clone(),
1745        );
1746        let right = LogicalValues::new(
1747            vec![],
1748            Schema {
1749                fields: fields[3..6].to_vec(),
1750            },
1751            ctx,
1752        );
1753        let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
1754            FunctionCall::new(
1755                Type::Equal,
1756                vec![
1757                    ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
1758                    ExprImpl::InputRef(Box::new(InputRef::new(3, ty))),
1759                ],
1760            )
1761            .unwrap(),
1762        ));
1763        let join_type = JoinType::Inner;
1764        let join: PlanRef = LogicalJoin::new(
1765            left.into(),
1766            right.into(),
1767            join_type,
1768            Condition::with_expr(on),
1769        )
1770        .into();
1771
1772        // Perform the prune
1773        let required_cols = vec![2, 3];
1774        let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1775
1776        // Check the result
1777        let join = plan.as_logical_join().unwrap();
1778        assert_eq!(join.schema().fields().len(), 2);
1779        assert_eq!(join.schema().fields()[0], fields[2]);
1780        assert_eq!(join.schema().fields()[1], fields[3]);
1781
1782        let expr: ExprImpl = join.on().clone().into();
1783        let call = expr.as_function_call().unwrap();
1784        assert_eq_input_ref!(&call.inputs()[0], 0);
1785        assert_eq_input_ref!(&call.inputs()[1], 2);
1786
1787        let left = join.left();
1788        let left = left.as_logical_values().unwrap();
1789        assert_eq!(left.schema().fields(), &fields[1..3]);
1790        let right = join.right();
1791        let right = right.as_logical_values().unwrap();
1792        assert_eq!(right.schema().fields(), &fields[3..4]);
1793    }
1794
1795    /// Semi join panicked previously at `prune_col`. Add test to prevent regression.
1796    #[tokio::test]
1797    async fn test_prune_semi_join() {
1798        let ty = DataType::Int32;
1799        let ctx = OptimizerContext::mock().await;
1800        let fields: Vec<Field> = (1..7)
1801            .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
1802            .collect();
1803        let left = LogicalValues::new(
1804            vec![],
1805            Schema {
1806                fields: fields[0..3].to_vec(),
1807            },
1808            ctx.clone(),
1809        );
1810        let right = LogicalValues::new(
1811            vec![],
1812            Schema {
1813                fields: fields[3..6].to_vec(),
1814            },
1815            ctx,
1816        );
1817        let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
1818            FunctionCall::new(
1819                Type::Equal,
1820                vec![
1821                    ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
1822                    ExprImpl::InputRef(Box::new(InputRef::new(4, ty))),
1823                ],
1824            )
1825            .unwrap(),
1826        ));
1827        for join_type in [
1828            JoinType::LeftSemi,
1829            JoinType::RightSemi,
1830            JoinType::LeftAnti,
1831            JoinType::RightAnti,
1832        ] {
1833            let join = LogicalJoin::new(
1834                left.clone().into(),
1835                right.clone().into(),
1836                join_type,
1837                Condition::with_expr(on.clone()),
1838            );
1839
1840            let offset = if join.is_right_join() { 3 } else { 0 };
1841            let join: PlanRef = join.into();
1842            // Perform the prune
1843            let required_cols = vec![0];
1844            // key 0 is never used in the join (always key 1)
1845            let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1846            let as_plan = plan.as_logical_join().unwrap();
1847            // Check the result
1848            assert_eq!(as_plan.schema().fields().len(), 1);
1849            assert_eq!(as_plan.schema().fields()[0], fields[offset]);
1850
1851            // Perform the prune
1852            let required_cols = vec![0, 1, 2];
1853            // should not panic here
1854            let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1855            let as_plan = plan.as_logical_join().unwrap();
1856            // Check the result
1857            assert_eq!(as_plan.schema().fields().len(), 3);
1858            assert_eq!(as_plan.schema().fields()[0], fields[offset]);
1859            assert_eq!(as_plan.schema().fields()[1], fields[offset + 1]);
1860            assert_eq!(as_plan.schema().fields()[2], fields[offset + 2]);
1861        }
1862    }
1863
1864    /// Pruning
1865    /// ```text
1866    /// Join(on: input_ref(1)=input_ref(3))
1867    ///   TableScan(v1, v2, v3)
1868    ///   TableScan(v4, v5, v6)
1869    /// ```
1870    /// with required columns [1,3] will result in
1871    /// ```text
1872    /// Join(on: input_ref(0)=input_ref(1))
1873    ///   TableScan(v2)
1874    ///   TableScan(v4)
1875    /// ```
1876    #[tokio::test]
1877    async fn test_prune_join_no_project() {
1878        let ty = DataType::Int32;
1879        let ctx = OptimizerContext::mock().await;
1880        let fields: Vec<Field> = (1..7)
1881            .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
1882            .collect();
1883        let left = LogicalValues::new(
1884            vec![],
1885            Schema {
1886                fields: fields[0..3].to_vec(),
1887            },
1888            ctx.clone(),
1889        );
1890        let right = LogicalValues::new(
1891            vec![],
1892            Schema {
1893                fields: fields[3..6].to_vec(),
1894            },
1895            ctx,
1896        );
1897        let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
1898            FunctionCall::new(
1899                Type::Equal,
1900                vec![
1901                    ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
1902                    ExprImpl::InputRef(Box::new(InputRef::new(3, ty))),
1903                ],
1904            )
1905            .unwrap(),
1906        ));
1907        let join_type = JoinType::Inner;
1908        let join: PlanRef = LogicalJoin::new(
1909            left.into(),
1910            right.into(),
1911            join_type,
1912            Condition::with_expr(on),
1913        )
1914        .into();
1915
1916        // Perform the prune
1917        let required_cols = vec![1, 3];
1918        let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1919
1920        // Check the result
1921        let join = plan.as_logical_join().unwrap();
1922        assert_eq!(join.schema().fields().len(), 2);
1923        assert_eq!(join.schema().fields()[0], fields[1]);
1924        assert_eq!(join.schema().fields()[1], fields[3]);
1925
1926        let expr: ExprImpl = join.on().clone().into();
1927        let call = expr.as_function_call().unwrap();
1928        assert_eq_input_ref!(&call.inputs()[0], 0);
1929        assert_eq_input_ref!(&call.inputs()[1], 1);
1930
1931        let left = join.left();
1932        let left = left.as_logical_values().unwrap();
1933        assert_eq!(left.schema().fields(), &fields[1..2]);
1934        let right = join.right();
1935        let right = right.as_logical_values().unwrap();
1936        assert_eq!(right.schema().fields(), &fields[3..4]);
1937    }
1938
1939    /// Convert
1940    /// ```text
1941    /// Join(on: ($1 = $3) AND ($2 == 42))
1942    ///   TableScan(v1, v2, v3)
1943    ///   TableScan(v4, v5, v6)
1944    /// ```
1945    /// to
1946    /// ```text
1947    /// Filter($2 == 42)
1948    ///   HashJoin(on: $1 = $3)
1949    ///     TableScan(v1, v2, v3)
1950    ///     TableScan(v4, v5, v6)
1951    /// ```
1952    #[tokio::test]
1953    async fn test_join_to_batch() {
1954        let ctx = OptimizerContext::mock().await;
1955        let fields: Vec<Field> = (1..7)
1956            .map(|i| Field::with_name(DataType::Int32, format!("v{}", i)))
1957            .collect();
1958        let left = LogicalValues::new(
1959            vec![],
1960            Schema {
1961                fields: fields[0..3].to_vec(),
1962            },
1963            ctx.clone(),
1964        );
1965        let right = LogicalValues::new(
1966            vec![],
1967            Schema {
1968                fields: fields[3..6].to_vec(),
1969            },
1970            ctx,
1971        );
1972
1973        fn input_ref(i: usize) -> ExprImpl {
1974            ExprImpl::InputRef(Box::new(InputRef::new(i, DataType::Int32)))
1975        }
1976        let eq_cond = ExprImpl::FunctionCall(Box::new(
1977            FunctionCall::new(Type::Equal, vec![input_ref(1), input_ref(3)]).unwrap(),
1978        ));
1979        let non_eq_cond = ExprImpl::FunctionCall(Box::new(
1980            FunctionCall::new(
1981                Type::Equal,
1982                vec![
1983                    input_ref(2),
1984                    ExprImpl::Literal(Box::new(Literal::new(
1985                        Datum::Some(42_i32.into()),
1986                        DataType::Int32,
1987                    ))),
1988                ],
1989            )
1990            .unwrap(),
1991        ));
1992        // Condition: ($1 = $3) AND ($2 == 42)
1993        let on_cond = ExprImpl::FunctionCall(Box::new(
1994            FunctionCall::new(Type::And, vec![eq_cond.clone(), non_eq_cond.clone()]).unwrap(),
1995        ));
1996
1997        let join_type = JoinType::Inner;
1998        let logical_join = LogicalJoin::new(
1999            left.into(),
2000            right.into(),
2001            join_type,
2002            Condition::with_expr(on_cond),
2003        );
2004
2005        // Perform `to_batch`
2006        let result = logical_join.to_batch().unwrap();
2007
2008        // Expected plan:  HashJoin($1 = $3 AND $2 == 42)
2009        let hash_join = result.as_batch_hash_join().unwrap();
2010        assert_eq!(
2011            ExprImpl::from(hash_join.eq_join_predicate().eq_cond()),
2012            eq_cond
2013        );
2014        assert_eq!(
2015            *hash_join
2016                .eq_join_predicate()
2017                .non_eq_cond()
2018                .conjunctions
2019                .first()
2020                .unwrap(),
2021            non_eq_cond
2022        );
2023    }
2024
2025    /// Convert
2026    /// ```text
2027    /// Join(join_type: left outer, on: ($1 = $3) AND ($2 == 42))
2028    ///   TableScan(v1, v2, v3)
2029    ///   TableScan(v4, v5, v6)
2030    /// ```
2031    /// to
2032    /// ```text
2033    /// HashJoin(join_type: left outer, on: ($1 = $3) AND ($2 == 42))
2034    ///   TableScan(v1, v2, v3)
2035    ///   TableScan(v4, v5, v6)
2036    /// ```
2037    #[tokio::test]
2038    #[ignore] // ignore due to refactor logical scan, but the test seem to duplicate with the explain test
2039    // framework, maybe we will remove it?
2040    async fn test_join_to_stream() {
2041        // let ctx = Rc::new(RefCell::new(QueryContext::mock().await));
2042        // let fields: Vec<Field> = (1..7)
2043        //     .map(|i| Field {
2044        //         data_type: DataType::Int32,
2045        //         name: format!("v{}", i),
2046        //     })
2047        //     .collect();
2048        // let left = LogicalScan::new(
2049        //     "left".to_string(),
2050        //     TableId::new(0),
2051        //     vec![1.into(), 2.into(), 3.into()],
2052        //     Schema {
2053        //         fields: fields[0..3].to_vec(),
2054        //     },
2055        //     ctx.clone(),
2056        // );
2057        // let right = LogicalScan::new(
2058        //     "right".to_string(),
2059        //     TableId::new(0),
2060        //     vec![4.into(), 5.into(), 6.into()],
2061        //     Schema {
2062        //         fields: fields[3..6].to_vec(),
2063        //     },
2064        //     ctx,
2065        // );
2066        // let eq_cond = ExprImpl::FunctionCall(Box::new(
2067        //     FunctionCall::new(
2068        //         Type::Equal,
2069        //         vec![
2070        //             ExprImpl::InputRef(Box::new(InputRef::new(1, DataType::Int32))),
2071        //             ExprImpl::InputRef(Box::new(InputRef::new(3, DataType::Int32))),
2072        //         ],
2073        //     )
2074        //     .unwrap(),
2075        // ));
2076        // let non_eq_cond = ExprImpl::FunctionCall(Box::new(
2077        //     FunctionCall::new(
2078        //         Type::Equal,
2079        //         vec![
2080        //             ExprImpl::InputRef(Box::new(InputRef::new(2, DataType::Int32))),
2081        //             ExprImpl::Literal(Box::new(Literal::new(
2082        //                 Datum::Some(42_i32.into()),
2083        //                 DataType::Int32,
2084        //             ))),
2085        //         ],
2086        //     )
2087        //     .unwrap(),
2088        // ));
2089        // // Condition: ($1 = $3) AND ($2 == 42)
2090        // let on_cond = ExprImpl::FunctionCall(Box::new(
2091        //     FunctionCall::new(Type::And, vec![eq_cond, non_eq_cond]).unwrap(),
2092        // ));
2093
2094        // let join_type = JoinType::LeftOuter;
2095        // let logical_join = LogicalJoin::new(
2096        //     left.into(),
2097        //     right.into(),
2098        //     join_type,
2099        //     Condition::with_expr(on_cond.clone()),
2100        // );
2101
2102        // // Perform `to_stream`
2103        // let result = logical_join.to_stream();
2104
2105        // // Expected plan: HashJoin(($1 = $3) AND ($2 == 42))
2106        // let hash_join = result.as_stream_hash_join().unwrap();
2107        // assert_eq!(hash_join.eq_join_predicate().all_cond().as_expr(), on_cond);
2108    }
2109    /// Pruning
2110    /// ```text
2111    /// Join(on: input_ref(1)=input_ref(3))
2112    ///   TableScan(v1, v2, v3)
2113    ///   TableScan(v4, v5, v6)
2114    /// ```
2115    /// with required columns [3, 2] will result in
2116    /// ```text
2117    /// Project(input_ref(2), input_ref(1))
2118    ///   Join(on: input_ref(0)=input_ref(2))
2119    ///     TableScan(v2, v3)
2120    ///     TableScan(v4)
2121    /// ```
2122    #[tokio::test]
2123    async fn test_join_column_prune_with_order_required() {
2124        let ty = DataType::Int32;
2125        let ctx = OptimizerContext::mock().await;
2126        let fields: Vec<Field> = (1..7)
2127            .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
2128            .collect();
2129        let left = LogicalValues::new(
2130            vec![],
2131            Schema {
2132                fields: fields[0..3].to_vec(),
2133            },
2134            ctx.clone(),
2135        );
2136        let right = LogicalValues::new(
2137            vec![],
2138            Schema {
2139                fields: fields[3..6].to_vec(),
2140            },
2141            ctx,
2142        );
2143        let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
2144            FunctionCall::new(
2145                Type::Equal,
2146                vec![
2147                    ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
2148                    ExprImpl::InputRef(Box::new(InputRef::new(3, ty))),
2149                ],
2150            )
2151            .unwrap(),
2152        ));
2153        let join_type = JoinType::Inner;
2154        let join: PlanRef = LogicalJoin::new(
2155            left.into(),
2156            right.into(),
2157            join_type,
2158            Condition::with_expr(on),
2159        )
2160        .into();
2161
2162        // Perform the prune
2163        let required_cols = vec![3, 2];
2164        let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
2165
2166        // Check the result
2167        let join = plan.as_logical_join().unwrap();
2168        assert_eq!(join.schema().fields().len(), 2);
2169        assert_eq!(join.schema().fields()[0], fields[3]);
2170        assert_eq!(join.schema().fields()[1], fields[2]);
2171
2172        let expr: ExprImpl = join.on().clone().into();
2173        let call = expr.as_function_call().unwrap();
2174        assert_eq_input_ref!(&call.inputs()[0], 0);
2175        assert_eq_input_ref!(&call.inputs()[1], 2);
2176
2177        let left = join.left();
2178        let left = left.as_logical_values().unwrap();
2179        assert_eq!(left.schema().fields(), &fields[1..3]);
2180        let right = join.right();
2181        let right = right.as_logical_values().unwrap();
2182        assert_eq!(right.schema().fields(), &fields[3..4]);
2183    }
2184
2185    #[tokio::test]
2186    async fn fd_derivation_inner_outer_join() {
2187        // left: [l0, l1], right: [r0, r1, r2]
2188        // FD: l0 --> l1, r0 --> { r1, r2 }
2189        // On: l0 = 0 AND l1 = r1
2190        //
2191        // Inner Join:
2192        //  Schema: [l0, l1, r0, r1, r2]
2193        //  FD: l0 --> l1, r0 --> { r1, r2 }, {} --> l0, l1 --> r1, r1 --> l1
2194        // Left Outer Join:
2195        //  Schema: [l0, l1, r0, r1, r2]
2196        //  FD: l0 --> l1
2197        // Right Outer Join:
2198        //  Schema: [l0, l1, r0, r1, r2]
2199        //  FD: r0 --> { r1, r2 }
2200        // Full Outer Join:
2201        //  Schema: [l0, l1, r0, r1, r2]
2202        //  FD: empty
2203        // Left Semi/Anti Join:
2204        //  Schema: [l0, l1]
2205        //  FD: l0 --> l1
2206        // Right Semi/Anti Join:
2207        //  Schema: [r0, r1, r2]
2208        //  FD: r0 --> {r1, r2}
2209        let ctx = OptimizerContext::mock().await;
2210        let left = {
2211            let fields: Vec<Field> = vec![
2212                Field::with_name(DataType::Int32, "l0"),
2213                Field::with_name(DataType::Int32, "l1"),
2214            ];
2215            let mut values = LogicalValues::new(vec![], Schema { fields }, ctx.clone());
2216            // 0 --> 1
2217            values
2218                .base
2219                .functional_dependency_mut()
2220                .add_functional_dependency_by_column_indices(&[0], &[1]);
2221            values
2222        };
2223        let right = {
2224            let fields: Vec<Field> = vec![
2225                Field::with_name(DataType::Int32, "r0"),
2226                Field::with_name(DataType::Int32, "r1"),
2227                Field::with_name(DataType::Int32, "r2"),
2228            ];
2229            let mut values = LogicalValues::new(vec![], Schema { fields }, ctx);
2230            // 0 --> 1, 2
2231            values
2232                .base
2233                .functional_dependency_mut()
2234                .add_functional_dependency_by_column_indices(&[0], &[1, 2]);
2235            values
2236        };
2237        // l0 = 0 AND l1 = r1
2238        let on: ExprImpl = FunctionCall::new(
2239            Type::And,
2240            vec![
2241                FunctionCall::new(
2242                    Type::Equal,
2243                    vec![
2244                        InputRef::new(0, DataType::Int32).into(),
2245                        ExprImpl::literal_int(0),
2246                    ],
2247                )
2248                .unwrap()
2249                .into(),
2250                FunctionCall::new(
2251                    Type::Equal,
2252                    vec![
2253                        InputRef::new(1, DataType::Int32).into(),
2254                        InputRef::new(3, DataType::Int32).into(),
2255                    ],
2256                )
2257                .unwrap()
2258                .into(),
2259            ],
2260        )
2261        .unwrap()
2262        .into();
2263        let expected_fd_set = [
2264            (
2265                JoinType::Inner,
2266                [
2267                    // inherit from left
2268                    FunctionalDependency::with_indices(5, &[0], &[1]),
2269                    // inherit from right
2270                    FunctionalDependency::with_indices(5, &[2], &[3, 4]),
2271                    // constant column in join condition
2272                    FunctionalDependency::with_indices(5, &[], &[0]),
2273                    // eq column in join condition
2274                    FunctionalDependency::with_indices(5, &[1], &[3]),
2275                    FunctionalDependency::with_indices(5, &[3], &[1]),
2276                ]
2277                .into_iter()
2278                .collect::<HashSet<_>>(),
2279            ),
2280            (JoinType::FullOuter, HashSet::new()),
2281            (
2282                JoinType::RightOuter,
2283                [
2284                    // inherit from right
2285                    FunctionalDependency::with_indices(5, &[2], &[3, 4]),
2286                ]
2287                .into_iter()
2288                .collect::<HashSet<_>>(),
2289            ),
2290            (
2291                JoinType::LeftOuter,
2292                [
2293                    // inherit from left
2294                    FunctionalDependency::with_indices(5, &[0], &[1]),
2295                ]
2296                .into_iter()
2297                .collect::<HashSet<_>>(),
2298            ),
2299            (
2300                JoinType::LeftSemi,
2301                [
2302                    // inherit from left
2303                    FunctionalDependency::with_indices(2, &[0], &[1]),
2304                ]
2305                .into_iter()
2306                .collect::<HashSet<_>>(),
2307            ),
2308            (
2309                JoinType::LeftAnti,
2310                [
2311                    // inherit from left
2312                    FunctionalDependency::with_indices(2, &[0], &[1]),
2313                ]
2314                .into_iter()
2315                .collect::<HashSet<_>>(),
2316            ),
2317            (
2318                JoinType::RightSemi,
2319                [
2320                    // inherit from right
2321                    FunctionalDependency::with_indices(3, &[0], &[1, 2]),
2322                ]
2323                .into_iter()
2324                .collect::<HashSet<_>>(),
2325            ),
2326            (
2327                JoinType::RightAnti,
2328                [
2329                    // inherit from right
2330                    FunctionalDependency::with_indices(3, &[0], &[1, 2]),
2331                ]
2332                .into_iter()
2333                .collect::<HashSet<_>>(),
2334            ),
2335        ];
2336
2337        for (join_type, expected_res) in expected_fd_set {
2338            let join = LogicalJoin::new(
2339                left.clone().into(),
2340                right.clone().into(),
2341                join_type,
2342                Condition::with_expr(on.clone()),
2343            );
2344            let fd_set = join
2345                .functional_dependency()
2346                .as_dependencies()
2347                .iter()
2348                .cloned()
2349                .collect::<HashSet<_>>();
2350            assert_eq!(fd_set, expected_res);
2351        }
2352    }
2353}