risingwave_frontend/optimizer/plan_node/
logical_join.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16
17use fixedbitset::FixedBitSet;
18use itertools::{EitherOrBoth, Itertools};
19use pretty_xmlish::{Pretty, XmlNode};
20use risingwave_expr::bail;
21use risingwave_pb::expr::expr_node::PbType;
22use risingwave_pb::plan_common::{AsOfJoinDesc, JoinType, PbAsOfJoinInequalityType};
23use risingwave_pb::stream_plan::StreamScanType;
24use risingwave_sqlparser::ast::AsOf;
25
26use super::generic::{
27    GenericPlanNode, GenericPlanRef, push_down_into_join, push_down_join_condition,
28};
29use super::utils::{Distill, childless_record};
30use super::{
31    BatchPlanRef, ColPrunable, ExprRewritable, Logical, LogicalPlanRef as PlanRef, PlanBase,
32    PlanTreeNodeBinary, PredicatePushdown, StreamHashJoin, StreamPlanRef, StreamProject, ToBatch,
33    ToStream, generic,
34};
35use crate::error::{ErrorCode, Result, RwError};
36use crate::expr::{CollectInputRef, Expr, ExprImpl, ExprRewriter, ExprType, ExprVisitor, InputRef};
37use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
38use crate::optimizer::plan_node::generic::DynamicFilter;
39use crate::optimizer::plan_node::stream_asof_join::StreamAsOfJoin;
40use crate::optimizer::plan_node::utils::IndicesDisplay;
41use crate::optimizer::plan_node::{
42    BatchHashJoin, BatchLookupJoin, BatchNestedLoopJoin, ColumnPruningContext, EqJoinPredicate,
43    LogicalFilter, LogicalScan, PredicatePushdownContext, RewriteStreamContext,
44    StreamDynamicFilter, StreamFilter, StreamTableScan, StreamTemporalJoin, ToStreamContext,
45};
46use crate::optimizer::plan_visitor::LogicalCardinalityExt;
47use crate::optimizer::property::{Distribution, RequiredDist};
48use crate::utils::{ColIndexMapping, ColIndexMappingRewriteExt, Condition, ConditionDisplay};
49
50/// `LogicalJoin` combines two relations according to some condition.
51///
52/// Each output row has fields from the left and right inputs. The set of output rows is a subset
53/// of the cartesian product of the two inputs; precisely which subset depends on the join
54/// condition. In addition, the output columns are a subset of the columns of the left and
55/// right columns, dependent on the output indices provided. A repeat output index is illegal.
56#[derive(Debug, Clone, PartialEq, Eq, Hash)]
57pub struct LogicalJoin {
58    pub base: PlanBase<Logical>,
59    core: generic::Join<PlanRef>,
60}
61
62impl Distill for LogicalJoin {
63    fn distill<'a>(&self) -> XmlNode<'a> {
64        let verbose = self.base.ctx().is_explain_verbose();
65        let mut vec = Vec::with_capacity(if verbose { 3 } else { 2 });
66        vec.push(("type", Pretty::debug(&self.join_type())));
67
68        let concat_schema = self.core.concat_schema();
69        let cond = Pretty::debug(&ConditionDisplay {
70            condition: self.on(),
71            input_schema: &concat_schema,
72        });
73        vec.push(("on", cond));
74
75        if verbose {
76            let data = IndicesDisplay::from_join(&self.core, &concat_schema);
77            vec.push(("output", data));
78        }
79
80        childless_record("LogicalJoin", vec)
81    }
82}
83
84impl LogicalJoin {
85    pub(crate) fn new(left: PlanRef, right: PlanRef, join_type: JoinType, on: Condition) -> Self {
86        let core = generic::Join::with_full_output(left, right, join_type, on);
87        Self::with_core(core)
88    }
89
90    pub(crate) fn with_output_indices(
91        left: PlanRef,
92        right: PlanRef,
93        join_type: JoinType,
94        on: Condition,
95        output_indices: Vec<usize>,
96    ) -> Self {
97        let core = generic::Join::new(left, right, on, join_type, output_indices);
98        Self::with_core(core)
99    }
100
101    pub fn with_core(core: generic::Join<PlanRef>) -> Self {
102        let base = PlanBase::new_logical_with_core(&core);
103        LogicalJoin { base, core }
104    }
105
106    pub fn create(
107        left: PlanRef,
108        right: PlanRef,
109        join_type: JoinType,
110        on_clause: ExprImpl,
111    ) -> PlanRef {
112        Self::new(left, right, join_type, Condition::with_expr(on_clause)).into()
113    }
114
115    pub fn internal_column_num(&self) -> usize {
116        self.core.internal_column_num()
117    }
118
119    pub fn i2l_col_mapping_ignore_join_type(&self) -> ColIndexMapping {
120        self.core.i2l_col_mapping_ignore_join_type()
121    }
122
123    pub fn i2r_col_mapping_ignore_join_type(&self) -> ColIndexMapping {
124        self.core.i2r_col_mapping_ignore_join_type()
125    }
126
127    /// Get a reference to the logical join's on.
128    pub fn on(&self) -> &Condition {
129        &self.core.on
130    }
131
132    pub fn core(&self) -> &generic::Join<PlanRef> {
133        &self.core
134    }
135
136    /// Collect all input ref in the on condition. And separate them into left and right.
137    pub fn input_idx_on_condition(&self) -> (Vec<usize>, Vec<usize>) {
138        let input_refs = self
139            .core
140            .on
141            .collect_input_refs(self.core.left.schema().len() + self.core.right.schema().len());
142        let index_group = input_refs
143            .ones()
144            .chunk_by(|i| *i < self.core.left.schema().len());
145        let left_index = index_group
146            .into_iter()
147            .next()
148            .map_or(vec![], |group| group.1.collect_vec());
149        let right_index = index_group.into_iter().next().map_or(vec![], |group| {
150            group
151                .1
152                .map(|i| i - self.core.left.schema().len())
153                .collect_vec()
154        });
155        (left_index, right_index)
156    }
157
158    /// Get the join type of the logical join.
159    pub fn join_type(&self) -> JoinType {
160        self.core.join_type
161    }
162
163    /// Get the eq join key of the logical join.
164    pub fn eq_indexes(&self) -> Vec<(usize, usize)> {
165        self.core.eq_indexes()
166    }
167
168    /// Get the output indices of the logical join.
169    pub fn output_indices(&self) -> &Vec<usize> {
170        &self.core.output_indices
171    }
172
173    /// Clone with new output indices
174    pub fn clone_with_output_indices(&self, output_indices: Vec<usize>) -> Self {
175        Self::with_core(generic::Join {
176            output_indices,
177            ..self.core.clone()
178        })
179    }
180
181    /// Clone with new `on` condition
182    pub fn clone_with_cond(&self, on: Condition) -> Self {
183        Self::with_core(generic::Join {
184            on,
185            ..self.core.clone()
186        })
187    }
188
189    pub fn is_left_join(&self) -> bool {
190        matches!(self.join_type(), JoinType::LeftSemi | JoinType::LeftAnti)
191    }
192
193    pub fn is_right_join(&self) -> bool {
194        matches!(self.join_type(), JoinType::RightSemi | JoinType::RightAnti)
195    }
196
197    pub fn is_full_out(&self) -> bool {
198        self.core.is_full_out()
199    }
200
201    pub fn is_asof_join(&self) -> bool {
202        self.join_type() == JoinType::AsofInner || self.join_type() == JoinType::AsofLeftOuter
203    }
204
205    pub fn output_indices_are_trivial(&self) -> bool {
206        self.output_indices() == &(0..self.internal_column_num()).collect_vec()
207    }
208
209    /// Try to simplify the outer join with the predicate on the top of the join
210    ///
211    /// now it is just a naive implementation for comparison expression, we can give a more general
212    /// implementation with constant folding in future
213    fn simplify_outer(predicate: &Condition, left_col_num: usize, join_type: JoinType) -> JoinType {
214        let (mut gen_null_in_left, mut gen_null_in_right) = match join_type {
215            JoinType::LeftOuter => (false, true),
216            JoinType::RightOuter => (true, false),
217            JoinType::FullOuter => (true, true),
218            _ => return join_type,
219        };
220
221        for expr in &predicate.conjunctions {
222            if let ExprImpl::FunctionCall(func) = expr {
223                match func.func_type() {
224                    ExprType::Equal
225                    | ExprType::NotEqual
226                    | ExprType::LessThan
227                    | ExprType::LessThanOrEqual
228                    | ExprType::GreaterThan
229                    | ExprType::GreaterThanOrEqual => {
230                        for input in func.inputs() {
231                            if let ExprImpl::InputRef(input) = input {
232                                let idx = input.index;
233                                if idx < left_col_num {
234                                    gen_null_in_left = false;
235                                } else {
236                                    gen_null_in_right = false;
237                                }
238                            }
239                        }
240                    }
241                    _ => {}
242                };
243            }
244        }
245
246        match (gen_null_in_left, gen_null_in_right) {
247            (true, true) => JoinType::FullOuter,
248            (true, false) => JoinType::RightOuter,
249            (false, true) => JoinType::LeftOuter,
250            (false, false) => JoinType::Inner,
251        }
252    }
253
254    /// Index Join:
255    /// Try to convert logical join into batch lookup join and meanwhile it will do
256    /// the index selection for the lookup table so that we can benefit from indexes.
257    fn to_batch_lookup_join_with_index_selection(
258        &self,
259        predicate: EqJoinPredicate,
260        batch_join: generic::Join<BatchPlanRef>,
261    ) -> Result<Option<BatchLookupJoin>> {
262        match batch_join.join_type {
263            JoinType::Inner
264            | JoinType::LeftOuter
265            | JoinType::LeftSemi
266            | JoinType::LeftAnti
267            | JoinType::AsofInner
268            | JoinType::AsofLeftOuter => {}
269            _ => return Ok(None),
270        };
271
272        // Index selection for index join.
273        let right = self.right();
274        // Lookup Join only supports basic tables on the join's right side.
275        let logical_scan: &LogicalScan = if let Some(logical_scan) = right.as_logical_scan() {
276            logical_scan
277        } else {
278            return Ok(None);
279        };
280
281        let mut result_plan = None;
282        // Lookup primary table.
283        if let Some(lookup_join) =
284            self.to_batch_lookup_join(predicate.clone(), batch_join.clone())?
285        {
286            result_plan = Some(lookup_join);
287        }
288
289        if self
290            .core
291            .ctx()
292            .session_ctx()
293            .config()
294            .enable_index_selection()
295        {
296            let indexes = logical_scan.table_indexes();
297            for index in indexes {
298                if let Some(index_scan) = logical_scan.to_index_scan_if_index_covered(index) {
299                    let index_scan: PlanRef = index_scan.into();
300                    let that = self.clone_with_left_right(self.left(), index_scan.clone());
301                    let mut new_batch_join = batch_join.clone();
302                    new_batch_join.right =
303                        index_scan.to_batch().expect("index scan failed to batch");
304
305                    // Lookup covered index.
306                    if let Some(lookup_join) =
307                        that.to_batch_lookup_join(predicate.clone(), new_batch_join)?
308                    {
309                        match &result_plan {
310                            None => result_plan = Some(lookup_join),
311                            Some(prev_lookup_join) => {
312                                // Prefer to choose lookup join with longer lookup prefix len.
313                                if prev_lookup_join.lookup_prefix_len()
314                                    < lookup_join.lookup_prefix_len()
315                                {
316                                    result_plan = Some(lookup_join)
317                                }
318                            }
319                        }
320                    }
321                }
322            }
323        }
324
325        Ok(result_plan)
326    }
327
328    /// Try to convert logical join into batch lookup join.
329    fn to_batch_lookup_join(
330        &self,
331        predicate: EqJoinPredicate,
332        logical_join: generic::Join<BatchPlanRef>,
333    ) -> Result<Option<BatchLookupJoin>> {
334        let logical_scan: &LogicalScan =
335            if let Some(logical_scan) = self.core.right.as_logical_scan() {
336                logical_scan
337            } else {
338                return Ok(None);
339            };
340        Self::gen_batch_lookup_join(logical_scan, predicate, logical_join, self.is_asof_join())
341    }
342
343    pub fn gen_batch_lookup_join(
344        logical_scan: &LogicalScan,
345        predicate: EqJoinPredicate,
346        logical_join: generic::Join<BatchPlanRef>,
347        is_as_of: bool,
348    ) -> Result<Option<BatchLookupJoin>> {
349        match logical_join.join_type {
350            JoinType::Inner
351            | JoinType::LeftOuter
352            | JoinType::LeftSemi
353            | JoinType::LeftAnti
354            | JoinType::AsofInner
355            | JoinType::AsofLeftOuter => {}
356            _ => return Ok(None),
357        };
358
359        let table = logical_scan.table();
360        let output_column_ids = logical_scan.output_column_ids();
361
362        // Verify that the right join key columns are the the prefix of the primary key and
363        // also contain the distribution key.
364        let order_col_ids = table.order_column_ids();
365        let dist_key = table.distribution_key.clone();
366        // The at least prefix of order key that contains distribution key.
367        let mut dist_key_in_order_key_pos = vec![];
368        for d in dist_key {
369            let pos = table
370                .order_column_indices()
371                .position(|x| x == d)
372                .expect("dist_key must in order_key");
373            dist_key_in_order_key_pos.push(pos);
374        }
375        // The shortest prefix of order key that contains distribution key.
376        let shortest_prefix_len = dist_key_in_order_key_pos
377            .iter()
378            .max()
379            .map_or(0, |pos| pos + 1);
380
381        // Distributed lookup join can't support lookup table with a singleton distribution.
382        if shortest_prefix_len == 0 {
383            return Ok(None);
384        }
385
386        // Reorder the join equal predicate to match the order key.
387        let mut reorder_idx = Vec::with_capacity(shortest_prefix_len);
388        for order_col_id in order_col_ids {
389            let mut found = false;
390            for (i, eq_idx) in predicate.right_eq_indexes().into_iter().enumerate() {
391                if order_col_id == output_column_ids[eq_idx] {
392                    reorder_idx.push(i);
393                    found = true;
394                    break;
395                }
396            }
397            if !found {
398                break;
399            }
400        }
401        if reorder_idx.len() < shortest_prefix_len {
402            return Ok(None);
403        }
404        let lookup_prefix_len = reorder_idx.len();
405        let predicate = predicate.reorder(&reorder_idx);
406
407        // Extract the predicate from logical scan. Only pure scan is supported.
408        let (new_scan, scan_predicate, project_expr) = logical_scan.predicate_pull_up();
409        // Construct output column to require column mapping
410        let o2r = if let Some(project_expr) = project_expr {
411            project_expr
412                .into_iter()
413                .map(|x| x.as_input_ref().unwrap().index)
414                .collect_vec()
415        } else {
416            (0..logical_scan.output_col_idx().len()).collect_vec()
417        };
418        let left_schema_len = logical_join.left.schema().len();
419
420        let mut join_predicate_rewriter = LookupJoinPredicateRewriter {
421            offset: left_schema_len,
422            mapping: o2r.clone(),
423        };
424
425        let new_eq_cond = predicate
426            .eq_cond()
427            .rewrite_expr(&mut join_predicate_rewriter);
428
429        let mut scan_predicate_rewriter = LookupJoinScanPredicateRewriter {
430            offset: left_schema_len,
431        };
432
433        let new_other_cond = predicate
434            .other_cond()
435            .clone()
436            .rewrite_expr(&mut join_predicate_rewriter)
437            .and(scan_predicate.rewrite_expr(&mut scan_predicate_rewriter));
438
439        let new_join_on = new_eq_cond.and(new_other_cond);
440        let new_predicate = EqJoinPredicate::create(
441            left_schema_len,
442            new_scan.schema().len(),
443            new_join_on.clone(),
444        );
445
446        // We discovered that we cannot use a lookup join after pulling up the predicate
447        // from one side and simplifying the condition. Let's use some other join instead.
448        if !new_predicate.has_eq() {
449            return Ok(None);
450        }
451
452        // Rewrite the join output indices and all output indices referred to the old scan need to
453        // rewrite.
454        let new_join_output_indices = logical_join
455            .output_indices
456            .iter()
457            .map(|&x| {
458                if x < left_schema_len {
459                    x
460                } else {
461                    o2r[x - left_schema_len] + left_schema_len
462                }
463            })
464            .collect_vec();
465
466        let new_scan_output_column_ids = new_scan.output_column_ids();
467        let as_of = new_scan.as_of.clone();
468        let new_logical_scan: LogicalScan = new_scan.into();
469
470        // Construct a new logical join, because we have change its RHS.
471        let new_logical_join = generic::Join::new(
472            logical_join.left,
473            new_logical_scan.to_batch()?,
474            new_join_on,
475            logical_join.join_type,
476            new_join_output_indices,
477        );
478
479        let asof_desc = is_as_of
480            .then(|| {
481                Self::get_inequality_desc_from_predicate(
482                    predicate.other_cond().clone(),
483                    left_schema_len,
484                )
485            })
486            .transpose()?;
487
488        Ok(Some(BatchLookupJoin::new(
489            new_logical_join,
490            new_predicate,
491            table.clone(),
492            new_scan_output_column_ids,
493            lookup_prefix_len,
494            false,
495            as_of,
496            asof_desc,
497        )))
498    }
499
500    pub fn decompose(self) -> (PlanRef, PlanRef, Condition, JoinType, Vec<usize>) {
501        self.core.decompose()
502    }
503}
504
505impl PlanTreeNodeBinary<Logical> for LogicalJoin {
506    fn left(&self) -> PlanRef {
507        self.core.left.clone()
508    }
509
510    fn right(&self) -> PlanRef {
511        self.core.right.clone()
512    }
513
514    fn clone_with_left_right(&self, left: PlanRef, right: PlanRef) -> Self {
515        Self::with_core(generic::Join {
516            left,
517            right,
518            ..self.core.clone()
519        })
520    }
521
522    fn rewrite_with_left_right(
523        &self,
524        left: PlanRef,
525        left_col_change: ColIndexMapping,
526        right: PlanRef,
527        right_col_change: ColIndexMapping,
528    ) -> (Self, ColIndexMapping) {
529        let (new_on, new_output_indices) = {
530            let (mut map, _) = left_col_change.clone().into_parts();
531            let (mut right_map, _) = right_col_change.clone().into_parts();
532            for i in right_map.iter_mut().flatten() {
533                *i += left.schema().len();
534            }
535            map.append(&mut right_map);
536            let mut mapping = ColIndexMapping::new(map, left.schema().len() + right.schema().len());
537
538            let new_output_indices = self
539                .output_indices()
540                .iter()
541                .map(|&i| mapping.map(i))
542                .collect::<Vec<_>>();
543            let new_on = self.on().clone().rewrite_expr(&mut mapping);
544            (new_on, new_output_indices)
545        };
546
547        let join = Self::with_output_indices(
548            left,
549            right,
550            self.join_type(),
551            new_on,
552            new_output_indices.clone(),
553        );
554
555        let new_i2o = ColIndexMapping::with_remaining_columns(
556            &new_output_indices,
557            join.internal_column_num(),
558        );
559
560        let old_o2i = self.core.o2i_col_mapping();
561
562        let old_o2l = old_o2i
563            .composite(&self.core.i2l_col_mapping())
564            .composite(&left_col_change);
565        let old_o2r = old_o2i
566            .composite(&self.core.i2r_col_mapping())
567            .composite(&right_col_change);
568        let new_l2o = join.core.l2i_col_mapping().composite(&new_i2o);
569        let new_r2o = join.core.r2i_col_mapping().composite(&new_i2o);
570
571        let out_col_change = old_o2l
572            .composite(&new_l2o)
573            .union(&old_o2r.composite(&new_r2o));
574        (join, out_col_change)
575    }
576}
577
578impl_plan_tree_node_for_binary! { Logical, LogicalJoin }
579
580impl ColPrunable for LogicalJoin {
581    fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef {
582        // make `required_cols` point to internal table instead of output schema.
583        let required_cols = required_cols
584            .iter()
585            .map(|i| self.output_indices()[*i])
586            .collect_vec();
587        let left_len = self.left().schema().fields.len();
588
589        let total_len = self.left().schema().len() + self.right().schema().len();
590        let mut resized_required_cols = FixedBitSet::with_capacity(total_len);
591
592        required_cols.iter().for_each(|&i| {
593            if self.is_right_join() {
594                resized_required_cols.insert(left_len + i);
595            } else {
596                resized_required_cols.insert(i);
597            }
598        });
599
600        // add those columns which are required in the join condition to
601        // to those that are required in the output
602        let mut visitor = CollectInputRef::new(resized_required_cols);
603        self.on().visit_expr(&mut visitor);
604        let left_right_required_cols = FixedBitSet::from(visitor).ones().collect_vec();
605
606        let mut left_required_cols = Vec::new();
607        let mut right_required_cols = Vec::new();
608        left_right_required_cols.iter().for_each(|&i| {
609            if i < left_len {
610                left_required_cols.push(i);
611            } else {
612                right_required_cols.push(i - left_len);
613            }
614        });
615
616        let mut on = self.on().clone();
617        let mut mapping =
618            ColIndexMapping::with_remaining_columns(&left_right_required_cols, total_len);
619        on = on.rewrite_expr(&mut mapping);
620
621        let new_output_indices = {
622            let required_inputs_in_output = if self.is_left_join() {
623                &left_required_cols
624            } else if self.is_right_join() {
625                &right_required_cols
626            } else {
627                &left_right_required_cols
628            };
629
630            let mapping =
631                ColIndexMapping::with_remaining_columns(required_inputs_in_output, total_len);
632            required_cols.iter().map(|&i| mapping.map(i)).collect_vec()
633        };
634
635        LogicalJoin::with_output_indices(
636            self.left().prune_col(&left_required_cols, ctx),
637            self.right().prune_col(&right_required_cols, ctx),
638            self.join_type(),
639            on,
640            new_output_indices,
641        )
642        .into()
643    }
644}
645
646impl ExprRewritable<Logical> for LogicalJoin {
647    fn has_rewritable_expr(&self) -> bool {
648        true
649    }
650
651    fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
652        let mut core = self.core.clone();
653        core.rewrite_exprs(r);
654        Self {
655            base: self.base.clone_with_new_plan_id(),
656            core,
657        }
658        .into()
659    }
660}
661
662impl ExprVisitable for LogicalJoin {
663    fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
664        self.core.visit_exprs(v);
665    }
666}
667
668/// We are trying to derive a predicate to apply to the other side of a join if all
669/// the `InputRef`s in the predicate are eq condition columns, and can hence be substituted
670/// with the corresponding eq condition columns of the other side.
671///
672/// Strategy:
673/// 1. If the function is pure except for any `InputRef` (which may refer to impure computation),
674///    then we proceed. Else abort.
675/// 2. Then, we collect `InputRef`s in the conjunction.
676/// 3. If they are all columns in the given side of join eq condition, then we proceed. Else abort.
677/// 4. We then rewrite the `ExprImpl`, by replacing `InputRef` column indices with the equivalent in
678///    the other side.
679///
680/// # Arguments
681///
682/// Suppose we derive a predicate from the left side to be pushed to the right side.
683/// * `expr`: An expr from the left side.
684/// * `col_num`: The number of columns in the left side.
685fn derive_predicate_from_eq_condition(
686    expr: &ExprImpl,
687    eq_condition: &EqJoinPredicate,
688    col_num: usize,
689    expr_is_left: bool,
690) -> Option<ExprImpl> {
691    if expr.is_impure() {
692        return None;
693    }
694    let eq_indices = eq_condition
695        .eq_indexes_typed()
696        .iter()
697        .filter_map(|(l, r)| {
698            if l.return_type() != r.return_type() {
699                None
700            } else if expr_is_left {
701                Some(l.index())
702            } else {
703                Some(r.index())
704            }
705        })
706        .collect_vec();
707    if expr
708        .collect_input_refs(col_num)
709        .ones()
710        .any(|index| !eq_indices.contains(&index))
711    {
712        // expr contains an InputRef not in eq_condition
713        return None;
714    }
715    // The function is pure except for `InputRef` and all `InputRef`s are `eq_condition` indices.
716    // Hence, we can substitute those `InputRef`s with indices from the other side.
717    let other_side_mapping = if expr_is_left {
718        eq_condition.eq_indexes_typed().into_iter().collect()
719    } else {
720        eq_condition
721            .eq_indexes_typed()
722            .into_iter()
723            .map(|(x, y)| (y, x))
724            .collect()
725    };
726    struct InputRefsRewriter {
727        mapping: HashMap<InputRef, InputRef>,
728    }
729    impl ExprRewriter for InputRefsRewriter {
730        fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
731            self.mapping[&input_ref].clone().into()
732        }
733    }
734    Some(
735        InputRefsRewriter {
736            mapping: other_side_mapping,
737        }
738        .rewrite_expr(expr.clone()),
739    )
740}
741
742/// Rewrite the join predicate and all columns referred to the scan side need to rewrite.
743struct LookupJoinPredicateRewriter {
744    offset: usize,
745    mapping: Vec<usize>,
746}
747impl ExprRewriter for LookupJoinPredicateRewriter {
748    fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
749        if input_ref.index() < self.offset {
750            input_ref.into()
751        } else {
752            InputRef::new(
753                self.mapping[input_ref.index() - self.offset] + self.offset,
754                input_ref.return_type(),
755            )
756            .into()
757        }
758    }
759}
760
761/// Rewrite the scan predicate so we can add it to the join predicate.
762struct LookupJoinScanPredicateRewriter {
763    offset: usize,
764}
765impl ExprRewriter for LookupJoinScanPredicateRewriter {
766    fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
767        InputRef::new(input_ref.index() + self.offset, input_ref.return_type()).into()
768    }
769}
770
771impl PredicatePushdown for LogicalJoin {
772    /// Pushes predicates above and within a join node into the join node and/or its children nodes.
773    ///
774    /// # Which predicates can be pushed
775    ///
776    /// For inner join, we can do all kinds of pushdown.
777    ///
778    /// For left/right semi join, we can push filter to left/right and on-clause,
779    /// and push on-clause to left/right.
780    ///
781    /// For left/right anti join, we can push filter to left/right, but on-clause can not be pushed
782    ///
783    /// ## Outer Join
784    ///
785    /// Preserved Row table
786    /// : The table in an Outer Join that must return all rows.
787    ///
788    /// Null Supplying table
789    /// : This is the table that has nulls filled in for its columns in unmatched rows.
790    ///
791    /// |                          | Preserved Row table | Null Supplying table |
792    /// |--------------------------|---------------------|----------------------|
793    /// | Join predicate (on)      | Not Pushed          | Pushed               |
794    /// | Where predicate (filter) | Pushed              | Not Pushed           |
795    fn predicate_pushdown(
796        &self,
797        predicate: Condition,
798        ctx: &mut PredicatePushdownContext,
799    ) -> PlanRef {
800        // rewrite output col referencing indices as internal cols
801        let mut predicate = {
802            let mut mapping = self.core.o2i_col_mapping();
803            predicate.rewrite_expr(&mut mapping)
804        };
805
806        let left_col_num = self.left().schema().len();
807        let right_col_num = self.right().schema().len();
808        let join_type = LogicalJoin::simplify_outer(&predicate, left_col_num, self.join_type());
809
810        let push_down_temporal_predicate = !self.should_be_temporal_join();
811
812        let (left_from_filter, right_from_filter, on) = push_down_into_join(
813            &mut predicate,
814            left_col_num,
815            right_col_num,
816            join_type,
817            push_down_temporal_predicate,
818        );
819
820        let mut new_on = self.on().clone().and(on);
821        let (left_from_on, right_from_on) = push_down_join_condition(
822            &mut new_on,
823            left_col_num,
824            right_col_num,
825            join_type,
826            push_down_temporal_predicate,
827        );
828
829        let left_predicate = left_from_filter.and(left_from_on);
830        let right_predicate = right_from_filter.and(right_from_on);
831
832        // Derive conditions to push to the other side based on eq condition columns
833        let eq_condition = EqJoinPredicate::create(left_col_num, right_col_num, new_on.clone());
834
835        // Only push to RHS if RHS is inner side of a join (RHS requires match on LHS)
836        let right_from_left = if matches!(
837            join_type,
838            JoinType::Inner | JoinType::LeftOuter | JoinType::RightSemi | JoinType::LeftSemi
839        ) {
840            Condition {
841                conjunctions: left_predicate
842                    .conjunctions
843                    .iter()
844                    .filter_map(|expr| {
845                        derive_predicate_from_eq_condition(expr, &eq_condition, left_col_num, true)
846                    })
847                    .collect(),
848            }
849        } else {
850            Condition::true_cond()
851        };
852
853        // Only push to LHS if LHS is inner side of a join (LHS requires match on RHS)
854        let left_from_right = if matches!(
855            join_type,
856            JoinType::Inner | JoinType::RightOuter | JoinType::LeftSemi | JoinType::RightSemi
857        ) {
858            Condition {
859                conjunctions: right_predicate
860                    .conjunctions
861                    .iter()
862                    .filter_map(|expr| {
863                        derive_predicate_from_eq_condition(
864                            expr,
865                            &eq_condition,
866                            right_col_num,
867                            false,
868                        )
869                    })
870                    .collect(),
871            }
872        } else {
873            Condition::true_cond()
874        };
875
876        let left_predicate = left_predicate.and(left_from_right);
877        let right_predicate = right_predicate.and(right_from_left);
878
879        let new_left = self.left().predicate_pushdown(left_predicate, ctx);
880        let new_right = self.right().predicate_pushdown(right_predicate, ctx);
881        let new_join = LogicalJoin::with_output_indices(
882            new_left,
883            new_right,
884            join_type,
885            new_on,
886            self.output_indices().clone(),
887        );
888
889        let mut mapping = self.core.i2o_col_mapping();
890        predicate = predicate.rewrite_expr(&mut mapping);
891        LogicalFilter::create(new_join.into(), predicate)
892    }
893}
894
895impl LogicalJoin {
896    fn get_stream_input_for_hash_join(
897        &self,
898        predicate: &EqJoinPredicate,
899        ctx: &mut ToStreamContext,
900    ) -> Result<(StreamPlanRef, StreamPlanRef)> {
901        use super::stream::prelude::*;
902
903        let lhs_join_key_idx = self.eq_indexes().into_iter().map(|(l, _)| l).collect_vec();
904        let rhs_join_key_idx = self.eq_indexes().into_iter().map(|(_, r)| r).collect_vec();
905
906        let logical_right = self
907            .right()
908            .try_better_locality(&rhs_join_key_idx)
909            .unwrap_or_else(|| self.right());
910        let mut right = logical_right.to_stream_with_dist_required(
911            &RequiredDist::shard_by_key(self.right().schema().len(), &predicate.right_eq_indexes()),
912            ctx,
913        )?;
914        let logical_left = self
915            .left()
916            .try_better_locality(&lhs_join_key_idx)
917            .unwrap_or_else(|| self.left());
918
919        let r2l =
920            predicate.r2l_eq_columns_mapping(logical_left.schema().len(), right.schema().len());
921        let l2r =
922            predicate.l2r_eq_columns_mapping(logical_left.schema().len(), right.schema().len());
923        let mut left;
924        let right_dist = right.distribution();
925        match right_dist {
926            Distribution::HashShard(_) => {
927                let left_dist = r2l
928                    .rewrite_required_distribution(&RequiredDist::PhysicalDist(right_dist.clone()));
929                left = logical_left.to_stream_with_dist_required(&left_dist, ctx)?;
930            }
931            Distribution::UpstreamHashShard(_, _) => {
932                left = logical_left.to_stream_with_dist_required(
933                    &RequiredDist::shard_by_key(
934                        self.left().schema().len(),
935                        &predicate.left_eq_indexes(),
936                    ),
937                    ctx,
938                )?;
939                let left_dist = left.distribution();
940                match left_dist {
941                    Distribution::HashShard(_) => {
942                        let right_dist = l2r.rewrite_required_distribution(
943                            &RequiredDist::PhysicalDist(left_dist.clone()),
944                        );
945                        right = right_dist.streaming_enforce_if_not_satisfies(right)?
946                    }
947                    Distribution::UpstreamHashShard(_, _) => {
948                        left = RequiredDist::hash_shard(&predicate.left_eq_indexes())
949                            .streaming_enforce_if_not_satisfies(left)?;
950                        right = RequiredDist::hash_shard(&predicate.right_eq_indexes())
951                            .streaming_enforce_if_not_satisfies(right)?;
952                    }
953                    _ => unreachable!(),
954                }
955            }
956            _ => unreachable!(),
957        }
958        Ok((left, right))
959    }
960
961    fn to_stream_hash_join(
962        &self,
963        predicate: EqJoinPredicate,
964        ctx: &mut ToStreamContext,
965    ) -> Result<StreamPlanRef> {
966        use super::stream::prelude::*;
967
968        assert!(predicate.has_eq());
969        let (left, right) = self.get_stream_input_for_hash_join(&predicate, ctx)?;
970
971        let core = self.core.clone_with_inputs(left, right);
972
973        // Convert to Hash Join for equal joins
974        // For inner joins, pull non-equal conditions to a filter operator on top of it by default.
975        // We do so as the filter operator can apply the non-equal condition batch-wise (vectorized)
976        // as opposed to the HashJoin, which applies the condition row-wise.
977        // However, the default behavior of pulling up non-equal conditions can be overridden by the
978        // session variable `streaming_force_filter_inside_join` as it can save unnecessary
979        // materialization of rows only to be filtered later.
980
981        let stream_hash_join = StreamHashJoin::new(core.clone(), predicate.clone())?;
982
983        let force_filter_inside_join = self
984            .base
985            .ctx()
986            .session_ctx()
987            .config()
988            .streaming_force_filter_inside_join();
989
990        let pull_filter = self.join_type() == JoinType::Inner
991            && stream_hash_join.eq_join_predicate().has_non_eq()
992            && stream_hash_join.inequality_pairs().is_empty()
993            && (!force_filter_inside_join);
994        if pull_filter {
995            let default_indices = (0..self.internal_column_num()).collect::<Vec<_>>();
996
997            let mut core = core.clone();
998            core.output_indices = default_indices.clone();
999            // Temporarily remove output indices.
1000            let eq_cond = EqJoinPredicate::new(
1001                Condition::true_cond(),
1002                predicate.eq_keys().to_vec(),
1003                self.left().schema().len(),
1004                self.right().schema().len(),
1005            );
1006            core.on = eq_cond.eq_cond();
1007            let hash_join = StreamHashJoin::new(core, eq_cond)?.into();
1008            let logical_filter = generic::Filter::new(predicate.non_eq_cond(), hash_join);
1009            let plan = StreamFilter::new(logical_filter).into();
1010            if self.output_indices() != &default_indices {
1011                let logical_project = generic::Project::with_mapping(
1012                    plan,
1013                    ColIndexMapping::with_remaining_columns(
1014                        self.output_indices(),
1015                        self.internal_column_num(),
1016                    ),
1017                );
1018                Ok(StreamProject::new(logical_project).into())
1019            } else {
1020                Ok(plan)
1021            }
1022        } else {
1023            Ok(stream_hash_join.into())
1024        }
1025    }
1026
1027    fn should_be_temporal_join(&self) -> bool {
1028        let right = self.right();
1029        if let Some(logical_scan) = right.as_logical_scan() {
1030            matches!(logical_scan.as_of(), Some(AsOf::ProcessTime))
1031        } else {
1032            false
1033        }
1034    }
1035
1036    fn to_stream_temporal_join_with_index_selection(
1037        &self,
1038        predicate: EqJoinPredicate,
1039        ctx: &mut ToStreamContext,
1040    ) -> Result<StreamPlanRef> {
1041        // Index selection for temporal join.
1042        let right = self.right();
1043        // `should_be_temporal_join()` has already check right input for us.
1044        let logical_scan: &LogicalScan = right.as_logical_scan().unwrap();
1045
1046        // Use primary table.
1047        let mut result_plan: Result<StreamTemporalJoin> =
1048            self.to_stream_temporal_join(predicate.clone(), ctx);
1049        // Return directly if this temporal join can match the pk of its right table.
1050        if let Ok(temporal_join) = &result_plan
1051            && temporal_join.eq_join_predicate().eq_indexes().len()
1052                == logical_scan.primary_key().len()
1053        {
1054            return result_plan.map(|x| x.into());
1055        }
1056        if self
1057            .core
1058            .ctx()
1059            .session_ctx()
1060            .config()
1061            .enable_index_selection()
1062        {
1063            let indexes = logical_scan.table_indexes();
1064            for index in indexes {
1065                // Use index table
1066                if let Some(index_scan) = logical_scan.to_index_scan_if_index_covered(index) {
1067                    let index_scan: PlanRef = index_scan.into();
1068                    let that = self.clone_with_left_right(self.left(), index_scan.clone());
1069                    if let Ok(temporal_join) = that.to_stream_temporal_join(predicate.clone(), ctx)
1070                    {
1071                        match &result_plan {
1072                            Err(_) => result_plan = Ok(temporal_join),
1073                            Ok(prev_temporal_join) => {
1074                                // Prefer to the temporal join with a longer lookup prefix len.
1075                                if prev_temporal_join.eq_join_predicate().eq_indexes().len()
1076                                    < temporal_join.eq_join_predicate().eq_indexes().len()
1077                                {
1078                                    result_plan = Ok(temporal_join)
1079                                }
1080                            }
1081                        }
1082                    }
1083                }
1084            }
1085        }
1086
1087        result_plan.map(|x| x.into())
1088    }
1089
1090    fn check_temporal_rhs(right: &PlanRef) -> Result<&LogicalScan> {
1091        let Some(logical_scan) = right.as_logical_scan() else {
1092            return Err(RwError::from(ErrorCode::NotSupported(
1093                "Temporal join requires a table scan as its lookup table".into(),
1094                "Please provide a table scan".into(),
1095            )));
1096        };
1097
1098        if !matches!(logical_scan.as_of(), Some(AsOf::ProcessTime)) {
1099            return Err(RwError::from(ErrorCode::NotSupported(
1100                "Temporal join requires a table defined as temporal table".into(),
1101                "Please use FOR SYSTEM_TIME AS OF PROCTIME() syntax".into(),
1102            )));
1103        }
1104        Ok(logical_scan)
1105    }
1106
1107    fn temporal_join_scan_predicate_pull_up(
1108        logical_scan: &LogicalScan,
1109        predicate: EqJoinPredicate,
1110        output_indices: &[usize],
1111        left_schema_len: usize,
1112    ) -> Result<(StreamTableScan, EqJoinPredicate, Condition, Vec<usize>)> {
1113        // Extract the predicate from logical scan. Only pure scan is supported.
1114        let (new_scan, scan_predicate, project_expr) = logical_scan.predicate_pull_up();
1115        // Construct output column to require column mapping
1116        let o2r = if let Some(project_expr) = project_expr {
1117            project_expr
1118                .into_iter()
1119                .map(|x| x.as_input_ref().unwrap().index)
1120                .collect_vec()
1121        } else {
1122            (0..logical_scan.output_col_idx().len()).collect_vec()
1123        };
1124        let mut join_predicate_rewriter = LookupJoinPredicateRewriter {
1125            offset: left_schema_len,
1126            mapping: o2r.clone(),
1127        };
1128
1129        let new_eq_cond = predicate
1130            .eq_cond()
1131            .rewrite_expr(&mut join_predicate_rewriter);
1132
1133        let mut scan_predicate_rewriter = LookupJoinScanPredicateRewriter {
1134            offset: left_schema_len,
1135        };
1136
1137        let new_other_cond = predicate
1138            .other_cond()
1139            .clone()
1140            .rewrite_expr(&mut join_predicate_rewriter)
1141            .and(scan_predicate.rewrite_expr(&mut scan_predicate_rewriter));
1142
1143        let new_join_on = new_eq_cond.and(new_other_cond);
1144
1145        let new_predicate = EqJoinPredicate::create(
1146            left_schema_len,
1147            new_scan.schema().len(),
1148            new_join_on.clone(),
1149        );
1150
1151        // Rewrite the join output indices and all output indices referred to the old scan need to
1152        // rewrite.
1153        let new_join_output_indices = output_indices
1154            .iter()
1155            .map(|&x| {
1156                if x < left_schema_len {
1157                    x
1158                } else {
1159                    o2r[x - left_schema_len] + left_schema_len
1160                }
1161            })
1162            .collect_vec();
1163
1164        // Use UpstreamOnly chain type
1165        if new_scan.cross_database() {
1166            return Err(RwError::from(ErrorCode::NotSupported(
1167                "Temporal join requires the lookup table to be in the same database as the stream source table".into(),
1168                "Please ensure both tables are in the same database".into(),
1169            )));
1170        }
1171        let new_stream_table_scan =
1172            StreamTableScan::new_with_stream_scan_type(new_scan, StreamScanType::UpstreamOnly);
1173        Ok((
1174            new_stream_table_scan,
1175            new_predicate,
1176            new_join_on,
1177            new_join_output_indices,
1178        ))
1179    }
1180
1181    fn to_stream_temporal_join(
1182        &self,
1183        predicate: EqJoinPredicate,
1184        ctx: &mut ToStreamContext,
1185    ) -> Result<StreamTemporalJoin> {
1186        use super::stream::prelude::*;
1187
1188        assert!(predicate.has_eq());
1189
1190        let right = self.right();
1191
1192        let logical_scan = Self::check_temporal_rhs(&right)?;
1193
1194        let table = logical_scan.table();
1195        let output_column_ids = logical_scan.output_column_ids();
1196
1197        // Verify that the right join key columns are the the prefix of the primary key and
1198        // also contain the distribution key.
1199        let order_col_ids = table.order_column_ids();
1200        let dist_key = table.distribution_key.clone();
1201
1202        let mut dist_key_in_order_key_pos = vec![];
1203        for d in dist_key {
1204            let pos = table
1205                .order_column_indices()
1206                .position(|x| x == d)
1207                .expect("dist_key must in order_key");
1208            dist_key_in_order_key_pos.push(pos);
1209        }
1210        // The shortest prefix of order key that contains distribution key.
1211        let shortest_prefix_len = dist_key_in_order_key_pos
1212            .iter()
1213            .max()
1214            .map_or(0, |pos| pos + 1);
1215
1216        // Reorder the join equal predicate to match the order key.
1217        let mut reorder_idx = Vec::with_capacity(shortest_prefix_len);
1218        for order_col_id in order_col_ids {
1219            let mut found = false;
1220            for (i, eq_idx) in predicate.right_eq_indexes().into_iter().enumerate() {
1221                if order_col_id == output_column_ids[eq_idx] {
1222                    reorder_idx.push(i);
1223                    found = true;
1224                    break;
1225                }
1226            }
1227            if !found {
1228                break;
1229            }
1230        }
1231        if reorder_idx.len() < shortest_prefix_len {
1232            // TODO: support index selection for temporal join and refine this error message.
1233            return Err(RwError::from(ErrorCode::NotSupported(
1234                "Temporal join requires the lookup table's primary key contained exactly in the equivalence condition".into(),
1235                "Please add the primary key of the lookup table to the join condition and remove any other conditions".into(),
1236            )));
1237        }
1238        let lookup_prefix_len = reorder_idx.len();
1239        let predicate = predicate.reorder(&reorder_idx);
1240
1241        let required_dist = if dist_key_in_order_key_pos.is_empty() {
1242            RequiredDist::single()
1243        } else {
1244            let left_eq_indexes = predicate.left_eq_indexes();
1245            let left_dist_key = dist_key_in_order_key_pos
1246                .iter()
1247                .map(|pos| left_eq_indexes[*pos])
1248                .collect_vec();
1249
1250            RequiredDist::hash_shard(&left_dist_key)
1251        };
1252
1253        let lhs_join_key_idx = predicate
1254            .eq_indexes()
1255            .into_iter()
1256            .map(|(l, _)| l)
1257            .collect_vec();
1258        let logical_left = self
1259            .left()
1260            .try_better_locality(&lhs_join_key_idx)
1261            .unwrap_or_else(|| self.left());
1262        let left = logical_left.to_stream(ctx)?;
1263        // Enforce a shuffle for the temporal join LHS to let the scheduler be able to schedule the join fragment together with the RHS with a `no_shuffle` exchange.
1264        let left = required_dist.stream_enforce(left);
1265
1266        let (new_stream_table_scan, new_predicate, new_join_on, new_join_output_indices) =
1267            Self::temporal_join_scan_predicate_pull_up(
1268                logical_scan,
1269                predicate,
1270                self.output_indices(),
1271                self.left().schema().len(),
1272            )?;
1273
1274        let right = RequiredDist::no_shuffle(new_stream_table_scan.into());
1275        if !new_predicate.has_eq() {
1276            return Err(RwError::from(ErrorCode::NotSupported(
1277                "Temporal join requires a non trivial join condition".into(),
1278                "Please remove the false condition of the join".into(),
1279            )));
1280        }
1281
1282        // Construct a new logical join, because we have change its RHS.
1283        let new_logical_join = generic::Join::new(
1284            left,
1285            right,
1286            new_join_on,
1287            self.join_type(),
1288            new_join_output_indices,
1289        );
1290
1291        let new_predicate = new_predicate.retain_prefix_eq_key(lookup_prefix_len);
1292
1293        StreamTemporalJoin::new(new_logical_join, new_predicate, false)
1294    }
1295
1296    fn to_stream_nested_loop_temporal_join(
1297        &self,
1298        predicate: EqJoinPredicate,
1299        ctx: &mut ToStreamContext,
1300    ) -> Result<StreamPlanRef> {
1301        use super::stream::prelude::*;
1302        assert!(!predicate.has_eq());
1303
1304        let left = self.left().to_stream_with_dist_required(
1305            &RequiredDist::PhysicalDist(Distribution::Broadcast),
1306            ctx,
1307        )?;
1308        assert!(left.as_stream_exchange().is_some());
1309
1310        if self.join_type() != JoinType::Inner {
1311            return Err(RwError::from(ErrorCode::NotSupported(
1312                "Temporal join requires an inner join".into(),
1313                "Please use an inner join".into(),
1314            )));
1315        }
1316
1317        if !left.append_only() {
1318            return Err(RwError::from(ErrorCode::NotSupported(
1319                "Nested-loop Temporal join requires the left hash side to be append only".into(),
1320                "Please ensure the left hash side is append only".into(),
1321            )));
1322        }
1323
1324        let right = self.right();
1325        let logical_scan = Self::check_temporal_rhs(&right)?;
1326
1327        let (new_stream_table_scan, new_predicate, new_join_on, new_join_output_indices) =
1328            Self::temporal_join_scan_predicate_pull_up(
1329                logical_scan,
1330                predicate,
1331                self.output_indices(),
1332                self.left().schema().len(),
1333            )?;
1334
1335        let right = RequiredDist::no_shuffle(new_stream_table_scan.into());
1336
1337        // Construct a new logical join, because we have change its RHS.
1338        let new_logical_join = generic::Join::new(
1339            left,
1340            right,
1341            new_join_on,
1342            self.join_type(),
1343            new_join_output_indices,
1344        );
1345
1346        Ok(StreamTemporalJoin::new(new_logical_join, new_predicate, true)?.into())
1347    }
1348
1349    fn to_stream_dynamic_filter(
1350        &self,
1351        predicate: Condition,
1352        ctx: &mut ToStreamContext,
1353    ) -> Result<Option<StreamPlanRef>> {
1354        use super::stream::prelude::*;
1355
1356        // If there is exactly one predicate, it is a comparison (<, <=, >, >=), and the
1357        // join is a `Inner` or `LeftSemi` join, we can convert the scalar subquery into a
1358        // `StreamDynamicFilter`
1359
1360        // Check if `Inner`/`LeftSemi`
1361        if !matches!(self.join_type(), JoinType::Inner | JoinType::LeftSemi) {
1362            return Ok(None);
1363        }
1364
1365        // Check if right side is a scalar
1366        if !self.right().max_one_row() {
1367            return Ok(None);
1368        }
1369        if self.right().schema().len() != 1 {
1370            return Ok(None);
1371        }
1372
1373        // Check if the join condition is a correlated comparison
1374        if predicate.conjunctions.len() > 1 {
1375            return Ok(None);
1376        }
1377        let expr: ExprImpl = predicate.into();
1378        let (left_ref, comparator, right_ref) = match expr.as_comparison_cond() {
1379            Some(v) => v,
1380            None => return Ok(None),
1381        };
1382
1383        let condition_cross_inputs = left_ref.index < self.left().schema().len()
1384            && right_ref.index == self.left().schema().len() /* right side has only one column */;
1385        if !condition_cross_inputs {
1386            // Maybe we should panic here because it means some predicates are not pushed down.
1387            return Ok(None);
1388        }
1389
1390        // We align input types on all join predicates with cmp operator
1391        if self.left().schema().fields()[left_ref.index].data_type
1392            != self.right().schema().fields()[0].data_type
1393        {
1394            return Ok(None);
1395        }
1396
1397        // Check if non of the columns from the inner side is required to output
1398        let all_output_from_left = self
1399            .output_indices()
1400            .iter()
1401            .all(|i| *i < self.left().schema().len());
1402        if !all_output_from_left {
1403            return Ok(None);
1404        }
1405
1406        let left = self.left().to_stream(ctx)?.enforce_concrete_distribution();
1407        let right = self.right().to_stream_with_dist_required(
1408            &RequiredDist::PhysicalDist(Distribution::Broadcast),
1409            ctx,
1410        )?;
1411
1412        assert!(right.as_stream_exchange().is_some());
1413        assert_eq!(
1414            *right.inputs().iter().exactly_one().unwrap().distribution(),
1415            Distribution::Single
1416        );
1417
1418        let core = DynamicFilter::new(comparator, left_ref.index, left, right);
1419        let plan = StreamDynamicFilter::new(core)?.into();
1420        // TODO: `DynamicFilterExecutor` should support `output_indices` in `ChunkBuilder`
1421        if self
1422            .output_indices()
1423            .iter()
1424            .copied()
1425            .ne(0..self.left().schema().len())
1426        {
1427            // The schema of dynamic filter is always the same as the left side now, and we have
1428            // checked that all output columns are from the left side before.
1429            let logical_project = generic::Project::with_mapping(
1430                plan,
1431                ColIndexMapping::with_remaining_columns(
1432                    self.output_indices(),
1433                    self.left().schema().len(),
1434                ),
1435            );
1436            Ok(Some(StreamProject::new(logical_project).into()))
1437        } else {
1438            Ok(Some(plan))
1439        }
1440    }
1441
1442    pub fn index_lookup_join_to_batch_lookup_join(&self) -> Result<BatchPlanRef> {
1443        let predicate = EqJoinPredicate::create(
1444            self.left().schema().len(),
1445            self.right().schema().len(),
1446            self.on().clone(),
1447        );
1448        assert!(predicate.has_eq());
1449
1450        let join = self
1451            .core
1452            .clone_with_inputs(self.core.left.to_batch()?, self.core.right.to_batch()?);
1453
1454        Ok(self
1455            .to_batch_lookup_join(predicate, join)?
1456            .expect("Fail to convert to lookup join")
1457            .into())
1458    }
1459
1460    fn to_stream_asof_join(
1461        &self,
1462        predicate: EqJoinPredicate,
1463        ctx: &mut ToStreamContext,
1464    ) -> Result<StreamPlanRef> {
1465        use super::stream::prelude::*;
1466
1467        if predicate.eq_keys().is_empty() {
1468            return Err(ErrorCode::InvalidInputSyntax(
1469                "AsOf join requires at least 1 equal condition".to_owned(),
1470            )
1471            .into());
1472        }
1473
1474        let (left, right) = self.get_stream_input_for_hash_join(&predicate, ctx)?;
1475        let left_len = left.schema().len();
1476        let core = self.core.clone_with_inputs(left, right);
1477
1478        let inequality_desc =
1479            Self::get_inequality_desc_from_predicate(predicate.other_cond().clone(), left_len)?;
1480
1481        Ok(StreamAsOfJoin::new(core, predicate, inequality_desc)?.into())
1482    }
1483
1484    /// Convert the logical join to a Hash join.
1485    fn to_batch_hash_join(
1486        &self,
1487        logical_join: generic::Join<BatchPlanRef>,
1488        predicate: EqJoinPredicate,
1489    ) -> Result<BatchPlanRef> {
1490        use super::batch::prelude::*;
1491
1492        let left_schema_len = logical_join.left.schema().len();
1493        let asof_desc = self
1494            .is_asof_join()
1495            .then(|| {
1496                Self::get_inequality_desc_from_predicate(
1497                    predicate.other_cond().clone(),
1498                    left_schema_len,
1499                )
1500            })
1501            .transpose()?;
1502
1503        let batch_join = BatchHashJoin::new(logical_join, predicate, asof_desc);
1504        Ok(batch_join.into())
1505    }
1506
1507    pub fn get_inequality_desc_from_predicate(
1508        predicate: Condition,
1509        left_input_len: usize,
1510    ) -> Result<AsOfJoinDesc> {
1511        let expr: ExprImpl = predicate.into();
1512        if let Some((left_input_ref, expr_type, right_input_ref)) = expr.as_comparison_cond() {
1513            if left_input_ref.index() < left_input_len && right_input_ref.index() >= left_input_len
1514            {
1515                Ok(AsOfJoinDesc {
1516                    left_idx: left_input_ref.index() as u32,
1517                    right_idx: (right_input_ref.index() - left_input_len) as u32,
1518                    inequality_type: Self::expr_type_to_comparison_type(expr_type)?.into(),
1519                })
1520            } else {
1521                bail!("inequal condition from the same side should be push down in optimizer");
1522            }
1523        } else {
1524            Err(ErrorCode::InvalidInputSyntax(
1525                "AsOf join requires exactly 1 ineuquality condition".to_owned(),
1526            )
1527            .into())
1528        }
1529    }
1530
1531    fn expr_type_to_comparison_type(expr_type: PbType) -> Result<PbAsOfJoinInequalityType> {
1532        match expr_type {
1533            PbType::LessThan => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeLt),
1534            PbType::LessThanOrEqual => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeLe),
1535            PbType::GreaterThan => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeGt),
1536            PbType::GreaterThanOrEqual => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeGe),
1537            _ => Err(ErrorCode::InvalidInputSyntax(format!(
1538                "Invalid comparison type: {}",
1539                expr_type.as_str_name()
1540            ))
1541            .into()),
1542        }
1543    }
1544}
1545
1546impl ToBatch for LogicalJoin {
1547    fn to_batch(&self) -> Result<crate::optimizer::plan_node::BatchPlanRef> {
1548        let predicate = EqJoinPredicate::create(
1549            self.left().schema().len(),
1550            self.right().schema().len(),
1551            self.on().clone(),
1552        );
1553
1554        let batch_join = self
1555            .core
1556            .clone_with_inputs(self.core.left.to_batch()?, self.core.right.to_batch()?);
1557
1558        let ctx = self.base.ctx();
1559        let config = ctx.session_ctx().config();
1560
1561        if predicate.has_eq() {
1562            if !predicate.eq_keys_are_type_aligned() {
1563                return Err(ErrorCode::InternalError(format!(
1564                    "Join eq keys are not aligned for predicate: {predicate:?}"
1565                ))
1566                .into());
1567            }
1568            if config.batch_enable_lookup_join()
1569                && let Some(lookup_join) = self.to_batch_lookup_join_with_index_selection(
1570                    predicate.clone(),
1571                    batch_join.clone(),
1572                )?
1573            {
1574                return Ok(lookup_join.into());
1575            }
1576            self.to_batch_hash_join(batch_join, predicate)
1577        } else if self.is_asof_join() {
1578            Err(ErrorCode::InvalidInputSyntax(
1579                "AsOf join requires at least 1 equal condition".to_owned(),
1580            )
1581            .into())
1582        } else {
1583            // Convert to Nested-loop Join for non-equal joins
1584            Ok(BatchNestedLoopJoin::new(batch_join).into())
1585        }
1586    }
1587}
1588
1589impl ToStream for LogicalJoin {
1590    fn to_stream(
1591        &self,
1592        ctx: &mut ToStreamContext,
1593    ) -> Result<crate::optimizer::plan_node::StreamPlanRef> {
1594        if self
1595            .on()
1596            .conjunctions
1597            .iter()
1598            .any(|cond| cond.count_nows() > 0)
1599        {
1600            return Err(ErrorCode::NotSupported(
1601                "optimizer has tried to separate the temporal predicate(with now() expression) from the on condition, but it still reminded in on join's condition. Considering move it into WHERE clause?".to_owned(),
1602                 "please refer to https://www.risingwave.dev/docs/current/sql-pattern-temporal-filters/ for more information".to_owned()).into());
1603        }
1604
1605        let predicate = EqJoinPredicate::create(
1606            self.left().schema().len(),
1607            self.right().schema().len(),
1608            self.on().clone(),
1609        );
1610
1611        if self.join_type() == JoinType::AsofInner || self.join_type() == JoinType::AsofLeftOuter {
1612            self.to_stream_asof_join(predicate, ctx)
1613        } else if predicate.has_eq() {
1614            if !predicate.eq_keys_are_type_aligned() {
1615                return Err(ErrorCode::InternalError(format!(
1616                    "Join eq keys are not aligned for predicate: {predicate:?}"
1617                ))
1618                .into());
1619            }
1620
1621            if self.should_be_temporal_join() {
1622                self.to_stream_temporal_join_with_index_selection(predicate, ctx)
1623            } else {
1624                self.to_stream_hash_join(predicate, ctx)
1625            }
1626        } else if self.should_be_temporal_join() {
1627            self.to_stream_nested_loop_temporal_join(predicate, ctx)
1628        } else if let Some(dynamic_filter) =
1629            self.to_stream_dynamic_filter(self.on().clone(), ctx)?
1630        {
1631            Ok(dynamic_filter)
1632        } else {
1633            Err(RwError::from(ErrorCode::NotSupported(
1634                "streaming nested-loop join".to_owned(),
1635                "The non-equal join in the query requires a nested-loop join executor, which could be very expensive to run. \
1636                 Consider rewriting the query to use dynamic filter as a substitute if possible.\n\
1637                 See also: https://docs.risingwave.com/docs/current/sql-pattern-dynamic-filters/".to_owned(),
1638            )))
1639        }
1640    }
1641
1642    fn logical_rewrite_for_stream(
1643        &self,
1644        ctx: &mut RewriteStreamContext,
1645    ) -> Result<(PlanRef, ColIndexMapping)> {
1646        let (left, left_col_change) = self.left().logical_rewrite_for_stream(ctx)?;
1647        let left_len = left.schema().len();
1648        let (right, right_col_change) = self.right().logical_rewrite_for_stream(ctx)?;
1649        let (join, out_col_change) = self.rewrite_with_left_right(
1650            left.clone(),
1651            left_col_change,
1652            right.clone(),
1653            right_col_change,
1654        );
1655
1656        let mapping = ColIndexMapping::with_remaining_columns(
1657            join.output_indices(),
1658            join.internal_column_num(),
1659        );
1660
1661        let l2o = join.core.l2i_col_mapping().composite(&mapping);
1662        let r2o = join.core.r2i_col_mapping().composite(&mapping);
1663
1664        // Add missing pk indices to the logical join
1665        let mut left_to_add = left
1666            .expect_stream_key()
1667            .iter()
1668            .cloned()
1669            .filter(|i| l2o.try_map(*i).is_none())
1670            .collect_vec();
1671
1672        let mut right_to_add = right
1673            .expect_stream_key()
1674            .iter()
1675            .filter(|&&i| r2o.try_map(i).is_none())
1676            .map(|&i| i + left_len)
1677            .collect_vec();
1678
1679        // NOTE(st1page): add join keys in the pk_indices a work around before we really have stream
1680        // key.
1681        let right_len = right.schema().len();
1682        let eq_predicate = EqJoinPredicate::create(left_len, right_len, join.on().clone());
1683
1684        let either_or_both = self.core.add_which_join_key_to_pk();
1685
1686        for (lk, rk) in eq_predicate.eq_indexes() {
1687            match either_or_both {
1688                EitherOrBoth::Left(_) => {
1689                    if l2o.try_map(lk).is_none() {
1690                        left_to_add.push(lk);
1691                    }
1692                }
1693                EitherOrBoth::Right(_) => {
1694                    if r2o.try_map(rk).is_none() {
1695                        right_to_add.push(rk + left_len)
1696                    }
1697                }
1698                EitherOrBoth::Both(_, _) => {
1699                    if l2o.try_map(lk).is_none() {
1700                        left_to_add.push(lk);
1701                    }
1702                    if r2o.try_map(rk).is_none() {
1703                        right_to_add.push(rk + left_len)
1704                    }
1705                }
1706            };
1707        }
1708        let left_to_add = left_to_add.into_iter().unique();
1709        let right_to_add = right_to_add.into_iter().unique();
1710        // NOTE(st1page) over
1711
1712        let mut new_output_indices = join.output_indices().clone();
1713        if !join.is_right_join() {
1714            new_output_indices.extend(left_to_add);
1715        }
1716        if !join.is_left_join() {
1717            new_output_indices.extend(right_to_add);
1718        }
1719
1720        let join_with_pk = join.clone_with_output_indices(new_output_indices);
1721
1722        let plan = if join_with_pk.join_type() == JoinType::FullOuter {
1723            // ignore the all NULL to maintain the stream key's uniqueness, see https://github.com/risingwavelabs/risingwave/issues/8084 for more information
1724
1725            let l2o = join_with_pk
1726                .core
1727                .l2i_col_mapping()
1728                .composite(&join_with_pk.core.i2o_col_mapping());
1729            let r2o = join_with_pk
1730                .core
1731                .r2i_col_mapping()
1732                .composite(&join_with_pk.core.i2o_col_mapping());
1733            let left_right_stream_keys = join_with_pk
1734                .left()
1735                .expect_stream_key()
1736                .iter()
1737                .map(|i| l2o.map(*i))
1738                .chain(
1739                    join_with_pk
1740                        .right()
1741                        .expect_stream_key()
1742                        .iter()
1743                        .map(|i| r2o.map(*i)),
1744                )
1745                .collect_vec();
1746            let plan: PlanRef = join_with_pk.into();
1747            LogicalFilter::filter_out_all_null_keys(plan, &left_right_stream_keys)
1748        } else {
1749            join_with_pk.into()
1750        };
1751
1752        // the added columns is at the end, so it will not change the exists column index
1753        Ok((plan, out_col_change))
1754    }
1755
1756    fn try_better_locality(&self, columns: &[usize]) -> Option<PlanRef> {
1757        let mut ctx = ToStreamContext::new(false);
1758        // only pass through the locality information if it can be converted to dynamic filter
1759        if let Ok(Some(_)) = self.to_stream_dynamic_filter(self.on().clone(), &mut ctx) {
1760            // since dynamic filter only supports left input ref in the output indices, we can safely use o2i mapping to convert the required columns.
1761            let o2i_mapping = self.core.o2i_col_mapping();
1762            let left_input_columns = columns
1763                .iter()
1764                .map(|&col| o2i_mapping.try_map(col))
1765                .collect::<Option<Vec<usize>>>()?;
1766            if let Some(better_left_plan) = self.left().try_better_locality(&left_input_columns) {
1767                return Some(
1768                    self.clone_with_left_right(better_left_plan, self.right())
1769                        .into(),
1770                );
1771            }
1772        }
1773        None
1774    }
1775}
1776
1777#[cfg(test)]
1778mod tests {
1779
1780    use std::collections::HashSet;
1781
1782    use risingwave_common::catalog::{Field, Schema};
1783    use risingwave_common::types::{DataType, Datum};
1784    use risingwave_pb::expr::expr_node::Type;
1785
1786    use super::*;
1787    use crate::expr::{FunctionCall, Literal, assert_eq_input_ref};
1788    use crate::optimizer::optimizer_context::OptimizerContext;
1789    use crate::optimizer::plan_node::LogicalValues;
1790    use crate::optimizer::property::FunctionalDependency;
1791
1792    /// Pruning
1793    /// ```text
1794    /// Join(on: input_ref(1)=input_ref(3))
1795    ///   TableScan(v1, v2, v3)
1796    ///   TableScan(v4, v5, v6)
1797    /// ```
1798    /// with required columns [2,3] will result in
1799    /// ```text
1800    /// Project(input_ref(1), input_ref(2))
1801    ///   Join(on: input_ref(0)=input_ref(2))
1802    ///     TableScan(v2, v3)
1803    ///     TableScan(v4)
1804    /// ```
1805    #[tokio::test]
1806    async fn test_prune_join() {
1807        let ty = DataType::Int32;
1808        let ctx = OptimizerContext::mock().await;
1809        let fields: Vec<Field> = (1..7)
1810            .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
1811            .collect();
1812        let left = LogicalValues::new(
1813            vec![],
1814            Schema {
1815                fields: fields[0..3].to_vec(),
1816            },
1817            ctx.clone(),
1818        );
1819        let right = LogicalValues::new(
1820            vec![],
1821            Schema {
1822                fields: fields[3..6].to_vec(),
1823            },
1824            ctx,
1825        );
1826        let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
1827            FunctionCall::new(
1828                Type::Equal,
1829                vec![
1830                    ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
1831                    ExprImpl::InputRef(Box::new(InputRef::new(3, ty))),
1832                ],
1833            )
1834            .unwrap(),
1835        ));
1836        let join_type = JoinType::Inner;
1837        let join: PlanRef = LogicalJoin::new(
1838            left.into(),
1839            right.into(),
1840            join_type,
1841            Condition::with_expr(on),
1842        )
1843        .into();
1844
1845        // Perform the prune
1846        let required_cols = vec![2, 3];
1847        let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1848
1849        // Check the result
1850        let join = plan.as_logical_join().unwrap();
1851        assert_eq!(join.schema().fields().len(), 2);
1852        assert_eq!(join.schema().fields()[0], fields[2]);
1853        assert_eq!(join.schema().fields()[1], fields[3]);
1854
1855        let expr: ExprImpl = join.on().clone().into();
1856        let call = expr.as_function_call().unwrap();
1857        assert_eq_input_ref!(&call.inputs()[0], 0);
1858        assert_eq_input_ref!(&call.inputs()[1], 2);
1859
1860        let left = join.left();
1861        let left = left.as_logical_values().unwrap();
1862        assert_eq!(left.schema().fields(), &fields[1..3]);
1863        let right = join.right();
1864        let right = right.as_logical_values().unwrap();
1865        assert_eq!(right.schema().fields(), &fields[3..4]);
1866    }
1867
1868    /// Semi join panicked previously at `prune_col`. Add test to prevent regression.
1869    #[tokio::test]
1870    async fn test_prune_semi_join() {
1871        let ty = DataType::Int32;
1872        let ctx = OptimizerContext::mock().await;
1873        let fields: Vec<Field> = (1..7)
1874            .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
1875            .collect();
1876        let left = LogicalValues::new(
1877            vec![],
1878            Schema {
1879                fields: fields[0..3].to_vec(),
1880            },
1881            ctx.clone(),
1882        );
1883        let right = LogicalValues::new(
1884            vec![],
1885            Schema {
1886                fields: fields[3..6].to_vec(),
1887            },
1888            ctx,
1889        );
1890        let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
1891            FunctionCall::new(
1892                Type::Equal,
1893                vec![
1894                    ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
1895                    ExprImpl::InputRef(Box::new(InputRef::new(4, ty))),
1896                ],
1897            )
1898            .unwrap(),
1899        ));
1900        for join_type in [
1901            JoinType::LeftSemi,
1902            JoinType::RightSemi,
1903            JoinType::LeftAnti,
1904            JoinType::RightAnti,
1905        ] {
1906            let join = LogicalJoin::new(
1907                left.clone().into(),
1908                right.clone().into(),
1909                join_type,
1910                Condition::with_expr(on.clone()),
1911            );
1912
1913            let offset = if join.is_right_join() { 3 } else { 0 };
1914            let join: PlanRef = join.into();
1915            // Perform the prune
1916            let required_cols = vec![0];
1917            // key 0 is never used in the join (always key 1)
1918            let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1919            let as_plan = plan.as_logical_join().unwrap();
1920            // Check the result
1921            assert_eq!(as_plan.schema().fields().len(), 1);
1922            assert_eq!(as_plan.schema().fields()[0], fields[offset]);
1923
1924            // Perform the prune
1925            let required_cols = vec![0, 1, 2];
1926            // should not panic here
1927            let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1928            let as_plan = plan.as_logical_join().unwrap();
1929            // Check the result
1930            assert_eq!(as_plan.schema().fields().len(), 3);
1931            assert_eq!(as_plan.schema().fields()[0], fields[offset]);
1932            assert_eq!(as_plan.schema().fields()[1], fields[offset + 1]);
1933            assert_eq!(as_plan.schema().fields()[2], fields[offset + 2]);
1934        }
1935    }
1936
1937    /// Pruning
1938    /// ```text
1939    /// Join(on: input_ref(1)=input_ref(3))
1940    ///   TableScan(v1, v2, v3)
1941    ///   TableScan(v4, v5, v6)
1942    /// ```
1943    /// with required columns [1, 3] will result in
1944    /// ```text
1945    /// Join(on: input_ref(0)=input_ref(1))
1946    ///   TableScan(v2)
1947    ///   TableScan(v4)
1948    /// ```
1949    #[tokio::test]
1950    async fn test_prune_join_no_project() {
1951        let ty = DataType::Int32;
1952        let ctx = OptimizerContext::mock().await;
1953        let fields: Vec<Field> = (1..7)
1954            .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
1955            .collect();
1956        let left = LogicalValues::new(
1957            vec![],
1958            Schema {
1959                fields: fields[0..3].to_vec(),
1960            },
1961            ctx.clone(),
1962        );
1963        let right = LogicalValues::new(
1964            vec![],
1965            Schema {
1966                fields: fields[3..6].to_vec(),
1967            },
1968            ctx,
1969        );
1970        let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
1971            FunctionCall::new(
1972                Type::Equal,
1973                vec![
1974                    ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
1975                    ExprImpl::InputRef(Box::new(InputRef::new(3, ty))),
1976                ],
1977            )
1978            .unwrap(),
1979        ));
1980        let join_type = JoinType::Inner;
1981        let join: PlanRef = LogicalJoin::new(
1982            left.into(),
1983            right.into(),
1984            join_type,
1985            Condition::with_expr(on),
1986        )
1987        .into();
1988
1989        // Perform the prune
1990        let required_cols = vec![1, 3];
1991        let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1992
1993        // Check the result
1994        let join = plan.as_logical_join().unwrap();
1995        assert_eq!(join.schema().fields().len(), 2);
1996        assert_eq!(join.schema().fields()[0], fields[1]);
1997        assert_eq!(join.schema().fields()[1], fields[3]);
1998
1999        let expr: ExprImpl = join.on().clone().into();
2000        let call = expr.as_function_call().unwrap();
2001        assert_eq_input_ref!(&call.inputs()[0], 0);
2002        assert_eq_input_ref!(&call.inputs()[1], 1);
2003
2004        let left = join.left();
2005        let left = left.as_logical_values().unwrap();
2006        assert_eq!(left.schema().fields(), &fields[1..2]);
2007        let right = join.right();
2008        let right = right.as_logical_values().unwrap();
2009        assert_eq!(right.schema().fields(), &fields[3..4]);
2010    }
2011
2012    /// Convert
2013    /// ```text
2014    /// Join(on: ($1 = $3) AND ($2 == 42))
2015    ///   TableScan(v1, v2, v3)
2016    ///   TableScan(v4, v5, v6)
2017    /// ```
2018    /// to
2019    /// ```text
2020    /// Filter($2 == 42)
2021    ///   HashJoin(on: $1 = $3)
2022    ///     TableScan(v1, v2, v3)
2023    ///     TableScan(v4, v5, v6)
2024    /// ```
2025    #[tokio::test]
2026    async fn test_join_to_batch() {
2027        let ctx = OptimizerContext::mock().await;
2028        let fields: Vec<Field> = (1..7)
2029            .map(|i| Field::with_name(DataType::Int32, format!("v{}", i)))
2030            .collect();
2031        let left = LogicalValues::new(
2032            vec![],
2033            Schema {
2034                fields: fields[0..3].to_vec(),
2035            },
2036            ctx.clone(),
2037        );
2038        let right = LogicalValues::new(
2039            vec![],
2040            Schema {
2041                fields: fields[3..6].to_vec(),
2042            },
2043            ctx,
2044        );
2045
2046        fn input_ref(i: usize) -> ExprImpl {
2047            ExprImpl::InputRef(Box::new(InputRef::new(i, DataType::Int32)))
2048        }
2049        let eq_cond = ExprImpl::FunctionCall(Box::new(
2050            FunctionCall::new(Type::Equal, vec![input_ref(1), input_ref(3)]).unwrap(),
2051        ));
2052        let non_eq_cond = ExprImpl::FunctionCall(Box::new(
2053            FunctionCall::new(
2054                Type::Equal,
2055                vec![
2056                    input_ref(2),
2057                    ExprImpl::Literal(Box::new(Literal::new(
2058                        Datum::Some(42_i32.into()),
2059                        DataType::Int32,
2060                    ))),
2061                ],
2062            )
2063            .unwrap(),
2064        ));
2065        // Condition: ($1 = $3) AND ($2 == 42)
2066        let on_cond = ExprImpl::FunctionCall(Box::new(
2067            FunctionCall::new(Type::And, vec![eq_cond.clone(), non_eq_cond.clone()]).unwrap(),
2068        ));
2069
2070        let join_type = JoinType::Inner;
2071        let logical_join = LogicalJoin::new(
2072            left.into(),
2073            right.into(),
2074            join_type,
2075            Condition::with_expr(on_cond),
2076        );
2077
2078        // Perform `to_batch`
2079        let result = logical_join.to_batch().unwrap();
2080
2081        // Expected plan:  HashJoin($1 = $3 AND $2 == 42)
2082        let hash_join = result.as_batch_hash_join().unwrap();
2083        assert_eq!(
2084            ExprImpl::from(hash_join.eq_join_predicate().eq_cond()),
2085            eq_cond
2086        );
2087        assert_eq!(
2088            *hash_join
2089                .eq_join_predicate()
2090                .non_eq_cond()
2091                .conjunctions
2092                .first()
2093                .unwrap(),
2094            non_eq_cond
2095        );
2096    }
2097
2098    /// Convert
2099    /// ```text
2100    /// Join(join_type: left outer, on: ($1 = $3) AND ($2 == 42))
2101    ///   TableScan(v1, v2, v3)
2102    ///   TableScan(v4, v5, v6)
2103    /// ```
2104    /// to
2105    /// ```text
2106    /// HashJoin(join_type: left outer, on: ($1 = $3) AND ($2 == 42))
2107    ///   TableScan(v1, v2, v3)
2108    ///   TableScan(v4, v5, v6)
2109    /// ```
2110    #[tokio::test]
2111    #[ignore] // ignore due to refactor logical scan, but the test seem to duplicate with the explain test
2112    // framework, maybe we will remove it?
2113    async fn test_join_to_stream() {
2114        // let ctx = Rc::new(RefCell::new(QueryContext::mock().await));
2115        // let fields: Vec<Field> = (1..7)
2116        //     .map(|i| Field {
2117        //         data_type: DataType::Int32,
2118        //         name: format!("v{}", i),
2119        //     })
2120        //     .collect();
2121        // let left = LogicalScan::new(
2122        //     "left".to_string(),
2123        //     TableId::new(0),
2124        //     vec![1.into(), 2.into(), 3.into()],
2125        //     Schema {
2126        //         fields: fields[0..3].to_vec(),
2127        //     },
2128        //     ctx.clone(),
2129        // );
2130        // let right = LogicalScan::new(
2131        //     "right".to_string(),
2132        //     TableId::new(0),
2133        //     vec![4.into(), 5.into(), 6.into()],
2134        //     Schema {
2135        //                 fields: fields[3..6].to_vec(),
2136        //     },
2137        //     ctx,
2138        // );
2139        // let eq_cond = ExprImpl::FunctionCall(Box::new(
2140        //     FunctionCall::new(
2141        //         Type::Equal,
2142        //         vec![
2143        //             ExprImpl::InputRef(Box::new(InputRef::new(1, DataType::Int32))),
2144        //             ExprImpl::InputRef(Box::new(InputRef::new(3, DataType::Int32))),
2145        //         ],
2146        //     )
2147        //     .unwrap(),
2148        // ));
2149        // let non_eq_cond = ExprImpl::FunctionCall(Box::new(
2150        //     FunctionCall::new(
2151        //         Type::Equal,
2152        //         vec![
2153        //             ExprImpl::InputRef(Box::new(InputRef::new(2, DataType::Int32))),
2154        //             ExprImpl::Literal(Box::new(Literal::new(
2155        //                 Datum::Some(42_i32.into()),
2156        //                 DataType::Int32,
2157        //             ))),
2158        //         ],
2159        //     )
2160        //     .unwrap(),
2161        // ));
2162        // // Condition: ($1 = $3) AND ($2 == 42)
2163        // let on_cond = ExprImpl::FunctionCall(Box::new(
2164        //     FunctionCall::new(Type::And, vec![eq_cond, non_eq_cond]).unwrap(),
2165        // ));
2166
2167        // let join_type = JoinType::LeftOuter;
2168        // let logical_join = LogicalJoin::new(
2169        //     left.clone().into(),
2170        //     right.clone().into(),
2171        //     join_type,
2172        //     Condition::with_expr(on_cond.clone()),
2173        // );
2174
2175        // // Perform `to_stream`
2176        // let result = logical_join.to_stream();
2177
2178        // // Expected plan: HashJoin(($1 = $3) AND ($2 == 42))
2179        // let hash_join = result.as_stream_hash_join().unwrap();
2180        // assert_eq!(hash_join.eq_join_predicate().all_cond().as_expr(), on_cond);
2181    }
2182    /// Pruning
2183    /// ```text
2184    /// Join(on: input_ref(1)=input_ref(3))
2185    ///   TableScan(v1, v2, v3)
2186    ///   TableScan(v4, v5, v6)
2187    /// ```
2188    /// with required columns [3, 2] will result in
2189    /// ```text
2190    /// Project(input_ref(2), input_ref(1))
2191    ///   Join(on: input_ref(0)=input_ref(2))
2192    ///     TableScan(v2, v3)
2193    ///     TableScan(v4)
2194    /// ```
2195    #[tokio::test]
2196    async fn test_join_column_prune_with_order_required() {
2197        let ty = DataType::Int32;
2198        let ctx = OptimizerContext::mock().await;
2199        let fields: Vec<Field> = (1..7)
2200            .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
2201            .collect();
2202        let left = LogicalValues::new(
2203            vec![],
2204            Schema {
2205                fields: fields[0..3].to_vec(),
2206            },
2207            ctx.clone(),
2208        );
2209        let right = LogicalValues::new(
2210            vec![],
2211            Schema {
2212                fields: fields[3..6].to_vec(),
2213            },
2214            ctx,
2215        );
2216        let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
2217            FunctionCall::new(
2218                Type::Equal,
2219                vec![
2220                    ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
2221                    ExprImpl::InputRef(Box::new(InputRef::new(3, ty))),
2222                ],
2223            )
2224            .unwrap(),
2225        ));
2226        let join_type = JoinType::Inner;
2227        let join: PlanRef = LogicalJoin::new(
2228            left.into(),
2229            right.into(),
2230            join_type,
2231            Condition::with_expr(on),
2232        )
2233        .into();
2234
2235        // Perform the prune
2236        let required_cols = vec![3, 2];
2237        let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
2238
2239        // Check the result
2240        let join = plan.as_logical_join().unwrap();
2241        assert_eq!(join.schema().fields().len(), 2);
2242        assert_eq!(join.schema().fields()[0], fields[3]);
2243        assert_eq!(join.schema().fields()[1], fields[2]);
2244
2245        let expr: ExprImpl = join.on().clone().into();
2246        let call = expr.as_function_call().unwrap();
2247        assert_eq_input_ref!(&call.inputs()[0], 0);
2248        assert_eq_input_ref!(&call.inputs()[1], 2);
2249
2250        let left = join.left();
2251        let left = left.as_logical_values().unwrap();
2252        assert_eq!(left.schema().fields(), &fields[1..3]);
2253        let right = join.right();
2254        let right = right.as_logical_values().unwrap();
2255        assert_eq!(right.schema().fields(), &fields[3..4]);
2256    }
2257
2258    #[tokio::test]
2259    async fn fd_derivation_inner_outer_join() {
2260        // left: [l0, l1], right: [r0, r1, r2]
2261        // FD: l0 --> l1, r0 --> { r1, r2 }
2262        // On: l0 = 0 AND l1 = r1
2263        //
2264        // Inner Join:
2265        //  Schema: [l0, l1, r0, r1, r2]
2266        //  FD: l0 --> l1, r0 --> { r1, r2 }, {} --> l0, l1 --> r1, r1 --> l1
2267        // Left Outer Join:
2268        //  Schema: [l0, l1, r0, r1, r2]
2269        //  FD: l0 --> l1
2270        // Right Outer Join:
2271        //  Schema: [l0, l1, r0, r1, r2]
2272        //  FD: r0 --> { r1, r2 }
2273        // Full Outer Join:
2274        //  Schema: [l0, l1, r0, r1, r2]
2275        //  FD: empty
2276        // Left Semi/Anti Join:
2277        //  Schema: [l0, l1]
2278        //  FD: l0 --> l1
2279        // Right Semi/Anti Join:
2280        //  Schema: [r0, r1, r2]
2281        //  FD: r0 --> {r1, r2}
2282        let ctx = OptimizerContext::mock().await;
2283        let left = {
2284            let fields: Vec<Field> = vec![
2285                Field::with_name(DataType::Int32, "l0"),
2286                Field::with_name(DataType::Int32, "l1"),
2287            ];
2288            let mut values = LogicalValues::new(vec![], Schema { fields }, ctx.clone());
2289            // 0 --> 1
2290            values
2291                .base
2292                .functional_dependency_mut()
2293                .add_functional_dependency_by_column_indices(&[0], &[1]);
2294            values
2295        };
2296        let right = {
2297            let fields: Vec<Field> = vec![
2298                Field::with_name(DataType::Int32, "r0"),
2299                Field::with_name(DataType::Int32, "r1"),
2300                Field::with_name(DataType::Int32, "r2"),
2301            ];
2302            let mut values = LogicalValues::new(vec![], Schema { fields }, ctx);
2303            // 0 --> 1, 2
2304            values
2305                .base
2306                .functional_dependency_mut()
2307                .add_functional_dependency_by_column_indices(&[0], &[1, 2]);
2308            values
2309        };
2310        // l0 = 0 AND l1 = r1
2311        let on: ExprImpl = FunctionCall::new(
2312            Type::And,
2313            vec![
2314                FunctionCall::new(
2315                    Type::Equal,
2316                    vec![
2317                        InputRef::new(0, DataType::Int32).into(),
2318                        ExprImpl::literal_int(0),
2319                    ],
2320                )
2321                .unwrap()
2322                .into(),
2323                FunctionCall::new(
2324                    Type::Equal,
2325                    vec![
2326                        InputRef::new(1, DataType::Int32).into(),
2327                        InputRef::new(3, DataType::Int32).into(),
2328                    ],
2329                )
2330                .unwrap()
2331                .into(),
2332            ],
2333        )
2334        .unwrap()
2335        .into();
2336        let expected_fd_set = [
2337            (
2338                JoinType::Inner,
2339                [
2340                    // inherit from left
2341                    FunctionalDependency::with_indices(5, &[0], &[1]),
2342                    // inherit from right
2343                    FunctionalDependency::with_indices(5, &[2], &[3, 4]),
2344                    // constant column in join condition
2345                    FunctionalDependency::with_indices(5, &[], &[0]),
2346                    // eq column in join condition
2347                    FunctionalDependency::with_indices(5, &[1], &[3]),
2348                    FunctionalDependency::with_indices(5, &[3], &[1]),
2349                ]
2350                .into_iter()
2351                .collect::<HashSet<_>>(),
2352            ),
2353            (JoinType::FullOuter, HashSet::new()),
2354            (
2355                JoinType::RightOuter,
2356                [
2357                    // inherit from right
2358                    FunctionalDependency::with_indices(5, &[2], &[3, 4]),
2359                ]
2360                .into_iter()
2361                .collect::<HashSet<_>>(),
2362            ),
2363            (
2364                JoinType::LeftOuter,
2365                [
2366                    // inherit from left
2367                    FunctionalDependency::with_indices(5, &[0], &[1]),
2368                ]
2369                .into_iter()
2370                .collect::<HashSet<_>>(),
2371            ),
2372            (
2373                JoinType::LeftSemi,
2374                [
2375                    // inherit from left
2376                    FunctionalDependency::with_indices(2, &[0], &[1]),
2377                ]
2378                .into_iter()
2379                .collect::<HashSet<_>>(),
2380            ),
2381            (
2382                JoinType::LeftAnti,
2383                [
2384                    // inherit from left
2385                    FunctionalDependency::with_indices(2, &[0], &[1]),
2386                ]
2387                .into_iter()
2388                .collect::<HashSet<_>>(),
2389            ),
2390            (
2391                JoinType::RightSemi,
2392                [
2393                    // inherit from right
2394                    FunctionalDependency::with_indices(3, &[0], &[1, 2]),
2395                ]
2396                .into_iter()
2397                .collect::<HashSet<_>>(),
2398            ),
2399            (
2400                JoinType::RightAnti,
2401                [
2402                    // inherit from right
2403                    FunctionalDependency::with_indices(3, &[0], &[1, 2]),
2404                ]
2405                .into_iter()
2406                .collect::<HashSet<_>>(),
2407            ),
2408        ];
2409
2410        for (join_type, expected_res) in expected_fd_set {
2411            let join = LogicalJoin::new(
2412                left.clone().into(),
2413                right.clone().into(),
2414                join_type,
2415                Condition::with_expr(on.clone()),
2416            );
2417            let fd_set = join
2418                .functional_dependency()
2419                .as_dependencies()
2420                .iter()
2421                .cloned()
2422                .collect::<HashSet<_>>();
2423            assert_eq!(fd_set, expected_res);
2424        }
2425    }
2426}