Skip to main content

risingwave_frontend/optimizer/plan_node/
logical_join.rs

1// Copyright 2022 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::ops::Deref;
17
18use fixedbitset::FixedBitSet;
19use itertools::{EitherOrBoth, Itertools};
20use pretty_xmlish::{Pretty, XmlNode};
21use risingwave_expr::bail;
22use risingwave_pb::expr::expr_node::PbType;
23use risingwave_pb::plan_common::{AsOfJoinDesc, JoinType, PbAsOfJoinInequalityType};
24use risingwave_pb::stream_plan::StreamScanType;
25use risingwave_sqlparser::ast::AsOf;
26
27use super::generic::{
28    GenericPlanNode, GenericPlanRef, push_down_into_join, push_down_join_condition,
29};
30use super::utils::{Distill, childless_record};
31use super::{
32    BackfillType, BatchPlanRef, ColPrunable, ExprRewritable, Logical, LogicalPlanRef as PlanRef,
33    PlanBase, PlanTreeNodeBinary, PredicatePushdown, StreamHashJoin, StreamPlanRef, StreamProject,
34    ToBatch, ToStream, generic, try_enforce_locality_requirement,
35};
36use crate::error::{ErrorCode, Result, RwError};
37use crate::expr::{CollectInputRef, Expr, ExprImpl, ExprRewriter, ExprType, ExprVisitor, InputRef};
38use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
39use crate::optimizer::plan_node::generic::DynamicFilter;
40use crate::optimizer::plan_node::stream_asof_join::StreamAsOfJoin;
41use crate::optimizer::plan_node::utils::IndicesDisplay;
42use crate::optimizer::plan_node::{
43    BatchHashJoin, BatchLookupJoin, BatchNestedLoopJoin, ColumnPruningContext, EqJoinPredicate,
44    LogicalFilter, LogicalScan, PredicatePushdownContext, RewriteStreamContext,
45    StreamDynamicFilter, StreamFilter, StreamTableScan, StreamTemporalJoin, ToStreamContext,
46};
47use crate::optimizer::plan_visitor::LogicalCardinalityExt;
48use crate::optimizer::property::{Distribution, RequiredDist};
49use crate::utils::{ColIndexMapping, ColIndexMappingRewriteExt, Condition, ConditionDisplay};
50
51/// `LogicalJoin` combines two relations according to some condition.
52///
53/// Each output row has fields from the left and right inputs. The set of output rows is a subset
54/// of the cartesian product of the two inputs; precisely which subset depends on the join
55/// condition. In addition, the output columns are a subset of the columns of the left and
56/// right columns, dependent on the output indices provided. A repeat output index is illegal.
57#[derive(Debug, Clone, PartialEq, Eq, Hash)]
58pub struct LogicalJoin {
59    pub base: PlanBase<Logical>,
60    core: generic::Join<PlanRef>,
61}
62
63impl Distill for LogicalJoin {
64    fn distill<'a>(&self) -> XmlNode<'a> {
65        let verbose = self.base.ctx().is_explain_verbose();
66        let mut vec = Vec::with_capacity(if verbose { 3 } else { 2 });
67        vec.push(("type", Pretty::debug(&self.join_type())));
68
69        let concat_schema = self.core.concat_schema();
70        let cond = Pretty::debug(&ConditionDisplay {
71            condition: self.on(),
72            input_schema: &concat_schema,
73        });
74        vec.push(("on", cond));
75
76        if verbose {
77            let data = IndicesDisplay::from_join(&self.core, &concat_schema);
78            vec.push(("output", data));
79        }
80
81        childless_record("LogicalJoin", vec)
82    }
83}
84
85impl LogicalJoin {
86    pub(crate) fn new(left: PlanRef, right: PlanRef, join_type: JoinType, on: Condition) -> Self {
87        let core = generic::Join::with_full_output(left, right, join_type, on);
88        Self::with_core(core)
89    }
90
91    pub(crate) fn with_output_indices(
92        left: PlanRef,
93        right: PlanRef,
94        join_type: JoinType,
95        on: Condition,
96        output_indices: Vec<usize>,
97    ) -> Self {
98        let core = generic::Join::new(left, right, on, join_type, output_indices);
99        Self::with_core(core)
100    }
101
102    pub fn with_core(core: generic::Join<PlanRef>) -> Self {
103        let base = PlanBase::new_logical_with_core(&core);
104        LogicalJoin { base, core }
105    }
106
107    pub fn create(
108        left: PlanRef,
109        right: PlanRef,
110        join_type: JoinType,
111        on_clause: ExprImpl,
112    ) -> PlanRef {
113        Self::new(left, right, join_type, Condition::with_expr(on_clause)).into()
114    }
115
116    pub fn internal_column_num(&self) -> usize {
117        self.core.internal_column_num()
118    }
119
120    pub fn i2l_col_mapping_ignore_join_type(&self) -> ColIndexMapping {
121        self.core.i2l_col_mapping_ignore_join_type()
122    }
123
124    pub fn i2r_col_mapping_ignore_join_type(&self) -> ColIndexMapping {
125        self.core.i2r_col_mapping_ignore_join_type()
126    }
127
128    /// Get a reference to the logical join's on.
129    pub fn on(&self) -> &Condition {
130        self.core
131            .on
132            .as_condition_ref()
133            .expect("logical join should store predicate as Condition")
134    }
135
136    pub fn core(&self) -> &generic::Join<PlanRef> {
137        &self.core
138    }
139
140    /// Collect all input ref in the on condition. And separate them into left and right.
141    pub fn input_idx_on_condition(&self) -> (Vec<usize>, Vec<usize>) {
142        let input_refs = self
143            .core
144            .on
145            .as_condition_ref()
146            .expect("logical join should store predicate as Condition")
147            .collect_input_refs(self.core.left.schema().len() + self.core.right.schema().len());
148        let index_group = input_refs
149            .ones()
150            .chunk_by(|i| *i < self.core.left.schema().len());
151        let left_index = index_group
152            .into_iter()
153            .next()
154            .map_or(vec![], |group| group.1.collect_vec());
155        let right_index = index_group.into_iter().next().map_or(vec![], |group| {
156            group
157                .1
158                .map(|i| i - self.core.left.schema().len())
159                .collect_vec()
160        });
161        (left_index, right_index)
162    }
163
164    /// Get the join type of the logical join.
165    pub fn join_type(&self) -> JoinType {
166        self.core.join_type
167    }
168
169    /// Get the eq join key of the logical join.
170    pub fn eq_indexes(&self) -> Vec<(usize, usize)> {
171        self.core.eq_indexes()
172    }
173
174    /// Get the output indices of the logical join.
175    pub fn output_indices(&self) -> &Vec<usize> {
176        &self.core.output_indices
177    }
178
179    /// Clone with new output indices
180    pub fn clone_with_output_indices(&self, output_indices: Vec<usize>) -> Self {
181        Self::with_core(generic::Join {
182            output_indices,
183            ..self.core.clone()
184        })
185    }
186
187    /// Clone with new `on` condition
188    pub fn clone_with_cond(&self, on: Condition) -> Self {
189        Self::with_core(generic::Join {
190            on: generic::JoinOn::Condition(on),
191            ..self.core.clone()
192        })
193    }
194
195    pub fn is_left_join(&self) -> bool {
196        matches!(self.join_type(), JoinType::LeftSemi | JoinType::LeftAnti)
197    }
198
199    pub fn is_right_join(&self) -> bool {
200        matches!(self.join_type(), JoinType::RightSemi | JoinType::RightAnti)
201    }
202
203    pub fn is_full_out(&self) -> bool {
204        self.core.is_full_out()
205    }
206
207    pub fn is_asof_join(&self) -> bool {
208        self.join_type() == JoinType::AsofInner || self.join_type() == JoinType::AsofLeftOuter
209    }
210
211    pub fn output_indices_are_trivial(&self) -> bool {
212        itertools::equal(
213            self.output_indices().iter().cloned(),
214            0..self.internal_column_num(),
215        )
216    }
217
218    /// Try to simplify the outer join with the predicate on the top of the join
219    ///
220    /// now it is just a naive implementation for comparison expression, we can give a more general
221    /// implementation with constant folding in future
222    fn simplify_outer(predicate: &Condition, left_col_num: usize, join_type: JoinType) -> JoinType {
223        let (mut gen_null_in_left, mut gen_null_in_right) = match join_type {
224            JoinType::LeftOuter => (false, true),
225            JoinType::RightOuter => (true, false),
226            JoinType::FullOuter => (true, true),
227            _ => return join_type,
228        };
229
230        for expr in &predicate.conjunctions {
231            if let ExprImpl::FunctionCall(func) = expr {
232                match func.func_type() {
233                    ExprType::Equal
234                    | ExprType::NotEqual
235                    | ExprType::LessThan
236                    | ExprType::LessThanOrEqual
237                    | ExprType::GreaterThan
238                    | ExprType::GreaterThanOrEqual => {
239                        for input in func.inputs() {
240                            if let ExprImpl::InputRef(input) = input {
241                                let idx = input.index;
242                                if idx < left_col_num {
243                                    gen_null_in_left = false;
244                                } else {
245                                    gen_null_in_right = false;
246                                }
247                            }
248                        }
249                    }
250                    _ => {}
251                };
252            }
253        }
254
255        match (gen_null_in_left, gen_null_in_right) {
256            (true, true) => JoinType::FullOuter,
257            (true, false) => JoinType::RightOuter,
258            (false, true) => JoinType::LeftOuter,
259            (false, false) => JoinType::Inner,
260        }
261    }
262
263    /// Index Join:
264    /// Try to convert logical join into batch lookup join and meanwhile it will do
265    /// the index selection for the lookup table so that we can benefit from indexes.
266    fn to_batch_lookup_join_with_index_selection(
267        &self,
268        predicate: EqJoinPredicate,
269        batch_join: generic::Join<BatchPlanRef>,
270    ) -> Result<Option<BatchLookupJoin>> {
271        match batch_join.join_type {
272            JoinType::Inner
273            | JoinType::LeftOuter
274            | JoinType::LeftSemi
275            | JoinType::LeftAnti
276            | JoinType::AsofInner
277            | JoinType::AsofLeftOuter => {}
278            _ => return Ok(None),
279        };
280
281        // Index selection for index join.
282        let right = self.right();
283        // Lookup Join only supports basic tables on the join's right side.
284        let logical_scan: &LogicalScan = if let Some(logical_scan) = right.as_logical_scan() {
285            logical_scan
286        } else {
287            return Ok(None);
288        };
289
290        let mut result_plan = None;
291        // Lookup primary table.
292        if let Some(lookup_join) =
293            self.to_batch_lookup_join(predicate.clone(), batch_join.clone())?
294        {
295            result_plan = Some(lookup_join);
296        }
297
298        if self
299            .core
300            .ctx()
301            .session_ctx()
302            .config()
303            .enable_index_selection()
304        {
305            let indexes = logical_scan.table_indexes();
306            for index in indexes {
307                if let Some(index_scan) = logical_scan.to_index_scan_if_index_covered(index) {
308                    let index_scan: PlanRef = index_scan.into();
309                    let that = self.clone_with_left_right(self.left(), index_scan.clone());
310                    let mut new_batch_join = batch_join.clone();
311                    new_batch_join.right =
312                        index_scan.to_batch().expect("index scan failed to batch");
313
314                    // Lookup covered index.
315                    if let Some(lookup_join) =
316                        that.to_batch_lookup_join(predicate.clone(), new_batch_join)?
317                    {
318                        match &result_plan {
319                            None => result_plan = Some(lookup_join),
320                            Some(prev_lookup_join) => {
321                                // Prefer to choose lookup join with longer lookup prefix len.
322                                if prev_lookup_join.lookup_prefix_len()
323                                    < lookup_join.lookup_prefix_len()
324                                {
325                                    result_plan = Some(lookup_join)
326                                }
327                            }
328                        }
329                    }
330                }
331            }
332        }
333
334        Ok(result_plan)
335    }
336
337    /// Try to convert logical join into batch lookup join.
338    fn to_batch_lookup_join(
339        &self,
340        predicate: EqJoinPredicate,
341        logical_join: generic::Join<BatchPlanRef>,
342    ) -> Result<Option<BatchLookupJoin>> {
343        let logical_scan: &LogicalScan =
344            if let Some(logical_scan) = self.core.right.as_logical_scan() {
345                logical_scan
346            } else {
347                return Ok(None);
348            };
349        Self::gen_batch_lookup_join(logical_scan, predicate, logical_join, self.is_asof_join())
350    }
351
352    pub fn gen_batch_lookup_join(
353        logical_scan: &LogicalScan,
354        predicate: EqJoinPredicate,
355        logical_join: generic::Join<BatchPlanRef>,
356        is_as_of: bool,
357    ) -> Result<Option<BatchLookupJoin>> {
358        match logical_join.join_type {
359            JoinType::Inner
360            | JoinType::LeftOuter
361            | JoinType::LeftSemi
362            | JoinType::LeftAnti
363            | JoinType::AsofInner
364            | JoinType::AsofLeftOuter => {}
365            _ => return Ok(None),
366        };
367
368        let table = logical_scan.table();
369        let output_column_ids = logical_scan.output_column_ids();
370
371        // Verify that the right join key columns are the the prefix of the primary key and
372        // also contain the distribution key.
373        let order_col_ids = table.order_column_ids();
374        let dist_key = table.distribution_key.clone();
375        // The at least prefix of order key that contains distribution key.
376        let mut dist_key_in_order_key_pos = vec![];
377        for d in dist_key {
378            let pos = table
379                .order_column_indices()
380                .position(|x| x == d)
381                .expect("dist_key must in order_key");
382            dist_key_in_order_key_pos.push(pos);
383        }
384        // The shortest prefix of order key that contains distribution key.
385        let shortest_prefix_len = dist_key_in_order_key_pos
386            .iter()
387            .max()
388            .map_or(0, |pos| pos + 1);
389
390        // Distributed lookup join can't support lookup table with a singleton distribution.
391        if shortest_prefix_len == 0 {
392            return Ok(None);
393        }
394
395        // Reorder the join equal predicate to match the order key.
396        let mut reorder_idx = Vec::with_capacity(shortest_prefix_len);
397        for order_col_id in order_col_ids {
398            let mut found = false;
399            for (i, eq_idx) in predicate.right_eq_indexes().into_iter().enumerate() {
400                if order_col_id == output_column_ids[eq_idx] {
401                    reorder_idx.push(i);
402                    found = true;
403                    break;
404                }
405            }
406            if !found {
407                break;
408            }
409        }
410        if reorder_idx.len() < shortest_prefix_len {
411            return Ok(None);
412        }
413        let lookup_prefix_len = reorder_idx.len();
414        let predicate = predicate.reorder(&reorder_idx);
415
416        // Extract the predicate from logical scan. Only pure scan is supported.
417        let (new_scan, scan_predicate, project_expr) = logical_scan.predicate_pull_up();
418        // Construct output column to require column mapping
419        let o2r = if let Some(project_expr) = project_expr {
420            project_expr
421                .into_iter()
422                .map(|x| x.as_input_ref().unwrap().index)
423                .collect_vec()
424        } else {
425            (0..logical_scan.output_col_idx().len()).collect_vec()
426        };
427        let left_schema_len = logical_join.left.schema().len();
428
429        let mut join_predicate_rewriter = LookupJoinPredicateRewriter {
430            offset: left_schema_len,
431            mapping: o2r.clone(),
432        };
433
434        let new_eq_cond = predicate
435            .eq_cond()
436            .rewrite_expr(&mut join_predicate_rewriter);
437
438        let mut scan_predicate_rewriter = LookupJoinScanPredicateRewriter {
439            offset: left_schema_len,
440        };
441
442        let new_other_cond = predicate
443            .other_cond()
444            .clone()
445            .rewrite_expr(&mut join_predicate_rewriter)
446            .and(scan_predicate.rewrite_expr(&mut scan_predicate_rewriter));
447
448        let new_join_on = new_eq_cond.and(new_other_cond);
449        let new_predicate =
450            EqJoinPredicate::create(left_schema_len, new_scan.schema().len(), new_join_on);
451
452        // We discovered that we cannot use a lookup join after pulling up the predicate
453        // from one side and simplifying the condition. Let's use some other join instead.
454        if !new_predicate.has_eq() {
455            return Ok(None);
456        }
457
458        // Rewrite the join output indices and all output indices referred to the old scan need to
459        // rewrite.
460        let new_join_output_indices = logical_join
461            .output_indices
462            .iter()
463            .map(|&x| {
464                if x < left_schema_len {
465                    x
466                } else {
467                    o2r[x - left_schema_len] + left_schema_len
468                }
469            })
470            .collect_vec();
471
472        let new_scan_output_column_ids = new_scan.output_column_ids();
473        let as_of = new_scan.as_of.clone();
474        let new_logical_scan: LogicalScan = new_scan.into();
475
476        // Construct a new logical join, because we have change its RHS.
477        let new_logical_join = generic::Join::new_with_eq_predicate(
478            logical_join.left,
479            new_logical_scan.to_batch()?,
480            new_predicate,
481            logical_join.join_type,
482            new_join_output_indices,
483        );
484
485        let asof_desc = is_as_of
486            .then(|| {
487                Self::get_inequality_desc_from_predicate(
488                    predicate.other_cond().clone(),
489                    left_schema_len,
490                )
491            })
492            .transpose()?;
493
494        Ok(Some(BatchLookupJoin::new(
495            new_logical_join,
496            table.clone(),
497            new_scan_output_column_ids,
498            lookup_prefix_len,
499            false,
500            as_of,
501            asof_desc,
502        )))
503    }
504
505    pub fn decompose(self) -> (PlanRef, PlanRef, Condition, JoinType, Vec<usize>) {
506        self.core.decompose()
507    }
508
509    fn dynamic_filter_candidate(&self, predicate: &Condition) -> Option<(usize, PbType)> {
510        // Dynamic filter only supports `Inner`/`LeftSemi`.
511        if !matches!(self.join_type(), JoinType::Inner | JoinType::LeftSemi) {
512            return None;
513        }
514
515        // Dynamic filter requires right side to be a scalar with one column.
516        if !self.right().max_one_row() || self.right().schema().len() != 1 {
517            return None;
518        }
519
520        // Dynamic filter only supports a single comparison predicate.
521        if predicate.conjunctions.len() > 1 {
522            return None;
523        }
524        let expr: ExprImpl = predicate.clone().into();
525        let (left_ref, comparator, right_ref) = expr.as_comparison_cond()?;
526
527        // Comparison must cross inputs: left input ref vs right scalar input ref.
528        let left_len = self.left().schema().len();
529        let condition_cross_inputs = left_ref.index < left_len && right_ref.index == left_len;
530        if !condition_cross_inputs {
531            return None;
532        }
533
534        // Comparison keys must be type aligned.
535        if self.left().schema().fields()[left_ref.index].data_type
536            != self.right().schema().fields()[0].data_type
537        {
538            return None;
539        }
540
541        // Dynamic filter output can only come from the left side.
542        if !self.output_indices().iter().all(|i| *i < left_len) {
543            return None;
544        }
545
546        Some((left_ref.index, comparator))
547    }
548
549    /// Check whether this join can be treated as a temporal filter for locality optimization.
550    ///
551    /// This is intentionally stricter than dynamic filter:
552    /// - right side must be a `LogicalNow`,
553    /// - and all dynamic filter preconditions must hold.
554    fn temporal_filter_candidate(&self) -> bool {
555        self.right().as_logical_now().is_some()
556            && self.dynamic_filter_candidate(self.on()).is_some()
557    }
558}
559
560impl PlanTreeNodeBinary<Logical> for LogicalJoin {
561    fn left(&self) -> PlanRef {
562        self.core.left.clone()
563    }
564
565    fn right(&self) -> PlanRef {
566        self.core.right.clone()
567    }
568
569    fn clone_with_left_right(&self, left: PlanRef, right: PlanRef) -> Self {
570        Self::with_core(generic::Join {
571            left,
572            right,
573            ..self.core.clone()
574        })
575    }
576
577    fn rewrite_with_left_right(
578        &self,
579        left: PlanRef,
580        left_col_change: ColIndexMapping,
581        right: PlanRef,
582        right_col_change: ColIndexMapping,
583    ) -> (Self, ColIndexMapping) {
584        let (new_on, new_output_indices) = {
585            let (mut map, _) = left_col_change.clone().into_parts();
586            let (mut right_map, _) = right_col_change.clone().into_parts();
587            for i in right_map.iter_mut().flatten() {
588                *i += left.schema().len();
589            }
590            map.append(&mut right_map);
591            let mut mapping = ColIndexMapping::new(map, left.schema().len() + right.schema().len());
592
593            let new_output_indices = self
594                .output_indices()
595                .iter()
596                .map(|&i| mapping.map(i))
597                .collect::<Vec<_>>();
598            let new_on = self.on().clone().rewrite_expr(&mut mapping);
599            (new_on, new_output_indices)
600        };
601
602        let join = Self::with_output_indices(
603            left,
604            right,
605            self.join_type(),
606            new_on,
607            new_output_indices.clone(),
608        );
609
610        let new_i2o = ColIndexMapping::with_remaining_columns(
611            &new_output_indices,
612            join.internal_column_num(),
613        );
614
615        let old_o2i = self.core.o2i_col_mapping();
616
617        let old_o2l = old_o2i
618            .composite(&self.core.i2l_col_mapping())
619            .composite(&left_col_change);
620        let old_o2r = old_o2i
621            .composite(&self.core.i2r_col_mapping())
622            .composite(&right_col_change);
623        let new_l2o = join.core.l2i_col_mapping().composite(&new_i2o);
624        let new_r2o = join.core.r2i_col_mapping().composite(&new_i2o);
625
626        let out_col_change = old_o2l
627            .composite(&new_l2o)
628            .union(&old_o2r.composite(&new_r2o));
629        (join, out_col_change)
630    }
631}
632
633impl_plan_tree_node_for_binary! { Logical, LogicalJoin }
634
635impl ColPrunable for LogicalJoin {
636    fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef {
637        // make `required_cols` point to internal table instead of output schema.
638        let required_cols = required_cols
639            .iter()
640            .map(|i| self.output_indices()[*i])
641            .collect_vec();
642        let left_len = self.left().schema().fields.len();
643
644        let total_len = self.left().schema().len() + self.right().schema().len();
645        let mut resized_required_cols = FixedBitSet::with_capacity(total_len);
646
647        required_cols.iter().for_each(|&i| {
648            if self.is_right_join() {
649                resized_required_cols.insert(left_len + i);
650            } else {
651                resized_required_cols.insert(i);
652            }
653        });
654
655        // add those columns which are required in the join condition to
656        // to those that are required in the output
657        let mut visitor = CollectInputRef::new(resized_required_cols);
658        self.on().visit_expr(&mut visitor);
659        let left_right_required_cols = FixedBitSet::from(visitor).ones().collect_vec();
660
661        let mut left_required_cols = Vec::new();
662        let mut right_required_cols = Vec::new();
663        left_right_required_cols.iter().for_each(|&i| {
664            if i < left_len {
665                left_required_cols.push(i);
666            } else {
667                right_required_cols.push(i - left_len);
668            }
669        });
670
671        let mut on = self.on().clone();
672        let mut mapping =
673            ColIndexMapping::with_remaining_columns(&left_right_required_cols, total_len);
674        on = on.rewrite_expr(&mut mapping);
675
676        let new_output_indices = {
677            let required_inputs_in_output = if self.is_left_join() {
678                &left_required_cols
679            } else if self.is_right_join() {
680                &right_required_cols
681            } else {
682                &left_right_required_cols
683            };
684
685            let mapping =
686                ColIndexMapping::with_remaining_columns(required_inputs_in_output, total_len);
687            required_cols.iter().map(|&i| mapping.map(i)).collect_vec()
688        };
689
690        LogicalJoin::with_output_indices(
691            self.left().prune_col(&left_required_cols, ctx),
692            self.right().prune_col(&right_required_cols, ctx),
693            self.join_type(),
694            on,
695            new_output_indices,
696        )
697        .into()
698    }
699}
700
701impl ExprRewritable<Logical> for LogicalJoin {
702    fn has_rewritable_expr(&self) -> bool {
703        true
704    }
705
706    fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
707        let mut core = self.core.clone();
708        core.rewrite_exprs(r);
709        Self {
710            base: self.base.clone_with_new_plan_id(),
711            core,
712        }
713        .into()
714    }
715}
716
717impl ExprVisitable for LogicalJoin {
718    fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
719        self.core.visit_exprs(v);
720    }
721}
722
723/// We are trying to derive a predicate to apply to the other side of a join if all
724/// the `InputRef`s in the predicate are eq condition columns, and can hence be substituted
725/// with the corresponding eq condition columns of the other side.
726///
727/// Strategy:
728/// 1. If the function is pure except for any `InputRef` (which may refer to impure computation),
729///    then we proceed. Else abort.
730/// 2. Then, we collect `InputRef`s in the conjunction.
731/// 3. If they are all columns in the given side of join eq condition, then we proceed. Else abort.
732/// 4. We then rewrite the `ExprImpl`, by replacing `InputRef` column indices with the equivalent in
733///    the other side.
734///
735/// # Arguments
736///
737/// Suppose we derive a predicate from the left side to be pushed to the right side.
738/// * `expr`: An expr from the left side.
739/// * `col_num`: The number of columns in the left side.
740fn derive_predicate_from_eq_condition(
741    expr: &ExprImpl,
742    eq_condition: &EqJoinPredicate,
743    col_num: usize,
744    expr_is_left: bool,
745) -> Option<ExprImpl> {
746    if expr.is_impure() {
747        return None;
748    }
749    let eq_indices = eq_condition
750        .eq_indexes_typed()
751        .iter()
752        .filter_map(|(l, r)| {
753            if l.return_type() != r.return_type() {
754                None
755            } else if expr_is_left {
756                Some(l.index())
757            } else {
758                Some(r.index())
759            }
760        })
761        .collect_vec();
762    if expr
763        .collect_input_refs(col_num)
764        .ones()
765        .any(|index| !eq_indices.contains(&index))
766    {
767        // expr contains an InputRef not in eq_condition
768        return None;
769    }
770    // The function is pure except for `InputRef` and all `InputRef`s are `eq_condition` indices.
771    // Hence, we can substitute those `InputRef`s with indices from the other side.
772    let other_side_mapping = if expr_is_left {
773        eq_condition.eq_indexes_typed().into_iter().collect()
774    } else {
775        eq_condition
776            .eq_indexes_typed()
777            .into_iter()
778            .map(|(x, y)| (y, x))
779            .collect()
780    };
781    struct InputRefsRewriter {
782        mapping: HashMap<InputRef, InputRef>,
783    }
784    impl ExprRewriter for InputRefsRewriter {
785        fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
786            self.mapping[&input_ref].clone().into()
787        }
788    }
789    Some(
790        InputRefsRewriter {
791            mapping: other_side_mapping,
792        }
793        .rewrite_expr(expr.clone()),
794    )
795}
796
797/// Rewrite the join predicate and all columns referred to the scan side need to rewrite.
798struct LookupJoinPredicateRewriter {
799    offset: usize,
800    mapping: Vec<usize>,
801}
802impl ExprRewriter for LookupJoinPredicateRewriter {
803    fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
804        if input_ref.index() < self.offset {
805            input_ref.into()
806        } else {
807            InputRef::new(
808                self.mapping[input_ref.index() - self.offset] + self.offset,
809                input_ref.return_type(),
810            )
811            .into()
812        }
813    }
814}
815
816/// Rewrite the scan predicate so we can add it to the join predicate.
817struct LookupJoinScanPredicateRewriter {
818    offset: usize,
819}
820impl ExprRewriter for LookupJoinScanPredicateRewriter {
821    fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
822        InputRef::new(input_ref.index() + self.offset, input_ref.return_type()).into()
823    }
824}
825
826impl PredicatePushdown for LogicalJoin {
827    /// Pushes predicates above and within a join node into the join node and/or its children nodes.
828    ///
829    /// # Which predicates can be pushed
830    ///
831    /// For inner join, we can do all kinds of pushdown.
832    ///
833    /// For left/right semi join, we can push filter to left/right and on-clause,
834    /// and push on-clause to left/right.
835    ///
836    /// For left/right anti join, we can push filter to left/right, but on-clause can not be pushed
837    ///
838    /// ## Outer Join
839    ///
840    /// Preserved Row table
841    /// : The table in an Outer Join that must return all rows.
842    ///
843    /// Null Supplying table
844    /// : This is the table that has nulls filled in for its columns in unmatched rows.
845    ///
846    /// |                          | Preserved Row table | Null Supplying table |
847    /// |--------------------------|---------------------|----------------------|
848    /// | Join predicate (on)      | Not Pushed          | Pushed               |
849    /// | Where predicate (filter) | Pushed              | Not Pushed           |
850    fn predicate_pushdown(
851        &self,
852        predicate: Condition,
853        ctx: &mut PredicatePushdownContext,
854    ) -> PlanRef {
855        // rewrite output col referencing indices as internal cols
856        let mut predicate = {
857            let mut mapping = self.core.o2i_col_mapping();
858            predicate.rewrite_expr(&mut mapping)
859        };
860
861        let left_col_num = self.left().schema().len();
862        let right_col_num = self.right().schema().len();
863        let join_type = LogicalJoin::simplify_outer(&predicate, left_col_num, self.join_type());
864
865        let push_down_temporal_predicate = self.temporal_join_on().is_none();
866
867        let (left_from_filter, right_from_filter, on) = push_down_into_join(
868            &mut predicate,
869            left_col_num,
870            right_col_num,
871            join_type,
872            push_down_temporal_predicate,
873        );
874
875        let mut new_on = self.on().clone().and(on);
876        let (left_from_on, right_from_on) = push_down_join_condition(
877            &mut new_on,
878            left_col_num,
879            right_col_num,
880            join_type,
881            push_down_temporal_predicate,
882        );
883
884        let left_predicate = left_from_filter.and(left_from_on);
885        let right_predicate = right_from_filter.and(right_from_on);
886
887        // Derive conditions to push to the other side based on eq condition columns
888        let eq_condition = EqJoinPredicate::create(left_col_num, right_col_num, new_on.clone());
889
890        // Only push to RHS if RHS is inner side of a join (RHS requires match on LHS)
891        let right_from_left = if matches!(
892            join_type,
893            JoinType::Inner | JoinType::LeftOuter | JoinType::RightSemi | JoinType::LeftSemi
894        ) {
895            Condition {
896                conjunctions: left_predicate
897                    .conjunctions
898                    .iter()
899                    .filter_map(|expr| {
900                        derive_predicate_from_eq_condition(expr, &eq_condition, left_col_num, true)
901                    })
902                    .collect(),
903            }
904        } else {
905            Condition::true_cond()
906        };
907
908        // Only push to LHS if LHS is inner side of a join (LHS requires match on RHS)
909        let left_from_right = if matches!(
910            join_type,
911            JoinType::Inner | JoinType::RightOuter | JoinType::LeftSemi | JoinType::RightSemi
912        ) {
913            Condition {
914                conjunctions: right_predicate
915                    .conjunctions
916                    .iter()
917                    .filter_map(|expr| {
918                        derive_predicate_from_eq_condition(
919                            expr,
920                            &eq_condition,
921                            right_col_num,
922                            false,
923                        )
924                    })
925                    .collect(),
926            }
927        } else {
928            Condition::true_cond()
929        };
930
931        let left_predicate = left_predicate.and(left_from_right);
932        let right_predicate = right_predicate.and(right_from_left);
933
934        let new_left = self.left().predicate_pushdown(left_predicate, ctx);
935        let new_right = self.right().predicate_pushdown(right_predicate, ctx);
936        let new_join = LogicalJoin::with_output_indices(
937            new_left,
938            new_right,
939            join_type,
940            new_on,
941            self.output_indices().clone(),
942        );
943
944        let mut mapping = self.core.i2o_col_mapping();
945        predicate = predicate.rewrite_expr(&mut mapping);
946        LogicalFilter::create(new_join.into(), predicate)
947    }
948}
949
950#[derive(Clone, Copy)]
951struct TemporalJoinScan<'a>(&'a LogicalScan);
952
953impl<'a> Deref for TemporalJoinScan<'a> {
954    type Target = LogicalScan;
955
956    fn deref(&self) -> &Self::Target {
957        self.0
958    }
959}
960
961impl LogicalJoin {
962    fn get_stream_input_for_hash_join(
963        &self,
964        predicate: &EqJoinPredicate,
965        ctx: &mut ToStreamContext,
966    ) -> Result<(StreamPlanRef, StreamPlanRef)> {
967        use super::stream::prelude::*;
968
969        let mut right = self.right().to_stream_with_dist_required(
970            &RequiredDist::shard_by_key(self.right().schema().len(), &predicate.right_eq_indexes()),
971            ctx,
972        )?;
973        let r2l =
974            predicate.r2l_eq_columns_mapping(self.left().schema().len(), right.schema().len());
975        let l2r =
976            predicate.l2r_eq_columns_mapping(self.left().schema().len(), right.schema().len());
977        let mut left;
978        let right_dist = right.distribution();
979        match right_dist {
980            Distribution::HashShard(_) => {
981                let left_dist = r2l
982                    .rewrite_required_distribution(&RequiredDist::PhysicalDist(right_dist.clone()));
983                left = self.left().to_stream_with_dist_required(&left_dist, ctx)?;
984            }
985            Distribution::UpstreamHashShard(_, _) => {
986                left = self.left().to_stream_with_dist_required(
987                    &RequiredDist::shard_by_key(
988                        self.left().schema().len(),
989                        &predicate.left_eq_indexes(),
990                    ),
991                    ctx,
992                )?;
993                let left_dist = left.distribution();
994                match left_dist {
995                    Distribution::HashShard(_) => {
996                        let right_dist = l2r.rewrite_required_distribution(
997                            &RequiredDist::PhysicalDist(left_dist.clone()),
998                        );
999                        right = right_dist.streaming_enforce_if_not_satisfies(right)?
1000                    }
1001                    Distribution::UpstreamHashShard(_, _) => {
1002                        left = RequiredDist::hash_shard(&predicate.left_eq_indexes())
1003                            .streaming_enforce_if_not_satisfies(left)?;
1004                        right = RequiredDist::hash_shard(&predicate.right_eq_indexes())
1005                            .streaming_enforce_if_not_satisfies(right)?;
1006                    }
1007                    _ => unreachable!(),
1008                }
1009            }
1010            _ => unreachable!(),
1011        }
1012        Ok((left, right))
1013    }
1014
1015    fn to_stream_hash_join(
1016        &self,
1017        predicate: EqJoinPredicate,
1018        ctx: &mut ToStreamContext,
1019    ) -> Result<StreamPlanRef> {
1020        use super::stream::prelude::*;
1021
1022        assert!(predicate.has_eq());
1023        let (left, right) = self.get_stream_input_for_hash_join(&predicate, ctx)?;
1024
1025        let mut core = self.core.clone_with_inputs(left, right);
1026        core.on = generic::JoinOn::EqPredicate(predicate);
1027
1028        // Convert to Hash Join for equal joins
1029        // For inner joins, pull non-equal conditions to a filter operator on top of it by default.
1030        // We do so as the filter operator can apply the non-equal condition batch-wise (vectorized)
1031        // as opposed to the HashJoin, which applies the condition row-wise.
1032        // However, the default behavior of pulling up non-equal conditions can be overridden by the
1033        // session variable `streaming_force_filter_inside_join` as it can save unnecessary
1034        // materialization of rows only to be filtered later.
1035
1036        let stream_hash_join = StreamHashJoin::new(core.clone())?;
1037        let predicate = stream_hash_join.eq_join_predicate().clone();
1038
1039        let force_filter_inside_join = self
1040            .base
1041            .ctx()
1042            .session_ctx()
1043            .config()
1044            .streaming_force_filter_inside_join();
1045
1046        let pull_filter = self.join_type() == JoinType::Inner
1047            && stream_hash_join.eq_join_predicate().has_non_eq()
1048            && stream_hash_join.inequality_pairs().is_empty()
1049            && (!force_filter_inside_join);
1050        if pull_filter {
1051            let default_indices = (0..self.internal_column_num()).collect::<Vec<_>>();
1052
1053            let mut core = core;
1054            core.output_indices = default_indices.clone();
1055            // Temporarily remove output indices.
1056            let eq_cond = EqJoinPredicate::new(
1057                Condition::true_cond(),
1058                predicate.eq_keys().to_vec(),
1059                self.left().schema().len(),
1060                self.right().schema().len(),
1061            );
1062            core.on = generic::JoinOn::EqPredicate(eq_cond);
1063            let hash_join = StreamHashJoin::new(core)?.into();
1064            let logical_filter = generic::Filter::new(predicate.non_eq_cond(), hash_join);
1065            let plan = StreamFilter::new(logical_filter).into();
1066            if self.output_indices() != &default_indices {
1067                let logical_project = generic::Project::with_mapping(
1068                    plan,
1069                    ColIndexMapping::with_remaining_columns(
1070                        self.output_indices(),
1071                        self.internal_column_num(),
1072                    ),
1073                );
1074                Ok(StreamProject::new(logical_project).into())
1075            } else {
1076                Ok(plan)
1077            }
1078        } else {
1079            Ok(stream_hash_join.into())
1080        }
1081    }
1082
1083    pub fn should_be_temporal_join(&self) -> bool {
1084        self.temporal_join_on().is_some()
1085    }
1086
1087    fn temporal_join_on(&self) -> Option<TemporalJoinScan<'_>> {
1088        if let Some(logical_scan) = self.core.right.as_logical_scan() {
1089            matches!(logical_scan.as_of(), Some(AsOf::ProcessTime))
1090                .then_some(TemporalJoinScan(logical_scan))
1091        } else {
1092            None
1093        }
1094    }
1095
1096    fn should_be_stream_temporal_join<'a>(
1097        &'a self,
1098        ctx: &ToStreamContext,
1099    ) -> Result<Option<TemporalJoinScan<'a>>> {
1100        Ok(if let Some(scan) = self.temporal_join_on() {
1101            if let BackfillType::SnapshotBackfill = ctx.backfill_type() {
1102                return Err(RwError::from(ErrorCode::NotSupported(
1103                    "Temporal join with snapshot backfill not supported".into(),
1104                    "Please use arrangement backfill".into(),
1105                )));
1106            }
1107            if scan.cross_database() {
1108                return Err(RwError::from(ErrorCode::NotSupported(
1109                        "Temporal join requires the lookup table to be in the same database as the stream source table".into(),
1110                        "Please ensure both tables are in the same database".into(),
1111                    )));
1112            }
1113            Some(scan)
1114        } else {
1115            None
1116        })
1117    }
1118
1119    fn to_stream_temporal_join_with_index_selection(
1120        &self,
1121        logical_scan: TemporalJoinScan<'_>,
1122        predicate: EqJoinPredicate,
1123        ctx: &mut ToStreamContext,
1124    ) -> Result<StreamPlanRef> {
1125        // Use primary table.
1126        let mut result_plan: Result<StreamTemporalJoin> =
1127            self.to_stream_temporal_join(logical_scan, predicate.clone(), ctx);
1128        // Return directly if this temporal join can match the pk of its right table.
1129        if let Ok(temporal_join) = &result_plan
1130            && temporal_join.eq_join_predicate().eq_indexes().len()
1131                == logical_scan.primary_key().len()
1132        {
1133            return result_plan.map(|x| x.into());
1134        }
1135        if self
1136            .core
1137            .ctx()
1138            .session_ctx()
1139            .config()
1140            .enable_index_selection()
1141        {
1142            let indexes = logical_scan.table_indexes();
1143            for index in indexes {
1144                // Use index table
1145                if let Some(index_scan) = logical_scan.to_index_scan_if_index_covered(index) {
1146                    let index_scan: PlanRef = index_scan.into();
1147                    let that = self.clone_with_left_right(self.left(), index_scan.clone());
1148                    if let Ok(temporal_join) = that.to_stream_temporal_join(
1149                        that.temporal_join_on().expect(
1150                            "index scan created from temporal join scan must also be temporal join",
1151                        ),
1152                        predicate.clone(),
1153                        ctx,
1154                    ) {
1155                        match &result_plan {
1156                            Err(_) => result_plan = Ok(temporal_join),
1157                            Ok(prev_temporal_join) => {
1158                                // Prefer to the temporal join with a longer lookup prefix len.
1159                                if prev_temporal_join.eq_join_predicate().eq_indexes().len()
1160                                    < temporal_join.eq_join_predicate().eq_indexes().len()
1161                                {
1162                                    result_plan = Ok(temporal_join)
1163                                }
1164                            }
1165                        }
1166                    }
1167                }
1168            }
1169        }
1170
1171        result_plan.map(|x| x.into())
1172    }
1173
1174    fn temporal_join_scan_predicate_pull_up(
1175        logical_scan: TemporalJoinScan<'_>,
1176        predicate: EqJoinPredicate,
1177        output_indices: &[usize],
1178        left_schema_len: usize,
1179    ) -> Result<(StreamTableScan, EqJoinPredicate, Condition, Vec<usize>)> {
1180        // Extract the predicate from logical scan. Only pure scan is supported.
1181        let (new_scan, scan_predicate, project_expr) = logical_scan.predicate_pull_up();
1182        // Construct output column to require column mapping
1183        let o2r = if let Some(project_expr) = project_expr {
1184            project_expr
1185                .into_iter()
1186                .map(|x| x.as_input_ref().unwrap().index)
1187                .collect_vec()
1188        } else {
1189            (0..logical_scan.output_col_idx().len()).collect_vec()
1190        };
1191        let mut join_predicate_rewriter = LookupJoinPredicateRewriter {
1192            offset: left_schema_len,
1193            mapping: o2r.clone(),
1194        };
1195
1196        let new_eq_cond = predicate
1197            .eq_cond()
1198            .rewrite_expr(&mut join_predicate_rewriter);
1199
1200        let mut scan_predicate_rewriter = LookupJoinScanPredicateRewriter {
1201            offset: left_schema_len,
1202        };
1203
1204        let new_other_cond = predicate
1205            .other_cond()
1206            .clone()
1207            .rewrite_expr(&mut join_predicate_rewriter)
1208            .and(scan_predicate.rewrite_expr(&mut scan_predicate_rewriter));
1209
1210        let new_join_on = new_eq_cond.and(new_other_cond);
1211
1212        let new_predicate = EqJoinPredicate::create(
1213            left_schema_len,
1214            new_scan.schema().len(),
1215            new_join_on.clone(),
1216        );
1217
1218        // Rewrite the join output indices and all output indices referred to the old scan need to
1219        // rewrite.
1220        let new_join_output_indices = output_indices
1221            .iter()
1222            .map(|&x| {
1223                if x < left_schema_len {
1224                    x
1225                } else {
1226                    o2r[x - left_schema_len] + left_schema_len
1227                }
1228            })
1229            .collect_vec();
1230
1231        let new_stream_table_scan =
1232            StreamTableScan::new_with_stream_scan_type(new_scan, StreamScanType::UpstreamOnly);
1233        Ok((
1234            new_stream_table_scan,
1235            new_predicate,
1236            new_join_on,
1237            new_join_output_indices,
1238        ))
1239    }
1240
1241    fn to_stream_temporal_join(
1242        &self,
1243        logical_scan: TemporalJoinScan<'_>,
1244        predicate: EqJoinPredicate,
1245        ctx: &mut ToStreamContext,
1246    ) -> Result<StreamTemporalJoin> {
1247        use super::stream::prelude::*;
1248
1249        assert!(predicate.has_eq());
1250
1251        let table = logical_scan.table();
1252        let output_column_ids = logical_scan.output_column_ids();
1253
1254        // Verify that the right join key columns are the the prefix of the primary key and
1255        // also contain the distribution key.
1256        let order_col_ids = table.order_column_ids();
1257        let dist_key = table.distribution_key.clone();
1258
1259        let mut dist_key_in_order_key_pos = vec![];
1260        for d in dist_key {
1261            let pos = table
1262                .order_column_indices()
1263                .position(|x| x == d)
1264                .expect("dist_key must in order_key");
1265            dist_key_in_order_key_pos.push(pos);
1266        }
1267        // The shortest prefix of order key that contains distribution key.
1268        let shortest_prefix_len = dist_key_in_order_key_pos
1269            .iter()
1270            .max()
1271            .map_or(0, |pos| pos + 1);
1272
1273        // Reorder the join equal predicate to match the order key.
1274        let mut reorder_idx = Vec::with_capacity(shortest_prefix_len);
1275        for order_col_id in order_col_ids {
1276            let mut found = false;
1277            for (i, eq_idx) in predicate.right_eq_indexes().into_iter().enumerate() {
1278                if order_col_id == output_column_ids[eq_idx] {
1279                    reorder_idx.push(i);
1280                    found = true;
1281                    break;
1282                }
1283            }
1284            if !found {
1285                break;
1286            }
1287        }
1288        if reorder_idx.len() < shortest_prefix_len {
1289            return Err(RwError::from(ErrorCode::NotSupported(
1290                "Temporal join requires the equivalence join condition includes the key columns that form the distribution key of the lookup table".into(),
1291                concat!(
1292                    "Use DESCRIBE <table_name> to view the table's key information.\n",
1293                    "You can create an index on the lookup table to facilitate the temporal join if necessary."
1294                ).into(),
1295            )));
1296        }
1297        let lookup_prefix_len = reorder_idx.len();
1298        let predicate = predicate.reorder(&reorder_idx);
1299
1300        let required_dist = if dist_key_in_order_key_pos.is_empty() {
1301            RequiredDist::single()
1302        } else {
1303            let left_eq_indexes = predicate.left_eq_indexes();
1304            let left_dist_key = dist_key_in_order_key_pos
1305                .iter()
1306                .map(|pos| left_eq_indexes[*pos])
1307                .collect_vec();
1308
1309            RequiredDist::hash_shard(&left_dist_key)
1310        };
1311
1312        let left = self.left().to_stream(ctx)?;
1313        // Enforce a shuffle for the temporal join LHS to let the scheduler be able to schedule the join fragment together with the RHS with a `no_shuffle` exchange.
1314        let left = required_dist.stream_enforce(left);
1315
1316        let (new_stream_table_scan, new_predicate, new_join_on, new_join_output_indices) =
1317            Self::temporal_join_scan_predicate_pull_up(
1318                logical_scan,
1319                predicate,
1320                self.output_indices(),
1321                self.left().schema().len(),
1322            )?;
1323
1324        let right = RequiredDist::no_shuffle(new_stream_table_scan.into());
1325        if !new_predicate.has_eq() {
1326            return Err(RwError::from(ErrorCode::NotSupported(
1327                "Temporal join requires a non trivial join condition".into(),
1328                "Please remove the false condition of the join".into(),
1329            )));
1330        }
1331
1332        // Construct a new logical join, because we have change its RHS.
1333        let new_logical_join = generic::Join::new(
1334            left,
1335            right,
1336            new_join_on,
1337            self.join_type(),
1338            new_join_output_indices,
1339        );
1340
1341        let new_predicate = new_predicate.retain_prefix_eq_key(lookup_prefix_len);
1342
1343        let mut new_logical_join = new_logical_join;
1344        new_logical_join.on = generic::JoinOn::EqPredicate(new_predicate);
1345        StreamTemporalJoin::new(new_logical_join, false)
1346    }
1347
1348    fn to_stream_nested_loop_temporal_join(
1349        &self,
1350        logical_scan: TemporalJoinScan<'_>,
1351        predicate: EqJoinPredicate,
1352        ctx: &mut ToStreamContext,
1353    ) -> Result<StreamPlanRef> {
1354        use super::stream::prelude::*;
1355        assert!(!predicate.has_eq());
1356
1357        let left = self.left().to_stream_with_dist_required(
1358            &RequiredDist::PhysicalDist(Distribution::Broadcast),
1359            ctx,
1360        )?;
1361        assert!(left.as_stream_exchange().is_some());
1362
1363        if self.join_type() != JoinType::Inner {
1364            return Err(RwError::from(ErrorCode::NotSupported(
1365                "Temporal join requires an inner join".into(),
1366                "Please use an inner join".into(),
1367            )));
1368        }
1369
1370        if !left.append_only() {
1371            return Err(RwError::from(ErrorCode::NotSupported(
1372                "Nested-loop Temporal join requires the left hash side to be append only".into(),
1373                "Please ensure the left hash side is append only".into(),
1374            )));
1375        }
1376
1377        let (new_stream_table_scan, new_predicate, new_join_on, new_join_output_indices) =
1378            Self::temporal_join_scan_predicate_pull_up(
1379                logical_scan,
1380                predicate,
1381                self.output_indices(),
1382                self.left().schema().len(),
1383            )?;
1384
1385        let right = RequiredDist::no_shuffle(new_stream_table_scan.into());
1386
1387        // Construct a new logical join, because we have change its RHS.
1388        let new_logical_join = generic::Join::new(
1389            left,
1390            right,
1391            new_join_on,
1392            self.join_type(),
1393            new_join_output_indices,
1394        );
1395
1396        let mut new_logical_join = new_logical_join;
1397        new_logical_join.on = generic::JoinOn::EqPredicate(new_predicate);
1398        Ok(StreamTemporalJoin::new(new_logical_join, true)?.into())
1399    }
1400
1401    fn to_stream_dynamic_filter(
1402        &self,
1403        predicate: Condition,
1404        ctx: &mut ToStreamContext,
1405    ) -> Result<Option<StreamPlanRef>> {
1406        use super::stream::prelude::*;
1407
1408        // If there is exactly one predicate, it is a comparison (<, <=, >, >=), and the
1409        // join is a `Inner` or `LeftSemi` join, we can convert the scalar subquery into a
1410        // `StreamDynamicFilter`
1411        let Some((left_key_idx, comparator)) = self.dynamic_filter_candidate(&predicate) else {
1412            return Ok(None);
1413        };
1414
1415        let left = self.left().to_stream(ctx)?.enforce_concrete_distribution();
1416        let right = self.right().to_stream_with_dist_required(
1417            &RequiredDist::PhysicalDist(Distribution::Broadcast),
1418            ctx,
1419        )?;
1420
1421        assert!(right.as_stream_exchange().is_some());
1422        assert_eq!(
1423            *Itertools::exactly_one(right.inputs().iter())
1424                .unwrap()
1425                .distribution(),
1426            Distribution::Single
1427        );
1428
1429        let core = DynamicFilter::new(comparator, left_key_idx, left, right);
1430        let plan = StreamDynamicFilter::new(core)?.into();
1431        // TODO: `DynamicFilterExecutor` should support `output_indices` in `ChunkBuilder`
1432        if self
1433            .output_indices()
1434            .iter()
1435            .copied()
1436            .ne(0..self.left().schema().len())
1437        {
1438            // The schema of dynamic filter is always the same as the left side now, and we have
1439            // checked that all output columns are from the left side before.
1440            let logical_project = generic::Project::with_mapping(
1441                plan,
1442                ColIndexMapping::with_remaining_columns(
1443                    self.output_indices(),
1444                    self.left().schema().len(),
1445                ),
1446            );
1447            Ok(Some(StreamProject::new(logical_project).into()))
1448        } else {
1449            Ok(Some(plan))
1450        }
1451    }
1452
1453    pub fn index_lookup_join_to_batch_lookup_join(&self) -> Result<BatchPlanRef> {
1454        let predicate = EqJoinPredicate::create(
1455            self.left().schema().len(),
1456            self.right().schema().len(),
1457            self.on().clone(),
1458        );
1459        assert!(predicate.has_eq());
1460
1461        let join = self
1462            .core
1463            .clone_with_inputs(self.core.left.to_batch()?, self.core.right.to_batch()?);
1464
1465        Ok(self
1466            .to_batch_lookup_join(predicate, join)?
1467            .expect("Fail to convert to lookup join")
1468            .into())
1469    }
1470
1471    fn to_stream_asof_join(
1472        &self,
1473        predicate: EqJoinPredicate,
1474        ctx: &mut ToStreamContext,
1475    ) -> Result<StreamPlanRef> {
1476        use super::stream::prelude::*;
1477
1478        if predicate.eq_keys().is_empty() {
1479            return Err(ErrorCode::InvalidInputSyntax(
1480                "AsOf join requires at least 1 equal condition".to_owned(),
1481            )
1482            .into());
1483        }
1484
1485        let (left, right) = self.get_stream_input_for_hash_join(&predicate, ctx)?;
1486        let left_len = left.schema().len();
1487        let mut core = self.core.clone_with_inputs(left, right);
1488        core.on = generic::JoinOn::EqPredicate(predicate);
1489
1490        let inequality_desc = Self::get_inequality_desc_from_predicate(
1491            core.on
1492                .as_eq_predicate_ref()
1493                .expect("core predicate must exist")
1494                .other_cond()
1495                .clone(),
1496            left_len,
1497        )?;
1498
1499        Ok(StreamAsOfJoin::new(core, inequality_desc)?.into())
1500    }
1501
1502    /// Convert the logical join to a Hash join.
1503    fn to_batch_hash_join(
1504        &self,
1505        logical_join: generic::Join<BatchPlanRef>,
1506        predicate: EqJoinPredicate,
1507    ) -> Result<BatchPlanRef> {
1508        use super::batch::prelude::*;
1509
1510        let left_schema_len = logical_join.left.schema().len();
1511        let asof_desc = self
1512            .is_asof_join()
1513            .then(|| {
1514                Self::get_inequality_desc_from_predicate(
1515                    predicate.other_cond().clone(),
1516                    left_schema_len,
1517                )
1518            })
1519            .transpose()?;
1520
1521        let logical_join = generic::Join {
1522            on: generic::JoinOn::EqPredicate(predicate),
1523            ..logical_join
1524        };
1525        let batch_join = BatchHashJoin::new(logical_join, asof_desc);
1526        Ok(batch_join.into())
1527    }
1528
1529    pub fn get_inequality_desc_from_predicate(
1530        predicate: Condition,
1531        left_input_len: usize,
1532    ) -> Result<AsOfJoinDesc> {
1533        let expr: ExprImpl = predicate.into();
1534        if let Some((left_input_ref, expr_type, right_input_ref)) = expr.as_comparison_cond() {
1535            if left_input_ref.index() < left_input_len && right_input_ref.index() >= left_input_len
1536            {
1537                Ok(AsOfJoinDesc {
1538                    left_idx: left_input_ref.index() as u32,
1539                    right_idx: (right_input_ref.index() - left_input_len) as u32,
1540                    inequality_type: Self::expr_type_to_comparison_type(expr_type)?.into(),
1541                })
1542            } else {
1543                bail!("inequal condition from the same side should be push down in optimizer");
1544            }
1545        } else {
1546            Err(ErrorCode::InvalidInputSyntax(
1547                "AsOf join requires exactly 1 ineuquality condition".to_owned(),
1548            )
1549            .into())
1550        }
1551    }
1552
1553    fn expr_type_to_comparison_type(expr_type: PbType) -> Result<PbAsOfJoinInequalityType> {
1554        match expr_type {
1555            PbType::LessThan => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeLt),
1556            PbType::LessThanOrEqual => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeLe),
1557            PbType::GreaterThan => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeGt),
1558            PbType::GreaterThanOrEqual => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeGe),
1559            _ => Err(ErrorCode::InvalidInputSyntax(format!(
1560                "Invalid comparison type: {}",
1561                expr_type.as_str_name()
1562            ))
1563            .into()),
1564        }
1565    }
1566}
1567
1568impl ToBatch for LogicalJoin {
1569    fn to_batch(&self) -> Result<crate::optimizer::plan_node::BatchPlanRef> {
1570        let predicate = EqJoinPredicate::create(
1571            self.left().schema().len(),
1572            self.right().schema().len(),
1573            self.on().clone(),
1574        );
1575
1576        let batch_join = self
1577            .core
1578            .clone_with_inputs(self.core.left.to_batch()?, self.core.right.to_batch()?);
1579
1580        let ctx = self.base.ctx();
1581        let config = ctx.session_ctx().config();
1582
1583        if predicate.has_eq() {
1584            if !predicate.eq_keys_are_type_aligned() {
1585                return Err(ErrorCode::InternalError(format!(
1586                    "Join eq keys are not aligned for predicate: {predicate:?}"
1587                ))
1588                .into());
1589            }
1590            if config.batch_enable_lookup_join()
1591                && let Some(lookup_join) = self.to_batch_lookup_join_with_index_selection(
1592                    predicate.clone(),
1593                    batch_join.clone(),
1594                )?
1595            {
1596                return Ok(lookup_join.into());
1597            }
1598            self.to_batch_hash_join(batch_join, predicate)
1599        } else if self.is_asof_join() {
1600            Err(ErrorCode::InvalidInputSyntax(
1601                "AsOf join requires at least 1 equal condition".to_owned(),
1602            )
1603            .into())
1604        } else {
1605            // Convert to Nested-loop Join for non-equal joins
1606            Ok(BatchNestedLoopJoin::new(batch_join).into())
1607        }
1608    }
1609}
1610
1611impl ToStream for LogicalJoin {
1612    fn to_stream(
1613        &self,
1614        ctx: &mut ToStreamContext,
1615    ) -> Result<crate::optimizer::plan_node::StreamPlanRef> {
1616        if self
1617            .on()
1618            .conjunctions
1619            .iter()
1620            .any(|cond| cond.count_nows() > 0)
1621        {
1622            return Err(ErrorCode::NotSupported(
1623                "optimizer has tried to separate the temporal predicate(with now() expression) from the on condition, but it still reminded in on join's condition. Considering move it into WHERE clause?".to_owned(),
1624                 "please refer to https://docs.risingwave.com/processing/sql/temporal-filters for more information".to_owned()).into());
1625        }
1626
1627        let predicate = EqJoinPredicate::create(
1628            self.left().schema().len(),
1629            self.right().schema().len(),
1630            self.on().clone(),
1631        );
1632
1633        if self.join_type() == JoinType::AsofInner || self.join_type() == JoinType::AsofLeftOuter {
1634            self.to_stream_asof_join(predicate, ctx)
1635        } else if predicate.has_eq() {
1636            if !predicate.eq_keys_are_type_aligned() {
1637                return Err(ErrorCode::InternalError(format!(
1638                    "Join eq keys are not aligned for predicate: {predicate:?}"
1639                ))
1640                .into());
1641            }
1642
1643            if let Some(scan) = self.should_be_stream_temporal_join(ctx)? {
1644                self.to_stream_temporal_join_with_index_selection(scan, predicate, ctx)
1645            } else {
1646                self.to_stream_hash_join(predicate, ctx)
1647            }
1648        } else if let Some(scan) = self.should_be_stream_temporal_join(ctx)? {
1649            self.to_stream_nested_loop_temporal_join(scan, predicate, ctx)
1650        } else if let Some(dynamic_filter) =
1651            self.to_stream_dynamic_filter(self.on().clone(), ctx)?
1652        {
1653            Ok(dynamic_filter)
1654        } else {
1655            Err(RwError::from(ErrorCode::NotSupported(
1656                "streaming nested-loop join".to_owned(),
1657                "The non-equal join in the query requires a nested-loop join executor, which could be very expensive to run. \
1658                 Consider rewriting the query to use dynamic filter as a substitute if possible.\n\
1659                 See also: https://docs.risingwave.com/processing/sql/dynamic-filters".to_owned(),
1660            )))
1661        }
1662    }
1663
1664    fn logical_rewrite_for_stream(
1665        &self,
1666        ctx: &mut RewriteStreamContext,
1667    ) -> Result<(PlanRef, ColIndexMapping)> {
1668        let eq_indexes = self.eq_indexes();
1669        let (logical_left, logical_right) = if eq_indexes.is_empty() {
1670            (self.left(), self.right())
1671        } else {
1672            let lhs_join_key_idx = eq_indexes.iter().map(|(l, _)| *l).collect_vec();
1673            if self.should_be_temporal_join() {
1674                (
1675                    try_enforce_locality_requirement(self.left(), &lhs_join_key_idx),
1676                    self.right(),
1677                )
1678            } else {
1679                let rhs_join_key_idx = eq_indexes.iter().map(|(_, r)| *r).collect_vec();
1680                (
1681                    try_enforce_locality_requirement(self.left(), &lhs_join_key_idx),
1682                    try_enforce_locality_requirement(self.right(), &rhs_join_key_idx),
1683                )
1684            }
1685        };
1686
1687        let (left, left_col_change) = logical_left.logical_rewrite_for_stream(ctx)?;
1688        let left_len = left.schema().len();
1689        let (right, right_col_change) = logical_right.logical_rewrite_for_stream(ctx)?;
1690        let (join, out_col_change) = self.rewrite_with_left_right(
1691            left.clone(),
1692            left_col_change,
1693            right.clone(),
1694            right_col_change,
1695        );
1696
1697        let mapping = ColIndexMapping::with_remaining_columns(
1698            join.output_indices(),
1699            join.internal_column_num(),
1700        );
1701
1702        let l2o = join.core.l2i_col_mapping().composite(&mapping);
1703        let r2o = join.core.r2i_col_mapping().composite(&mapping);
1704
1705        // Add missing pk indices to the logical join
1706        let mut left_to_add = left
1707            .expect_stream_key()
1708            .iter()
1709            .cloned()
1710            .filter(|i| l2o.try_map(*i).is_none())
1711            .collect_vec();
1712
1713        let mut right_to_add = right
1714            .expect_stream_key()
1715            .iter()
1716            .filter(|&&i| r2o.try_map(i).is_none())
1717            .map(|&i| i + left_len)
1718            .collect_vec();
1719
1720        // NOTE(st1page): add join keys in the pk_indices a work around before we really have stream
1721        // key.
1722        let right_len = right.schema().len();
1723        let eq_predicate = EqJoinPredicate::create(left_len, right_len, join.on().clone());
1724
1725        let either_or_both = self.core.add_which_join_key_to_pk();
1726
1727        for (lk, rk) in eq_predicate.eq_indexes() {
1728            match either_or_both {
1729                EitherOrBoth::Left(_) => {
1730                    if l2o.try_map(lk).is_none() {
1731                        left_to_add.push(lk);
1732                    }
1733                }
1734                EitherOrBoth::Right(_) => {
1735                    if r2o.try_map(rk).is_none() {
1736                        right_to_add.push(rk + left_len)
1737                    }
1738                }
1739                EitherOrBoth::Both(_, _) => {
1740                    if l2o.try_map(lk).is_none() {
1741                        left_to_add.push(lk);
1742                    }
1743                    if r2o.try_map(rk).is_none() {
1744                        right_to_add.push(rk + left_len)
1745                    }
1746                }
1747            };
1748        }
1749        let left_to_add = left_to_add.into_iter().unique();
1750        let right_to_add = right_to_add.into_iter().unique();
1751        // NOTE(st1page) over
1752
1753        let mut new_output_indices = join.output_indices().clone();
1754        if !join.is_right_join() {
1755            new_output_indices.extend(left_to_add);
1756        }
1757        if !join.is_left_join() {
1758            new_output_indices.extend(right_to_add);
1759        }
1760
1761        let join_with_pk = join.clone_with_output_indices(new_output_indices);
1762
1763        let plan = if join_with_pk.join_type() == JoinType::FullOuter {
1764            // ignore the all NULL to maintain the stream key's uniqueness, see https://github.com/risingwavelabs/risingwave/issues/8084 for more information
1765
1766            let l2o = join_with_pk
1767                .core
1768                .l2i_col_mapping()
1769                .composite(&join_with_pk.core.i2o_col_mapping());
1770            let r2o = join_with_pk
1771                .core
1772                .r2i_col_mapping()
1773                .composite(&join_with_pk.core.i2o_col_mapping());
1774            let mut left_right_keys = join_with_pk
1775                .left()
1776                .expect_stream_key()
1777                .iter()
1778                .map(|i| l2o.map(*i))
1779                .collect_vec();
1780            left_right_keys.extend(
1781                join_with_pk
1782                    .right()
1783                    .expect_stream_key()
1784                    .iter()
1785                    .map(|i| r2o.map(*i)),
1786            );
1787            left_right_keys.extend(
1788                eq_predicate
1789                    .eq_indexes()
1790                    .iter()
1791                    .flat_map(|(lk, rk)| [l2o.map(*lk), r2o.map(*rk)]),
1792            );
1793            let left_right_keys = left_right_keys.into_iter().unique().collect_vec();
1794            let plan: PlanRef = join_with_pk.into();
1795            LogicalFilter::filter_out_all_null_keys(plan, &left_right_keys)
1796        } else {
1797            join_with_pk.into()
1798        };
1799
1800        // the added columns is at the end, so it will not change the exists column index
1801        Ok((plan, out_col_change))
1802    }
1803
1804    fn try_better_locality(&self, columns: &[usize]) -> Option<PlanRef> {
1805        // Only propagate locality for temporal-filter.
1806        if !self.temporal_filter_candidate() {
1807            return None;
1808        }
1809
1810        // Temporal filter only outputs columns from left input, so mapping is safe.
1811        let o2i_mapping = self.core.o2i_col_mapping();
1812        let left_input_columns = columns
1813            .iter()
1814            .map(|&col| o2i_mapping.try_map(col))
1815            .collect::<Option<Vec<usize>>>()?;
1816        if let Some(better_left_plan) = self.left().try_better_locality(&left_input_columns) {
1817            return Some(
1818                self.clone_with_left_right(better_left_plan, self.right())
1819                    .into(),
1820            );
1821        }
1822        None
1823    }
1824}
1825
1826#[cfg(test)]
1827mod tests {
1828
1829    use std::collections::HashSet;
1830
1831    use risingwave_common::catalog::{Field, Schema};
1832    use risingwave_common::types::{DataType, Datum};
1833    use risingwave_pb::expr::expr_node::Type;
1834
1835    use super::*;
1836    use crate::expr::{FunctionCall, Literal, assert_eq_input_ref};
1837    use crate::optimizer::optimizer_context::OptimizerContext;
1838    use crate::optimizer::plan_node::LogicalValues;
1839    use crate::optimizer::property::FunctionalDependency;
1840
1841    /// Pruning
1842    /// ```text
1843    /// Join(on: input_ref(1)=input_ref(3))
1844    ///   TableScan(v1, v2, v3)
1845    ///   TableScan(v4, v5, v6)
1846    /// ```
1847    /// with required columns [2,3] will result in
1848    /// ```text
1849    /// Project(input_ref(1), input_ref(2))
1850    ///   Join(on: input_ref(0)=input_ref(2))
1851    ///     TableScan(v2, v3)
1852    ///     TableScan(v4)
1853    /// ```
1854    #[tokio::test]
1855    async fn test_prune_join() {
1856        let ty = DataType::Int32;
1857        let ctx = OptimizerContext::mock();
1858        let fields: Vec<Field> = (1..7)
1859            .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
1860            .collect();
1861        let left = LogicalValues::new(
1862            vec![],
1863            Schema {
1864                fields: fields[0..3].to_vec(),
1865            },
1866            ctx.clone(),
1867        );
1868        let right = LogicalValues::new(
1869            vec![],
1870            Schema {
1871                fields: fields[3..6].to_vec(),
1872            },
1873            ctx,
1874        );
1875        let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
1876            FunctionCall::new(
1877                Type::Equal,
1878                vec![
1879                    ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
1880                    ExprImpl::InputRef(Box::new(InputRef::new(3, ty))),
1881                ],
1882            )
1883            .unwrap(),
1884        ));
1885        let join_type = JoinType::Inner;
1886        let join: PlanRef = LogicalJoin::new(
1887            left.into(),
1888            right.into(),
1889            join_type,
1890            Condition::with_expr(on),
1891        )
1892        .into();
1893
1894        // Perform the prune
1895        let required_cols = vec![2, 3];
1896        let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1897
1898        // Check the result
1899        let join = plan.as_logical_join().unwrap();
1900        assert_eq!(join.schema().fields().len(), 2);
1901        assert_eq!(join.schema().fields()[0], fields[2]);
1902        assert_eq!(join.schema().fields()[1], fields[3]);
1903
1904        let expr: ExprImpl = join.on().clone().into();
1905        let call = expr.as_function_call().unwrap();
1906        assert_eq_input_ref!(&call.inputs()[0], 0);
1907        assert_eq_input_ref!(&call.inputs()[1], 2);
1908
1909        let left = join.left();
1910        let left = left.as_logical_values().unwrap();
1911        assert_eq!(left.schema().fields(), &fields[1..3]);
1912        let right = join.right();
1913        let right = right.as_logical_values().unwrap();
1914        assert_eq!(right.schema().fields(), &fields[3..4]);
1915    }
1916
1917    /// Semi join panicked previously at `prune_col`. Add test to prevent regression.
1918    #[tokio::test]
1919    async fn test_prune_semi_join() {
1920        let ty = DataType::Int32;
1921        let ctx = OptimizerContext::mock();
1922        let fields: Vec<Field> = (1..7)
1923            .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
1924            .collect();
1925        let left = LogicalValues::new(
1926            vec![],
1927            Schema {
1928                fields: fields[0..3].to_vec(),
1929            },
1930            ctx.clone(),
1931        );
1932        let right = LogicalValues::new(
1933            vec![],
1934            Schema {
1935                fields: fields[3..6].to_vec(),
1936            },
1937            ctx,
1938        );
1939        let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
1940            FunctionCall::new(
1941                Type::Equal,
1942                vec![
1943                    ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
1944                    ExprImpl::InputRef(Box::new(InputRef::new(4, ty))),
1945                ],
1946            )
1947            .unwrap(),
1948        ));
1949        for join_type in [
1950            JoinType::LeftSemi,
1951            JoinType::RightSemi,
1952            JoinType::LeftAnti,
1953            JoinType::RightAnti,
1954        ] {
1955            let join = LogicalJoin::new(
1956                left.clone().into(),
1957                right.clone().into(),
1958                join_type,
1959                Condition::with_expr(on.clone()),
1960            );
1961
1962            let offset = if join.is_right_join() { 3 } else { 0 };
1963            let join: PlanRef = join.into();
1964            // Perform the prune
1965            let required_cols = vec![0];
1966            // key 0 is never used in the join (always key 1)
1967            let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1968            let as_plan = plan.as_logical_join().unwrap();
1969            // Check the result
1970            assert_eq!(as_plan.schema().fields().len(), 1);
1971            assert_eq!(as_plan.schema().fields()[0], fields[offset]);
1972
1973            // Perform the prune
1974            let required_cols = vec![0, 1, 2];
1975            // should not panic here
1976            let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1977            let as_plan = plan.as_logical_join().unwrap();
1978            // Check the result
1979            assert_eq!(as_plan.schema().fields().len(), 3);
1980            assert_eq!(as_plan.schema().fields()[0], fields[offset]);
1981            assert_eq!(as_plan.schema().fields()[1], fields[offset + 1]);
1982            assert_eq!(as_plan.schema().fields()[2], fields[offset + 2]);
1983        }
1984    }
1985
1986    /// Pruning
1987    /// ```text
1988    /// Join(on: input_ref(1)=input_ref(3))
1989    ///   TableScan(v1, v2, v3)
1990    ///   TableScan(v4, v5, v6)
1991    /// ```
1992    /// with required columns [1, 3] will result in
1993    /// ```text
1994    /// Join(on: input_ref(0)=input_ref(1))
1995    ///   TableScan(v2)
1996    ///   TableScan(v4)
1997    /// ```
1998    #[tokio::test]
1999    async fn test_prune_join_no_project() {
2000        let ty = DataType::Int32;
2001        let ctx = OptimizerContext::mock();
2002        let fields: Vec<Field> = (1..7)
2003            .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
2004            .collect();
2005        let left = LogicalValues::new(
2006            vec![],
2007            Schema {
2008                fields: fields[0..3].to_vec(),
2009            },
2010            ctx.clone(),
2011        );
2012        let right = LogicalValues::new(
2013            vec![],
2014            Schema {
2015                fields: fields[3..6].to_vec(),
2016            },
2017            ctx,
2018        );
2019        let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
2020            FunctionCall::new(
2021                Type::Equal,
2022                vec![
2023                    ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
2024                    ExprImpl::InputRef(Box::new(InputRef::new(3, ty))),
2025                ],
2026            )
2027            .unwrap(),
2028        ));
2029        let join_type = JoinType::Inner;
2030        let join: PlanRef = LogicalJoin::new(
2031            left.into(),
2032            right.into(),
2033            join_type,
2034            Condition::with_expr(on),
2035        )
2036        .into();
2037
2038        // Perform the prune
2039        let required_cols = vec![1, 3];
2040        let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
2041
2042        // Check the result
2043        let join = plan.as_logical_join().unwrap();
2044        assert_eq!(join.schema().fields().len(), 2);
2045        assert_eq!(join.schema().fields()[0], fields[1]);
2046        assert_eq!(join.schema().fields()[1], fields[3]);
2047
2048        let expr: ExprImpl = join.on().clone().into();
2049        let call = expr.as_function_call().unwrap();
2050        assert_eq_input_ref!(&call.inputs()[0], 0);
2051        assert_eq_input_ref!(&call.inputs()[1], 1);
2052
2053        let left = join.left();
2054        let left = left.as_logical_values().unwrap();
2055        assert_eq!(left.schema().fields(), &fields[1..2]);
2056        let right = join.right();
2057        let right = right.as_logical_values().unwrap();
2058        assert_eq!(right.schema().fields(), &fields[3..4]);
2059    }
2060
2061    /// Convert
2062    /// ```text
2063    /// Join(on: ($1 = $3) AND ($2 == 42))
2064    ///   TableScan(v1, v2, v3)
2065    ///   TableScan(v4, v5, v6)
2066    /// ```
2067    /// to
2068    /// ```text
2069    /// Filter($2 == 42)
2070    ///   HashJoin(on: $1 = $3)
2071    ///     TableScan(v1, v2, v3)
2072    ///     TableScan(v4, v5, v6)
2073    /// ```
2074    #[tokio::test]
2075    async fn test_join_to_batch() {
2076        let ctx = OptimizerContext::mock();
2077        let fields: Vec<Field> = (1..7)
2078            .map(|i| Field::with_name(DataType::Int32, format!("v{}", i)))
2079            .collect();
2080        let left = LogicalValues::new(
2081            vec![],
2082            Schema {
2083                fields: fields[0..3].to_vec(),
2084            },
2085            ctx.clone(),
2086        );
2087        let right = LogicalValues::new(
2088            vec![],
2089            Schema {
2090                fields: fields[3..6].to_vec(),
2091            },
2092            ctx,
2093        );
2094
2095        fn input_ref(i: usize) -> ExprImpl {
2096            ExprImpl::InputRef(Box::new(InputRef::new(i, DataType::Int32)))
2097        }
2098        let eq_cond = ExprImpl::FunctionCall(Box::new(
2099            FunctionCall::new(Type::Equal, vec![input_ref(1), input_ref(3)]).unwrap(),
2100        ));
2101        let non_eq_cond = ExprImpl::FunctionCall(Box::new(
2102            FunctionCall::new(
2103                Type::Equal,
2104                vec![
2105                    input_ref(2),
2106                    ExprImpl::Literal(Box::new(Literal::new(
2107                        Datum::Some(42_i32.into()),
2108                        DataType::Int32,
2109                    ))),
2110                ],
2111            )
2112            .unwrap(),
2113        ));
2114        // Condition: ($1 = $3) AND ($2 == 42)
2115        let on_cond = ExprImpl::FunctionCall(Box::new(
2116            FunctionCall::new(Type::And, vec![eq_cond.clone(), non_eq_cond.clone()]).unwrap(),
2117        ));
2118
2119        let join_type = JoinType::Inner;
2120        let logical_join = LogicalJoin::new(
2121            left.into(),
2122            right.into(),
2123            join_type,
2124            Condition::with_expr(on_cond),
2125        );
2126
2127        // Perform `to_batch`
2128        let result = logical_join.to_batch().unwrap();
2129
2130        // Expected plan:  HashJoin($1 = $3 AND $2 == 42)
2131        let hash_join = result.as_batch_hash_join().unwrap();
2132        assert_eq!(
2133            ExprImpl::from(hash_join.eq_join_predicate().eq_cond()),
2134            eq_cond
2135        );
2136        assert_eq!(
2137            *hash_join
2138                .eq_join_predicate()
2139                .non_eq_cond()
2140                .conjunctions
2141                .first()
2142                .unwrap(),
2143            non_eq_cond
2144        );
2145    }
2146
2147    /// Convert
2148    /// ```text
2149    /// Join(join_type: left outer, on: ($1 = $3) AND ($2 == 42))
2150    ///   TableScan(v1, v2, v3)
2151    ///   TableScan(v4, v5, v6)
2152    /// ```
2153    /// to
2154    /// ```text
2155    /// HashJoin(join_type: left outer, on: ($1 = $3) AND ($2 == 42))
2156    ///   TableScan(v1, v2, v3)
2157    ///   TableScan(v4, v5, v6)
2158    /// ```
2159    #[tokio::test]
2160    #[ignore] // ignore due to refactor logical scan, but the test seem to duplicate with the explain test
2161    // framework, maybe we will remove it?
2162    async fn test_join_to_stream() {
2163        // let ctx = Rc::new(RefCell::new(QueryContext::mock().await));
2164        // let fields: Vec<Field> = (1..7)
2165        //     .map(|i| Field {
2166        //         data_type: DataType::Int32,
2167        //         name: format!("v{}", i),
2168        //     })
2169        //     .collect();
2170        // let left = LogicalScan::new(
2171        //     "left".to_string(),
2172        //     TableId::new(0),
2173        //     vec![1.into(), 2.into(), 3.into()],
2174        //     Schema {
2175        //         fields: fields[0..3].to_vec(),
2176        //     },
2177        //     ctx.clone(),
2178        // );
2179        // let right = LogicalScan::new(
2180        //     "right".to_string(),
2181        //     TableId::new(0),
2182        //     vec![4.into(), 5.into(), 6.into()],
2183        //     Schema {
2184        //                 fields: fields[3..6].to_vec(),
2185        //     },
2186        //     ctx,
2187        // );
2188        // let eq_cond = ExprImpl::FunctionCall(Box::new(
2189        //     FunctionCall::new(
2190        //         Type::Equal,
2191        //         vec![
2192        //             ExprImpl::InputRef(Box::new(InputRef::new(1, DataType::Int32))),
2193        //             ExprImpl::InputRef(Box::new(InputRef::new(3, DataType::Int32))),
2194        //         ],
2195        //     )
2196        //     .unwrap(),
2197        // ));
2198        // let non_eq_cond = ExprImpl::FunctionCall(Box::new(
2199        //     FunctionCall::new(
2200        //         Type::Equal,
2201        //         vec![
2202        //             ExprImpl::InputRef(Box::new(InputRef::new(2, DataType::Int32))),
2203        //             ExprImpl::Literal(Box::new(Literal::new(
2204        //                 Datum::Some(42_i32.into()),
2205        //                 DataType::Int32,
2206        //             ))),
2207        //         ],
2208        //     )
2209        //     .unwrap(),
2210        // ));
2211        // // Condition: ($1 = $3) AND ($2 == 42)
2212        // let on_cond = ExprImpl::FunctionCall(Box::new(
2213        //     FunctionCall::new(Type::And, vec![eq_cond, non_eq_cond]).unwrap(),
2214        // ));
2215
2216        // let join_type = JoinType::LeftOuter;
2217        // let logical_join = LogicalJoin::new(
2218        //     left.clone().into(),
2219        //     right.clone().into(),
2220        //     join_type,
2221        //     Condition::with_expr(on_cond.clone()),
2222        // );
2223
2224        // // Perform `to_stream`
2225        // let result = logical_join.to_stream();
2226
2227        // // Expected plan: HashJoin(($1 = $3) AND ($2 == 42))
2228        // let hash_join = result.as_stream_hash_join().unwrap();
2229        // assert_eq!(hash_join.eq_join_predicate().all_cond().as_expr(), on_cond);
2230    }
2231    /// Pruning
2232    /// ```text
2233    /// Join(on: input_ref(1)=input_ref(3))
2234    ///   TableScan(v1, v2, v3)
2235    ///   TableScan(v4, v5, v6)
2236    /// ```
2237    /// with required columns [3, 2] will result in
2238    /// ```text
2239    /// Project(input_ref(2), input_ref(1))
2240    ///   Join(on: input_ref(0)=input_ref(2))
2241    ///     TableScan(v2, v3)
2242    ///     TableScan(v4)
2243    /// ```
2244    #[tokio::test]
2245    async fn test_join_column_prune_with_order_required() {
2246        let ty = DataType::Int32;
2247        let ctx = OptimizerContext::mock();
2248        let fields: Vec<Field> = (1..7)
2249            .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
2250            .collect();
2251        let left = LogicalValues::new(
2252            vec![],
2253            Schema {
2254                fields: fields[0..3].to_vec(),
2255            },
2256            ctx.clone(),
2257        );
2258        let right = LogicalValues::new(
2259            vec![],
2260            Schema {
2261                fields: fields[3..6].to_vec(),
2262            },
2263            ctx,
2264        );
2265        let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
2266            FunctionCall::new(
2267                Type::Equal,
2268                vec![
2269                    ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
2270                    ExprImpl::InputRef(Box::new(InputRef::new(3, ty))),
2271                ],
2272            )
2273            .unwrap(),
2274        ));
2275        let join_type = JoinType::Inner;
2276        let join: PlanRef = LogicalJoin::new(
2277            left.into(),
2278            right.into(),
2279            join_type,
2280            Condition::with_expr(on),
2281        )
2282        .into();
2283
2284        // Perform the prune
2285        let required_cols = vec![3, 2];
2286        let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
2287
2288        // Check the result
2289        let join = plan.as_logical_join().unwrap();
2290        assert_eq!(join.schema().fields().len(), 2);
2291        assert_eq!(join.schema().fields()[0], fields[3]);
2292        assert_eq!(join.schema().fields()[1], fields[2]);
2293
2294        let expr: ExprImpl = join.on().clone().into();
2295        let call = expr.as_function_call().unwrap();
2296        assert_eq_input_ref!(&call.inputs()[0], 0);
2297        assert_eq_input_ref!(&call.inputs()[1], 2);
2298
2299        let left = join.left();
2300        let left = left.as_logical_values().unwrap();
2301        assert_eq!(left.schema().fields(), &fields[1..3]);
2302        let right = join.right();
2303        let right = right.as_logical_values().unwrap();
2304        assert_eq!(right.schema().fields(), &fields[3..4]);
2305    }
2306
2307    #[tokio::test]
2308    async fn fd_derivation_inner_outer_join() {
2309        // left: [l0, l1], right: [r0, r1, r2]
2310        // FD: l0 --> l1, r0 --> { r1, r2 }
2311        // On: l0 = 0 AND l1 = r1
2312        //
2313        // Inner Join:
2314        //  Schema: [l0, l1, r0, r1, r2]
2315        //  FD: l0 --> l1, r0 --> { r1, r2 }, {} --> l0, l1 --> r1, r1 --> l1
2316        // Left Outer Join:
2317        //  Schema: [l0, l1, r0, r1, r2]
2318        //  FD: l0 --> l1
2319        // Right Outer Join:
2320        //  Schema: [l0, l1, r0, r1, r2]
2321        //  FD: r0 --> { r1, r2 }
2322        // Full Outer Join:
2323        //  Schema: [l0, l1, r0, r1, r2]
2324        //  FD: empty
2325        // Left Semi/Anti Join:
2326        //  Schema: [l0, l1]
2327        //  FD: l0 --> l1
2328        // Right Semi/Anti Join:
2329        //  Schema: [r0, r1, r2]
2330        //  FD: r0 --> {r1, r2}
2331        let ctx = OptimizerContext::mock();
2332        let left = {
2333            let fields: Vec<Field> = vec![
2334                Field::with_name(DataType::Int32, "l0"),
2335                Field::with_name(DataType::Int32, "l1"),
2336            ];
2337            let mut values = LogicalValues::new(vec![], Schema { fields }, ctx.clone());
2338            // 0 --> 1
2339            values
2340                .base
2341                .functional_dependency_mut()
2342                .add_functional_dependency_by_column_indices(&[0], &[1]);
2343            values
2344        };
2345        let right = {
2346            let fields: Vec<Field> = vec![
2347                Field::with_name(DataType::Int32, "r0"),
2348                Field::with_name(DataType::Int32, "r1"),
2349                Field::with_name(DataType::Int32, "r2"),
2350            ];
2351            let mut values = LogicalValues::new(vec![], Schema { fields }, ctx);
2352            // 0 --> 1, 2
2353            values
2354                .base
2355                .functional_dependency_mut()
2356                .add_functional_dependency_by_column_indices(&[0], &[1, 2]);
2357            values
2358        };
2359        // l0 = 0 AND l1 = r1
2360        let on: ExprImpl = FunctionCall::new(
2361            Type::And,
2362            vec![
2363                FunctionCall::new(
2364                    Type::Equal,
2365                    vec![
2366                        InputRef::new(0, DataType::Int32).into(),
2367                        ExprImpl::literal_int(0),
2368                    ],
2369                )
2370                .unwrap()
2371                .into(),
2372                FunctionCall::new(
2373                    Type::Equal,
2374                    vec![
2375                        InputRef::new(1, DataType::Int32).into(),
2376                        InputRef::new(3, DataType::Int32).into(),
2377                    ],
2378                )
2379                .unwrap()
2380                .into(),
2381            ],
2382        )
2383        .unwrap()
2384        .into();
2385        let expected_fd_set = [
2386            (
2387                JoinType::Inner,
2388                [
2389                    // inherit from left
2390                    FunctionalDependency::with_indices(5, &[0], &[1]),
2391                    // inherit from right
2392                    FunctionalDependency::with_indices(5, &[2], &[3, 4]),
2393                    // constant column in join condition
2394                    FunctionalDependency::with_indices(5, &[], &[0]),
2395                    // eq column in join condition
2396                    FunctionalDependency::with_indices(5, &[1], &[3]),
2397                    FunctionalDependency::with_indices(5, &[3], &[1]),
2398                ]
2399                .into_iter()
2400                .collect::<HashSet<_>>(),
2401            ),
2402            (JoinType::FullOuter, HashSet::new()),
2403            (
2404                JoinType::RightOuter,
2405                [
2406                    // inherit from right
2407                    FunctionalDependency::with_indices(5, &[2], &[3, 4]),
2408                ]
2409                .into_iter()
2410                .collect::<HashSet<_>>(),
2411            ),
2412            (
2413                JoinType::LeftOuter,
2414                [
2415                    // inherit from left
2416                    FunctionalDependency::with_indices(5, &[0], &[1]),
2417                ]
2418                .into_iter()
2419                .collect::<HashSet<_>>(),
2420            ),
2421            (
2422                JoinType::LeftSemi,
2423                [
2424                    // inherit from left
2425                    FunctionalDependency::with_indices(2, &[0], &[1]),
2426                ]
2427                .into_iter()
2428                .collect::<HashSet<_>>(),
2429            ),
2430            (
2431                JoinType::LeftAnti,
2432                [
2433                    // inherit from left
2434                    FunctionalDependency::with_indices(2, &[0], &[1]),
2435                ]
2436                .into_iter()
2437                .collect::<HashSet<_>>(),
2438            ),
2439            (
2440                JoinType::RightSemi,
2441                [
2442                    // inherit from right
2443                    FunctionalDependency::with_indices(3, &[0], &[1, 2]),
2444                ]
2445                .into_iter()
2446                .collect::<HashSet<_>>(),
2447            ),
2448            (
2449                JoinType::RightAnti,
2450                [
2451                    // inherit from right
2452                    FunctionalDependency::with_indices(3, &[0], &[1, 2]),
2453                ]
2454                .into_iter()
2455                .collect::<HashSet<_>>(),
2456            ),
2457        ];
2458
2459        for (join_type, expected_res) in expected_fd_set {
2460            let join = LogicalJoin::new(
2461                left.clone().into(),
2462                right.clone().into(),
2463                join_type,
2464                Condition::with_expr(on.clone()),
2465            );
2466            let fd_set = join
2467                .functional_dependency()
2468                .as_dependencies()
2469                .iter()
2470                .cloned()
2471                .collect::<HashSet<_>>();
2472            assert_eq!(fd_set, expected_res);
2473        }
2474    }
2475}