risingwave_frontend/optimizer/plan_node/
logical_join.rs

1// Copyright 2022 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::ops::Deref;
17
18use fixedbitset::FixedBitSet;
19use itertools::{EitherOrBoth, Itertools};
20use pretty_xmlish::{Pretty, XmlNode};
21use risingwave_expr::bail;
22use risingwave_pb::expr::expr_node::PbType;
23use risingwave_pb::plan_common::{AsOfJoinDesc, JoinType, PbAsOfJoinInequalityType};
24use risingwave_pb::stream_plan::StreamScanType;
25use risingwave_sqlparser::ast::AsOf;
26
27use super::generic::{
28    GenericPlanNode, GenericPlanRef, push_down_into_join, push_down_join_condition,
29};
30use super::utils::{Distill, childless_record};
31use super::{
32    BackfillType, BatchPlanRef, ColPrunable, ExprRewritable, Logical, LogicalPlanRef as PlanRef,
33    PlanBase, PlanTreeNodeBinary, PredicatePushdown, StreamHashJoin, StreamPlanRef, StreamProject,
34    ToBatch, ToStream, generic, try_enforce_locality_requirement,
35};
36use crate::error::{ErrorCode, Result, RwError};
37use crate::expr::{CollectInputRef, Expr, ExprImpl, ExprRewriter, ExprType, ExprVisitor, InputRef};
38use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
39use crate::optimizer::plan_node::generic::DynamicFilter;
40use crate::optimizer::plan_node::stream_asof_join::StreamAsOfJoin;
41use crate::optimizer::plan_node::utils::IndicesDisplay;
42use crate::optimizer::plan_node::{
43    BatchHashJoin, BatchLookupJoin, BatchNestedLoopJoin, ColumnPruningContext, EqJoinPredicate,
44    LogicalFilter, LogicalScan, PredicatePushdownContext, RewriteStreamContext,
45    StreamDynamicFilter, StreamFilter, StreamTableScan, StreamTemporalJoin, ToStreamContext,
46};
47use crate::optimizer::plan_visitor::LogicalCardinalityExt;
48use crate::optimizer::property::{Distribution, RequiredDist};
49use crate::utils::{ColIndexMapping, ColIndexMappingRewriteExt, Condition, ConditionDisplay};
50
51/// `LogicalJoin` combines two relations according to some condition.
52///
53/// Each output row has fields from the left and right inputs. The set of output rows is a subset
54/// of the cartesian product of the two inputs; precisely which subset depends on the join
55/// condition. In addition, the output columns are a subset of the columns of the left and
56/// right columns, dependent on the output indices provided. A repeat output index is illegal.
57#[derive(Debug, Clone, PartialEq, Eq, Hash)]
58pub struct LogicalJoin {
59    pub base: PlanBase<Logical>,
60    core: generic::Join<PlanRef>,
61}
62
63impl Distill for LogicalJoin {
64    fn distill<'a>(&self) -> XmlNode<'a> {
65        let verbose = self.base.ctx().is_explain_verbose();
66        let mut vec = Vec::with_capacity(if verbose { 3 } else { 2 });
67        vec.push(("type", Pretty::debug(&self.join_type())));
68
69        let concat_schema = self.core.concat_schema();
70        let cond = Pretty::debug(&ConditionDisplay {
71            condition: self.on(),
72            input_schema: &concat_schema,
73        });
74        vec.push(("on", cond));
75
76        if verbose {
77            let data = IndicesDisplay::from_join(&self.core, &concat_schema);
78            vec.push(("output", data));
79        }
80
81        childless_record("LogicalJoin", vec)
82    }
83}
84
85impl LogicalJoin {
86    pub(crate) fn new(left: PlanRef, right: PlanRef, join_type: JoinType, on: Condition) -> Self {
87        let core = generic::Join::with_full_output(left, right, join_type, on);
88        Self::with_core(core)
89    }
90
91    pub(crate) fn with_output_indices(
92        left: PlanRef,
93        right: PlanRef,
94        join_type: JoinType,
95        on: Condition,
96        output_indices: Vec<usize>,
97    ) -> Self {
98        let core = generic::Join::new(left, right, on, join_type, output_indices);
99        Self::with_core(core)
100    }
101
102    pub fn with_core(core: generic::Join<PlanRef>) -> Self {
103        let base = PlanBase::new_logical_with_core(&core);
104        LogicalJoin { base, core }
105    }
106
107    pub fn create(
108        left: PlanRef,
109        right: PlanRef,
110        join_type: JoinType,
111        on_clause: ExprImpl,
112    ) -> PlanRef {
113        Self::new(left, right, join_type, Condition::with_expr(on_clause)).into()
114    }
115
116    pub fn internal_column_num(&self) -> usize {
117        self.core.internal_column_num()
118    }
119
120    pub fn i2l_col_mapping_ignore_join_type(&self) -> ColIndexMapping {
121        self.core.i2l_col_mapping_ignore_join_type()
122    }
123
124    pub fn i2r_col_mapping_ignore_join_type(&self) -> ColIndexMapping {
125        self.core.i2r_col_mapping_ignore_join_type()
126    }
127
128    /// Get a reference to the logical join's on.
129    pub fn on(&self) -> &Condition {
130        self.core
131            .on
132            .as_condition_ref()
133            .expect("logical join should store predicate as Condition")
134    }
135
136    pub fn core(&self) -> &generic::Join<PlanRef> {
137        &self.core
138    }
139
140    /// Collect all input ref in the on condition. And separate them into left and right.
141    pub fn input_idx_on_condition(&self) -> (Vec<usize>, Vec<usize>) {
142        let input_refs = self
143            .core
144            .on
145            .as_condition_ref()
146            .expect("logical join should store predicate as Condition")
147            .collect_input_refs(self.core.left.schema().len() + self.core.right.schema().len());
148        let index_group = input_refs
149            .ones()
150            .chunk_by(|i| *i < self.core.left.schema().len());
151        let left_index = index_group
152            .into_iter()
153            .next()
154            .map_or(vec![], |group| group.1.collect_vec());
155        let right_index = index_group.into_iter().next().map_or(vec![], |group| {
156            group
157                .1
158                .map(|i| i - self.core.left.schema().len())
159                .collect_vec()
160        });
161        (left_index, right_index)
162    }
163
164    /// Get the join type of the logical join.
165    pub fn join_type(&self) -> JoinType {
166        self.core.join_type
167    }
168
169    /// Get the eq join key of the logical join.
170    pub fn eq_indexes(&self) -> Vec<(usize, usize)> {
171        self.core.eq_indexes()
172    }
173
174    /// Get the output indices of the logical join.
175    pub fn output_indices(&self) -> &Vec<usize> {
176        &self.core.output_indices
177    }
178
179    /// Clone with new output indices
180    pub fn clone_with_output_indices(&self, output_indices: Vec<usize>) -> Self {
181        Self::with_core(generic::Join {
182            output_indices,
183            ..self.core.clone()
184        })
185    }
186
187    /// Clone with new `on` condition
188    pub fn clone_with_cond(&self, on: Condition) -> Self {
189        Self::with_core(generic::Join {
190            on: generic::JoinOn::Condition(on),
191            ..self.core.clone()
192        })
193    }
194
195    pub fn is_left_join(&self) -> bool {
196        matches!(self.join_type(), JoinType::LeftSemi | JoinType::LeftAnti)
197    }
198
199    pub fn is_right_join(&self) -> bool {
200        matches!(self.join_type(), JoinType::RightSemi | JoinType::RightAnti)
201    }
202
203    pub fn is_full_out(&self) -> bool {
204        self.core.is_full_out()
205    }
206
207    pub fn is_asof_join(&self) -> bool {
208        self.join_type() == JoinType::AsofInner || self.join_type() == JoinType::AsofLeftOuter
209    }
210
211    pub fn output_indices_are_trivial(&self) -> bool {
212        itertools::equal(
213            self.output_indices().iter().cloned(),
214            0..self.internal_column_num(),
215        )
216    }
217
218    /// Try to simplify the outer join with the predicate on the top of the join
219    ///
220    /// now it is just a naive implementation for comparison expression, we can give a more general
221    /// implementation with constant folding in future
222    fn simplify_outer(predicate: &Condition, left_col_num: usize, join_type: JoinType) -> JoinType {
223        let (mut gen_null_in_left, mut gen_null_in_right) = match join_type {
224            JoinType::LeftOuter => (false, true),
225            JoinType::RightOuter => (true, false),
226            JoinType::FullOuter => (true, true),
227            _ => return join_type,
228        };
229
230        for expr in &predicate.conjunctions {
231            if let ExprImpl::FunctionCall(func) = expr {
232                match func.func_type() {
233                    ExprType::Equal
234                    | ExprType::NotEqual
235                    | ExprType::LessThan
236                    | ExprType::LessThanOrEqual
237                    | ExprType::GreaterThan
238                    | ExprType::GreaterThanOrEqual => {
239                        for input in func.inputs() {
240                            if let ExprImpl::InputRef(input) = input {
241                                let idx = input.index;
242                                if idx < left_col_num {
243                                    gen_null_in_left = false;
244                                } else {
245                                    gen_null_in_right = false;
246                                }
247                            }
248                        }
249                    }
250                    _ => {}
251                };
252            }
253        }
254
255        match (gen_null_in_left, gen_null_in_right) {
256            (true, true) => JoinType::FullOuter,
257            (true, false) => JoinType::RightOuter,
258            (false, true) => JoinType::LeftOuter,
259            (false, false) => JoinType::Inner,
260        }
261    }
262
263    /// Index Join:
264    /// Try to convert logical join into batch lookup join and meanwhile it will do
265    /// the index selection for the lookup table so that we can benefit from indexes.
266    fn to_batch_lookup_join_with_index_selection(
267        &self,
268        predicate: EqJoinPredicate,
269        batch_join: generic::Join<BatchPlanRef>,
270    ) -> Result<Option<BatchLookupJoin>> {
271        match batch_join.join_type {
272            JoinType::Inner
273            | JoinType::LeftOuter
274            | JoinType::LeftSemi
275            | JoinType::LeftAnti
276            | JoinType::AsofInner
277            | JoinType::AsofLeftOuter => {}
278            _ => return Ok(None),
279        };
280
281        // Index selection for index join.
282        let right = self.right();
283        // Lookup Join only supports basic tables on the join's right side.
284        let logical_scan: &LogicalScan = if let Some(logical_scan) = right.as_logical_scan() {
285            logical_scan
286        } else {
287            return Ok(None);
288        };
289
290        let mut result_plan = None;
291        // Lookup primary table.
292        if let Some(lookup_join) =
293            self.to_batch_lookup_join(predicate.clone(), batch_join.clone())?
294        {
295            result_plan = Some(lookup_join);
296        }
297
298        if self
299            .core
300            .ctx()
301            .session_ctx()
302            .config()
303            .enable_index_selection()
304        {
305            let indexes = logical_scan.table_indexes();
306            for index in indexes {
307                if let Some(index_scan) = logical_scan.to_index_scan_if_index_covered(index) {
308                    let index_scan: PlanRef = index_scan.into();
309                    let that = self.clone_with_left_right(self.left(), index_scan.clone());
310                    let mut new_batch_join = batch_join.clone();
311                    new_batch_join.right =
312                        index_scan.to_batch().expect("index scan failed to batch");
313
314                    // Lookup covered index.
315                    if let Some(lookup_join) =
316                        that.to_batch_lookup_join(predicate.clone(), new_batch_join)?
317                    {
318                        match &result_plan {
319                            None => result_plan = Some(lookup_join),
320                            Some(prev_lookup_join) => {
321                                // Prefer to choose lookup join with longer lookup prefix len.
322                                if prev_lookup_join.lookup_prefix_len()
323                                    < lookup_join.lookup_prefix_len()
324                                {
325                                    result_plan = Some(lookup_join)
326                                }
327                            }
328                        }
329                    }
330                }
331            }
332        }
333
334        Ok(result_plan)
335    }
336
337    /// Try to convert logical join into batch lookup join.
338    fn to_batch_lookup_join(
339        &self,
340        predicate: EqJoinPredicate,
341        logical_join: generic::Join<BatchPlanRef>,
342    ) -> Result<Option<BatchLookupJoin>> {
343        let logical_scan: &LogicalScan =
344            if let Some(logical_scan) = self.core.right.as_logical_scan() {
345                logical_scan
346            } else {
347                return Ok(None);
348            };
349        Self::gen_batch_lookup_join(logical_scan, predicate, logical_join, self.is_asof_join())
350    }
351
352    pub fn gen_batch_lookup_join(
353        logical_scan: &LogicalScan,
354        predicate: EqJoinPredicate,
355        logical_join: generic::Join<BatchPlanRef>,
356        is_as_of: bool,
357    ) -> Result<Option<BatchLookupJoin>> {
358        match logical_join.join_type {
359            JoinType::Inner
360            | JoinType::LeftOuter
361            | JoinType::LeftSemi
362            | JoinType::LeftAnti
363            | JoinType::AsofInner
364            | JoinType::AsofLeftOuter => {}
365            _ => return Ok(None),
366        };
367
368        let table = logical_scan.table();
369        let output_column_ids = logical_scan.output_column_ids();
370
371        // Verify that the right join key columns are the the prefix of the primary key and
372        // also contain the distribution key.
373        let order_col_ids = table.order_column_ids();
374        let dist_key = table.distribution_key.clone();
375        // The at least prefix of order key that contains distribution key.
376        let mut dist_key_in_order_key_pos = vec![];
377        for d in dist_key {
378            let pos = table
379                .order_column_indices()
380                .position(|x| x == d)
381                .expect("dist_key must in order_key");
382            dist_key_in_order_key_pos.push(pos);
383        }
384        // The shortest prefix of order key that contains distribution key.
385        let shortest_prefix_len = dist_key_in_order_key_pos
386            .iter()
387            .max()
388            .map_or(0, |pos| pos + 1);
389
390        // Distributed lookup join can't support lookup table with a singleton distribution.
391        if shortest_prefix_len == 0 {
392            return Ok(None);
393        }
394
395        // Reorder the join equal predicate to match the order key.
396        let mut reorder_idx = Vec::with_capacity(shortest_prefix_len);
397        for order_col_id in order_col_ids {
398            let mut found = false;
399            for (i, eq_idx) in predicate.right_eq_indexes().into_iter().enumerate() {
400                if order_col_id == output_column_ids[eq_idx] {
401                    reorder_idx.push(i);
402                    found = true;
403                    break;
404                }
405            }
406            if !found {
407                break;
408            }
409        }
410        if reorder_idx.len() < shortest_prefix_len {
411            return Ok(None);
412        }
413        let lookup_prefix_len = reorder_idx.len();
414        let predicate = predicate.reorder(&reorder_idx);
415
416        // Extract the predicate from logical scan. Only pure scan is supported.
417        let (new_scan, scan_predicate, project_expr) = logical_scan.predicate_pull_up();
418        // Construct output column to require column mapping
419        let o2r = if let Some(project_expr) = project_expr {
420            project_expr
421                .into_iter()
422                .map(|x| x.as_input_ref().unwrap().index)
423                .collect_vec()
424        } else {
425            (0..logical_scan.output_col_idx().len()).collect_vec()
426        };
427        let left_schema_len = logical_join.left.schema().len();
428
429        let mut join_predicate_rewriter = LookupJoinPredicateRewriter {
430            offset: left_schema_len,
431            mapping: o2r.clone(),
432        };
433
434        let new_eq_cond = predicate
435            .eq_cond()
436            .rewrite_expr(&mut join_predicate_rewriter);
437
438        let mut scan_predicate_rewriter = LookupJoinScanPredicateRewriter {
439            offset: left_schema_len,
440        };
441
442        let new_other_cond = predicate
443            .other_cond()
444            .clone()
445            .rewrite_expr(&mut join_predicate_rewriter)
446            .and(scan_predicate.rewrite_expr(&mut scan_predicate_rewriter));
447
448        let new_join_on = new_eq_cond.and(new_other_cond);
449        let new_predicate =
450            EqJoinPredicate::create(left_schema_len, new_scan.schema().len(), new_join_on);
451
452        // We discovered that we cannot use a lookup join after pulling up the predicate
453        // from one side and simplifying the condition. Let's use some other join instead.
454        if !new_predicate.has_eq() {
455            return Ok(None);
456        }
457
458        // Rewrite the join output indices and all output indices referred to the old scan need to
459        // rewrite.
460        let new_join_output_indices = logical_join
461            .output_indices
462            .iter()
463            .map(|&x| {
464                if x < left_schema_len {
465                    x
466                } else {
467                    o2r[x - left_schema_len] + left_schema_len
468                }
469            })
470            .collect_vec();
471
472        let new_scan_output_column_ids = new_scan.output_column_ids();
473        let as_of = new_scan.as_of.clone();
474        let new_logical_scan: LogicalScan = new_scan.into();
475
476        // Construct a new logical join, because we have change its RHS.
477        let new_logical_join = generic::Join::new_with_eq_predicate(
478            logical_join.left,
479            new_logical_scan.to_batch()?,
480            new_predicate,
481            logical_join.join_type,
482            new_join_output_indices,
483        );
484
485        let asof_desc = is_as_of
486            .then(|| {
487                Self::get_inequality_desc_from_predicate(
488                    predicate.other_cond().clone(),
489                    left_schema_len,
490                )
491            })
492            .transpose()?;
493
494        Ok(Some(BatchLookupJoin::new(
495            new_logical_join,
496            table.clone(),
497            new_scan_output_column_ids,
498            lookup_prefix_len,
499            false,
500            as_of,
501            asof_desc,
502        )))
503    }
504
505    pub fn decompose(self) -> (PlanRef, PlanRef, Condition, JoinType, Vec<usize>) {
506        self.core.decompose()
507    }
508}
509
510impl PlanTreeNodeBinary<Logical> for LogicalJoin {
511    fn left(&self) -> PlanRef {
512        self.core.left.clone()
513    }
514
515    fn right(&self) -> PlanRef {
516        self.core.right.clone()
517    }
518
519    fn clone_with_left_right(&self, left: PlanRef, right: PlanRef) -> Self {
520        Self::with_core(generic::Join {
521            left,
522            right,
523            ..self.core.clone()
524        })
525    }
526
527    fn rewrite_with_left_right(
528        &self,
529        left: PlanRef,
530        left_col_change: ColIndexMapping,
531        right: PlanRef,
532        right_col_change: ColIndexMapping,
533    ) -> (Self, ColIndexMapping) {
534        let (new_on, new_output_indices) = {
535            let (mut map, _) = left_col_change.clone().into_parts();
536            let (mut right_map, _) = right_col_change.clone().into_parts();
537            for i in right_map.iter_mut().flatten() {
538                *i += left.schema().len();
539            }
540            map.append(&mut right_map);
541            let mut mapping = ColIndexMapping::new(map, left.schema().len() + right.schema().len());
542
543            let new_output_indices = self
544                .output_indices()
545                .iter()
546                .map(|&i| mapping.map(i))
547                .collect::<Vec<_>>();
548            let new_on = self.on().clone().rewrite_expr(&mut mapping);
549            (new_on, new_output_indices)
550        };
551
552        let join = Self::with_output_indices(
553            left,
554            right,
555            self.join_type(),
556            new_on,
557            new_output_indices.clone(),
558        );
559
560        let new_i2o = ColIndexMapping::with_remaining_columns(
561            &new_output_indices,
562            join.internal_column_num(),
563        );
564
565        let old_o2i = self.core.o2i_col_mapping();
566
567        let old_o2l = old_o2i
568            .composite(&self.core.i2l_col_mapping())
569            .composite(&left_col_change);
570        let old_o2r = old_o2i
571            .composite(&self.core.i2r_col_mapping())
572            .composite(&right_col_change);
573        let new_l2o = join.core.l2i_col_mapping().composite(&new_i2o);
574        let new_r2o = join.core.r2i_col_mapping().composite(&new_i2o);
575
576        let out_col_change = old_o2l
577            .composite(&new_l2o)
578            .union(&old_o2r.composite(&new_r2o));
579        (join, out_col_change)
580    }
581}
582
583impl_plan_tree_node_for_binary! { Logical, LogicalJoin }
584
585impl ColPrunable for LogicalJoin {
586    fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef {
587        // make `required_cols` point to internal table instead of output schema.
588        let required_cols = required_cols
589            .iter()
590            .map(|i| self.output_indices()[*i])
591            .collect_vec();
592        let left_len = self.left().schema().fields.len();
593
594        let total_len = self.left().schema().len() + self.right().schema().len();
595        let mut resized_required_cols = FixedBitSet::with_capacity(total_len);
596
597        required_cols.iter().for_each(|&i| {
598            if self.is_right_join() {
599                resized_required_cols.insert(left_len + i);
600            } else {
601                resized_required_cols.insert(i);
602            }
603        });
604
605        // add those columns which are required in the join condition to
606        // to those that are required in the output
607        let mut visitor = CollectInputRef::new(resized_required_cols);
608        self.on().visit_expr(&mut visitor);
609        let left_right_required_cols = FixedBitSet::from(visitor).ones().collect_vec();
610
611        let mut left_required_cols = Vec::new();
612        let mut right_required_cols = Vec::new();
613        left_right_required_cols.iter().for_each(|&i| {
614            if i < left_len {
615                left_required_cols.push(i);
616            } else {
617                right_required_cols.push(i - left_len);
618            }
619        });
620
621        let mut on = self.on().clone();
622        let mut mapping =
623            ColIndexMapping::with_remaining_columns(&left_right_required_cols, total_len);
624        on = on.rewrite_expr(&mut mapping);
625
626        let new_output_indices = {
627            let required_inputs_in_output = if self.is_left_join() {
628                &left_required_cols
629            } else if self.is_right_join() {
630                &right_required_cols
631            } else {
632                &left_right_required_cols
633            };
634
635            let mapping =
636                ColIndexMapping::with_remaining_columns(required_inputs_in_output, total_len);
637            required_cols.iter().map(|&i| mapping.map(i)).collect_vec()
638        };
639
640        LogicalJoin::with_output_indices(
641            self.left().prune_col(&left_required_cols, ctx),
642            self.right().prune_col(&right_required_cols, ctx),
643            self.join_type(),
644            on,
645            new_output_indices,
646        )
647        .into()
648    }
649}
650
651impl ExprRewritable<Logical> for LogicalJoin {
652    fn has_rewritable_expr(&self) -> bool {
653        true
654    }
655
656    fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
657        let mut core = self.core.clone();
658        core.rewrite_exprs(r);
659        Self {
660            base: self.base.clone_with_new_plan_id(),
661            core,
662        }
663        .into()
664    }
665}
666
667impl ExprVisitable for LogicalJoin {
668    fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
669        self.core.visit_exprs(v);
670    }
671}
672
673/// We are trying to derive a predicate to apply to the other side of a join if all
674/// the `InputRef`s in the predicate are eq condition columns, and can hence be substituted
675/// with the corresponding eq condition columns of the other side.
676///
677/// Strategy:
678/// 1. If the function is pure except for any `InputRef` (which may refer to impure computation),
679///    then we proceed. Else abort.
680/// 2. Then, we collect `InputRef`s in the conjunction.
681/// 3. If they are all columns in the given side of join eq condition, then we proceed. Else abort.
682/// 4. We then rewrite the `ExprImpl`, by replacing `InputRef` column indices with the equivalent in
683///    the other side.
684///
685/// # Arguments
686///
687/// Suppose we derive a predicate from the left side to be pushed to the right side.
688/// * `expr`: An expr from the left side.
689/// * `col_num`: The number of columns in the left side.
690fn derive_predicate_from_eq_condition(
691    expr: &ExprImpl,
692    eq_condition: &EqJoinPredicate,
693    col_num: usize,
694    expr_is_left: bool,
695) -> Option<ExprImpl> {
696    if expr.is_impure() {
697        return None;
698    }
699    let eq_indices = eq_condition
700        .eq_indexes_typed()
701        .iter()
702        .filter_map(|(l, r)| {
703            if l.return_type() != r.return_type() {
704                None
705            } else if expr_is_left {
706                Some(l.index())
707            } else {
708                Some(r.index())
709            }
710        })
711        .collect_vec();
712    if expr
713        .collect_input_refs(col_num)
714        .ones()
715        .any(|index| !eq_indices.contains(&index))
716    {
717        // expr contains an InputRef not in eq_condition
718        return None;
719    }
720    // The function is pure except for `InputRef` and all `InputRef`s are `eq_condition` indices.
721    // Hence, we can substitute those `InputRef`s with indices from the other side.
722    let other_side_mapping = if expr_is_left {
723        eq_condition.eq_indexes_typed().into_iter().collect()
724    } else {
725        eq_condition
726            .eq_indexes_typed()
727            .into_iter()
728            .map(|(x, y)| (y, x))
729            .collect()
730    };
731    struct InputRefsRewriter {
732        mapping: HashMap<InputRef, InputRef>,
733    }
734    impl ExprRewriter for InputRefsRewriter {
735        fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
736            self.mapping[&input_ref].clone().into()
737        }
738    }
739    Some(
740        InputRefsRewriter {
741            mapping: other_side_mapping,
742        }
743        .rewrite_expr(expr.clone()),
744    )
745}
746
747/// Rewrite the join predicate and all columns referred to the scan side need to rewrite.
748struct LookupJoinPredicateRewriter {
749    offset: usize,
750    mapping: Vec<usize>,
751}
752impl ExprRewriter for LookupJoinPredicateRewriter {
753    fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
754        if input_ref.index() < self.offset {
755            input_ref.into()
756        } else {
757            InputRef::new(
758                self.mapping[input_ref.index() - self.offset] + self.offset,
759                input_ref.return_type(),
760            )
761            .into()
762        }
763    }
764}
765
766/// Rewrite the scan predicate so we can add it to the join predicate.
767struct LookupJoinScanPredicateRewriter {
768    offset: usize,
769}
770impl ExprRewriter for LookupJoinScanPredicateRewriter {
771    fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
772        InputRef::new(input_ref.index() + self.offset, input_ref.return_type()).into()
773    }
774}
775
776impl PredicatePushdown for LogicalJoin {
777    /// Pushes predicates above and within a join node into the join node and/or its children nodes.
778    ///
779    /// # Which predicates can be pushed
780    ///
781    /// For inner join, we can do all kinds of pushdown.
782    ///
783    /// For left/right semi join, we can push filter to left/right and on-clause,
784    /// and push on-clause to left/right.
785    ///
786    /// For left/right anti join, we can push filter to left/right, but on-clause can not be pushed
787    ///
788    /// ## Outer Join
789    ///
790    /// Preserved Row table
791    /// : The table in an Outer Join that must return all rows.
792    ///
793    /// Null Supplying table
794    /// : This is the table that has nulls filled in for its columns in unmatched rows.
795    ///
796    /// |                          | Preserved Row table | Null Supplying table |
797    /// |--------------------------|---------------------|----------------------|
798    /// | Join predicate (on)      | Not Pushed          | Pushed               |
799    /// | Where predicate (filter) | Pushed              | Not Pushed           |
800    fn predicate_pushdown(
801        &self,
802        predicate: Condition,
803        ctx: &mut PredicatePushdownContext,
804    ) -> PlanRef {
805        // rewrite output col referencing indices as internal cols
806        let mut predicate = {
807            let mut mapping = self.core.o2i_col_mapping();
808            predicate.rewrite_expr(&mut mapping)
809        };
810
811        let left_col_num = self.left().schema().len();
812        let right_col_num = self.right().schema().len();
813        let join_type = LogicalJoin::simplify_outer(&predicate, left_col_num, self.join_type());
814
815        let push_down_temporal_predicate = self.temporal_join_on().is_none();
816
817        let (left_from_filter, right_from_filter, on) = push_down_into_join(
818            &mut predicate,
819            left_col_num,
820            right_col_num,
821            join_type,
822            push_down_temporal_predicate,
823        );
824
825        let mut new_on = self.on().clone().and(on);
826        let (left_from_on, right_from_on) = push_down_join_condition(
827            &mut new_on,
828            left_col_num,
829            right_col_num,
830            join_type,
831            push_down_temporal_predicate,
832        );
833
834        let left_predicate = left_from_filter.and(left_from_on);
835        let right_predicate = right_from_filter.and(right_from_on);
836
837        // Derive conditions to push to the other side based on eq condition columns
838        let eq_condition = EqJoinPredicate::create(left_col_num, right_col_num, new_on.clone());
839
840        // Only push to RHS if RHS is inner side of a join (RHS requires match on LHS)
841        let right_from_left = if matches!(
842            join_type,
843            JoinType::Inner | JoinType::LeftOuter | JoinType::RightSemi | JoinType::LeftSemi
844        ) {
845            Condition {
846                conjunctions: left_predicate
847                    .conjunctions
848                    .iter()
849                    .filter_map(|expr| {
850                        derive_predicate_from_eq_condition(expr, &eq_condition, left_col_num, true)
851                    })
852                    .collect(),
853            }
854        } else {
855            Condition::true_cond()
856        };
857
858        // Only push to LHS if LHS is inner side of a join (LHS requires match on RHS)
859        let left_from_right = if matches!(
860            join_type,
861            JoinType::Inner | JoinType::RightOuter | JoinType::LeftSemi | JoinType::RightSemi
862        ) {
863            Condition {
864                conjunctions: right_predicate
865                    .conjunctions
866                    .iter()
867                    .filter_map(|expr| {
868                        derive_predicate_from_eq_condition(
869                            expr,
870                            &eq_condition,
871                            right_col_num,
872                            false,
873                        )
874                    })
875                    .collect(),
876            }
877        } else {
878            Condition::true_cond()
879        };
880
881        let left_predicate = left_predicate.and(left_from_right);
882        let right_predicate = right_predicate.and(right_from_left);
883
884        let new_left = self.left().predicate_pushdown(left_predicate, ctx);
885        let new_right = self.right().predicate_pushdown(right_predicate, ctx);
886        let new_join = LogicalJoin::with_output_indices(
887            new_left,
888            new_right,
889            join_type,
890            new_on,
891            self.output_indices().clone(),
892        );
893
894        let mut mapping = self.core.i2o_col_mapping();
895        predicate = predicate.rewrite_expr(&mut mapping);
896        LogicalFilter::create(new_join.into(), predicate)
897    }
898}
899
900#[derive(Clone, Copy)]
901struct TemporalJoinScan<'a>(&'a LogicalScan);
902
903impl<'a> Deref for TemporalJoinScan<'a> {
904    type Target = LogicalScan;
905
906    fn deref(&self) -> &Self::Target {
907        self.0
908    }
909}
910
911impl LogicalJoin {
912    fn get_stream_input_for_hash_join(
913        &self,
914        predicate: &EqJoinPredicate,
915        ctx: &mut ToStreamContext,
916    ) -> Result<(StreamPlanRef, StreamPlanRef)> {
917        use super::stream::prelude::*;
918
919        let lhs_join_key_idx = self.eq_indexes().into_iter().map(|(l, _)| l).collect_vec();
920        let rhs_join_key_idx = self.eq_indexes().into_iter().map(|(_, r)| r).collect_vec();
921
922        let logical_right = try_enforce_locality_requirement(self.right(), &rhs_join_key_idx);
923        let mut right = logical_right.to_stream_with_dist_required(
924            &RequiredDist::shard_by_key(self.right().schema().len(), &predicate.right_eq_indexes()),
925            ctx,
926        )?;
927        let logical_left = try_enforce_locality_requirement(self.left(), &lhs_join_key_idx);
928        let r2l =
929            predicate.r2l_eq_columns_mapping(logical_left.schema().len(), right.schema().len());
930        let l2r =
931            predicate.l2r_eq_columns_mapping(logical_left.schema().len(), right.schema().len());
932        let mut left;
933        let right_dist = right.distribution();
934        match right_dist {
935            Distribution::HashShard(_) => {
936                let left_dist = r2l
937                    .rewrite_required_distribution(&RequiredDist::PhysicalDist(right_dist.clone()));
938                left = logical_left.to_stream_with_dist_required(&left_dist, ctx)?;
939            }
940            Distribution::UpstreamHashShard(_, _) => {
941                left = logical_left.to_stream_with_dist_required(
942                    &RequiredDist::shard_by_key(
943                        self.left().schema().len(),
944                        &predicate.left_eq_indexes(),
945                    ),
946                    ctx,
947                )?;
948                let left_dist = left.distribution();
949                match left_dist {
950                    Distribution::HashShard(_) => {
951                        let right_dist = l2r.rewrite_required_distribution(
952                            &RequiredDist::PhysicalDist(left_dist.clone()),
953                        );
954                        right = right_dist.streaming_enforce_if_not_satisfies(right)?
955                    }
956                    Distribution::UpstreamHashShard(_, _) => {
957                        left = RequiredDist::hash_shard(&predicate.left_eq_indexes())
958                            .streaming_enforce_if_not_satisfies(left)?;
959                        right = RequiredDist::hash_shard(&predicate.right_eq_indexes())
960                            .streaming_enforce_if_not_satisfies(right)?;
961                    }
962                    _ => unreachable!(),
963                }
964            }
965            _ => unreachable!(),
966        }
967        Ok((left, right))
968    }
969
970    fn to_stream_hash_join(
971        &self,
972        predicate: EqJoinPredicate,
973        ctx: &mut ToStreamContext,
974    ) -> Result<StreamPlanRef> {
975        use super::stream::prelude::*;
976
977        assert!(predicate.has_eq());
978        let (left, right) = self.get_stream_input_for_hash_join(&predicate, ctx)?;
979
980        let mut core = self.core.clone_with_inputs(left, right);
981        core.on = generic::JoinOn::EqPredicate(predicate);
982
983        // Convert to Hash Join for equal joins
984        // For inner joins, pull non-equal conditions to a filter operator on top of it by default.
985        // We do so as the filter operator can apply the non-equal condition batch-wise (vectorized)
986        // as opposed to the HashJoin, which applies the condition row-wise.
987        // However, the default behavior of pulling up non-equal conditions can be overridden by the
988        // session variable `streaming_force_filter_inside_join` as it can save unnecessary
989        // materialization of rows only to be filtered later.
990
991        let stream_hash_join = StreamHashJoin::new(core.clone())?;
992        let predicate = stream_hash_join.eq_join_predicate().clone();
993
994        let force_filter_inside_join = self
995            .base
996            .ctx()
997            .session_ctx()
998            .config()
999            .streaming_force_filter_inside_join();
1000
1001        let pull_filter = self.join_type() == JoinType::Inner
1002            && stream_hash_join.eq_join_predicate().has_non_eq()
1003            && stream_hash_join.inequality_pairs().is_empty()
1004            && (!force_filter_inside_join);
1005        if pull_filter {
1006            let default_indices = (0..self.internal_column_num()).collect::<Vec<_>>();
1007
1008            let mut core = core;
1009            core.output_indices = default_indices.clone();
1010            // Temporarily remove output indices.
1011            let eq_cond = EqJoinPredicate::new(
1012                Condition::true_cond(),
1013                predicate.eq_keys().to_vec(),
1014                self.left().schema().len(),
1015                self.right().schema().len(),
1016            );
1017            core.on = generic::JoinOn::EqPredicate(eq_cond);
1018            let hash_join = StreamHashJoin::new(core)?.into();
1019            let logical_filter = generic::Filter::new(predicate.non_eq_cond(), hash_join);
1020            let plan = StreamFilter::new(logical_filter).into();
1021            if self.output_indices() != &default_indices {
1022                let logical_project = generic::Project::with_mapping(
1023                    plan,
1024                    ColIndexMapping::with_remaining_columns(
1025                        self.output_indices(),
1026                        self.internal_column_num(),
1027                    ),
1028                );
1029                Ok(StreamProject::new(logical_project).into())
1030            } else {
1031                Ok(plan)
1032            }
1033        } else {
1034            Ok(stream_hash_join.into())
1035        }
1036    }
1037
1038    pub fn should_be_temporal_join(&self) -> bool {
1039        self.temporal_join_on().is_some()
1040    }
1041
1042    fn temporal_join_on(&self) -> Option<TemporalJoinScan<'_>> {
1043        if let Some(logical_scan) = self.core.right.as_logical_scan() {
1044            matches!(logical_scan.as_of(), Some(AsOf::ProcessTime))
1045                .then_some(TemporalJoinScan(logical_scan))
1046        } else {
1047            None
1048        }
1049    }
1050
1051    fn should_be_stream_temporal_join<'a>(
1052        &'a self,
1053        ctx: &ToStreamContext,
1054    ) -> Result<Option<TemporalJoinScan<'a>>> {
1055        Ok(if let Some(scan) = self.temporal_join_on() {
1056            if let BackfillType::SnapshotBackfill = ctx.backfill_type() {
1057                return Err(RwError::from(ErrorCode::NotSupported(
1058                    "Temporal join with snapshot backfill not supported".into(),
1059                    "Please use arrangement backfill".into(),
1060                )));
1061            }
1062            if scan.cross_database() {
1063                return Err(RwError::from(ErrorCode::NotSupported(
1064                        "Temporal join requires the lookup table to be in the same database as the stream source table".into(),
1065                        "Please ensure both tables are in the same database".into(),
1066                    )));
1067            }
1068            Some(scan)
1069        } else {
1070            None
1071        })
1072    }
1073
1074    fn to_stream_temporal_join_with_index_selection(
1075        &self,
1076        logical_scan: TemporalJoinScan<'_>,
1077        predicate: EqJoinPredicate,
1078        ctx: &mut ToStreamContext,
1079    ) -> Result<StreamPlanRef> {
1080        // Use primary table.
1081        let mut result_plan: Result<StreamTemporalJoin> =
1082            self.to_stream_temporal_join(logical_scan, predicate.clone(), ctx);
1083        // Return directly if this temporal join can match the pk of its right table.
1084        if let Ok(temporal_join) = &result_plan
1085            && temporal_join.eq_join_predicate().eq_indexes().len()
1086                == logical_scan.primary_key().len()
1087        {
1088            return result_plan.map(|x| x.into());
1089        }
1090        if self
1091            .core
1092            .ctx()
1093            .session_ctx()
1094            .config()
1095            .enable_index_selection()
1096        {
1097            let indexes = logical_scan.table_indexes();
1098            for index in indexes {
1099                // Use index table
1100                if let Some(index_scan) = logical_scan.to_index_scan_if_index_covered(index) {
1101                    let index_scan: PlanRef = index_scan.into();
1102                    let that = self.clone_with_left_right(self.left(), index_scan.clone());
1103                    if let Ok(temporal_join) = that.to_stream_temporal_join(
1104                        that.temporal_join_on().expect(
1105                            "index scan created from temporal join scan must also be temporal join",
1106                        ),
1107                        predicate.clone(),
1108                        ctx,
1109                    ) {
1110                        match &result_plan {
1111                            Err(_) => result_plan = Ok(temporal_join),
1112                            Ok(prev_temporal_join) => {
1113                                // Prefer to the temporal join with a longer lookup prefix len.
1114                                if prev_temporal_join.eq_join_predicate().eq_indexes().len()
1115                                    < temporal_join.eq_join_predicate().eq_indexes().len()
1116                                {
1117                                    result_plan = Ok(temporal_join)
1118                                }
1119                            }
1120                        }
1121                    }
1122                }
1123            }
1124        }
1125
1126        result_plan.map(|x| x.into())
1127    }
1128
1129    fn temporal_join_scan_predicate_pull_up(
1130        logical_scan: TemporalJoinScan<'_>,
1131        predicate: EqJoinPredicate,
1132        output_indices: &[usize],
1133        left_schema_len: usize,
1134    ) -> Result<(StreamTableScan, EqJoinPredicate, Condition, Vec<usize>)> {
1135        // Extract the predicate from logical scan. Only pure scan is supported.
1136        let (new_scan, scan_predicate, project_expr) = logical_scan.predicate_pull_up();
1137        // Construct output column to require column mapping
1138        let o2r = if let Some(project_expr) = project_expr {
1139            project_expr
1140                .into_iter()
1141                .map(|x| x.as_input_ref().unwrap().index)
1142                .collect_vec()
1143        } else {
1144            (0..logical_scan.output_col_idx().len()).collect_vec()
1145        };
1146        let mut join_predicate_rewriter = LookupJoinPredicateRewriter {
1147            offset: left_schema_len,
1148            mapping: o2r.clone(),
1149        };
1150
1151        let new_eq_cond = predicate
1152            .eq_cond()
1153            .rewrite_expr(&mut join_predicate_rewriter);
1154
1155        let mut scan_predicate_rewriter = LookupJoinScanPredicateRewriter {
1156            offset: left_schema_len,
1157        };
1158
1159        let new_other_cond = predicate
1160            .other_cond()
1161            .clone()
1162            .rewrite_expr(&mut join_predicate_rewriter)
1163            .and(scan_predicate.rewrite_expr(&mut scan_predicate_rewriter));
1164
1165        let new_join_on = new_eq_cond.and(new_other_cond);
1166
1167        let new_predicate = EqJoinPredicate::create(
1168            left_schema_len,
1169            new_scan.schema().len(),
1170            new_join_on.clone(),
1171        );
1172
1173        // Rewrite the join output indices and all output indices referred to the old scan need to
1174        // rewrite.
1175        let new_join_output_indices = output_indices
1176            .iter()
1177            .map(|&x| {
1178                if x < left_schema_len {
1179                    x
1180                } else {
1181                    o2r[x - left_schema_len] + left_schema_len
1182                }
1183            })
1184            .collect_vec();
1185
1186        let new_stream_table_scan =
1187            StreamTableScan::new_with_stream_scan_type(new_scan, StreamScanType::UpstreamOnly);
1188        Ok((
1189            new_stream_table_scan,
1190            new_predicate,
1191            new_join_on,
1192            new_join_output_indices,
1193        ))
1194    }
1195
1196    fn to_stream_temporal_join(
1197        &self,
1198        logical_scan: TemporalJoinScan<'_>,
1199        predicate: EqJoinPredicate,
1200        ctx: &mut ToStreamContext,
1201    ) -> Result<StreamTemporalJoin> {
1202        use super::stream::prelude::*;
1203
1204        assert!(predicate.has_eq());
1205
1206        let table = logical_scan.table();
1207        let output_column_ids = logical_scan.output_column_ids();
1208
1209        // Verify that the right join key columns are the the prefix of the primary key and
1210        // also contain the distribution key.
1211        let order_col_ids = table.order_column_ids();
1212        let dist_key = table.distribution_key.clone();
1213
1214        let mut dist_key_in_order_key_pos = vec![];
1215        for d in dist_key {
1216            let pos = table
1217                .order_column_indices()
1218                .position(|x| x == d)
1219                .expect("dist_key must in order_key");
1220            dist_key_in_order_key_pos.push(pos);
1221        }
1222        // The shortest prefix of order key that contains distribution key.
1223        let shortest_prefix_len = dist_key_in_order_key_pos
1224            .iter()
1225            .max()
1226            .map_or(0, |pos| pos + 1);
1227
1228        // Reorder the join equal predicate to match the order key.
1229        let mut reorder_idx = Vec::with_capacity(shortest_prefix_len);
1230        for order_col_id in order_col_ids {
1231            let mut found = false;
1232            for (i, eq_idx) in predicate.right_eq_indexes().into_iter().enumerate() {
1233                if order_col_id == output_column_ids[eq_idx] {
1234                    reorder_idx.push(i);
1235                    found = true;
1236                    break;
1237                }
1238            }
1239            if !found {
1240                break;
1241            }
1242        }
1243        if reorder_idx.len() < shortest_prefix_len {
1244            return Err(RwError::from(ErrorCode::NotSupported(
1245                "Temporal join requires the equivalence join condition includes the key columns that form the distribution key of the lookup table".into(),
1246                concat!(
1247                    "Use DESCRIBE <table_name> to view the table's key information.\n",
1248                    "You can create an index on the lookup table to facilitate the temporal join if necessary."
1249                ).into(),
1250            )));
1251        }
1252        let lookup_prefix_len = reorder_idx.len();
1253        let predicate = predicate.reorder(&reorder_idx);
1254
1255        let required_dist = if dist_key_in_order_key_pos.is_empty() {
1256            RequiredDist::single()
1257        } else {
1258            let left_eq_indexes = predicate.left_eq_indexes();
1259            let left_dist_key = dist_key_in_order_key_pos
1260                .iter()
1261                .map(|pos| left_eq_indexes[*pos])
1262                .collect_vec();
1263
1264            RequiredDist::hash_shard(&left_dist_key)
1265        };
1266
1267        let lhs_join_key_idx = predicate
1268            .eq_indexes()
1269            .into_iter()
1270            .map(|(l, _)| l)
1271            .collect_vec();
1272        let logical_left = try_enforce_locality_requirement(self.left(), &lhs_join_key_idx);
1273        let left = logical_left.to_stream(ctx)?;
1274        // Enforce a shuffle for the temporal join LHS to let the scheduler be able to schedule the join fragment together with the RHS with a `no_shuffle` exchange.
1275        let left = required_dist.stream_enforce(left);
1276
1277        let (new_stream_table_scan, new_predicate, new_join_on, new_join_output_indices) =
1278            Self::temporal_join_scan_predicate_pull_up(
1279                logical_scan,
1280                predicate,
1281                self.output_indices(),
1282                self.left().schema().len(),
1283            )?;
1284
1285        let right = RequiredDist::no_shuffle(new_stream_table_scan.into());
1286        if !new_predicate.has_eq() {
1287            return Err(RwError::from(ErrorCode::NotSupported(
1288                "Temporal join requires a non trivial join condition".into(),
1289                "Please remove the false condition of the join".into(),
1290            )));
1291        }
1292
1293        // Construct a new logical join, because we have change its RHS.
1294        let new_logical_join = generic::Join::new(
1295            left,
1296            right,
1297            new_join_on,
1298            self.join_type(),
1299            new_join_output_indices,
1300        );
1301
1302        let new_predicate = new_predicate.retain_prefix_eq_key(lookup_prefix_len);
1303
1304        let mut new_logical_join = new_logical_join;
1305        new_logical_join.on = generic::JoinOn::EqPredicate(new_predicate);
1306        StreamTemporalJoin::new(new_logical_join, false)
1307    }
1308
1309    fn to_stream_nested_loop_temporal_join(
1310        &self,
1311        logical_scan: TemporalJoinScan<'_>,
1312        predicate: EqJoinPredicate,
1313        ctx: &mut ToStreamContext,
1314    ) -> Result<StreamPlanRef> {
1315        use super::stream::prelude::*;
1316        assert!(!predicate.has_eq());
1317
1318        let left = self.left().to_stream_with_dist_required(
1319            &RequiredDist::PhysicalDist(Distribution::Broadcast),
1320            ctx,
1321        )?;
1322        assert!(left.as_stream_exchange().is_some());
1323
1324        if self.join_type() != JoinType::Inner {
1325            return Err(RwError::from(ErrorCode::NotSupported(
1326                "Temporal join requires an inner join".into(),
1327                "Please use an inner join".into(),
1328            )));
1329        }
1330
1331        if !left.append_only() {
1332            return Err(RwError::from(ErrorCode::NotSupported(
1333                "Nested-loop Temporal join requires the left hash side to be append only".into(),
1334                "Please ensure the left hash side is append only".into(),
1335            )));
1336        }
1337
1338        let (new_stream_table_scan, new_predicate, new_join_on, new_join_output_indices) =
1339            Self::temporal_join_scan_predicate_pull_up(
1340                logical_scan,
1341                predicate,
1342                self.output_indices(),
1343                self.left().schema().len(),
1344            )?;
1345
1346        let right = RequiredDist::no_shuffle(new_stream_table_scan.into());
1347
1348        // Construct a new logical join, because we have change its RHS.
1349        let new_logical_join = generic::Join::new(
1350            left,
1351            right,
1352            new_join_on,
1353            self.join_type(),
1354            new_join_output_indices,
1355        );
1356
1357        let mut new_logical_join = new_logical_join;
1358        new_logical_join.on = generic::JoinOn::EqPredicate(new_predicate);
1359        Ok(StreamTemporalJoin::new(new_logical_join, true)?.into())
1360    }
1361
1362    fn to_stream_dynamic_filter(
1363        &self,
1364        predicate: Condition,
1365        ctx: &mut ToStreamContext,
1366    ) -> Result<Option<StreamPlanRef>> {
1367        use super::stream::prelude::*;
1368
1369        // If there is exactly one predicate, it is a comparison (<, <=, >, >=), and the
1370        // join is a `Inner` or `LeftSemi` join, we can convert the scalar subquery into a
1371        // `StreamDynamicFilter`
1372
1373        // Check if `Inner`/`LeftSemi`
1374        if !matches!(self.join_type(), JoinType::Inner | JoinType::LeftSemi) {
1375            return Ok(None);
1376        }
1377
1378        // Check if right side is a scalar
1379        if !self.right().max_one_row() {
1380            return Ok(None);
1381        }
1382        if self.right().schema().len() != 1 {
1383            return Ok(None);
1384        }
1385
1386        // Check if the join condition is a correlated comparison
1387        if predicate.conjunctions.len() > 1 {
1388            return Ok(None);
1389        }
1390        let expr: ExprImpl = predicate.into();
1391        let (left_ref, comparator, right_ref) = match expr.as_comparison_cond() {
1392            Some(v) => v,
1393            None => return Ok(None),
1394        };
1395
1396        let condition_cross_inputs = left_ref.index < self.left().schema().len()
1397            && right_ref.index == self.left().schema().len() /* right side has only one column */;
1398        if !condition_cross_inputs {
1399            // Maybe we should panic here because it means some predicates are not pushed down.
1400            return Ok(None);
1401        }
1402
1403        // We align input types on all join predicates with cmp operator
1404        if self.left().schema().fields()[left_ref.index].data_type
1405            != self.right().schema().fields()[0].data_type
1406        {
1407            return Ok(None);
1408        }
1409
1410        // Check if non of the columns from the inner side is required to output
1411        let all_output_from_left = self
1412            .output_indices()
1413            .iter()
1414            .all(|i| *i < self.left().schema().len());
1415        if !all_output_from_left {
1416            return Ok(None);
1417        }
1418
1419        let left = self.left().to_stream(ctx)?.enforce_concrete_distribution();
1420        let right = self.right().to_stream_with_dist_required(
1421            &RequiredDist::PhysicalDist(Distribution::Broadcast),
1422            ctx,
1423        )?;
1424
1425        assert!(right.as_stream_exchange().is_some());
1426        assert_eq!(
1427            *right.inputs().iter().exactly_one().unwrap().distribution(),
1428            Distribution::Single
1429        );
1430
1431        let core = DynamicFilter::new(comparator, left_ref.index, left, right);
1432        let plan = StreamDynamicFilter::new(core)?.into();
1433        // TODO: `DynamicFilterExecutor` should support `output_indices` in `ChunkBuilder`
1434        if self
1435            .output_indices()
1436            .iter()
1437            .copied()
1438            .ne(0..self.left().schema().len())
1439        {
1440            // The schema of dynamic filter is always the same as the left side now, and we have
1441            // checked that all output columns are from the left side before.
1442            let logical_project = generic::Project::with_mapping(
1443                plan,
1444                ColIndexMapping::with_remaining_columns(
1445                    self.output_indices(),
1446                    self.left().schema().len(),
1447                ),
1448            );
1449            Ok(Some(StreamProject::new(logical_project).into()))
1450        } else {
1451            Ok(Some(plan))
1452        }
1453    }
1454
1455    pub fn index_lookup_join_to_batch_lookup_join(&self) -> Result<BatchPlanRef> {
1456        let predicate = EqJoinPredicate::create(
1457            self.left().schema().len(),
1458            self.right().schema().len(),
1459            self.on().clone(),
1460        );
1461        assert!(predicate.has_eq());
1462
1463        let join = self
1464            .core
1465            .clone_with_inputs(self.core.left.to_batch()?, self.core.right.to_batch()?);
1466
1467        Ok(self
1468            .to_batch_lookup_join(predicate, join)?
1469            .expect("Fail to convert to lookup join")
1470            .into())
1471    }
1472
1473    fn to_stream_asof_join(
1474        &self,
1475        predicate: EqJoinPredicate,
1476        ctx: &mut ToStreamContext,
1477    ) -> Result<StreamPlanRef> {
1478        use super::stream::prelude::*;
1479
1480        if predicate.eq_keys().is_empty() {
1481            return Err(ErrorCode::InvalidInputSyntax(
1482                "AsOf join requires at least 1 equal condition".to_owned(),
1483            )
1484            .into());
1485        }
1486
1487        let (left, right) = self.get_stream_input_for_hash_join(&predicate, ctx)?;
1488        let left_len = left.schema().len();
1489        let mut core = self.core.clone_with_inputs(left, right);
1490        core.on = generic::JoinOn::EqPredicate(predicate);
1491
1492        let inequality_desc = Self::get_inequality_desc_from_predicate(
1493            core.on
1494                .as_eq_predicate_ref()
1495                .expect("core predicate must exist")
1496                .other_cond()
1497                .clone(),
1498            left_len,
1499        )?;
1500
1501        Ok(StreamAsOfJoin::new(core, inequality_desc)?.into())
1502    }
1503
1504    /// Convert the logical join to a Hash join.
1505    fn to_batch_hash_join(
1506        &self,
1507        logical_join: generic::Join<BatchPlanRef>,
1508        predicate: EqJoinPredicate,
1509    ) -> Result<BatchPlanRef> {
1510        use super::batch::prelude::*;
1511
1512        let left_schema_len = logical_join.left.schema().len();
1513        let asof_desc = self
1514            .is_asof_join()
1515            .then(|| {
1516                Self::get_inequality_desc_from_predicate(
1517                    predicate.other_cond().clone(),
1518                    left_schema_len,
1519                )
1520            })
1521            .transpose()?;
1522
1523        let logical_join = generic::Join {
1524            on: generic::JoinOn::EqPredicate(predicate),
1525            ..logical_join
1526        };
1527        let batch_join = BatchHashJoin::new(logical_join, asof_desc);
1528        Ok(batch_join.into())
1529    }
1530
1531    pub fn get_inequality_desc_from_predicate(
1532        predicate: Condition,
1533        left_input_len: usize,
1534    ) -> Result<AsOfJoinDesc> {
1535        let expr: ExprImpl = predicate.into();
1536        if let Some((left_input_ref, expr_type, right_input_ref)) = expr.as_comparison_cond() {
1537            if left_input_ref.index() < left_input_len && right_input_ref.index() >= left_input_len
1538            {
1539                Ok(AsOfJoinDesc {
1540                    left_idx: left_input_ref.index() as u32,
1541                    right_idx: (right_input_ref.index() - left_input_len) as u32,
1542                    inequality_type: Self::expr_type_to_comparison_type(expr_type)?.into(),
1543                })
1544            } else {
1545                bail!("inequal condition from the same side should be push down in optimizer");
1546            }
1547        } else {
1548            Err(ErrorCode::InvalidInputSyntax(
1549                "AsOf join requires exactly 1 ineuquality condition".to_owned(),
1550            )
1551            .into())
1552        }
1553    }
1554
1555    fn expr_type_to_comparison_type(expr_type: PbType) -> Result<PbAsOfJoinInequalityType> {
1556        match expr_type {
1557            PbType::LessThan => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeLt),
1558            PbType::LessThanOrEqual => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeLe),
1559            PbType::GreaterThan => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeGt),
1560            PbType::GreaterThanOrEqual => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeGe),
1561            _ => Err(ErrorCode::InvalidInputSyntax(format!(
1562                "Invalid comparison type: {}",
1563                expr_type.as_str_name()
1564            ))
1565            .into()),
1566        }
1567    }
1568}
1569
1570impl ToBatch for LogicalJoin {
1571    fn to_batch(&self) -> Result<crate::optimizer::plan_node::BatchPlanRef> {
1572        let predicate = EqJoinPredicate::create(
1573            self.left().schema().len(),
1574            self.right().schema().len(),
1575            self.on().clone(),
1576        );
1577
1578        let batch_join = self
1579            .core
1580            .clone_with_inputs(self.core.left.to_batch()?, self.core.right.to_batch()?);
1581
1582        let ctx = self.base.ctx();
1583        let config = ctx.session_ctx().config();
1584
1585        if predicate.has_eq() {
1586            if !predicate.eq_keys_are_type_aligned() {
1587                return Err(ErrorCode::InternalError(format!(
1588                    "Join eq keys are not aligned for predicate: {predicate:?}"
1589                ))
1590                .into());
1591            }
1592            if config.batch_enable_lookup_join()
1593                && let Some(lookup_join) = self.to_batch_lookup_join_with_index_selection(
1594                    predicate.clone(),
1595                    batch_join.clone(),
1596                )?
1597            {
1598                return Ok(lookup_join.into());
1599            }
1600            self.to_batch_hash_join(batch_join, predicate)
1601        } else if self.is_asof_join() {
1602            Err(ErrorCode::InvalidInputSyntax(
1603                "AsOf join requires at least 1 equal condition".to_owned(),
1604            )
1605            .into())
1606        } else {
1607            // Convert to Nested-loop Join for non-equal joins
1608            Ok(BatchNestedLoopJoin::new(batch_join).into())
1609        }
1610    }
1611}
1612
1613impl ToStream for LogicalJoin {
1614    fn to_stream(
1615        &self,
1616        ctx: &mut ToStreamContext,
1617    ) -> Result<crate::optimizer::plan_node::StreamPlanRef> {
1618        if self
1619            .on()
1620            .conjunctions
1621            .iter()
1622            .any(|cond| cond.count_nows() > 0)
1623        {
1624            return Err(ErrorCode::NotSupported(
1625                "optimizer has tried to separate the temporal predicate(with now() expression) from the on condition, but it still reminded in on join's condition. Considering move it into WHERE clause?".to_owned(),
1626                 "please refer to https://www.risingwave.dev/docs/current/sql-pattern-temporal-filters/ for more information".to_owned()).into());
1627        }
1628
1629        let predicate = EqJoinPredicate::create(
1630            self.left().schema().len(),
1631            self.right().schema().len(),
1632            self.on().clone(),
1633        );
1634
1635        if self.join_type() == JoinType::AsofInner || self.join_type() == JoinType::AsofLeftOuter {
1636            self.to_stream_asof_join(predicate, ctx)
1637        } else if predicate.has_eq() {
1638            if !predicate.eq_keys_are_type_aligned() {
1639                return Err(ErrorCode::InternalError(format!(
1640                    "Join eq keys are not aligned for predicate: {predicate:?}"
1641                ))
1642                .into());
1643            }
1644
1645            if let Some(scan) = self.should_be_stream_temporal_join(ctx)? {
1646                self.to_stream_temporal_join_with_index_selection(scan, predicate, ctx)
1647            } else {
1648                self.to_stream_hash_join(predicate, ctx)
1649            }
1650        } else if let Some(scan) = self.should_be_stream_temporal_join(ctx)? {
1651            self.to_stream_nested_loop_temporal_join(scan, predicate, ctx)
1652        } else if let Some(dynamic_filter) =
1653            self.to_stream_dynamic_filter(self.on().clone(), ctx)?
1654        {
1655            Ok(dynamic_filter)
1656        } else {
1657            Err(RwError::from(ErrorCode::NotSupported(
1658                "streaming nested-loop join".to_owned(),
1659                "The non-equal join in the query requires a nested-loop join executor, which could be very expensive to run. \
1660                 Consider rewriting the query to use dynamic filter as a substitute if possible.\n\
1661                 See also: https://docs.risingwave.com/docs/current/sql-pattern-dynamic-filters/".to_owned(),
1662            )))
1663        }
1664    }
1665
1666    fn logical_rewrite_for_stream(
1667        &self,
1668        ctx: &mut RewriteStreamContext,
1669    ) -> Result<(PlanRef, ColIndexMapping)> {
1670        let (left, left_col_change) = self.left().logical_rewrite_for_stream(ctx)?;
1671        let left_len = left.schema().len();
1672        let (right, right_col_change) = self.right().logical_rewrite_for_stream(ctx)?;
1673        let (join, out_col_change) = self.rewrite_with_left_right(
1674            left.clone(),
1675            left_col_change,
1676            right.clone(),
1677            right_col_change,
1678        );
1679
1680        let mapping = ColIndexMapping::with_remaining_columns(
1681            join.output_indices(),
1682            join.internal_column_num(),
1683        );
1684
1685        let l2o = join.core.l2i_col_mapping().composite(&mapping);
1686        let r2o = join.core.r2i_col_mapping().composite(&mapping);
1687
1688        // Add missing pk indices to the logical join
1689        let mut left_to_add = left
1690            .expect_stream_key()
1691            .iter()
1692            .cloned()
1693            .filter(|i| l2o.try_map(*i).is_none())
1694            .collect_vec();
1695
1696        let mut right_to_add = right
1697            .expect_stream_key()
1698            .iter()
1699            .filter(|&&i| r2o.try_map(i).is_none())
1700            .map(|&i| i + left_len)
1701            .collect_vec();
1702
1703        // NOTE(st1page): add join keys in the pk_indices a work around before we really have stream
1704        // key.
1705        let right_len = right.schema().len();
1706        let eq_predicate = EqJoinPredicate::create(left_len, right_len, join.on().clone());
1707
1708        let either_or_both = self.core.add_which_join_key_to_pk();
1709
1710        for (lk, rk) in eq_predicate.eq_indexes() {
1711            match either_or_both {
1712                EitherOrBoth::Left(_) => {
1713                    if l2o.try_map(lk).is_none() {
1714                        left_to_add.push(lk);
1715                    }
1716                }
1717                EitherOrBoth::Right(_) => {
1718                    if r2o.try_map(rk).is_none() {
1719                        right_to_add.push(rk + left_len)
1720                    }
1721                }
1722                EitherOrBoth::Both(_, _) => {
1723                    if l2o.try_map(lk).is_none() {
1724                        left_to_add.push(lk);
1725                    }
1726                    if r2o.try_map(rk).is_none() {
1727                        right_to_add.push(rk + left_len)
1728                    }
1729                }
1730            };
1731        }
1732        let left_to_add = left_to_add.into_iter().unique();
1733        let right_to_add = right_to_add.into_iter().unique();
1734        // NOTE(st1page) over
1735
1736        let mut new_output_indices = join.output_indices().clone();
1737        if !join.is_right_join() {
1738            new_output_indices.extend(left_to_add);
1739        }
1740        if !join.is_left_join() {
1741            new_output_indices.extend(right_to_add);
1742        }
1743
1744        let join_with_pk = join.clone_with_output_indices(new_output_indices);
1745
1746        let plan = if join_with_pk.join_type() == JoinType::FullOuter {
1747            // ignore the all NULL to maintain the stream key's uniqueness, see https://github.com/risingwavelabs/risingwave/issues/8084 for more information
1748
1749            let l2o = join_with_pk
1750                .core
1751                .l2i_col_mapping()
1752                .composite(&join_with_pk.core.i2o_col_mapping());
1753            let r2o = join_with_pk
1754                .core
1755                .r2i_col_mapping()
1756                .composite(&join_with_pk.core.i2o_col_mapping());
1757            let left_right_stream_keys = join_with_pk
1758                .left()
1759                .expect_stream_key()
1760                .iter()
1761                .map(|i| l2o.map(*i))
1762                .chain(
1763                    join_with_pk
1764                        .right()
1765                        .expect_stream_key()
1766                        .iter()
1767                        .map(|i| r2o.map(*i)),
1768                )
1769                .collect_vec();
1770            let plan: PlanRef = join_with_pk.into();
1771            LogicalFilter::filter_out_all_null_keys(plan, &left_right_stream_keys)
1772        } else {
1773            join_with_pk.into()
1774        };
1775
1776        // the added columns is at the end, so it will not change the exists column index
1777        Ok((plan, out_col_change))
1778    }
1779
1780    fn try_better_locality(&self, columns: &[usize]) -> Option<PlanRef> {
1781        let mut ctx = ToStreamContext::new(false);
1782        // only pass through the locality information if it can be converted to dynamic filter
1783        if let Ok(Some(_)) = self.to_stream_dynamic_filter(self.on().clone(), &mut ctx) {
1784            // since dynamic filter only supports left input ref in the output indices, we can safely use o2i mapping to convert the required columns.
1785            let o2i_mapping = self.core.o2i_col_mapping();
1786            let left_input_columns = columns
1787                .iter()
1788                .map(|&col| o2i_mapping.try_map(col))
1789                .collect::<Option<Vec<usize>>>()?;
1790            if let Some(better_left_plan) = self.left().try_better_locality(&left_input_columns) {
1791                return Some(
1792                    self.clone_with_left_right(better_left_plan, self.right())
1793                        .into(),
1794                );
1795            }
1796        }
1797        None
1798    }
1799}
1800
1801#[cfg(test)]
1802mod tests {
1803
1804    use std::collections::HashSet;
1805
1806    use risingwave_common::catalog::{Field, Schema};
1807    use risingwave_common::types::{DataType, Datum};
1808    use risingwave_pb::expr::expr_node::Type;
1809
1810    use super::*;
1811    use crate::expr::{FunctionCall, Literal, assert_eq_input_ref};
1812    use crate::optimizer::optimizer_context::OptimizerContext;
1813    use crate::optimizer::plan_node::LogicalValues;
1814    use crate::optimizer::property::FunctionalDependency;
1815
1816    /// Pruning
1817    /// ```text
1818    /// Join(on: input_ref(1)=input_ref(3))
1819    ///   TableScan(v1, v2, v3)
1820    ///   TableScan(v4, v5, v6)
1821    /// ```
1822    /// with required columns [2,3] will result in
1823    /// ```text
1824    /// Project(input_ref(1), input_ref(2))
1825    ///   Join(on: input_ref(0)=input_ref(2))
1826    ///     TableScan(v2, v3)
1827    ///     TableScan(v4)
1828    /// ```
1829    #[tokio::test]
1830    async fn test_prune_join() {
1831        let ty = DataType::Int32;
1832        let ctx = OptimizerContext::mock().await;
1833        let fields: Vec<Field> = (1..7)
1834            .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
1835            .collect();
1836        let left = LogicalValues::new(
1837            vec![],
1838            Schema {
1839                fields: fields[0..3].to_vec(),
1840            },
1841            ctx.clone(),
1842        );
1843        let right = LogicalValues::new(
1844            vec![],
1845            Schema {
1846                fields: fields[3..6].to_vec(),
1847            },
1848            ctx,
1849        );
1850        let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
1851            FunctionCall::new(
1852                Type::Equal,
1853                vec![
1854                    ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
1855                    ExprImpl::InputRef(Box::new(InputRef::new(3, ty))),
1856                ],
1857            )
1858            .unwrap(),
1859        ));
1860        let join_type = JoinType::Inner;
1861        let join: PlanRef = LogicalJoin::new(
1862            left.into(),
1863            right.into(),
1864            join_type,
1865            Condition::with_expr(on),
1866        )
1867        .into();
1868
1869        // Perform the prune
1870        let required_cols = vec![2, 3];
1871        let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1872
1873        // Check the result
1874        let join = plan.as_logical_join().unwrap();
1875        assert_eq!(join.schema().fields().len(), 2);
1876        assert_eq!(join.schema().fields()[0], fields[2]);
1877        assert_eq!(join.schema().fields()[1], fields[3]);
1878
1879        let expr: ExprImpl = join.on().clone().into();
1880        let call = expr.as_function_call().unwrap();
1881        assert_eq_input_ref!(&call.inputs()[0], 0);
1882        assert_eq_input_ref!(&call.inputs()[1], 2);
1883
1884        let left = join.left();
1885        let left = left.as_logical_values().unwrap();
1886        assert_eq!(left.schema().fields(), &fields[1..3]);
1887        let right = join.right();
1888        let right = right.as_logical_values().unwrap();
1889        assert_eq!(right.schema().fields(), &fields[3..4]);
1890    }
1891
1892    /// Semi join panicked previously at `prune_col`. Add test to prevent regression.
1893    #[tokio::test]
1894    async fn test_prune_semi_join() {
1895        let ty = DataType::Int32;
1896        let ctx = OptimizerContext::mock().await;
1897        let fields: Vec<Field> = (1..7)
1898            .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
1899            .collect();
1900        let left = LogicalValues::new(
1901            vec![],
1902            Schema {
1903                fields: fields[0..3].to_vec(),
1904            },
1905            ctx.clone(),
1906        );
1907        let right = LogicalValues::new(
1908            vec![],
1909            Schema {
1910                fields: fields[3..6].to_vec(),
1911            },
1912            ctx,
1913        );
1914        let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
1915            FunctionCall::new(
1916                Type::Equal,
1917                vec![
1918                    ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
1919                    ExprImpl::InputRef(Box::new(InputRef::new(4, ty))),
1920                ],
1921            )
1922            .unwrap(),
1923        ));
1924        for join_type in [
1925            JoinType::LeftSemi,
1926            JoinType::RightSemi,
1927            JoinType::LeftAnti,
1928            JoinType::RightAnti,
1929        ] {
1930            let join = LogicalJoin::new(
1931                left.clone().into(),
1932                right.clone().into(),
1933                join_type,
1934                Condition::with_expr(on.clone()),
1935            );
1936
1937            let offset = if join.is_right_join() { 3 } else { 0 };
1938            let join: PlanRef = join.into();
1939            // Perform the prune
1940            let required_cols = vec![0];
1941            // key 0 is never used in the join (always key 1)
1942            let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1943            let as_plan = plan.as_logical_join().unwrap();
1944            // Check the result
1945            assert_eq!(as_plan.schema().fields().len(), 1);
1946            assert_eq!(as_plan.schema().fields()[0], fields[offset]);
1947
1948            // Perform the prune
1949            let required_cols = vec![0, 1, 2];
1950            // should not panic here
1951            let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1952            let as_plan = plan.as_logical_join().unwrap();
1953            // Check the result
1954            assert_eq!(as_plan.schema().fields().len(), 3);
1955            assert_eq!(as_plan.schema().fields()[0], fields[offset]);
1956            assert_eq!(as_plan.schema().fields()[1], fields[offset + 1]);
1957            assert_eq!(as_plan.schema().fields()[2], fields[offset + 2]);
1958        }
1959    }
1960
1961    /// Pruning
1962    /// ```text
1963    /// Join(on: input_ref(1)=input_ref(3))
1964    ///   TableScan(v1, v2, v3)
1965    ///   TableScan(v4, v5, v6)
1966    /// ```
1967    /// with required columns [1, 3] will result in
1968    /// ```text
1969    /// Join(on: input_ref(0)=input_ref(1))
1970    ///   TableScan(v2)
1971    ///   TableScan(v4)
1972    /// ```
1973    #[tokio::test]
1974    async fn test_prune_join_no_project() {
1975        let ty = DataType::Int32;
1976        let ctx = OptimizerContext::mock().await;
1977        let fields: Vec<Field> = (1..7)
1978            .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
1979            .collect();
1980        let left = LogicalValues::new(
1981            vec![],
1982            Schema {
1983                fields: fields[0..3].to_vec(),
1984            },
1985            ctx.clone(),
1986        );
1987        let right = LogicalValues::new(
1988            vec![],
1989            Schema {
1990                fields: fields[3..6].to_vec(),
1991            },
1992            ctx,
1993        );
1994        let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
1995            FunctionCall::new(
1996                Type::Equal,
1997                vec![
1998                    ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
1999                    ExprImpl::InputRef(Box::new(InputRef::new(3, ty))),
2000                ],
2001            )
2002            .unwrap(),
2003        ));
2004        let join_type = JoinType::Inner;
2005        let join: PlanRef = LogicalJoin::new(
2006            left.into(),
2007            right.into(),
2008            join_type,
2009            Condition::with_expr(on),
2010        )
2011        .into();
2012
2013        // Perform the prune
2014        let required_cols = vec![1, 3];
2015        let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
2016
2017        // Check the result
2018        let join = plan.as_logical_join().unwrap();
2019        assert_eq!(join.schema().fields().len(), 2);
2020        assert_eq!(join.schema().fields()[0], fields[1]);
2021        assert_eq!(join.schema().fields()[1], fields[3]);
2022
2023        let expr: ExprImpl = join.on().clone().into();
2024        let call = expr.as_function_call().unwrap();
2025        assert_eq_input_ref!(&call.inputs()[0], 0);
2026        assert_eq_input_ref!(&call.inputs()[1], 1);
2027
2028        let left = join.left();
2029        let left = left.as_logical_values().unwrap();
2030        assert_eq!(left.schema().fields(), &fields[1..2]);
2031        let right = join.right();
2032        let right = right.as_logical_values().unwrap();
2033        assert_eq!(right.schema().fields(), &fields[3..4]);
2034    }
2035
2036    /// Convert
2037    /// ```text
2038    /// Join(on: ($1 = $3) AND ($2 == 42))
2039    ///   TableScan(v1, v2, v3)
2040    ///   TableScan(v4, v5, v6)
2041    /// ```
2042    /// to
2043    /// ```text
2044    /// Filter($2 == 42)
2045    ///   HashJoin(on: $1 = $3)
2046    ///     TableScan(v1, v2, v3)
2047    ///     TableScan(v4, v5, v6)
2048    /// ```
2049    #[tokio::test]
2050    async fn test_join_to_batch() {
2051        let ctx = OptimizerContext::mock().await;
2052        let fields: Vec<Field> = (1..7)
2053            .map(|i| Field::with_name(DataType::Int32, format!("v{}", i)))
2054            .collect();
2055        let left = LogicalValues::new(
2056            vec![],
2057            Schema {
2058                fields: fields[0..3].to_vec(),
2059            },
2060            ctx.clone(),
2061        );
2062        let right = LogicalValues::new(
2063            vec![],
2064            Schema {
2065                fields: fields[3..6].to_vec(),
2066            },
2067            ctx,
2068        );
2069
2070        fn input_ref(i: usize) -> ExprImpl {
2071            ExprImpl::InputRef(Box::new(InputRef::new(i, DataType::Int32)))
2072        }
2073        let eq_cond = ExprImpl::FunctionCall(Box::new(
2074            FunctionCall::new(Type::Equal, vec![input_ref(1), input_ref(3)]).unwrap(),
2075        ));
2076        let non_eq_cond = ExprImpl::FunctionCall(Box::new(
2077            FunctionCall::new(
2078                Type::Equal,
2079                vec![
2080                    input_ref(2),
2081                    ExprImpl::Literal(Box::new(Literal::new(
2082                        Datum::Some(42_i32.into()),
2083                        DataType::Int32,
2084                    ))),
2085                ],
2086            )
2087            .unwrap(),
2088        ));
2089        // Condition: ($1 = $3) AND ($2 == 42)
2090        let on_cond = ExprImpl::FunctionCall(Box::new(
2091            FunctionCall::new(Type::And, vec![eq_cond.clone(), non_eq_cond.clone()]).unwrap(),
2092        ));
2093
2094        let join_type = JoinType::Inner;
2095        let logical_join = LogicalJoin::new(
2096            left.into(),
2097            right.into(),
2098            join_type,
2099            Condition::with_expr(on_cond),
2100        );
2101
2102        // Perform `to_batch`
2103        let result = logical_join.to_batch().unwrap();
2104
2105        // Expected plan:  HashJoin($1 = $3 AND $2 == 42)
2106        let hash_join = result.as_batch_hash_join().unwrap();
2107        assert_eq!(
2108            ExprImpl::from(hash_join.eq_join_predicate().eq_cond()),
2109            eq_cond
2110        );
2111        assert_eq!(
2112            *hash_join
2113                .eq_join_predicate()
2114                .non_eq_cond()
2115                .conjunctions
2116                .first()
2117                .unwrap(),
2118            non_eq_cond
2119        );
2120    }
2121
2122    /// Convert
2123    /// ```text
2124    /// Join(join_type: left outer, on: ($1 = $3) AND ($2 == 42))
2125    ///   TableScan(v1, v2, v3)
2126    ///   TableScan(v4, v5, v6)
2127    /// ```
2128    /// to
2129    /// ```text
2130    /// HashJoin(join_type: left outer, on: ($1 = $3) AND ($2 == 42))
2131    ///   TableScan(v1, v2, v3)
2132    ///   TableScan(v4, v5, v6)
2133    /// ```
2134    #[tokio::test]
2135    #[ignore] // ignore due to refactor logical scan, but the test seem to duplicate with the explain test
2136    // framework, maybe we will remove it?
2137    async fn test_join_to_stream() {
2138        // let ctx = Rc::new(RefCell::new(QueryContext::mock().await));
2139        // let fields: Vec<Field> = (1..7)
2140        //     .map(|i| Field {
2141        //         data_type: DataType::Int32,
2142        //         name: format!("v{}", i),
2143        //     })
2144        //     .collect();
2145        // let left = LogicalScan::new(
2146        //     "left".to_string(),
2147        //     TableId::new(0),
2148        //     vec![1.into(), 2.into(), 3.into()],
2149        //     Schema {
2150        //         fields: fields[0..3].to_vec(),
2151        //     },
2152        //     ctx.clone(),
2153        // );
2154        // let right = LogicalScan::new(
2155        //     "right".to_string(),
2156        //     TableId::new(0),
2157        //     vec![4.into(), 5.into(), 6.into()],
2158        //     Schema {
2159        //                 fields: fields[3..6].to_vec(),
2160        //     },
2161        //     ctx,
2162        // );
2163        // let eq_cond = ExprImpl::FunctionCall(Box::new(
2164        //     FunctionCall::new(
2165        //         Type::Equal,
2166        //         vec![
2167        //             ExprImpl::InputRef(Box::new(InputRef::new(1, DataType::Int32))),
2168        //             ExprImpl::InputRef(Box::new(InputRef::new(3, DataType::Int32))),
2169        //         ],
2170        //     )
2171        //     .unwrap(),
2172        // ));
2173        // let non_eq_cond = ExprImpl::FunctionCall(Box::new(
2174        //     FunctionCall::new(
2175        //         Type::Equal,
2176        //         vec![
2177        //             ExprImpl::InputRef(Box::new(InputRef::new(2, DataType::Int32))),
2178        //             ExprImpl::Literal(Box::new(Literal::new(
2179        //                 Datum::Some(42_i32.into()),
2180        //                 DataType::Int32,
2181        //             ))),
2182        //         ],
2183        //     )
2184        //     .unwrap(),
2185        // ));
2186        // // Condition: ($1 = $3) AND ($2 == 42)
2187        // let on_cond = ExprImpl::FunctionCall(Box::new(
2188        //     FunctionCall::new(Type::And, vec![eq_cond, non_eq_cond]).unwrap(),
2189        // ));
2190
2191        // let join_type = JoinType::LeftOuter;
2192        // let logical_join = LogicalJoin::new(
2193        //     left.clone().into(),
2194        //     right.clone().into(),
2195        //     join_type,
2196        //     Condition::with_expr(on_cond.clone()),
2197        // );
2198
2199        // // Perform `to_stream`
2200        // let result = logical_join.to_stream();
2201
2202        // // Expected plan: HashJoin(($1 = $3) AND ($2 == 42))
2203        // let hash_join = result.as_stream_hash_join().unwrap();
2204        // assert_eq!(hash_join.eq_join_predicate().all_cond().as_expr(), on_cond);
2205    }
2206    /// Pruning
2207    /// ```text
2208    /// Join(on: input_ref(1)=input_ref(3))
2209    ///   TableScan(v1, v2, v3)
2210    ///   TableScan(v4, v5, v6)
2211    /// ```
2212    /// with required columns [3, 2] will result in
2213    /// ```text
2214    /// Project(input_ref(2), input_ref(1))
2215    ///   Join(on: input_ref(0)=input_ref(2))
2216    ///     TableScan(v2, v3)
2217    ///     TableScan(v4)
2218    /// ```
2219    #[tokio::test]
2220    async fn test_join_column_prune_with_order_required() {
2221        let ty = DataType::Int32;
2222        let ctx = OptimizerContext::mock().await;
2223        let fields: Vec<Field> = (1..7)
2224            .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
2225            .collect();
2226        let left = LogicalValues::new(
2227            vec![],
2228            Schema {
2229                fields: fields[0..3].to_vec(),
2230            },
2231            ctx.clone(),
2232        );
2233        let right = LogicalValues::new(
2234            vec![],
2235            Schema {
2236                fields: fields[3..6].to_vec(),
2237            },
2238            ctx,
2239        );
2240        let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
2241            FunctionCall::new(
2242                Type::Equal,
2243                vec![
2244                    ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
2245                    ExprImpl::InputRef(Box::new(InputRef::new(3, ty))),
2246                ],
2247            )
2248            .unwrap(),
2249        ));
2250        let join_type = JoinType::Inner;
2251        let join: PlanRef = LogicalJoin::new(
2252            left.into(),
2253            right.into(),
2254            join_type,
2255            Condition::with_expr(on),
2256        )
2257        .into();
2258
2259        // Perform the prune
2260        let required_cols = vec![3, 2];
2261        let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
2262
2263        // Check the result
2264        let join = plan.as_logical_join().unwrap();
2265        assert_eq!(join.schema().fields().len(), 2);
2266        assert_eq!(join.schema().fields()[0], fields[3]);
2267        assert_eq!(join.schema().fields()[1], fields[2]);
2268
2269        let expr: ExprImpl = join.on().clone().into();
2270        let call = expr.as_function_call().unwrap();
2271        assert_eq_input_ref!(&call.inputs()[0], 0);
2272        assert_eq_input_ref!(&call.inputs()[1], 2);
2273
2274        let left = join.left();
2275        let left = left.as_logical_values().unwrap();
2276        assert_eq!(left.schema().fields(), &fields[1..3]);
2277        let right = join.right();
2278        let right = right.as_logical_values().unwrap();
2279        assert_eq!(right.schema().fields(), &fields[3..4]);
2280    }
2281
2282    #[tokio::test]
2283    async fn fd_derivation_inner_outer_join() {
2284        // left: [l0, l1], right: [r0, r1, r2]
2285        // FD: l0 --> l1, r0 --> { r1, r2 }
2286        // On: l0 = 0 AND l1 = r1
2287        //
2288        // Inner Join:
2289        //  Schema: [l0, l1, r0, r1, r2]
2290        //  FD: l0 --> l1, r0 --> { r1, r2 }, {} --> l0, l1 --> r1, r1 --> l1
2291        // Left Outer Join:
2292        //  Schema: [l0, l1, r0, r1, r2]
2293        //  FD: l0 --> l1
2294        // Right Outer Join:
2295        //  Schema: [l0, l1, r0, r1, r2]
2296        //  FD: r0 --> { r1, r2 }
2297        // Full Outer Join:
2298        //  Schema: [l0, l1, r0, r1, r2]
2299        //  FD: empty
2300        // Left Semi/Anti Join:
2301        //  Schema: [l0, l1]
2302        //  FD: l0 --> l1
2303        // Right Semi/Anti Join:
2304        //  Schema: [r0, r1, r2]
2305        //  FD: r0 --> {r1, r2}
2306        let ctx = OptimizerContext::mock().await;
2307        let left = {
2308            let fields: Vec<Field> = vec![
2309                Field::with_name(DataType::Int32, "l0"),
2310                Field::with_name(DataType::Int32, "l1"),
2311            ];
2312            let mut values = LogicalValues::new(vec![], Schema { fields }, ctx.clone());
2313            // 0 --> 1
2314            values
2315                .base
2316                .functional_dependency_mut()
2317                .add_functional_dependency_by_column_indices(&[0], &[1]);
2318            values
2319        };
2320        let right = {
2321            let fields: Vec<Field> = vec![
2322                Field::with_name(DataType::Int32, "r0"),
2323                Field::with_name(DataType::Int32, "r1"),
2324                Field::with_name(DataType::Int32, "r2"),
2325            ];
2326            let mut values = LogicalValues::new(vec![], Schema { fields }, ctx);
2327            // 0 --> 1, 2
2328            values
2329                .base
2330                .functional_dependency_mut()
2331                .add_functional_dependency_by_column_indices(&[0], &[1, 2]);
2332            values
2333        };
2334        // l0 = 0 AND l1 = r1
2335        let on: ExprImpl = FunctionCall::new(
2336            Type::And,
2337            vec![
2338                FunctionCall::new(
2339                    Type::Equal,
2340                    vec![
2341                        InputRef::new(0, DataType::Int32).into(),
2342                        ExprImpl::literal_int(0),
2343                    ],
2344                )
2345                .unwrap()
2346                .into(),
2347                FunctionCall::new(
2348                    Type::Equal,
2349                    vec![
2350                        InputRef::new(1, DataType::Int32).into(),
2351                        InputRef::new(3, DataType::Int32).into(),
2352                    ],
2353                )
2354                .unwrap()
2355                .into(),
2356            ],
2357        )
2358        .unwrap()
2359        .into();
2360        let expected_fd_set = [
2361            (
2362                JoinType::Inner,
2363                [
2364                    // inherit from left
2365                    FunctionalDependency::with_indices(5, &[0], &[1]),
2366                    // inherit from right
2367                    FunctionalDependency::with_indices(5, &[2], &[3, 4]),
2368                    // constant column in join condition
2369                    FunctionalDependency::with_indices(5, &[], &[0]),
2370                    // eq column in join condition
2371                    FunctionalDependency::with_indices(5, &[1], &[3]),
2372                    FunctionalDependency::with_indices(5, &[3], &[1]),
2373                ]
2374                .into_iter()
2375                .collect::<HashSet<_>>(),
2376            ),
2377            (JoinType::FullOuter, HashSet::new()),
2378            (
2379                JoinType::RightOuter,
2380                [
2381                    // inherit from right
2382                    FunctionalDependency::with_indices(5, &[2], &[3, 4]),
2383                ]
2384                .into_iter()
2385                .collect::<HashSet<_>>(),
2386            ),
2387            (
2388                JoinType::LeftOuter,
2389                [
2390                    // inherit from left
2391                    FunctionalDependency::with_indices(5, &[0], &[1]),
2392                ]
2393                .into_iter()
2394                .collect::<HashSet<_>>(),
2395            ),
2396            (
2397                JoinType::LeftSemi,
2398                [
2399                    // inherit from left
2400                    FunctionalDependency::with_indices(2, &[0], &[1]),
2401                ]
2402                .into_iter()
2403                .collect::<HashSet<_>>(),
2404            ),
2405            (
2406                JoinType::LeftAnti,
2407                [
2408                    // inherit from left
2409                    FunctionalDependency::with_indices(2, &[0], &[1]),
2410                ]
2411                .into_iter()
2412                .collect::<HashSet<_>>(),
2413            ),
2414            (
2415                JoinType::RightSemi,
2416                [
2417                    // inherit from right
2418                    FunctionalDependency::with_indices(3, &[0], &[1, 2]),
2419                ]
2420                .into_iter()
2421                .collect::<HashSet<_>>(),
2422            ),
2423            (
2424                JoinType::RightAnti,
2425                [
2426                    // inherit from right
2427                    FunctionalDependency::with_indices(3, &[0], &[1, 2]),
2428                ]
2429                .into_iter()
2430                .collect::<HashSet<_>>(),
2431            ),
2432        ];
2433
2434        for (join_type, expected_res) in expected_fd_set {
2435            let join = LogicalJoin::new(
2436                left.clone().into(),
2437                right.clone().into(),
2438                join_type,
2439                Condition::with_expr(on.clone()),
2440            );
2441            let fd_set = join
2442                .functional_dependency()
2443                .as_dependencies()
2444                .iter()
2445                .cloned()
2446                .collect::<HashSet<_>>();
2447            assert_eq!(fd_set, expected_res);
2448        }
2449    }
2450}