risingwave_frontend/optimizer/plan_node/
logical_join.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16
17use fixedbitset::FixedBitSet;
18use itertools::{EitherOrBoth, Itertools};
19use pretty_xmlish::{Pretty, XmlNode};
20use risingwave_expr::bail;
21use risingwave_pb::expr::expr_node::PbType;
22use risingwave_pb::plan_common::{AsOfJoinDesc, JoinType, PbAsOfJoinInequalityType};
23use risingwave_pb::stream_plan::StreamScanType;
24use risingwave_sqlparser::ast::AsOf;
25
26use super::generic::{
27    GenericPlanNode, GenericPlanRef, push_down_into_join, push_down_join_condition,
28};
29use super::utils::{Distill, childless_record};
30use super::{
31    BatchPlanRef, ColPrunable, ExprRewritable, Logical, LogicalPlanRef as PlanRef, PlanBase,
32    PlanTreeNodeBinary, PredicatePushdown, StreamHashJoin, StreamPlanRef, StreamProject, ToBatch,
33    ToStream, generic,
34};
35use crate::error::{ErrorCode, Result, RwError};
36use crate::expr::{CollectInputRef, Expr, ExprImpl, ExprRewriter, ExprType, ExprVisitor, InputRef};
37use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
38use crate::optimizer::plan_node::generic::DynamicFilter;
39use crate::optimizer::plan_node::stream_asof_join::StreamAsOfJoin;
40use crate::optimizer::plan_node::utils::IndicesDisplay;
41use crate::optimizer::plan_node::{
42    BatchHashJoin, BatchLookupJoin, BatchNestedLoopJoin, ColumnPruningContext, EqJoinPredicate,
43    LogicalFilter, LogicalScan, PredicatePushdownContext, RewriteStreamContext,
44    StreamDynamicFilter, StreamFilter, StreamTableScan, StreamTemporalJoin, ToStreamContext,
45};
46use crate::optimizer::plan_visitor::LogicalCardinalityExt;
47use crate::optimizer::property::{Distribution, RequiredDist};
48use crate::utils::{ColIndexMapping, ColIndexMappingRewriteExt, Condition, ConditionDisplay};
49
50/// `LogicalJoin` combines two relations according to some condition.
51///
52/// Each output row has fields from the left and right inputs. The set of output rows is a subset
53/// of the cartesian product of the two inputs; precisely which subset depends on the join
54/// condition. In addition, the output columns are a subset of the columns of the left and
55/// right columns, dependent on the output indices provided. A repeat output index is illegal.
56#[derive(Debug, Clone, PartialEq, Eq, Hash)]
57pub struct LogicalJoin {
58    pub base: PlanBase<Logical>,
59    core: generic::Join<PlanRef>,
60}
61
62impl Distill for LogicalJoin {
63    fn distill<'a>(&self) -> XmlNode<'a> {
64        let verbose = self.base.ctx().is_explain_verbose();
65        let mut vec = Vec::with_capacity(if verbose { 3 } else { 2 });
66        vec.push(("type", Pretty::debug(&self.join_type())));
67
68        let concat_schema = self.core.concat_schema();
69        let cond = Pretty::debug(&ConditionDisplay {
70            condition: self.on(),
71            input_schema: &concat_schema,
72        });
73        vec.push(("on", cond));
74
75        if verbose {
76            let data = IndicesDisplay::from_join(&self.core, &concat_schema);
77            vec.push(("output", data));
78        }
79
80        childless_record("LogicalJoin", vec)
81    }
82}
83
84impl LogicalJoin {
85    pub(crate) fn new(left: PlanRef, right: PlanRef, join_type: JoinType, on: Condition) -> Self {
86        let core = generic::Join::with_full_output(left, right, join_type, on);
87        Self::with_core(core)
88    }
89
90    pub(crate) fn with_output_indices(
91        left: PlanRef,
92        right: PlanRef,
93        join_type: JoinType,
94        on: Condition,
95        output_indices: Vec<usize>,
96    ) -> Self {
97        let core = generic::Join::new(left, right, on, join_type, output_indices);
98        Self::with_core(core)
99    }
100
101    pub fn with_core(core: generic::Join<PlanRef>) -> Self {
102        let base = PlanBase::new_logical_with_core(&core);
103        LogicalJoin { base, core }
104    }
105
106    pub fn create(
107        left: PlanRef,
108        right: PlanRef,
109        join_type: JoinType,
110        on_clause: ExprImpl,
111    ) -> PlanRef {
112        Self::new(left, right, join_type, Condition::with_expr(on_clause)).into()
113    }
114
115    pub fn internal_column_num(&self) -> usize {
116        self.core.internal_column_num()
117    }
118
119    pub fn i2l_col_mapping_ignore_join_type(&self) -> ColIndexMapping {
120        self.core.i2l_col_mapping_ignore_join_type()
121    }
122
123    pub fn i2r_col_mapping_ignore_join_type(&self) -> ColIndexMapping {
124        self.core.i2r_col_mapping_ignore_join_type()
125    }
126
127    /// Get a reference to the logical join's on.
128    pub fn on(&self) -> &Condition {
129        &self.core.on
130    }
131
132    pub fn core(&self) -> &generic::Join<PlanRef> {
133        &self.core
134    }
135
136    /// Collect all input ref in the on condition. And separate them into left and right.
137    pub fn input_idx_on_condition(&self) -> (Vec<usize>, Vec<usize>) {
138        let input_refs = self
139            .core
140            .on
141            .collect_input_refs(self.core.left.schema().len() + self.core.right.schema().len());
142        let index_group = input_refs
143            .ones()
144            .chunk_by(|i| *i < self.core.left.schema().len());
145        let left_index = index_group
146            .into_iter()
147            .next()
148            .map_or(vec![], |group| group.1.collect_vec());
149        let right_index = index_group.into_iter().next().map_or(vec![], |group| {
150            group
151                .1
152                .map(|i| i - self.core.left.schema().len())
153                .collect_vec()
154        });
155        (left_index, right_index)
156    }
157
158    /// Get the join type of the logical join.
159    pub fn join_type(&self) -> JoinType {
160        self.core.join_type
161    }
162
163    /// Get the eq join key of the logical join.
164    pub fn eq_indexes(&self) -> Vec<(usize, usize)> {
165        self.core.eq_indexes()
166    }
167
168    /// Get the output indices of the logical join.
169    pub fn output_indices(&self) -> &Vec<usize> {
170        &self.core.output_indices
171    }
172
173    /// Clone with new output indices
174    pub fn clone_with_output_indices(&self, output_indices: Vec<usize>) -> Self {
175        Self::with_core(generic::Join {
176            output_indices,
177            ..self.core.clone()
178        })
179    }
180
181    /// Clone with new `on` condition
182    pub fn clone_with_cond(&self, on: Condition) -> Self {
183        Self::with_core(generic::Join {
184            on,
185            ..self.core.clone()
186        })
187    }
188
189    pub fn is_left_join(&self) -> bool {
190        matches!(self.join_type(), JoinType::LeftSemi | JoinType::LeftAnti)
191    }
192
193    pub fn is_right_join(&self) -> bool {
194        matches!(self.join_type(), JoinType::RightSemi | JoinType::RightAnti)
195    }
196
197    pub fn is_full_out(&self) -> bool {
198        self.core.is_full_out()
199    }
200
201    pub fn is_asof_join(&self) -> bool {
202        self.join_type() == JoinType::AsofInner || self.join_type() == JoinType::AsofLeftOuter
203    }
204
205    pub fn output_indices_are_trivial(&self) -> bool {
206        self.output_indices() == &(0..self.internal_column_num()).collect_vec()
207    }
208
209    /// Try to simplify the outer join with the predicate on the top of the join
210    ///
211    /// now it is just a naive implementation for comparison expression, we can give a more general
212    /// implementation with constant folding in future
213    fn simplify_outer(predicate: &Condition, left_col_num: usize, join_type: JoinType) -> JoinType {
214        let (mut gen_null_in_left, mut gen_null_in_right) = match join_type {
215            JoinType::LeftOuter => (false, true),
216            JoinType::RightOuter => (true, false),
217            JoinType::FullOuter => (true, true),
218            _ => return join_type,
219        };
220
221        for expr in &predicate.conjunctions {
222            if let ExprImpl::FunctionCall(func) = expr {
223                match func.func_type() {
224                    ExprType::Equal
225                    | ExprType::NotEqual
226                    | ExprType::LessThan
227                    | ExprType::LessThanOrEqual
228                    | ExprType::GreaterThan
229                    | ExprType::GreaterThanOrEqual => {
230                        for input in func.inputs() {
231                            if let ExprImpl::InputRef(input) = input {
232                                let idx = input.index;
233                                if idx < left_col_num {
234                                    gen_null_in_left = false;
235                                } else {
236                                    gen_null_in_right = false;
237                                }
238                            }
239                        }
240                    }
241                    _ => {}
242                };
243            }
244        }
245
246        match (gen_null_in_left, gen_null_in_right) {
247            (true, true) => JoinType::FullOuter,
248            (true, false) => JoinType::RightOuter,
249            (false, true) => JoinType::LeftOuter,
250            (false, false) => JoinType::Inner,
251        }
252    }
253
254    /// Index Join:
255    /// Try to convert logical join into batch lookup join and meanwhile it will do
256    /// the index selection for the lookup table so that we can benefit from indexes.
257    fn to_batch_lookup_join_with_index_selection(
258        &self,
259        predicate: EqJoinPredicate,
260        batch_join: generic::Join<BatchPlanRef>,
261    ) -> Result<Option<BatchLookupJoin>> {
262        match batch_join.join_type {
263            JoinType::Inner
264            | JoinType::LeftOuter
265            | JoinType::LeftSemi
266            | JoinType::LeftAnti
267            | JoinType::AsofInner
268            | JoinType::AsofLeftOuter => {}
269            _ => return Ok(None),
270        };
271
272        // Index selection for index join.
273        let right = self.right();
274        // Lookup Join only supports basic tables on the join's right side.
275        let logical_scan: &LogicalScan = if let Some(logical_scan) = right.as_logical_scan() {
276            logical_scan
277        } else {
278            return Ok(None);
279        };
280
281        let mut result_plan = None;
282        // Lookup primary table.
283        if let Some(lookup_join) =
284            self.to_batch_lookup_join(predicate.clone(), batch_join.clone())?
285        {
286            result_plan = Some(lookup_join);
287        }
288
289        let indexes = logical_scan.table_indexes();
290        for index in indexes {
291            if let Some(index_scan) = logical_scan.to_index_scan_if_index_covered(index) {
292                let index_scan: PlanRef = index_scan.into();
293                let that = self.clone_with_left_right(self.left(), index_scan.clone());
294                let mut new_batch_join = batch_join.clone();
295                new_batch_join.right = index_scan.to_batch().expect("index scan failed to batch");
296
297                // Lookup covered index.
298                if let Some(lookup_join) =
299                    that.to_batch_lookup_join(predicate.clone(), new_batch_join)?
300                {
301                    match &result_plan {
302                        None => result_plan = Some(lookup_join),
303                        Some(prev_lookup_join) => {
304                            // Prefer to choose lookup join with longer lookup prefix len.
305                            if prev_lookup_join.lookup_prefix_len()
306                                < lookup_join.lookup_prefix_len()
307                            {
308                                result_plan = Some(lookup_join)
309                            }
310                        }
311                    }
312                }
313            }
314        }
315
316        Ok(result_plan)
317    }
318
319    /// Try to convert logical join into batch lookup join.
320    fn to_batch_lookup_join(
321        &self,
322        predicate: EqJoinPredicate,
323        logical_join: generic::Join<BatchPlanRef>,
324    ) -> Result<Option<BatchLookupJoin>> {
325        match logical_join.join_type {
326            JoinType::Inner
327            | JoinType::LeftOuter
328            | JoinType::LeftSemi
329            | JoinType::LeftAnti
330            | JoinType::AsofInner
331            | JoinType::AsofLeftOuter => {}
332            _ => return Ok(None),
333        };
334
335        let right = self.right();
336        // Lookup Join only supports basic tables on the join's right side.
337        let logical_scan: &LogicalScan = if let Some(logical_scan) = right.as_logical_scan() {
338            logical_scan
339        } else {
340            return Ok(None);
341        };
342        let table = logical_scan.table();
343        let output_column_ids = logical_scan.output_column_ids();
344
345        // Verify that the right join key columns are the the prefix of the primary key and
346        // also contain the distribution key.
347        let order_col_ids = table.order_column_ids();
348        let dist_key = table.distribution_key.clone();
349        // The at least prefix of order key that contains distribution key.
350        let mut dist_key_in_order_key_pos = vec![];
351        for d in dist_key {
352            let pos = table
353                .order_column_indices()
354                .position(|x| x == d)
355                .expect("dist_key must in order_key");
356            dist_key_in_order_key_pos.push(pos);
357        }
358        // The shortest prefix of order key that contains distribution key.
359        let shortest_prefix_len = dist_key_in_order_key_pos
360            .iter()
361            .max()
362            .map_or(0, |pos| pos + 1);
363
364        // Distributed lookup join can't support lookup table with a singleton distribution.
365        if shortest_prefix_len == 0 {
366            return Ok(None);
367        }
368
369        // Reorder the join equal predicate to match the order key.
370        let mut reorder_idx = Vec::with_capacity(shortest_prefix_len);
371        for order_col_id in order_col_ids {
372            let mut found = false;
373            for (i, eq_idx) in predicate.right_eq_indexes().into_iter().enumerate() {
374                if order_col_id == output_column_ids[eq_idx] {
375                    reorder_idx.push(i);
376                    found = true;
377                    break;
378                }
379            }
380            if !found {
381                break;
382            }
383        }
384        if reorder_idx.len() < shortest_prefix_len {
385            return Ok(None);
386        }
387        let lookup_prefix_len = reorder_idx.len();
388        let predicate = predicate.reorder(&reorder_idx);
389
390        // Extract the predicate from logical scan. Only pure scan is supported.
391        let (new_scan, scan_predicate, project_expr) = logical_scan.predicate_pull_up();
392        // Construct output column to require column mapping
393        let o2r = if let Some(project_expr) = project_expr {
394            project_expr
395                .into_iter()
396                .map(|x| x.as_input_ref().unwrap().index)
397                .collect_vec()
398        } else {
399            (0..logical_scan.output_col_idx().len()).collect_vec()
400        };
401        let left_schema_len = logical_join.left.schema().len();
402
403        let mut join_predicate_rewriter = LookupJoinPredicateRewriter {
404            offset: left_schema_len,
405            mapping: o2r.clone(),
406        };
407
408        let new_eq_cond = predicate
409            .eq_cond()
410            .rewrite_expr(&mut join_predicate_rewriter);
411
412        let mut scan_predicate_rewriter = LookupJoinScanPredicateRewriter {
413            offset: left_schema_len,
414        };
415
416        let new_other_cond = predicate
417            .other_cond()
418            .clone()
419            .rewrite_expr(&mut join_predicate_rewriter)
420            .and(scan_predicate.rewrite_expr(&mut scan_predicate_rewriter));
421
422        let new_join_on = new_eq_cond.and(new_other_cond);
423        let new_predicate = EqJoinPredicate::create(
424            left_schema_len,
425            new_scan.schema().len(),
426            new_join_on.clone(),
427        );
428
429        // We discovered that we cannot use a lookup join after pulling up the predicate
430        // from one side and simplifying the condition. Let's use some other join instead.
431        if !new_predicate.has_eq() {
432            return Ok(None);
433        }
434
435        // Rewrite the join output indices and all output indices referred to the old scan need to
436        // rewrite.
437        let new_join_output_indices = logical_join
438            .output_indices
439            .iter()
440            .map(|&x| {
441                if x < left_schema_len {
442                    x
443                } else {
444                    o2r[x - left_schema_len] + left_schema_len
445                }
446            })
447            .collect_vec();
448
449        let new_scan_output_column_ids = new_scan.output_column_ids();
450        let as_of = new_scan.as_of.clone();
451        let new_logical_scan: LogicalScan = new_scan.into();
452
453        // Construct a new logical join, because we have change its RHS.
454        let new_logical_join = generic::Join::new(
455            logical_join.left,
456            new_logical_scan.to_batch()?,
457            new_join_on,
458            logical_join.join_type,
459            new_join_output_indices,
460        );
461
462        let asof_desc = self
463            .is_asof_join()
464            .then(|| {
465                Self::get_inequality_desc_from_predicate(
466                    predicate.other_cond().clone(),
467                    left_schema_len,
468                )
469            })
470            .transpose()?;
471
472        Ok(Some(BatchLookupJoin::new(
473            new_logical_join,
474            new_predicate,
475            table.clone(),
476            new_scan_output_column_ids,
477            lookup_prefix_len,
478            false,
479            as_of,
480            asof_desc,
481        )))
482    }
483
484    pub fn decompose(self) -> (PlanRef, PlanRef, Condition, JoinType, Vec<usize>) {
485        self.core.decompose()
486    }
487}
488
489impl PlanTreeNodeBinary<Logical> for LogicalJoin {
490    fn left(&self) -> PlanRef {
491        self.core.left.clone()
492    }
493
494    fn right(&self) -> PlanRef {
495        self.core.right.clone()
496    }
497
498    fn clone_with_left_right(&self, left: PlanRef, right: PlanRef) -> Self {
499        Self::with_core(generic::Join {
500            left,
501            right,
502            ..self.core.clone()
503        })
504    }
505
506    fn rewrite_with_left_right(
507        &self,
508        left: PlanRef,
509        left_col_change: ColIndexMapping,
510        right: PlanRef,
511        right_col_change: ColIndexMapping,
512    ) -> (Self, ColIndexMapping) {
513        let (new_on, new_output_indices) = {
514            let (mut map, _) = left_col_change.clone().into_parts();
515            let (mut right_map, _) = right_col_change.clone().into_parts();
516            for i in right_map.iter_mut().flatten() {
517                *i += left.schema().len();
518            }
519            map.append(&mut right_map);
520            let mut mapping = ColIndexMapping::new(map, left.schema().len() + right.schema().len());
521
522            let new_output_indices = self
523                .output_indices()
524                .iter()
525                .map(|&i| mapping.map(i))
526                .collect::<Vec<_>>();
527            let new_on = self.on().clone().rewrite_expr(&mut mapping);
528            (new_on, new_output_indices)
529        };
530
531        let join = Self::with_output_indices(
532            left,
533            right,
534            self.join_type(),
535            new_on,
536            new_output_indices.clone(),
537        );
538
539        let new_i2o = ColIndexMapping::with_remaining_columns(
540            &new_output_indices,
541            join.internal_column_num(),
542        );
543
544        let old_o2i = self.core.o2i_col_mapping();
545
546        let old_o2l = old_o2i
547            .composite(&self.core.i2l_col_mapping())
548            .composite(&left_col_change);
549        let old_o2r = old_o2i
550            .composite(&self.core.i2r_col_mapping())
551            .composite(&right_col_change);
552        let new_l2o = join.core.l2i_col_mapping().composite(&new_i2o);
553        let new_r2o = join.core.r2i_col_mapping().composite(&new_i2o);
554
555        let out_col_change = old_o2l
556            .composite(&new_l2o)
557            .union(&old_o2r.composite(&new_r2o));
558        (join, out_col_change)
559    }
560}
561
562impl_plan_tree_node_for_binary! { Logical, LogicalJoin }
563
564impl ColPrunable for LogicalJoin {
565    fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef {
566        // make `required_cols` point to internal table instead of output schema.
567        let required_cols = required_cols
568            .iter()
569            .map(|i| self.output_indices()[*i])
570            .collect_vec();
571        let left_len = self.left().schema().fields.len();
572
573        let total_len = self.left().schema().len() + self.right().schema().len();
574        let mut resized_required_cols = FixedBitSet::with_capacity(total_len);
575
576        required_cols.iter().for_each(|&i| {
577            if self.is_right_join() {
578                resized_required_cols.insert(left_len + i);
579            } else {
580                resized_required_cols.insert(i);
581            }
582        });
583
584        // add those columns which are required in the join condition to
585        // to those that are required in the output
586        let mut visitor = CollectInputRef::new(resized_required_cols);
587        self.on().visit_expr(&mut visitor);
588        let left_right_required_cols = FixedBitSet::from(visitor).ones().collect_vec();
589
590        let mut left_required_cols = Vec::new();
591        let mut right_required_cols = Vec::new();
592        left_right_required_cols.iter().for_each(|&i| {
593            if i < left_len {
594                left_required_cols.push(i);
595            } else {
596                right_required_cols.push(i - left_len);
597            }
598        });
599
600        let mut on = self.on().clone();
601        let mut mapping =
602            ColIndexMapping::with_remaining_columns(&left_right_required_cols, total_len);
603        on = on.rewrite_expr(&mut mapping);
604
605        let new_output_indices = {
606            let required_inputs_in_output = if self.is_left_join() {
607                &left_required_cols
608            } else if self.is_right_join() {
609                &right_required_cols
610            } else {
611                &left_right_required_cols
612            };
613
614            let mapping =
615                ColIndexMapping::with_remaining_columns(required_inputs_in_output, total_len);
616            required_cols.iter().map(|&i| mapping.map(i)).collect_vec()
617        };
618
619        LogicalJoin::with_output_indices(
620            self.left().prune_col(&left_required_cols, ctx),
621            self.right().prune_col(&right_required_cols, ctx),
622            self.join_type(),
623            on,
624            new_output_indices,
625        )
626        .into()
627    }
628}
629
630impl ExprRewritable<Logical> for LogicalJoin {
631    fn has_rewritable_expr(&self) -> bool {
632        true
633    }
634
635    fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
636        let mut core = self.core.clone();
637        core.rewrite_exprs(r);
638        Self {
639            base: self.base.clone_with_new_plan_id(),
640            core,
641        }
642        .into()
643    }
644}
645
646impl ExprVisitable for LogicalJoin {
647    fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
648        self.core.visit_exprs(v);
649    }
650}
651
652/// We are trying to derive a predicate to apply to the other side of a join if all
653/// the `InputRef`s in the predicate are eq condition columns, and can hence be substituted
654/// with the corresponding eq condition columns of the other side.
655///
656/// Strategy:
657/// 1. If the function is pure except for any `InputRef` (which may refer to impure computation),
658///    then we proceed. Else abort.
659/// 2. Then, we collect `InputRef`s in the conjunction.
660/// 3. If they are all columns in the given side of join eq condition, then we proceed. Else abort.
661/// 4. We then rewrite the `ExprImpl`, by replacing `InputRef` column indices with the equivalent in
662///    the other side.
663///
664/// # Arguments
665///
666/// Suppose we derive a predicate from the left side to be pushed to the right side.
667/// * `expr`: An expr from the left side.
668/// * `col_num`: The number of columns in the left side.
669fn derive_predicate_from_eq_condition(
670    expr: &ExprImpl,
671    eq_condition: &EqJoinPredicate,
672    col_num: usize,
673    expr_is_left: bool,
674) -> Option<ExprImpl> {
675    if expr.is_impure() {
676        return None;
677    }
678    let eq_indices = eq_condition
679        .eq_indexes_typed()
680        .iter()
681        .filter_map(|(l, r)| {
682            if l.return_type() != r.return_type() {
683                None
684            } else if expr_is_left {
685                Some(l.index())
686            } else {
687                Some(r.index())
688            }
689        })
690        .collect_vec();
691    if expr
692        .collect_input_refs(col_num)
693        .ones()
694        .any(|index| !eq_indices.contains(&index))
695    {
696        // expr contains an InputRef not in eq_condition
697        return None;
698    }
699    // The function is pure except for `InputRef` and all `InputRef`s are `eq_condition` indices.
700    // Hence, we can substitute those `InputRef`s with indices from the other side.
701    let other_side_mapping = if expr_is_left {
702        eq_condition.eq_indexes_typed().into_iter().collect()
703    } else {
704        eq_condition
705            .eq_indexes_typed()
706            .into_iter()
707            .map(|(x, y)| (y, x))
708            .collect()
709    };
710    struct InputRefsRewriter {
711        mapping: HashMap<InputRef, InputRef>,
712    }
713    impl ExprRewriter for InputRefsRewriter {
714        fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
715            self.mapping[&input_ref].clone().into()
716        }
717    }
718    Some(
719        InputRefsRewriter {
720            mapping: other_side_mapping,
721        }
722        .rewrite_expr(expr.clone()),
723    )
724}
725
726/// Rewrite the join predicate and all columns referred to the scan side need to rewrite.
727struct LookupJoinPredicateRewriter {
728    offset: usize,
729    mapping: Vec<usize>,
730}
731impl ExprRewriter for LookupJoinPredicateRewriter {
732    fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
733        if input_ref.index() < self.offset {
734            input_ref.into()
735        } else {
736            InputRef::new(
737                self.mapping[input_ref.index() - self.offset] + self.offset,
738                input_ref.return_type(),
739            )
740            .into()
741        }
742    }
743}
744
745/// Rewrite the scan predicate so we can add it to the join predicate.
746struct LookupJoinScanPredicateRewriter {
747    offset: usize,
748}
749impl ExprRewriter for LookupJoinScanPredicateRewriter {
750    fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
751        InputRef::new(input_ref.index() + self.offset, input_ref.return_type()).into()
752    }
753}
754
755impl PredicatePushdown for LogicalJoin {
756    /// Pushes predicates above and within a join node into the join node and/or its children nodes.
757    ///
758    /// # Which predicates can be pushed
759    ///
760    /// For inner join, we can do all kinds of pushdown.
761    ///
762    /// For left/right semi join, we can push filter to left/right and on-clause,
763    /// and push on-clause to left/right.
764    ///
765    /// For left/right anti join, we can push filter to left/right, but on-clause can not be pushed
766    ///
767    /// ## Outer Join
768    ///
769    /// Preserved Row table
770    /// : The table in an Outer Join that must return all rows.
771    ///
772    /// Null Supplying table
773    /// : This is the table that has nulls filled in for its columns in unmatched rows.
774    ///
775    /// |                          | Preserved Row table | Null Supplying table |
776    /// |--------------------------|---------------------|----------------------|
777    /// | Join predicate (on)      | Not Pushed          | Pushed               |
778    /// | Where predicate (filter) | Pushed              | Not Pushed           |
779    fn predicate_pushdown(
780        &self,
781        predicate: Condition,
782        ctx: &mut PredicatePushdownContext,
783    ) -> PlanRef {
784        // rewrite output col referencing indices as internal cols
785        let mut predicate = {
786            let mut mapping = self.core.o2i_col_mapping();
787            predicate.rewrite_expr(&mut mapping)
788        };
789
790        let left_col_num = self.left().schema().len();
791        let right_col_num = self.right().schema().len();
792        let join_type = LogicalJoin::simplify_outer(&predicate, left_col_num, self.join_type());
793
794        let push_down_temporal_predicate = !self.should_be_temporal_join();
795
796        let (left_from_filter, right_from_filter, on) = push_down_into_join(
797            &mut predicate,
798            left_col_num,
799            right_col_num,
800            join_type,
801            push_down_temporal_predicate,
802        );
803
804        let mut new_on = self.on().clone().and(on);
805        let (left_from_on, right_from_on) = push_down_join_condition(
806            &mut new_on,
807            left_col_num,
808            right_col_num,
809            join_type,
810            push_down_temporal_predicate,
811        );
812
813        let left_predicate = left_from_filter.and(left_from_on);
814        let right_predicate = right_from_filter.and(right_from_on);
815
816        // Derive conditions to push to the other side based on eq condition columns
817        let eq_condition = EqJoinPredicate::create(left_col_num, right_col_num, new_on.clone());
818
819        // Only push to RHS if RHS is inner side of a join (RHS requires match on LHS)
820        let right_from_left = if matches!(
821            join_type,
822            JoinType::Inner | JoinType::LeftOuter | JoinType::RightSemi | JoinType::LeftSemi
823        ) {
824            Condition {
825                conjunctions: left_predicate
826                    .conjunctions
827                    .iter()
828                    .filter_map(|expr| {
829                        derive_predicate_from_eq_condition(expr, &eq_condition, left_col_num, true)
830                    })
831                    .collect(),
832            }
833        } else {
834            Condition::true_cond()
835        };
836
837        // Only push to LHS if LHS is inner side of a join (LHS requires match on RHS)
838        let left_from_right = if matches!(
839            join_type,
840            JoinType::Inner | JoinType::RightOuter | JoinType::LeftSemi | JoinType::RightSemi
841        ) {
842            Condition {
843                conjunctions: right_predicate
844                    .conjunctions
845                    .iter()
846                    .filter_map(|expr| {
847                        derive_predicate_from_eq_condition(
848                            expr,
849                            &eq_condition,
850                            right_col_num,
851                            false,
852                        )
853                    })
854                    .collect(),
855            }
856        } else {
857            Condition::true_cond()
858        };
859
860        let left_predicate = left_predicate.and(left_from_right);
861        let right_predicate = right_predicate.and(right_from_left);
862
863        let new_left = self.left().predicate_pushdown(left_predicate, ctx);
864        let new_right = self.right().predicate_pushdown(right_predicate, ctx);
865        let new_join = LogicalJoin::with_output_indices(
866            new_left,
867            new_right,
868            join_type,
869            new_on,
870            self.output_indices().clone(),
871        );
872
873        let mut mapping = self.core.i2o_col_mapping();
874        predicate = predicate.rewrite_expr(&mut mapping);
875        LogicalFilter::create(new_join.into(), predicate)
876    }
877}
878
879impl LogicalJoin {
880    fn get_stream_input_for_hash_join(
881        &self,
882        predicate: &EqJoinPredicate,
883        ctx: &mut ToStreamContext,
884    ) -> Result<(StreamPlanRef, StreamPlanRef)> {
885        use super::stream::prelude::*;
886
887        let mut right = self.right().to_stream_with_dist_required(
888            &RequiredDist::shard_by_key(self.right().schema().len(), &predicate.right_eq_indexes()),
889            ctx,
890        )?;
891        let logical_left = self.left();
892
893        let r2l =
894            predicate.r2l_eq_columns_mapping(logical_left.schema().len(), right.schema().len());
895        let l2r =
896            predicate.l2r_eq_columns_mapping(logical_left.schema().len(), right.schema().len());
897
898        let mut left;
899        let right_dist = right.distribution();
900        match right_dist {
901            Distribution::HashShard(_) => {
902                let left_dist = r2l
903                    .rewrite_required_distribution(&RequiredDist::PhysicalDist(right_dist.clone()));
904                left = self
905                    .core
906                    .left
907                    .to_stream_with_dist_required(&left_dist, ctx)?;
908            }
909            Distribution::UpstreamHashShard(_, _) => {
910                left = self.core.left.to_stream_with_dist_required(
911                    &RequiredDist::shard_by_key(
912                        self.left().schema().len(),
913                        &predicate.left_eq_indexes(),
914                    ),
915                    ctx,
916                )?;
917                let left_dist = left.distribution();
918                match left_dist {
919                    Distribution::HashShard(_) => {
920                        let right_dist = l2r.rewrite_required_distribution(
921                            &RequiredDist::PhysicalDist(left_dist.clone()),
922                        );
923                        right = right_dist.streaming_enforce_if_not_satisfies(right)?
924                    }
925                    Distribution::UpstreamHashShard(_, _) => {
926                        left = RequiredDist::hash_shard(&predicate.left_eq_indexes())
927                            .streaming_enforce_if_not_satisfies(left)?;
928                        right = RequiredDist::hash_shard(&predicate.right_eq_indexes())
929                            .streaming_enforce_if_not_satisfies(right)?;
930                    }
931                    _ => unreachable!(),
932                }
933            }
934            _ => unreachable!(),
935        }
936        Ok((left, right))
937    }
938
939    fn to_stream_hash_join(
940        &self,
941        predicate: EqJoinPredicate,
942        ctx: &mut ToStreamContext,
943    ) -> Result<StreamPlanRef> {
944        use super::stream::prelude::*;
945
946        assert!(predicate.has_eq());
947        let (left, right) = self.get_stream_input_for_hash_join(&predicate, ctx)?;
948
949        let core = self.core.clone_with_inputs(left, right);
950
951        // Convert to Hash Join for equal joins
952        // For inner joins, pull non-equal conditions to a filter operator on top of it by default.
953        // We do so as the filter operator can apply the non-equal condition batch-wise (vectorized)
954        // as opposed to the HashJoin, which applies the condition row-wise.
955        // However, the default behavior of pulling up non-equal conditions can be overridden by the
956        // session variable `streaming_force_filter_inside_join` as it can save unnecessary
957        // materialization of rows only to be filtered later.
958
959        let stream_hash_join = StreamHashJoin::new(core.clone(), predicate.clone())?;
960
961        let force_filter_inside_join = self
962            .base
963            .ctx()
964            .session_ctx()
965            .config()
966            .streaming_force_filter_inside_join();
967
968        let pull_filter = self.join_type() == JoinType::Inner
969            && stream_hash_join.eq_join_predicate().has_non_eq()
970            && stream_hash_join.inequality_pairs().is_empty()
971            && (!force_filter_inside_join);
972        if pull_filter {
973            let default_indices = (0..self.internal_column_num()).collect::<Vec<_>>();
974
975            let mut core = core.clone();
976            core.output_indices = default_indices.clone();
977            // Temporarily remove output indices.
978            let eq_cond = EqJoinPredicate::new(
979                Condition::true_cond(),
980                predicate.eq_keys().to_vec(),
981                self.left().schema().len(),
982                self.right().schema().len(),
983            );
984            core.on = eq_cond.eq_cond();
985            let hash_join = StreamHashJoin::new(core, eq_cond)?.into();
986            let logical_filter = generic::Filter::new(predicate.non_eq_cond(), hash_join);
987            let plan = StreamFilter::new(logical_filter).into();
988            if self.output_indices() != &default_indices {
989                let logical_project = generic::Project::with_mapping(
990                    plan,
991                    ColIndexMapping::with_remaining_columns(
992                        self.output_indices(),
993                        self.internal_column_num(),
994                    ),
995                );
996                Ok(StreamProject::new(logical_project).into())
997            } else {
998                Ok(plan)
999            }
1000        } else {
1001            Ok(stream_hash_join.into())
1002        }
1003    }
1004
1005    fn should_be_temporal_join(&self) -> bool {
1006        let right = self.right();
1007        if let Some(logical_scan) = right.as_logical_scan() {
1008            matches!(logical_scan.as_of(), Some(AsOf::ProcessTime))
1009        } else {
1010            false
1011        }
1012    }
1013
1014    fn to_stream_temporal_join_with_index_selection(
1015        &self,
1016        predicate: EqJoinPredicate,
1017        ctx: &mut ToStreamContext,
1018    ) -> Result<StreamPlanRef> {
1019        // Index selection for temporal join.
1020        let right = self.right();
1021        // `should_be_temporal_join()` has already check right input for us.
1022        let logical_scan: &LogicalScan = right.as_logical_scan().unwrap();
1023
1024        // Use primary table.
1025        let mut result_plan: Result<StreamTemporalJoin> =
1026            self.to_stream_temporal_join(predicate.clone(), ctx);
1027        // Return directly if this temporal join can match the pk of its right table.
1028        if let Ok(temporal_join) = &result_plan
1029            && temporal_join.eq_join_predicate().eq_indexes().len()
1030                == logical_scan.primary_key().len()
1031        {
1032            return result_plan.map(|x| x.into());
1033        }
1034        let indexes = logical_scan.table_indexes();
1035        for index in indexes {
1036            // Use index table
1037            if let Some(index_scan) = logical_scan.to_index_scan_if_index_covered(index) {
1038                let index_scan: PlanRef = index_scan.into();
1039                let that = self.clone_with_left_right(self.left(), index_scan.clone());
1040                if let Ok(temporal_join) = that.to_stream_temporal_join(predicate.clone(), ctx) {
1041                    match &result_plan {
1042                        Err(_) => result_plan = Ok(temporal_join),
1043                        Ok(prev_temporal_join) => {
1044                            // Prefer to the temporal join with a longer lookup prefix len.
1045                            if prev_temporal_join.eq_join_predicate().eq_indexes().len()
1046                                < temporal_join.eq_join_predicate().eq_indexes().len()
1047                            {
1048                                result_plan = Ok(temporal_join)
1049                            }
1050                        }
1051                    }
1052                }
1053            }
1054        }
1055
1056        result_plan.map(|x| x.into())
1057    }
1058
1059    fn check_temporal_rhs(right: &PlanRef) -> Result<&LogicalScan> {
1060        let Some(logical_scan) = right.as_logical_scan() else {
1061            return Err(RwError::from(ErrorCode::NotSupported(
1062                "Temporal join requires a table scan as its lookup table".into(),
1063                "Please provide a table scan".into(),
1064            )));
1065        };
1066
1067        if !matches!(logical_scan.as_of(), Some(AsOf::ProcessTime)) {
1068            return Err(RwError::from(ErrorCode::NotSupported(
1069                "Temporal join requires a table defined as temporal table".into(),
1070                "Please use FOR SYSTEM_TIME AS OF PROCTIME() syntax".into(),
1071            )));
1072        }
1073        Ok(logical_scan)
1074    }
1075
1076    fn temporal_join_scan_predicate_pull_up(
1077        logical_scan: &LogicalScan,
1078        predicate: EqJoinPredicate,
1079        output_indices: &[usize],
1080        left_schema_len: usize,
1081    ) -> Result<(StreamTableScan, EqJoinPredicate, Condition, Vec<usize>)> {
1082        // Extract the predicate from logical scan. Only pure scan is supported.
1083        let (new_scan, scan_predicate, project_expr) = logical_scan.predicate_pull_up();
1084        // Construct output column to require column mapping
1085        let o2r = if let Some(project_expr) = project_expr {
1086            project_expr
1087                .into_iter()
1088                .map(|x| x.as_input_ref().unwrap().index)
1089                .collect_vec()
1090        } else {
1091            (0..logical_scan.output_col_idx().len()).collect_vec()
1092        };
1093        let mut join_predicate_rewriter = LookupJoinPredicateRewriter {
1094            offset: left_schema_len,
1095            mapping: o2r.clone(),
1096        };
1097
1098        let new_eq_cond = predicate
1099            .eq_cond()
1100            .rewrite_expr(&mut join_predicate_rewriter);
1101
1102        let mut scan_predicate_rewriter = LookupJoinScanPredicateRewriter {
1103            offset: left_schema_len,
1104        };
1105
1106        let new_other_cond = predicate
1107            .other_cond()
1108            .clone()
1109            .rewrite_expr(&mut join_predicate_rewriter)
1110            .and(scan_predicate.rewrite_expr(&mut scan_predicate_rewriter));
1111
1112        let new_join_on = new_eq_cond.and(new_other_cond);
1113
1114        let new_predicate = EqJoinPredicate::create(
1115            left_schema_len,
1116            new_scan.schema().len(),
1117            new_join_on.clone(),
1118        );
1119
1120        // Rewrite the join output indices and all output indices referred to the old scan need to
1121        // rewrite.
1122        let new_join_output_indices = output_indices
1123            .iter()
1124            .map(|&x| {
1125                if x < left_schema_len {
1126                    x
1127                } else {
1128                    o2r[x - left_schema_len] + left_schema_len
1129                }
1130            })
1131            .collect_vec();
1132        // Use UpstreamOnly chain type
1133        let new_stream_table_scan =
1134            StreamTableScan::new_with_stream_scan_type(new_scan, StreamScanType::UpstreamOnly);
1135        Ok((
1136            new_stream_table_scan,
1137            new_predicate,
1138            new_join_on,
1139            new_join_output_indices,
1140        ))
1141    }
1142
1143    fn to_stream_temporal_join(
1144        &self,
1145        predicate: EqJoinPredicate,
1146        ctx: &mut ToStreamContext,
1147    ) -> Result<StreamTemporalJoin> {
1148        use super::stream::prelude::*;
1149
1150        assert!(predicate.has_eq());
1151
1152        let right = self.right();
1153
1154        let logical_scan = Self::check_temporal_rhs(&right)?;
1155
1156        let table = logical_scan.table();
1157        let output_column_ids = logical_scan.output_column_ids();
1158
1159        // Verify that the right join key columns are the the prefix of the primary key and
1160        // also contain the distribution key.
1161        let order_col_ids = table.order_column_ids();
1162        let dist_key = table.distribution_key.clone();
1163
1164        let mut dist_key_in_order_key_pos = vec![];
1165        for d in dist_key {
1166            let pos = table
1167                .order_column_indices()
1168                .position(|x| x == d)
1169                .expect("dist_key must in order_key");
1170            dist_key_in_order_key_pos.push(pos);
1171        }
1172        // The shortest prefix of order key that contains distribution key.
1173        let shortest_prefix_len = dist_key_in_order_key_pos
1174            .iter()
1175            .max()
1176            .map_or(0, |pos| pos + 1);
1177
1178        // Reorder the join equal predicate to match the order key.
1179        let mut reorder_idx = Vec::with_capacity(shortest_prefix_len);
1180        for order_col_id in order_col_ids {
1181            let mut found = false;
1182            for (i, eq_idx) in predicate.right_eq_indexes().into_iter().enumerate() {
1183                if order_col_id == output_column_ids[eq_idx] {
1184                    reorder_idx.push(i);
1185                    found = true;
1186                    break;
1187                }
1188            }
1189            if !found {
1190                break;
1191            }
1192        }
1193        if reorder_idx.len() < shortest_prefix_len {
1194            // TODO: support index selection for temporal join and refine this error message.
1195            return Err(RwError::from(ErrorCode::NotSupported(
1196                "Temporal join requires the lookup table's primary key contained exactly in the equivalence condition".into(),
1197                "Please add the primary key of the lookup table to the join condition and remove any other conditions".into(),
1198            )));
1199        }
1200        let lookup_prefix_len = reorder_idx.len();
1201        let predicate = predicate.reorder(&reorder_idx);
1202
1203        let required_dist = if dist_key_in_order_key_pos.is_empty() {
1204            RequiredDist::single()
1205        } else {
1206            let left_eq_indexes = predicate.left_eq_indexes();
1207            let left_dist_key = dist_key_in_order_key_pos
1208                .iter()
1209                .map(|pos| left_eq_indexes[*pos])
1210                .collect_vec();
1211
1212            RequiredDist::hash_shard(&left_dist_key)
1213        };
1214
1215        let left = self.left().to_stream(ctx)?;
1216        // Enforce a shuffle for the temporal join LHS to let the scheduler be able to schedule the join fragment together with the RHS with a `no_shuffle` exchange.
1217        let left = required_dist.stream_enforce(left);
1218
1219        let (new_stream_table_scan, new_predicate, new_join_on, new_join_output_indices) =
1220            Self::temporal_join_scan_predicate_pull_up(
1221                logical_scan,
1222                predicate,
1223                self.output_indices(),
1224                self.left().schema().len(),
1225            )?;
1226
1227        let right = RequiredDist::no_shuffle(new_stream_table_scan.into());
1228        if !new_predicate.has_eq() {
1229            return Err(RwError::from(ErrorCode::NotSupported(
1230                "Temporal join requires a non trivial join condition".into(),
1231                "Please remove the false condition of the join".into(),
1232            )));
1233        }
1234
1235        // Construct a new logical join, because we have change its RHS.
1236        let new_logical_join = generic::Join::new(
1237            left,
1238            right,
1239            new_join_on,
1240            self.join_type(),
1241            new_join_output_indices,
1242        );
1243
1244        let new_predicate = new_predicate.retain_prefix_eq_key(lookup_prefix_len);
1245
1246        StreamTemporalJoin::new(new_logical_join, new_predicate, false)
1247    }
1248
1249    fn to_stream_nested_loop_temporal_join(
1250        &self,
1251        predicate: EqJoinPredicate,
1252        ctx: &mut ToStreamContext,
1253    ) -> Result<StreamPlanRef> {
1254        use super::stream::prelude::*;
1255        assert!(!predicate.has_eq());
1256
1257        let left = self.left().to_stream_with_dist_required(
1258            &RequiredDist::PhysicalDist(Distribution::Broadcast),
1259            ctx,
1260        )?;
1261        assert!(left.as_stream_exchange().is_some());
1262
1263        if self.join_type() != JoinType::Inner {
1264            return Err(RwError::from(ErrorCode::NotSupported(
1265                "Temporal join requires an inner join".into(),
1266                "Please use an inner join".into(),
1267            )));
1268        }
1269
1270        if !left.append_only() {
1271            return Err(RwError::from(ErrorCode::NotSupported(
1272                "Nested-loop Temporal join requires the left hash side to be append only".into(),
1273                "Please ensure the left hash side is append only".into(),
1274            )));
1275        }
1276
1277        let right = self.right();
1278        let logical_scan = Self::check_temporal_rhs(&right)?;
1279
1280        let (new_stream_table_scan, new_predicate, new_join_on, new_join_output_indices) =
1281            Self::temporal_join_scan_predicate_pull_up(
1282                logical_scan,
1283                predicate,
1284                self.output_indices(),
1285                self.left().schema().len(),
1286            )?;
1287
1288        let right = RequiredDist::no_shuffle(new_stream_table_scan.into());
1289
1290        // Construct a new logical join, because we have change its RHS.
1291        let new_logical_join = generic::Join::new(
1292            left,
1293            right,
1294            new_join_on,
1295            self.join_type(),
1296            new_join_output_indices,
1297        );
1298
1299        Ok(StreamTemporalJoin::new(new_logical_join, new_predicate, true)?.into())
1300    }
1301
1302    fn to_stream_dynamic_filter(
1303        &self,
1304        predicate: Condition,
1305        ctx: &mut ToStreamContext,
1306    ) -> Result<Option<StreamPlanRef>> {
1307        use super::stream::prelude::*;
1308
1309        // If there is exactly one predicate, it is a comparison (<, <=, >, >=), and the
1310        // join is a `Inner` or `LeftSemi` join, we can convert the scalar subquery into a
1311        // `StreamDynamicFilter`
1312
1313        // Check if `Inner`/`LeftSemi`
1314        if !matches!(self.join_type(), JoinType::Inner | JoinType::LeftSemi) {
1315            return Ok(None);
1316        }
1317
1318        // Check if right side is a scalar
1319        if !self.right().max_one_row() {
1320            return Ok(None);
1321        }
1322        if self.right().schema().len() != 1 {
1323            return Ok(None);
1324        }
1325
1326        // Check if the join condition is a correlated comparison
1327        if predicate.conjunctions.len() > 1 {
1328            return Ok(None);
1329        }
1330        let expr: ExprImpl = predicate.into();
1331        let (left_ref, comparator, right_ref) = match expr.as_comparison_cond() {
1332            Some(v) => v,
1333            None => return Ok(None),
1334        };
1335
1336        let condition_cross_inputs = left_ref.index < self.left().schema().len()
1337            && right_ref.index == self.left().schema().len() /* right side has only one column */;
1338        if !condition_cross_inputs {
1339            // Maybe we should panic here because it means some predicates are not pushed down.
1340            return Ok(None);
1341        }
1342
1343        // We align input types on all join predicates with cmp operator
1344        if self.left().schema().fields()[left_ref.index].data_type
1345            != self.right().schema().fields()[0].data_type
1346        {
1347            return Ok(None);
1348        }
1349
1350        // Check if non of the columns from the inner side is required to output
1351        let all_output_from_left = self
1352            .output_indices()
1353            .iter()
1354            .all(|i| *i < self.left().schema().len());
1355        if !all_output_from_left {
1356            return Ok(None);
1357        }
1358
1359        let left = self.left().to_stream(ctx)?.enforce_concrete_distribution();
1360        let right = self.right().to_stream_with_dist_required(
1361            &RequiredDist::PhysicalDist(Distribution::Broadcast),
1362            ctx,
1363        )?;
1364
1365        assert!(right.as_stream_exchange().is_some());
1366        assert_eq!(
1367            *right.inputs().iter().exactly_one().unwrap().distribution(),
1368            Distribution::Single
1369        );
1370
1371        let core = DynamicFilter::new(comparator, left_ref.index, left, right);
1372        let plan = StreamDynamicFilter::new(core)?.into();
1373        // TODO: `DynamicFilterExecutor` should support `output_indices` in `ChunkBuilder`
1374        if self
1375            .output_indices()
1376            .iter()
1377            .copied()
1378            .ne(0..self.left().schema().len())
1379        {
1380            // The schema of dynamic filter is always the same as the left side now, and we have
1381            // checked that all output columns are from the left side before.
1382            let logical_project = generic::Project::with_mapping(
1383                plan,
1384                ColIndexMapping::with_remaining_columns(
1385                    self.output_indices(),
1386                    self.left().schema().len(),
1387                ),
1388            );
1389            Ok(Some(StreamProject::new(logical_project).into()))
1390        } else {
1391            Ok(Some(plan))
1392        }
1393    }
1394
1395    pub fn index_lookup_join_to_batch_lookup_join(&self) -> Result<BatchPlanRef> {
1396        let predicate = EqJoinPredicate::create(
1397            self.left().schema().len(),
1398            self.right().schema().len(),
1399            self.on().clone(),
1400        );
1401        assert!(predicate.has_eq());
1402
1403        let join = self
1404            .core
1405            .clone_with_inputs(self.core.left.to_batch()?, self.core.right.to_batch()?);
1406
1407        Ok(self
1408            .to_batch_lookup_join(predicate, join)?
1409            .expect("Fail to convert to lookup join")
1410            .into())
1411    }
1412
1413    fn to_stream_asof_join(
1414        &self,
1415        predicate: EqJoinPredicate,
1416        ctx: &mut ToStreamContext,
1417    ) -> Result<StreamPlanRef> {
1418        use super::stream::prelude::*;
1419
1420        if predicate.eq_keys().is_empty() {
1421            return Err(ErrorCode::InvalidInputSyntax(
1422                "AsOf join requires at least 1 equal condition".to_owned(),
1423            )
1424            .into());
1425        }
1426
1427        let (left, right) = self.get_stream_input_for_hash_join(&predicate, ctx)?;
1428        let left_len = left.schema().len();
1429        let core = self.core.clone_with_inputs(left, right);
1430
1431        let inequality_desc =
1432            Self::get_inequality_desc_from_predicate(predicate.other_cond().clone(), left_len)?;
1433
1434        Ok(StreamAsOfJoin::new(core, predicate, inequality_desc)?.into())
1435    }
1436
1437    /// Convert the logical join to a Hash join.
1438    fn to_batch_hash_join(
1439        &self,
1440        logical_join: generic::Join<BatchPlanRef>,
1441        predicate: EqJoinPredicate,
1442    ) -> Result<BatchPlanRef> {
1443        use super::batch::prelude::*;
1444
1445        let left_schema_len = logical_join.left.schema().len();
1446        let asof_desc = self
1447            .is_asof_join()
1448            .then(|| {
1449                Self::get_inequality_desc_from_predicate(
1450                    predicate.other_cond().clone(),
1451                    left_schema_len,
1452                )
1453            })
1454            .transpose()?;
1455
1456        let batch_join = BatchHashJoin::new(logical_join, predicate, asof_desc);
1457        Ok(batch_join.into())
1458    }
1459
1460    pub fn get_inequality_desc_from_predicate(
1461        predicate: Condition,
1462        left_input_len: usize,
1463    ) -> Result<AsOfJoinDesc> {
1464        let expr: ExprImpl = predicate.into();
1465        if let Some((left_input_ref, expr_type, right_input_ref)) = expr.as_comparison_cond() {
1466            if left_input_ref.index() < left_input_len && right_input_ref.index() >= left_input_len
1467            {
1468                Ok(AsOfJoinDesc {
1469                    left_idx: left_input_ref.index() as u32,
1470                    right_idx: (right_input_ref.index() - left_input_len) as u32,
1471                    inequality_type: Self::expr_type_to_comparison_type(expr_type)?.into(),
1472                })
1473            } else {
1474                bail!("inequal condition from the same side should be push down in optimizer");
1475            }
1476        } else {
1477            Err(ErrorCode::InvalidInputSyntax(
1478                "AsOf join requires exactly 1 ineuquality condition".to_owned(),
1479            )
1480            .into())
1481        }
1482    }
1483
1484    fn expr_type_to_comparison_type(expr_type: PbType) -> Result<PbAsOfJoinInequalityType> {
1485        match expr_type {
1486            PbType::LessThan => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeLt),
1487            PbType::LessThanOrEqual => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeLe),
1488            PbType::GreaterThan => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeGt),
1489            PbType::GreaterThanOrEqual => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeGe),
1490            _ => Err(ErrorCode::InvalidInputSyntax(format!(
1491                "Invalid comparison type: {}",
1492                expr_type.as_str_name()
1493            ))
1494            .into()),
1495        }
1496    }
1497}
1498
1499impl ToBatch for LogicalJoin {
1500    fn to_batch(&self) -> Result<crate::optimizer::plan_node::BatchPlanRef> {
1501        let predicate = EqJoinPredicate::create(
1502            self.left().schema().len(),
1503            self.right().schema().len(),
1504            self.on().clone(),
1505        );
1506
1507        let batch_join = self
1508            .core
1509            .clone_with_inputs(self.core.left.to_batch()?, self.core.right.to_batch()?);
1510
1511        let ctx = self.base.ctx();
1512        let config = ctx.session_ctx().config();
1513
1514        if predicate.has_eq() {
1515            if !predicate.eq_keys_are_type_aligned() {
1516                return Err(ErrorCode::InternalError(format!(
1517                    "Join eq keys are not aligned for predicate: {predicate:?}"
1518                ))
1519                .into());
1520            }
1521            if config.batch_enable_lookup_join()
1522                && let Some(lookup_join) = self.to_batch_lookup_join_with_index_selection(
1523                    predicate.clone(),
1524                    batch_join.clone(),
1525                )?
1526            {
1527                return Ok(lookup_join.into());
1528            }
1529            self.to_batch_hash_join(batch_join, predicate)
1530        } else if self.is_asof_join() {
1531            Err(ErrorCode::InvalidInputSyntax(
1532                "AsOf join requires at least 1 equal condition".to_owned(),
1533            )
1534            .into())
1535        } else {
1536            // Convert to Nested-loop Join for non-equal joins
1537            Ok(BatchNestedLoopJoin::new(batch_join).into())
1538        }
1539    }
1540}
1541
1542impl ToStream for LogicalJoin {
1543    fn to_stream(
1544        &self,
1545        ctx: &mut ToStreamContext,
1546    ) -> Result<crate::optimizer::plan_node::StreamPlanRef> {
1547        if self
1548            .on()
1549            .conjunctions
1550            .iter()
1551            .any(|cond| cond.count_nows() > 0)
1552        {
1553            return Err(ErrorCode::NotSupported(
1554                "optimizer has tried to separate the temporal predicate(with now() expression) from the on condition, but it still reminded in on join's condition. Considering move it into WHERE clause?".to_owned(),
1555                 "please refer to https://www.risingwave.dev/docs/current/sql-pattern-temporal-filters/ for more information".to_owned()).into());
1556        }
1557
1558        let predicate = EqJoinPredicate::create(
1559            self.left().schema().len(),
1560            self.right().schema().len(),
1561            self.on().clone(),
1562        );
1563
1564        if self.join_type() == JoinType::AsofInner || self.join_type() == JoinType::AsofLeftOuter {
1565            self.to_stream_asof_join(predicate, ctx)
1566        } else if predicate.has_eq() {
1567            if !predicate.eq_keys_are_type_aligned() {
1568                return Err(ErrorCode::InternalError(format!(
1569                    "Join eq keys are not aligned for predicate: {predicate:?}"
1570                ))
1571                .into());
1572            }
1573
1574            if self.should_be_temporal_join() {
1575                self.to_stream_temporal_join_with_index_selection(predicate, ctx)
1576            } else {
1577                self.to_stream_hash_join(predicate, ctx)
1578            }
1579        } else if self.should_be_temporal_join() {
1580            self.to_stream_nested_loop_temporal_join(predicate, ctx)
1581        } else if let Some(dynamic_filter) =
1582            self.to_stream_dynamic_filter(self.on().clone(), ctx)?
1583        {
1584            Ok(dynamic_filter)
1585        } else {
1586            Err(RwError::from(ErrorCode::NotSupported(
1587                "streaming nested-loop join".to_owned(),
1588                "The non-equal join in the query requires a nested-loop join executor, which could be very expensive to run. \
1589                 Consider rewriting the query to use dynamic filter as a substitute if possible.\n\
1590                 See also: https://docs.risingwave.com/docs/current/sql-pattern-dynamic-filters/".to_owned(),
1591            )))
1592        }
1593    }
1594
1595    fn logical_rewrite_for_stream(
1596        &self,
1597        ctx: &mut RewriteStreamContext,
1598    ) -> Result<(PlanRef, ColIndexMapping)> {
1599        let (left, left_col_change) = self.left().logical_rewrite_for_stream(ctx)?;
1600        let left_len = left.schema().len();
1601        let (right, right_col_change) = self.right().logical_rewrite_for_stream(ctx)?;
1602        let (join, out_col_change) = self.rewrite_with_left_right(
1603            left.clone(),
1604            left_col_change,
1605            right.clone(),
1606            right_col_change,
1607        );
1608
1609        let mapping = ColIndexMapping::with_remaining_columns(
1610            join.output_indices(),
1611            join.internal_column_num(),
1612        );
1613
1614        let l2o = join.core.l2i_col_mapping().composite(&mapping);
1615        let r2o = join.core.r2i_col_mapping().composite(&mapping);
1616
1617        // Add missing pk indices to the logical join
1618        let mut left_to_add = left
1619            .expect_stream_key()
1620            .iter()
1621            .cloned()
1622            .filter(|i| l2o.try_map(*i).is_none())
1623            .collect_vec();
1624
1625        let mut right_to_add = right
1626            .expect_stream_key()
1627            .iter()
1628            .filter(|&&i| r2o.try_map(i).is_none())
1629            .map(|&i| i + left_len)
1630            .collect_vec();
1631
1632        // NOTE(st1page): add join keys in the pk_indices a work around before we really have stream
1633        // key.
1634        let right_len = right.schema().len();
1635        let eq_predicate = EqJoinPredicate::create(left_len, right_len, join.on().clone());
1636
1637        let either_or_both = self.core.add_which_join_key_to_pk();
1638
1639        for (lk, rk) in eq_predicate.eq_indexes() {
1640            match either_or_both {
1641                EitherOrBoth::Left(_) => {
1642                    if l2o.try_map(lk).is_none() {
1643                        left_to_add.push(lk);
1644                    }
1645                }
1646                EitherOrBoth::Right(_) => {
1647                    if r2o.try_map(rk).is_none() {
1648                        right_to_add.push(rk + left_len)
1649                    }
1650                }
1651                EitherOrBoth::Both(_, _) => {
1652                    if l2o.try_map(lk).is_none() {
1653                        left_to_add.push(lk);
1654                    }
1655                    if r2o.try_map(rk).is_none() {
1656                        right_to_add.push(rk + left_len)
1657                    }
1658                }
1659            };
1660        }
1661        let left_to_add = left_to_add.into_iter().unique();
1662        let right_to_add = right_to_add.into_iter().unique();
1663        // NOTE(st1page) over
1664
1665        let mut new_output_indices = join.output_indices().clone();
1666        if !join.is_right_join() {
1667            new_output_indices.extend(left_to_add);
1668        }
1669        if !join.is_left_join() {
1670            new_output_indices.extend(right_to_add);
1671        }
1672
1673        let join_with_pk = join.clone_with_output_indices(new_output_indices);
1674
1675        let plan = if join_with_pk.join_type() == JoinType::FullOuter {
1676            // ignore the all NULL to maintain the stream key's uniqueness, see https://github.com/risingwavelabs/risingwave/issues/8084 for more information
1677
1678            let l2o = join_with_pk
1679                .core
1680                .l2i_col_mapping()
1681                .composite(&join_with_pk.core.i2o_col_mapping());
1682            let r2o = join_with_pk
1683                .core
1684                .r2i_col_mapping()
1685                .composite(&join_with_pk.core.i2o_col_mapping());
1686            let left_right_stream_keys = join_with_pk
1687                .left()
1688                .expect_stream_key()
1689                .iter()
1690                .map(|i| l2o.map(*i))
1691                .chain(
1692                    join_with_pk
1693                        .right()
1694                        .expect_stream_key()
1695                        .iter()
1696                        .map(|i| r2o.map(*i)),
1697                )
1698                .collect_vec();
1699            let plan: PlanRef = join_with_pk.into();
1700            LogicalFilter::filter_out_all_null_keys(plan, &left_right_stream_keys)
1701        } else {
1702            join_with_pk.into()
1703        };
1704
1705        // the added columns is at the end, so it will not change the exists column index
1706        Ok((plan, out_col_change))
1707    }
1708}
1709
1710#[cfg(test)]
1711mod tests {
1712
1713    use std::collections::HashSet;
1714
1715    use risingwave_common::catalog::{Field, Schema};
1716    use risingwave_common::types::{DataType, Datum};
1717    use risingwave_pb::expr::expr_node::Type;
1718
1719    use super::*;
1720    use crate::expr::{FunctionCall, Literal, assert_eq_input_ref};
1721    use crate::optimizer::optimizer_context::OptimizerContext;
1722    use crate::optimizer::plan_node::LogicalValues;
1723    use crate::optimizer::property::FunctionalDependency;
1724
1725    /// Pruning
1726    /// ```text
1727    /// Join(on: input_ref(1)=input_ref(3))
1728    ///   TableScan(v1, v2, v3)
1729    ///   TableScan(v4, v5, v6)
1730    /// ```
1731    /// with required columns [2,3] will result in
1732    /// ```text
1733    /// Project(input_ref(1), input_ref(2))
1734    ///   Join(on: input_ref(0)=input_ref(2))
1735    ///     TableScan(v2, v3)
1736    ///     TableScan(v4)
1737    /// ```
1738    #[tokio::test]
1739    async fn test_prune_join() {
1740        let ty = DataType::Int32;
1741        let ctx = OptimizerContext::mock().await;
1742        let fields: Vec<Field> = (1..7)
1743            .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
1744            .collect();
1745        let left = LogicalValues::new(
1746            vec![],
1747            Schema {
1748                fields: fields[0..3].to_vec(),
1749            },
1750            ctx.clone(),
1751        );
1752        let right = LogicalValues::new(
1753            vec![],
1754            Schema {
1755                fields: fields[3..6].to_vec(),
1756            },
1757            ctx,
1758        );
1759        let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
1760            FunctionCall::new(
1761                Type::Equal,
1762                vec![
1763                    ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
1764                    ExprImpl::InputRef(Box::new(InputRef::new(3, ty))),
1765                ],
1766            )
1767            .unwrap(),
1768        ));
1769        let join_type = JoinType::Inner;
1770        let join: PlanRef = LogicalJoin::new(
1771            left.into(),
1772            right.into(),
1773            join_type,
1774            Condition::with_expr(on),
1775        )
1776        .into();
1777
1778        // Perform the prune
1779        let required_cols = vec![2, 3];
1780        let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1781
1782        // Check the result
1783        let join = plan.as_logical_join().unwrap();
1784        assert_eq!(join.schema().fields().len(), 2);
1785        assert_eq!(join.schema().fields()[0], fields[2]);
1786        assert_eq!(join.schema().fields()[1], fields[3]);
1787
1788        let expr: ExprImpl = join.on().clone().into();
1789        let call = expr.as_function_call().unwrap();
1790        assert_eq_input_ref!(&call.inputs()[0], 0);
1791        assert_eq_input_ref!(&call.inputs()[1], 2);
1792
1793        let left = join.left();
1794        let left = left.as_logical_values().unwrap();
1795        assert_eq!(left.schema().fields(), &fields[1..3]);
1796        let right = join.right();
1797        let right = right.as_logical_values().unwrap();
1798        assert_eq!(right.schema().fields(), &fields[3..4]);
1799    }
1800
1801    /// Semi join panicked previously at `prune_col`. Add test to prevent regression.
1802    #[tokio::test]
1803    async fn test_prune_semi_join() {
1804        let ty = DataType::Int32;
1805        let ctx = OptimizerContext::mock().await;
1806        let fields: Vec<Field> = (1..7)
1807            .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
1808            .collect();
1809        let left = LogicalValues::new(
1810            vec![],
1811            Schema {
1812                fields: fields[0..3].to_vec(),
1813            },
1814            ctx.clone(),
1815        );
1816        let right = LogicalValues::new(
1817            vec![],
1818            Schema {
1819                fields: fields[3..6].to_vec(),
1820            },
1821            ctx,
1822        );
1823        let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
1824            FunctionCall::new(
1825                Type::Equal,
1826                vec![
1827                    ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
1828                    ExprImpl::InputRef(Box::new(InputRef::new(4, ty))),
1829                ],
1830            )
1831            .unwrap(),
1832        ));
1833        for join_type in [
1834            JoinType::LeftSemi,
1835            JoinType::RightSemi,
1836            JoinType::LeftAnti,
1837            JoinType::RightAnti,
1838        ] {
1839            let join = LogicalJoin::new(
1840                left.clone().into(),
1841                right.clone().into(),
1842                join_type,
1843                Condition::with_expr(on.clone()),
1844            );
1845
1846            let offset = if join.is_right_join() { 3 } else { 0 };
1847            let join: PlanRef = join.into();
1848            // Perform the prune
1849            let required_cols = vec![0];
1850            // key 0 is never used in the join (always key 1)
1851            let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1852            let as_plan = plan.as_logical_join().unwrap();
1853            // Check the result
1854            assert_eq!(as_plan.schema().fields().len(), 1);
1855            assert_eq!(as_plan.schema().fields()[0], fields[offset]);
1856
1857            // Perform the prune
1858            let required_cols = vec![0, 1, 2];
1859            // should not panic here
1860            let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1861            let as_plan = plan.as_logical_join().unwrap();
1862            // Check the result
1863            assert_eq!(as_plan.schema().fields().len(), 3);
1864            assert_eq!(as_plan.schema().fields()[0], fields[offset]);
1865            assert_eq!(as_plan.schema().fields()[1], fields[offset + 1]);
1866            assert_eq!(as_plan.schema().fields()[2], fields[offset + 2]);
1867        }
1868    }
1869
1870    /// Pruning
1871    /// ```text
1872    /// Join(on: input_ref(1)=input_ref(3))
1873    ///   TableScan(v1, v2, v3)
1874    ///   TableScan(v4, v5, v6)
1875    /// ```
1876    /// with required columns [1, 3] will result in
1877    /// ```text
1878    /// Join(on: input_ref(0)=input_ref(1))
1879    ///   TableScan(v2)
1880    ///   TableScan(v4)
1881    /// ```
1882    #[tokio::test]
1883    async fn test_prune_join_no_project() {
1884        let ty = DataType::Int32;
1885        let ctx = OptimizerContext::mock().await;
1886        let fields: Vec<Field> = (1..7)
1887            .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
1888            .collect();
1889        let left = LogicalValues::new(
1890            vec![],
1891            Schema {
1892                fields: fields[0..3].to_vec(),
1893            },
1894            ctx.clone(),
1895        );
1896        let right = LogicalValues::new(
1897            vec![],
1898            Schema {
1899                fields: fields[3..6].to_vec(),
1900            },
1901            ctx,
1902        );
1903        let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
1904            FunctionCall::new(
1905                Type::Equal,
1906                vec![
1907                    ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
1908                    ExprImpl::InputRef(Box::new(InputRef::new(3, ty))),
1909                ],
1910            )
1911            .unwrap(),
1912        ));
1913        let join_type = JoinType::Inner;
1914        let join: PlanRef = LogicalJoin::new(
1915            left.into(),
1916            right.into(),
1917            join_type,
1918            Condition::with_expr(on),
1919        )
1920        .into();
1921
1922        // Perform the prune
1923        let required_cols = vec![1, 3];
1924        let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1925
1926        // Check the result
1927        let join = plan.as_logical_join().unwrap();
1928        assert_eq!(join.schema().fields().len(), 2);
1929        assert_eq!(join.schema().fields()[0], fields[1]);
1930        assert_eq!(join.schema().fields()[1], fields[3]);
1931
1932        let expr: ExprImpl = join.on().clone().into();
1933        let call = expr.as_function_call().unwrap();
1934        assert_eq_input_ref!(&call.inputs()[0], 0);
1935        assert_eq_input_ref!(&call.inputs()[1], 1);
1936
1937        let left = join.left();
1938        let left = left.as_logical_values().unwrap();
1939        assert_eq!(left.schema().fields(), &fields[1..2]);
1940        let right = join.right();
1941        let right = right.as_logical_values().unwrap();
1942        assert_eq!(right.schema().fields(), &fields[3..4]);
1943    }
1944
1945    /// Convert
1946    /// ```text
1947    /// Join(on: ($1 = $3) AND ($2 == 42))
1948    ///   TableScan(v1, v2, v3)
1949    ///   TableScan(v4, v5, v6)
1950    /// ```
1951    /// to
1952    /// ```text
1953    /// Filter($2 == 42)
1954    ///   HashJoin(on: $1 = $3)
1955    ///     TableScan(v1, v2, v3)
1956    ///     TableScan(v4, v5, v6)
1957    /// ```
1958    #[tokio::test]
1959    async fn test_join_to_batch() {
1960        let ctx = OptimizerContext::mock().await;
1961        let fields: Vec<Field> = (1..7)
1962            .map(|i| Field::with_name(DataType::Int32, format!("v{}", i)))
1963            .collect();
1964        let left = LogicalValues::new(
1965            vec![],
1966            Schema {
1967                fields: fields[0..3].to_vec(),
1968            },
1969            ctx.clone(),
1970        );
1971        let right = LogicalValues::new(
1972            vec![],
1973            Schema {
1974                fields: fields[3..6].to_vec(),
1975            },
1976            ctx,
1977        );
1978
1979        fn input_ref(i: usize) -> ExprImpl {
1980            ExprImpl::InputRef(Box::new(InputRef::new(i, DataType::Int32)))
1981        }
1982        let eq_cond = ExprImpl::FunctionCall(Box::new(
1983            FunctionCall::new(Type::Equal, vec![input_ref(1), input_ref(3)]).unwrap(),
1984        ));
1985        let non_eq_cond = ExprImpl::FunctionCall(Box::new(
1986            FunctionCall::new(
1987                Type::Equal,
1988                vec![
1989                    input_ref(2),
1990                    ExprImpl::Literal(Box::new(Literal::new(
1991                        Datum::Some(42_i32.into()),
1992                        DataType::Int32,
1993                    ))),
1994                ],
1995            )
1996            .unwrap(),
1997        ));
1998        // Condition: ($1 = $3) AND ($2 == 42)
1999        let on_cond = ExprImpl::FunctionCall(Box::new(
2000            FunctionCall::new(Type::And, vec![eq_cond.clone(), non_eq_cond.clone()]).unwrap(),
2001        ));
2002
2003        let join_type = JoinType::Inner;
2004        let logical_join = LogicalJoin::new(
2005            left.into(),
2006            right.into(),
2007            join_type,
2008            Condition::with_expr(on_cond),
2009        );
2010
2011        // Perform `to_batch`
2012        let result = logical_join.to_batch().unwrap();
2013
2014        // Expected plan:  HashJoin($1 = $3 AND $2 == 42)
2015        let hash_join = result.as_batch_hash_join().unwrap();
2016        assert_eq!(
2017            ExprImpl::from(hash_join.eq_join_predicate().eq_cond()),
2018            eq_cond
2019        );
2020        assert_eq!(
2021            *hash_join
2022                .eq_join_predicate()
2023                .non_eq_cond()
2024                .conjunctions
2025                .first()
2026                .unwrap(),
2027            non_eq_cond
2028        );
2029    }
2030
2031    /// Convert
2032    /// ```text
2033    /// Join(join_type: left outer, on: ($1 = $3) AND ($2 == 42))
2034    ///   TableScan(v1, v2, v3)
2035    ///   TableScan(v4, v5, v6)
2036    /// ```
2037    /// to
2038    /// ```text
2039    /// HashJoin(join_type: left outer, on: ($1 = $3) AND ($2 == 42))
2040    ///   TableScan(v1, v2, v3)
2041    ///   TableScan(v4, v5, v6)
2042    /// ```
2043    #[tokio::test]
2044    #[ignore] // ignore due to refactor logical scan, but the test seem to duplicate with the explain test
2045    // framework, maybe we will remove it?
2046    async fn test_join_to_stream() {
2047        // let ctx = Rc::new(RefCell::new(QueryContext::mock().await));
2048        // let fields: Vec<Field> = (1..7)
2049        //     .map(|i| Field {
2050        //         data_type: DataType::Int32,
2051        //         name: format!("v{}", i),
2052        //     })
2053        //     .collect();
2054        // let left = LogicalScan::new(
2055        //     "left".to_string(),
2056        //     TableId::new(0),
2057        //     vec![1.into(), 2.into(), 3.into()],
2058        //     Schema {
2059        //         fields: fields[0..3].to_vec(),
2060        //     },
2061        //     ctx.clone(),
2062        // );
2063        // let right = LogicalScan::new(
2064        //     "right".to_string(),
2065        //     TableId::new(0),
2066        //     vec![4.into(), 5.into(), 6.into()],
2067        //     Schema {
2068        //                 fields: fields[3..6].to_vec(),
2069        //     },
2070        //     ctx,
2071        // );
2072        // let eq_cond = ExprImpl::FunctionCall(Box::new(
2073        //     FunctionCall::new(
2074        //         Type::Equal,
2075        //         vec![
2076        //             ExprImpl::InputRef(Box::new(InputRef::new(1, DataType::Int32))),
2077        //             ExprImpl::InputRef(Box::new(InputRef::new(3, DataType::Int32))),
2078        //         ],
2079        //     )
2080        //     .unwrap(),
2081        // ));
2082        // let non_eq_cond = ExprImpl::FunctionCall(Box::new(
2083        //     FunctionCall::new(
2084        //         Type::Equal,
2085        //         vec![
2086        //             ExprImpl::InputRef(Box::new(InputRef::new(2, DataType::Int32))),
2087        //             ExprImpl::Literal(Box::new(Literal::new(
2088        //                 Datum::Some(42_i32.into()),
2089        //                 DataType::Int32,
2090        //             ))),
2091        //         ],
2092        //     )
2093        //     .unwrap(),
2094        // ));
2095        // // Condition: ($1 = $3) AND ($2 == 42)
2096        // let on_cond = ExprImpl::FunctionCall(Box::new(
2097        //     FunctionCall::new(Type::And, vec![eq_cond, non_eq_cond]).unwrap(),
2098        // ));
2099
2100        // let join_type = JoinType::LeftOuter;
2101        // let logical_join = LogicalJoin::new(
2102        //     left.clone().into(),
2103        //     right.clone().into(),
2104        //     join_type,
2105        //     Condition::with_expr(on_cond.clone()),
2106        // );
2107
2108        // // Perform `to_stream`
2109        // let result = logical_join.to_stream();
2110
2111        // // Expected plan: HashJoin(($1 = $3) AND ($2 == 42))
2112        // let hash_join = result.as_stream_hash_join().unwrap();
2113        // assert_eq!(hash_join.eq_join_predicate().all_cond().as_expr(), on_cond);
2114    }
2115    /// Pruning
2116    /// ```text
2117    /// Join(on: input_ref(1)=input_ref(3))
2118    ///   TableScan(v1, v2, v3)
2119    ///   TableScan(v4, v5, v6)
2120    /// ```
2121    /// with required columns [3, 2] will result in
2122    /// ```text
2123    /// Project(input_ref(2), input_ref(1))
2124    ///   Join(on: input_ref(0)=input_ref(2))
2125    ///     TableScan(v2, v3)
2126    ///     TableScan(v4)
2127    /// ```
2128    #[tokio::test]
2129    async fn test_join_column_prune_with_order_required() {
2130        let ty = DataType::Int32;
2131        let ctx = OptimizerContext::mock().await;
2132        let fields: Vec<Field> = (1..7)
2133            .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
2134            .collect();
2135        let left = LogicalValues::new(
2136            vec![],
2137            Schema {
2138                fields: fields[0..3].to_vec(),
2139            },
2140            ctx.clone(),
2141        );
2142        let right = LogicalValues::new(
2143            vec![],
2144            Schema {
2145                fields: fields[3..6].to_vec(),
2146            },
2147            ctx,
2148        );
2149        let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
2150            FunctionCall::new(
2151                Type::Equal,
2152                vec![
2153                    ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
2154                    ExprImpl::InputRef(Box::new(InputRef::new(3, ty))),
2155                ],
2156            )
2157            .unwrap(),
2158        ));
2159        let join_type = JoinType::Inner;
2160        let join: PlanRef = LogicalJoin::new(
2161            left.into(),
2162            right.into(),
2163            join_type,
2164            Condition::with_expr(on),
2165        )
2166        .into();
2167
2168        // Perform the prune
2169        let required_cols = vec![3, 2];
2170        let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
2171
2172        // Check the result
2173        let join = plan.as_logical_join().unwrap();
2174        assert_eq!(join.schema().fields().len(), 2);
2175        assert_eq!(join.schema().fields()[0], fields[3]);
2176        assert_eq!(join.schema().fields()[1], fields[2]);
2177
2178        let expr: ExprImpl = join.on().clone().into();
2179        let call = expr.as_function_call().unwrap();
2180        assert_eq_input_ref!(&call.inputs()[0], 0);
2181        assert_eq_input_ref!(&call.inputs()[1], 2);
2182
2183        let left = join.left();
2184        let left = left.as_logical_values().unwrap();
2185        assert_eq!(left.schema().fields(), &fields[1..3]);
2186        let right = join.right();
2187        let right = right.as_logical_values().unwrap();
2188        assert_eq!(right.schema().fields(), &fields[3..4]);
2189    }
2190
2191    #[tokio::test]
2192    async fn fd_derivation_inner_outer_join() {
2193        // left: [l0, l1], right: [r0, r1, r2]
2194        // FD: l0 --> l1, r0 --> { r1, r2 }
2195        // On: l0 = 0 AND l1 = r1
2196        //
2197        // Inner Join:
2198        //  Schema: [l0, l1, r0, r1, r2]
2199        //  FD: l0 --> l1, r0 --> { r1, r2 }, {} --> l0, l1 --> r1, r1 --> l1
2200        // Left Outer Join:
2201        //  Schema: [l0, l1, r0, r1, r2]
2202        //  FD: l0 --> l1
2203        // Right Outer Join:
2204        //  Schema: [l0, l1, r0, r1, r2]
2205        //  FD: r0 --> { r1, r2 }
2206        // Full Outer Join:
2207        //  Schema: [l0, l1, r0, r1, r2]
2208        //  FD: empty
2209        // Left Semi/Anti Join:
2210        //  Schema: [l0, l1]
2211        //  FD: l0 --> l1
2212        // Right Semi/Anti Join:
2213        //  Schema: [r0, r1, r2]
2214        //  FD: r0 --> {r1, r2}
2215        let ctx = OptimizerContext::mock().await;
2216        let left = {
2217            let fields: Vec<Field> = vec![
2218                Field::with_name(DataType::Int32, "l0"),
2219                Field::with_name(DataType::Int32, "l1"),
2220            ];
2221            let mut values = LogicalValues::new(vec![], Schema { fields }, ctx.clone());
2222            // 0 --> 1
2223            values
2224                .base
2225                .functional_dependency_mut()
2226                .add_functional_dependency_by_column_indices(&[0], &[1]);
2227            values
2228        };
2229        let right = {
2230            let fields: Vec<Field> = vec![
2231                Field::with_name(DataType::Int32, "r0"),
2232                Field::with_name(DataType::Int32, "r1"),
2233                Field::with_name(DataType::Int32, "r2"),
2234            ];
2235            let mut values = LogicalValues::new(vec![], Schema { fields }, ctx);
2236            // 0 --> 1, 2
2237            values
2238                .base
2239                .functional_dependency_mut()
2240                .add_functional_dependency_by_column_indices(&[0], &[1, 2]);
2241            values
2242        };
2243        // l0 = 0 AND l1 = r1
2244        let on: ExprImpl = FunctionCall::new(
2245            Type::And,
2246            vec![
2247                FunctionCall::new(
2248                    Type::Equal,
2249                    vec![
2250                        InputRef::new(0, DataType::Int32).into(),
2251                        ExprImpl::literal_int(0),
2252                    ],
2253                )
2254                .unwrap()
2255                .into(),
2256                FunctionCall::new(
2257                    Type::Equal,
2258                    vec![
2259                        InputRef::new(1, DataType::Int32).into(),
2260                        InputRef::new(3, DataType::Int32).into(),
2261                    ],
2262                )
2263                .unwrap()
2264                .into(),
2265            ],
2266        )
2267        .unwrap()
2268        .into();
2269        let expected_fd_set = [
2270            (
2271                JoinType::Inner,
2272                [
2273                    // inherit from left
2274                    FunctionalDependency::with_indices(5, &[0], &[1]),
2275                    // inherit from right
2276                    FunctionalDependency::with_indices(5, &[2], &[3, 4]),
2277                    // constant column in join condition
2278                    FunctionalDependency::with_indices(5, &[], &[0]),
2279                    // eq column in join condition
2280                    FunctionalDependency::with_indices(5, &[1], &[3]),
2281                    FunctionalDependency::with_indices(5, &[3], &[1]),
2282                ]
2283                .into_iter()
2284                .collect::<HashSet<_>>(),
2285            ),
2286            (JoinType::FullOuter, HashSet::new()),
2287            (
2288                JoinType::RightOuter,
2289                [
2290                    // inherit from right
2291                    FunctionalDependency::with_indices(5, &[2], &[3, 4]),
2292                ]
2293                .into_iter()
2294                .collect::<HashSet<_>>(),
2295            ),
2296            (
2297                JoinType::LeftOuter,
2298                [
2299                    // inherit from left
2300                    FunctionalDependency::with_indices(5, &[0], &[1]),
2301                ]
2302                .into_iter()
2303                .collect::<HashSet<_>>(),
2304            ),
2305            (
2306                JoinType::LeftSemi,
2307                [
2308                    // inherit from left
2309                    FunctionalDependency::with_indices(2, &[0], &[1]),
2310                ]
2311                .into_iter()
2312                .collect::<HashSet<_>>(),
2313            ),
2314            (
2315                JoinType::LeftAnti,
2316                [
2317                    // inherit from left
2318                    FunctionalDependency::with_indices(2, &[0], &[1]),
2319                ]
2320                .into_iter()
2321                .collect::<HashSet<_>>(),
2322            ),
2323            (
2324                JoinType::RightSemi,
2325                [
2326                    // inherit from right
2327                    FunctionalDependency::with_indices(3, &[0], &[1, 2]),
2328                ]
2329                .into_iter()
2330                .collect::<HashSet<_>>(),
2331            ),
2332            (
2333                JoinType::RightAnti,
2334                [
2335                    // inherit from right
2336                    FunctionalDependency::with_indices(3, &[0], &[1, 2]),
2337                ]
2338                .into_iter()
2339                .collect::<HashSet<_>>(),
2340            ),
2341        ];
2342
2343        for (join_type, expected_res) in expected_fd_set {
2344            let join = LogicalJoin::new(
2345                left.clone().into(),
2346                right.clone().into(),
2347                join_type,
2348                Condition::with_expr(on.clone()),
2349            );
2350            let fd_set = join
2351                .functional_dependency()
2352                .as_dependencies()
2353                .iter()
2354                .cloned()
2355                .collect::<HashSet<_>>();
2356            assert_eq!(fd_set, expected_res);
2357        }
2358    }
2359}