risingwave_frontend/optimizer/plan_node/
logical_join.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16
17use fixedbitset::FixedBitSet;
18use itertools::{EitherOrBoth, Itertools};
19use pretty_xmlish::{Pretty, XmlNode};
20use risingwave_expr::bail;
21use risingwave_pb::expr::expr_node::PbType;
22use risingwave_pb::plan_common::{AsOfJoinDesc, JoinType, PbAsOfJoinInequalityType};
23use risingwave_pb::stream_plan::StreamScanType;
24use risingwave_sqlparser::ast::AsOf;
25
26use super::generic::{
27    GenericPlanNode, GenericPlanRef, push_down_into_join, push_down_join_condition,
28};
29use super::utils::{Distill, childless_record};
30use super::{
31    ColPrunable, ExprRewritable, Logical, PlanBase, PlanRef, PlanTreeNodeBinary, PredicatePushdown,
32    StreamHashJoin, StreamProject, ToBatch, ToStream, generic,
33};
34use crate::error::{ErrorCode, Result, RwError};
35use crate::expr::{CollectInputRef, Expr, ExprImpl, ExprRewriter, ExprType, ExprVisitor, InputRef};
36use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
37use crate::optimizer::plan_node::generic::DynamicFilter;
38use crate::optimizer::plan_node::stream_asof_join::StreamAsOfJoin;
39use crate::optimizer::plan_node::utils::IndicesDisplay;
40use crate::optimizer::plan_node::{
41    BatchHashJoin, BatchLookupJoin, BatchNestedLoopJoin, ColumnPruningContext, EqJoinPredicate,
42    LogicalFilter, LogicalScan, PredicatePushdownContext, RewriteStreamContext,
43    StreamDynamicFilter, StreamFilter, StreamTableScan, StreamTemporalJoin, ToStreamContext,
44};
45use crate::optimizer::plan_visitor::LogicalCardinalityExt;
46use crate::optimizer::property::{Distribution, Order, RequiredDist};
47use crate::utils::{ColIndexMapping, ColIndexMappingRewriteExt, Condition, ConditionDisplay};
48
49/// `LogicalJoin` combines two relations according to some condition.
50///
51/// Each output row has fields from the left and right inputs. The set of output rows is a subset
52/// of the cartesian product of the two inputs; precisely which subset depends on the join
53/// condition. In addition, the output columns are a subset of the columns of the left and
54/// right columns, dependent on the output indices provided. A repeat output index is illegal.
55#[derive(Debug, Clone, PartialEq, Eq, Hash)]
56pub struct LogicalJoin {
57    pub base: PlanBase<Logical>,
58    core: generic::Join<PlanRef>,
59}
60
61impl Distill for LogicalJoin {
62    fn distill<'a>(&self) -> XmlNode<'a> {
63        let verbose = self.base.ctx().is_explain_verbose();
64        let mut vec = Vec::with_capacity(if verbose { 3 } else { 2 });
65        vec.push(("type", Pretty::debug(&self.join_type())));
66
67        let concat_schema = self.core.concat_schema();
68        let cond = Pretty::debug(&ConditionDisplay {
69            condition: self.on(),
70            input_schema: &concat_schema,
71        });
72        vec.push(("on", cond));
73
74        if verbose {
75            let data = IndicesDisplay::from_join(&self.core, &concat_schema);
76            vec.push(("output", data));
77        }
78
79        childless_record("LogicalJoin", vec)
80    }
81}
82
83impl LogicalJoin {
84    pub(crate) fn new(left: PlanRef, right: PlanRef, join_type: JoinType, on: Condition) -> Self {
85        let core = generic::Join::with_full_output(left, right, join_type, on);
86        Self::with_core(core)
87    }
88
89    pub(crate) fn with_output_indices(
90        left: PlanRef,
91        right: PlanRef,
92        join_type: JoinType,
93        on: Condition,
94        output_indices: Vec<usize>,
95    ) -> Self {
96        let core = generic::Join::new(left, right, on, join_type, output_indices);
97        Self::with_core(core)
98    }
99
100    pub fn with_core(core: generic::Join<PlanRef>) -> Self {
101        let base = PlanBase::new_logical_with_core(&core);
102        LogicalJoin { base, core }
103    }
104
105    pub fn create(
106        left: PlanRef,
107        right: PlanRef,
108        join_type: JoinType,
109        on_clause: ExprImpl,
110    ) -> PlanRef {
111        Self::new(left, right, join_type, Condition::with_expr(on_clause)).into()
112    }
113
114    pub fn internal_column_num(&self) -> usize {
115        self.core.internal_column_num()
116    }
117
118    pub fn i2l_col_mapping_ignore_join_type(&self) -> ColIndexMapping {
119        self.core.i2l_col_mapping_ignore_join_type()
120    }
121
122    pub fn i2r_col_mapping_ignore_join_type(&self) -> ColIndexMapping {
123        self.core.i2r_col_mapping_ignore_join_type()
124    }
125
126    /// Get a reference to the logical join's on.
127    pub fn on(&self) -> &Condition {
128        &self.core.on
129    }
130
131    /// Collect all input ref in the on condition. And separate them into left and right.
132    pub fn input_idx_on_condition(&self) -> (Vec<usize>, Vec<usize>) {
133        let input_refs = self
134            .core
135            .on
136            .collect_input_refs(self.core.left.schema().len() + self.core.right.schema().len());
137        let index_group = input_refs
138            .ones()
139            .chunk_by(|i| *i < self.core.left.schema().len());
140        let left_index = index_group
141            .into_iter()
142            .next()
143            .map_or(vec![], |group| group.1.collect_vec());
144        let right_index = index_group.into_iter().next().map_or(vec![], |group| {
145            group
146                .1
147                .map(|i| i - self.core.left.schema().len())
148                .collect_vec()
149        });
150        (left_index, right_index)
151    }
152
153    /// Get the join type of the logical join.
154    pub fn join_type(&self) -> JoinType {
155        self.core.join_type
156    }
157
158    /// Get the eq join key of the logical join.
159    pub fn eq_indexes(&self) -> Vec<(usize, usize)> {
160        self.core.eq_indexes()
161    }
162
163    /// Get the output indices of the logical join.
164    pub fn output_indices(&self) -> &Vec<usize> {
165        &self.core.output_indices
166    }
167
168    /// Clone with new output indices
169    pub fn clone_with_output_indices(&self, output_indices: Vec<usize>) -> Self {
170        Self::with_core(generic::Join {
171            output_indices,
172            ..self.core.clone()
173        })
174    }
175
176    /// Clone with new `on` condition
177    pub fn clone_with_cond(&self, on: Condition) -> Self {
178        Self::with_core(generic::Join {
179            on,
180            ..self.core.clone()
181        })
182    }
183
184    pub fn is_left_join(&self) -> bool {
185        matches!(self.join_type(), JoinType::LeftSemi | JoinType::LeftAnti)
186    }
187
188    pub fn is_right_join(&self) -> bool {
189        matches!(self.join_type(), JoinType::RightSemi | JoinType::RightAnti)
190    }
191
192    pub fn is_full_out(&self) -> bool {
193        self.core.is_full_out()
194    }
195
196    pub fn is_asof_join(&self) -> bool {
197        self.join_type() == JoinType::AsofInner || self.join_type() == JoinType::AsofLeftOuter
198    }
199
200    pub fn output_indices_are_trivial(&self) -> bool {
201        self.output_indices() == &(0..self.internal_column_num()).collect_vec()
202    }
203
204    /// Try to simplify the outer join with the predicate on the top of the join
205    ///
206    /// now it is just a naive implementation for comparison expression, we can give a more general
207    /// implementation with constant folding in future
208    fn simplify_outer(predicate: &Condition, left_col_num: usize, join_type: JoinType) -> JoinType {
209        let (mut gen_null_in_left, mut gen_null_in_right) = match join_type {
210            JoinType::LeftOuter => (false, true),
211            JoinType::RightOuter => (true, false),
212            JoinType::FullOuter => (true, true),
213            _ => return join_type,
214        };
215
216        for expr in &predicate.conjunctions {
217            if let ExprImpl::FunctionCall(func) = expr {
218                match func.func_type() {
219                    ExprType::Equal
220                    | ExprType::NotEqual
221                    | ExprType::LessThan
222                    | ExprType::LessThanOrEqual
223                    | ExprType::GreaterThan
224                    | ExprType::GreaterThanOrEqual => {
225                        for input in func.inputs() {
226                            if let ExprImpl::InputRef(input) = input {
227                                let idx = input.index;
228                                if idx < left_col_num {
229                                    gen_null_in_left = false;
230                                } else {
231                                    gen_null_in_right = false;
232                                }
233                            }
234                        }
235                    }
236                    _ => {}
237                };
238            }
239        }
240
241        match (gen_null_in_left, gen_null_in_right) {
242            (true, true) => JoinType::FullOuter,
243            (true, false) => JoinType::RightOuter,
244            (false, true) => JoinType::LeftOuter,
245            (false, false) => JoinType::Inner,
246        }
247    }
248
249    /// Index Join:
250    /// Try to convert logical join into batch lookup join and meanwhile it will do
251    /// the index selection for the lookup table so that we can benefit from indexes.
252    fn to_batch_lookup_join_with_index_selection(
253        &self,
254        predicate: EqJoinPredicate,
255        logical_join: generic::Join<PlanRef>,
256    ) -> Result<Option<BatchLookupJoin>> {
257        match logical_join.join_type {
258            JoinType::Inner
259            | JoinType::LeftOuter
260            | JoinType::LeftSemi
261            | JoinType::LeftAnti
262            | JoinType::AsofInner
263            | JoinType::AsofLeftOuter => {}
264            _ => return Ok(None),
265        };
266
267        // Index selection for index join.
268        let right = self.right();
269        // Lookup Join only supports basic tables on the join's right side.
270        let logical_scan: &LogicalScan = if let Some(logical_scan) = right.as_logical_scan() {
271            logical_scan
272        } else {
273            return Ok(None);
274        };
275
276        let mut result_plan = None;
277        // Lookup primary table.
278        if let Some(lookup_join) =
279            self.to_batch_lookup_join(predicate.clone(), logical_join.clone())?
280        {
281            result_plan = Some(lookup_join);
282        }
283
284        let indexes = logical_scan.indexes();
285        for index in indexes {
286            if let Some(index_scan) = logical_scan.to_index_scan_if_index_covered(index) {
287                let index_scan: PlanRef = index_scan.into();
288                let that = self.clone_with_left_right(self.left(), index_scan.clone());
289                let mut new_logical_join = logical_join.clone();
290                new_logical_join.right = index_scan.to_batch().expect("index scan failed to batch");
291
292                // Lookup covered index.
293                if let Some(lookup_join) =
294                    that.to_batch_lookup_join(predicate.clone(), new_logical_join)?
295                {
296                    match &result_plan {
297                        None => result_plan = Some(lookup_join),
298                        Some(prev_lookup_join) => {
299                            // Prefer to choose lookup join with longer lookup prefix len.
300                            if prev_lookup_join.lookup_prefix_len()
301                                < lookup_join.lookup_prefix_len()
302                            {
303                                result_plan = Some(lookup_join)
304                            }
305                        }
306                    }
307                }
308            }
309        }
310
311        Ok(result_plan)
312    }
313
314    /// Try to convert logical join into batch lookup join.
315    fn to_batch_lookup_join(
316        &self,
317        predicate: EqJoinPredicate,
318        logical_join: generic::Join<PlanRef>,
319    ) -> Result<Option<BatchLookupJoin>> {
320        match logical_join.join_type {
321            JoinType::Inner
322            | JoinType::LeftOuter
323            | JoinType::LeftSemi
324            | JoinType::LeftAnti
325            | JoinType::AsofInner
326            | JoinType::AsofLeftOuter => {}
327            _ => return Ok(None),
328        };
329
330        let right = self.right();
331        // Lookup Join only supports basic tables on the join's right side.
332        let logical_scan: &LogicalScan = if let Some(logical_scan) = right.as_logical_scan() {
333            logical_scan
334        } else {
335            return Ok(None);
336        };
337        let table_desc = logical_scan.table_desc().clone();
338        let output_column_ids = logical_scan.output_column_ids();
339
340        // Verify that the right join key columns are the the prefix of the primary key and
341        // also contain the distribution key.
342        let order_col_ids = table_desc.order_column_ids();
343        let order_key = table_desc.order_column_indices();
344        let dist_key = table_desc.distribution_key.clone();
345        // The at least prefix of order key that contains distribution key.
346        let mut dist_key_in_order_key_pos = vec![];
347        for d in dist_key {
348            let pos = order_key
349                .iter()
350                .position(|&x| x == d)
351                .expect("dist_key must in order_key");
352            dist_key_in_order_key_pos.push(pos);
353        }
354        // The shortest prefix of order key that contains distribution key.
355        let shortest_prefix_len = dist_key_in_order_key_pos
356            .iter()
357            .max()
358            .map_or(0, |pos| pos + 1);
359
360        // Distributed lookup join can't support lookup table with a singleton distribution.
361        if shortest_prefix_len == 0 {
362            return Ok(None);
363        }
364
365        // Reorder the join equal predicate to match the order key.
366        let mut reorder_idx = Vec::with_capacity(shortest_prefix_len);
367        for order_col_id in order_col_ids {
368            let mut found = false;
369            for (i, eq_idx) in predicate.right_eq_indexes().into_iter().enumerate() {
370                if order_col_id == output_column_ids[eq_idx] {
371                    reorder_idx.push(i);
372                    found = true;
373                    break;
374                }
375            }
376            if !found {
377                break;
378            }
379        }
380        if reorder_idx.len() < shortest_prefix_len {
381            return Ok(None);
382        }
383        let lookup_prefix_len = reorder_idx.len();
384        let predicate = predicate.reorder(&reorder_idx);
385
386        // Extract the predicate from logical scan. Only pure scan is supported.
387        let (new_scan, scan_predicate, project_expr) = logical_scan.predicate_pull_up();
388        // Construct output column to require column mapping
389        let o2r = if let Some(project_expr) = project_expr {
390            project_expr
391                .into_iter()
392                .map(|x| x.as_input_ref().unwrap().index)
393                .collect_vec()
394        } else {
395            (0..logical_scan.output_col_idx().len()).collect_vec()
396        };
397        let left_schema_len = logical_join.left.schema().len();
398
399        let mut join_predicate_rewriter = LookupJoinPredicateRewriter {
400            offset: left_schema_len,
401            mapping: o2r.clone(),
402        };
403
404        let new_eq_cond = predicate
405            .eq_cond()
406            .rewrite_expr(&mut join_predicate_rewriter);
407
408        let mut scan_predicate_rewriter = LookupJoinScanPredicateRewriter {
409            offset: left_schema_len,
410        };
411
412        let new_other_cond = predicate
413            .other_cond()
414            .clone()
415            .rewrite_expr(&mut join_predicate_rewriter)
416            .and(scan_predicate.rewrite_expr(&mut scan_predicate_rewriter));
417
418        let new_join_on = new_eq_cond.and(new_other_cond);
419        let new_predicate = EqJoinPredicate::create(
420            left_schema_len,
421            new_scan.schema().len(),
422            new_join_on.clone(),
423        );
424
425        // We discovered that we cannot use a lookup join after pulling up the predicate
426        // from one side and simplifying the condition. Let's use some other join instead.
427        if !new_predicate.has_eq() {
428            return Ok(None);
429        }
430
431        // Rewrite the join output indices and all output indices referred to the old scan need to
432        // rewrite.
433        let new_join_output_indices = logical_join
434            .output_indices
435            .iter()
436            .map(|&x| {
437                if x < left_schema_len {
438                    x
439                } else {
440                    o2r[x - left_schema_len] + left_schema_len
441                }
442            })
443            .collect_vec();
444
445        let new_scan_output_column_ids = new_scan.output_column_ids();
446        let as_of = new_scan.as_of.clone();
447
448        // Construct a new logical join, because we have change its RHS.
449        let new_logical_join = generic::Join::new(
450            logical_join.left,
451            new_scan.into(),
452            new_join_on,
453            logical_join.join_type,
454            new_join_output_indices,
455        );
456
457        let asof_desc = self
458            .is_asof_join()
459            .then(|| {
460                Self::get_inequality_desc_from_predicate(
461                    predicate.other_cond().clone(),
462                    left_schema_len,
463                )
464            })
465            .transpose()?;
466
467        Ok(Some(BatchLookupJoin::new(
468            new_logical_join,
469            new_predicate,
470            table_desc,
471            new_scan_output_column_ids,
472            lookup_prefix_len,
473            false,
474            as_of,
475            asof_desc,
476        )))
477    }
478
479    pub fn decompose(self) -> (PlanRef, PlanRef, Condition, JoinType, Vec<usize>) {
480        self.core.decompose()
481    }
482}
483
484impl PlanTreeNodeBinary for LogicalJoin {
485    fn left(&self) -> PlanRef {
486        self.core.left.clone()
487    }
488
489    fn right(&self) -> PlanRef {
490        self.core.right.clone()
491    }
492
493    fn clone_with_left_right(&self, left: PlanRef, right: PlanRef) -> Self {
494        Self::with_core(generic::Join {
495            left,
496            right,
497            ..self.core.clone()
498        })
499    }
500
501    fn rewrite_with_left_right(
502        &self,
503        left: PlanRef,
504        left_col_change: ColIndexMapping,
505        right: PlanRef,
506        right_col_change: ColIndexMapping,
507    ) -> (Self, ColIndexMapping) {
508        let (new_on, new_output_indices) = {
509            let (mut map, _) = left_col_change.clone().into_parts();
510            let (mut right_map, _) = right_col_change.clone().into_parts();
511            for i in right_map.iter_mut().flatten() {
512                *i += left.schema().len();
513            }
514            map.append(&mut right_map);
515            let mut mapping = ColIndexMapping::new(map, left.schema().len() + right.schema().len());
516
517            let new_output_indices = self
518                .output_indices()
519                .iter()
520                .map(|&i| mapping.map(i))
521                .collect::<Vec<_>>();
522            let new_on = self.on().clone().rewrite_expr(&mut mapping);
523            (new_on, new_output_indices)
524        };
525
526        let join = Self::with_output_indices(
527            left,
528            right,
529            self.join_type(),
530            new_on,
531            new_output_indices.clone(),
532        );
533
534        let new_i2o = ColIndexMapping::with_remaining_columns(
535            &new_output_indices,
536            join.internal_column_num(),
537        );
538
539        let old_o2i = self.core.o2i_col_mapping();
540
541        let old_o2l = old_o2i
542            .composite(&self.core.i2l_col_mapping())
543            .composite(&left_col_change);
544        let old_o2r = old_o2i
545            .composite(&self.core.i2r_col_mapping())
546            .composite(&right_col_change);
547        let new_l2o = join.core.l2i_col_mapping().composite(&new_i2o);
548        let new_r2o = join.core.r2i_col_mapping().composite(&new_i2o);
549
550        let out_col_change = old_o2l
551            .composite(&new_l2o)
552            .union(&old_o2r.composite(&new_r2o));
553        (join, out_col_change)
554    }
555}
556
557impl_plan_tree_node_for_binary! { LogicalJoin }
558
559impl ColPrunable for LogicalJoin {
560    fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef {
561        // make `required_cols` point to internal table instead of output schema.
562        let required_cols = required_cols
563            .iter()
564            .map(|i| self.output_indices()[*i])
565            .collect_vec();
566        let left_len = self.left().schema().fields.len();
567
568        let total_len = self.left().schema().len() + self.right().schema().len();
569        let mut resized_required_cols = FixedBitSet::with_capacity(total_len);
570
571        required_cols.iter().for_each(|&i| {
572            if self.is_right_join() {
573                resized_required_cols.insert(left_len + i);
574            } else {
575                resized_required_cols.insert(i);
576            }
577        });
578
579        // add those columns which are required in the join condition to
580        // to those that are required in the output
581        let mut visitor = CollectInputRef::new(resized_required_cols);
582        self.on().visit_expr(&mut visitor);
583        let left_right_required_cols = FixedBitSet::from(visitor).ones().collect_vec();
584
585        let mut left_required_cols = Vec::new();
586        let mut right_required_cols = Vec::new();
587        left_right_required_cols.iter().for_each(|&i| {
588            if i < left_len {
589                left_required_cols.push(i);
590            } else {
591                right_required_cols.push(i - left_len);
592            }
593        });
594
595        let mut on = self.on().clone();
596        let mut mapping =
597            ColIndexMapping::with_remaining_columns(&left_right_required_cols, total_len);
598        on = on.rewrite_expr(&mut mapping);
599
600        let new_output_indices = {
601            let required_inputs_in_output = if self.is_left_join() {
602                &left_required_cols
603            } else if self.is_right_join() {
604                &right_required_cols
605            } else {
606                &left_right_required_cols
607            };
608
609            let mapping =
610                ColIndexMapping::with_remaining_columns(required_inputs_in_output, total_len);
611            required_cols.iter().map(|&i| mapping.map(i)).collect_vec()
612        };
613
614        LogicalJoin::with_output_indices(
615            self.left().prune_col(&left_required_cols, ctx),
616            self.right().prune_col(&right_required_cols, ctx),
617            self.join_type(),
618            on,
619            new_output_indices,
620        )
621        .into()
622    }
623}
624
625impl ExprRewritable for LogicalJoin {
626    fn has_rewritable_expr(&self) -> bool {
627        true
628    }
629
630    fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
631        let mut core = self.core.clone();
632        core.rewrite_exprs(r);
633        Self {
634            base: self.base.clone_with_new_plan_id(),
635            core,
636        }
637        .into()
638    }
639}
640
641impl ExprVisitable for LogicalJoin {
642    fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
643        self.core.visit_exprs(v);
644    }
645}
646
647/// We are trying to derive a predicate to apply to the other side of a join if all
648/// the `InputRef`s in the predicate are eq condition columns, and can hence be substituted
649/// with the corresponding eq condition columns of the other side.
650///
651/// Strategy:
652/// 1. If the function is pure except for any `InputRef` (which may refer to impure computation),
653///    then we proceed. Else abort.
654/// 2. Then, we collect `InputRef`s in the conjunction.
655/// 3. If they are all columns in the given side of join eq condition, then we proceed. Else abort.
656/// 4. We then rewrite the `ExprImpl`, by replacing `InputRef` column indices with the equivalent in
657///    the other side.
658///
659/// # Arguments
660///
661/// Suppose we derive a predicate from the left side to be pushed to the right side.
662/// * `expr`: An expr from the left side.
663/// * `col_num`: The number of columns in the left side.
664fn derive_predicate_from_eq_condition(
665    expr: &ExprImpl,
666    eq_condition: &EqJoinPredicate,
667    col_num: usize,
668    expr_is_left: bool,
669) -> Option<ExprImpl> {
670    if expr.is_impure() {
671        return None;
672    }
673    let eq_indices = eq_condition
674        .eq_indexes_typed()
675        .iter()
676        .filter_map(|(l, r)| {
677            if l.return_type() != r.return_type() {
678                None
679            } else if expr_is_left {
680                Some(l.index())
681            } else {
682                Some(r.index())
683            }
684        })
685        .collect_vec();
686    if expr
687        .collect_input_refs(col_num)
688        .ones()
689        .any(|index| !eq_indices.contains(&index))
690    {
691        // expr contains an InputRef not in eq_condition
692        return None;
693    }
694    // The function is pure except for `InputRef` and all `InputRef`s are `eq_condition` indices.
695    // Hence, we can substitute those `InputRef`s with indices from the other side.
696    let other_side_mapping = if expr_is_left {
697        eq_condition.eq_indexes_typed().into_iter().collect()
698    } else {
699        eq_condition
700            .eq_indexes_typed()
701            .into_iter()
702            .map(|(x, y)| (y, x))
703            .collect()
704    };
705    struct InputRefsRewriter {
706        mapping: HashMap<InputRef, InputRef>,
707    }
708    impl ExprRewriter for InputRefsRewriter {
709        fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
710            self.mapping[&input_ref].clone().into()
711        }
712    }
713    Some(
714        InputRefsRewriter {
715            mapping: other_side_mapping,
716        }
717        .rewrite_expr(expr.clone()),
718    )
719}
720
721/// Rewrite the join predicate and all columns referred to the scan side need to rewrite.
722struct LookupJoinPredicateRewriter {
723    offset: usize,
724    mapping: Vec<usize>,
725}
726impl ExprRewriter for LookupJoinPredicateRewriter {
727    fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
728        if input_ref.index() < self.offset {
729            input_ref.into()
730        } else {
731            InputRef::new(
732                self.mapping[input_ref.index() - self.offset] + self.offset,
733                input_ref.return_type(),
734            )
735            .into()
736        }
737    }
738}
739
740/// Rewrite the scan predicate so we can add it to the join predicate.
741struct LookupJoinScanPredicateRewriter {
742    offset: usize,
743}
744impl ExprRewriter for LookupJoinScanPredicateRewriter {
745    fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
746        InputRef::new(input_ref.index() + self.offset, input_ref.return_type()).into()
747    }
748}
749
750impl PredicatePushdown for LogicalJoin {
751    /// Pushes predicates above and within a join node into the join node and/or its children nodes.
752    ///
753    /// # Which predicates can be pushed
754    ///
755    /// For inner join, we can do all kinds of pushdown.
756    ///
757    /// For left/right semi join, we can push filter to left/right and on-clause,
758    /// and push on-clause to left/right.
759    ///
760    /// For left/right anti join, we can push filter to left/right, but on-clause can not be pushed
761    ///
762    /// ## Outer Join
763    ///
764    /// Preserved Row table
765    /// : The table in an Outer Join that must return all rows.
766    ///
767    /// Null Supplying table
768    /// : This is the table that has nulls filled in for its columns in unmatched rows.
769    ///
770    /// |                          | Preserved Row table | Null Supplying table |
771    /// |--------------------------|---------------------|----------------------|
772    /// | Join predicate (on)      | Not Pushed          | Pushed               |
773    /// | Where predicate (filter) | Pushed              | Not Pushed           |
774    fn predicate_pushdown(
775        &self,
776        predicate: Condition,
777        ctx: &mut PredicatePushdownContext,
778    ) -> PlanRef {
779        // rewrite output col referencing indices as internal cols
780        let mut predicate = {
781            let mut mapping = self.core.o2i_col_mapping();
782            predicate.rewrite_expr(&mut mapping)
783        };
784
785        let left_col_num = self.left().schema().len();
786        let right_col_num = self.right().schema().len();
787        let join_type = LogicalJoin::simplify_outer(&predicate, left_col_num, self.join_type());
788
789        let push_down_temporal_predicate = !self.should_be_temporal_join();
790
791        let (left_from_filter, right_from_filter, on) = push_down_into_join(
792            &mut predicate,
793            left_col_num,
794            right_col_num,
795            join_type,
796            push_down_temporal_predicate,
797        );
798
799        let mut new_on = self.on().clone().and(on);
800        let (left_from_on, right_from_on) = push_down_join_condition(
801            &mut new_on,
802            left_col_num,
803            right_col_num,
804            join_type,
805            push_down_temporal_predicate,
806        );
807
808        let left_predicate = left_from_filter.and(left_from_on);
809        let right_predicate = right_from_filter.and(right_from_on);
810
811        // Derive conditions to push to the other side based on eq condition columns
812        let eq_condition = EqJoinPredicate::create(left_col_num, right_col_num, new_on.clone());
813
814        // Only push to RHS if RHS is inner side of a join (RHS requires match on LHS)
815        let right_from_left = if matches!(
816            join_type,
817            JoinType::Inner | JoinType::LeftOuter | JoinType::RightSemi | JoinType::LeftSemi
818        ) {
819            Condition {
820                conjunctions: left_predicate
821                    .conjunctions
822                    .iter()
823                    .filter_map(|expr| {
824                        derive_predicate_from_eq_condition(expr, &eq_condition, left_col_num, true)
825                    })
826                    .collect(),
827            }
828        } else {
829            Condition::true_cond()
830        };
831
832        // Only push to LHS if LHS is inner side of a join (LHS requires match on RHS)
833        let left_from_right = if matches!(
834            join_type,
835            JoinType::Inner | JoinType::RightOuter | JoinType::LeftSemi | JoinType::RightSemi
836        ) {
837            Condition {
838                conjunctions: right_predicate
839                    .conjunctions
840                    .iter()
841                    .filter_map(|expr| {
842                        derive_predicate_from_eq_condition(
843                            expr,
844                            &eq_condition,
845                            right_col_num,
846                            false,
847                        )
848                    })
849                    .collect(),
850            }
851        } else {
852            Condition::true_cond()
853        };
854
855        let left_predicate = left_predicate.and(left_from_right);
856        let right_predicate = right_predicate.and(right_from_left);
857
858        let new_left = self.left().predicate_pushdown(left_predicate, ctx);
859        let new_right = self.right().predicate_pushdown(right_predicate, ctx);
860        let new_join = LogicalJoin::with_output_indices(
861            new_left,
862            new_right,
863            join_type,
864            new_on,
865            self.output_indices().clone(),
866        );
867
868        let mut mapping = self.core.i2o_col_mapping();
869        predicate = predicate.rewrite_expr(&mut mapping);
870        LogicalFilter::create(new_join.into(), predicate)
871    }
872}
873
874impl LogicalJoin {
875    fn get_stream_input_for_hash_join(
876        &self,
877        predicate: &EqJoinPredicate,
878        ctx: &mut ToStreamContext,
879    ) -> Result<(PlanRef, PlanRef)> {
880        use super::stream::prelude::*;
881
882        let mut right = self.right().to_stream_with_dist_required(
883            &RequiredDist::shard_by_key(self.right().schema().len(), &predicate.right_eq_indexes()),
884            ctx,
885        )?;
886        let mut left = self.left();
887
888        let r2l = predicate.r2l_eq_columns_mapping(left.schema().len(), right.schema().len());
889        let l2r = predicate.l2r_eq_columns_mapping(left.schema().len(), right.schema().len());
890
891        let right_dist = right.distribution();
892        match right_dist {
893            Distribution::HashShard(_) => {
894                let left_dist = r2l
895                    .rewrite_required_distribution(&RequiredDist::PhysicalDist(right_dist.clone()));
896                left = left.to_stream_with_dist_required(&left_dist, ctx)?;
897            }
898            Distribution::UpstreamHashShard(_, _) => {
899                left = left.to_stream_with_dist_required(
900                    &RequiredDist::shard_by_key(
901                        self.left().schema().len(),
902                        &predicate.left_eq_indexes(),
903                    ),
904                    ctx,
905                )?;
906                let left_dist = left.distribution();
907                match left_dist {
908                    Distribution::HashShard(_) => {
909                        let right_dist = l2r.rewrite_required_distribution(
910                            &RequiredDist::PhysicalDist(left_dist.clone()),
911                        );
912                        right = right_dist.enforce_if_not_satisfies(right, &Order::any())?
913                    }
914                    Distribution::UpstreamHashShard(_, _) => {
915                        left = RequiredDist::hash_shard(&predicate.left_eq_indexes())
916                            .enforce_if_not_satisfies(left, &Order::any())?;
917                        right = RequiredDist::hash_shard(&predicate.right_eq_indexes())
918                            .enforce_if_not_satisfies(right, &Order::any())?;
919                    }
920                    _ => unreachable!(),
921                }
922            }
923            _ => unreachable!(),
924        }
925        Ok((left, right))
926    }
927
928    fn to_stream_hash_join(
929        &self,
930        predicate: EqJoinPredicate,
931        ctx: &mut ToStreamContext,
932    ) -> Result<PlanRef> {
933        use super::stream::prelude::*;
934
935        assert!(predicate.has_eq());
936        let (left, right) = self.get_stream_input_for_hash_join(&predicate, ctx)?;
937
938        let logical_join = self.clone_with_left_right(left, right);
939
940        // Convert to Hash Join for equal joins
941        // For inner joins, pull non-equal conditions to a filter operator on top of it by default.
942        // We do so as the filter operator can apply the non-equal condition batch-wise (vectorized)
943        // as opposed to the HashJoin, which applies the condition row-wise.
944        // However, the default behavior of pulling up non-equal conditions can be overridden by the
945        // session variable `streaming_force_filter_inside_join` as it can save unnecessary
946        // materialization of rows only to be filtered later.
947
948        let stream_hash_join = StreamHashJoin::new(logical_join.core.clone(), predicate.clone());
949
950        let force_filter_inside_join = self
951            .base
952            .ctx()
953            .session_ctx()
954            .config()
955            .streaming_force_filter_inside_join();
956
957        let pull_filter = self.join_type() == JoinType::Inner
958            && stream_hash_join.eq_join_predicate().has_non_eq()
959            && stream_hash_join.inequality_pairs().is_empty()
960            && (!force_filter_inside_join);
961        if pull_filter {
962            let default_indices = (0..self.internal_column_num()).collect::<Vec<_>>();
963
964            // Temporarily remove output indices.
965            let logical_join = logical_join.clone_with_output_indices(default_indices.clone());
966            let eq_cond = EqJoinPredicate::new(
967                Condition::true_cond(),
968                predicate.eq_keys().to_vec(),
969                self.left().schema().len(),
970                self.right().schema().len(),
971            );
972            let logical_join = logical_join.clone_with_cond(eq_cond.eq_cond());
973            let hash_join = StreamHashJoin::new(logical_join.core, eq_cond).into();
974            let logical_filter = generic::Filter::new(predicate.non_eq_cond(), hash_join);
975            let plan = StreamFilter::new(logical_filter).into();
976            if self.output_indices() != &default_indices {
977                let logical_project = generic::Project::with_mapping(
978                    plan,
979                    ColIndexMapping::with_remaining_columns(
980                        self.output_indices(),
981                        self.internal_column_num(),
982                    ),
983                );
984                Ok(StreamProject::new(logical_project).into())
985            } else {
986                Ok(plan)
987            }
988        } else {
989            Ok(stream_hash_join.into())
990        }
991    }
992
993    fn should_be_temporal_join(&self) -> bool {
994        let right = self.right();
995        if let Some(logical_scan) = right.as_logical_scan() {
996            matches!(logical_scan.as_of(), Some(AsOf::ProcessTime))
997        } else {
998            false
999        }
1000    }
1001
1002    fn to_stream_temporal_join_with_index_selection(
1003        &self,
1004        predicate: EqJoinPredicate,
1005        ctx: &mut ToStreamContext,
1006    ) -> Result<PlanRef> {
1007        // Index selection for temporal join.
1008        let right = self.right();
1009        // `should_be_temporal_join()` has already check right input for us.
1010        let logical_scan: &LogicalScan = right.as_logical_scan().unwrap();
1011
1012        // Use primary table.
1013        let mut result_plan: Result<StreamTemporalJoin> =
1014            self.to_stream_temporal_join(predicate.clone(), ctx);
1015        // Return directly if this temporal join can match the pk of its right table.
1016        if let Ok(temporal_join) = &result_plan
1017            && temporal_join.eq_join_predicate().eq_indexes().len()
1018                == logical_scan.primary_key().len()
1019        {
1020            return result_plan.map(|x| x.into());
1021        }
1022        let indexes = logical_scan.indexes();
1023        for index in indexes {
1024            // Use index table
1025            if let Some(index_scan) = logical_scan.to_index_scan_if_index_covered(index) {
1026                let index_scan: PlanRef = index_scan.into();
1027                let that = self.clone_with_left_right(self.left(), index_scan.clone());
1028                if let Ok(temporal_join) = that.to_stream_temporal_join(predicate.clone(), ctx) {
1029                    match &result_plan {
1030                        Err(_) => result_plan = Ok(temporal_join),
1031                        Ok(prev_temporal_join) => {
1032                            // Prefer to the temporal join with a longer lookup prefix len.
1033                            if prev_temporal_join.eq_join_predicate().eq_indexes().len()
1034                                < temporal_join.eq_join_predicate().eq_indexes().len()
1035                            {
1036                                result_plan = Ok(temporal_join)
1037                            }
1038                        }
1039                    }
1040                }
1041            }
1042        }
1043
1044        result_plan.map(|x| x.into())
1045    }
1046
1047    fn check_temporal_rhs(right: &PlanRef) -> Result<&LogicalScan> {
1048        let Some(logical_scan) = right.as_logical_scan() else {
1049            return Err(RwError::from(ErrorCode::NotSupported(
1050                "Temporal join requires a table scan as its lookup table".into(),
1051                "Please provide a table scan".into(),
1052            )));
1053        };
1054
1055        if !matches!(logical_scan.as_of(), Some(AsOf::ProcessTime)) {
1056            return Err(RwError::from(ErrorCode::NotSupported(
1057                "Temporal join requires a table defined as temporal table".into(),
1058                "Please use FOR SYSTEM_TIME AS OF PROCTIME() syntax".into(),
1059            )));
1060        }
1061        Ok(logical_scan)
1062    }
1063
1064    fn temporal_join_scan_predicate_pull_up(
1065        logical_scan: &LogicalScan,
1066        predicate: EqJoinPredicate,
1067        output_indices: &[usize],
1068        left_schema_len: usize,
1069    ) -> Result<(StreamTableScan, EqJoinPredicate, Condition, Vec<usize>)> {
1070        // Extract the predicate from logical scan. Only pure scan is supported.
1071        let (new_scan, scan_predicate, project_expr) = logical_scan.predicate_pull_up();
1072        // Construct output column to require column mapping
1073        let o2r = if let Some(project_expr) = project_expr {
1074            project_expr
1075                .into_iter()
1076                .map(|x| x.as_input_ref().unwrap().index)
1077                .collect_vec()
1078        } else {
1079            (0..logical_scan.output_col_idx().len()).collect_vec()
1080        };
1081        let mut join_predicate_rewriter = LookupJoinPredicateRewriter {
1082            offset: left_schema_len,
1083            mapping: o2r.clone(),
1084        };
1085
1086        let new_eq_cond = predicate
1087            .eq_cond()
1088            .rewrite_expr(&mut join_predicate_rewriter);
1089
1090        let mut scan_predicate_rewriter = LookupJoinScanPredicateRewriter {
1091            offset: left_schema_len,
1092        };
1093
1094        let new_other_cond = predicate
1095            .other_cond()
1096            .clone()
1097            .rewrite_expr(&mut join_predicate_rewriter)
1098            .and(scan_predicate.rewrite_expr(&mut scan_predicate_rewriter));
1099
1100        let new_join_on = new_eq_cond.and(new_other_cond);
1101
1102        let new_predicate = EqJoinPredicate::create(
1103            left_schema_len,
1104            new_scan.schema().len(),
1105            new_join_on.clone(),
1106        );
1107
1108        // Rewrite the join output indices and all output indices referred to the old scan need to
1109        // rewrite.
1110        let new_join_output_indices = output_indices
1111            .iter()
1112            .map(|&x| {
1113                if x < left_schema_len {
1114                    x
1115                } else {
1116                    o2r[x - left_schema_len] + left_schema_len
1117                }
1118            })
1119            .collect_vec();
1120        // Use UpstreamOnly chain type
1121        let new_stream_table_scan =
1122            StreamTableScan::new_with_stream_scan_type(new_scan, StreamScanType::UpstreamOnly);
1123        Ok((
1124            new_stream_table_scan,
1125            new_predicate,
1126            new_join_on,
1127            new_join_output_indices,
1128        ))
1129    }
1130
1131    fn to_stream_temporal_join(
1132        &self,
1133        predicate: EqJoinPredicate,
1134        ctx: &mut ToStreamContext,
1135    ) -> Result<StreamTemporalJoin> {
1136        use super::stream::prelude::*;
1137
1138        assert!(predicate.has_eq());
1139
1140        let right = self.right();
1141
1142        let logical_scan = Self::check_temporal_rhs(&right)?;
1143
1144        let table_desc = logical_scan.table_desc();
1145        let output_column_ids = logical_scan.output_column_ids();
1146
1147        // Verify that the right join key columns are the the prefix of the primary key and
1148        // also contain the distribution key.
1149        let order_col_ids = table_desc.order_column_ids();
1150        let order_key = table_desc.order_column_indices();
1151        let dist_key = table_desc.distribution_key.clone();
1152
1153        let mut dist_key_in_order_key_pos = vec![];
1154        for d in dist_key {
1155            let pos = order_key
1156                .iter()
1157                .position(|&x| x == d)
1158                .expect("dist_key must in order_key");
1159            dist_key_in_order_key_pos.push(pos);
1160        }
1161        // The shortest prefix of order key that contains distribution key.
1162        let shortest_prefix_len = dist_key_in_order_key_pos
1163            .iter()
1164            .max()
1165            .map_or(0, |pos| pos + 1);
1166
1167        // Reorder the join equal predicate to match the order key.
1168        let mut reorder_idx = Vec::with_capacity(shortest_prefix_len);
1169        for order_col_id in order_col_ids {
1170            let mut found = false;
1171            for (i, eq_idx) in predicate.right_eq_indexes().into_iter().enumerate() {
1172                if order_col_id == output_column_ids[eq_idx] {
1173                    reorder_idx.push(i);
1174                    found = true;
1175                    break;
1176                }
1177            }
1178            if !found {
1179                break;
1180            }
1181        }
1182        if reorder_idx.len() < shortest_prefix_len {
1183            // TODO: support index selection for temporal join and refine this error message.
1184            return Err(RwError::from(ErrorCode::NotSupported(
1185                "Temporal join requires the lookup table's primary key contained exactly in the equivalence condition".into(),
1186                "Please add the primary key of the lookup table to the join condition and remove any other conditions".into(),
1187            )));
1188        }
1189        let lookup_prefix_len = reorder_idx.len();
1190        let predicate = predicate.reorder(&reorder_idx);
1191
1192        let required_dist = if dist_key_in_order_key_pos.is_empty() {
1193            RequiredDist::single()
1194        } else {
1195            let left_eq_indexes = predicate.left_eq_indexes();
1196            let left_dist_key = dist_key_in_order_key_pos
1197                .iter()
1198                .map(|pos| left_eq_indexes[*pos])
1199                .collect_vec();
1200
1201            RequiredDist::hash_shard(&left_dist_key)
1202        };
1203
1204        let left = self.left().to_stream(ctx)?;
1205        // Enforce a shuffle for the temporal join LHS to let the scheduler be able to schedule the join fragment together with the RHS with a `no_shuffle` exchange.
1206        let left = required_dist.enforce(left, &Order::any());
1207
1208        let (new_stream_table_scan, new_predicate, new_join_on, new_join_output_indices) =
1209            Self::temporal_join_scan_predicate_pull_up(
1210                logical_scan,
1211                predicate,
1212                self.output_indices(),
1213                self.left().schema().len(),
1214            )?;
1215
1216        let right = RequiredDist::no_shuffle(new_stream_table_scan.into());
1217        if !new_predicate.has_eq() {
1218            return Err(RwError::from(ErrorCode::NotSupported(
1219                "Temporal join requires a non trivial join condition".into(),
1220                "Please remove the false condition of the join".into(),
1221            )));
1222        }
1223
1224        // Construct a new logical join, because we have change its RHS.
1225        let new_logical_join = generic::Join::new(
1226            left,
1227            right,
1228            new_join_on,
1229            self.join_type(),
1230            new_join_output_indices,
1231        );
1232
1233        let new_predicate = new_predicate.retain_prefix_eq_key(lookup_prefix_len);
1234
1235        Ok(StreamTemporalJoin::new(
1236            new_logical_join,
1237            new_predicate,
1238            false,
1239        ))
1240    }
1241
1242    fn to_stream_nested_loop_temporal_join(
1243        &self,
1244        predicate: EqJoinPredicate,
1245        ctx: &mut ToStreamContext,
1246    ) -> Result<PlanRef> {
1247        use super::stream::prelude::*;
1248        assert!(!predicate.has_eq());
1249
1250        let left = self.left().to_stream_with_dist_required(
1251            &RequiredDist::PhysicalDist(Distribution::Broadcast),
1252            ctx,
1253        )?;
1254        assert!(left.as_stream_exchange().is_some());
1255
1256        if self.join_type() != JoinType::Inner {
1257            return Err(RwError::from(ErrorCode::NotSupported(
1258                "Temporal join requires an inner join".into(),
1259                "Please use an inner join".into(),
1260            )));
1261        }
1262
1263        if !left.append_only() {
1264            return Err(RwError::from(ErrorCode::NotSupported(
1265                "Nested-loop Temporal join requires the left hash side to be append only".into(),
1266                "Please ensure the left hash side is append only".into(),
1267            )));
1268        }
1269
1270        let right = self.right();
1271        let logical_scan = Self::check_temporal_rhs(&right)?;
1272
1273        let (new_stream_table_scan, new_predicate, new_join_on, new_join_output_indices) =
1274            Self::temporal_join_scan_predicate_pull_up(
1275                logical_scan,
1276                predicate,
1277                self.output_indices(),
1278                self.left().schema().len(),
1279            )?;
1280
1281        let right = RequiredDist::no_shuffle(new_stream_table_scan.into());
1282
1283        // Construct a new logical join, because we have change its RHS.
1284        let new_logical_join = generic::Join::new(
1285            left,
1286            right,
1287            new_join_on,
1288            self.join_type(),
1289            new_join_output_indices,
1290        );
1291
1292        Ok(StreamTemporalJoin::new(new_logical_join, new_predicate, true).into())
1293    }
1294
1295    fn to_stream_dynamic_filter(
1296        &self,
1297        predicate: Condition,
1298        ctx: &mut ToStreamContext,
1299    ) -> Result<Option<PlanRef>> {
1300        use super::stream::prelude::*;
1301
1302        // If there is exactly one predicate, it is a comparison (<, <=, >, >=), and the
1303        // join is a `Inner` or `LeftSemi` join, we can convert the scalar subquery into a
1304        // `StreamDynamicFilter`
1305
1306        // Check if `Inner`/`LeftSemi`
1307        if !matches!(self.join_type(), JoinType::Inner | JoinType::LeftSemi) {
1308            return Ok(None);
1309        }
1310
1311        // Check if right side is a scalar
1312        if !self.right().max_one_row() {
1313            return Ok(None);
1314        }
1315        if self.right().schema().len() != 1 {
1316            return Ok(None);
1317        }
1318
1319        // Check if the join condition is a correlated comparison
1320        if predicate.conjunctions.len() > 1 {
1321            return Ok(None);
1322        }
1323        let expr: ExprImpl = predicate.into();
1324        let (left_ref, comparator, right_ref) = match expr.as_comparison_cond() {
1325            Some(v) => v,
1326            None => return Ok(None),
1327        };
1328
1329        let condition_cross_inputs = left_ref.index < self.left().schema().len()
1330            && right_ref.index == self.left().schema().len() /* right side has only one column */;
1331        if !condition_cross_inputs {
1332            // Maybe we should panic here because it means some predicates are not pushed down.
1333            return Ok(None);
1334        }
1335
1336        // We align input types on all join predicates with cmp operator
1337        if self.left().schema().fields()[left_ref.index].data_type
1338            != self.right().schema().fields()[0].data_type
1339        {
1340            return Ok(None);
1341        }
1342
1343        // Check if non of the columns from the inner side is required to output
1344        let all_output_from_left = self
1345            .output_indices()
1346            .iter()
1347            .all(|i| *i < self.left().schema().len());
1348        if !all_output_from_left {
1349            return Ok(None);
1350        }
1351
1352        let left = self.left().to_stream(ctx)?;
1353        let right = self.right().to_stream_with_dist_required(
1354            &RequiredDist::PhysicalDist(Distribution::Broadcast),
1355            ctx,
1356        )?;
1357
1358        assert!(right.as_stream_exchange().is_some());
1359        assert_eq!(
1360            *right.inputs().iter().exactly_one().unwrap().distribution(),
1361            Distribution::Single
1362        );
1363
1364        let core = DynamicFilter::new(comparator, left_ref.index, left, right);
1365        let plan = StreamDynamicFilter::new(core).into();
1366        // TODO: `DynamicFilterExecutor` should support `output_indices` in `ChunkBuilder`
1367        if self
1368            .output_indices()
1369            .iter()
1370            .copied()
1371            .ne(0..self.left().schema().len())
1372        {
1373            // The schema of dynamic filter is always the same as the left side now, and we have
1374            // checked that all output columns are from the left side before.
1375            let logical_project = generic::Project::with_mapping(
1376                plan,
1377                ColIndexMapping::with_remaining_columns(
1378                    self.output_indices(),
1379                    self.left().schema().len(),
1380                ),
1381            );
1382            Ok(Some(StreamProject::new(logical_project).into()))
1383        } else {
1384            Ok(Some(plan))
1385        }
1386    }
1387
1388    pub fn index_lookup_join_to_batch_lookup_join(&self) -> Result<PlanRef> {
1389        let predicate = EqJoinPredicate::create(
1390            self.left().schema().len(),
1391            self.right().schema().len(),
1392            self.on().clone(),
1393        );
1394        assert!(predicate.has_eq());
1395
1396        let mut logical_join = self.core.clone();
1397        logical_join.left = logical_join.left.to_batch()?;
1398        logical_join.right = logical_join.right.to_batch()?;
1399
1400        Ok(self
1401            .to_batch_lookup_join(predicate, logical_join)?
1402            .expect("Fail to convert to lookup join")
1403            .into())
1404    }
1405
1406    fn to_stream_asof_join(
1407        &self,
1408        predicate: EqJoinPredicate,
1409        ctx: &mut ToStreamContext,
1410    ) -> Result<PlanRef> {
1411        use super::stream::prelude::*;
1412
1413        if predicate.eq_keys().is_empty() {
1414            return Err(ErrorCode::InvalidInputSyntax(
1415                "AsOf join requires at least 1 equal condition".to_owned(),
1416            )
1417            .into());
1418        }
1419
1420        let (left, right) = self.get_stream_input_for_hash_join(&predicate, ctx)?;
1421        let left_len = left.schema().len();
1422        let logical_join = self.clone_with_left_right(left, right);
1423
1424        let inequality_desc =
1425            Self::get_inequality_desc_from_predicate(predicate.other_cond().clone(), left_len)?;
1426
1427        Ok(StreamAsOfJoin::new(logical_join.core.clone(), predicate, inequality_desc).into())
1428    }
1429
1430    /// Convert the logical join to a Hash join.
1431    fn to_batch_hash_join(
1432        &self,
1433        logical_join: generic::Join<PlanRef>,
1434        predicate: EqJoinPredicate,
1435    ) -> Result<PlanRef> {
1436        use super::batch::prelude::*;
1437
1438        let left_schema_len = logical_join.left.schema().len();
1439        let asof_desc = self
1440            .is_asof_join()
1441            .then(|| {
1442                Self::get_inequality_desc_from_predicate(
1443                    predicate.other_cond().clone(),
1444                    left_schema_len,
1445                )
1446            })
1447            .transpose()?;
1448
1449        let batch_join = BatchHashJoin::new(logical_join, predicate, asof_desc);
1450        Ok(batch_join.into())
1451    }
1452
1453    pub fn get_inequality_desc_from_predicate(
1454        predicate: Condition,
1455        left_input_len: usize,
1456    ) -> Result<AsOfJoinDesc> {
1457        let expr: ExprImpl = predicate.into();
1458        if let Some((left_input_ref, expr_type, right_input_ref)) = expr.as_comparison_cond() {
1459            if left_input_ref.index() < left_input_len && right_input_ref.index() >= left_input_len
1460            {
1461                Ok(AsOfJoinDesc {
1462                    left_idx: left_input_ref.index() as u32,
1463                    right_idx: (right_input_ref.index() - left_input_len) as u32,
1464                    inequality_type: Self::expr_type_to_comparison_type(expr_type)?.into(),
1465                })
1466            } else {
1467                bail!("inequal condition from the same side should be push down in optimizer");
1468            }
1469        } else {
1470            Err(ErrorCode::InvalidInputSyntax(
1471                "AsOf join requires exactly 1 ineuquality condition".to_owned(),
1472            )
1473            .into())
1474        }
1475    }
1476
1477    fn expr_type_to_comparison_type(expr_type: PbType) -> Result<PbAsOfJoinInequalityType> {
1478        match expr_type {
1479            PbType::LessThan => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeLt),
1480            PbType::LessThanOrEqual => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeLe),
1481            PbType::GreaterThan => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeGt),
1482            PbType::GreaterThanOrEqual => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeGe),
1483            _ => Err(ErrorCode::InvalidInputSyntax(format!(
1484                "Invalid comparison type: {}",
1485                expr_type.as_str_name()
1486            ))
1487            .into()),
1488        }
1489    }
1490}
1491
1492impl ToBatch for LogicalJoin {
1493    fn to_batch(&self) -> Result<PlanRef> {
1494        let predicate = EqJoinPredicate::create(
1495            self.left().schema().len(),
1496            self.right().schema().len(),
1497            self.on().clone(),
1498        );
1499
1500        let mut logical_join = self.core.clone();
1501        logical_join.left = logical_join.left.to_batch()?;
1502        logical_join.right = logical_join.right.to_batch()?;
1503
1504        let ctx = self.base.ctx();
1505        let config = ctx.session_ctx().config();
1506
1507        if predicate.has_eq() {
1508            if !predicate.eq_keys_are_type_aligned() {
1509                return Err(ErrorCode::InternalError(format!(
1510                    "Join eq keys are not aligned for predicate: {predicate:?}"
1511                ))
1512                .into());
1513            }
1514            if config.batch_enable_lookup_join() {
1515                if let Some(lookup_join) = self.to_batch_lookup_join_with_index_selection(
1516                    predicate.clone(),
1517                    logical_join.clone(),
1518                )? {
1519                    return Ok(lookup_join.into());
1520                }
1521            }
1522            self.to_batch_hash_join(logical_join, predicate)
1523        } else if self.is_asof_join() {
1524            return Err(ErrorCode::InvalidInputSyntax(
1525                "AsOf join requires at least 1 equal condition".to_owned(),
1526            )
1527            .into());
1528        } else {
1529            // Convert to Nested-loop Join for non-equal joins
1530            Ok(BatchNestedLoopJoin::new(logical_join).into())
1531        }
1532    }
1533}
1534
1535impl ToStream for LogicalJoin {
1536    fn to_stream(&self, ctx: &mut ToStreamContext) -> Result<PlanRef> {
1537        if self
1538            .on()
1539            .conjunctions
1540            .iter()
1541            .any(|cond| cond.count_nows() > 0)
1542        {
1543            return Err(ErrorCode::NotSupported(
1544                "optimizer has tried to separate the temporal predicate(with now() expression) from the on condition, but it still reminded in on join's condition. Considering move it into WHERE clause?".to_owned(),
1545                 "please refer to https://www.risingwave.dev/docs/current/sql-pattern-temporal-filters/ for more information".to_owned()).into());
1546        }
1547
1548        let predicate = EqJoinPredicate::create(
1549            self.left().schema().len(),
1550            self.right().schema().len(),
1551            self.on().clone(),
1552        );
1553
1554        if self.join_type() == JoinType::AsofInner || self.join_type() == JoinType::AsofLeftOuter {
1555            self.to_stream_asof_join(predicate, ctx)
1556        } else if predicate.has_eq() {
1557            if !predicate.eq_keys_are_type_aligned() {
1558                return Err(ErrorCode::InternalError(format!(
1559                    "Join eq keys are not aligned for predicate: {predicate:?}"
1560                ))
1561                .into());
1562            }
1563
1564            if self.should_be_temporal_join() {
1565                self.to_stream_temporal_join_with_index_selection(predicate, ctx)
1566            } else {
1567                self.to_stream_hash_join(predicate, ctx)
1568            }
1569        } else if self.should_be_temporal_join() {
1570            self.to_stream_nested_loop_temporal_join(predicate, ctx)
1571        } else if let Some(dynamic_filter) =
1572            self.to_stream_dynamic_filter(self.on().clone(), ctx)?
1573        {
1574            Ok(dynamic_filter)
1575        } else {
1576            Err(RwError::from(ErrorCode::NotSupported(
1577                "streaming nested-loop join".to_owned(),
1578                "The non-equal join in the query requires a nested-loop join executor, which could be very expensive to run. \
1579                 Consider rewriting the query to use dynamic filter as a substitute if possible.\n\
1580                 See also: https://docs.risingwave.com/docs/current/sql-pattern-dynamic-filters/".to_owned(),
1581            )))
1582        }
1583    }
1584
1585    fn logical_rewrite_for_stream(
1586        &self,
1587        ctx: &mut RewriteStreamContext,
1588    ) -> Result<(PlanRef, ColIndexMapping)> {
1589        let (left, left_col_change) = self.left().logical_rewrite_for_stream(ctx)?;
1590        let left_len = left.schema().len();
1591        let (right, right_col_change) = self.right().logical_rewrite_for_stream(ctx)?;
1592        let (join, out_col_change) = self.rewrite_with_left_right(
1593            left.clone(),
1594            left_col_change,
1595            right.clone(),
1596            right_col_change,
1597        );
1598
1599        let mapping = ColIndexMapping::with_remaining_columns(
1600            join.output_indices(),
1601            join.internal_column_num(),
1602        );
1603
1604        let l2o = join.core.l2i_col_mapping().composite(&mapping);
1605        let r2o = join.core.r2i_col_mapping().composite(&mapping);
1606
1607        // Add missing pk indices to the logical join
1608        let mut left_to_add = left
1609            .expect_stream_key()
1610            .iter()
1611            .cloned()
1612            .filter(|i| l2o.try_map(*i).is_none())
1613            .collect_vec();
1614
1615        let mut right_to_add = right
1616            .expect_stream_key()
1617            .iter()
1618            .filter(|&&i| r2o.try_map(i).is_none())
1619            .map(|&i| i + left_len)
1620            .collect_vec();
1621
1622        // NOTE(st1page): add join keys in the pk_indices a work around before we really have stream
1623        // key.
1624        let right_len = right.schema().len();
1625        let eq_predicate = EqJoinPredicate::create(left_len, right_len, join.on().clone());
1626
1627        let either_or_both = self.core.add_which_join_key_to_pk();
1628
1629        for (lk, rk) in eq_predicate.eq_indexes() {
1630            match either_or_both {
1631                EitherOrBoth::Left(_) => {
1632                    if l2o.try_map(lk).is_none() {
1633                        left_to_add.push(lk);
1634                    }
1635                }
1636                EitherOrBoth::Right(_) => {
1637                    if r2o.try_map(rk).is_none() {
1638                        right_to_add.push(rk + left_len)
1639                    }
1640                }
1641                EitherOrBoth::Both(_, _) => {
1642                    if l2o.try_map(lk).is_none() {
1643                        left_to_add.push(lk);
1644                    }
1645                    if r2o.try_map(rk).is_none() {
1646                        right_to_add.push(rk + left_len)
1647                    }
1648                }
1649            };
1650        }
1651        let left_to_add = left_to_add.into_iter().unique();
1652        let right_to_add = right_to_add.into_iter().unique();
1653        // NOTE(st1page) over
1654
1655        let mut new_output_indices = join.output_indices().clone();
1656        if !join.is_right_join() {
1657            new_output_indices.extend(left_to_add);
1658        }
1659        if !join.is_left_join() {
1660            new_output_indices.extend(right_to_add);
1661        }
1662
1663        let join_with_pk = join.clone_with_output_indices(new_output_indices);
1664
1665        let plan = if join_with_pk.join_type() == JoinType::FullOuter {
1666            // ignore the all NULL to maintain the stream key's uniqueness, see https://github.com/risingwavelabs/risingwave/issues/8084 for more information
1667
1668            let l2o = join_with_pk
1669                .core
1670                .l2i_col_mapping()
1671                .composite(&join_with_pk.core.i2o_col_mapping());
1672            let r2o = join_with_pk
1673                .core
1674                .r2i_col_mapping()
1675                .composite(&join_with_pk.core.i2o_col_mapping());
1676            let left_right_stream_keys = join_with_pk
1677                .left()
1678                .expect_stream_key()
1679                .iter()
1680                .map(|i| l2o.map(*i))
1681                .chain(
1682                    join_with_pk
1683                        .right()
1684                        .expect_stream_key()
1685                        .iter()
1686                        .map(|i| r2o.map(*i)),
1687                )
1688                .collect_vec();
1689            let plan: PlanRef = join_with_pk.into();
1690            LogicalFilter::filter_out_all_null_keys(plan, &left_right_stream_keys)
1691        } else {
1692            join_with_pk.into()
1693        };
1694
1695        // the added columns is at the end, so it will not change the exists column index
1696        Ok((plan, out_col_change))
1697    }
1698}
1699
1700#[cfg(test)]
1701mod tests {
1702
1703    use std::collections::HashSet;
1704
1705    use risingwave_common::catalog::{Field, Schema};
1706    use risingwave_common::types::{DataType, Datum};
1707    use risingwave_pb::expr::expr_node::Type;
1708
1709    use super::*;
1710    use crate::expr::{FunctionCall, Literal, assert_eq_input_ref};
1711    use crate::optimizer::optimizer_context::OptimizerContext;
1712    use crate::optimizer::plan_node::LogicalValues;
1713    use crate::optimizer::property::FunctionalDependency;
1714
1715    /// Pruning
1716    /// ```text
1717    /// Join(on: input_ref(1)=input_ref(3))
1718    ///   TableScan(v1, v2, v3)
1719    ///   TableScan(v4, v5, v6)
1720    /// ```
1721    /// with required columns [2,3] will result in
1722    /// ```text
1723    /// Project(input_ref(1), input_ref(2))
1724    ///   Join(on: input_ref(0)=input_ref(2))
1725    ///     TableScan(v2, v3)
1726    ///     TableScan(v4)
1727    /// ```
1728    #[tokio::test]
1729    async fn test_prune_join() {
1730        let ty = DataType::Int32;
1731        let ctx = OptimizerContext::mock().await;
1732        let fields: Vec<Field> = (1..7)
1733            .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
1734            .collect();
1735        let left = LogicalValues::new(
1736            vec![],
1737            Schema {
1738                fields: fields[0..3].to_vec(),
1739            },
1740            ctx.clone(),
1741        );
1742        let right = LogicalValues::new(
1743            vec![],
1744            Schema {
1745                fields: fields[3..6].to_vec(),
1746            },
1747            ctx,
1748        );
1749        let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
1750            FunctionCall::new(
1751                Type::Equal,
1752                vec![
1753                    ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
1754                    ExprImpl::InputRef(Box::new(InputRef::new(3, ty))),
1755                ],
1756            )
1757            .unwrap(),
1758        ));
1759        let join_type = JoinType::Inner;
1760        let join: PlanRef = LogicalJoin::new(
1761            left.into(),
1762            right.into(),
1763            join_type,
1764            Condition::with_expr(on),
1765        )
1766        .into();
1767
1768        // Perform the prune
1769        let required_cols = vec![2, 3];
1770        let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1771
1772        // Check the result
1773        let join = plan.as_logical_join().unwrap();
1774        assert_eq!(join.schema().fields().len(), 2);
1775        assert_eq!(join.schema().fields()[0], fields[2]);
1776        assert_eq!(join.schema().fields()[1], fields[3]);
1777
1778        let expr: ExprImpl = join.on().clone().into();
1779        let call = expr.as_function_call().unwrap();
1780        assert_eq_input_ref!(&call.inputs()[0], 0);
1781        assert_eq_input_ref!(&call.inputs()[1], 2);
1782
1783        let left = join.left();
1784        let left = left.as_logical_values().unwrap();
1785        assert_eq!(left.schema().fields(), &fields[1..3]);
1786        let right = join.right();
1787        let right = right.as_logical_values().unwrap();
1788        assert_eq!(right.schema().fields(), &fields[3..4]);
1789    }
1790
1791    /// Semi join panicked previously at `prune_col`. Add test to prevent regression.
1792    #[tokio::test]
1793    async fn test_prune_semi_join() {
1794        let ty = DataType::Int32;
1795        let ctx = OptimizerContext::mock().await;
1796        let fields: Vec<Field> = (1..7)
1797            .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
1798            .collect();
1799        let left = LogicalValues::new(
1800            vec![],
1801            Schema {
1802                fields: fields[0..3].to_vec(),
1803            },
1804            ctx.clone(),
1805        );
1806        let right = LogicalValues::new(
1807            vec![],
1808            Schema {
1809                fields: fields[3..6].to_vec(),
1810            },
1811            ctx,
1812        );
1813        let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
1814            FunctionCall::new(
1815                Type::Equal,
1816                vec![
1817                    ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
1818                    ExprImpl::InputRef(Box::new(InputRef::new(4, ty))),
1819                ],
1820            )
1821            .unwrap(),
1822        ));
1823        for join_type in [
1824            JoinType::LeftSemi,
1825            JoinType::RightSemi,
1826            JoinType::LeftAnti,
1827            JoinType::RightAnti,
1828        ] {
1829            let join = LogicalJoin::new(
1830                left.clone().into(),
1831                right.clone().into(),
1832                join_type,
1833                Condition::with_expr(on.clone()),
1834            );
1835
1836            let offset = if join.is_right_join() { 3 } else { 0 };
1837            let join: PlanRef = join.into();
1838            // Perform the prune
1839            let required_cols = vec![0];
1840            // key 0 is never used in the join (always key 1)
1841            let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1842            let as_plan = plan.as_logical_join().unwrap();
1843            // Check the result
1844            assert_eq!(as_plan.schema().fields().len(), 1);
1845            assert_eq!(as_plan.schema().fields()[0], fields[offset]);
1846
1847            // Perform the prune
1848            let required_cols = vec![0, 1, 2];
1849            // should not panic here
1850            let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1851            let as_plan = plan.as_logical_join().unwrap();
1852            // Check the result
1853            assert_eq!(as_plan.schema().fields().len(), 3);
1854            assert_eq!(as_plan.schema().fields()[0], fields[offset]);
1855            assert_eq!(as_plan.schema().fields()[1], fields[offset + 1]);
1856            assert_eq!(as_plan.schema().fields()[2], fields[offset + 2]);
1857        }
1858    }
1859
1860    /// Pruning
1861    /// ```text
1862    /// Join(on: input_ref(1)=input_ref(3))
1863    ///   TableScan(v1, v2, v3)
1864    ///   TableScan(v4, v5, v6)
1865    /// ```
1866    /// with required columns [1,3] will result in
1867    /// ```text
1868    /// Join(on: input_ref(0)=input_ref(1))
1869    ///   TableScan(v2)
1870    ///   TableScan(v4)
1871    /// ```
1872    #[tokio::test]
1873    async fn test_prune_join_no_project() {
1874        let ty = DataType::Int32;
1875        let ctx = OptimizerContext::mock().await;
1876        let fields: Vec<Field> = (1..7)
1877            .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
1878            .collect();
1879        let left = LogicalValues::new(
1880            vec![],
1881            Schema {
1882                fields: fields[0..3].to_vec(),
1883            },
1884            ctx.clone(),
1885        );
1886        let right = LogicalValues::new(
1887            vec![],
1888            Schema {
1889                fields: fields[3..6].to_vec(),
1890            },
1891            ctx,
1892        );
1893        let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
1894            FunctionCall::new(
1895                Type::Equal,
1896                vec![
1897                    ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
1898                    ExprImpl::InputRef(Box::new(InputRef::new(3, ty))),
1899                ],
1900            )
1901            .unwrap(),
1902        ));
1903        let join_type = JoinType::Inner;
1904        let join: PlanRef = LogicalJoin::new(
1905            left.into(),
1906            right.into(),
1907            join_type,
1908            Condition::with_expr(on),
1909        )
1910        .into();
1911
1912        // Perform the prune
1913        let required_cols = vec![1, 3];
1914        let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1915
1916        // Check the result
1917        let join = plan.as_logical_join().unwrap();
1918        assert_eq!(join.schema().fields().len(), 2);
1919        assert_eq!(join.schema().fields()[0], fields[1]);
1920        assert_eq!(join.schema().fields()[1], fields[3]);
1921
1922        let expr: ExprImpl = join.on().clone().into();
1923        let call = expr.as_function_call().unwrap();
1924        assert_eq_input_ref!(&call.inputs()[0], 0);
1925        assert_eq_input_ref!(&call.inputs()[1], 1);
1926
1927        let left = join.left();
1928        let left = left.as_logical_values().unwrap();
1929        assert_eq!(left.schema().fields(), &fields[1..2]);
1930        let right = join.right();
1931        let right = right.as_logical_values().unwrap();
1932        assert_eq!(right.schema().fields(), &fields[3..4]);
1933    }
1934
1935    /// Convert
1936    /// ```text
1937    /// Join(on: ($1 = $3) AND ($2 == 42))
1938    ///   TableScan(v1, v2, v3)
1939    ///   TableScan(v4, v5, v6)
1940    /// ```
1941    /// to
1942    /// ```text
1943    /// Filter($2 == 42)
1944    ///   HashJoin(on: $1 = $3)
1945    ///     TableScan(v1, v2, v3)
1946    ///     TableScan(v4, v5, v6)
1947    /// ```
1948    #[tokio::test]
1949    async fn test_join_to_batch() {
1950        let ctx = OptimizerContext::mock().await;
1951        let fields: Vec<Field> = (1..7)
1952            .map(|i| Field::with_name(DataType::Int32, format!("v{}", i)))
1953            .collect();
1954        let left = LogicalValues::new(
1955            vec![],
1956            Schema {
1957                fields: fields[0..3].to_vec(),
1958            },
1959            ctx.clone(),
1960        );
1961        let right = LogicalValues::new(
1962            vec![],
1963            Schema {
1964                fields: fields[3..6].to_vec(),
1965            },
1966            ctx,
1967        );
1968
1969        fn input_ref(i: usize) -> ExprImpl {
1970            ExprImpl::InputRef(Box::new(InputRef::new(i, DataType::Int32)))
1971        }
1972        let eq_cond = ExprImpl::FunctionCall(Box::new(
1973            FunctionCall::new(Type::Equal, vec![input_ref(1), input_ref(3)]).unwrap(),
1974        ));
1975        let non_eq_cond = ExprImpl::FunctionCall(Box::new(
1976            FunctionCall::new(
1977                Type::Equal,
1978                vec![
1979                    input_ref(2),
1980                    ExprImpl::Literal(Box::new(Literal::new(
1981                        Datum::Some(42_i32.into()),
1982                        DataType::Int32,
1983                    ))),
1984                ],
1985            )
1986            .unwrap(),
1987        ));
1988        // Condition: ($1 = $3) AND ($2 == 42)
1989        let on_cond = ExprImpl::FunctionCall(Box::new(
1990            FunctionCall::new(Type::And, vec![eq_cond.clone(), non_eq_cond.clone()]).unwrap(),
1991        ));
1992
1993        let join_type = JoinType::Inner;
1994        let logical_join = LogicalJoin::new(
1995            left.into(),
1996            right.into(),
1997            join_type,
1998            Condition::with_expr(on_cond),
1999        );
2000
2001        // Perform `to_batch`
2002        let result = logical_join.to_batch().unwrap();
2003
2004        // Expected plan:  HashJoin($1 = $3 AND $2 == 42)
2005        let hash_join = result.as_batch_hash_join().unwrap();
2006        assert_eq!(
2007            ExprImpl::from(hash_join.eq_join_predicate().eq_cond()),
2008            eq_cond
2009        );
2010        assert_eq!(
2011            *hash_join
2012                .eq_join_predicate()
2013                .non_eq_cond()
2014                .conjunctions
2015                .first()
2016                .unwrap(),
2017            non_eq_cond
2018        );
2019    }
2020
2021    /// Convert
2022    /// ```text
2023    /// Join(join_type: left outer, on: ($1 = $3) AND ($2 == 42))
2024    ///   TableScan(v1, v2, v3)
2025    ///   TableScan(v4, v5, v6)
2026    /// ```
2027    /// to
2028    /// ```text
2029    /// HashJoin(join_type: left outer, on: ($1 = $3) AND ($2 == 42))
2030    ///   TableScan(v1, v2, v3)
2031    ///   TableScan(v4, v5, v6)
2032    /// ```
2033    #[tokio::test]
2034    #[ignore] // ignore due to refactor logical scan, but the test seem to duplicate with the explain test
2035    // framework, maybe we will remove it?
2036    async fn test_join_to_stream() {
2037        // let ctx = Rc::new(RefCell::new(QueryContext::mock().await));
2038        // let fields: Vec<Field> = (1..7)
2039        //     .map(|i| Field {
2040        //         data_type: DataType::Int32,
2041        //         name: format!("v{}", i),
2042        //     })
2043        //     .collect();
2044        // let left = LogicalScan::new(
2045        //     "left".to_string(),
2046        //     TableId::new(0),
2047        //     vec![1.into(), 2.into(), 3.into()],
2048        //     Schema {
2049        //         fields: fields[0..3].to_vec(),
2050        //     },
2051        //     ctx.clone(),
2052        // );
2053        // let right = LogicalScan::new(
2054        //     "right".to_string(),
2055        //     TableId::new(0),
2056        //     vec![4.into(), 5.into(), 6.into()],
2057        //     Schema {
2058        //         fields: fields[3..6].to_vec(),
2059        //     },
2060        //     ctx,
2061        // );
2062        // let eq_cond = ExprImpl::FunctionCall(Box::new(
2063        //     FunctionCall::new(
2064        //         Type::Equal,
2065        //         vec![
2066        //             ExprImpl::InputRef(Box::new(InputRef::new(1, DataType::Int32))),
2067        //             ExprImpl::InputRef(Box::new(InputRef::new(3, DataType::Int32))),
2068        //         ],
2069        //     )
2070        //     .unwrap(),
2071        // ));
2072        // let non_eq_cond = ExprImpl::FunctionCall(Box::new(
2073        //     FunctionCall::new(
2074        //         Type::Equal,
2075        //         vec![
2076        //             ExprImpl::InputRef(Box::new(InputRef::new(2, DataType::Int32))),
2077        //             ExprImpl::Literal(Box::new(Literal::new(
2078        //                 Datum::Some(42_i32.into()),
2079        //                 DataType::Int32,
2080        //             ))),
2081        //         ],
2082        //     )
2083        //     .unwrap(),
2084        // ));
2085        // // Condition: ($1 = $3) AND ($2 == 42)
2086        // let on_cond = ExprImpl::FunctionCall(Box::new(
2087        //     FunctionCall::new(Type::And, vec![eq_cond, non_eq_cond]).unwrap(),
2088        // ));
2089
2090        // let join_type = JoinType::LeftOuter;
2091        // let logical_join = LogicalJoin::new(
2092        //     left.into(),
2093        //     right.into(),
2094        //     join_type,
2095        //     Condition::with_expr(on_cond.clone()),
2096        // );
2097
2098        // // Perform `to_stream`
2099        // let result = logical_join.to_stream();
2100
2101        // // Expected plan: HashJoin(($1 = $3) AND ($2 == 42))
2102        // let hash_join = result.as_stream_hash_join().unwrap();
2103        // assert_eq!(hash_join.eq_join_predicate().all_cond().as_expr(), on_cond);
2104    }
2105    /// Pruning
2106    /// ```text
2107    /// Join(on: input_ref(1)=input_ref(3))
2108    ///   TableScan(v1, v2, v3)
2109    ///   TableScan(v4, v5, v6)
2110    /// ```
2111    /// with required columns [3, 2] will result in
2112    /// ```text
2113    /// Project(input_ref(2), input_ref(1))
2114    ///   Join(on: input_ref(0)=input_ref(2))
2115    ///     TableScan(v2, v3)
2116    ///     TableScan(v4)
2117    /// ```
2118    #[tokio::test]
2119    async fn test_join_column_prune_with_order_required() {
2120        let ty = DataType::Int32;
2121        let ctx = OptimizerContext::mock().await;
2122        let fields: Vec<Field> = (1..7)
2123            .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
2124            .collect();
2125        let left = LogicalValues::new(
2126            vec![],
2127            Schema {
2128                fields: fields[0..3].to_vec(),
2129            },
2130            ctx.clone(),
2131        );
2132        let right = LogicalValues::new(
2133            vec![],
2134            Schema {
2135                fields: fields[3..6].to_vec(),
2136            },
2137            ctx,
2138        );
2139        let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
2140            FunctionCall::new(
2141                Type::Equal,
2142                vec![
2143                    ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
2144                    ExprImpl::InputRef(Box::new(InputRef::new(3, ty))),
2145                ],
2146            )
2147            .unwrap(),
2148        ));
2149        let join_type = JoinType::Inner;
2150        let join: PlanRef = LogicalJoin::new(
2151            left.into(),
2152            right.into(),
2153            join_type,
2154            Condition::with_expr(on),
2155        )
2156        .into();
2157
2158        // Perform the prune
2159        let required_cols = vec![3, 2];
2160        let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
2161
2162        // Check the result
2163        let join = plan.as_logical_join().unwrap();
2164        assert_eq!(join.schema().fields().len(), 2);
2165        assert_eq!(join.schema().fields()[0], fields[3]);
2166        assert_eq!(join.schema().fields()[1], fields[2]);
2167
2168        let expr: ExprImpl = join.on().clone().into();
2169        let call = expr.as_function_call().unwrap();
2170        assert_eq_input_ref!(&call.inputs()[0], 0);
2171        assert_eq_input_ref!(&call.inputs()[1], 2);
2172
2173        let left = join.left();
2174        let left = left.as_logical_values().unwrap();
2175        assert_eq!(left.schema().fields(), &fields[1..3]);
2176        let right = join.right();
2177        let right = right.as_logical_values().unwrap();
2178        assert_eq!(right.schema().fields(), &fields[3..4]);
2179    }
2180
2181    #[tokio::test]
2182    async fn fd_derivation_inner_outer_join() {
2183        // left: [l0, l1], right: [r0, r1, r2]
2184        // FD: l0 --> l1, r0 --> { r1, r2 }
2185        // On: l0 = 0 AND l1 = r1
2186        //
2187        // Inner Join:
2188        //  Schema: [l0, l1, r0, r1, r2]
2189        //  FD: l0 --> l1, r0 --> { r1, r2 }, {} --> l0, l1 --> r1, r1 --> l1
2190        // Left Outer Join:
2191        //  Schema: [l0, l1, r0, r1, r2]
2192        //  FD: l0 --> l1
2193        // Right Outer Join:
2194        //  Schema: [l0, l1, r0, r1, r2]
2195        //  FD: r0 --> { r1, r2 }
2196        // Full Outer Join:
2197        //  Schema: [l0, l1, r0, r1, r2]
2198        //  FD: empty
2199        // Left Semi/Anti Join:
2200        //  Schema: [l0, l1]
2201        //  FD: l0 --> l1
2202        // Right Semi/Anti Join:
2203        //  Schema: [r0, r1, r2]
2204        //  FD: r0 --> {r1, r2}
2205        let ctx = OptimizerContext::mock().await;
2206        let left = {
2207            let fields: Vec<Field> = vec![
2208                Field::with_name(DataType::Int32, "l0"),
2209                Field::with_name(DataType::Int32, "l1"),
2210            ];
2211            let mut values = LogicalValues::new(vec![], Schema { fields }, ctx.clone());
2212            // 0 --> 1
2213            values
2214                .base
2215                .functional_dependency_mut()
2216                .add_functional_dependency_by_column_indices(&[0], &[1]);
2217            values
2218        };
2219        let right = {
2220            let fields: Vec<Field> = vec![
2221                Field::with_name(DataType::Int32, "r0"),
2222                Field::with_name(DataType::Int32, "r1"),
2223                Field::with_name(DataType::Int32, "r2"),
2224            ];
2225            let mut values = LogicalValues::new(vec![], Schema { fields }, ctx);
2226            // 0 --> 1, 2
2227            values
2228                .base
2229                .functional_dependency_mut()
2230                .add_functional_dependency_by_column_indices(&[0], &[1, 2]);
2231            values
2232        };
2233        // l0 = 0 AND l1 = r1
2234        let on: ExprImpl = FunctionCall::new(
2235            Type::And,
2236            vec![
2237                FunctionCall::new(
2238                    Type::Equal,
2239                    vec![
2240                        InputRef::new(0, DataType::Int32).into(),
2241                        ExprImpl::literal_int(0),
2242                    ],
2243                )
2244                .unwrap()
2245                .into(),
2246                FunctionCall::new(
2247                    Type::Equal,
2248                    vec![
2249                        InputRef::new(1, DataType::Int32).into(),
2250                        InputRef::new(3, DataType::Int32).into(),
2251                    ],
2252                )
2253                .unwrap()
2254                .into(),
2255            ],
2256        )
2257        .unwrap()
2258        .into();
2259        let expected_fd_set = [
2260            (
2261                JoinType::Inner,
2262                [
2263                    // inherit from left
2264                    FunctionalDependency::with_indices(5, &[0], &[1]),
2265                    // inherit from right
2266                    FunctionalDependency::with_indices(5, &[2], &[3, 4]),
2267                    // constant column in join condition
2268                    FunctionalDependency::with_indices(5, &[], &[0]),
2269                    // eq column in join condition
2270                    FunctionalDependency::with_indices(5, &[1], &[3]),
2271                    FunctionalDependency::with_indices(5, &[3], &[1]),
2272                ]
2273                .into_iter()
2274                .collect::<HashSet<_>>(),
2275            ),
2276            (JoinType::FullOuter, HashSet::new()),
2277            (
2278                JoinType::RightOuter,
2279                [
2280                    // inherit from right
2281                    FunctionalDependency::with_indices(5, &[2], &[3, 4]),
2282                ]
2283                .into_iter()
2284                .collect::<HashSet<_>>(),
2285            ),
2286            (
2287                JoinType::LeftOuter,
2288                [
2289                    // inherit from left
2290                    FunctionalDependency::with_indices(5, &[0], &[1]),
2291                ]
2292                .into_iter()
2293                .collect::<HashSet<_>>(),
2294            ),
2295            (
2296                JoinType::LeftSemi,
2297                [
2298                    // inherit from left
2299                    FunctionalDependency::with_indices(2, &[0], &[1]),
2300                ]
2301                .into_iter()
2302                .collect::<HashSet<_>>(),
2303            ),
2304            (
2305                JoinType::LeftAnti,
2306                [
2307                    // inherit from left
2308                    FunctionalDependency::with_indices(2, &[0], &[1]),
2309                ]
2310                .into_iter()
2311                .collect::<HashSet<_>>(),
2312            ),
2313            (
2314                JoinType::RightSemi,
2315                [
2316                    // inherit from right
2317                    FunctionalDependency::with_indices(3, &[0], &[1, 2]),
2318                ]
2319                .into_iter()
2320                .collect::<HashSet<_>>(),
2321            ),
2322            (
2323                JoinType::RightAnti,
2324                [
2325                    // inherit from right
2326                    FunctionalDependency::with_indices(3, &[0], &[1, 2]),
2327                ]
2328                .into_iter()
2329                .collect::<HashSet<_>>(),
2330            ),
2331        ];
2332
2333        for (join_type, expected_res) in expected_fd_set {
2334            let join = LogicalJoin::new(
2335                left.clone().into(),
2336                right.clone().into(),
2337                join_type,
2338                Condition::with_expr(on.clone()),
2339            );
2340            let fd_set = join
2341                .functional_dependency()
2342                .as_dependencies()
2343                .iter()
2344                .cloned()
2345                .collect::<HashSet<_>>();
2346            assert_eq!(fd_set, expected_res);
2347        }
2348    }
2349}