risingwave_frontend/optimizer/plan_node/
logical_join.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::ops::Deref;
17
18use fixedbitset::FixedBitSet;
19use itertools::{EitherOrBoth, Itertools};
20use pretty_xmlish::{Pretty, XmlNode};
21use risingwave_expr::bail;
22use risingwave_pb::expr::expr_node::PbType;
23use risingwave_pb::plan_common::{AsOfJoinDesc, JoinType, PbAsOfJoinInequalityType};
24use risingwave_pb::stream_plan::StreamScanType;
25use risingwave_sqlparser::ast::AsOf;
26
27use super::generic::{
28    GenericPlanNode, GenericPlanRef, push_down_into_join, push_down_join_condition,
29};
30use super::utils::{Distill, childless_record};
31use super::{
32    BackfillType, BatchPlanRef, ColPrunable, ExprRewritable, Logical, LogicalPlanRef as PlanRef,
33    PlanBase, PlanTreeNodeBinary, PredicatePushdown, StreamHashJoin, StreamPlanRef, StreamProject,
34    ToBatch, ToStream, generic, try_enforce_locality_requirement,
35};
36use crate::error::{ErrorCode, Result, RwError};
37use crate::expr::{CollectInputRef, Expr, ExprImpl, ExprRewriter, ExprType, ExprVisitor, InputRef};
38use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
39use crate::optimizer::plan_node::generic::DynamicFilter;
40use crate::optimizer::plan_node::stream_asof_join::StreamAsOfJoin;
41use crate::optimizer::plan_node::utils::IndicesDisplay;
42use crate::optimizer::plan_node::{
43    BatchHashJoin, BatchLookupJoin, BatchNestedLoopJoin, ColumnPruningContext, EqJoinPredicate,
44    LogicalFilter, LogicalScan, PredicatePushdownContext, RewriteStreamContext,
45    StreamDynamicFilter, StreamFilter, StreamTableScan, StreamTemporalJoin, ToStreamContext,
46};
47use crate::optimizer::plan_visitor::LogicalCardinalityExt;
48use crate::optimizer::property::{Distribution, RequiredDist};
49use crate::utils::{ColIndexMapping, ColIndexMappingRewriteExt, Condition, ConditionDisplay};
50
51/// `LogicalJoin` combines two relations according to some condition.
52///
53/// Each output row has fields from the left and right inputs. The set of output rows is a subset
54/// of the cartesian product of the two inputs; precisely which subset depends on the join
55/// condition. In addition, the output columns are a subset of the columns of the left and
56/// right columns, dependent on the output indices provided. A repeat output index is illegal.
57#[derive(Debug, Clone, PartialEq, Eq, Hash)]
58pub struct LogicalJoin {
59    pub base: PlanBase<Logical>,
60    core: generic::Join<PlanRef>,
61}
62
63impl Distill for LogicalJoin {
64    fn distill<'a>(&self) -> XmlNode<'a> {
65        let verbose = self.base.ctx().is_explain_verbose();
66        let mut vec = Vec::with_capacity(if verbose { 3 } else { 2 });
67        vec.push(("type", Pretty::debug(&self.join_type())));
68
69        let concat_schema = self.core.concat_schema();
70        let cond = Pretty::debug(&ConditionDisplay {
71            condition: self.on(),
72            input_schema: &concat_schema,
73        });
74        vec.push(("on", cond));
75
76        if verbose {
77            let data = IndicesDisplay::from_join(&self.core, &concat_schema);
78            vec.push(("output", data));
79        }
80
81        childless_record("LogicalJoin", vec)
82    }
83}
84
85impl LogicalJoin {
86    pub(crate) fn new(left: PlanRef, right: PlanRef, join_type: JoinType, on: Condition) -> Self {
87        let core = generic::Join::with_full_output(left, right, join_type, on);
88        Self::with_core(core)
89    }
90
91    pub(crate) fn with_output_indices(
92        left: PlanRef,
93        right: PlanRef,
94        join_type: JoinType,
95        on: Condition,
96        output_indices: Vec<usize>,
97    ) -> Self {
98        let core = generic::Join::new(left, right, on, join_type, output_indices);
99        Self::with_core(core)
100    }
101
102    pub fn with_core(core: generic::Join<PlanRef>) -> Self {
103        let base = PlanBase::new_logical_with_core(&core);
104        LogicalJoin { base, core }
105    }
106
107    pub fn create(
108        left: PlanRef,
109        right: PlanRef,
110        join_type: JoinType,
111        on_clause: ExprImpl,
112    ) -> PlanRef {
113        Self::new(left, right, join_type, Condition::with_expr(on_clause)).into()
114    }
115
116    pub fn internal_column_num(&self) -> usize {
117        self.core.internal_column_num()
118    }
119
120    pub fn i2l_col_mapping_ignore_join_type(&self) -> ColIndexMapping {
121        self.core.i2l_col_mapping_ignore_join_type()
122    }
123
124    pub fn i2r_col_mapping_ignore_join_type(&self) -> ColIndexMapping {
125        self.core.i2r_col_mapping_ignore_join_type()
126    }
127
128    /// Get a reference to the logical join's on.
129    pub fn on(&self) -> &Condition {
130        &self.core.on
131    }
132
133    pub fn core(&self) -> &generic::Join<PlanRef> {
134        &self.core
135    }
136
137    /// Collect all input ref in the on condition. And separate them into left and right.
138    pub fn input_idx_on_condition(&self) -> (Vec<usize>, Vec<usize>) {
139        let input_refs = self
140            .core
141            .on
142            .collect_input_refs(self.core.left.schema().len() + self.core.right.schema().len());
143        let index_group = input_refs
144            .ones()
145            .chunk_by(|i| *i < self.core.left.schema().len());
146        let left_index = index_group
147            .into_iter()
148            .next()
149            .map_or(vec![], |group| group.1.collect_vec());
150        let right_index = index_group.into_iter().next().map_or(vec![], |group| {
151            group
152                .1
153                .map(|i| i - self.core.left.schema().len())
154                .collect_vec()
155        });
156        (left_index, right_index)
157    }
158
159    /// Get the join type of the logical join.
160    pub fn join_type(&self) -> JoinType {
161        self.core.join_type
162    }
163
164    /// Get the eq join key of the logical join.
165    pub fn eq_indexes(&self) -> Vec<(usize, usize)> {
166        self.core.eq_indexes()
167    }
168
169    /// Get the output indices of the logical join.
170    pub fn output_indices(&self) -> &Vec<usize> {
171        &self.core.output_indices
172    }
173
174    /// Clone with new output indices
175    pub fn clone_with_output_indices(&self, output_indices: Vec<usize>) -> Self {
176        Self::with_core(generic::Join {
177            output_indices,
178            ..self.core.clone()
179        })
180    }
181
182    /// Clone with new `on` condition
183    pub fn clone_with_cond(&self, on: Condition) -> Self {
184        Self::with_core(generic::Join {
185            on,
186            ..self.core.clone()
187        })
188    }
189
190    pub fn is_left_join(&self) -> bool {
191        matches!(self.join_type(), JoinType::LeftSemi | JoinType::LeftAnti)
192    }
193
194    pub fn is_right_join(&self) -> bool {
195        matches!(self.join_type(), JoinType::RightSemi | JoinType::RightAnti)
196    }
197
198    pub fn is_full_out(&self) -> bool {
199        self.core.is_full_out()
200    }
201
202    pub fn is_asof_join(&self) -> bool {
203        self.join_type() == JoinType::AsofInner || self.join_type() == JoinType::AsofLeftOuter
204    }
205
206    pub fn output_indices_are_trivial(&self) -> bool {
207        itertools::equal(
208            self.output_indices().iter().cloned(),
209            0..self.internal_column_num(),
210        )
211    }
212
213    /// Try to simplify the outer join with the predicate on the top of the join
214    ///
215    /// now it is just a naive implementation for comparison expression, we can give a more general
216    /// implementation with constant folding in future
217    fn simplify_outer(predicate: &Condition, left_col_num: usize, join_type: JoinType) -> JoinType {
218        let (mut gen_null_in_left, mut gen_null_in_right) = match join_type {
219            JoinType::LeftOuter => (false, true),
220            JoinType::RightOuter => (true, false),
221            JoinType::FullOuter => (true, true),
222            _ => return join_type,
223        };
224
225        for expr in &predicate.conjunctions {
226            if let ExprImpl::FunctionCall(func) = expr {
227                match func.func_type() {
228                    ExprType::Equal
229                    | ExprType::NotEqual
230                    | ExprType::LessThan
231                    | ExprType::LessThanOrEqual
232                    | ExprType::GreaterThan
233                    | ExprType::GreaterThanOrEqual => {
234                        for input in func.inputs() {
235                            if let ExprImpl::InputRef(input) = input {
236                                let idx = input.index;
237                                if idx < left_col_num {
238                                    gen_null_in_left = false;
239                                } else {
240                                    gen_null_in_right = false;
241                                }
242                            }
243                        }
244                    }
245                    _ => {}
246                };
247            }
248        }
249
250        match (gen_null_in_left, gen_null_in_right) {
251            (true, true) => JoinType::FullOuter,
252            (true, false) => JoinType::RightOuter,
253            (false, true) => JoinType::LeftOuter,
254            (false, false) => JoinType::Inner,
255        }
256    }
257
258    /// Index Join:
259    /// Try to convert logical join into batch lookup join and meanwhile it will do
260    /// the index selection for the lookup table so that we can benefit from indexes.
261    fn to_batch_lookup_join_with_index_selection(
262        &self,
263        predicate: EqJoinPredicate,
264        batch_join: generic::Join<BatchPlanRef>,
265    ) -> Result<Option<BatchLookupJoin>> {
266        match batch_join.join_type {
267            JoinType::Inner
268            | JoinType::LeftOuter
269            | JoinType::LeftSemi
270            | JoinType::LeftAnti
271            | JoinType::AsofInner
272            | JoinType::AsofLeftOuter => {}
273            _ => return Ok(None),
274        };
275
276        // Index selection for index join.
277        let right = self.right();
278        // Lookup Join only supports basic tables on the join's right side.
279        let logical_scan: &LogicalScan = if let Some(logical_scan) = right.as_logical_scan() {
280            logical_scan
281        } else {
282            return Ok(None);
283        };
284
285        let mut result_plan = None;
286        // Lookup primary table.
287        if let Some(lookup_join) =
288            self.to_batch_lookup_join(predicate.clone(), batch_join.clone())?
289        {
290            result_plan = Some(lookup_join);
291        }
292
293        if self
294            .core
295            .ctx()
296            .session_ctx()
297            .config()
298            .enable_index_selection()
299        {
300            let indexes = logical_scan.table_indexes();
301            for index in indexes {
302                if let Some(index_scan) = logical_scan.to_index_scan_if_index_covered(index) {
303                    let index_scan: PlanRef = index_scan.into();
304                    let that = self.clone_with_left_right(self.left(), index_scan.clone());
305                    let mut new_batch_join = batch_join.clone();
306                    new_batch_join.right =
307                        index_scan.to_batch().expect("index scan failed to batch");
308
309                    // Lookup covered index.
310                    if let Some(lookup_join) =
311                        that.to_batch_lookup_join(predicate.clone(), new_batch_join)?
312                    {
313                        match &result_plan {
314                            None => result_plan = Some(lookup_join),
315                            Some(prev_lookup_join) => {
316                                // Prefer to choose lookup join with longer lookup prefix len.
317                                if prev_lookup_join.lookup_prefix_len()
318                                    < lookup_join.lookup_prefix_len()
319                                {
320                                    result_plan = Some(lookup_join)
321                                }
322                            }
323                        }
324                    }
325                }
326            }
327        }
328
329        Ok(result_plan)
330    }
331
332    /// Try to convert logical join into batch lookup join.
333    fn to_batch_lookup_join(
334        &self,
335        predicate: EqJoinPredicate,
336        logical_join: generic::Join<BatchPlanRef>,
337    ) -> Result<Option<BatchLookupJoin>> {
338        let logical_scan: &LogicalScan =
339            if let Some(logical_scan) = self.core.right.as_logical_scan() {
340                logical_scan
341            } else {
342                return Ok(None);
343            };
344        Self::gen_batch_lookup_join(logical_scan, predicate, logical_join, self.is_asof_join())
345    }
346
347    pub fn gen_batch_lookup_join(
348        logical_scan: &LogicalScan,
349        predicate: EqJoinPredicate,
350        logical_join: generic::Join<BatchPlanRef>,
351        is_as_of: bool,
352    ) -> Result<Option<BatchLookupJoin>> {
353        match logical_join.join_type {
354            JoinType::Inner
355            | JoinType::LeftOuter
356            | JoinType::LeftSemi
357            | JoinType::LeftAnti
358            | JoinType::AsofInner
359            | JoinType::AsofLeftOuter => {}
360            _ => return Ok(None),
361        };
362
363        let table = logical_scan.table();
364        let output_column_ids = logical_scan.output_column_ids();
365
366        // Verify that the right join key columns are the the prefix of the primary key and
367        // also contain the distribution key.
368        let order_col_ids = table.order_column_ids();
369        let dist_key = table.distribution_key.clone();
370        // The at least prefix of order key that contains distribution key.
371        let mut dist_key_in_order_key_pos = vec![];
372        for d in dist_key {
373            let pos = table
374                .order_column_indices()
375                .position(|x| x == d)
376                .expect("dist_key must in order_key");
377            dist_key_in_order_key_pos.push(pos);
378        }
379        // The shortest prefix of order key that contains distribution key.
380        let shortest_prefix_len = dist_key_in_order_key_pos
381            .iter()
382            .max()
383            .map_or(0, |pos| pos + 1);
384
385        // Distributed lookup join can't support lookup table with a singleton distribution.
386        if shortest_prefix_len == 0 {
387            return Ok(None);
388        }
389
390        // Reorder the join equal predicate to match the order key.
391        let mut reorder_idx = Vec::with_capacity(shortest_prefix_len);
392        for order_col_id in order_col_ids {
393            let mut found = false;
394            for (i, eq_idx) in predicate.right_eq_indexes().into_iter().enumerate() {
395                if order_col_id == output_column_ids[eq_idx] {
396                    reorder_idx.push(i);
397                    found = true;
398                    break;
399                }
400            }
401            if !found {
402                break;
403            }
404        }
405        if reorder_idx.len() < shortest_prefix_len {
406            return Ok(None);
407        }
408        let lookup_prefix_len = reorder_idx.len();
409        let predicate = predicate.reorder(&reorder_idx);
410
411        // Extract the predicate from logical scan. Only pure scan is supported.
412        let (new_scan, scan_predicate, project_expr) = logical_scan.predicate_pull_up();
413        // Construct output column to require column mapping
414        let o2r = if let Some(project_expr) = project_expr {
415            project_expr
416                .into_iter()
417                .map(|x| x.as_input_ref().unwrap().index)
418                .collect_vec()
419        } else {
420            (0..logical_scan.output_col_idx().len()).collect_vec()
421        };
422        let left_schema_len = logical_join.left.schema().len();
423
424        let mut join_predicate_rewriter = LookupJoinPredicateRewriter {
425            offset: left_schema_len,
426            mapping: o2r.clone(),
427        };
428
429        let new_eq_cond = predicate
430            .eq_cond()
431            .rewrite_expr(&mut join_predicate_rewriter);
432
433        let mut scan_predicate_rewriter = LookupJoinScanPredicateRewriter {
434            offset: left_schema_len,
435        };
436
437        let new_other_cond = predicate
438            .other_cond()
439            .clone()
440            .rewrite_expr(&mut join_predicate_rewriter)
441            .and(scan_predicate.rewrite_expr(&mut scan_predicate_rewriter));
442
443        let new_join_on = new_eq_cond.and(new_other_cond);
444        let new_predicate = EqJoinPredicate::create(
445            left_schema_len,
446            new_scan.schema().len(),
447            new_join_on.clone(),
448        );
449
450        // We discovered that we cannot use a lookup join after pulling up the predicate
451        // from one side and simplifying the condition. Let's use some other join instead.
452        if !new_predicate.has_eq() {
453            return Ok(None);
454        }
455
456        // Rewrite the join output indices and all output indices referred to the old scan need to
457        // rewrite.
458        let new_join_output_indices = logical_join
459            .output_indices
460            .iter()
461            .map(|&x| {
462                if x < left_schema_len {
463                    x
464                } else {
465                    o2r[x - left_schema_len] + left_schema_len
466                }
467            })
468            .collect_vec();
469
470        let new_scan_output_column_ids = new_scan.output_column_ids();
471        let as_of = new_scan.as_of.clone();
472        let new_logical_scan: LogicalScan = new_scan.into();
473
474        // Construct a new logical join, because we have change its RHS.
475        let new_logical_join = generic::Join::new(
476            logical_join.left,
477            new_logical_scan.to_batch()?,
478            new_join_on,
479            logical_join.join_type,
480            new_join_output_indices,
481        );
482
483        let asof_desc = is_as_of
484            .then(|| {
485                Self::get_inequality_desc_from_predicate(
486                    predicate.other_cond().clone(),
487                    left_schema_len,
488                )
489            })
490            .transpose()?;
491
492        Ok(Some(BatchLookupJoin::new(
493            new_logical_join,
494            new_predicate,
495            table.clone(),
496            new_scan_output_column_ids,
497            lookup_prefix_len,
498            false,
499            as_of,
500            asof_desc,
501        )))
502    }
503
504    pub fn decompose(self) -> (PlanRef, PlanRef, Condition, JoinType, Vec<usize>) {
505        self.core.decompose()
506    }
507}
508
509impl PlanTreeNodeBinary<Logical> for LogicalJoin {
510    fn left(&self) -> PlanRef {
511        self.core.left.clone()
512    }
513
514    fn right(&self) -> PlanRef {
515        self.core.right.clone()
516    }
517
518    fn clone_with_left_right(&self, left: PlanRef, right: PlanRef) -> Self {
519        Self::with_core(generic::Join {
520            left,
521            right,
522            ..self.core.clone()
523        })
524    }
525
526    fn rewrite_with_left_right(
527        &self,
528        left: PlanRef,
529        left_col_change: ColIndexMapping,
530        right: PlanRef,
531        right_col_change: ColIndexMapping,
532    ) -> (Self, ColIndexMapping) {
533        let (new_on, new_output_indices) = {
534            let (mut map, _) = left_col_change.clone().into_parts();
535            let (mut right_map, _) = right_col_change.clone().into_parts();
536            for i in right_map.iter_mut().flatten() {
537                *i += left.schema().len();
538            }
539            map.append(&mut right_map);
540            let mut mapping = ColIndexMapping::new(map, left.schema().len() + right.schema().len());
541
542            let new_output_indices = self
543                .output_indices()
544                .iter()
545                .map(|&i| mapping.map(i))
546                .collect::<Vec<_>>();
547            let new_on = self.on().clone().rewrite_expr(&mut mapping);
548            (new_on, new_output_indices)
549        };
550
551        let join = Self::with_output_indices(
552            left,
553            right,
554            self.join_type(),
555            new_on,
556            new_output_indices.clone(),
557        );
558
559        let new_i2o = ColIndexMapping::with_remaining_columns(
560            &new_output_indices,
561            join.internal_column_num(),
562        );
563
564        let old_o2i = self.core.o2i_col_mapping();
565
566        let old_o2l = old_o2i
567            .composite(&self.core.i2l_col_mapping())
568            .composite(&left_col_change);
569        let old_o2r = old_o2i
570            .composite(&self.core.i2r_col_mapping())
571            .composite(&right_col_change);
572        let new_l2o = join.core.l2i_col_mapping().composite(&new_i2o);
573        let new_r2o = join.core.r2i_col_mapping().composite(&new_i2o);
574
575        let out_col_change = old_o2l
576            .composite(&new_l2o)
577            .union(&old_o2r.composite(&new_r2o));
578        (join, out_col_change)
579    }
580}
581
582impl_plan_tree_node_for_binary! { Logical, LogicalJoin }
583
584impl ColPrunable for LogicalJoin {
585    fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef {
586        // make `required_cols` point to internal table instead of output schema.
587        let required_cols = required_cols
588            .iter()
589            .map(|i| self.output_indices()[*i])
590            .collect_vec();
591        let left_len = self.left().schema().fields.len();
592
593        let total_len = self.left().schema().len() + self.right().schema().len();
594        let mut resized_required_cols = FixedBitSet::with_capacity(total_len);
595
596        required_cols.iter().for_each(|&i| {
597            if self.is_right_join() {
598                resized_required_cols.insert(left_len + i);
599            } else {
600                resized_required_cols.insert(i);
601            }
602        });
603
604        // add those columns which are required in the join condition to
605        // to those that are required in the output
606        let mut visitor = CollectInputRef::new(resized_required_cols);
607        self.on().visit_expr(&mut visitor);
608        let left_right_required_cols = FixedBitSet::from(visitor).ones().collect_vec();
609
610        let mut left_required_cols = Vec::new();
611        let mut right_required_cols = Vec::new();
612        left_right_required_cols.iter().for_each(|&i| {
613            if i < left_len {
614                left_required_cols.push(i);
615            } else {
616                right_required_cols.push(i - left_len);
617            }
618        });
619
620        let mut on = self.on().clone();
621        let mut mapping =
622            ColIndexMapping::with_remaining_columns(&left_right_required_cols, total_len);
623        on = on.rewrite_expr(&mut mapping);
624
625        let new_output_indices = {
626            let required_inputs_in_output = if self.is_left_join() {
627                &left_required_cols
628            } else if self.is_right_join() {
629                &right_required_cols
630            } else {
631                &left_right_required_cols
632            };
633
634            let mapping =
635                ColIndexMapping::with_remaining_columns(required_inputs_in_output, total_len);
636            required_cols.iter().map(|&i| mapping.map(i)).collect_vec()
637        };
638
639        LogicalJoin::with_output_indices(
640            self.left().prune_col(&left_required_cols, ctx),
641            self.right().prune_col(&right_required_cols, ctx),
642            self.join_type(),
643            on,
644            new_output_indices,
645        )
646        .into()
647    }
648}
649
650impl ExprRewritable<Logical> for LogicalJoin {
651    fn has_rewritable_expr(&self) -> bool {
652        true
653    }
654
655    fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
656        let mut core = self.core.clone();
657        core.rewrite_exprs(r);
658        Self {
659            base: self.base.clone_with_new_plan_id(),
660            core,
661        }
662        .into()
663    }
664}
665
666impl ExprVisitable for LogicalJoin {
667    fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
668        self.core.visit_exprs(v);
669    }
670}
671
672/// We are trying to derive a predicate to apply to the other side of a join if all
673/// the `InputRef`s in the predicate are eq condition columns, and can hence be substituted
674/// with the corresponding eq condition columns of the other side.
675///
676/// Strategy:
677/// 1. If the function is pure except for any `InputRef` (which may refer to impure computation),
678///    then we proceed. Else abort.
679/// 2. Then, we collect `InputRef`s in the conjunction.
680/// 3. If they are all columns in the given side of join eq condition, then we proceed. Else abort.
681/// 4. We then rewrite the `ExprImpl`, by replacing `InputRef` column indices with the equivalent in
682///    the other side.
683///
684/// # Arguments
685///
686/// Suppose we derive a predicate from the left side to be pushed to the right side.
687/// * `expr`: An expr from the left side.
688/// * `col_num`: The number of columns in the left side.
689fn derive_predicate_from_eq_condition(
690    expr: &ExprImpl,
691    eq_condition: &EqJoinPredicate,
692    col_num: usize,
693    expr_is_left: bool,
694) -> Option<ExprImpl> {
695    if expr.is_impure() {
696        return None;
697    }
698    let eq_indices = eq_condition
699        .eq_indexes_typed()
700        .iter()
701        .filter_map(|(l, r)| {
702            if l.return_type() != r.return_type() {
703                None
704            } else if expr_is_left {
705                Some(l.index())
706            } else {
707                Some(r.index())
708            }
709        })
710        .collect_vec();
711    if expr
712        .collect_input_refs(col_num)
713        .ones()
714        .any(|index| !eq_indices.contains(&index))
715    {
716        // expr contains an InputRef not in eq_condition
717        return None;
718    }
719    // The function is pure except for `InputRef` and all `InputRef`s are `eq_condition` indices.
720    // Hence, we can substitute those `InputRef`s with indices from the other side.
721    let other_side_mapping = if expr_is_left {
722        eq_condition.eq_indexes_typed().into_iter().collect()
723    } else {
724        eq_condition
725            .eq_indexes_typed()
726            .into_iter()
727            .map(|(x, y)| (y, x))
728            .collect()
729    };
730    struct InputRefsRewriter {
731        mapping: HashMap<InputRef, InputRef>,
732    }
733    impl ExprRewriter for InputRefsRewriter {
734        fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
735            self.mapping[&input_ref].clone().into()
736        }
737    }
738    Some(
739        InputRefsRewriter {
740            mapping: other_side_mapping,
741        }
742        .rewrite_expr(expr.clone()),
743    )
744}
745
746/// Rewrite the join predicate and all columns referred to the scan side need to rewrite.
747struct LookupJoinPredicateRewriter {
748    offset: usize,
749    mapping: Vec<usize>,
750}
751impl ExprRewriter for LookupJoinPredicateRewriter {
752    fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
753        if input_ref.index() < self.offset {
754            input_ref.into()
755        } else {
756            InputRef::new(
757                self.mapping[input_ref.index() - self.offset] + self.offset,
758                input_ref.return_type(),
759            )
760            .into()
761        }
762    }
763}
764
765/// Rewrite the scan predicate so we can add it to the join predicate.
766struct LookupJoinScanPredicateRewriter {
767    offset: usize,
768}
769impl ExprRewriter for LookupJoinScanPredicateRewriter {
770    fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
771        InputRef::new(input_ref.index() + self.offset, input_ref.return_type()).into()
772    }
773}
774
775impl PredicatePushdown for LogicalJoin {
776    /// Pushes predicates above and within a join node into the join node and/or its children nodes.
777    ///
778    /// # Which predicates can be pushed
779    ///
780    /// For inner join, we can do all kinds of pushdown.
781    ///
782    /// For left/right semi join, we can push filter to left/right and on-clause,
783    /// and push on-clause to left/right.
784    ///
785    /// For left/right anti join, we can push filter to left/right, but on-clause can not be pushed
786    ///
787    /// ## Outer Join
788    ///
789    /// Preserved Row table
790    /// : The table in an Outer Join that must return all rows.
791    ///
792    /// Null Supplying table
793    /// : This is the table that has nulls filled in for its columns in unmatched rows.
794    ///
795    /// |                          | Preserved Row table | Null Supplying table |
796    /// |--------------------------|---------------------|----------------------|
797    /// | Join predicate (on)      | Not Pushed          | Pushed               |
798    /// | Where predicate (filter) | Pushed              | Not Pushed           |
799    fn predicate_pushdown(
800        &self,
801        predicate: Condition,
802        ctx: &mut PredicatePushdownContext,
803    ) -> PlanRef {
804        // rewrite output col referencing indices as internal cols
805        let mut predicate = {
806            let mut mapping = self.core.o2i_col_mapping();
807            predicate.rewrite_expr(&mut mapping)
808        };
809
810        let left_col_num = self.left().schema().len();
811        let right_col_num = self.right().schema().len();
812        let join_type = LogicalJoin::simplify_outer(&predicate, left_col_num, self.join_type());
813
814        let push_down_temporal_predicate = self.temporal_join_on().is_none();
815
816        let (left_from_filter, right_from_filter, on) = push_down_into_join(
817            &mut predicate,
818            left_col_num,
819            right_col_num,
820            join_type,
821            push_down_temporal_predicate,
822        );
823
824        let mut new_on = self.on().clone().and(on);
825        let (left_from_on, right_from_on) = push_down_join_condition(
826            &mut new_on,
827            left_col_num,
828            right_col_num,
829            join_type,
830            push_down_temporal_predicate,
831        );
832
833        let left_predicate = left_from_filter.and(left_from_on);
834        let right_predicate = right_from_filter.and(right_from_on);
835
836        // Derive conditions to push to the other side based on eq condition columns
837        let eq_condition = EqJoinPredicate::create(left_col_num, right_col_num, new_on.clone());
838
839        // Only push to RHS if RHS is inner side of a join (RHS requires match on LHS)
840        let right_from_left = if matches!(
841            join_type,
842            JoinType::Inner | JoinType::LeftOuter | JoinType::RightSemi | JoinType::LeftSemi
843        ) {
844            Condition {
845                conjunctions: left_predicate
846                    .conjunctions
847                    .iter()
848                    .filter_map(|expr| {
849                        derive_predicate_from_eq_condition(expr, &eq_condition, left_col_num, true)
850                    })
851                    .collect(),
852            }
853        } else {
854            Condition::true_cond()
855        };
856
857        // Only push to LHS if LHS is inner side of a join (LHS requires match on RHS)
858        let left_from_right = if matches!(
859            join_type,
860            JoinType::Inner | JoinType::RightOuter | JoinType::LeftSemi | JoinType::RightSemi
861        ) {
862            Condition {
863                conjunctions: right_predicate
864                    .conjunctions
865                    .iter()
866                    .filter_map(|expr| {
867                        derive_predicate_from_eq_condition(
868                            expr,
869                            &eq_condition,
870                            right_col_num,
871                            false,
872                        )
873                    })
874                    .collect(),
875            }
876        } else {
877            Condition::true_cond()
878        };
879
880        let left_predicate = left_predicate.and(left_from_right);
881        let right_predicate = right_predicate.and(right_from_left);
882
883        let new_left = self.left().predicate_pushdown(left_predicate, ctx);
884        let new_right = self.right().predicate_pushdown(right_predicate, ctx);
885        let new_join = LogicalJoin::with_output_indices(
886            new_left,
887            new_right,
888            join_type,
889            new_on,
890            self.output_indices().clone(),
891        );
892
893        let mut mapping = self.core.i2o_col_mapping();
894        predicate = predicate.rewrite_expr(&mut mapping);
895        LogicalFilter::create(new_join.into(), predicate)
896    }
897}
898
899#[derive(Clone, Copy)]
900struct TemporalJoinScan<'a>(&'a LogicalScan);
901
902impl<'a> Deref for TemporalJoinScan<'a> {
903    type Target = LogicalScan;
904
905    fn deref(&self) -> &Self::Target {
906        self.0
907    }
908}
909
910impl LogicalJoin {
911    fn get_stream_input_for_hash_join(
912        &self,
913        predicate: &EqJoinPredicate,
914        ctx: &mut ToStreamContext,
915    ) -> Result<(StreamPlanRef, StreamPlanRef)> {
916        use super::stream::prelude::*;
917
918        let lhs_join_key_idx = self.eq_indexes().into_iter().map(|(l, _)| l).collect_vec();
919        let rhs_join_key_idx = self.eq_indexes().into_iter().map(|(_, r)| r).collect_vec();
920
921        let logical_right = try_enforce_locality_requirement(self.right(), &rhs_join_key_idx);
922        let mut right = logical_right.to_stream_with_dist_required(
923            &RequiredDist::shard_by_key(self.right().schema().len(), &predicate.right_eq_indexes()),
924            ctx,
925        )?;
926        let logical_left = try_enforce_locality_requirement(self.left(), &lhs_join_key_idx);
927        let r2l =
928            predicate.r2l_eq_columns_mapping(logical_left.schema().len(), right.schema().len());
929        let l2r =
930            predicate.l2r_eq_columns_mapping(logical_left.schema().len(), right.schema().len());
931        let mut left;
932        let right_dist = right.distribution();
933        match right_dist {
934            Distribution::HashShard(_) => {
935                let left_dist = r2l
936                    .rewrite_required_distribution(&RequiredDist::PhysicalDist(right_dist.clone()));
937                left = logical_left.to_stream_with_dist_required(&left_dist, ctx)?;
938            }
939            Distribution::UpstreamHashShard(_, _) => {
940                left = logical_left.to_stream_with_dist_required(
941                    &RequiredDist::shard_by_key(
942                        self.left().schema().len(),
943                        &predicate.left_eq_indexes(),
944                    ),
945                    ctx,
946                )?;
947                let left_dist = left.distribution();
948                match left_dist {
949                    Distribution::HashShard(_) => {
950                        let right_dist = l2r.rewrite_required_distribution(
951                            &RequiredDist::PhysicalDist(left_dist.clone()),
952                        );
953                        right = right_dist.streaming_enforce_if_not_satisfies(right)?
954                    }
955                    Distribution::UpstreamHashShard(_, _) => {
956                        left = RequiredDist::hash_shard(&predicate.left_eq_indexes())
957                            .streaming_enforce_if_not_satisfies(left)?;
958                        right = RequiredDist::hash_shard(&predicate.right_eq_indexes())
959                            .streaming_enforce_if_not_satisfies(right)?;
960                    }
961                    _ => unreachable!(),
962                }
963            }
964            _ => unreachable!(),
965        }
966        Ok((left, right))
967    }
968
969    fn to_stream_hash_join(
970        &self,
971        predicate: EqJoinPredicate,
972        ctx: &mut ToStreamContext,
973    ) -> Result<StreamPlanRef> {
974        use super::stream::prelude::*;
975
976        assert!(predicate.has_eq());
977        let (left, right) = self.get_stream_input_for_hash_join(&predicate, ctx)?;
978
979        let core = self.core.clone_with_inputs(left, right);
980
981        // Convert to Hash Join for equal joins
982        // For inner joins, pull non-equal conditions to a filter operator on top of it by default.
983        // We do so as the filter operator can apply the non-equal condition batch-wise (vectorized)
984        // as opposed to the HashJoin, which applies the condition row-wise.
985        // However, the default behavior of pulling up non-equal conditions can be overridden by the
986        // session variable `streaming_force_filter_inside_join` as it can save unnecessary
987        // materialization of rows only to be filtered later.
988
989        let stream_hash_join = StreamHashJoin::new(core.clone(), predicate.clone())?;
990
991        let force_filter_inside_join = self
992            .base
993            .ctx()
994            .session_ctx()
995            .config()
996            .streaming_force_filter_inside_join();
997
998        let pull_filter = self.join_type() == JoinType::Inner
999            && stream_hash_join.eq_join_predicate().has_non_eq()
1000            && stream_hash_join.inequality_pairs().is_empty()
1001            && (!force_filter_inside_join);
1002        if pull_filter {
1003            let default_indices = (0..self.internal_column_num()).collect::<Vec<_>>();
1004
1005            let mut core = core;
1006            core.output_indices = default_indices.clone();
1007            // Temporarily remove output indices.
1008            let eq_cond = EqJoinPredicate::new(
1009                Condition::true_cond(),
1010                predicate.eq_keys().to_vec(),
1011                self.left().schema().len(),
1012                self.right().schema().len(),
1013            );
1014            core.on = eq_cond.eq_cond();
1015            let hash_join = StreamHashJoin::new(core, eq_cond)?.into();
1016            let logical_filter = generic::Filter::new(predicate.non_eq_cond(), hash_join);
1017            let plan = StreamFilter::new(logical_filter).into();
1018            if self.output_indices() != &default_indices {
1019                let logical_project = generic::Project::with_mapping(
1020                    plan,
1021                    ColIndexMapping::with_remaining_columns(
1022                        self.output_indices(),
1023                        self.internal_column_num(),
1024                    ),
1025                );
1026                Ok(StreamProject::new(logical_project).into())
1027            } else {
1028                Ok(plan)
1029            }
1030        } else {
1031            Ok(stream_hash_join.into())
1032        }
1033    }
1034
1035    fn temporal_join_on(&self) -> Option<TemporalJoinScan<'_>> {
1036        if let Some(logical_scan) = self.core.right.as_logical_scan() {
1037            matches!(logical_scan.as_of(), Some(AsOf::ProcessTime))
1038                .then_some(TemporalJoinScan(logical_scan))
1039        } else {
1040            None
1041        }
1042    }
1043
1044    fn should_be_stream_temporal_join<'a>(
1045        &'a self,
1046        ctx: &ToStreamContext,
1047    ) -> Result<Option<TemporalJoinScan<'a>>> {
1048        Ok(if let Some(scan) = self.temporal_join_on() {
1049            if let BackfillType::SnapshotBackfill = ctx.backfill_type() {
1050                return Err(RwError::from(ErrorCode::NotSupported(
1051                    "Temporal join with snapshot backfill not supported".into(),
1052                    "Please use arrangement backfill".into(),
1053                )));
1054            }
1055            if scan.cross_database() {
1056                return Err(RwError::from(ErrorCode::NotSupported(
1057                        "Temporal join requires the lookup table to be in the same database as the stream source table".into(),
1058                        "Please ensure both tables are in the same database".into(),
1059                    )));
1060            }
1061            Some(scan)
1062        } else {
1063            None
1064        })
1065    }
1066
1067    fn to_stream_temporal_join_with_index_selection(
1068        &self,
1069        logical_scan: TemporalJoinScan<'_>,
1070        predicate: EqJoinPredicate,
1071        ctx: &mut ToStreamContext,
1072    ) -> Result<StreamPlanRef> {
1073        // Use primary table.
1074        let mut result_plan: Result<StreamTemporalJoin> =
1075            self.to_stream_temporal_join(logical_scan, predicate.clone(), ctx);
1076        // Return directly if this temporal join can match the pk of its right table.
1077        if let Ok(temporal_join) = &result_plan
1078            && temporal_join.eq_join_predicate().eq_indexes().len()
1079                == logical_scan.primary_key().len()
1080        {
1081            return result_plan.map(|x| x.into());
1082        }
1083        if self
1084            .core
1085            .ctx()
1086            .session_ctx()
1087            .config()
1088            .enable_index_selection()
1089        {
1090            let indexes = logical_scan.table_indexes();
1091            for index in indexes {
1092                // Use index table
1093                if let Some(index_scan) = logical_scan.to_index_scan_if_index_covered(index) {
1094                    let index_scan: PlanRef = index_scan.into();
1095                    let that = self.clone_with_left_right(self.left(), index_scan.clone());
1096                    if let Ok(temporal_join) = that.to_stream_temporal_join(
1097                        that.temporal_join_on().expect(
1098                            "index scan created from temporal join scan must also be temporal join",
1099                        ),
1100                        predicate.clone(),
1101                        ctx,
1102                    ) {
1103                        match &result_plan {
1104                            Err(_) => result_plan = Ok(temporal_join),
1105                            Ok(prev_temporal_join) => {
1106                                // Prefer to the temporal join with a longer lookup prefix len.
1107                                if prev_temporal_join.eq_join_predicate().eq_indexes().len()
1108                                    < temporal_join.eq_join_predicate().eq_indexes().len()
1109                                {
1110                                    result_plan = Ok(temporal_join)
1111                                }
1112                            }
1113                        }
1114                    }
1115                }
1116            }
1117        }
1118
1119        result_plan.map(|x| x.into())
1120    }
1121
1122    fn temporal_join_scan_predicate_pull_up(
1123        logical_scan: TemporalJoinScan<'_>,
1124        predicate: EqJoinPredicate,
1125        output_indices: &[usize],
1126        left_schema_len: usize,
1127    ) -> Result<(StreamTableScan, EqJoinPredicate, Condition, Vec<usize>)> {
1128        // Extract the predicate from logical scan. Only pure scan is supported.
1129        let (new_scan, scan_predicate, project_expr) = logical_scan.predicate_pull_up();
1130        // Construct output column to require column mapping
1131        let o2r = if let Some(project_expr) = project_expr {
1132            project_expr
1133                .into_iter()
1134                .map(|x| x.as_input_ref().unwrap().index)
1135                .collect_vec()
1136        } else {
1137            (0..logical_scan.output_col_idx().len()).collect_vec()
1138        };
1139        let mut join_predicate_rewriter = LookupJoinPredicateRewriter {
1140            offset: left_schema_len,
1141            mapping: o2r.clone(),
1142        };
1143
1144        let new_eq_cond = predicate
1145            .eq_cond()
1146            .rewrite_expr(&mut join_predicate_rewriter);
1147
1148        let mut scan_predicate_rewriter = LookupJoinScanPredicateRewriter {
1149            offset: left_schema_len,
1150        };
1151
1152        let new_other_cond = predicate
1153            .other_cond()
1154            .clone()
1155            .rewrite_expr(&mut join_predicate_rewriter)
1156            .and(scan_predicate.rewrite_expr(&mut scan_predicate_rewriter));
1157
1158        let new_join_on = new_eq_cond.and(new_other_cond);
1159
1160        let new_predicate = EqJoinPredicate::create(
1161            left_schema_len,
1162            new_scan.schema().len(),
1163            new_join_on.clone(),
1164        );
1165
1166        // Rewrite the join output indices and all output indices referred to the old scan need to
1167        // rewrite.
1168        let new_join_output_indices = output_indices
1169            .iter()
1170            .map(|&x| {
1171                if x < left_schema_len {
1172                    x
1173                } else {
1174                    o2r[x - left_schema_len] + left_schema_len
1175                }
1176            })
1177            .collect_vec();
1178
1179        let new_stream_table_scan =
1180            StreamTableScan::new_with_stream_scan_type(new_scan, StreamScanType::UpstreamOnly);
1181        Ok((
1182            new_stream_table_scan,
1183            new_predicate,
1184            new_join_on,
1185            new_join_output_indices,
1186        ))
1187    }
1188
1189    fn to_stream_temporal_join(
1190        &self,
1191        logical_scan: TemporalJoinScan<'_>,
1192        predicate: EqJoinPredicate,
1193        ctx: &mut ToStreamContext,
1194    ) -> Result<StreamTemporalJoin> {
1195        use super::stream::prelude::*;
1196
1197        assert!(predicate.has_eq());
1198
1199        let table = logical_scan.table();
1200        let output_column_ids = logical_scan.output_column_ids();
1201
1202        // Verify that the right join key columns are the the prefix of the primary key and
1203        // also contain the distribution key.
1204        let order_col_ids = table.order_column_ids();
1205        let dist_key = table.distribution_key.clone();
1206
1207        let mut dist_key_in_order_key_pos = vec![];
1208        for d in dist_key {
1209            let pos = table
1210                .order_column_indices()
1211                .position(|x| x == d)
1212                .expect("dist_key must in order_key");
1213            dist_key_in_order_key_pos.push(pos);
1214        }
1215        // The shortest prefix of order key that contains distribution key.
1216        let shortest_prefix_len = dist_key_in_order_key_pos
1217            .iter()
1218            .max()
1219            .map_or(0, |pos| pos + 1);
1220
1221        // Reorder the join equal predicate to match the order key.
1222        let mut reorder_idx = Vec::with_capacity(shortest_prefix_len);
1223        for order_col_id in order_col_ids {
1224            let mut found = false;
1225            for (i, eq_idx) in predicate.right_eq_indexes().into_iter().enumerate() {
1226                if order_col_id == output_column_ids[eq_idx] {
1227                    reorder_idx.push(i);
1228                    found = true;
1229                    break;
1230                }
1231            }
1232            if !found {
1233                break;
1234            }
1235        }
1236        if reorder_idx.len() < shortest_prefix_len {
1237            // TODO: support index selection for temporal join and refine this error message.
1238            return Err(RwError::from(ErrorCode::NotSupported(
1239                "Temporal join requires the lookup table's primary key contained exactly in the equivalence condition".into(),
1240                "Please add the primary key of the lookup table to the join condition and remove any other conditions".into(),
1241            )));
1242        }
1243        let lookup_prefix_len = reorder_idx.len();
1244        let predicate = predicate.reorder(&reorder_idx);
1245
1246        let required_dist = if dist_key_in_order_key_pos.is_empty() {
1247            RequiredDist::single()
1248        } else {
1249            let left_eq_indexes = predicate.left_eq_indexes();
1250            let left_dist_key = dist_key_in_order_key_pos
1251                .iter()
1252                .map(|pos| left_eq_indexes[*pos])
1253                .collect_vec();
1254
1255            RequiredDist::hash_shard(&left_dist_key)
1256        };
1257
1258        let lhs_join_key_idx = predicate
1259            .eq_indexes()
1260            .into_iter()
1261            .map(|(l, _)| l)
1262            .collect_vec();
1263        let logical_left = try_enforce_locality_requirement(self.left(), &lhs_join_key_idx);
1264        let left = logical_left.to_stream(ctx)?;
1265        // Enforce a shuffle for the temporal join LHS to let the scheduler be able to schedule the join fragment together with the RHS with a `no_shuffle` exchange.
1266        let left = required_dist.stream_enforce(left);
1267
1268        let (new_stream_table_scan, new_predicate, new_join_on, new_join_output_indices) =
1269            Self::temporal_join_scan_predicate_pull_up(
1270                logical_scan,
1271                predicate,
1272                self.output_indices(),
1273                self.left().schema().len(),
1274            )?;
1275
1276        let right = RequiredDist::no_shuffle(new_stream_table_scan.into());
1277        if !new_predicate.has_eq() {
1278            return Err(RwError::from(ErrorCode::NotSupported(
1279                "Temporal join requires a non trivial join condition".into(),
1280                "Please remove the false condition of the join".into(),
1281            )));
1282        }
1283
1284        // Construct a new logical join, because we have change its RHS.
1285        let new_logical_join = generic::Join::new(
1286            left,
1287            right,
1288            new_join_on,
1289            self.join_type(),
1290            new_join_output_indices,
1291        );
1292
1293        let new_predicate = new_predicate.retain_prefix_eq_key(lookup_prefix_len);
1294
1295        StreamTemporalJoin::new(new_logical_join, new_predicate, false)
1296    }
1297
1298    fn to_stream_nested_loop_temporal_join(
1299        &self,
1300        logical_scan: TemporalJoinScan<'_>,
1301        predicate: EqJoinPredicate,
1302        ctx: &mut ToStreamContext,
1303    ) -> Result<StreamPlanRef> {
1304        use super::stream::prelude::*;
1305        assert!(!predicate.has_eq());
1306
1307        let left = self.left().to_stream_with_dist_required(
1308            &RequiredDist::PhysicalDist(Distribution::Broadcast),
1309            ctx,
1310        )?;
1311        assert!(left.as_stream_exchange().is_some());
1312
1313        if self.join_type() != JoinType::Inner {
1314            return Err(RwError::from(ErrorCode::NotSupported(
1315                "Temporal join requires an inner join".into(),
1316                "Please use an inner join".into(),
1317            )));
1318        }
1319
1320        if !left.append_only() {
1321            return Err(RwError::from(ErrorCode::NotSupported(
1322                "Nested-loop Temporal join requires the left hash side to be append only".into(),
1323                "Please ensure the left hash side is append only".into(),
1324            )));
1325        }
1326
1327        let (new_stream_table_scan, new_predicate, new_join_on, new_join_output_indices) =
1328            Self::temporal_join_scan_predicate_pull_up(
1329                logical_scan,
1330                predicate,
1331                self.output_indices(),
1332                self.left().schema().len(),
1333            )?;
1334
1335        let right = RequiredDist::no_shuffle(new_stream_table_scan.into());
1336
1337        // Construct a new logical join, because we have change its RHS.
1338        let new_logical_join = generic::Join::new(
1339            left,
1340            right,
1341            new_join_on,
1342            self.join_type(),
1343            new_join_output_indices,
1344        );
1345
1346        Ok(StreamTemporalJoin::new(new_logical_join, new_predicate, true)?.into())
1347    }
1348
1349    fn to_stream_dynamic_filter(
1350        &self,
1351        predicate: Condition,
1352        ctx: &mut ToStreamContext,
1353    ) -> Result<Option<StreamPlanRef>> {
1354        use super::stream::prelude::*;
1355
1356        // If there is exactly one predicate, it is a comparison (<, <=, >, >=), and the
1357        // join is a `Inner` or `LeftSemi` join, we can convert the scalar subquery into a
1358        // `StreamDynamicFilter`
1359
1360        // Check if `Inner`/`LeftSemi`
1361        if !matches!(self.join_type(), JoinType::Inner | JoinType::LeftSemi) {
1362            return Ok(None);
1363        }
1364
1365        // Check if right side is a scalar
1366        if !self.right().max_one_row() {
1367            return Ok(None);
1368        }
1369        if self.right().schema().len() != 1 {
1370            return Ok(None);
1371        }
1372
1373        // Check if the join condition is a correlated comparison
1374        if predicate.conjunctions.len() > 1 {
1375            return Ok(None);
1376        }
1377        let expr: ExprImpl = predicate.into();
1378        let (left_ref, comparator, right_ref) = match expr.as_comparison_cond() {
1379            Some(v) => v,
1380            None => return Ok(None),
1381        };
1382
1383        let condition_cross_inputs = left_ref.index < self.left().schema().len()
1384            && right_ref.index == self.left().schema().len() /* right side has only one column */;
1385        if !condition_cross_inputs {
1386            // Maybe we should panic here because it means some predicates are not pushed down.
1387            return Ok(None);
1388        }
1389
1390        // We align input types on all join predicates with cmp operator
1391        if self.left().schema().fields()[left_ref.index].data_type
1392            != self.right().schema().fields()[0].data_type
1393        {
1394            return Ok(None);
1395        }
1396
1397        // Check if non of the columns from the inner side is required to output
1398        let all_output_from_left = self
1399            .output_indices()
1400            .iter()
1401            .all(|i| *i < self.left().schema().len());
1402        if !all_output_from_left {
1403            return Ok(None);
1404        }
1405
1406        let left = self.left().to_stream(ctx)?.enforce_concrete_distribution();
1407        let right = self.right().to_stream_with_dist_required(
1408            &RequiredDist::PhysicalDist(Distribution::Broadcast),
1409            ctx,
1410        )?;
1411
1412        assert!(right.as_stream_exchange().is_some());
1413        assert_eq!(
1414            *right.inputs().iter().exactly_one().unwrap().distribution(),
1415            Distribution::Single
1416        );
1417
1418        let core = DynamicFilter::new(comparator, left_ref.index, left, right);
1419        let plan = StreamDynamicFilter::new(core)?.into();
1420        // TODO: `DynamicFilterExecutor` should support `output_indices` in `ChunkBuilder`
1421        if self
1422            .output_indices()
1423            .iter()
1424            .copied()
1425            .ne(0..self.left().schema().len())
1426        {
1427            // The schema of dynamic filter is always the same as the left side now, and we have
1428            // checked that all output columns are from the left side before.
1429            let logical_project = generic::Project::with_mapping(
1430                plan,
1431                ColIndexMapping::with_remaining_columns(
1432                    self.output_indices(),
1433                    self.left().schema().len(),
1434                ),
1435            );
1436            Ok(Some(StreamProject::new(logical_project).into()))
1437        } else {
1438            Ok(Some(plan))
1439        }
1440    }
1441
1442    pub fn index_lookup_join_to_batch_lookup_join(&self) -> Result<BatchPlanRef> {
1443        let predicate = EqJoinPredicate::create(
1444            self.left().schema().len(),
1445            self.right().schema().len(),
1446            self.on().clone(),
1447        );
1448        assert!(predicate.has_eq());
1449
1450        let join = self
1451            .core
1452            .clone_with_inputs(self.core.left.to_batch()?, self.core.right.to_batch()?);
1453
1454        Ok(self
1455            .to_batch_lookup_join(predicate, join)?
1456            .expect("Fail to convert to lookup join")
1457            .into())
1458    }
1459
1460    fn to_stream_asof_join(
1461        &self,
1462        predicate: EqJoinPredicate,
1463        ctx: &mut ToStreamContext,
1464    ) -> Result<StreamPlanRef> {
1465        use super::stream::prelude::*;
1466
1467        if predicate.eq_keys().is_empty() {
1468            return Err(ErrorCode::InvalidInputSyntax(
1469                "AsOf join requires at least 1 equal condition".to_owned(),
1470            )
1471            .into());
1472        }
1473
1474        let (left, right) = self.get_stream_input_for_hash_join(&predicate, ctx)?;
1475        let left_len = left.schema().len();
1476        let core = self.core.clone_with_inputs(left, right);
1477
1478        let inequality_desc =
1479            Self::get_inequality_desc_from_predicate(predicate.other_cond().clone(), left_len)?;
1480
1481        Ok(StreamAsOfJoin::new(core, predicate, inequality_desc)?.into())
1482    }
1483
1484    /// Convert the logical join to a Hash join.
1485    fn to_batch_hash_join(
1486        &self,
1487        logical_join: generic::Join<BatchPlanRef>,
1488        predicate: EqJoinPredicate,
1489    ) -> Result<BatchPlanRef> {
1490        use super::batch::prelude::*;
1491
1492        let left_schema_len = logical_join.left.schema().len();
1493        let asof_desc = self
1494            .is_asof_join()
1495            .then(|| {
1496                Self::get_inequality_desc_from_predicate(
1497                    predicate.other_cond().clone(),
1498                    left_schema_len,
1499                )
1500            })
1501            .transpose()?;
1502
1503        let batch_join = BatchHashJoin::new(logical_join, predicate, asof_desc);
1504        Ok(batch_join.into())
1505    }
1506
1507    pub fn get_inequality_desc_from_predicate(
1508        predicate: Condition,
1509        left_input_len: usize,
1510    ) -> Result<AsOfJoinDesc> {
1511        let expr: ExprImpl = predicate.into();
1512        if let Some((left_input_ref, expr_type, right_input_ref)) = expr.as_comparison_cond() {
1513            if left_input_ref.index() < left_input_len && right_input_ref.index() >= left_input_len
1514            {
1515                Ok(AsOfJoinDesc {
1516                    left_idx: left_input_ref.index() as u32,
1517                    right_idx: (right_input_ref.index() - left_input_len) as u32,
1518                    inequality_type: Self::expr_type_to_comparison_type(expr_type)?.into(),
1519                })
1520            } else {
1521                bail!("inequal condition from the same side should be push down in optimizer");
1522            }
1523        } else {
1524            Err(ErrorCode::InvalidInputSyntax(
1525                "AsOf join requires exactly 1 ineuquality condition".to_owned(),
1526            )
1527            .into())
1528        }
1529    }
1530
1531    fn expr_type_to_comparison_type(expr_type: PbType) -> Result<PbAsOfJoinInequalityType> {
1532        match expr_type {
1533            PbType::LessThan => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeLt),
1534            PbType::LessThanOrEqual => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeLe),
1535            PbType::GreaterThan => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeGt),
1536            PbType::GreaterThanOrEqual => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeGe),
1537            _ => Err(ErrorCode::InvalidInputSyntax(format!(
1538                "Invalid comparison type: {}",
1539                expr_type.as_str_name()
1540            ))
1541            .into()),
1542        }
1543    }
1544}
1545
1546impl ToBatch for LogicalJoin {
1547    fn to_batch(&self) -> Result<crate::optimizer::plan_node::BatchPlanRef> {
1548        let predicate = EqJoinPredicate::create(
1549            self.left().schema().len(),
1550            self.right().schema().len(),
1551            self.on().clone(),
1552        );
1553
1554        let batch_join = self
1555            .core
1556            .clone_with_inputs(self.core.left.to_batch()?, self.core.right.to_batch()?);
1557
1558        let ctx = self.base.ctx();
1559        let config = ctx.session_ctx().config();
1560
1561        if predicate.has_eq() {
1562            if !predicate.eq_keys_are_type_aligned() {
1563                return Err(ErrorCode::InternalError(format!(
1564                    "Join eq keys are not aligned for predicate: {predicate:?}"
1565                ))
1566                .into());
1567            }
1568            if config.batch_enable_lookup_join()
1569                && let Some(lookup_join) = self.to_batch_lookup_join_with_index_selection(
1570                    predicate.clone(),
1571                    batch_join.clone(),
1572                )?
1573            {
1574                return Ok(lookup_join.into());
1575            }
1576            self.to_batch_hash_join(batch_join, predicate)
1577        } else if self.is_asof_join() {
1578            Err(ErrorCode::InvalidInputSyntax(
1579                "AsOf join requires at least 1 equal condition".to_owned(),
1580            )
1581            .into())
1582        } else {
1583            // Convert to Nested-loop Join for non-equal joins
1584            Ok(BatchNestedLoopJoin::new(batch_join).into())
1585        }
1586    }
1587}
1588
1589impl ToStream for LogicalJoin {
1590    fn to_stream(
1591        &self,
1592        ctx: &mut ToStreamContext,
1593    ) -> Result<crate::optimizer::plan_node::StreamPlanRef> {
1594        if self
1595            .on()
1596            .conjunctions
1597            .iter()
1598            .any(|cond| cond.count_nows() > 0)
1599        {
1600            return Err(ErrorCode::NotSupported(
1601                "optimizer has tried to separate the temporal predicate(with now() expression) from the on condition, but it still reminded in on join's condition. Considering move it into WHERE clause?".to_owned(),
1602                 "please refer to https://www.risingwave.dev/docs/current/sql-pattern-temporal-filters/ for more information".to_owned()).into());
1603        }
1604
1605        let predicate = EqJoinPredicate::create(
1606            self.left().schema().len(),
1607            self.right().schema().len(),
1608            self.on().clone(),
1609        );
1610
1611        if self.join_type() == JoinType::AsofInner || self.join_type() == JoinType::AsofLeftOuter {
1612            self.to_stream_asof_join(predicate, ctx)
1613        } else if predicate.has_eq() {
1614            if !predicate.eq_keys_are_type_aligned() {
1615                return Err(ErrorCode::InternalError(format!(
1616                    "Join eq keys are not aligned for predicate: {predicate:?}"
1617                ))
1618                .into());
1619            }
1620
1621            if let Some(scan) = self.should_be_stream_temporal_join(ctx)? {
1622                self.to_stream_temporal_join_with_index_selection(scan, predicate, ctx)
1623            } else {
1624                self.to_stream_hash_join(predicate, ctx)
1625            }
1626        } else if let Some(scan) = self.should_be_stream_temporal_join(ctx)? {
1627            self.to_stream_nested_loop_temporal_join(scan, predicate, ctx)
1628        } else if let Some(dynamic_filter) =
1629            self.to_stream_dynamic_filter(self.on().clone(), ctx)?
1630        {
1631            Ok(dynamic_filter)
1632        } else {
1633            Err(RwError::from(ErrorCode::NotSupported(
1634                "streaming nested-loop join".to_owned(),
1635                "The non-equal join in the query requires a nested-loop join executor, which could be very expensive to run. \
1636                 Consider rewriting the query to use dynamic filter as a substitute if possible.\n\
1637                 See also: https://docs.risingwave.com/docs/current/sql-pattern-dynamic-filters/".to_owned(),
1638            )))
1639        }
1640    }
1641
1642    fn logical_rewrite_for_stream(
1643        &self,
1644        ctx: &mut RewriteStreamContext,
1645    ) -> Result<(PlanRef, ColIndexMapping)> {
1646        let (left, left_col_change) = self.left().logical_rewrite_for_stream(ctx)?;
1647        let left_len = left.schema().len();
1648        let (right, right_col_change) = self.right().logical_rewrite_for_stream(ctx)?;
1649        let (join, out_col_change) = self.rewrite_with_left_right(
1650            left.clone(),
1651            left_col_change,
1652            right.clone(),
1653            right_col_change,
1654        );
1655
1656        let mapping = ColIndexMapping::with_remaining_columns(
1657            join.output_indices(),
1658            join.internal_column_num(),
1659        );
1660
1661        let l2o = join.core.l2i_col_mapping().composite(&mapping);
1662        let r2o = join.core.r2i_col_mapping().composite(&mapping);
1663
1664        // Add missing pk indices to the logical join
1665        let mut left_to_add = left
1666            .expect_stream_key()
1667            .iter()
1668            .cloned()
1669            .filter(|i| l2o.try_map(*i).is_none())
1670            .collect_vec();
1671
1672        let mut right_to_add = right
1673            .expect_stream_key()
1674            .iter()
1675            .filter(|&&i| r2o.try_map(i).is_none())
1676            .map(|&i| i + left_len)
1677            .collect_vec();
1678
1679        // NOTE(st1page): add join keys in the pk_indices a work around before we really have stream
1680        // key.
1681        let right_len = right.schema().len();
1682        let eq_predicate = EqJoinPredicate::create(left_len, right_len, join.on().clone());
1683
1684        let either_or_both = self.core.add_which_join_key_to_pk();
1685
1686        for (lk, rk) in eq_predicate.eq_indexes() {
1687            match either_or_both {
1688                EitherOrBoth::Left(_) => {
1689                    if l2o.try_map(lk).is_none() {
1690                        left_to_add.push(lk);
1691                    }
1692                }
1693                EitherOrBoth::Right(_) => {
1694                    if r2o.try_map(rk).is_none() {
1695                        right_to_add.push(rk + left_len)
1696                    }
1697                }
1698                EitherOrBoth::Both(_, _) => {
1699                    if l2o.try_map(lk).is_none() {
1700                        left_to_add.push(lk);
1701                    }
1702                    if r2o.try_map(rk).is_none() {
1703                        right_to_add.push(rk + left_len)
1704                    }
1705                }
1706            };
1707        }
1708        let left_to_add = left_to_add.into_iter().unique();
1709        let right_to_add = right_to_add.into_iter().unique();
1710        // NOTE(st1page) over
1711
1712        let mut new_output_indices = join.output_indices().clone();
1713        if !join.is_right_join() {
1714            new_output_indices.extend(left_to_add);
1715        }
1716        if !join.is_left_join() {
1717            new_output_indices.extend(right_to_add);
1718        }
1719
1720        let join_with_pk = join.clone_with_output_indices(new_output_indices);
1721
1722        let plan = if join_with_pk.join_type() == JoinType::FullOuter {
1723            // ignore the all NULL to maintain the stream key's uniqueness, see https://github.com/risingwavelabs/risingwave/issues/8084 for more information
1724
1725            let l2o = join_with_pk
1726                .core
1727                .l2i_col_mapping()
1728                .composite(&join_with_pk.core.i2o_col_mapping());
1729            let r2o = join_with_pk
1730                .core
1731                .r2i_col_mapping()
1732                .composite(&join_with_pk.core.i2o_col_mapping());
1733            let left_right_stream_keys = join_with_pk
1734                .left()
1735                .expect_stream_key()
1736                .iter()
1737                .map(|i| l2o.map(*i))
1738                .chain(
1739                    join_with_pk
1740                        .right()
1741                        .expect_stream_key()
1742                        .iter()
1743                        .map(|i| r2o.map(*i)),
1744                )
1745                .collect_vec();
1746            let plan: PlanRef = join_with_pk.into();
1747            LogicalFilter::filter_out_all_null_keys(plan, &left_right_stream_keys)
1748        } else {
1749            join_with_pk.into()
1750        };
1751
1752        // the added columns is at the end, so it will not change the exists column index
1753        Ok((plan, out_col_change))
1754    }
1755
1756    fn try_better_locality(&self, columns: &[usize]) -> Option<PlanRef> {
1757        let mut ctx = ToStreamContext::new(false);
1758        // only pass through the locality information if it can be converted to dynamic filter
1759        if let Ok(Some(_)) = self.to_stream_dynamic_filter(self.on().clone(), &mut ctx) {
1760            // since dynamic filter only supports left input ref in the output indices, we can safely use o2i mapping to convert the required columns.
1761            let o2i_mapping = self.core.o2i_col_mapping();
1762            let left_input_columns = columns
1763                .iter()
1764                .map(|&col| o2i_mapping.try_map(col))
1765                .collect::<Option<Vec<usize>>>()?;
1766            if let Some(better_left_plan) = self.left().try_better_locality(&left_input_columns) {
1767                return Some(
1768                    self.clone_with_left_right(better_left_plan, self.right())
1769                        .into(),
1770                );
1771            }
1772        }
1773        None
1774    }
1775}
1776
1777#[cfg(test)]
1778mod tests {
1779
1780    use std::collections::HashSet;
1781
1782    use risingwave_common::catalog::{Field, Schema};
1783    use risingwave_common::types::{DataType, Datum};
1784    use risingwave_pb::expr::expr_node::Type;
1785
1786    use super::*;
1787    use crate::expr::{FunctionCall, Literal, assert_eq_input_ref};
1788    use crate::optimizer::optimizer_context::OptimizerContext;
1789    use crate::optimizer::plan_node::LogicalValues;
1790    use crate::optimizer::property::FunctionalDependency;
1791
1792    /// Pruning
1793    /// ```text
1794    /// Join(on: input_ref(1)=input_ref(3))
1795    ///   TableScan(v1, v2, v3)
1796    ///   TableScan(v4, v5, v6)
1797    /// ```
1798    /// with required columns [2,3] will result in
1799    /// ```text
1800    /// Project(input_ref(1), input_ref(2))
1801    ///   Join(on: input_ref(0)=input_ref(2))
1802    ///     TableScan(v2, v3)
1803    ///     TableScan(v4)
1804    /// ```
1805    #[tokio::test]
1806    async fn test_prune_join() {
1807        let ty = DataType::Int32;
1808        let ctx = OptimizerContext::mock().await;
1809        let fields: Vec<Field> = (1..7)
1810            .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
1811            .collect();
1812        let left = LogicalValues::new(
1813            vec![],
1814            Schema {
1815                fields: fields[0..3].to_vec(),
1816            },
1817            ctx.clone(),
1818        );
1819        let right = LogicalValues::new(
1820            vec![],
1821            Schema {
1822                fields: fields[3..6].to_vec(),
1823            },
1824            ctx,
1825        );
1826        let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
1827            FunctionCall::new(
1828                Type::Equal,
1829                vec![
1830                    ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
1831                    ExprImpl::InputRef(Box::new(InputRef::new(3, ty))),
1832                ],
1833            )
1834            .unwrap(),
1835        ));
1836        let join_type = JoinType::Inner;
1837        let join: PlanRef = LogicalJoin::new(
1838            left.into(),
1839            right.into(),
1840            join_type,
1841            Condition::with_expr(on),
1842        )
1843        .into();
1844
1845        // Perform the prune
1846        let required_cols = vec![2, 3];
1847        let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1848
1849        // Check the result
1850        let join = plan.as_logical_join().unwrap();
1851        assert_eq!(join.schema().fields().len(), 2);
1852        assert_eq!(join.schema().fields()[0], fields[2]);
1853        assert_eq!(join.schema().fields()[1], fields[3]);
1854
1855        let expr: ExprImpl = join.on().clone().into();
1856        let call = expr.as_function_call().unwrap();
1857        assert_eq_input_ref!(&call.inputs()[0], 0);
1858        assert_eq_input_ref!(&call.inputs()[1], 2);
1859
1860        let left = join.left();
1861        let left = left.as_logical_values().unwrap();
1862        assert_eq!(left.schema().fields(), &fields[1..3]);
1863        let right = join.right();
1864        let right = right.as_logical_values().unwrap();
1865        assert_eq!(right.schema().fields(), &fields[3..4]);
1866    }
1867
1868    /// Semi join panicked previously at `prune_col`. Add test to prevent regression.
1869    #[tokio::test]
1870    async fn test_prune_semi_join() {
1871        let ty = DataType::Int32;
1872        let ctx = OptimizerContext::mock().await;
1873        let fields: Vec<Field> = (1..7)
1874            .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
1875            .collect();
1876        let left = LogicalValues::new(
1877            vec![],
1878            Schema {
1879                fields: fields[0..3].to_vec(),
1880            },
1881            ctx.clone(),
1882        );
1883        let right = LogicalValues::new(
1884            vec![],
1885            Schema {
1886                fields: fields[3..6].to_vec(),
1887            },
1888            ctx,
1889        );
1890        let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
1891            FunctionCall::new(
1892                Type::Equal,
1893                vec![
1894                    ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
1895                    ExprImpl::InputRef(Box::new(InputRef::new(4, ty))),
1896                ],
1897            )
1898            .unwrap(),
1899        ));
1900        for join_type in [
1901            JoinType::LeftSemi,
1902            JoinType::RightSemi,
1903            JoinType::LeftAnti,
1904            JoinType::RightAnti,
1905        ] {
1906            let join = LogicalJoin::new(
1907                left.clone().into(),
1908                right.clone().into(),
1909                join_type,
1910                Condition::with_expr(on.clone()),
1911            );
1912
1913            let offset = if join.is_right_join() { 3 } else { 0 };
1914            let join: PlanRef = join.into();
1915            // Perform the prune
1916            let required_cols = vec![0];
1917            // key 0 is never used in the join (always key 1)
1918            let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1919            let as_plan = plan.as_logical_join().unwrap();
1920            // Check the result
1921            assert_eq!(as_plan.schema().fields().len(), 1);
1922            assert_eq!(as_plan.schema().fields()[0], fields[offset]);
1923
1924            // Perform the prune
1925            let required_cols = vec![0, 1, 2];
1926            // should not panic here
1927            let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1928            let as_plan = plan.as_logical_join().unwrap();
1929            // Check the result
1930            assert_eq!(as_plan.schema().fields().len(), 3);
1931            assert_eq!(as_plan.schema().fields()[0], fields[offset]);
1932            assert_eq!(as_plan.schema().fields()[1], fields[offset + 1]);
1933            assert_eq!(as_plan.schema().fields()[2], fields[offset + 2]);
1934        }
1935    }
1936
1937    /// Pruning
1938    /// ```text
1939    /// Join(on: input_ref(1)=input_ref(3))
1940    ///   TableScan(v1, v2, v3)
1941    ///   TableScan(v4, v5, v6)
1942    /// ```
1943    /// with required columns [1, 3] will result in
1944    /// ```text
1945    /// Join(on: input_ref(0)=input_ref(1))
1946    ///   TableScan(v2)
1947    ///   TableScan(v4)
1948    /// ```
1949    #[tokio::test]
1950    async fn test_prune_join_no_project() {
1951        let ty = DataType::Int32;
1952        let ctx = OptimizerContext::mock().await;
1953        let fields: Vec<Field> = (1..7)
1954            .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
1955            .collect();
1956        let left = LogicalValues::new(
1957            vec![],
1958            Schema {
1959                fields: fields[0..3].to_vec(),
1960            },
1961            ctx.clone(),
1962        );
1963        let right = LogicalValues::new(
1964            vec![],
1965            Schema {
1966                fields: fields[3..6].to_vec(),
1967            },
1968            ctx,
1969        );
1970        let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
1971            FunctionCall::new(
1972                Type::Equal,
1973                vec![
1974                    ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
1975                    ExprImpl::InputRef(Box::new(InputRef::new(3, ty))),
1976                ],
1977            )
1978            .unwrap(),
1979        ));
1980        let join_type = JoinType::Inner;
1981        let join: PlanRef = LogicalJoin::new(
1982            left.into(),
1983            right.into(),
1984            join_type,
1985            Condition::with_expr(on),
1986        )
1987        .into();
1988
1989        // Perform the prune
1990        let required_cols = vec![1, 3];
1991        let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1992
1993        // Check the result
1994        let join = plan.as_logical_join().unwrap();
1995        assert_eq!(join.schema().fields().len(), 2);
1996        assert_eq!(join.schema().fields()[0], fields[1]);
1997        assert_eq!(join.schema().fields()[1], fields[3]);
1998
1999        let expr: ExprImpl = join.on().clone().into();
2000        let call = expr.as_function_call().unwrap();
2001        assert_eq_input_ref!(&call.inputs()[0], 0);
2002        assert_eq_input_ref!(&call.inputs()[1], 1);
2003
2004        let left = join.left();
2005        let left = left.as_logical_values().unwrap();
2006        assert_eq!(left.schema().fields(), &fields[1..2]);
2007        let right = join.right();
2008        let right = right.as_logical_values().unwrap();
2009        assert_eq!(right.schema().fields(), &fields[3..4]);
2010    }
2011
2012    /// Convert
2013    /// ```text
2014    /// Join(on: ($1 = $3) AND ($2 == 42))
2015    ///   TableScan(v1, v2, v3)
2016    ///   TableScan(v4, v5, v6)
2017    /// ```
2018    /// to
2019    /// ```text
2020    /// Filter($2 == 42)
2021    ///   HashJoin(on: $1 = $3)
2022    ///     TableScan(v1, v2, v3)
2023    ///     TableScan(v4, v5, v6)
2024    /// ```
2025    #[tokio::test]
2026    async fn test_join_to_batch() {
2027        let ctx = OptimizerContext::mock().await;
2028        let fields: Vec<Field> = (1..7)
2029            .map(|i| Field::with_name(DataType::Int32, format!("v{}", i)))
2030            .collect();
2031        let left = LogicalValues::new(
2032            vec![],
2033            Schema {
2034                fields: fields[0..3].to_vec(),
2035            },
2036            ctx.clone(),
2037        );
2038        let right = LogicalValues::new(
2039            vec![],
2040            Schema {
2041                fields: fields[3..6].to_vec(),
2042            },
2043            ctx,
2044        );
2045
2046        fn input_ref(i: usize) -> ExprImpl {
2047            ExprImpl::InputRef(Box::new(InputRef::new(i, DataType::Int32)))
2048        }
2049        let eq_cond = ExprImpl::FunctionCall(Box::new(
2050            FunctionCall::new(Type::Equal, vec![input_ref(1), input_ref(3)]).unwrap(),
2051        ));
2052        let non_eq_cond = ExprImpl::FunctionCall(Box::new(
2053            FunctionCall::new(
2054                Type::Equal,
2055                vec![
2056                    input_ref(2),
2057                    ExprImpl::Literal(Box::new(Literal::new(
2058                        Datum::Some(42_i32.into()),
2059                        DataType::Int32,
2060                    ))),
2061                ],
2062            )
2063            .unwrap(),
2064        ));
2065        // Condition: ($1 = $3) AND ($2 == 42)
2066        let on_cond = ExprImpl::FunctionCall(Box::new(
2067            FunctionCall::new(Type::And, vec![eq_cond.clone(), non_eq_cond.clone()]).unwrap(),
2068        ));
2069
2070        let join_type = JoinType::Inner;
2071        let logical_join = LogicalJoin::new(
2072            left.into(),
2073            right.into(),
2074            join_type,
2075            Condition::with_expr(on_cond),
2076        );
2077
2078        // Perform `to_batch`
2079        let result = logical_join.to_batch().unwrap();
2080
2081        // Expected plan:  HashJoin($1 = $3 AND $2 == 42)
2082        let hash_join = result.as_batch_hash_join().unwrap();
2083        assert_eq!(
2084            ExprImpl::from(hash_join.eq_join_predicate().eq_cond()),
2085            eq_cond
2086        );
2087        assert_eq!(
2088            *hash_join
2089                .eq_join_predicate()
2090                .non_eq_cond()
2091                .conjunctions
2092                .first()
2093                .unwrap(),
2094            non_eq_cond
2095        );
2096    }
2097
2098    /// Convert
2099    /// ```text
2100    /// Join(join_type: left outer, on: ($1 = $3) AND ($2 == 42))
2101    ///   TableScan(v1, v2, v3)
2102    ///   TableScan(v4, v5, v6)
2103    /// ```
2104    /// to
2105    /// ```text
2106    /// HashJoin(join_type: left outer, on: ($1 = $3) AND ($2 == 42))
2107    ///   TableScan(v1, v2, v3)
2108    ///   TableScan(v4, v5, v6)
2109    /// ```
2110    #[tokio::test]
2111    #[ignore] // ignore due to refactor logical scan, but the test seem to duplicate with the explain test
2112    // framework, maybe we will remove it?
2113    async fn test_join_to_stream() {
2114        // let ctx = Rc::new(RefCell::new(QueryContext::mock().await));
2115        // let fields: Vec<Field> = (1..7)
2116        //     .map(|i| Field {
2117        //         data_type: DataType::Int32,
2118        //         name: format!("v{}", i),
2119        //     })
2120        //     .collect();
2121        // let left = LogicalScan::new(
2122        //     "left".to_string(),
2123        //     TableId::new(0),
2124        //     vec![1.into(), 2.into(), 3.into()],
2125        //     Schema {
2126        //         fields: fields[0..3].to_vec(),
2127        //     },
2128        //     ctx.clone(),
2129        // );
2130        // let right = LogicalScan::new(
2131        //     "right".to_string(),
2132        //     TableId::new(0),
2133        //     vec![4.into(), 5.into(), 6.into()],
2134        //     Schema {
2135        //                 fields: fields[3..6].to_vec(),
2136        //     },
2137        //     ctx,
2138        // );
2139        // let eq_cond = ExprImpl::FunctionCall(Box::new(
2140        //     FunctionCall::new(
2141        //         Type::Equal,
2142        //         vec![
2143        //             ExprImpl::InputRef(Box::new(InputRef::new(1, DataType::Int32))),
2144        //             ExprImpl::InputRef(Box::new(InputRef::new(3, DataType::Int32))),
2145        //         ],
2146        //     )
2147        //     .unwrap(),
2148        // ));
2149        // let non_eq_cond = ExprImpl::FunctionCall(Box::new(
2150        //     FunctionCall::new(
2151        //         Type::Equal,
2152        //         vec![
2153        //             ExprImpl::InputRef(Box::new(InputRef::new(2, DataType::Int32))),
2154        //             ExprImpl::Literal(Box::new(Literal::new(
2155        //                 Datum::Some(42_i32.into()),
2156        //                 DataType::Int32,
2157        //             ))),
2158        //         ],
2159        //     )
2160        //     .unwrap(),
2161        // ));
2162        // // Condition: ($1 = $3) AND ($2 == 42)
2163        // let on_cond = ExprImpl::FunctionCall(Box::new(
2164        //     FunctionCall::new(Type::And, vec![eq_cond, non_eq_cond]).unwrap(),
2165        // ));
2166
2167        // let join_type = JoinType::LeftOuter;
2168        // let logical_join = LogicalJoin::new(
2169        //     left.clone().into(),
2170        //     right.clone().into(),
2171        //     join_type,
2172        //     Condition::with_expr(on_cond.clone()),
2173        // );
2174
2175        // // Perform `to_stream`
2176        // let result = logical_join.to_stream();
2177
2178        // // Expected plan: HashJoin(($1 = $3) AND ($2 == 42))
2179        // let hash_join = result.as_stream_hash_join().unwrap();
2180        // assert_eq!(hash_join.eq_join_predicate().all_cond().as_expr(), on_cond);
2181    }
2182    /// Pruning
2183    /// ```text
2184    /// Join(on: input_ref(1)=input_ref(3))
2185    ///   TableScan(v1, v2, v3)
2186    ///   TableScan(v4, v5, v6)
2187    /// ```
2188    /// with required columns [3, 2] will result in
2189    /// ```text
2190    /// Project(input_ref(2), input_ref(1))
2191    ///   Join(on: input_ref(0)=input_ref(2))
2192    ///     TableScan(v2, v3)
2193    ///     TableScan(v4)
2194    /// ```
2195    #[tokio::test]
2196    async fn test_join_column_prune_with_order_required() {
2197        let ty = DataType::Int32;
2198        let ctx = OptimizerContext::mock().await;
2199        let fields: Vec<Field> = (1..7)
2200            .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
2201            .collect();
2202        let left = LogicalValues::new(
2203            vec![],
2204            Schema {
2205                fields: fields[0..3].to_vec(),
2206            },
2207            ctx.clone(),
2208        );
2209        let right = LogicalValues::new(
2210            vec![],
2211            Schema {
2212                fields: fields[3..6].to_vec(),
2213            },
2214            ctx,
2215        );
2216        let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
2217            FunctionCall::new(
2218                Type::Equal,
2219                vec![
2220                    ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
2221                    ExprImpl::InputRef(Box::new(InputRef::new(3, ty))),
2222                ],
2223            )
2224            .unwrap(),
2225        ));
2226        let join_type = JoinType::Inner;
2227        let join: PlanRef = LogicalJoin::new(
2228            left.into(),
2229            right.into(),
2230            join_type,
2231            Condition::with_expr(on),
2232        )
2233        .into();
2234
2235        // Perform the prune
2236        let required_cols = vec![3, 2];
2237        let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
2238
2239        // Check the result
2240        let join = plan.as_logical_join().unwrap();
2241        assert_eq!(join.schema().fields().len(), 2);
2242        assert_eq!(join.schema().fields()[0], fields[3]);
2243        assert_eq!(join.schema().fields()[1], fields[2]);
2244
2245        let expr: ExprImpl = join.on().clone().into();
2246        let call = expr.as_function_call().unwrap();
2247        assert_eq_input_ref!(&call.inputs()[0], 0);
2248        assert_eq_input_ref!(&call.inputs()[1], 2);
2249
2250        let left = join.left();
2251        let left = left.as_logical_values().unwrap();
2252        assert_eq!(left.schema().fields(), &fields[1..3]);
2253        let right = join.right();
2254        let right = right.as_logical_values().unwrap();
2255        assert_eq!(right.schema().fields(), &fields[3..4]);
2256    }
2257
2258    #[tokio::test]
2259    async fn fd_derivation_inner_outer_join() {
2260        // left: [l0, l1], right: [r0, r1, r2]
2261        // FD: l0 --> l1, r0 --> { r1, r2 }
2262        // On: l0 = 0 AND l1 = r1
2263        //
2264        // Inner Join:
2265        //  Schema: [l0, l1, r0, r1, r2]
2266        //  FD: l0 --> l1, r0 --> { r1, r2 }, {} --> l0, l1 --> r1, r1 --> l1
2267        // Left Outer Join:
2268        //  Schema: [l0, l1, r0, r1, r2]
2269        //  FD: l0 --> l1
2270        // Right Outer Join:
2271        //  Schema: [l0, l1, r0, r1, r2]
2272        //  FD: r0 --> { r1, r2 }
2273        // Full Outer Join:
2274        //  Schema: [l0, l1, r0, r1, r2]
2275        //  FD: empty
2276        // Left Semi/Anti Join:
2277        //  Schema: [l0, l1]
2278        //  FD: l0 --> l1
2279        // Right Semi/Anti Join:
2280        //  Schema: [r0, r1, r2]
2281        //  FD: r0 --> {r1, r2}
2282        let ctx = OptimizerContext::mock().await;
2283        let left = {
2284            let fields: Vec<Field> = vec![
2285                Field::with_name(DataType::Int32, "l0"),
2286                Field::with_name(DataType::Int32, "l1"),
2287            ];
2288            let mut values = LogicalValues::new(vec![], Schema { fields }, ctx.clone());
2289            // 0 --> 1
2290            values
2291                .base
2292                .functional_dependency_mut()
2293                .add_functional_dependency_by_column_indices(&[0], &[1]);
2294            values
2295        };
2296        let right = {
2297            let fields: Vec<Field> = vec![
2298                Field::with_name(DataType::Int32, "r0"),
2299                Field::with_name(DataType::Int32, "r1"),
2300                Field::with_name(DataType::Int32, "r2"),
2301            ];
2302            let mut values = LogicalValues::new(vec![], Schema { fields }, ctx);
2303            // 0 --> 1, 2
2304            values
2305                .base
2306                .functional_dependency_mut()
2307                .add_functional_dependency_by_column_indices(&[0], &[1, 2]);
2308            values
2309        };
2310        // l0 = 0 AND l1 = r1
2311        let on: ExprImpl = FunctionCall::new(
2312            Type::And,
2313            vec![
2314                FunctionCall::new(
2315                    Type::Equal,
2316                    vec![
2317                        InputRef::new(0, DataType::Int32).into(),
2318                        ExprImpl::literal_int(0),
2319                    ],
2320                )
2321                .unwrap()
2322                .into(),
2323                FunctionCall::new(
2324                    Type::Equal,
2325                    vec![
2326                        InputRef::new(1, DataType::Int32).into(),
2327                        InputRef::new(3, DataType::Int32).into(),
2328                    ],
2329                )
2330                .unwrap()
2331                .into(),
2332            ],
2333        )
2334        .unwrap()
2335        .into();
2336        let expected_fd_set = [
2337            (
2338                JoinType::Inner,
2339                [
2340                    // inherit from left
2341                    FunctionalDependency::with_indices(5, &[0], &[1]),
2342                    // inherit from right
2343                    FunctionalDependency::with_indices(5, &[2], &[3, 4]),
2344                    // constant column in join condition
2345                    FunctionalDependency::with_indices(5, &[], &[0]),
2346                    // eq column in join condition
2347                    FunctionalDependency::with_indices(5, &[1], &[3]),
2348                    FunctionalDependency::with_indices(5, &[3], &[1]),
2349                ]
2350                .into_iter()
2351                .collect::<HashSet<_>>(),
2352            ),
2353            (JoinType::FullOuter, HashSet::new()),
2354            (
2355                JoinType::RightOuter,
2356                [
2357                    // inherit from right
2358                    FunctionalDependency::with_indices(5, &[2], &[3, 4]),
2359                ]
2360                .into_iter()
2361                .collect::<HashSet<_>>(),
2362            ),
2363            (
2364                JoinType::LeftOuter,
2365                [
2366                    // inherit from left
2367                    FunctionalDependency::with_indices(5, &[0], &[1]),
2368                ]
2369                .into_iter()
2370                .collect::<HashSet<_>>(),
2371            ),
2372            (
2373                JoinType::LeftSemi,
2374                [
2375                    // inherit from left
2376                    FunctionalDependency::with_indices(2, &[0], &[1]),
2377                ]
2378                .into_iter()
2379                .collect::<HashSet<_>>(),
2380            ),
2381            (
2382                JoinType::LeftAnti,
2383                [
2384                    // inherit from left
2385                    FunctionalDependency::with_indices(2, &[0], &[1]),
2386                ]
2387                .into_iter()
2388                .collect::<HashSet<_>>(),
2389            ),
2390            (
2391                JoinType::RightSemi,
2392                [
2393                    // inherit from right
2394                    FunctionalDependency::with_indices(3, &[0], &[1, 2]),
2395                ]
2396                .into_iter()
2397                .collect::<HashSet<_>>(),
2398            ),
2399            (
2400                JoinType::RightAnti,
2401                [
2402                    // inherit from right
2403                    FunctionalDependency::with_indices(3, &[0], &[1, 2]),
2404                ]
2405                .into_iter()
2406                .collect::<HashSet<_>>(),
2407            ),
2408        ];
2409
2410        for (join_type, expected_res) in expected_fd_set {
2411            let join = LogicalJoin::new(
2412                left.clone().into(),
2413                right.clone().into(),
2414                join_type,
2415                Condition::with_expr(on.clone()),
2416            );
2417            let fd_set = join
2418                .functional_dependency()
2419                .as_dependencies()
2420                .iter()
2421                .cloned()
2422                .collect::<HashSet<_>>();
2423            assert_eq!(fd_set, expected_res);
2424        }
2425    }
2426}