risingwave_frontend/optimizer/plan_node/
logical_join.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16
17use fixedbitset::FixedBitSet;
18use itertools::{EitherOrBoth, Itertools};
19use pretty_xmlish::{Pretty, XmlNode};
20use risingwave_expr::bail;
21use risingwave_pb::expr::expr_node::PbType;
22use risingwave_pb::plan_common::{AsOfJoinDesc, JoinType, PbAsOfJoinInequalityType};
23use risingwave_pb::stream_plan::StreamScanType;
24use risingwave_sqlparser::ast::AsOf;
25
26use super::generic::{
27    GenericPlanNode, GenericPlanRef, push_down_into_join, push_down_join_condition,
28};
29use super::utils::{Distill, childless_record};
30use super::{
31    BatchPlanRef, ColPrunable, ExprRewritable, Logical, LogicalPlanRef as PlanRef, PlanBase,
32    PlanTreeNodeBinary, PredicatePushdown, StreamHashJoin, StreamPlanRef, StreamProject, ToBatch,
33    ToStream, generic, try_enforce_locality_requirement,
34};
35use crate::error::{ErrorCode, Result, RwError};
36use crate::expr::{CollectInputRef, Expr, ExprImpl, ExprRewriter, ExprType, ExprVisitor, InputRef};
37use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
38use crate::optimizer::plan_node::generic::DynamicFilter;
39use crate::optimizer::plan_node::stream_asof_join::StreamAsOfJoin;
40use crate::optimizer::plan_node::utils::IndicesDisplay;
41use crate::optimizer::plan_node::{
42    BatchHashJoin, BatchLookupJoin, BatchNestedLoopJoin, ColumnPruningContext, EqJoinPredicate,
43    LogicalFilter, LogicalScan, PredicatePushdownContext, RewriteStreamContext,
44    StreamDynamicFilter, StreamFilter, StreamTableScan, StreamTemporalJoin, ToStreamContext,
45};
46use crate::optimizer::plan_visitor::LogicalCardinalityExt;
47use crate::optimizer::property::{Distribution, RequiredDist};
48use crate::utils::{ColIndexMapping, ColIndexMappingRewriteExt, Condition, ConditionDisplay};
49
50/// `LogicalJoin` combines two relations according to some condition.
51///
52/// Each output row has fields from the left and right inputs. The set of output rows is a subset
53/// of the cartesian product of the two inputs; precisely which subset depends on the join
54/// condition. In addition, the output columns are a subset of the columns of the left and
55/// right columns, dependent on the output indices provided. A repeat output index is illegal.
56#[derive(Debug, Clone, PartialEq, Eq, Hash)]
57pub struct LogicalJoin {
58    pub base: PlanBase<Logical>,
59    core: generic::Join<PlanRef>,
60}
61
62impl Distill for LogicalJoin {
63    fn distill<'a>(&self) -> XmlNode<'a> {
64        let verbose = self.base.ctx().is_explain_verbose();
65        let mut vec = Vec::with_capacity(if verbose { 3 } else { 2 });
66        vec.push(("type", Pretty::debug(&self.join_type())));
67
68        let concat_schema = self.core.concat_schema();
69        let cond = Pretty::debug(&ConditionDisplay {
70            condition: self.on(),
71            input_schema: &concat_schema,
72        });
73        vec.push(("on", cond));
74
75        if verbose {
76            let data = IndicesDisplay::from_join(&self.core, &concat_schema);
77            vec.push(("output", data));
78        }
79
80        childless_record("LogicalJoin", vec)
81    }
82}
83
84impl LogicalJoin {
85    pub(crate) fn new(left: PlanRef, right: PlanRef, join_type: JoinType, on: Condition) -> Self {
86        let core = generic::Join::with_full_output(left, right, join_type, on);
87        Self::with_core(core)
88    }
89
90    pub(crate) fn with_output_indices(
91        left: PlanRef,
92        right: PlanRef,
93        join_type: JoinType,
94        on: Condition,
95        output_indices: Vec<usize>,
96    ) -> Self {
97        let core = generic::Join::new(left, right, on, join_type, output_indices);
98        Self::with_core(core)
99    }
100
101    pub fn with_core(core: generic::Join<PlanRef>) -> Self {
102        let base = PlanBase::new_logical_with_core(&core);
103        LogicalJoin { base, core }
104    }
105
106    pub fn create(
107        left: PlanRef,
108        right: PlanRef,
109        join_type: JoinType,
110        on_clause: ExprImpl,
111    ) -> PlanRef {
112        Self::new(left, right, join_type, Condition::with_expr(on_clause)).into()
113    }
114
115    pub fn internal_column_num(&self) -> usize {
116        self.core.internal_column_num()
117    }
118
119    pub fn i2l_col_mapping_ignore_join_type(&self) -> ColIndexMapping {
120        self.core.i2l_col_mapping_ignore_join_type()
121    }
122
123    pub fn i2r_col_mapping_ignore_join_type(&self) -> ColIndexMapping {
124        self.core.i2r_col_mapping_ignore_join_type()
125    }
126
127    /// Get a reference to the logical join's on.
128    pub fn on(&self) -> &Condition {
129        &self.core.on
130    }
131
132    pub fn core(&self) -> &generic::Join<PlanRef> {
133        &self.core
134    }
135
136    /// Collect all input ref in the on condition. And separate them into left and right.
137    pub fn input_idx_on_condition(&self) -> (Vec<usize>, Vec<usize>) {
138        let input_refs = self
139            .core
140            .on
141            .collect_input_refs(self.core.left.schema().len() + self.core.right.schema().len());
142        let index_group = input_refs
143            .ones()
144            .chunk_by(|i| *i < self.core.left.schema().len());
145        let left_index = index_group
146            .into_iter()
147            .next()
148            .map_or(vec![], |group| group.1.collect_vec());
149        let right_index = index_group.into_iter().next().map_or(vec![], |group| {
150            group
151                .1
152                .map(|i| i - self.core.left.schema().len())
153                .collect_vec()
154        });
155        (left_index, right_index)
156    }
157
158    /// Get the join type of the logical join.
159    pub fn join_type(&self) -> JoinType {
160        self.core.join_type
161    }
162
163    /// Get the eq join key of the logical join.
164    pub fn eq_indexes(&self) -> Vec<(usize, usize)> {
165        self.core.eq_indexes()
166    }
167
168    /// Get the output indices of the logical join.
169    pub fn output_indices(&self) -> &Vec<usize> {
170        &self.core.output_indices
171    }
172
173    /// Clone with new output indices
174    pub fn clone_with_output_indices(&self, output_indices: Vec<usize>) -> Self {
175        Self::with_core(generic::Join {
176            output_indices,
177            ..self.core.clone()
178        })
179    }
180
181    /// Clone with new `on` condition
182    pub fn clone_with_cond(&self, on: Condition) -> Self {
183        Self::with_core(generic::Join {
184            on,
185            ..self.core.clone()
186        })
187    }
188
189    pub fn is_left_join(&self) -> bool {
190        matches!(self.join_type(), JoinType::LeftSemi | JoinType::LeftAnti)
191    }
192
193    pub fn is_right_join(&self) -> bool {
194        matches!(self.join_type(), JoinType::RightSemi | JoinType::RightAnti)
195    }
196
197    pub fn is_full_out(&self) -> bool {
198        self.core.is_full_out()
199    }
200
201    pub fn is_asof_join(&self) -> bool {
202        self.join_type() == JoinType::AsofInner || self.join_type() == JoinType::AsofLeftOuter
203    }
204
205    pub fn output_indices_are_trivial(&self) -> bool {
206        self.output_indices() == &(0..self.internal_column_num()).collect_vec()
207    }
208
209    /// Try to simplify the outer join with the predicate on the top of the join
210    ///
211    /// now it is just a naive implementation for comparison expression, we can give a more general
212    /// implementation with constant folding in future
213    fn simplify_outer(predicate: &Condition, left_col_num: usize, join_type: JoinType) -> JoinType {
214        let (mut gen_null_in_left, mut gen_null_in_right) = match join_type {
215            JoinType::LeftOuter => (false, true),
216            JoinType::RightOuter => (true, false),
217            JoinType::FullOuter => (true, true),
218            _ => return join_type,
219        };
220
221        for expr in &predicate.conjunctions {
222            if let ExprImpl::FunctionCall(func) = expr {
223                match func.func_type() {
224                    ExprType::Equal
225                    | ExprType::NotEqual
226                    | ExprType::LessThan
227                    | ExprType::LessThanOrEqual
228                    | ExprType::GreaterThan
229                    | ExprType::GreaterThanOrEqual => {
230                        for input in func.inputs() {
231                            if let ExprImpl::InputRef(input) = input {
232                                let idx = input.index;
233                                if idx < left_col_num {
234                                    gen_null_in_left = false;
235                                } else {
236                                    gen_null_in_right = false;
237                                }
238                            }
239                        }
240                    }
241                    _ => {}
242                };
243            }
244        }
245
246        match (gen_null_in_left, gen_null_in_right) {
247            (true, true) => JoinType::FullOuter,
248            (true, false) => JoinType::RightOuter,
249            (false, true) => JoinType::LeftOuter,
250            (false, false) => JoinType::Inner,
251        }
252    }
253
254    /// Index Join:
255    /// Try to convert logical join into batch lookup join and meanwhile it will do
256    /// the index selection for the lookup table so that we can benefit from indexes.
257    fn to_batch_lookup_join_with_index_selection(
258        &self,
259        predicate: EqJoinPredicate,
260        batch_join: generic::Join<BatchPlanRef>,
261    ) -> Result<Option<BatchLookupJoin>> {
262        match batch_join.join_type {
263            JoinType::Inner
264            | JoinType::LeftOuter
265            | JoinType::LeftSemi
266            | JoinType::LeftAnti
267            | JoinType::AsofInner
268            | JoinType::AsofLeftOuter => {}
269            _ => return Ok(None),
270        };
271
272        // Index selection for index join.
273        let right = self.right();
274        // Lookup Join only supports basic tables on the join's right side.
275        let logical_scan: &LogicalScan = if let Some(logical_scan) = right.as_logical_scan() {
276            logical_scan
277        } else {
278            return Ok(None);
279        };
280
281        let mut result_plan = None;
282        // Lookup primary table.
283        if let Some(lookup_join) =
284            self.to_batch_lookup_join(predicate.clone(), batch_join.clone())?
285        {
286            result_plan = Some(lookup_join);
287        }
288
289        if self
290            .core
291            .ctx()
292            .session_ctx()
293            .config()
294            .enable_index_selection()
295        {
296            let indexes = logical_scan.table_indexes();
297            for index in indexes {
298                if let Some(index_scan) = logical_scan.to_index_scan_if_index_covered(index) {
299                    let index_scan: PlanRef = index_scan.into();
300                    let that = self.clone_with_left_right(self.left(), index_scan.clone());
301                    let mut new_batch_join = batch_join.clone();
302                    new_batch_join.right =
303                        index_scan.to_batch().expect("index scan failed to batch");
304
305                    // Lookup covered index.
306                    if let Some(lookup_join) =
307                        that.to_batch_lookup_join(predicate.clone(), new_batch_join)?
308                    {
309                        match &result_plan {
310                            None => result_plan = Some(lookup_join),
311                            Some(prev_lookup_join) => {
312                                // Prefer to choose lookup join with longer lookup prefix len.
313                                if prev_lookup_join.lookup_prefix_len()
314                                    < lookup_join.lookup_prefix_len()
315                                {
316                                    result_plan = Some(lookup_join)
317                                }
318                            }
319                        }
320                    }
321                }
322            }
323        }
324
325        Ok(result_plan)
326    }
327
328    /// Try to convert logical join into batch lookup join.
329    fn to_batch_lookup_join(
330        &self,
331        predicate: EqJoinPredicate,
332        logical_join: generic::Join<BatchPlanRef>,
333    ) -> Result<Option<BatchLookupJoin>> {
334        let logical_scan: &LogicalScan =
335            if let Some(logical_scan) = self.core.right.as_logical_scan() {
336                logical_scan
337            } else {
338                return Ok(None);
339            };
340        Self::gen_batch_lookup_join(logical_scan, predicate, logical_join, self.is_asof_join())
341    }
342
343    pub fn gen_batch_lookup_join(
344        logical_scan: &LogicalScan,
345        predicate: EqJoinPredicate,
346        logical_join: generic::Join<BatchPlanRef>,
347        is_as_of: bool,
348    ) -> Result<Option<BatchLookupJoin>> {
349        match logical_join.join_type {
350            JoinType::Inner
351            | JoinType::LeftOuter
352            | JoinType::LeftSemi
353            | JoinType::LeftAnti
354            | JoinType::AsofInner
355            | JoinType::AsofLeftOuter => {}
356            _ => return Ok(None),
357        };
358
359        let table = logical_scan.table();
360        let output_column_ids = logical_scan.output_column_ids();
361
362        // Verify that the right join key columns are the the prefix of the primary key and
363        // also contain the distribution key.
364        let order_col_ids = table.order_column_ids();
365        let dist_key = table.distribution_key.clone();
366        // The at least prefix of order key that contains distribution key.
367        let mut dist_key_in_order_key_pos = vec![];
368        for d in dist_key {
369            let pos = table
370                .order_column_indices()
371                .position(|x| x == d)
372                .expect("dist_key must in order_key");
373            dist_key_in_order_key_pos.push(pos);
374        }
375        // The shortest prefix of order key that contains distribution key.
376        let shortest_prefix_len = dist_key_in_order_key_pos
377            .iter()
378            .max()
379            .map_or(0, |pos| pos + 1);
380
381        // Distributed lookup join can't support lookup table with a singleton distribution.
382        if shortest_prefix_len == 0 {
383            return Ok(None);
384        }
385
386        // Reorder the join equal predicate to match the order key.
387        let mut reorder_idx = Vec::with_capacity(shortest_prefix_len);
388        for order_col_id in order_col_ids {
389            let mut found = false;
390            for (i, eq_idx) in predicate.right_eq_indexes().into_iter().enumerate() {
391                if order_col_id == output_column_ids[eq_idx] {
392                    reorder_idx.push(i);
393                    found = true;
394                    break;
395                }
396            }
397            if !found {
398                break;
399            }
400        }
401        if reorder_idx.len() < shortest_prefix_len {
402            return Ok(None);
403        }
404        let lookup_prefix_len = reorder_idx.len();
405        let predicate = predicate.reorder(&reorder_idx);
406
407        // Extract the predicate from logical scan. Only pure scan is supported.
408        let (new_scan, scan_predicate, project_expr) = logical_scan.predicate_pull_up();
409        // Construct output column to require column mapping
410        let o2r = if let Some(project_expr) = project_expr {
411            project_expr
412                .into_iter()
413                .map(|x| x.as_input_ref().unwrap().index)
414                .collect_vec()
415        } else {
416            (0..logical_scan.output_col_idx().len()).collect_vec()
417        };
418        let left_schema_len = logical_join.left.schema().len();
419
420        let mut join_predicate_rewriter = LookupJoinPredicateRewriter {
421            offset: left_schema_len,
422            mapping: o2r.clone(),
423        };
424
425        let new_eq_cond = predicate
426            .eq_cond()
427            .rewrite_expr(&mut join_predicate_rewriter);
428
429        let mut scan_predicate_rewriter = LookupJoinScanPredicateRewriter {
430            offset: left_schema_len,
431        };
432
433        let new_other_cond = predicate
434            .other_cond()
435            .clone()
436            .rewrite_expr(&mut join_predicate_rewriter)
437            .and(scan_predicate.rewrite_expr(&mut scan_predicate_rewriter));
438
439        let new_join_on = new_eq_cond.and(new_other_cond);
440        let new_predicate = EqJoinPredicate::create(
441            left_schema_len,
442            new_scan.schema().len(),
443            new_join_on.clone(),
444        );
445
446        // We discovered that we cannot use a lookup join after pulling up the predicate
447        // from one side and simplifying the condition. Let's use some other join instead.
448        if !new_predicate.has_eq() {
449            return Ok(None);
450        }
451
452        // Rewrite the join output indices and all output indices referred to the old scan need to
453        // rewrite.
454        let new_join_output_indices = logical_join
455            .output_indices
456            .iter()
457            .map(|&x| {
458                if x < left_schema_len {
459                    x
460                } else {
461                    o2r[x - left_schema_len] + left_schema_len
462                }
463            })
464            .collect_vec();
465
466        let new_scan_output_column_ids = new_scan.output_column_ids();
467        let as_of = new_scan.as_of.clone();
468        let new_logical_scan: LogicalScan = new_scan.into();
469
470        // Construct a new logical join, because we have change its RHS.
471        let new_logical_join = generic::Join::new(
472            logical_join.left,
473            new_logical_scan.to_batch()?,
474            new_join_on,
475            logical_join.join_type,
476            new_join_output_indices,
477        );
478
479        let asof_desc = is_as_of
480            .then(|| {
481                Self::get_inequality_desc_from_predicate(
482                    predicate.other_cond().clone(),
483                    left_schema_len,
484                )
485            })
486            .transpose()?;
487
488        Ok(Some(BatchLookupJoin::new(
489            new_logical_join,
490            new_predicate,
491            table.clone(),
492            new_scan_output_column_ids,
493            lookup_prefix_len,
494            false,
495            as_of,
496            asof_desc,
497        )))
498    }
499
500    pub fn decompose(self) -> (PlanRef, PlanRef, Condition, JoinType, Vec<usize>) {
501        self.core.decompose()
502    }
503}
504
505impl PlanTreeNodeBinary<Logical> for LogicalJoin {
506    fn left(&self) -> PlanRef {
507        self.core.left.clone()
508    }
509
510    fn right(&self) -> PlanRef {
511        self.core.right.clone()
512    }
513
514    fn clone_with_left_right(&self, left: PlanRef, right: PlanRef) -> Self {
515        Self::with_core(generic::Join {
516            left,
517            right,
518            ..self.core.clone()
519        })
520    }
521
522    fn rewrite_with_left_right(
523        &self,
524        left: PlanRef,
525        left_col_change: ColIndexMapping,
526        right: PlanRef,
527        right_col_change: ColIndexMapping,
528    ) -> (Self, ColIndexMapping) {
529        let (new_on, new_output_indices) = {
530            let (mut map, _) = left_col_change.clone().into_parts();
531            let (mut right_map, _) = right_col_change.clone().into_parts();
532            for i in right_map.iter_mut().flatten() {
533                *i += left.schema().len();
534            }
535            map.append(&mut right_map);
536            let mut mapping = ColIndexMapping::new(map, left.schema().len() + right.schema().len());
537
538            let new_output_indices = self
539                .output_indices()
540                .iter()
541                .map(|&i| mapping.map(i))
542                .collect::<Vec<_>>();
543            let new_on = self.on().clone().rewrite_expr(&mut mapping);
544            (new_on, new_output_indices)
545        };
546
547        let join = Self::with_output_indices(
548            left,
549            right,
550            self.join_type(),
551            new_on,
552            new_output_indices.clone(),
553        );
554
555        let new_i2o = ColIndexMapping::with_remaining_columns(
556            &new_output_indices,
557            join.internal_column_num(),
558        );
559
560        let old_o2i = self.core.o2i_col_mapping();
561
562        let old_o2l = old_o2i
563            .composite(&self.core.i2l_col_mapping())
564            .composite(&left_col_change);
565        let old_o2r = old_o2i
566            .composite(&self.core.i2r_col_mapping())
567            .composite(&right_col_change);
568        let new_l2o = join.core.l2i_col_mapping().composite(&new_i2o);
569        let new_r2o = join.core.r2i_col_mapping().composite(&new_i2o);
570
571        let out_col_change = old_o2l
572            .composite(&new_l2o)
573            .union(&old_o2r.composite(&new_r2o));
574        (join, out_col_change)
575    }
576}
577
578impl_plan_tree_node_for_binary! { Logical, LogicalJoin }
579
580impl ColPrunable for LogicalJoin {
581    fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef {
582        // make `required_cols` point to internal table instead of output schema.
583        let required_cols = required_cols
584            .iter()
585            .map(|i| self.output_indices()[*i])
586            .collect_vec();
587        let left_len = self.left().schema().fields.len();
588
589        let total_len = self.left().schema().len() + self.right().schema().len();
590        let mut resized_required_cols = FixedBitSet::with_capacity(total_len);
591
592        required_cols.iter().for_each(|&i| {
593            if self.is_right_join() {
594                resized_required_cols.insert(left_len + i);
595            } else {
596                resized_required_cols.insert(i);
597            }
598        });
599
600        // add those columns which are required in the join condition to
601        // to those that are required in the output
602        let mut visitor = CollectInputRef::new(resized_required_cols);
603        self.on().visit_expr(&mut visitor);
604        let left_right_required_cols = FixedBitSet::from(visitor).ones().collect_vec();
605
606        let mut left_required_cols = Vec::new();
607        let mut right_required_cols = Vec::new();
608        left_right_required_cols.iter().for_each(|&i| {
609            if i < left_len {
610                left_required_cols.push(i);
611            } else {
612                right_required_cols.push(i - left_len);
613            }
614        });
615
616        let mut on = self.on().clone();
617        let mut mapping =
618            ColIndexMapping::with_remaining_columns(&left_right_required_cols, total_len);
619        on = on.rewrite_expr(&mut mapping);
620
621        let new_output_indices = {
622            let required_inputs_in_output = if self.is_left_join() {
623                &left_required_cols
624            } else if self.is_right_join() {
625                &right_required_cols
626            } else {
627                &left_right_required_cols
628            };
629
630            let mapping =
631                ColIndexMapping::with_remaining_columns(required_inputs_in_output, total_len);
632            required_cols.iter().map(|&i| mapping.map(i)).collect_vec()
633        };
634
635        LogicalJoin::with_output_indices(
636            self.left().prune_col(&left_required_cols, ctx),
637            self.right().prune_col(&right_required_cols, ctx),
638            self.join_type(),
639            on,
640            new_output_indices,
641        )
642        .into()
643    }
644}
645
646impl ExprRewritable<Logical> for LogicalJoin {
647    fn has_rewritable_expr(&self) -> bool {
648        true
649    }
650
651    fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
652        let mut core = self.core.clone();
653        core.rewrite_exprs(r);
654        Self {
655            base: self.base.clone_with_new_plan_id(),
656            core,
657        }
658        .into()
659    }
660}
661
662impl ExprVisitable for LogicalJoin {
663    fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
664        self.core.visit_exprs(v);
665    }
666}
667
668/// We are trying to derive a predicate to apply to the other side of a join if all
669/// the `InputRef`s in the predicate are eq condition columns, and can hence be substituted
670/// with the corresponding eq condition columns of the other side.
671///
672/// Strategy:
673/// 1. If the function is pure except for any `InputRef` (which may refer to impure computation),
674///    then we proceed. Else abort.
675/// 2. Then, we collect `InputRef`s in the conjunction.
676/// 3. If they are all columns in the given side of join eq condition, then we proceed. Else abort.
677/// 4. We then rewrite the `ExprImpl`, by replacing `InputRef` column indices with the equivalent in
678///    the other side.
679///
680/// # Arguments
681///
682/// Suppose we derive a predicate from the left side to be pushed to the right side.
683/// * `expr`: An expr from the left side.
684/// * `col_num`: The number of columns in the left side.
685fn derive_predicate_from_eq_condition(
686    expr: &ExprImpl,
687    eq_condition: &EqJoinPredicate,
688    col_num: usize,
689    expr_is_left: bool,
690) -> Option<ExprImpl> {
691    if expr.is_impure() {
692        return None;
693    }
694    let eq_indices = eq_condition
695        .eq_indexes_typed()
696        .iter()
697        .filter_map(|(l, r)| {
698            if l.return_type() != r.return_type() {
699                None
700            } else if expr_is_left {
701                Some(l.index())
702            } else {
703                Some(r.index())
704            }
705        })
706        .collect_vec();
707    if expr
708        .collect_input_refs(col_num)
709        .ones()
710        .any(|index| !eq_indices.contains(&index))
711    {
712        // expr contains an InputRef not in eq_condition
713        return None;
714    }
715    // The function is pure except for `InputRef` and all `InputRef`s are `eq_condition` indices.
716    // Hence, we can substitute those `InputRef`s with indices from the other side.
717    let other_side_mapping = if expr_is_left {
718        eq_condition.eq_indexes_typed().into_iter().collect()
719    } else {
720        eq_condition
721            .eq_indexes_typed()
722            .into_iter()
723            .map(|(x, y)| (y, x))
724            .collect()
725    };
726    struct InputRefsRewriter {
727        mapping: HashMap<InputRef, InputRef>,
728    }
729    impl ExprRewriter for InputRefsRewriter {
730        fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
731            self.mapping[&input_ref].clone().into()
732        }
733    }
734    Some(
735        InputRefsRewriter {
736            mapping: other_side_mapping,
737        }
738        .rewrite_expr(expr.clone()),
739    )
740}
741
742/// Rewrite the join predicate and all columns referred to the scan side need to rewrite.
743struct LookupJoinPredicateRewriter {
744    offset: usize,
745    mapping: Vec<usize>,
746}
747impl ExprRewriter for LookupJoinPredicateRewriter {
748    fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
749        if input_ref.index() < self.offset {
750            input_ref.into()
751        } else {
752            InputRef::new(
753                self.mapping[input_ref.index() - self.offset] + self.offset,
754                input_ref.return_type(),
755            )
756            .into()
757        }
758    }
759}
760
761/// Rewrite the scan predicate so we can add it to the join predicate.
762struct LookupJoinScanPredicateRewriter {
763    offset: usize,
764}
765impl ExprRewriter for LookupJoinScanPredicateRewriter {
766    fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
767        InputRef::new(input_ref.index() + self.offset, input_ref.return_type()).into()
768    }
769}
770
771impl PredicatePushdown for LogicalJoin {
772    /// Pushes predicates above and within a join node into the join node and/or its children nodes.
773    ///
774    /// # Which predicates can be pushed
775    ///
776    /// For inner join, we can do all kinds of pushdown.
777    ///
778    /// For left/right semi join, we can push filter to left/right and on-clause,
779    /// and push on-clause to left/right.
780    ///
781    /// For left/right anti join, we can push filter to left/right, but on-clause can not be pushed
782    ///
783    /// ## Outer Join
784    ///
785    /// Preserved Row table
786    /// : The table in an Outer Join that must return all rows.
787    ///
788    /// Null Supplying table
789    /// : This is the table that has nulls filled in for its columns in unmatched rows.
790    ///
791    /// |                          | Preserved Row table | Null Supplying table |
792    /// |--------------------------|---------------------|----------------------|
793    /// | Join predicate (on)      | Not Pushed          | Pushed               |
794    /// | Where predicate (filter) | Pushed              | Not Pushed           |
795    fn predicate_pushdown(
796        &self,
797        predicate: Condition,
798        ctx: &mut PredicatePushdownContext,
799    ) -> PlanRef {
800        // rewrite output col referencing indices as internal cols
801        let mut predicate = {
802            let mut mapping = self.core.o2i_col_mapping();
803            predicate.rewrite_expr(&mut mapping)
804        };
805
806        let left_col_num = self.left().schema().len();
807        let right_col_num = self.right().schema().len();
808        let join_type = LogicalJoin::simplify_outer(&predicate, left_col_num, self.join_type());
809
810        let push_down_temporal_predicate = !self.should_be_temporal_join();
811
812        let (left_from_filter, right_from_filter, on) = push_down_into_join(
813            &mut predicate,
814            left_col_num,
815            right_col_num,
816            join_type,
817            push_down_temporal_predicate,
818        );
819
820        let mut new_on = self.on().clone().and(on);
821        let (left_from_on, right_from_on) = push_down_join_condition(
822            &mut new_on,
823            left_col_num,
824            right_col_num,
825            join_type,
826            push_down_temporal_predicate,
827        );
828
829        let left_predicate = left_from_filter.and(left_from_on);
830        let right_predicate = right_from_filter.and(right_from_on);
831
832        // Derive conditions to push to the other side based on eq condition columns
833        let eq_condition = EqJoinPredicate::create(left_col_num, right_col_num, new_on.clone());
834
835        // Only push to RHS if RHS is inner side of a join (RHS requires match on LHS)
836        let right_from_left = if matches!(
837            join_type,
838            JoinType::Inner | JoinType::LeftOuter | JoinType::RightSemi | JoinType::LeftSemi
839        ) {
840            Condition {
841                conjunctions: left_predicate
842                    .conjunctions
843                    .iter()
844                    .filter_map(|expr| {
845                        derive_predicate_from_eq_condition(expr, &eq_condition, left_col_num, true)
846                    })
847                    .collect(),
848            }
849        } else {
850            Condition::true_cond()
851        };
852
853        // Only push to LHS if LHS is inner side of a join (LHS requires match on RHS)
854        let left_from_right = if matches!(
855            join_type,
856            JoinType::Inner | JoinType::RightOuter | JoinType::LeftSemi | JoinType::RightSemi
857        ) {
858            Condition {
859                conjunctions: right_predicate
860                    .conjunctions
861                    .iter()
862                    .filter_map(|expr| {
863                        derive_predicate_from_eq_condition(
864                            expr,
865                            &eq_condition,
866                            right_col_num,
867                            false,
868                        )
869                    })
870                    .collect(),
871            }
872        } else {
873            Condition::true_cond()
874        };
875
876        let left_predicate = left_predicate.and(left_from_right);
877        let right_predicate = right_predicate.and(right_from_left);
878
879        let new_left = self.left().predicate_pushdown(left_predicate, ctx);
880        let new_right = self.right().predicate_pushdown(right_predicate, ctx);
881        let new_join = LogicalJoin::with_output_indices(
882            new_left,
883            new_right,
884            join_type,
885            new_on,
886            self.output_indices().clone(),
887        );
888
889        let mut mapping = self.core.i2o_col_mapping();
890        predicate = predicate.rewrite_expr(&mut mapping);
891        LogicalFilter::create(new_join.into(), predicate)
892    }
893}
894
895impl LogicalJoin {
896    fn get_stream_input_for_hash_join(
897        &self,
898        predicate: &EqJoinPredicate,
899        ctx: &mut ToStreamContext,
900    ) -> Result<(StreamPlanRef, StreamPlanRef)> {
901        use super::stream::prelude::*;
902
903        let lhs_join_key_idx = self.eq_indexes().into_iter().map(|(l, _)| l).collect_vec();
904        let rhs_join_key_idx = self.eq_indexes().into_iter().map(|(_, r)| r).collect_vec();
905
906        let logical_right = try_enforce_locality_requirement(self.right(), &rhs_join_key_idx);
907        let mut right = logical_right.to_stream_with_dist_required(
908            &RequiredDist::shard_by_key(self.right().schema().len(), &predicate.right_eq_indexes()),
909            ctx,
910        )?;
911        let logical_left = try_enforce_locality_requirement(self.left(), &lhs_join_key_idx);
912        let r2l =
913            predicate.r2l_eq_columns_mapping(logical_left.schema().len(), right.schema().len());
914        let l2r =
915            predicate.l2r_eq_columns_mapping(logical_left.schema().len(), right.schema().len());
916        let mut left;
917        let right_dist = right.distribution();
918        match right_dist {
919            Distribution::HashShard(_) => {
920                let left_dist = r2l
921                    .rewrite_required_distribution(&RequiredDist::PhysicalDist(right_dist.clone()));
922                left = logical_left.to_stream_with_dist_required(&left_dist, ctx)?;
923            }
924            Distribution::UpstreamHashShard(_, _) => {
925                left = logical_left.to_stream_with_dist_required(
926                    &RequiredDist::shard_by_key(
927                        self.left().schema().len(),
928                        &predicate.left_eq_indexes(),
929                    ),
930                    ctx,
931                )?;
932                let left_dist = left.distribution();
933                match left_dist {
934                    Distribution::HashShard(_) => {
935                        let right_dist = l2r.rewrite_required_distribution(
936                            &RequiredDist::PhysicalDist(left_dist.clone()),
937                        );
938                        right = right_dist.streaming_enforce_if_not_satisfies(right)?
939                    }
940                    Distribution::UpstreamHashShard(_, _) => {
941                        left = RequiredDist::hash_shard(&predicate.left_eq_indexes())
942                            .streaming_enforce_if_not_satisfies(left)?;
943                        right = RequiredDist::hash_shard(&predicate.right_eq_indexes())
944                            .streaming_enforce_if_not_satisfies(right)?;
945                    }
946                    _ => unreachable!(),
947                }
948            }
949            _ => unreachable!(),
950        }
951        Ok((left, right))
952    }
953
954    fn to_stream_hash_join(
955        &self,
956        predicate: EqJoinPredicate,
957        ctx: &mut ToStreamContext,
958    ) -> Result<StreamPlanRef> {
959        use super::stream::prelude::*;
960
961        assert!(predicate.has_eq());
962        let (left, right) = self.get_stream_input_for_hash_join(&predicate, ctx)?;
963
964        let core = self.core.clone_with_inputs(left, right);
965
966        // Convert to Hash Join for equal joins
967        // For inner joins, pull non-equal conditions to a filter operator on top of it by default.
968        // We do so as the filter operator can apply the non-equal condition batch-wise (vectorized)
969        // as opposed to the HashJoin, which applies the condition row-wise.
970        // However, the default behavior of pulling up non-equal conditions can be overridden by the
971        // session variable `streaming_force_filter_inside_join` as it can save unnecessary
972        // materialization of rows only to be filtered later.
973
974        let stream_hash_join = StreamHashJoin::new(core.clone(), predicate.clone())?;
975
976        let force_filter_inside_join = self
977            .base
978            .ctx()
979            .session_ctx()
980            .config()
981            .streaming_force_filter_inside_join();
982
983        let pull_filter = self.join_type() == JoinType::Inner
984            && stream_hash_join.eq_join_predicate().has_non_eq()
985            && stream_hash_join.inequality_pairs().is_empty()
986            && (!force_filter_inside_join);
987        if pull_filter {
988            let default_indices = (0..self.internal_column_num()).collect::<Vec<_>>();
989
990            let mut core = core;
991            core.output_indices = default_indices.clone();
992            // Temporarily remove output indices.
993            let eq_cond = EqJoinPredicate::new(
994                Condition::true_cond(),
995                predicate.eq_keys().to_vec(),
996                self.left().schema().len(),
997                self.right().schema().len(),
998            );
999            core.on = eq_cond.eq_cond();
1000            let hash_join = StreamHashJoin::new(core, eq_cond)?.into();
1001            let logical_filter = generic::Filter::new(predicate.non_eq_cond(), hash_join);
1002            let plan = StreamFilter::new(logical_filter).into();
1003            if self.output_indices() != &default_indices {
1004                let logical_project = generic::Project::with_mapping(
1005                    plan,
1006                    ColIndexMapping::with_remaining_columns(
1007                        self.output_indices(),
1008                        self.internal_column_num(),
1009                    ),
1010                );
1011                Ok(StreamProject::new(logical_project).into())
1012            } else {
1013                Ok(plan)
1014            }
1015        } else {
1016            Ok(stream_hash_join.into())
1017        }
1018    }
1019
1020    fn should_be_temporal_join(&self) -> bool {
1021        let right = self.right();
1022        if let Some(logical_scan) = right.as_logical_scan() {
1023            matches!(logical_scan.as_of(), Some(AsOf::ProcessTime))
1024        } else {
1025            false
1026        }
1027    }
1028
1029    fn to_stream_temporal_join_with_index_selection(
1030        &self,
1031        predicate: EqJoinPredicate,
1032        ctx: &mut ToStreamContext,
1033    ) -> Result<StreamPlanRef> {
1034        // Index selection for temporal join.
1035        let right = self.right();
1036        // `should_be_temporal_join()` has already check right input for us.
1037        let logical_scan: &LogicalScan = right.as_logical_scan().unwrap();
1038
1039        // Use primary table.
1040        let mut result_plan: Result<StreamTemporalJoin> =
1041            self.to_stream_temporal_join(predicate.clone(), ctx);
1042        // Return directly if this temporal join can match the pk of its right table.
1043        if let Ok(temporal_join) = &result_plan
1044            && temporal_join.eq_join_predicate().eq_indexes().len()
1045                == logical_scan.primary_key().len()
1046        {
1047            return result_plan.map(|x| x.into());
1048        }
1049        if self
1050            .core
1051            .ctx()
1052            .session_ctx()
1053            .config()
1054            .enable_index_selection()
1055        {
1056            let indexes = logical_scan.table_indexes();
1057            for index in indexes {
1058                // Use index table
1059                if let Some(index_scan) = logical_scan.to_index_scan_if_index_covered(index) {
1060                    let index_scan: PlanRef = index_scan.into();
1061                    let that = self.clone_with_left_right(self.left(), index_scan.clone());
1062                    if let Ok(temporal_join) = that.to_stream_temporal_join(predicate.clone(), ctx)
1063                    {
1064                        match &result_plan {
1065                            Err(_) => result_plan = Ok(temporal_join),
1066                            Ok(prev_temporal_join) => {
1067                                // Prefer to the temporal join with a longer lookup prefix len.
1068                                if prev_temporal_join.eq_join_predicate().eq_indexes().len()
1069                                    < temporal_join.eq_join_predicate().eq_indexes().len()
1070                                {
1071                                    result_plan = Ok(temporal_join)
1072                                }
1073                            }
1074                        }
1075                    }
1076                }
1077            }
1078        }
1079
1080        result_plan.map(|x| x.into())
1081    }
1082
1083    fn check_temporal_rhs(right: &PlanRef) -> Result<&LogicalScan> {
1084        let Some(logical_scan) = right.as_logical_scan() else {
1085            return Err(RwError::from(ErrorCode::NotSupported(
1086                "Temporal join requires a table scan as its lookup table".into(),
1087                "Please provide a table scan".into(),
1088            )));
1089        };
1090
1091        if !matches!(logical_scan.as_of(), Some(AsOf::ProcessTime)) {
1092            return Err(RwError::from(ErrorCode::NotSupported(
1093                "Temporal join requires a table defined as temporal table".into(),
1094                "Please use FOR SYSTEM_TIME AS OF PROCTIME() syntax".into(),
1095            )));
1096        }
1097        Ok(logical_scan)
1098    }
1099
1100    fn temporal_join_scan_predicate_pull_up(
1101        logical_scan: &LogicalScan,
1102        predicate: EqJoinPredicate,
1103        output_indices: &[usize],
1104        left_schema_len: usize,
1105    ) -> Result<(StreamTableScan, EqJoinPredicate, Condition, Vec<usize>)> {
1106        // Extract the predicate from logical scan. Only pure scan is supported.
1107        let (new_scan, scan_predicate, project_expr) = logical_scan.predicate_pull_up();
1108        // Construct output column to require column mapping
1109        let o2r = if let Some(project_expr) = project_expr {
1110            project_expr
1111                .into_iter()
1112                .map(|x| x.as_input_ref().unwrap().index)
1113                .collect_vec()
1114        } else {
1115            (0..logical_scan.output_col_idx().len()).collect_vec()
1116        };
1117        let mut join_predicate_rewriter = LookupJoinPredicateRewriter {
1118            offset: left_schema_len,
1119            mapping: o2r.clone(),
1120        };
1121
1122        let new_eq_cond = predicate
1123            .eq_cond()
1124            .rewrite_expr(&mut join_predicate_rewriter);
1125
1126        let mut scan_predicate_rewriter = LookupJoinScanPredicateRewriter {
1127            offset: left_schema_len,
1128        };
1129
1130        let new_other_cond = predicate
1131            .other_cond()
1132            .clone()
1133            .rewrite_expr(&mut join_predicate_rewriter)
1134            .and(scan_predicate.rewrite_expr(&mut scan_predicate_rewriter));
1135
1136        let new_join_on = new_eq_cond.and(new_other_cond);
1137
1138        let new_predicate = EqJoinPredicate::create(
1139            left_schema_len,
1140            new_scan.schema().len(),
1141            new_join_on.clone(),
1142        );
1143
1144        // Rewrite the join output indices and all output indices referred to the old scan need to
1145        // rewrite.
1146        let new_join_output_indices = output_indices
1147            .iter()
1148            .map(|&x| {
1149                if x < left_schema_len {
1150                    x
1151                } else {
1152                    o2r[x - left_schema_len] + left_schema_len
1153                }
1154            })
1155            .collect_vec();
1156
1157        // Use UpstreamOnly chain type
1158        if new_scan.cross_database() {
1159            return Err(RwError::from(ErrorCode::NotSupported(
1160                "Temporal join requires the lookup table to be in the same database as the stream source table".into(),
1161                "Please ensure both tables are in the same database".into(),
1162            )));
1163        }
1164        let new_stream_table_scan =
1165            StreamTableScan::new_with_stream_scan_type(new_scan, StreamScanType::UpstreamOnly);
1166        Ok((
1167            new_stream_table_scan,
1168            new_predicate,
1169            new_join_on,
1170            new_join_output_indices,
1171        ))
1172    }
1173
1174    fn to_stream_temporal_join(
1175        &self,
1176        predicate: EqJoinPredicate,
1177        ctx: &mut ToStreamContext,
1178    ) -> Result<StreamTemporalJoin> {
1179        use super::stream::prelude::*;
1180
1181        assert!(predicate.has_eq());
1182
1183        let right = self.right();
1184
1185        let logical_scan = Self::check_temporal_rhs(&right)?;
1186
1187        let table = logical_scan.table();
1188        let output_column_ids = logical_scan.output_column_ids();
1189
1190        // Verify that the right join key columns are the the prefix of the primary key and
1191        // also contain the distribution key.
1192        let order_col_ids = table.order_column_ids();
1193        let dist_key = table.distribution_key.clone();
1194
1195        let mut dist_key_in_order_key_pos = vec![];
1196        for d in dist_key {
1197            let pos = table
1198                .order_column_indices()
1199                .position(|x| x == d)
1200                .expect("dist_key must in order_key");
1201            dist_key_in_order_key_pos.push(pos);
1202        }
1203        // The shortest prefix of order key that contains distribution key.
1204        let shortest_prefix_len = dist_key_in_order_key_pos
1205            .iter()
1206            .max()
1207            .map_or(0, |pos| pos + 1);
1208
1209        // Reorder the join equal predicate to match the order key.
1210        let mut reorder_idx = Vec::with_capacity(shortest_prefix_len);
1211        for order_col_id in order_col_ids {
1212            let mut found = false;
1213            for (i, eq_idx) in predicate.right_eq_indexes().into_iter().enumerate() {
1214                if order_col_id == output_column_ids[eq_idx] {
1215                    reorder_idx.push(i);
1216                    found = true;
1217                    break;
1218                }
1219            }
1220            if !found {
1221                break;
1222            }
1223        }
1224        if reorder_idx.len() < shortest_prefix_len {
1225            // TODO: support index selection for temporal join and refine this error message.
1226            return Err(RwError::from(ErrorCode::NotSupported(
1227                "Temporal join requires the lookup table's primary key contained exactly in the equivalence condition".into(),
1228                "Please add the primary key of the lookup table to the join condition and remove any other conditions".into(),
1229            )));
1230        }
1231        let lookup_prefix_len = reorder_idx.len();
1232        let predicate = predicate.reorder(&reorder_idx);
1233
1234        let required_dist = if dist_key_in_order_key_pos.is_empty() {
1235            RequiredDist::single()
1236        } else {
1237            let left_eq_indexes = predicate.left_eq_indexes();
1238            let left_dist_key = dist_key_in_order_key_pos
1239                .iter()
1240                .map(|pos| left_eq_indexes[*pos])
1241                .collect_vec();
1242
1243            RequiredDist::hash_shard(&left_dist_key)
1244        };
1245
1246        let lhs_join_key_idx = predicate
1247            .eq_indexes()
1248            .into_iter()
1249            .map(|(l, _)| l)
1250            .collect_vec();
1251        let logical_left = try_enforce_locality_requirement(self.left(), &lhs_join_key_idx);
1252        let left = logical_left.to_stream(ctx)?;
1253        // Enforce a shuffle for the temporal join LHS to let the scheduler be able to schedule the join fragment together with the RHS with a `no_shuffle` exchange.
1254        let left = required_dist.stream_enforce(left);
1255
1256        let (new_stream_table_scan, new_predicate, new_join_on, new_join_output_indices) =
1257            Self::temporal_join_scan_predicate_pull_up(
1258                logical_scan,
1259                predicate,
1260                self.output_indices(),
1261                self.left().schema().len(),
1262            )?;
1263
1264        let right = RequiredDist::no_shuffle(new_stream_table_scan.into());
1265        if !new_predicate.has_eq() {
1266            return Err(RwError::from(ErrorCode::NotSupported(
1267                "Temporal join requires a non trivial join condition".into(),
1268                "Please remove the false condition of the join".into(),
1269            )));
1270        }
1271
1272        // Construct a new logical join, because we have change its RHS.
1273        let new_logical_join = generic::Join::new(
1274            left,
1275            right,
1276            new_join_on,
1277            self.join_type(),
1278            new_join_output_indices,
1279        );
1280
1281        let new_predicate = new_predicate.retain_prefix_eq_key(lookup_prefix_len);
1282
1283        StreamTemporalJoin::new(new_logical_join, new_predicate, false)
1284    }
1285
1286    fn to_stream_nested_loop_temporal_join(
1287        &self,
1288        predicate: EqJoinPredicate,
1289        ctx: &mut ToStreamContext,
1290    ) -> Result<StreamPlanRef> {
1291        use super::stream::prelude::*;
1292        assert!(!predicate.has_eq());
1293
1294        let left = self.left().to_stream_with_dist_required(
1295            &RequiredDist::PhysicalDist(Distribution::Broadcast),
1296            ctx,
1297        )?;
1298        assert!(left.as_stream_exchange().is_some());
1299
1300        if self.join_type() != JoinType::Inner {
1301            return Err(RwError::from(ErrorCode::NotSupported(
1302                "Temporal join requires an inner join".into(),
1303                "Please use an inner join".into(),
1304            )));
1305        }
1306
1307        if !left.append_only() {
1308            return Err(RwError::from(ErrorCode::NotSupported(
1309                "Nested-loop Temporal join requires the left hash side to be append only".into(),
1310                "Please ensure the left hash side is append only".into(),
1311            )));
1312        }
1313
1314        let right = self.right();
1315        let logical_scan = Self::check_temporal_rhs(&right)?;
1316
1317        let (new_stream_table_scan, new_predicate, new_join_on, new_join_output_indices) =
1318            Self::temporal_join_scan_predicate_pull_up(
1319                logical_scan,
1320                predicate,
1321                self.output_indices(),
1322                self.left().schema().len(),
1323            )?;
1324
1325        let right = RequiredDist::no_shuffle(new_stream_table_scan.into());
1326
1327        // Construct a new logical join, because we have change its RHS.
1328        let new_logical_join = generic::Join::new(
1329            left,
1330            right,
1331            new_join_on,
1332            self.join_type(),
1333            new_join_output_indices,
1334        );
1335
1336        Ok(StreamTemporalJoin::new(new_logical_join, new_predicate, true)?.into())
1337    }
1338
1339    fn to_stream_dynamic_filter(
1340        &self,
1341        predicate: Condition,
1342        ctx: &mut ToStreamContext,
1343    ) -> Result<Option<StreamPlanRef>> {
1344        use super::stream::prelude::*;
1345
1346        // If there is exactly one predicate, it is a comparison (<, <=, >, >=), and the
1347        // join is a `Inner` or `LeftSemi` join, we can convert the scalar subquery into a
1348        // `StreamDynamicFilter`
1349
1350        // Check if `Inner`/`LeftSemi`
1351        if !matches!(self.join_type(), JoinType::Inner | JoinType::LeftSemi) {
1352            return Ok(None);
1353        }
1354
1355        // Check if right side is a scalar
1356        if !self.right().max_one_row() {
1357            return Ok(None);
1358        }
1359        if self.right().schema().len() != 1 {
1360            return Ok(None);
1361        }
1362
1363        // Check if the join condition is a correlated comparison
1364        if predicate.conjunctions.len() > 1 {
1365            return Ok(None);
1366        }
1367        let expr: ExprImpl = predicate.into();
1368        let (left_ref, comparator, right_ref) = match expr.as_comparison_cond() {
1369            Some(v) => v,
1370            None => return Ok(None),
1371        };
1372
1373        let condition_cross_inputs = left_ref.index < self.left().schema().len()
1374            && right_ref.index == self.left().schema().len() /* right side has only one column */;
1375        if !condition_cross_inputs {
1376            // Maybe we should panic here because it means some predicates are not pushed down.
1377            return Ok(None);
1378        }
1379
1380        // We align input types on all join predicates with cmp operator
1381        if self.left().schema().fields()[left_ref.index].data_type
1382            != self.right().schema().fields()[0].data_type
1383        {
1384            return Ok(None);
1385        }
1386
1387        // Check if non of the columns from the inner side is required to output
1388        let all_output_from_left = self
1389            .output_indices()
1390            .iter()
1391            .all(|i| *i < self.left().schema().len());
1392        if !all_output_from_left {
1393            return Ok(None);
1394        }
1395
1396        let left = self.left().to_stream(ctx)?.enforce_concrete_distribution();
1397        let right = self.right().to_stream_with_dist_required(
1398            &RequiredDist::PhysicalDist(Distribution::Broadcast),
1399            ctx,
1400        )?;
1401
1402        assert!(right.as_stream_exchange().is_some());
1403        assert_eq!(
1404            *right.inputs().iter().exactly_one().unwrap().distribution(),
1405            Distribution::Single
1406        );
1407
1408        let core = DynamicFilter::new(comparator, left_ref.index, left, right);
1409        let plan = StreamDynamicFilter::new(core)?.into();
1410        // TODO: `DynamicFilterExecutor` should support `output_indices` in `ChunkBuilder`
1411        if self
1412            .output_indices()
1413            .iter()
1414            .copied()
1415            .ne(0..self.left().schema().len())
1416        {
1417            // The schema of dynamic filter is always the same as the left side now, and we have
1418            // checked that all output columns are from the left side before.
1419            let logical_project = generic::Project::with_mapping(
1420                plan,
1421                ColIndexMapping::with_remaining_columns(
1422                    self.output_indices(),
1423                    self.left().schema().len(),
1424                ),
1425            );
1426            Ok(Some(StreamProject::new(logical_project).into()))
1427        } else {
1428            Ok(Some(plan))
1429        }
1430    }
1431
1432    pub fn index_lookup_join_to_batch_lookup_join(&self) -> Result<BatchPlanRef> {
1433        let predicate = EqJoinPredicate::create(
1434            self.left().schema().len(),
1435            self.right().schema().len(),
1436            self.on().clone(),
1437        );
1438        assert!(predicate.has_eq());
1439
1440        let join = self
1441            .core
1442            .clone_with_inputs(self.core.left.to_batch()?, self.core.right.to_batch()?);
1443
1444        Ok(self
1445            .to_batch_lookup_join(predicate, join)?
1446            .expect("Fail to convert to lookup join")
1447            .into())
1448    }
1449
1450    fn to_stream_asof_join(
1451        &self,
1452        predicate: EqJoinPredicate,
1453        ctx: &mut ToStreamContext,
1454    ) -> Result<StreamPlanRef> {
1455        use super::stream::prelude::*;
1456
1457        if predicate.eq_keys().is_empty() {
1458            return Err(ErrorCode::InvalidInputSyntax(
1459                "AsOf join requires at least 1 equal condition".to_owned(),
1460            )
1461            .into());
1462        }
1463
1464        let (left, right) = self.get_stream_input_for_hash_join(&predicate, ctx)?;
1465        let left_len = left.schema().len();
1466        let core = self.core.clone_with_inputs(left, right);
1467
1468        let inequality_desc =
1469            Self::get_inequality_desc_from_predicate(predicate.other_cond().clone(), left_len)?;
1470
1471        Ok(StreamAsOfJoin::new(core, predicate, inequality_desc)?.into())
1472    }
1473
1474    /// Convert the logical join to a Hash join.
1475    fn to_batch_hash_join(
1476        &self,
1477        logical_join: generic::Join<BatchPlanRef>,
1478        predicate: EqJoinPredicate,
1479    ) -> Result<BatchPlanRef> {
1480        use super::batch::prelude::*;
1481
1482        let left_schema_len = logical_join.left.schema().len();
1483        let asof_desc = self
1484            .is_asof_join()
1485            .then(|| {
1486                Self::get_inequality_desc_from_predicate(
1487                    predicate.other_cond().clone(),
1488                    left_schema_len,
1489                )
1490            })
1491            .transpose()?;
1492
1493        let batch_join = BatchHashJoin::new(logical_join, predicate, asof_desc);
1494        Ok(batch_join.into())
1495    }
1496
1497    pub fn get_inequality_desc_from_predicate(
1498        predicate: Condition,
1499        left_input_len: usize,
1500    ) -> Result<AsOfJoinDesc> {
1501        let expr: ExprImpl = predicate.into();
1502        if let Some((left_input_ref, expr_type, right_input_ref)) = expr.as_comparison_cond() {
1503            if left_input_ref.index() < left_input_len && right_input_ref.index() >= left_input_len
1504            {
1505                Ok(AsOfJoinDesc {
1506                    left_idx: left_input_ref.index() as u32,
1507                    right_idx: (right_input_ref.index() - left_input_len) as u32,
1508                    inequality_type: Self::expr_type_to_comparison_type(expr_type)?.into(),
1509                })
1510            } else {
1511                bail!("inequal condition from the same side should be push down in optimizer");
1512            }
1513        } else {
1514            Err(ErrorCode::InvalidInputSyntax(
1515                "AsOf join requires exactly 1 ineuquality condition".to_owned(),
1516            )
1517            .into())
1518        }
1519    }
1520
1521    fn expr_type_to_comparison_type(expr_type: PbType) -> Result<PbAsOfJoinInequalityType> {
1522        match expr_type {
1523            PbType::LessThan => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeLt),
1524            PbType::LessThanOrEqual => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeLe),
1525            PbType::GreaterThan => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeGt),
1526            PbType::GreaterThanOrEqual => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeGe),
1527            _ => Err(ErrorCode::InvalidInputSyntax(format!(
1528                "Invalid comparison type: {}",
1529                expr_type.as_str_name()
1530            ))
1531            .into()),
1532        }
1533    }
1534}
1535
1536impl ToBatch for LogicalJoin {
1537    fn to_batch(&self) -> Result<crate::optimizer::plan_node::BatchPlanRef> {
1538        let predicate = EqJoinPredicate::create(
1539            self.left().schema().len(),
1540            self.right().schema().len(),
1541            self.on().clone(),
1542        );
1543
1544        let batch_join = self
1545            .core
1546            .clone_with_inputs(self.core.left.to_batch()?, self.core.right.to_batch()?);
1547
1548        let ctx = self.base.ctx();
1549        let config = ctx.session_ctx().config();
1550
1551        if predicate.has_eq() {
1552            if !predicate.eq_keys_are_type_aligned() {
1553                return Err(ErrorCode::InternalError(format!(
1554                    "Join eq keys are not aligned for predicate: {predicate:?}"
1555                ))
1556                .into());
1557            }
1558            if config.batch_enable_lookup_join()
1559                && let Some(lookup_join) = self.to_batch_lookup_join_with_index_selection(
1560                    predicate.clone(),
1561                    batch_join.clone(),
1562                )?
1563            {
1564                return Ok(lookup_join.into());
1565            }
1566            self.to_batch_hash_join(batch_join, predicate)
1567        } else if self.is_asof_join() {
1568            Err(ErrorCode::InvalidInputSyntax(
1569                "AsOf join requires at least 1 equal condition".to_owned(),
1570            )
1571            .into())
1572        } else {
1573            // Convert to Nested-loop Join for non-equal joins
1574            Ok(BatchNestedLoopJoin::new(batch_join).into())
1575        }
1576    }
1577}
1578
1579impl ToStream for LogicalJoin {
1580    fn to_stream(
1581        &self,
1582        ctx: &mut ToStreamContext,
1583    ) -> Result<crate::optimizer::plan_node::StreamPlanRef> {
1584        if self
1585            .on()
1586            .conjunctions
1587            .iter()
1588            .any(|cond| cond.count_nows() > 0)
1589        {
1590            return Err(ErrorCode::NotSupported(
1591                "optimizer has tried to separate the temporal predicate(with now() expression) from the on condition, but it still reminded in on join's condition. Considering move it into WHERE clause?".to_owned(),
1592                 "please refer to https://www.risingwave.dev/docs/current/sql-pattern-temporal-filters/ for more information".to_owned()).into());
1593        }
1594
1595        let predicate = EqJoinPredicate::create(
1596            self.left().schema().len(),
1597            self.right().schema().len(),
1598            self.on().clone(),
1599        );
1600
1601        if self.join_type() == JoinType::AsofInner || self.join_type() == JoinType::AsofLeftOuter {
1602            self.to_stream_asof_join(predicate, ctx)
1603        } else if predicate.has_eq() {
1604            if !predicate.eq_keys_are_type_aligned() {
1605                return Err(ErrorCode::InternalError(format!(
1606                    "Join eq keys are not aligned for predicate: {predicate:?}"
1607                ))
1608                .into());
1609            }
1610
1611            if self.should_be_temporal_join() {
1612                self.to_stream_temporal_join_with_index_selection(predicate, ctx)
1613            } else {
1614                self.to_stream_hash_join(predicate, ctx)
1615            }
1616        } else if self.should_be_temporal_join() {
1617            self.to_stream_nested_loop_temporal_join(predicate, ctx)
1618        } else if let Some(dynamic_filter) =
1619            self.to_stream_dynamic_filter(self.on().clone(), ctx)?
1620        {
1621            Ok(dynamic_filter)
1622        } else {
1623            Err(RwError::from(ErrorCode::NotSupported(
1624                "streaming nested-loop join".to_owned(),
1625                "The non-equal join in the query requires a nested-loop join executor, which could be very expensive to run. \
1626                 Consider rewriting the query to use dynamic filter as a substitute if possible.\n\
1627                 See also: https://docs.risingwave.com/docs/current/sql-pattern-dynamic-filters/".to_owned(),
1628            )))
1629        }
1630    }
1631
1632    fn logical_rewrite_for_stream(
1633        &self,
1634        ctx: &mut RewriteStreamContext,
1635    ) -> Result<(PlanRef, ColIndexMapping)> {
1636        let (left, left_col_change) = self.left().logical_rewrite_for_stream(ctx)?;
1637        let left_len = left.schema().len();
1638        let (right, right_col_change) = self.right().logical_rewrite_for_stream(ctx)?;
1639        let (join, out_col_change) = self.rewrite_with_left_right(
1640            left.clone(),
1641            left_col_change,
1642            right.clone(),
1643            right_col_change,
1644        );
1645
1646        let mapping = ColIndexMapping::with_remaining_columns(
1647            join.output_indices(),
1648            join.internal_column_num(),
1649        );
1650
1651        let l2o = join.core.l2i_col_mapping().composite(&mapping);
1652        let r2o = join.core.r2i_col_mapping().composite(&mapping);
1653
1654        // Add missing pk indices to the logical join
1655        let mut left_to_add = left
1656            .expect_stream_key()
1657            .iter()
1658            .cloned()
1659            .filter(|i| l2o.try_map(*i).is_none())
1660            .collect_vec();
1661
1662        let mut right_to_add = right
1663            .expect_stream_key()
1664            .iter()
1665            .filter(|&&i| r2o.try_map(i).is_none())
1666            .map(|&i| i + left_len)
1667            .collect_vec();
1668
1669        // NOTE(st1page): add join keys in the pk_indices a work around before we really have stream
1670        // key.
1671        let right_len = right.schema().len();
1672        let eq_predicate = EqJoinPredicate::create(left_len, right_len, join.on().clone());
1673
1674        let either_or_both = self.core.add_which_join_key_to_pk();
1675
1676        for (lk, rk) in eq_predicate.eq_indexes() {
1677            match either_or_both {
1678                EitherOrBoth::Left(_) => {
1679                    if l2o.try_map(lk).is_none() {
1680                        left_to_add.push(lk);
1681                    }
1682                }
1683                EitherOrBoth::Right(_) => {
1684                    if r2o.try_map(rk).is_none() {
1685                        right_to_add.push(rk + left_len)
1686                    }
1687                }
1688                EitherOrBoth::Both(_, _) => {
1689                    if l2o.try_map(lk).is_none() {
1690                        left_to_add.push(lk);
1691                    }
1692                    if r2o.try_map(rk).is_none() {
1693                        right_to_add.push(rk + left_len)
1694                    }
1695                }
1696            };
1697        }
1698        let left_to_add = left_to_add.into_iter().unique();
1699        let right_to_add = right_to_add.into_iter().unique();
1700        // NOTE(st1page) over
1701
1702        let mut new_output_indices = join.output_indices().clone();
1703        if !join.is_right_join() {
1704            new_output_indices.extend(left_to_add);
1705        }
1706        if !join.is_left_join() {
1707            new_output_indices.extend(right_to_add);
1708        }
1709
1710        let join_with_pk = join.clone_with_output_indices(new_output_indices);
1711
1712        let plan = if join_with_pk.join_type() == JoinType::FullOuter {
1713            // ignore the all NULL to maintain the stream key's uniqueness, see https://github.com/risingwavelabs/risingwave/issues/8084 for more information
1714
1715            let l2o = join_with_pk
1716                .core
1717                .l2i_col_mapping()
1718                .composite(&join_with_pk.core.i2o_col_mapping());
1719            let r2o = join_with_pk
1720                .core
1721                .r2i_col_mapping()
1722                .composite(&join_with_pk.core.i2o_col_mapping());
1723            let left_right_stream_keys = join_with_pk
1724                .left()
1725                .expect_stream_key()
1726                .iter()
1727                .map(|i| l2o.map(*i))
1728                .chain(
1729                    join_with_pk
1730                        .right()
1731                        .expect_stream_key()
1732                        .iter()
1733                        .map(|i| r2o.map(*i)),
1734                )
1735                .collect_vec();
1736            let plan: PlanRef = join_with_pk.into();
1737            LogicalFilter::filter_out_all_null_keys(plan, &left_right_stream_keys)
1738        } else {
1739            join_with_pk.into()
1740        };
1741
1742        // the added columns is at the end, so it will not change the exists column index
1743        Ok((plan, out_col_change))
1744    }
1745
1746    fn try_better_locality(&self, columns: &[usize]) -> Option<PlanRef> {
1747        let mut ctx = ToStreamContext::new(false);
1748        // only pass through the locality information if it can be converted to dynamic filter
1749        if let Ok(Some(_)) = self.to_stream_dynamic_filter(self.on().clone(), &mut ctx) {
1750            // since dynamic filter only supports left input ref in the output indices, we can safely use o2i mapping to convert the required columns.
1751            let o2i_mapping = self.core.o2i_col_mapping();
1752            let left_input_columns = columns
1753                .iter()
1754                .map(|&col| o2i_mapping.try_map(col))
1755                .collect::<Option<Vec<usize>>>()?;
1756            if let Some(better_left_plan) = self.left().try_better_locality(&left_input_columns) {
1757                return Some(
1758                    self.clone_with_left_right(better_left_plan, self.right())
1759                        .into(),
1760                );
1761            }
1762        }
1763        None
1764    }
1765}
1766
1767#[cfg(test)]
1768mod tests {
1769
1770    use std::collections::HashSet;
1771
1772    use risingwave_common::catalog::{Field, Schema};
1773    use risingwave_common::types::{DataType, Datum};
1774    use risingwave_pb::expr::expr_node::Type;
1775
1776    use super::*;
1777    use crate::expr::{FunctionCall, Literal, assert_eq_input_ref};
1778    use crate::optimizer::optimizer_context::OptimizerContext;
1779    use crate::optimizer::plan_node::LogicalValues;
1780    use crate::optimizer::property::FunctionalDependency;
1781
1782    /// Pruning
1783    /// ```text
1784    /// Join(on: input_ref(1)=input_ref(3))
1785    ///   TableScan(v1, v2, v3)
1786    ///   TableScan(v4, v5, v6)
1787    /// ```
1788    /// with required columns [2,3] will result in
1789    /// ```text
1790    /// Project(input_ref(1), input_ref(2))
1791    ///   Join(on: input_ref(0)=input_ref(2))
1792    ///     TableScan(v2, v3)
1793    ///     TableScan(v4)
1794    /// ```
1795    #[tokio::test]
1796    async fn test_prune_join() {
1797        let ty = DataType::Int32;
1798        let ctx = OptimizerContext::mock().await;
1799        let fields: Vec<Field> = (1..7)
1800            .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
1801            .collect();
1802        let left = LogicalValues::new(
1803            vec![],
1804            Schema {
1805                fields: fields[0..3].to_vec(),
1806            },
1807            ctx.clone(),
1808        );
1809        let right = LogicalValues::new(
1810            vec![],
1811            Schema {
1812                fields: fields[3..6].to_vec(),
1813            },
1814            ctx,
1815        );
1816        let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
1817            FunctionCall::new(
1818                Type::Equal,
1819                vec![
1820                    ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
1821                    ExprImpl::InputRef(Box::new(InputRef::new(3, ty))),
1822                ],
1823            )
1824            .unwrap(),
1825        ));
1826        let join_type = JoinType::Inner;
1827        let join: PlanRef = LogicalJoin::new(
1828            left.into(),
1829            right.into(),
1830            join_type,
1831            Condition::with_expr(on),
1832        )
1833        .into();
1834
1835        // Perform the prune
1836        let required_cols = vec![2, 3];
1837        let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1838
1839        // Check the result
1840        let join = plan.as_logical_join().unwrap();
1841        assert_eq!(join.schema().fields().len(), 2);
1842        assert_eq!(join.schema().fields()[0], fields[2]);
1843        assert_eq!(join.schema().fields()[1], fields[3]);
1844
1845        let expr: ExprImpl = join.on().clone().into();
1846        let call = expr.as_function_call().unwrap();
1847        assert_eq_input_ref!(&call.inputs()[0], 0);
1848        assert_eq_input_ref!(&call.inputs()[1], 2);
1849
1850        let left = join.left();
1851        let left = left.as_logical_values().unwrap();
1852        assert_eq!(left.schema().fields(), &fields[1..3]);
1853        let right = join.right();
1854        let right = right.as_logical_values().unwrap();
1855        assert_eq!(right.schema().fields(), &fields[3..4]);
1856    }
1857
1858    /// Semi join panicked previously at `prune_col`. Add test to prevent regression.
1859    #[tokio::test]
1860    async fn test_prune_semi_join() {
1861        let ty = DataType::Int32;
1862        let ctx = OptimizerContext::mock().await;
1863        let fields: Vec<Field> = (1..7)
1864            .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
1865            .collect();
1866        let left = LogicalValues::new(
1867            vec![],
1868            Schema {
1869                fields: fields[0..3].to_vec(),
1870            },
1871            ctx.clone(),
1872        );
1873        let right = LogicalValues::new(
1874            vec![],
1875            Schema {
1876                fields: fields[3..6].to_vec(),
1877            },
1878            ctx,
1879        );
1880        let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
1881            FunctionCall::new(
1882                Type::Equal,
1883                vec![
1884                    ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
1885                    ExprImpl::InputRef(Box::new(InputRef::new(4, ty))),
1886                ],
1887            )
1888            .unwrap(),
1889        ));
1890        for join_type in [
1891            JoinType::LeftSemi,
1892            JoinType::RightSemi,
1893            JoinType::LeftAnti,
1894            JoinType::RightAnti,
1895        ] {
1896            let join = LogicalJoin::new(
1897                left.clone().into(),
1898                right.clone().into(),
1899                join_type,
1900                Condition::with_expr(on.clone()),
1901            );
1902
1903            let offset = if join.is_right_join() { 3 } else { 0 };
1904            let join: PlanRef = join.into();
1905            // Perform the prune
1906            let required_cols = vec![0];
1907            // key 0 is never used in the join (always key 1)
1908            let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1909            let as_plan = plan.as_logical_join().unwrap();
1910            // Check the result
1911            assert_eq!(as_plan.schema().fields().len(), 1);
1912            assert_eq!(as_plan.schema().fields()[0], fields[offset]);
1913
1914            // Perform the prune
1915            let required_cols = vec![0, 1, 2];
1916            // should not panic here
1917            let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1918            let as_plan = plan.as_logical_join().unwrap();
1919            // Check the result
1920            assert_eq!(as_plan.schema().fields().len(), 3);
1921            assert_eq!(as_plan.schema().fields()[0], fields[offset]);
1922            assert_eq!(as_plan.schema().fields()[1], fields[offset + 1]);
1923            assert_eq!(as_plan.schema().fields()[2], fields[offset + 2]);
1924        }
1925    }
1926
1927    /// Pruning
1928    /// ```text
1929    /// Join(on: input_ref(1)=input_ref(3))
1930    ///   TableScan(v1, v2, v3)
1931    ///   TableScan(v4, v5, v6)
1932    /// ```
1933    /// with required columns [1, 3] will result in
1934    /// ```text
1935    /// Join(on: input_ref(0)=input_ref(1))
1936    ///   TableScan(v2)
1937    ///   TableScan(v4)
1938    /// ```
1939    #[tokio::test]
1940    async fn test_prune_join_no_project() {
1941        let ty = DataType::Int32;
1942        let ctx = OptimizerContext::mock().await;
1943        let fields: Vec<Field> = (1..7)
1944            .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
1945            .collect();
1946        let left = LogicalValues::new(
1947            vec![],
1948            Schema {
1949                fields: fields[0..3].to_vec(),
1950            },
1951            ctx.clone(),
1952        );
1953        let right = LogicalValues::new(
1954            vec![],
1955            Schema {
1956                fields: fields[3..6].to_vec(),
1957            },
1958            ctx,
1959        );
1960        let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
1961            FunctionCall::new(
1962                Type::Equal,
1963                vec![
1964                    ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
1965                    ExprImpl::InputRef(Box::new(InputRef::new(3, ty))),
1966                ],
1967            )
1968            .unwrap(),
1969        ));
1970        let join_type = JoinType::Inner;
1971        let join: PlanRef = LogicalJoin::new(
1972            left.into(),
1973            right.into(),
1974            join_type,
1975            Condition::with_expr(on),
1976        )
1977        .into();
1978
1979        // Perform the prune
1980        let required_cols = vec![1, 3];
1981        let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1982
1983        // Check the result
1984        let join = plan.as_logical_join().unwrap();
1985        assert_eq!(join.schema().fields().len(), 2);
1986        assert_eq!(join.schema().fields()[0], fields[1]);
1987        assert_eq!(join.schema().fields()[1], fields[3]);
1988
1989        let expr: ExprImpl = join.on().clone().into();
1990        let call = expr.as_function_call().unwrap();
1991        assert_eq_input_ref!(&call.inputs()[0], 0);
1992        assert_eq_input_ref!(&call.inputs()[1], 1);
1993
1994        let left = join.left();
1995        let left = left.as_logical_values().unwrap();
1996        assert_eq!(left.schema().fields(), &fields[1..2]);
1997        let right = join.right();
1998        let right = right.as_logical_values().unwrap();
1999        assert_eq!(right.schema().fields(), &fields[3..4]);
2000    }
2001
2002    /// Convert
2003    /// ```text
2004    /// Join(on: ($1 = $3) AND ($2 == 42))
2005    ///   TableScan(v1, v2, v3)
2006    ///   TableScan(v4, v5, v6)
2007    /// ```
2008    /// to
2009    /// ```text
2010    /// Filter($2 == 42)
2011    ///   HashJoin(on: $1 = $3)
2012    ///     TableScan(v1, v2, v3)
2013    ///     TableScan(v4, v5, v6)
2014    /// ```
2015    #[tokio::test]
2016    async fn test_join_to_batch() {
2017        let ctx = OptimizerContext::mock().await;
2018        let fields: Vec<Field> = (1..7)
2019            .map(|i| Field::with_name(DataType::Int32, format!("v{}", i)))
2020            .collect();
2021        let left = LogicalValues::new(
2022            vec![],
2023            Schema {
2024                fields: fields[0..3].to_vec(),
2025            },
2026            ctx.clone(),
2027        );
2028        let right = LogicalValues::new(
2029            vec![],
2030            Schema {
2031                fields: fields[3..6].to_vec(),
2032            },
2033            ctx,
2034        );
2035
2036        fn input_ref(i: usize) -> ExprImpl {
2037            ExprImpl::InputRef(Box::new(InputRef::new(i, DataType::Int32)))
2038        }
2039        let eq_cond = ExprImpl::FunctionCall(Box::new(
2040            FunctionCall::new(Type::Equal, vec![input_ref(1), input_ref(3)]).unwrap(),
2041        ));
2042        let non_eq_cond = ExprImpl::FunctionCall(Box::new(
2043            FunctionCall::new(
2044                Type::Equal,
2045                vec![
2046                    input_ref(2),
2047                    ExprImpl::Literal(Box::new(Literal::new(
2048                        Datum::Some(42_i32.into()),
2049                        DataType::Int32,
2050                    ))),
2051                ],
2052            )
2053            .unwrap(),
2054        ));
2055        // Condition: ($1 = $3) AND ($2 == 42)
2056        let on_cond = ExprImpl::FunctionCall(Box::new(
2057            FunctionCall::new(Type::And, vec![eq_cond.clone(), non_eq_cond.clone()]).unwrap(),
2058        ));
2059
2060        let join_type = JoinType::Inner;
2061        let logical_join = LogicalJoin::new(
2062            left.into(),
2063            right.into(),
2064            join_type,
2065            Condition::with_expr(on_cond),
2066        );
2067
2068        // Perform `to_batch`
2069        let result = logical_join.to_batch().unwrap();
2070
2071        // Expected plan:  HashJoin($1 = $3 AND $2 == 42)
2072        let hash_join = result.as_batch_hash_join().unwrap();
2073        assert_eq!(
2074            ExprImpl::from(hash_join.eq_join_predicate().eq_cond()),
2075            eq_cond
2076        );
2077        assert_eq!(
2078            *hash_join
2079                .eq_join_predicate()
2080                .non_eq_cond()
2081                .conjunctions
2082                .first()
2083                .unwrap(),
2084            non_eq_cond
2085        );
2086    }
2087
2088    /// Convert
2089    /// ```text
2090    /// Join(join_type: left outer, on: ($1 = $3) AND ($2 == 42))
2091    ///   TableScan(v1, v2, v3)
2092    ///   TableScan(v4, v5, v6)
2093    /// ```
2094    /// to
2095    /// ```text
2096    /// HashJoin(join_type: left outer, on: ($1 = $3) AND ($2 == 42))
2097    ///   TableScan(v1, v2, v3)
2098    ///   TableScan(v4, v5, v6)
2099    /// ```
2100    #[tokio::test]
2101    #[ignore] // ignore due to refactor logical scan, but the test seem to duplicate with the explain test
2102    // framework, maybe we will remove it?
2103    async fn test_join_to_stream() {
2104        // let ctx = Rc::new(RefCell::new(QueryContext::mock().await));
2105        // let fields: Vec<Field> = (1..7)
2106        //     .map(|i| Field {
2107        //         data_type: DataType::Int32,
2108        //         name: format!("v{}", i),
2109        //     })
2110        //     .collect();
2111        // let left = LogicalScan::new(
2112        //     "left".to_string(),
2113        //     TableId::new(0),
2114        //     vec![1.into(), 2.into(), 3.into()],
2115        //     Schema {
2116        //         fields: fields[0..3].to_vec(),
2117        //     },
2118        //     ctx.clone(),
2119        // );
2120        // let right = LogicalScan::new(
2121        //     "right".to_string(),
2122        //     TableId::new(0),
2123        //     vec![4.into(), 5.into(), 6.into()],
2124        //     Schema {
2125        //                 fields: fields[3..6].to_vec(),
2126        //     },
2127        //     ctx,
2128        // );
2129        // let eq_cond = ExprImpl::FunctionCall(Box::new(
2130        //     FunctionCall::new(
2131        //         Type::Equal,
2132        //         vec![
2133        //             ExprImpl::InputRef(Box::new(InputRef::new(1, DataType::Int32))),
2134        //             ExprImpl::InputRef(Box::new(InputRef::new(3, DataType::Int32))),
2135        //         ],
2136        //     )
2137        //     .unwrap(),
2138        // ));
2139        // let non_eq_cond = ExprImpl::FunctionCall(Box::new(
2140        //     FunctionCall::new(
2141        //         Type::Equal,
2142        //         vec![
2143        //             ExprImpl::InputRef(Box::new(InputRef::new(2, DataType::Int32))),
2144        //             ExprImpl::Literal(Box::new(Literal::new(
2145        //                 Datum::Some(42_i32.into()),
2146        //                 DataType::Int32,
2147        //             ))),
2148        //         ],
2149        //     )
2150        //     .unwrap(),
2151        // ));
2152        // // Condition: ($1 = $3) AND ($2 == 42)
2153        // let on_cond = ExprImpl::FunctionCall(Box::new(
2154        //     FunctionCall::new(Type::And, vec![eq_cond, non_eq_cond]).unwrap(),
2155        // ));
2156
2157        // let join_type = JoinType::LeftOuter;
2158        // let logical_join = LogicalJoin::new(
2159        //     left.clone().into(),
2160        //     right.clone().into(),
2161        //     join_type,
2162        //     Condition::with_expr(on_cond.clone()),
2163        // );
2164
2165        // // Perform `to_stream`
2166        // let result = logical_join.to_stream();
2167
2168        // // Expected plan: HashJoin(($1 = $3) AND ($2 == 42))
2169        // let hash_join = result.as_stream_hash_join().unwrap();
2170        // assert_eq!(hash_join.eq_join_predicate().all_cond().as_expr(), on_cond);
2171    }
2172    /// Pruning
2173    /// ```text
2174    /// Join(on: input_ref(1)=input_ref(3))
2175    ///   TableScan(v1, v2, v3)
2176    ///   TableScan(v4, v5, v6)
2177    /// ```
2178    /// with required columns [3, 2] will result in
2179    /// ```text
2180    /// Project(input_ref(2), input_ref(1))
2181    ///   Join(on: input_ref(0)=input_ref(2))
2182    ///     TableScan(v2, v3)
2183    ///     TableScan(v4)
2184    /// ```
2185    #[tokio::test]
2186    async fn test_join_column_prune_with_order_required() {
2187        let ty = DataType::Int32;
2188        let ctx = OptimizerContext::mock().await;
2189        let fields: Vec<Field> = (1..7)
2190            .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
2191            .collect();
2192        let left = LogicalValues::new(
2193            vec![],
2194            Schema {
2195                fields: fields[0..3].to_vec(),
2196            },
2197            ctx.clone(),
2198        );
2199        let right = LogicalValues::new(
2200            vec![],
2201            Schema {
2202                fields: fields[3..6].to_vec(),
2203            },
2204            ctx,
2205        );
2206        let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
2207            FunctionCall::new(
2208                Type::Equal,
2209                vec![
2210                    ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
2211                    ExprImpl::InputRef(Box::new(InputRef::new(3, ty))),
2212                ],
2213            )
2214            .unwrap(),
2215        ));
2216        let join_type = JoinType::Inner;
2217        let join: PlanRef = LogicalJoin::new(
2218            left.into(),
2219            right.into(),
2220            join_type,
2221            Condition::with_expr(on),
2222        )
2223        .into();
2224
2225        // Perform the prune
2226        let required_cols = vec![3, 2];
2227        let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
2228
2229        // Check the result
2230        let join = plan.as_logical_join().unwrap();
2231        assert_eq!(join.schema().fields().len(), 2);
2232        assert_eq!(join.schema().fields()[0], fields[3]);
2233        assert_eq!(join.schema().fields()[1], fields[2]);
2234
2235        let expr: ExprImpl = join.on().clone().into();
2236        let call = expr.as_function_call().unwrap();
2237        assert_eq_input_ref!(&call.inputs()[0], 0);
2238        assert_eq_input_ref!(&call.inputs()[1], 2);
2239
2240        let left = join.left();
2241        let left = left.as_logical_values().unwrap();
2242        assert_eq!(left.schema().fields(), &fields[1..3]);
2243        let right = join.right();
2244        let right = right.as_logical_values().unwrap();
2245        assert_eq!(right.schema().fields(), &fields[3..4]);
2246    }
2247
2248    #[tokio::test]
2249    async fn fd_derivation_inner_outer_join() {
2250        // left: [l0, l1], right: [r0, r1, r2]
2251        // FD: l0 --> l1, r0 --> { r1, r2 }
2252        // On: l0 = 0 AND l1 = r1
2253        //
2254        // Inner Join:
2255        //  Schema: [l0, l1, r0, r1, r2]
2256        //  FD: l0 --> l1, r0 --> { r1, r2 }, {} --> l0, l1 --> r1, r1 --> l1
2257        // Left Outer Join:
2258        //  Schema: [l0, l1, r0, r1, r2]
2259        //  FD: l0 --> l1
2260        // Right Outer Join:
2261        //  Schema: [l0, l1, r0, r1, r2]
2262        //  FD: r0 --> { r1, r2 }
2263        // Full Outer Join:
2264        //  Schema: [l0, l1, r0, r1, r2]
2265        //  FD: empty
2266        // Left Semi/Anti Join:
2267        //  Schema: [l0, l1]
2268        //  FD: l0 --> l1
2269        // Right Semi/Anti Join:
2270        //  Schema: [r0, r1, r2]
2271        //  FD: r0 --> {r1, r2}
2272        let ctx = OptimizerContext::mock().await;
2273        let left = {
2274            let fields: Vec<Field> = vec![
2275                Field::with_name(DataType::Int32, "l0"),
2276                Field::with_name(DataType::Int32, "l1"),
2277            ];
2278            let mut values = LogicalValues::new(vec![], Schema { fields }, ctx.clone());
2279            // 0 --> 1
2280            values
2281                .base
2282                .functional_dependency_mut()
2283                .add_functional_dependency_by_column_indices(&[0], &[1]);
2284            values
2285        };
2286        let right = {
2287            let fields: Vec<Field> = vec![
2288                Field::with_name(DataType::Int32, "r0"),
2289                Field::with_name(DataType::Int32, "r1"),
2290                Field::with_name(DataType::Int32, "r2"),
2291            ];
2292            let mut values = LogicalValues::new(vec![], Schema { fields }, ctx);
2293            // 0 --> 1, 2
2294            values
2295                .base
2296                .functional_dependency_mut()
2297                .add_functional_dependency_by_column_indices(&[0], &[1, 2]);
2298            values
2299        };
2300        // l0 = 0 AND l1 = r1
2301        let on: ExprImpl = FunctionCall::new(
2302            Type::And,
2303            vec![
2304                FunctionCall::new(
2305                    Type::Equal,
2306                    vec![
2307                        InputRef::new(0, DataType::Int32).into(),
2308                        ExprImpl::literal_int(0),
2309                    ],
2310                )
2311                .unwrap()
2312                .into(),
2313                FunctionCall::new(
2314                    Type::Equal,
2315                    vec![
2316                        InputRef::new(1, DataType::Int32).into(),
2317                        InputRef::new(3, DataType::Int32).into(),
2318                    ],
2319                )
2320                .unwrap()
2321                .into(),
2322            ],
2323        )
2324        .unwrap()
2325        .into();
2326        let expected_fd_set = [
2327            (
2328                JoinType::Inner,
2329                [
2330                    // inherit from left
2331                    FunctionalDependency::with_indices(5, &[0], &[1]),
2332                    // inherit from right
2333                    FunctionalDependency::with_indices(5, &[2], &[3, 4]),
2334                    // constant column in join condition
2335                    FunctionalDependency::with_indices(5, &[], &[0]),
2336                    // eq column in join condition
2337                    FunctionalDependency::with_indices(5, &[1], &[3]),
2338                    FunctionalDependency::with_indices(5, &[3], &[1]),
2339                ]
2340                .into_iter()
2341                .collect::<HashSet<_>>(),
2342            ),
2343            (JoinType::FullOuter, HashSet::new()),
2344            (
2345                JoinType::RightOuter,
2346                [
2347                    // inherit from right
2348                    FunctionalDependency::with_indices(5, &[2], &[3, 4]),
2349                ]
2350                .into_iter()
2351                .collect::<HashSet<_>>(),
2352            ),
2353            (
2354                JoinType::LeftOuter,
2355                [
2356                    // inherit from left
2357                    FunctionalDependency::with_indices(5, &[0], &[1]),
2358                ]
2359                .into_iter()
2360                .collect::<HashSet<_>>(),
2361            ),
2362            (
2363                JoinType::LeftSemi,
2364                [
2365                    // inherit from left
2366                    FunctionalDependency::with_indices(2, &[0], &[1]),
2367                ]
2368                .into_iter()
2369                .collect::<HashSet<_>>(),
2370            ),
2371            (
2372                JoinType::LeftAnti,
2373                [
2374                    // inherit from left
2375                    FunctionalDependency::with_indices(2, &[0], &[1]),
2376                ]
2377                .into_iter()
2378                .collect::<HashSet<_>>(),
2379            ),
2380            (
2381                JoinType::RightSemi,
2382                [
2383                    // inherit from right
2384                    FunctionalDependency::with_indices(3, &[0], &[1, 2]),
2385                ]
2386                .into_iter()
2387                .collect::<HashSet<_>>(),
2388            ),
2389            (
2390                JoinType::RightAnti,
2391                [
2392                    // inherit from right
2393                    FunctionalDependency::with_indices(3, &[0], &[1, 2]),
2394                ]
2395                .into_iter()
2396                .collect::<HashSet<_>>(),
2397            ),
2398        ];
2399
2400        for (join_type, expected_res) in expected_fd_set {
2401            let join = LogicalJoin::new(
2402                left.clone().into(),
2403                right.clone().into(),
2404                join_type,
2405                Condition::with_expr(on.clone()),
2406            );
2407            let fd_set = join
2408                .functional_dependency()
2409                .as_dependencies()
2410                .iter()
2411                .cloned()
2412                .collect::<HashSet<_>>();
2413            assert_eq!(fd_set, expected_res);
2414        }
2415    }
2416}