risingwave_frontend/optimizer/plan_node/
logical_join.rs

1// Copyright 2022 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::ops::Deref;
17
18use fixedbitset::FixedBitSet;
19use itertools::{EitherOrBoth, Itertools};
20use pretty_xmlish::{Pretty, XmlNode};
21use risingwave_expr::bail;
22use risingwave_pb::expr::expr_node::PbType;
23use risingwave_pb::plan_common::{AsOfJoinDesc, JoinType, PbAsOfJoinInequalityType};
24use risingwave_pb::stream_plan::StreamScanType;
25use risingwave_sqlparser::ast::AsOf;
26
27use super::generic::{
28    GenericPlanNode, GenericPlanRef, push_down_into_join, push_down_join_condition,
29};
30use super::utils::{Distill, childless_record};
31use super::{
32    BackfillType, BatchPlanRef, ColPrunable, ExprRewritable, Logical, LogicalPlanRef as PlanRef,
33    PlanBase, PlanTreeNodeBinary, PredicatePushdown, StreamHashJoin, StreamPlanRef, StreamProject,
34    ToBatch, ToStream, generic, try_enforce_locality_requirement,
35};
36use crate::error::{ErrorCode, Result, RwError};
37use crate::expr::{CollectInputRef, Expr, ExprImpl, ExprRewriter, ExprType, ExprVisitor, InputRef};
38use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
39use crate::optimizer::plan_node::generic::DynamicFilter;
40use crate::optimizer::plan_node::stream_asof_join::StreamAsOfJoin;
41use crate::optimizer::plan_node::utils::IndicesDisplay;
42use crate::optimizer::plan_node::{
43    BatchHashJoin, BatchLookupJoin, BatchNestedLoopJoin, ColumnPruningContext, EqJoinPredicate,
44    LogicalFilter, LogicalScan, PredicatePushdownContext, RewriteStreamContext,
45    StreamDynamicFilter, StreamFilter, StreamTableScan, StreamTemporalJoin, ToStreamContext,
46};
47use crate::optimizer::plan_visitor::LogicalCardinalityExt;
48use crate::optimizer::property::{Distribution, RequiredDist};
49use crate::utils::{ColIndexMapping, ColIndexMappingRewriteExt, Condition, ConditionDisplay};
50
51/// `LogicalJoin` combines two relations according to some condition.
52///
53/// Each output row has fields from the left and right inputs. The set of output rows is a subset
54/// of the cartesian product of the two inputs; precisely which subset depends on the join
55/// condition. In addition, the output columns are a subset of the columns of the left and
56/// right columns, dependent on the output indices provided. A repeat output index is illegal.
57#[derive(Debug, Clone, PartialEq, Eq, Hash)]
58pub struct LogicalJoin {
59    pub base: PlanBase<Logical>,
60    core: generic::Join<PlanRef>,
61}
62
63impl Distill for LogicalJoin {
64    fn distill<'a>(&self) -> XmlNode<'a> {
65        let verbose = self.base.ctx().is_explain_verbose();
66        let mut vec = Vec::with_capacity(if verbose { 3 } else { 2 });
67        vec.push(("type", Pretty::debug(&self.join_type())));
68
69        let concat_schema = self.core.concat_schema();
70        let cond = Pretty::debug(&ConditionDisplay {
71            condition: self.on(),
72            input_schema: &concat_schema,
73        });
74        vec.push(("on", cond));
75
76        if verbose {
77            let data = IndicesDisplay::from_join(&self.core, &concat_schema);
78            vec.push(("output", data));
79        }
80
81        childless_record("LogicalJoin", vec)
82    }
83}
84
85impl LogicalJoin {
86    pub(crate) fn new(left: PlanRef, right: PlanRef, join_type: JoinType, on: Condition) -> Self {
87        let core = generic::Join::with_full_output(left, right, join_type, on);
88        Self::with_core(core)
89    }
90
91    pub(crate) fn with_output_indices(
92        left: PlanRef,
93        right: PlanRef,
94        join_type: JoinType,
95        on: Condition,
96        output_indices: Vec<usize>,
97    ) -> Self {
98        let core = generic::Join::new(left, right, on, join_type, output_indices);
99        Self::with_core(core)
100    }
101
102    pub fn with_core(core: generic::Join<PlanRef>) -> Self {
103        let base = PlanBase::new_logical_with_core(&core);
104        LogicalJoin { base, core }
105    }
106
107    pub fn create(
108        left: PlanRef,
109        right: PlanRef,
110        join_type: JoinType,
111        on_clause: ExprImpl,
112    ) -> PlanRef {
113        Self::new(left, right, join_type, Condition::with_expr(on_clause)).into()
114    }
115
116    pub fn internal_column_num(&self) -> usize {
117        self.core.internal_column_num()
118    }
119
120    pub fn i2l_col_mapping_ignore_join_type(&self) -> ColIndexMapping {
121        self.core.i2l_col_mapping_ignore_join_type()
122    }
123
124    pub fn i2r_col_mapping_ignore_join_type(&self) -> ColIndexMapping {
125        self.core.i2r_col_mapping_ignore_join_type()
126    }
127
128    /// Get a reference to the logical join's on.
129    pub fn on(&self) -> &Condition {
130        self.core
131            .on
132            .as_condition_ref()
133            .expect("logical join should store predicate as Condition")
134    }
135
136    pub fn core(&self) -> &generic::Join<PlanRef> {
137        &self.core
138    }
139
140    /// Collect all input ref in the on condition. And separate them into left and right.
141    pub fn input_idx_on_condition(&self) -> (Vec<usize>, Vec<usize>) {
142        let input_refs = self
143            .core
144            .on
145            .as_condition_ref()
146            .expect("logical join should store predicate as Condition")
147            .collect_input_refs(self.core.left.schema().len() + self.core.right.schema().len());
148        let index_group = input_refs
149            .ones()
150            .chunk_by(|i| *i < self.core.left.schema().len());
151        let left_index = index_group
152            .into_iter()
153            .next()
154            .map_or(vec![], |group| group.1.collect_vec());
155        let right_index = index_group.into_iter().next().map_or(vec![], |group| {
156            group
157                .1
158                .map(|i| i - self.core.left.schema().len())
159                .collect_vec()
160        });
161        (left_index, right_index)
162    }
163
164    /// Get the join type of the logical join.
165    pub fn join_type(&self) -> JoinType {
166        self.core.join_type
167    }
168
169    /// Get the eq join key of the logical join.
170    pub fn eq_indexes(&self) -> Vec<(usize, usize)> {
171        self.core.eq_indexes()
172    }
173
174    /// Get the output indices of the logical join.
175    pub fn output_indices(&self) -> &Vec<usize> {
176        &self.core.output_indices
177    }
178
179    /// Clone with new output indices
180    pub fn clone_with_output_indices(&self, output_indices: Vec<usize>) -> Self {
181        Self::with_core(generic::Join {
182            output_indices,
183            ..self.core.clone()
184        })
185    }
186
187    /// Clone with new `on` condition
188    pub fn clone_with_cond(&self, on: Condition) -> Self {
189        Self::with_core(generic::Join {
190            on: generic::JoinOn::Condition(on),
191            ..self.core.clone()
192        })
193    }
194
195    pub fn is_left_join(&self) -> bool {
196        matches!(self.join_type(), JoinType::LeftSemi | JoinType::LeftAnti)
197    }
198
199    pub fn is_right_join(&self) -> bool {
200        matches!(self.join_type(), JoinType::RightSemi | JoinType::RightAnti)
201    }
202
203    pub fn is_full_out(&self) -> bool {
204        self.core.is_full_out()
205    }
206
207    pub fn is_asof_join(&self) -> bool {
208        self.join_type() == JoinType::AsofInner || self.join_type() == JoinType::AsofLeftOuter
209    }
210
211    pub fn output_indices_are_trivial(&self) -> bool {
212        itertools::equal(
213            self.output_indices().iter().cloned(),
214            0..self.internal_column_num(),
215        )
216    }
217
218    /// Try to simplify the outer join with the predicate on the top of the join
219    ///
220    /// now it is just a naive implementation for comparison expression, we can give a more general
221    /// implementation with constant folding in future
222    fn simplify_outer(predicate: &Condition, left_col_num: usize, join_type: JoinType) -> JoinType {
223        let (mut gen_null_in_left, mut gen_null_in_right) = match join_type {
224            JoinType::LeftOuter => (false, true),
225            JoinType::RightOuter => (true, false),
226            JoinType::FullOuter => (true, true),
227            _ => return join_type,
228        };
229
230        for expr in &predicate.conjunctions {
231            if let ExprImpl::FunctionCall(func) = expr {
232                match func.func_type() {
233                    ExprType::Equal
234                    | ExprType::NotEqual
235                    | ExprType::LessThan
236                    | ExprType::LessThanOrEqual
237                    | ExprType::GreaterThan
238                    | ExprType::GreaterThanOrEqual => {
239                        for input in func.inputs() {
240                            if let ExprImpl::InputRef(input) = input {
241                                let idx = input.index;
242                                if idx < left_col_num {
243                                    gen_null_in_left = false;
244                                } else {
245                                    gen_null_in_right = false;
246                                }
247                            }
248                        }
249                    }
250                    _ => {}
251                };
252            }
253        }
254
255        match (gen_null_in_left, gen_null_in_right) {
256            (true, true) => JoinType::FullOuter,
257            (true, false) => JoinType::RightOuter,
258            (false, true) => JoinType::LeftOuter,
259            (false, false) => JoinType::Inner,
260        }
261    }
262
263    /// Index Join:
264    /// Try to convert logical join into batch lookup join and meanwhile it will do
265    /// the index selection for the lookup table so that we can benefit from indexes.
266    fn to_batch_lookup_join_with_index_selection(
267        &self,
268        predicate: EqJoinPredicate,
269        batch_join: generic::Join<BatchPlanRef>,
270    ) -> Result<Option<BatchLookupJoin>> {
271        match batch_join.join_type {
272            JoinType::Inner
273            | JoinType::LeftOuter
274            | JoinType::LeftSemi
275            | JoinType::LeftAnti
276            | JoinType::AsofInner
277            | JoinType::AsofLeftOuter => {}
278            _ => return Ok(None),
279        };
280
281        // Index selection for index join.
282        let right = self.right();
283        // Lookup Join only supports basic tables on the join's right side.
284        let logical_scan: &LogicalScan = if let Some(logical_scan) = right.as_logical_scan() {
285            logical_scan
286        } else {
287            return Ok(None);
288        };
289
290        let mut result_plan = None;
291        // Lookup primary table.
292        if let Some(lookup_join) =
293            self.to_batch_lookup_join(predicate.clone(), batch_join.clone())?
294        {
295            result_plan = Some(lookup_join);
296        }
297
298        if self
299            .core
300            .ctx()
301            .session_ctx()
302            .config()
303            .enable_index_selection()
304        {
305            let indexes = logical_scan.table_indexes();
306            for index in indexes {
307                if let Some(index_scan) = logical_scan.to_index_scan_if_index_covered(index) {
308                    let index_scan: PlanRef = index_scan.into();
309                    let that = self.clone_with_left_right(self.left(), index_scan.clone());
310                    let mut new_batch_join = batch_join.clone();
311                    new_batch_join.right =
312                        index_scan.to_batch().expect("index scan failed to batch");
313
314                    // Lookup covered index.
315                    if let Some(lookup_join) =
316                        that.to_batch_lookup_join(predicate.clone(), new_batch_join)?
317                    {
318                        match &result_plan {
319                            None => result_plan = Some(lookup_join),
320                            Some(prev_lookup_join) => {
321                                // Prefer to choose lookup join with longer lookup prefix len.
322                                if prev_lookup_join.lookup_prefix_len()
323                                    < lookup_join.lookup_prefix_len()
324                                {
325                                    result_plan = Some(lookup_join)
326                                }
327                            }
328                        }
329                    }
330                }
331            }
332        }
333
334        Ok(result_plan)
335    }
336
337    /// Try to convert logical join into batch lookup join.
338    fn to_batch_lookup_join(
339        &self,
340        predicate: EqJoinPredicate,
341        logical_join: generic::Join<BatchPlanRef>,
342    ) -> Result<Option<BatchLookupJoin>> {
343        let logical_scan: &LogicalScan =
344            if let Some(logical_scan) = self.core.right.as_logical_scan() {
345                logical_scan
346            } else {
347                return Ok(None);
348            };
349        Self::gen_batch_lookup_join(logical_scan, predicate, logical_join, self.is_asof_join())
350    }
351
352    pub fn gen_batch_lookup_join(
353        logical_scan: &LogicalScan,
354        predicate: EqJoinPredicate,
355        logical_join: generic::Join<BatchPlanRef>,
356        is_as_of: bool,
357    ) -> Result<Option<BatchLookupJoin>> {
358        match logical_join.join_type {
359            JoinType::Inner
360            | JoinType::LeftOuter
361            | JoinType::LeftSemi
362            | JoinType::LeftAnti
363            | JoinType::AsofInner
364            | JoinType::AsofLeftOuter => {}
365            _ => return Ok(None),
366        };
367
368        let table = logical_scan.table();
369        let output_column_ids = logical_scan.output_column_ids();
370
371        // Verify that the right join key columns are the the prefix of the primary key and
372        // also contain the distribution key.
373        let order_col_ids = table.order_column_ids();
374        let dist_key = table.distribution_key.clone();
375        // The at least prefix of order key that contains distribution key.
376        let mut dist_key_in_order_key_pos = vec![];
377        for d in dist_key {
378            let pos = table
379                .order_column_indices()
380                .position(|x| x == d)
381                .expect("dist_key must in order_key");
382            dist_key_in_order_key_pos.push(pos);
383        }
384        // The shortest prefix of order key that contains distribution key.
385        let shortest_prefix_len = dist_key_in_order_key_pos
386            .iter()
387            .max()
388            .map_or(0, |pos| pos + 1);
389
390        // Distributed lookup join can't support lookup table with a singleton distribution.
391        if shortest_prefix_len == 0 {
392            return Ok(None);
393        }
394
395        // Reorder the join equal predicate to match the order key.
396        let mut reorder_idx = Vec::with_capacity(shortest_prefix_len);
397        for order_col_id in order_col_ids {
398            let mut found = false;
399            for (i, eq_idx) in predicate.right_eq_indexes().into_iter().enumerate() {
400                if order_col_id == output_column_ids[eq_idx] {
401                    reorder_idx.push(i);
402                    found = true;
403                    break;
404                }
405            }
406            if !found {
407                break;
408            }
409        }
410        if reorder_idx.len() < shortest_prefix_len {
411            return Ok(None);
412        }
413        let lookup_prefix_len = reorder_idx.len();
414        let predicate = predicate.reorder(&reorder_idx);
415
416        // Extract the predicate from logical scan. Only pure scan is supported.
417        let (new_scan, scan_predicate, project_expr) = logical_scan.predicate_pull_up();
418        // Construct output column to require column mapping
419        let o2r = if let Some(project_expr) = project_expr {
420            project_expr
421                .into_iter()
422                .map(|x| x.as_input_ref().unwrap().index)
423                .collect_vec()
424        } else {
425            (0..logical_scan.output_col_idx().len()).collect_vec()
426        };
427        let left_schema_len = logical_join.left.schema().len();
428
429        let mut join_predicate_rewriter = LookupJoinPredicateRewriter {
430            offset: left_schema_len,
431            mapping: o2r.clone(),
432        };
433
434        let new_eq_cond = predicate
435            .eq_cond()
436            .rewrite_expr(&mut join_predicate_rewriter);
437
438        let mut scan_predicate_rewriter = LookupJoinScanPredicateRewriter {
439            offset: left_schema_len,
440        };
441
442        let new_other_cond = predicate
443            .other_cond()
444            .clone()
445            .rewrite_expr(&mut join_predicate_rewriter)
446            .and(scan_predicate.rewrite_expr(&mut scan_predicate_rewriter));
447
448        let new_join_on = new_eq_cond.and(new_other_cond);
449        let new_predicate =
450            EqJoinPredicate::create(left_schema_len, new_scan.schema().len(), new_join_on);
451
452        // We discovered that we cannot use a lookup join after pulling up the predicate
453        // from one side and simplifying the condition. Let's use some other join instead.
454        if !new_predicate.has_eq() {
455            return Ok(None);
456        }
457
458        // Rewrite the join output indices and all output indices referred to the old scan need to
459        // rewrite.
460        let new_join_output_indices = logical_join
461            .output_indices
462            .iter()
463            .map(|&x| {
464                if x < left_schema_len {
465                    x
466                } else {
467                    o2r[x - left_schema_len] + left_schema_len
468                }
469            })
470            .collect_vec();
471
472        let new_scan_output_column_ids = new_scan.output_column_ids();
473        let as_of = new_scan.as_of.clone();
474        let new_logical_scan: LogicalScan = new_scan.into();
475
476        // Construct a new logical join, because we have change its RHS.
477        let new_logical_join = generic::Join::new_with_eq_predicate(
478            logical_join.left,
479            new_logical_scan.to_batch()?,
480            new_predicate,
481            logical_join.join_type,
482            new_join_output_indices,
483        );
484
485        let asof_desc = is_as_of
486            .then(|| {
487                Self::get_inequality_desc_from_predicate(
488                    predicate.other_cond().clone(),
489                    left_schema_len,
490                )
491            })
492            .transpose()?;
493
494        Ok(Some(BatchLookupJoin::new(
495            new_logical_join,
496            table.clone(),
497            new_scan_output_column_ids,
498            lookup_prefix_len,
499            false,
500            as_of,
501            asof_desc,
502        )))
503    }
504
505    pub fn decompose(self) -> (PlanRef, PlanRef, Condition, JoinType, Vec<usize>) {
506        self.core.decompose()
507    }
508
509    fn dynamic_filter_candidate(&self, predicate: &Condition) -> Option<(usize, PbType)> {
510        // Dynamic filter only supports `Inner`/`LeftSemi`.
511        if !matches!(self.join_type(), JoinType::Inner | JoinType::LeftSemi) {
512            return None;
513        }
514
515        // Dynamic filter requires right side to be a scalar with one column.
516        if !self.right().max_one_row() || self.right().schema().len() != 1 {
517            return None;
518        }
519
520        // Dynamic filter only supports a single comparison predicate.
521        if predicate.conjunctions.len() > 1 {
522            return None;
523        }
524        let expr: ExprImpl = predicate.clone().into();
525        let (left_ref, comparator, right_ref) = expr.as_comparison_cond()?;
526
527        // Comparison must cross inputs: left input ref vs right scalar input ref.
528        let left_len = self.left().schema().len();
529        let condition_cross_inputs = left_ref.index < left_len && right_ref.index == left_len;
530        if !condition_cross_inputs {
531            return None;
532        }
533
534        // Comparison keys must be type aligned.
535        if self.left().schema().fields()[left_ref.index].data_type
536            != self.right().schema().fields()[0].data_type
537        {
538            return None;
539        }
540
541        // Dynamic filter output can only come from the left side.
542        if !self.output_indices().iter().all(|i| *i < left_len) {
543            return None;
544        }
545
546        Some((left_ref.index, comparator))
547    }
548
549    /// Check whether this join can be treated as a temporal filter for locality optimization.
550    ///
551    /// This is intentionally stricter than dynamic filter:
552    /// - right side must be a `LogicalNow`,
553    /// - and all dynamic filter preconditions must hold.
554    fn temporal_filter_candidate(&self) -> bool {
555        self.right().as_logical_now().is_some()
556            && self.dynamic_filter_candidate(self.on()).is_some()
557    }
558}
559
560impl PlanTreeNodeBinary<Logical> for LogicalJoin {
561    fn left(&self) -> PlanRef {
562        self.core.left.clone()
563    }
564
565    fn right(&self) -> PlanRef {
566        self.core.right.clone()
567    }
568
569    fn clone_with_left_right(&self, left: PlanRef, right: PlanRef) -> Self {
570        Self::with_core(generic::Join {
571            left,
572            right,
573            ..self.core.clone()
574        })
575    }
576
577    fn rewrite_with_left_right(
578        &self,
579        left: PlanRef,
580        left_col_change: ColIndexMapping,
581        right: PlanRef,
582        right_col_change: ColIndexMapping,
583    ) -> (Self, ColIndexMapping) {
584        let (new_on, new_output_indices) = {
585            let (mut map, _) = left_col_change.clone().into_parts();
586            let (mut right_map, _) = right_col_change.clone().into_parts();
587            for i in right_map.iter_mut().flatten() {
588                *i += left.schema().len();
589            }
590            map.append(&mut right_map);
591            let mut mapping = ColIndexMapping::new(map, left.schema().len() + right.schema().len());
592
593            let new_output_indices = self
594                .output_indices()
595                .iter()
596                .map(|&i| mapping.map(i))
597                .collect::<Vec<_>>();
598            let new_on = self.on().clone().rewrite_expr(&mut mapping);
599            (new_on, new_output_indices)
600        };
601
602        let join = Self::with_output_indices(
603            left,
604            right,
605            self.join_type(),
606            new_on,
607            new_output_indices.clone(),
608        );
609
610        let new_i2o = ColIndexMapping::with_remaining_columns(
611            &new_output_indices,
612            join.internal_column_num(),
613        );
614
615        let old_o2i = self.core.o2i_col_mapping();
616
617        let old_o2l = old_o2i
618            .composite(&self.core.i2l_col_mapping())
619            .composite(&left_col_change);
620        let old_o2r = old_o2i
621            .composite(&self.core.i2r_col_mapping())
622            .composite(&right_col_change);
623        let new_l2o = join.core.l2i_col_mapping().composite(&new_i2o);
624        let new_r2o = join.core.r2i_col_mapping().composite(&new_i2o);
625
626        let out_col_change = old_o2l
627            .composite(&new_l2o)
628            .union(&old_o2r.composite(&new_r2o));
629        (join, out_col_change)
630    }
631}
632
633impl_plan_tree_node_for_binary! { Logical, LogicalJoin }
634
635impl ColPrunable for LogicalJoin {
636    fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef {
637        // make `required_cols` point to internal table instead of output schema.
638        let required_cols = required_cols
639            .iter()
640            .map(|i| self.output_indices()[*i])
641            .collect_vec();
642        let left_len = self.left().schema().fields.len();
643
644        let total_len = self.left().schema().len() + self.right().schema().len();
645        let mut resized_required_cols = FixedBitSet::with_capacity(total_len);
646
647        required_cols.iter().for_each(|&i| {
648            if self.is_right_join() {
649                resized_required_cols.insert(left_len + i);
650            } else {
651                resized_required_cols.insert(i);
652            }
653        });
654
655        // add those columns which are required in the join condition to
656        // to those that are required in the output
657        let mut visitor = CollectInputRef::new(resized_required_cols);
658        self.on().visit_expr(&mut visitor);
659        let left_right_required_cols = FixedBitSet::from(visitor).ones().collect_vec();
660
661        let mut left_required_cols = Vec::new();
662        let mut right_required_cols = Vec::new();
663        left_right_required_cols.iter().for_each(|&i| {
664            if i < left_len {
665                left_required_cols.push(i);
666            } else {
667                right_required_cols.push(i - left_len);
668            }
669        });
670
671        let mut on = self.on().clone();
672        let mut mapping =
673            ColIndexMapping::with_remaining_columns(&left_right_required_cols, total_len);
674        on = on.rewrite_expr(&mut mapping);
675
676        let new_output_indices = {
677            let required_inputs_in_output = if self.is_left_join() {
678                &left_required_cols
679            } else if self.is_right_join() {
680                &right_required_cols
681            } else {
682                &left_right_required_cols
683            };
684
685            let mapping =
686                ColIndexMapping::with_remaining_columns(required_inputs_in_output, total_len);
687            required_cols.iter().map(|&i| mapping.map(i)).collect_vec()
688        };
689
690        LogicalJoin::with_output_indices(
691            self.left().prune_col(&left_required_cols, ctx),
692            self.right().prune_col(&right_required_cols, ctx),
693            self.join_type(),
694            on,
695            new_output_indices,
696        )
697        .into()
698    }
699}
700
701impl ExprRewritable<Logical> for LogicalJoin {
702    fn has_rewritable_expr(&self) -> bool {
703        true
704    }
705
706    fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
707        let mut core = self.core.clone();
708        core.rewrite_exprs(r);
709        Self {
710            base: self.base.clone_with_new_plan_id(),
711            core,
712        }
713        .into()
714    }
715}
716
717impl ExprVisitable for LogicalJoin {
718    fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
719        self.core.visit_exprs(v);
720    }
721}
722
723/// We are trying to derive a predicate to apply to the other side of a join if all
724/// the `InputRef`s in the predicate are eq condition columns, and can hence be substituted
725/// with the corresponding eq condition columns of the other side.
726///
727/// Strategy:
728/// 1. If the function is pure except for any `InputRef` (which may refer to impure computation),
729///    then we proceed. Else abort.
730/// 2. Then, we collect `InputRef`s in the conjunction.
731/// 3. If they are all columns in the given side of join eq condition, then we proceed. Else abort.
732/// 4. We then rewrite the `ExprImpl`, by replacing `InputRef` column indices with the equivalent in
733///    the other side.
734///
735/// # Arguments
736///
737/// Suppose we derive a predicate from the left side to be pushed to the right side.
738/// * `expr`: An expr from the left side.
739/// * `col_num`: The number of columns in the left side.
740fn derive_predicate_from_eq_condition(
741    expr: &ExprImpl,
742    eq_condition: &EqJoinPredicate,
743    col_num: usize,
744    expr_is_left: bool,
745) -> Option<ExprImpl> {
746    if expr.is_impure() {
747        return None;
748    }
749    let eq_indices = eq_condition
750        .eq_indexes_typed()
751        .iter()
752        .filter_map(|(l, r)| {
753            if l.return_type() != r.return_type() {
754                None
755            } else if expr_is_left {
756                Some(l.index())
757            } else {
758                Some(r.index())
759            }
760        })
761        .collect_vec();
762    if expr
763        .collect_input_refs(col_num)
764        .ones()
765        .any(|index| !eq_indices.contains(&index))
766    {
767        // expr contains an InputRef not in eq_condition
768        return None;
769    }
770    // The function is pure except for `InputRef` and all `InputRef`s are `eq_condition` indices.
771    // Hence, we can substitute those `InputRef`s with indices from the other side.
772    let other_side_mapping = if expr_is_left {
773        eq_condition.eq_indexes_typed().into_iter().collect()
774    } else {
775        eq_condition
776            .eq_indexes_typed()
777            .into_iter()
778            .map(|(x, y)| (y, x))
779            .collect()
780    };
781    struct InputRefsRewriter {
782        mapping: HashMap<InputRef, InputRef>,
783    }
784    impl ExprRewriter for InputRefsRewriter {
785        fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
786            self.mapping[&input_ref].clone().into()
787        }
788    }
789    Some(
790        InputRefsRewriter {
791            mapping: other_side_mapping,
792        }
793        .rewrite_expr(expr.clone()),
794    )
795}
796
797/// Rewrite the join predicate and all columns referred to the scan side need to rewrite.
798struct LookupJoinPredicateRewriter {
799    offset: usize,
800    mapping: Vec<usize>,
801}
802impl ExprRewriter for LookupJoinPredicateRewriter {
803    fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
804        if input_ref.index() < self.offset {
805            input_ref.into()
806        } else {
807            InputRef::new(
808                self.mapping[input_ref.index() - self.offset] + self.offset,
809                input_ref.return_type(),
810            )
811            .into()
812        }
813    }
814}
815
816/// Rewrite the scan predicate so we can add it to the join predicate.
817struct LookupJoinScanPredicateRewriter {
818    offset: usize,
819}
820impl ExprRewriter for LookupJoinScanPredicateRewriter {
821    fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
822        InputRef::new(input_ref.index() + self.offset, input_ref.return_type()).into()
823    }
824}
825
826impl PredicatePushdown for LogicalJoin {
827    /// Pushes predicates above and within a join node into the join node and/or its children nodes.
828    ///
829    /// # Which predicates can be pushed
830    ///
831    /// For inner join, we can do all kinds of pushdown.
832    ///
833    /// For left/right semi join, we can push filter to left/right and on-clause,
834    /// and push on-clause to left/right.
835    ///
836    /// For left/right anti join, we can push filter to left/right, but on-clause can not be pushed
837    ///
838    /// ## Outer Join
839    ///
840    /// Preserved Row table
841    /// : The table in an Outer Join that must return all rows.
842    ///
843    /// Null Supplying table
844    /// : This is the table that has nulls filled in for its columns in unmatched rows.
845    ///
846    /// |                          | Preserved Row table | Null Supplying table |
847    /// |--------------------------|---------------------|----------------------|
848    /// | Join predicate (on)      | Not Pushed          | Pushed               |
849    /// | Where predicate (filter) | Pushed              | Not Pushed           |
850    fn predicate_pushdown(
851        &self,
852        predicate: Condition,
853        ctx: &mut PredicatePushdownContext,
854    ) -> PlanRef {
855        // rewrite output col referencing indices as internal cols
856        let mut predicate = {
857            let mut mapping = self.core.o2i_col_mapping();
858            predicate.rewrite_expr(&mut mapping)
859        };
860
861        let left_col_num = self.left().schema().len();
862        let right_col_num = self.right().schema().len();
863        let join_type = LogicalJoin::simplify_outer(&predicate, left_col_num, self.join_type());
864
865        let push_down_temporal_predicate = self.temporal_join_on().is_none();
866
867        let (left_from_filter, right_from_filter, on) = push_down_into_join(
868            &mut predicate,
869            left_col_num,
870            right_col_num,
871            join_type,
872            push_down_temporal_predicate,
873        );
874
875        let mut new_on = self.on().clone().and(on);
876        let (left_from_on, right_from_on) = push_down_join_condition(
877            &mut new_on,
878            left_col_num,
879            right_col_num,
880            join_type,
881            push_down_temporal_predicate,
882        );
883
884        let left_predicate = left_from_filter.and(left_from_on);
885        let right_predicate = right_from_filter.and(right_from_on);
886
887        // Derive conditions to push to the other side based on eq condition columns
888        let eq_condition = EqJoinPredicate::create(left_col_num, right_col_num, new_on.clone());
889
890        // Only push to RHS if RHS is inner side of a join (RHS requires match on LHS)
891        let right_from_left = if matches!(
892            join_type,
893            JoinType::Inner | JoinType::LeftOuter | JoinType::RightSemi | JoinType::LeftSemi
894        ) {
895            Condition {
896                conjunctions: left_predicate
897                    .conjunctions
898                    .iter()
899                    .filter_map(|expr| {
900                        derive_predicate_from_eq_condition(expr, &eq_condition, left_col_num, true)
901                    })
902                    .collect(),
903            }
904        } else {
905            Condition::true_cond()
906        };
907
908        // Only push to LHS if LHS is inner side of a join (LHS requires match on RHS)
909        let left_from_right = if matches!(
910            join_type,
911            JoinType::Inner | JoinType::RightOuter | JoinType::LeftSemi | JoinType::RightSemi
912        ) {
913            Condition {
914                conjunctions: right_predicate
915                    .conjunctions
916                    .iter()
917                    .filter_map(|expr| {
918                        derive_predicate_from_eq_condition(
919                            expr,
920                            &eq_condition,
921                            right_col_num,
922                            false,
923                        )
924                    })
925                    .collect(),
926            }
927        } else {
928            Condition::true_cond()
929        };
930
931        let left_predicate = left_predicate.and(left_from_right);
932        let right_predicate = right_predicate.and(right_from_left);
933
934        let new_left = self.left().predicate_pushdown(left_predicate, ctx);
935        let new_right = self.right().predicate_pushdown(right_predicate, ctx);
936        let new_join = LogicalJoin::with_output_indices(
937            new_left,
938            new_right,
939            join_type,
940            new_on,
941            self.output_indices().clone(),
942        );
943
944        let mut mapping = self.core.i2o_col_mapping();
945        predicate = predicate.rewrite_expr(&mut mapping);
946        LogicalFilter::create(new_join.into(), predicate)
947    }
948}
949
950#[derive(Clone, Copy)]
951struct TemporalJoinScan<'a>(&'a LogicalScan);
952
953impl<'a> Deref for TemporalJoinScan<'a> {
954    type Target = LogicalScan;
955
956    fn deref(&self) -> &Self::Target {
957        self.0
958    }
959}
960
961impl LogicalJoin {
962    fn get_stream_input_for_hash_join(
963        &self,
964        predicate: &EqJoinPredicate,
965        ctx: &mut ToStreamContext,
966    ) -> Result<(StreamPlanRef, StreamPlanRef)> {
967        use super::stream::prelude::*;
968
969        let mut right = self.right().to_stream_with_dist_required(
970            &RequiredDist::shard_by_key(self.right().schema().len(), &predicate.right_eq_indexes()),
971            ctx,
972        )?;
973        let r2l =
974            predicate.r2l_eq_columns_mapping(self.left().schema().len(), right.schema().len());
975        let l2r =
976            predicate.l2r_eq_columns_mapping(self.left().schema().len(), right.schema().len());
977        let mut left;
978        let right_dist = right.distribution();
979        match right_dist {
980            Distribution::HashShard(_) => {
981                let left_dist = r2l
982                    .rewrite_required_distribution(&RequiredDist::PhysicalDist(right_dist.clone()));
983                left = self.left().to_stream_with_dist_required(&left_dist, ctx)?;
984            }
985            Distribution::UpstreamHashShard(_, _) => {
986                left = self.left().to_stream_with_dist_required(
987                    &RequiredDist::shard_by_key(
988                        self.left().schema().len(),
989                        &predicate.left_eq_indexes(),
990                    ),
991                    ctx,
992                )?;
993                let left_dist = left.distribution();
994                match left_dist {
995                    Distribution::HashShard(_) => {
996                        let right_dist = l2r.rewrite_required_distribution(
997                            &RequiredDist::PhysicalDist(left_dist.clone()),
998                        );
999                        right = right_dist.streaming_enforce_if_not_satisfies(right)?
1000                    }
1001                    Distribution::UpstreamHashShard(_, _) => {
1002                        left = RequiredDist::hash_shard(&predicate.left_eq_indexes())
1003                            .streaming_enforce_if_not_satisfies(left)?;
1004                        right = RequiredDist::hash_shard(&predicate.right_eq_indexes())
1005                            .streaming_enforce_if_not_satisfies(right)?;
1006                    }
1007                    _ => unreachable!(),
1008                }
1009            }
1010            _ => unreachable!(),
1011        }
1012        Ok((left, right))
1013    }
1014
1015    fn to_stream_hash_join(
1016        &self,
1017        predicate: EqJoinPredicate,
1018        ctx: &mut ToStreamContext,
1019    ) -> Result<StreamPlanRef> {
1020        use super::stream::prelude::*;
1021
1022        assert!(predicate.has_eq());
1023        let (left, right) = self.get_stream_input_for_hash_join(&predicate, ctx)?;
1024
1025        let mut core = self.core.clone_with_inputs(left, right);
1026        core.on = generic::JoinOn::EqPredicate(predicate);
1027
1028        // Convert to Hash Join for equal joins
1029        // For inner joins, pull non-equal conditions to a filter operator on top of it by default.
1030        // We do so as the filter operator can apply the non-equal condition batch-wise (vectorized)
1031        // as opposed to the HashJoin, which applies the condition row-wise.
1032        // However, the default behavior of pulling up non-equal conditions can be overridden by the
1033        // session variable `streaming_force_filter_inside_join` as it can save unnecessary
1034        // materialization of rows only to be filtered later.
1035
1036        let stream_hash_join = StreamHashJoin::new(core.clone())?;
1037        let predicate = stream_hash_join.eq_join_predicate().clone();
1038
1039        let force_filter_inside_join = self
1040            .base
1041            .ctx()
1042            .session_ctx()
1043            .config()
1044            .streaming_force_filter_inside_join();
1045
1046        let pull_filter = self.join_type() == JoinType::Inner
1047            && stream_hash_join.eq_join_predicate().has_non_eq()
1048            && stream_hash_join.inequality_pairs().is_empty()
1049            && (!force_filter_inside_join);
1050        if pull_filter {
1051            let default_indices = (0..self.internal_column_num()).collect::<Vec<_>>();
1052
1053            let mut core = core;
1054            core.output_indices = default_indices.clone();
1055            // Temporarily remove output indices.
1056            let eq_cond = EqJoinPredicate::new(
1057                Condition::true_cond(),
1058                predicate.eq_keys().to_vec(),
1059                self.left().schema().len(),
1060                self.right().schema().len(),
1061            );
1062            core.on = generic::JoinOn::EqPredicate(eq_cond);
1063            let hash_join = StreamHashJoin::new(core)?.into();
1064            let logical_filter = generic::Filter::new(predicate.non_eq_cond(), hash_join);
1065            let plan = StreamFilter::new(logical_filter).into();
1066            if self.output_indices() != &default_indices {
1067                let logical_project = generic::Project::with_mapping(
1068                    plan,
1069                    ColIndexMapping::with_remaining_columns(
1070                        self.output_indices(),
1071                        self.internal_column_num(),
1072                    ),
1073                );
1074                Ok(StreamProject::new(logical_project).into())
1075            } else {
1076                Ok(plan)
1077            }
1078        } else {
1079            Ok(stream_hash_join.into())
1080        }
1081    }
1082
1083    pub fn should_be_temporal_join(&self) -> bool {
1084        self.temporal_join_on().is_some()
1085    }
1086
1087    fn temporal_join_on(&self) -> Option<TemporalJoinScan<'_>> {
1088        if let Some(logical_scan) = self.core.right.as_logical_scan() {
1089            matches!(logical_scan.as_of(), Some(AsOf::ProcessTime))
1090                .then_some(TemporalJoinScan(logical_scan))
1091        } else {
1092            None
1093        }
1094    }
1095
1096    fn should_be_stream_temporal_join<'a>(
1097        &'a self,
1098        ctx: &ToStreamContext,
1099    ) -> Result<Option<TemporalJoinScan<'a>>> {
1100        Ok(if let Some(scan) = self.temporal_join_on() {
1101            if let BackfillType::SnapshotBackfill = ctx.backfill_type() {
1102                return Err(RwError::from(ErrorCode::NotSupported(
1103                    "Temporal join with snapshot backfill not supported".into(),
1104                    "Please use arrangement backfill".into(),
1105                )));
1106            }
1107            if scan.cross_database() {
1108                return Err(RwError::from(ErrorCode::NotSupported(
1109                        "Temporal join requires the lookup table to be in the same database as the stream source table".into(),
1110                        "Please ensure both tables are in the same database".into(),
1111                    )));
1112            }
1113            Some(scan)
1114        } else {
1115            None
1116        })
1117    }
1118
1119    fn to_stream_temporal_join_with_index_selection(
1120        &self,
1121        logical_scan: TemporalJoinScan<'_>,
1122        predicate: EqJoinPredicate,
1123        ctx: &mut ToStreamContext,
1124    ) -> Result<StreamPlanRef> {
1125        // Use primary table.
1126        let mut result_plan: Result<StreamTemporalJoin> =
1127            self.to_stream_temporal_join(logical_scan, predicate.clone(), ctx);
1128        // Return directly if this temporal join can match the pk of its right table.
1129        if let Ok(temporal_join) = &result_plan
1130            && temporal_join.eq_join_predicate().eq_indexes().len()
1131                == logical_scan.primary_key().len()
1132        {
1133            return result_plan.map(|x| x.into());
1134        }
1135        if self
1136            .core
1137            .ctx()
1138            .session_ctx()
1139            .config()
1140            .enable_index_selection()
1141        {
1142            let indexes = logical_scan.table_indexes();
1143            for index in indexes {
1144                // Use index table
1145                if let Some(index_scan) = logical_scan.to_index_scan_if_index_covered(index) {
1146                    let index_scan: PlanRef = index_scan.into();
1147                    let that = self.clone_with_left_right(self.left(), index_scan.clone());
1148                    if let Ok(temporal_join) = that.to_stream_temporal_join(
1149                        that.temporal_join_on().expect(
1150                            "index scan created from temporal join scan must also be temporal join",
1151                        ),
1152                        predicate.clone(),
1153                        ctx,
1154                    ) {
1155                        match &result_plan {
1156                            Err(_) => result_plan = Ok(temporal_join),
1157                            Ok(prev_temporal_join) => {
1158                                // Prefer to the temporal join with a longer lookup prefix len.
1159                                if prev_temporal_join.eq_join_predicate().eq_indexes().len()
1160                                    < temporal_join.eq_join_predicate().eq_indexes().len()
1161                                {
1162                                    result_plan = Ok(temporal_join)
1163                                }
1164                            }
1165                        }
1166                    }
1167                }
1168            }
1169        }
1170
1171        result_plan.map(|x| x.into())
1172    }
1173
1174    fn temporal_join_scan_predicate_pull_up(
1175        logical_scan: TemporalJoinScan<'_>,
1176        predicate: EqJoinPredicate,
1177        output_indices: &[usize],
1178        left_schema_len: usize,
1179    ) -> Result<(StreamTableScan, EqJoinPredicate, Condition, Vec<usize>)> {
1180        // Extract the predicate from logical scan. Only pure scan is supported.
1181        let (new_scan, scan_predicate, project_expr) = logical_scan.predicate_pull_up();
1182        // Construct output column to require column mapping
1183        let o2r = if let Some(project_expr) = project_expr {
1184            project_expr
1185                .into_iter()
1186                .map(|x| x.as_input_ref().unwrap().index)
1187                .collect_vec()
1188        } else {
1189            (0..logical_scan.output_col_idx().len()).collect_vec()
1190        };
1191        let mut join_predicate_rewriter = LookupJoinPredicateRewriter {
1192            offset: left_schema_len,
1193            mapping: o2r.clone(),
1194        };
1195
1196        let new_eq_cond = predicate
1197            .eq_cond()
1198            .rewrite_expr(&mut join_predicate_rewriter);
1199
1200        let mut scan_predicate_rewriter = LookupJoinScanPredicateRewriter {
1201            offset: left_schema_len,
1202        };
1203
1204        let new_other_cond = predicate
1205            .other_cond()
1206            .clone()
1207            .rewrite_expr(&mut join_predicate_rewriter)
1208            .and(scan_predicate.rewrite_expr(&mut scan_predicate_rewriter));
1209
1210        let new_join_on = new_eq_cond.and(new_other_cond);
1211
1212        let new_predicate = EqJoinPredicate::create(
1213            left_schema_len,
1214            new_scan.schema().len(),
1215            new_join_on.clone(),
1216        );
1217
1218        // Rewrite the join output indices and all output indices referred to the old scan need to
1219        // rewrite.
1220        let new_join_output_indices = output_indices
1221            .iter()
1222            .map(|&x| {
1223                if x < left_schema_len {
1224                    x
1225                } else {
1226                    o2r[x - left_schema_len] + left_schema_len
1227                }
1228            })
1229            .collect_vec();
1230
1231        let new_stream_table_scan =
1232            StreamTableScan::new_with_stream_scan_type(new_scan, StreamScanType::UpstreamOnly);
1233        Ok((
1234            new_stream_table_scan,
1235            new_predicate,
1236            new_join_on,
1237            new_join_output_indices,
1238        ))
1239    }
1240
1241    fn to_stream_temporal_join(
1242        &self,
1243        logical_scan: TemporalJoinScan<'_>,
1244        predicate: EqJoinPredicate,
1245        ctx: &mut ToStreamContext,
1246    ) -> Result<StreamTemporalJoin> {
1247        use super::stream::prelude::*;
1248
1249        assert!(predicate.has_eq());
1250
1251        let table = logical_scan.table();
1252        let output_column_ids = logical_scan.output_column_ids();
1253
1254        // Verify that the right join key columns are the the prefix of the primary key and
1255        // also contain the distribution key.
1256        let order_col_ids = table.order_column_ids();
1257        let dist_key = table.distribution_key.clone();
1258
1259        let mut dist_key_in_order_key_pos = vec![];
1260        for d in dist_key {
1261            let pos = table
1262                .order_column_indices()
1263                .position(|x| x == d)
1264                .expect("dist_key must in order_key");
1265            dist_key_in_order_key_pos.push(pos);
1266        }
1267        // The shortest prefix of order key that contains distribution key.
1268        let shortest_prefix_len = dist_key_in_order_key_pos
1269            .iter()
1270            .max()
1271            .map_or(0, |pos| pos + 1);
1272
1273        // Reorder the join equal predicate to match the order key.
1274        let mut reorder_idx = Vec::with_capacity(shortest_prefix_len);
1275        for order_col_id in order_col_ids {
1276            let mut found = false;
1277            for (i, eq_idx) in predicate.right_eq_indexes().into_iter().enumerate() {
1278                if order_col_id == output_column_ids[eq_idx] {
1279                    reorder_idx.push(i);
1280                    found = true;
1281                    break;
1282                }
1283            }
1284            if !found {
1285                break;
1286            }
1287        }
1288        if reorder_idx.len() < shortest_prefix_len {
1289            return Err(RwError::from(ErrorCode::NotSupported(
1290                "Temporal join requires the equivalence join condition includes the key columns that form the distribution key of the lookup table".into(),
1291                concat!(
1292                    "Use DESCRIBE <table_name> to view the table's key information.\n",
1293                    "You can create an index on the lookup table to facilitate the temporal join if necessary."
1294                ).into(),
1295            )));
1296        }
1297        let lookup_prefix_len = reorder_idx.len();
1298        let predicate = predicate.reorder(&reorder_idx);
1299
1300        let required_dist = if dist_key_in_order_key_pos.is_empty() {
1301            RequiredDist::single()
1302        } else {
1303            let left_eq_indexes = predicate.left_eq_indexes();
1304            let left_dist_key = dist_key_in_order_key_pos
1305                .iter()
1306                .map(|pos| left_eq_indexes[*pos])
1307                .collect_vec();
1308
1309            RequiredDist::hash_shard(&left_dist_key)
1310        };
1311
1312        let left = self.left().to_stream(ctx)?;
1313        // Enforce a shuffle for the temporal join LHS to let the scheduler be able to schedule the join fragment together with the RHS with a `no_shuffle` exchange.
1314        let left = required_dist.stream_enforce(left);
1315
1316        let (new_stream_table_scan, new_predicate, new_join_on, new_join_output_indices) =
1317            Self::temporal_join_scan_predicate_pull_up(
1318                logical_scan,
1319                predicate,
1320                self.output_indices(),
1321                self.left().schema().len(),
1322            )?;
1323
1324        let right = RequiredDist::no_shuffle(new_stream_table_scan.into());
1325        if !new_predicate.has_eq() {
1326            return Err(RwError::from(ErrorCode::NotSupported(
1327                "Temporal join requires a non trivial join condition".into(),
1328                "Please remove the false condition of the join".into(),
1329            )));
1330        }
1331
1332        // Construct a new logical join, because we have change its RHS.
1333        let new_logical_join = generic::Join::new(
1334            left,
1335            right,
1336            new_join_on,
1337            self.join_type(),
1338            new_join_output_indices,
1339        );
1340
1341        let new_predicate = new_predicate.retain_prefix_eq_key(lookup_prefix_len);
1342
1343        let mut new_logical_join = new_logical_join;
1344        new_logical_join.on = generic::JoinOn::EqPredicate(new_predicate);
1345        StreamTemporalJoin::new(new_logical_join, false)
1346    }
1347
1348    fn to_stream_nested_loop_temporal_join(
1349        &self,
1350        logical_scan: TemporalJoinScan<'_>,
1351        predicate: EqJoinPredicate,
1352        ctx: &mut ToStreamContext,
1353    ) -> Result<StreamPlanRef> {
1354        use super::stream::prelude::*;
1355        assert!(!predicate.has_eq());
1356
1357        let left = self.left().to_stream_with_dist_required(
1358            &RequiredDist::PhysicalDist(Distribution::Broadcast),
1359            ctx,
1360        )?;
1361        assert!(left.as_stream_exchange().is_some());
1362
1363        if self.join_type() != JoinType::Inner {
1364            return Err(RwError::from(ErrorCode::NotSupported(
1365                "Temporal join requires an inner join".into(),
1366                "Please use an inner join".into(),
1367            )));
1368        }
1369
1370        if !left.append_only() {
1371            return Err(RwError::from(ErrorCode::NotSupported(
1372                "Nested-loop Temporal join requires the left hash side to be append only".into(),
1373                "Please ensure the left hash side is append only".into(),
1374            )));
1375        }
1376
1377        let (new_stream_table_scan, new_predicate, new_join_on, new_join_output_indices) =
1378            Self::temporal_join_scan_predicate_pull_up(
1379                logical_scan,
1380                predicate,
1381                self.output_indices(),
1382                self.left().schema().len(),
1383            )?;
1384
1385        let right = RequiredDist::no_shuffle(new_stream_table_scan.into());
1386
1387        // Construct a new logical join, because we have change its RHS.
1388        let new_logical_join = generic::Join::new(
1389            left,
1390            right,
1391            new_join_on,
1392            self.join_type(),
1393            new_join_output_indices,
1394        );
1395
1396        let mut new_logical_join = new_logical_join;
1397        new_logical_join.on = generic::JoinOn::EqPredicate(new_predicate);
1398        Ok(StreamTemporalJoin::new(new_logical_join, true)?.into())
1399    }
1400
1401    fn to_stream_dynamic_filter(
1402        &self,
1403        predicate: Condition,
1404        ctx: &mut ToStreamContext,
1405    ) -> Result<Option<StreamPlanRef>> {
1406        use super::stream::prelude::*;
1407
1408        // If there is exactly one predicate, it is a comparison (<, <=, >, >=), and the
1409        // join is a `Inner` or `LeftSemi` join, we can convert the scalar subquery into a
1410        // `StreamDynamicFilter`
1411        let Some((left_key_idx, comparator)) = self.dynamic_filter_candidate(&predicate) else {
1412            return Ok(None);
1413        };
1414
1415        let left = self.left().to_stream(ctx)?.enforce_concrete_distribution();
1416        let right = self.right().to_stream_with_dist_required(
1417            &RequiredDist::PhysicalDist(Distribution::Broadcast),
1418            ctx,
1419        )?;
1420
1421        assert!(right.as_stream_exchange().is_some());
1422        assert_eq!(
1423            *right.inputs().iter().exactly_one().unwrap().distribution(),
1424            Distribution::Single
1425        );
1426
1427        let core = DynamicFilter::new(comparator, left_key_idx, left, right);
1428        let plan = StreamDynamicFilter::new(core)?.into();
1429        // TODO: `DynamicFilterExecutor` should support `output_indices` in `ChunkBuilder`
1430        if self
1431            .output_indices()
1432            .iter()
1433            .copied()
1434            .ne(0..self.left().schema().len())
1435        {
1436            // The schema of dynamic filter is always the same as the left side now, and we have
1437            // checked that all output columns are from the left side before.
1438            let logical_project = generic::Project::with_mapping(
1439                plan,
1440                ColIndexMapping::with_remaining_columns(
1441                    self.output_indices(),
1442                    self.left().schema().len(),
1443                ),
1444            );
1445            Ok(Some(StreamProject::new(logical_project).into()))
1446        } else {
1447            Ok(Some(plan))
1448        }
1449    }
1450
1451    pub fn index_lookup_join_to_batch_lookup_join(&self) -> Result<BatchPlanRef> {
1452        let predicate = EqJoinPredicate::create(
1453            self.left().schema().len(),
1454            self.right().schema().len(),
1455            self.on().clone(),
1456        );
1457        assert!(predicate.has_eq());
1458
1459        let join = self
1460            .core
1461            .clone_with_inputs(self.core.left.to_batch()?, self.core.right.to_batch()?);
1462
1463        Ok(self
1464            .to_batch_lookup_join(predicate, join)?
1465            .expect("Fail to convert to lookup join")
1466            .into())
1467    }
1468
1469    fn to_stream_asof_join(
1470        &self,
1471        predicate: EqJoinPredicate,
1472        ctx: &mut ToStreamContext,
1473    ) -> Result<StreamPlanRef> {
1474        use super::stream::prelude::*;
1475
1476        if predicate.eq_keys().is_empty() {
1477            return Err(ErrorCode::InvalidInputSyntax(
1478                "AsOf join requires at least 1 equal condition".to_owned(),
1479            )
1480            .into());
1481        }
1482
1483        let (left, right) = self.get_stream_input_for_hash_join(&predicate, ctx)?;
1484        let left_len = left.schema().len();
1485        let mut core = self.core.clone_with_inputs(left, right);
1486        core.on = generic::JoinOn::EqPredicate(predicate);
1487
1488        let inequality_desc = Self::get_inequality_desc_from_predicate(
1489            core.on
1490                .as_eq_predicate_ref()
1491                .expect("core predicate must exist")
1492                .other_cond()
1493                .clone(),
1494            left_len,
1495        )?;
1496
1497        Ok(StreamAsOfJoin::new(core, inequality_desc)?.into())
1498    }
1499
1500    /// Convert the logical join to a Hash join.
1501    fn to_batch_hash_join(
1502        &self,
1503        logical_join: generic::Join<BatchPlanRef>,
1504        predicate: EqJoinPredicate,
1505    ) -> Result<BatchPlanRef> {
1506        use super::batch::prelude::*;
1507
1508        let left_schema_len = logical_join.left.schema().len();
1509        let asof_desc = self
1510            .is_asof_join()
1511            .then(|| {
1512                Self::get_inequality_desc_from_predicate(
1513                    predicate.other_cond().clone(),
1514                    left_schema_len,
1515                )
1516            })
1517            .transpose()?;
1518
1519        let logical_join = generic::Join {
1520            on: generic::JoinOn::EqPredicate(predicate),
1521            ..logical_join
1522        };
1523        let batch_join = BatchHashJoin::new(logical_join, asof_desc);
1524        Ok(batch_join.into())
1525    }
1526
1527    pub fn get_inequality_desc_from_predicate(
1528        predicate: Condition,
1529        left_input_len: usize,
1530    ) -> Result<AsOfJoinDesc> {
1531        let expr: ExprImpl = predicate.into();
1532        if let Some((left_input_ref, expr_type, right_input_ref)) = expr.as_comparison_cond() {
1533            if left_input_ref.index() < left_input_len && right_input_ref.index() >= left_input_len
1534            {
1535                Ok(AsOfJoinDesc {
1536                    left_idx: left_input_ref.index() as u32,
1537                    right_idx: (right_input_ref.index() - left_input_len) as u32,
1538                    inequality_type: Self::expr_type_to_comparison_type(expr_type)?.into(),
1539                })
1540            } else {
1541                bail!("inequal condition from the same side should be push down in optimizer");
1542            }
1543        } else {
1544            Err(ErrorCode::InvalidInputSyntax(
1545                "AsOf join requires exactly 1 ineuquality condition".to_owned(),
1546            )
1547            .into())
1548        }
1549    }
1550
1551    fn expr_type_to_comparison_type(expr_type: PbType) -> Result<PbAsOfJoinInequalityType> {
1552        match expr_type {
1553            PbType::LessThan => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeLt),
1554            PbType::LessThanOrEqual => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeLe),
1555            PbType::GreaterThan => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeGt),
1556            PbType::GreaterThanOrEqual => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeGe),
1557            _ => Err(ErrorCode::InvalidInputSyntax(format!(
1558                "Invalid comparison type: {}",
1559                expr_type.as_str_name()
1560            ))
1561            .into()),
1562        }
1563    }
1564}
1565
1566impl ToBatch for LogicalJoin {
1567    fn to_batch(&self) -> Result<crate::optimizer::plan_node::BatchPlanRef> {
1568        let predicate = EqJoinPredicate::create(
1569            self.left().schema().len(),
1570            self.right().schema().len(),
1571            self.on().clone(),
1572        );
1573
1574        let batch_join = self
1575            .core
1576            .clone_with_inputs(self.core.left.to_batch()?, self.core.right.to_batch()?);
1577
1578        let ctx = self.base.ctx();
1579        let config = ctx.session_ctx().config();
1580
1581        if predicate.has_eq() {
1582            if !predicate.eq_keys_are_type_aligned() {
1583                return Err(ErrorCode::InternalError(format!(
1584                    "Join eq keys are not aligned for predicate: {predicate:?}"
1585                ))
1586                .into());
1587            }
1588            if config.batch_enable_lookup_join()
1589                && let Some(lookup_join) = self.to_batch_lookup_join_with_index_selection(
1590                    predicate.clone(),
1591                    batch_join.clone(),
1592                )?
1593            {
1594                return Ok(lookup_join.into());
1595            }
1596            self.to_batch_hash_join(batch_join, predicate)
1597        } else if self.is_asof_join() {
1598            Err(ErrorCode::InvalidInputSyntax(
1599                "AsOf join requires at least 1 equal condition".to_owned(),
1600            )
1601            .into())
1602        } else {
1603            // Convert to Nested-loop Join for non-equal joins
1604            Ok(BatchNestedLoopJoin::new(batch_join).into())
1605        }
1606    }
1607}
1608
1609impl ToStream for LogicalJoin {
1610    fn to_stream(
1611        &self,
1612        ctx: &mut ToStreamContext,
1613    ) -> Result<crate::optimizer::plan_node::StreamPlanRef> {
1614        if self
1615            .on()
1616            .conjunctions
1617            .iter()
1618            .any(|cond| cond.count_nows() > 0)
1619        {
1620            return Err(ErrorCode::NotSupported(
1621                "optimizer has tried to separate the temporal predicate(with now() expression) from the on condition, but it still reminded in on join's condition. Considering move it into WHERE clause?".to_owned(),
1622                 "please refer to https://docs.risingwave.com/processing/sql/temporal-filters for more information".to_owned()).into());
1623        }
1624
1625        let predicate = EqJoinPredicate::create(
1626            self.left().schema().len(),
1627            self.right().schema().len(),
1628            self.on().clone(),
1629        );
1630
1631        if self.join_type() == JoinType::AsofInner || self.join_type() == JoinType::AsofLeftOuter {
1632            self.to_stream_asof_join(predicate, ctx)
1633        } else if predicate.has_eq() {
1634            if !predicate.eq_keys_are_type_aligned() {
1635                return Err(ErrorCode::InternalError(format!(
1636                    "Join eq keys are not aligned for predicate: {predicate:?}"
1637                ))
1638                .into());
1639            }
1640
1641            if let Some(scan) = self.should_be_stream_temporal_join(ctx)? {
1642                self.to_stream_temporal_join_with_index_selection(scan, predicate, ctx)
1643            } else {
1644                self.to_stream_hash_join(predicate, ctx)
1645            }
1646        } else if let Some(scan) = self.should_be_stream_temporal_join(ctx)? {
1647            self.to_stream_nested_loop_temporal_join(scan, predicate, ctx)
1648        } else if let Some(dynamic_filter) =
1649            self.to_stream_dynamic_filter(self.on().clone(), ctx)?
1650        {
1651            Ok(dynamic_filter)
1652        } else {
1653            Err(RwError::from(ErrorCode::NotSupported(
1654                "streaming nested-loop join".to_owned(),
1655                "The non-equal join in the query requires a nested-loop join executor, which could be very expensive to run. \
1656                 Consider rewriting the query to use dynamic filter as a substitute if possible.\n\
1657                 See also: https://docs.risingwave.com/processing/sql/dynamic-filters".to_owned(),
1658            )))
1659        }
1660    }
1661
1662    fn logical_rewrite_for_stream(
1663        &self,
1664        ctx: &mut RewriteStreamContext,
1665    ) -> Result<(PlanRef, ColIndexMapping)> {
1666        let eq_indexes = self.eq_indexes();
1667        let (logical_left, logical_right) = if eq_indexes.is_empty() {
1668            (self.left(), self.right())
1669        } else {
1670            let lhs_join_key_idx = eq_indexes.iter().map(|(l, _)| *l).collect_vec();
1671            if self.should_be_temporal_join() {
1672                (
1673                    try_enforce_locality_requirement(self.left(), &lhs_join_key_idx),
1674                    self.right(),
1675                )
1676            } else {
1677                let rhs_join_key_idx = eq_indexes.iter().map(|(_, r)| *r).collect_vec();
1678                (
1679                    try_enforce_locality_requirement(self.left(), &lhs_join_key_idx),
1680                    try_enforce_locality_requirement(self.right(), &rhs_join_key_idx),
1681                )
1682            }
1683        };
1684
1685        let (left, left_col_change) = logical_left.logical_rewrite_for_stream(ctx)?;
1686        let left_len = left.schema().len();
1687        let (right, right_col_change) = logical_right.logical_rewrite_for_stream(ctx)?;
1688        let (join, out_col_change) = self.rewrite_with_left_right(
1689            left.clone(),
1690            left_col_change,
1691            right.clone(),
1692            right_col_change,
1693        );
1694
1695        let mapping = ColIndexMapping::with_remaining_columns(
1696            join.output_indices(),
1697            join.internal_column_num(),
1698        );
1699
1700        let l2o = join.core.l2i_col_mapping().composite(&mapping);
1701        let r2o = join.core.r2i_col_mapping().composite(&mapping);
1702
1703        // Add missing pk indices to the logical join
1704        let mut left_to_add = left
1705            .expect_stream_key()
1706            .iter()
1707            .cloned()
1708            .filter(|i| l2o.try_map(*i).is_none())
1709            .collect_vec();
1710
1711        let mut right_to_add = right
1712            .expect_stream_key()
1713            .iter()
1714            .filter(|&&i| r2o.try_map(i).is_none())
1715            .map(|&i| i + left_len)
1716            .collect_vec();
1717
1718        // NOTE(st1page): add join keys in the pk_indices a work around before we really have stream
1719        // key.
1720        let right_len = right.schema().len();
1721        let eq_predicate = EqJoinPredicate::create(left_len, right_len, join.on().clone());
1722
1723        let either_or_both = self.core.add_which_join_key_to_pk();
1724
1725        for (lk, rk) in eq_predicate.eq_indexes() {
1726            match either_or_both {
1727                EitherOrBoth::Left(_) => {
1728                    if l2o.try_map(lk).is_none() {
1729                        left_to_add.push(lk);
1730                    }
1731                }
1732                EitherOrBoth::Right(_) => {
1733                    if r2o.try_map(rk).is_none() {
1734                        right_to_add.push(rk + left_len)
1735                    }
1736                }
1737                EitherOrBoth::Both(_, _) => {
1738                    if l2o.try_map(lk).is_none() {
1739                        left_to_add.push(lk);
1740                    }
1741                    if r2o.try_map(rk).is_none() {
1742                        right_to_add.push(rk + left_len)
1743                    }
1744                }
1745            };
1746        }
1747        let left_to_add = left_to_add.into_iter().unique();
1748        let right_to_add = right_to_add.into_iter().unique();
1749        // NOTE(st1page) over
1750
1751        let mut new_output_indices = join.output_indices().clone();
1752        if !join.is_right_join() {
1753            new_output_indices.extend(left_to_add);
1754        }
1755        if !join.is_left_join() {
1756            new_output_indices.extend(right_to_add);
1757        }
1758
1759        let join_with_pk = join.clone_with_output_indices(new_output_indices);
1760
1761        let plan = if join_with_pk.join_type() == JoinType::FullOuter {
1762            // ignore the all NULL to maintain the stream key's uniqueness, see https://github.com/risingwavelabs/risingwave/issues/8084 for more information
1763
1764            let l2o = join_with_pk
1765                .core
1766                .l2i_col_mapping()
1767                .composite(&join_with_pk.core.i2o_col_mapping());
1768            let r2o = join_with_pk
1769                .core
1770                .r2i_col_mapping()
1771                .composite(&join_with_pk.core.i2o_col_mapping());
1772            let mut left_right_keys = join_with_pk
1773                .left()
1774                .expect_stream_key()
1775                .iter()
1776                .map(|i| l2o.map(*i))
1777                .collect_vec();
1778            left_right_keys.extend(
1779                join_with_pk
1780                    .right()
1781                    .expect_stream_key()
1782                    .iter()
1783                    .map(|i| r2o.map(*i)),
1784            );
1785            left_right_keys.extend(
1786                eq_predicate
1787                    .eq_indexes()
1788                    .iter()
1789                    .flat_map(|(lk, rk)| [l2o.map(*lk), r2o.map(*rk)]),
1790            );
1791            let left_right_keys = left_right_keys.into_iter().unique().collect_vec();
1792            let plan: PlanRef = join_with_pk.into();
1793            LogicalFilter::filter_out_all_null_keys(plan, &left_right_keys)
1794        } else {
1795            join_with_pk.into()
1796        };
1797
1798        // the added columns is at the end, so it will not change the exists column index
1799        Ok((plan, out_col_change))
1800    }
1801
1802    fn try_better_locality(&self, columns: &[usize]) -> Option<PlanRef> {
1803        // Only propagate locality for temporal-filter.
1804        if !self.temporal_filter_candidate() {
1805            return None;
1806        }
1807
1808        // Temporal filter only outputs columns from left input, so mapping is safe.
1809        let o2i_mapping = self.core.o2i_col_mapping();
1810        let left_input_columns = columns
1811            .iter()
1812            .map(|&col| o2i_mapping.try_map(col))
1813            .collect::<Option<Vec<usize>>>()?;
1814        if let Some(better_left_plan) = self.left().try_better_locality(&left_input_columns) {
1815            return Some(
1816                self.clone_with_left_right(better_left_plan, self.right())
1817                    .into(),
1818            );
1819        }
1820        None
1821    }
1822}
1823
1824#[cfg(test)]
1825mod tests {
1826
1827    use std::collections::HashSet;
1828
1829    use risingwave_common::catalog::{Field, Schema};
1830    use risingwave_common::types::{DataType, Datum};
1831    use risingwave_pb::expr::expr_node::Type;
1832
1833    use super::*;
1834    use crate::expr::{FunctionCall, Literal, assert_eq_input_ref};
1835    use crate::optimizer::optimizer_context::OptimizerContext;
1836    use crate::optimizer::plan_node::LogicalValues;
1837    use crate::optimizer::property::FunctionalDependency;
1838
1839    /// Pruning
1840    /// ```text
1841    /// Join(on: input_ref(1)=input_ref(3))
1842    ///   TableScan(v1, v2, v3)
1843    ///   TableScan(v4, v5, v6)
1844    /// ```
1845    /// with required columns [2,3] will result in
1846    /// ```text
1847    /// Project(input_ref(1), input_ref(2))
1848    ///   Join(on: input_ref(0)=input_ref(2))
1849    ///     TableScan(v2, v3)
1850    ///     TableScan(v4)
1851    /// ```
1852    #[tokio::test]
1853    async fn test_prune_join() {
1854        let ty = DataType::Int32;
1855        let ctx = OptimizerContext::mock().await;
1856        let fields: Vec<Field> = (1..7)
1857            .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
1858            .collect();
1859        let left = LogicalValues::new(
1860            vec![],
1861            Schema {
1862                fields: fields[0..3].to_vec(),
1863            },
1864            ctx.clone(),
1865        );
1866        let right = LogicalValues::new(
1867            vec![],
1868            Schema {
1869                fields: fields[3..6].to_vec(),
1870            },
1871            ctx,
1872        );
1873        let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
1874            FunctionCall::new(
1875                Type::Equal,
1876                vec![
1877                    ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
1878                    ExprImpl::InputRef(Box::new(InputRef::new(3, ty))),
1879                ],
1880            )
1881            .unwrap(),
1882        ));
1883        let join_type = JoinType::Inner;
1884        let join: PlanRef = LogicalJoin::new(
1885            left.into(),
1886            right.into(),
1887            join_type,
1888            Condition::with_expr(on),
1889        )
1890        .into();
1891
1892        // Perform the prune
1893        let required_cols = vec![2, 3];
1894        let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1895
1896        // Check the result
1897        let join = plan.as_logical_join().unwrap();
1898        assert_eq!(join.schema().fields().len(), 2);
1899        assert_eq!(join.schema().fields()[0], fields[2]);
1900        assert_eq!(join.schema().fields()[1], fields[3]);
1901
1902        let expr: ExprImpl = join.on().clone().into();
1903        let call = expr.as_function_call().unwrap();
1904        assert_eq_input_ref!(&call.inputs()[0], 0);
1905        assert_eq_input_ref!(&call.inputs()[1], 2);
1906
1907        let left = join.left();
1908        let left = left.as_logical_values().unwrap();
1909        assert_eq!(left.schema().fields(), &fields[1..3]);
1910        let right = join.right();
1911        let right = right.as_logical_values().unwrap();
1912        assert_eq!(right.schema().fields(), &fields[3..4]);
1913    }
1914
1915    /// Semi join panicked previously at `prune_col`. Add test to prevent regression.
1916    #[tokio::test]
1917    async fn test_prune_semi_join() {
1918        let ty = DataType::Int32;
1919        let ctx = OptimizerContext::mock().await;
1920        let fields: Vec<Field> = (1..7)
1921            .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
1922            .collect();
1923        let left = LogicalValues::new(
1924            vec![],
1925            Schema {
1926                fields: fields[0..3].to_vec(),
1927            },
1928            ctx.clone(),
1929        );
1930        let right = LogicalValues::new(
1931            vec![],
1932            Schema {
1933                fields: fields[3..6].to_vec(),
1934            },
1935            ctx,
1936        );
1937        let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
1938            FunctionCall::new(
1939                Type::Equal,
1940                vec![
1941                    ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
1942                    ExprImpl::InputRef(Box::new(InputRef::new(4, ty))),
1943                ],
1944            )
1945            .unwrap(),
1946        ));
1947        for join_type in [
1948            JoinType::LeftSemi,
1949            JoinType::RightSemi,
1950            JoinType::LeftAnti,
1951            JoinType::RightAnti,
1952        ] {
1953            let join = LogicalJoin::new(
1954                left.clone().into(),
1955                right.clone().into(),
1956                join_type,
1957                Condition::with_expr(on.clone()),
1958            );
1959
1960            let offset = if join.is_right_join() { 3 } else { 0 };
1961            let join: PlanRef = join.into();
1962            // Perform the prune
1963            let required_cols = vec![0];
1964            // key 0 is never used in the join (always key 1)
1965            let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1966            let as_plan = plan.as_logical_join().unwrap();
1967            // Check the result
1968            assert_eq!(as_plan.schema().fields().len(), 1);
1969            assert_eq!(as_plan.schema().fields()[0], fields[offset]);
1970
1971            // Perform the prune
1972            let required_cols = vec![0, 1, 2];
1973            // should not panic here
1974            let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1975            let as_plan = plan.as_logical_join().unwrap();
1976            // Check the result
1977            assert_eq!(as_plan.schema().fields().len(), 3);
1978            assert_eq!(as_plan.schema().fields()[0], fields[offset]);
1979            assert_eq!(as_plan.schema().fields()[1], fields[offset + 1]);
1980            assert_eq!(as_plan.schema().fields()[2], fields[offset + 2]);
1981        }
1982    }
1983
1984    /// Pruning
1985    /// ```text
1986    /// Join(on: input_ref(1)=input_ref(3))
1987    ///   TableScan(v1, v2, v3)
1988    ///   TableScan(v4, v5, v6)
1989    /// ```
1990    /// with required columns [1, 3] will result in
1991    /// ```text
1992    /// Join(on: input_ref(0)=input_ref(1))
1993    ///   TableScan(v2)
1994    ///   TableScan(v4)
1995    /// ```
1996    #[tokio::test]
1997    async fn test_prune_join_no_project() {
1998        let ty = DataType::Int32;
1999        let ctx = OptimizerContext::mock().await;
2000        let fields: Vec<Field> = (1..7)
2001            .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
2002            .collect();
2003        let left = LogicalValues::new(
2004            vec![],
2005            Schema {
2006                fields: fields[0..3].to_vec(),
2007            },
2008            ctx.clone(),
2009        );
2010        let right = LogicalValues::new(
2011            vec![],
2012            Schema {
2013                fields: fields[3..6].to_vec(),
2014            },
2015            ctx,
2016        );
2017        let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
2018            FunctionCall::new(
2019                Type::Equal,
2020                vec![
2021                    ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
2022                    ExprImpl::InputRef(Box::new(InputRef::new(3, ty))),
2023                ],
2024            )
2025            .unwrap(),
2026        ));
2027        let join_type = JoinType::Inner;
2028        let join: PlanRef = LogicalJoin::new(
2029            left.into(),
2030            right.into(),
2031            join_type,
2032            Condition::with_expr(on),
2033        )
2034        .into();
2035
2036        // Perform the prune
2037        let required_cols = vec![1, 3];
2038        let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
2039
2040        // Check the result
2041        let join = plan.as_logical_join().unwrap();
2042        assert_eq!(join.schema().fields().len(), 2);
2043        assert_eq!(join.schema().fields()[0], fields[1]);
2044        assert_eq!(join.schema().fields()[1], fields[3]);
2045
2046        let expr: ExprImpl = join.on().clone().into();
2047        let call = expr.as_function_call().unwrap();
2048        assert_eq_input_ref!(&call.inputs()[0], 0);
2049        assert_eq_input_ref!(&call.inputs()[1], 1);
2050
2051        let left = join.left();
2052        let left = left.as_logical_values().unwrap();
2053        assert_eq!(left.schema().fields(), &fields[1..2]);
2054        let right = join.right();
2055        let right = right.as_logical_values().unwrap();
2056        assert_eq!(right.schema().fields(), &fields[3..4]);
2057    }
2058
2059    /// Convert
2060    /// ```text
2061    /// Join(on: ($1 = $3) AND ($2 == 42))
2062    ///   TableScan(v1, v2, v3)
2063    ///   TableScan(v4, v5, v6)
2064    /// ```
2065    /// to
2066    /// ```text
2067    /// Filter($2 == 42)
2068    ///   HashJoin(on: $1 = $3)
2069    ///     TableScan(v1, v2, v3)
2070    ///     TableScan(v4, v5, v6)
2071    /// ```
2072    #[tokio::test]
2073    async fn test_join_to_batch() {
2074        let ctx = OptimizerContext::mock().await;
2075        let fields: Vec<Field> = (1..7)
2076            .map(|i| Field::with_name(DataType::Int32, format!("v{}", i)))
2077            .collect();
2078        let left = LogicalValues::new(
2079            vec![],
2080            Schema {
2081                fields: fields[0..3].to_vec(),
2082            },
2083            ctx.clone(),
2084        );
2085        let right = LogicalValues::new(
2086            vec![],
2087            Schema {
2088                fields: fields[3..6].to_vec(),
2089            },
2090            ctx,
2091        );
2092
2093        fn input_ref(i: usize) -> ExprImpl {
2094            ExprImpl::InputRef(Box::new(InputRef::new(i, DataType::Int32)))
2095        }
2096        let eq_cond = ExprImpl::FunctionCall(Box::new(
2097            FunctionCall::new(Type::Equal, vec![input_ref(1), input_ref(3)]).unwrap(),
2098        ));
2099        let non_eq_cond = ExprImpl::FunctionCall(Box::new(
2100            FunctionCall::new(
2101                Type::Equal,
2102                vec![
2103                    input_ref(2),
2104                    ExprImpl::Literal(Box::new(Literal::new(
2105                        Datum::Some(42_i32.into()),
2106                        DataType::Int32,
2107                    ))),
2108                ],
2109            )
2110            .unwrap(),
2111        ));
2112        // Condition: ($1 = $3) AND ($2 == 42)
2113        let on_cond = ExprImpl::FunctionCall(Box::new(
2114            FunctionCall::new(Type::And, vec![eq_cond.clone(), non_eq_cond.clone()]).unwrap(),
2115        ));
2116
2117        let join_type = JoinType::Inner;
2118        let logical_join = LogicalJoin::new(
2119            left.into(),
2120            right.into(),
2121            join_type,
2122            Condition::with_expr(on_cond),
2123        );
2124
2125        // Perform `to_batch`
2126        let result = logical_join.to_batch().unwrap();
2127
2128        // Expected plan:  HashJoin($1 = $3 AND $2 == 42)
2129        let hash_join = result.as_batch_hash_join().unwrap();
2130        assert_eq!(
2131            ExprImpl::from(hash_join.eq_join_predicate().eq_cond()),
2132            eq_cond
2133        );
2134        assert_eq!(
2135            *hash_join
2136                .eq_join_predicate()
2137                .non_eq_cond()
2138                .conjunctions
2139                .first()
2140                .unwrap(),
2141            non_eq_cond
2142        );
2143    }
2144
2145    /// Convert
2146    /// ```text
2147    /// Join(join_type: left outer, on: ($1 = $3) AND ($2 == 42))
2148    ///   TableScan(v1, v2, v3)
2149    ///   TableScan(v4, v5, v6)
2150    /// ```
2151    /// to
2152    /// ```text
2153    /// HashJoin(join_type: left outer, on: ($1 = $3) AND ($2 == 42))
2154    ///   TableScan(v1, v2, v3)
2155    ///   TableScan(v4, v5, v6)
2156    /// ```
2157    #[tokio::test]
2158    #[ignore] // ignore due to refactor logical scan, but the test seem to duplicate with the explain test
2159    // framework, maybe we will remove it?
2160    async fn test_join_to_stream() {
2161        // let ctx = Rc::new(RefCell::new(QueryContext::mock().await));
2162        // let fields: Vec<Field> = (1..7)
2163        //     .map(|i| Field {
2164        //         data_type: DataType::Int32,
2165        //         name: format!("v{}", i),
2166        //     })
2167        //     .collect();
2168        // let left = LogicalScan::new(
2169        //     "left".to_string(),
2170        //     TableId::new(0),
2171        //     vec![1.into(), 2.into(), 3.into()],
2172        //     Schema {
2173        //         fields: fields[0..3].to_vec(),
2174        //     },
2175        //     ctx.clone(),
2176        // );
2177        // let right = LogicalScan::new(
2178        //     "right".to_string(),
2179        //     TableId::new(0),
2180        //     vec![4.into(), 5.into(), 6.into()],
2181        //     Schema {
2182        //                 fields: fields[3..6].to_vec(),
2183        //     },
2184        //     ctx,
2185        // );
2186        // let eq_cond = ExprImpl::FunctionCall(Box::new(
2187        //     FunctionCall::new(
2188        //         Type::Equal,
2189        //         vec![
2190        //             ExprImpl::InputRef(Box::new(InputRef::new(1, DataType::Int32))),
2191        //             ExprImpl::InputRef(Box::new(InputRef::new(3, DataType::Int32))),
2192        //         ],
2193        //     )
2194        //     .unwrap(),
2195        // ));
2196        // let non_eq_cond = ExprImpl::FunctionCall(Box::new(
2197        //     FunctionCall::new(
2198        //         Type::Equal,
2199        //         vec![
2200        //             ExprImpl::InputRef(Box::new(InputRef::new(2, DataType::Int32))),
2201        //             ExprImpl::Literal(Box::new(Literal::new(
2202        //                 Datum::Some(42_i32.into()),
2203        //                 DataType::Int32,
2204        //             ))),
2205        //         ],
2206        //     )
2207        //     .unwrap(),
2208        // ));
2209        // // Condition: ($1 = $3) AND ($2 == 42)
2210        // let on_cond = ExprImpl::FunctionCall(Box::new(
2211        //     FunctionCall::new(Type::And, vec![eq_cond, non_eq_cond]).unwrap(),
2212        // ));
2213
2214        // let join_type = JoinType::LeftOuter;
2215        // let logical_join = LogicalJoin::new(
2216        //     left.clone().into(),
2217        //     right.clone().into(),
2218        //     join_type,
2219        //     Condition::with_expr(on_cond.clone()),
2220        // );
2221
2222        // // Perform `to_stream`
2223        // let result = logical_join.to_stream();
2224
2225        // // Expected plan: HashJoin(($1 = $3) AND ($2 == 42))
2226        // let hash_join = result.as_stream_hash_join().unwrap();
2227        // assert_eq!(hash_join.eq_join_predicate().all_cond().as_expr(), on_cond);
2228    }
2229    /// Pruning
2230    /// ```text
2231    /// Join(on: input_ref(1)=input_ref(3))
2232    ///   TableScan(v1, v2, v3)
2233    ///   TableScan(v4, v5, v6)
2234    /// ```
2235    /// with required columns [3, 2] will result in
2236    /// ```text
2237    /// Project(input_ref(2), input_ref(1))
2238    ///   Join(on: input_ref(0)=input_ref(2))
2239    ///     TableScan(v2, v3)
2240    ///     TableScan(v4)
2241    /// ```
2242    #[tokio::test]
2243    async fn test_join_column_prune_with_order_required() {
2244        let ty = DataType::Int32;
2245        let ctx = OptimizerContext::mock().await;
2246        let fields: Vec<Field> = (1..7)
2247            .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
2248            .collect();
2249        let left = LogicalValues::new(
2250            vec![],
2251            Schema {
2252                fields: fields[0..3].to_vec(),
2253            },
2254            ctx.clone(),
2255        );
2256        let right = LogicalValues::new(
2257            vec![],
2258            Schema {
2259                fields: fields[3..6].to_vec(),
2260            },
2261            ctx,
2262        );
2263        let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
2264            FunctionCall::new(
2265                Type::Equal,
2266                vec![
2267                    ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
2268                    ExprImpl::InputRef(Box::new(InputRef::new(3, ty))),
2269                ],
2270            )
2271            .unwrap(),
2272        ));
2273        let join_type = JoinType::Inner;
2274        let join: PlanRef = LogicalJoin::new(
2275            left.into(),
2276            right.into(),
2277            join_type,
2278            Condition::with_expr(on),
2279        )
2280        .into();
2281
2282        // Perform the prune
2283        let required_cols = vec![3, 2];
2284        let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
2285
2286        // Check the result
2287        let join = plan.as_logical_join().unwrap();
2288        assert_eq!(join.schema().fields().len(), 2);
2289        assert_eq!(join.schema().fields()[0], fields[3]);
2290        assert_eq!(join.schema().fields()[1], fields[2]);
2291
2292        let expr: ExprImpl = join.on().clone().into();
2293        let call = expr.as_function_call().unwrap();
2294        assert_eq_input_ref!(&call.inputs()[0], 0);
2295        assert_eq_input_ref!(&call.inputs()[1], 2);
2296
2297        let left = join.left();
2298        let left = left.as_logical_values().unwrap();
2299        assert_eq!(left.schema().fields(), &fields[1..3]);
2300        let right = join.right();
2301        let right = right.as_logical_values().unwrap();
2302        assert_eq!(right.schema().fields(), &fields[3..4]);
2303    }
2304
2305    #[tokio::test]
2306    async fn fd_derivation_inner_outer_join() {
2307        // left: [l0, l1], right: [r0, r1, r2]
2308        // FD: l0 --> l1, r0 --> { r1, r2 }
2309        // On: l0 = 0 AND l1 = r1
2310        //
2311        // Inner Join:
2312        //  Schema: [l0, l1, r0, r1, r2]
2313        //  FD: l0 --> l1, r0 --> { r1, r2 }, {} --> l0, l1 --> r1, r1 --> l1
2314        // Left Outer Join:
2315        //  Schema: [l0, l1, r0, r1, r2]
2316        //  FD: l0 --> l1
2317        // Right Outer Join:
2318        //  Schema: [l0, l1, r0, r1, r2]
2319        //  FD: r0 --> { r1, r2 }
2320        // Full Outer Join:
2321        //  Schema: [l0, l1, r0, r1, r2]
2322        //  FD: empty
2323        // Left Semi/Anti Join:
2324        //  Schema: [l0, l1]
2325        //  FD: l0 --> l1
2326        // Right Semi/Anti Join:
2327        //  Schema: [r0, r1, r2]
2328        //  FD: r0 --> {r1, r2}
2329        let ctx = OptimizerContext::mock().await;
2330        let left = {
2331            let fields: Vec<Field> = vec![
2332                Field::with_name(DataType::Int32, "l0"),
2333                Field::with_name(DataType::Int32, "l1"),
2334            ];
2335            let mut values = LogicalValues::new(vec![], Schema { fields }, ctx.clone());
2336            // 0 --> 1
2337            values
2338                .base
2339                .functional_dependency_mut()
2340                .add_functional_dependency_by_column_indices(&[0], &[1]);
2341            values
2342        };
2343        let right = {
2344            let fields: Vec<Field> = vec![
2345                Field::with_name(DataType::Int32, "r0"),
2346                Field::with_name(DataType::Int32, "r1"),
2347                Field::with_name(DataType::Int32, "r2"),
2348            ];
2349            let mut values = LogicalValues::new(vec![], Schema { fields }, ctx);
2350            // 0 --> 1, 2
2351            values
2352                .base
2353                .functional_dependency_mut()
2354                .add_functional_dependency_by_column_indices(&[0], &[1, 2]);
2355            values
2356        };
2357        // l0 = 0 AND l1 = r1
2358        let on: ExprImpl = FunctionCall::new(
2359            Type::And,
2360            vec![
2361                FunctionCall::new(
2362                    Type::Equal,
2363                    vec![
2364                        InputRef::new(0, DataType::Int32).into(),
2365                        ExprImpl::literal_int(0),
2366                    ],
2367                )
2368                .unwrap()
2369                .into(),
2370                FunctionCall::new(
2371                    Type::Equal,
2372                    vec![
2373                        InputRef::new(1, DataType::Int32).into(),
2374                        InputRef::new(3, DataType::Int32).into(),
2375                    ],
2376                )
2377                .unwrap()
2378                .into(),
2379            ],
2380        )
2381        .unwrap()
2382        .into();
2383        let expected_fd_set = [
2384            (
2385                JoinType::Inner,
2386                [
2387                    // inherit from left
2388                    FunctionalDependency::with_indices(5, &[0], &[1]),
2389                    // inherit from right
2390                    FunctionalDependency::with_indices(5, &[2], &[3, 4]),
2391                    // constant column in join condition
2392                    FunctionalDependency::with_indices(5, &[], &[0]),
2393                    // eq column in join condition
2394                    FunctionalDependency::with_indices(5, &[1], &[3]),
2395                    FunctionalDependency::with_indices(5, &[3], &[1]),
2396                ]
2397                .into_iter()
2398                .collect::<HashSet<_>>(),
2399            ),
2400            (JoinType::FullOuter, HashSet::new()),
2401            (
2402                JoinType::RightOuter,
2403                [
2404                    // inherit from right
2405                    FunctionalDependency::with_indices(5, &[2], &[3, 4]),
2406                ]
2407                .into_iter()
2408                .collect::<HashSet<_>>(),
2409            ),
2410            (
2411                JoinType::LeftOuter,
2412                [
2413                    // inherit from left
2414                    FunctionalDependency::with_indices(5, &[0], &[1]),
2415                ]
2416                .into_iter()
2417                .collect::<HashSet<_>>(),
2418            ),
2419            (
2420                JoinType::LeftSemi,
2421                [
2422                    // inherit from left
2423                    FunctionalDependency::with_indices(2, &[0], &[1]),
2424                ]
2425                .into_iter()
2426                .collect::<HashSet<_>>(),
2427            ),
2428            (
2429                JoinType::LeftAnti,
2430                [
2431                    // inherit from left
2432                    FunctionalDependency::with_indices(2, &[0], &[1]),
2433                ]
2434                .into_iter()
2435                .collect::<HashSet<_>>(),
2436            ),
2437            (
2438                JoinType::RightSemi,
2439                [
2440                    // inherit from right
2441                    FunctionalDependency::with_indices(3, &[0], &[1, 2]),
2442                ]
2443                .into_iter()
2444                .collect::<HashSet<_>>(),
2445            ),
2446            (
2447                JoinType::RightAnti,
2448                [
2449                    // inherit from right
2450                    FunctionalDependency::with_indices(3, &[0], &[1, 2]),
2451                ]
2452                .into_iter()
2453                .collect::<HashSet<_>>(),
2454            ),
2455        ];
2456
2457        for (join_type, expected_res) in expected_fd_set {
2458            let join = LogicalJoin::new(
2459                left.clone().into(),
2460                right.clone().into(),
2461                join_type,
2462                Condition::with_expr(on.clone()),
2463            );
2464            let fd_set = join
2465                .functional_dependency()
2466                .as_dependencies()
2467                .iter()
2468                .cloned()
2469                .collect::<HashSet<_>>();
2470            assert_eq!(fd_set, expected_res);
2471        }
2472    }
2473}