risingwave_frontend/optimizer/plan_node/
logical_join.rs

1// Copyright 2022 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::ops::Deref;
17
18use fixedbitset::FixedBitSet;
19use itertools::{EitherOrBoth, Itertools};
20use pretty_xmlish::{Pretty, XmlNode};
21use risingwave_expr::bail;
22use risingwave_pb::expr::expr_node::PbType;
23use risingwave_pb::plan_common::{AsOfJoinDesc, JoinType, PbAsOfJoinInequalityType};
24use risingwave_pb::stream_plan::StreamScanType;
25use risingwave_sqlparser::ast::AsOf;
26
27use super::generic::{
28    GenericPlanNode, GenericPlanRef, push_down_into_join, push_down_join_condition,
29};
30use super::utils::{Distill, childless_record};
31use super::{
32    BackfillType, BatchPlanRef, ColPrunable, ExprRewritable, Logical, LogicalPlanRef as PlanRef,
33    PlanBase, PlanTreeNodeBinary, PredicatePushdown, StreamHashJoin, StreamPlanRef, StreamProject,
34    ToBatch, ToStream, generic, try_enforce_locality_requirement,
35};
36use crate::error::{ErrorCode, Result, RwError};
37use crate::expr::{CollectInputRef, Expr, ExprImpl, ExprRewriter, ExprType, ExprVisitor, InputRef};
38use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
39use crate::optimizer::plan_node::generic::DynamicFilter;
40use crate::optimizer::plan_node::stream_asof_join::StreamAsOfJoin;
41use crate::optimizer::plan_node::utils::IndicesDisplay;
42use crate::optimizer::plan_node::{
43    BatchHashJoin, BatchLookupJoin, BatchNestedLoopJoin, ColumnPruningContext, EqJoinPredicate,
44    LogicalFilter, LogicalScan, PredicatePushdownContext, RewriteStreamContext,
45    StreamDynamicFilter, StreamFilter, StreamTableScan, StreamTemporalJoin, ToStreamContext,
46};
47use crate::optimizer::plan_visitor::LogicalCardinalityExt;
48use crate::optimizer::property::{Distribution, RequiredDist};
49use crate::utils::{ColIndexMapping, ColIndexMappingRewriteExt, Condition, ConditionDisplay};
50
51/// `LogicalJoin` combines two relations according to some condition.
52///
53/// Each output row has fields from the left and right inputs. The set of output rows is a subset
54/// of the cartesian product of the two inputs; precisely which subset depends on the join
55/// condition. In addition, the output columns are a subset of the columns of the left and
56/// right columns, dependent on the output indices provided. A repeat output index is illegal.
57#[derive(Debug, Clone, PartialEq, Eq, Hash)]
58pub struct LogicalJoin {
59    pub base: PlanBase<Logical>,
60    core: generic::Join<PlanRef>,
61}
62
63impl Distill for LogicalJoin {
64    fn distill<'a>(&self) -> XmlNode<'a> {
65        let verbose = self.base.ctx().is_explain_verbose();
66        let mut vec = Vec::with_capacity(if verbose { 3 } else { 2 });
67        vec.push(("type", Pretty::debug(&self.join_type())));
68
69        let concat_schema = self.core.concat_schema();
70        let cond = Pretty::debug(&ConditionDisplay {
71            condition: self.on(),
72            input_schema: &concat_schema,
73        });
74        vec.push(("on", cond));
75
76        if verbose {
77            let data = IndicesDisplay::from_join(&self.core, &concat_schema);
78            vec.push(("output", data));
79        }
80
81        childless_record("LogicalJoin", vec)
82    }
83}
84
85impl LogicalJoin {
86    pub(crate) fn new(left: PlanRef, right: PlanRef, join_type: JoinType, on: Condition) -> Self {
87        let core = generic::Join::with_full_output(left, right, join_type, on);
88        Self::with_core(core)
89    }
90
91    pub(crate) fn with_output_indices(
92        left: PlanRef,
93        right: PlanRef,
94        join_type: JoinType,
95        on: Condition,
96        output_indices: Vec<usize>,
97    ) -> Self {
98        let core = generic::Join::new(left, right, on, join_type, output_indices);
99        Self::with_core(core)
100    }
101
102    pub fn with_core(core: generic::Join<PlanRef>) -> Self {
103        let base = PlanBase::new_logical_with_core(&core);
104        LogicalJoin { base, core }
105    }
106
107    pub fn create(
108        left: PlanRef,
109        right: PlanRef,
110        join_type: JoinType,
111        on_clause: ExprImpl,
112    ) -> PlanRef {
113        Self::new(left, right, join_type, Condition::with_expr(on_clause)).into()
114    }
115
116    pub fn internal_column_num(&self) -> usize {
117        self.core.internal_column_num()
118    }
119
120    pub fn i2l_col_mapping_ignore_join_type(&self) -> ColIndexMapping {
121        self.core.i2l_col_mapping_ignore_join_type()
122    }
123
124    pub fn i2r_col_mapping_ignore_join_type(&self) -> ColIndexMapping {
125        self.core.i2r_col_mapping_ignore_join_type()
126    }
127
128    /// Get a reference to the logical join's on.
129    pub fn on(&self) -> &Condition {
130        self.core
131            .on
132            .as_condition_ref()
133            .expect("logical join should store predicate as Condition")
134    }
135
136    pub fn core(&self) -> &generic::Join<PlanRef> {
137        &self.core
138    }
139
140    /// Collect all input ref in the on condition. And separate them into left and right.
141    pub fn input_idx_on_condition(&self) -> (Vec<usize>, Vec<usize>) {
142        let input_refs = self
143            .core
144            .on
145            .as_condition_ref()
146            .expect("logical join should store predicate as Condition")
147            .collect_input_refs(self.core.left.schema().len() + self.core.right.schema().len());
148        let index_group = input_refs
149            .ones()
150            .chunk_by(|i| *i < self.core.left.schema().len());
151        let left_index = index_group
152            .into_iter()
153            .next()
154            .map_or(vec![], |group| group.1.collect_vec());
155        let right_index = index_group.into_iter().next().map_or(vec![], |group| {
156            group
157                .1
158                .map(|i| i - self.core.left.schema().len())
159                .collect_vec()
160        });
161        (left_index, right_index)
162    }
163
164    /// Get the join type of the logical join.
165    pub fn join_type(&self) -> JoinType {
166        self.core.join_type
167    }
168
169    /// Get the eq join key of the logical join.
170    pub fn eq_indexes(&self) -> Vec<(usize, usize)> {
171        self.core.eq_indexes()
172    }
173
174    /// Get the output indices of the logical join.
175    pub fn output_indices(&self) -> &Vec<usize> {
176        &self.core.output_indices
177    }
178
179    /// Clone with new output indices
180    pub fn clone_with_output_indices(&self, output_indices: Vec<usize>) -> Self {
181        Self::with_core(generic::Join {
182            output_indices,
183            ..self.core.clone()
184        })
185    }
186
187    /// Clone with new `on` condition
188    pub fn clone_with_cond(&self, on: Condition) -> Self {
189        Self::with_core(generic::Join {
190            on: generic::JoinOn::Condition(on),
191            ..self.core.clone()
192        })
193    }
194
195    pub fn is_left_join(&self) -> bool {
196        matches!(self.join_type(), JoinType::LeftSemi | JoinType::LeftAnti)
197    }
198
199    pub fn is_right_join(&self) -> bool {
200        matches!(self.join_type(), JoinType::RightSemi | JoinType::RightAnti)
201    }
202
203    pub fn is_full_out(&self) -> bool {
204        self.core.is_full_out()
205    }
206
207    pub fn is_asof_join(&self) -> bool {
208        self.join_type() == JoinType::AsofInner || self.join_type() == JoinType::AsofLeftOuter
209    }
210
211    pub fn output_indices_are_trivial(&self) -> bool {
212        itertools::equal(
213            self.output_indices().iter().cloned(),
214            0..self.internal_column_num(),
215        )
216    }
217
218    /// Try to simplify the outer join with the predicate on the top of the join
219    ///
220    /// now it is just a naive implementation for comparison expression, we can give a more general
221    /// implementation with constant folding in future
222    fn simplify_outer(predicate: &Condition, left_col_num: usize, join_type: JoinType) -> JoinType {
223        let (mut gen_null_in_left, mut gen_null_in_right) = match join_type {
224            JoinType::LeftOuter => (false, true),
225            JoinType::RightOuter => (true, false),
226            JoinType::FullOuter => (true, true),
227            _ => return join_type,
228        };
229
230        for expr in &predicate.conjunctions {
231            if let ExprImpl::FunctionCall(func) = expr {
232                match func.func_type() {
233                    ExprType::Equal
234                    | ExprType::NotEqual
235                    | ExprType::LessThan
236                    | ExprType::LessThanOrEqual
237                    | ExprType::GreaterThan
238                    | ExprType::GreaterThanOrEqual => {
239                        for input in func.inputs() {
240                            if let ExprImpl::InputRef(input) = input {
241                                let idx = input.index;
242                                if idx < left_col_num {
243                                    gen_null_in_left = false;
244                                } else {
245                                    gen_null_in_right = false;
246                                }
247                            }
248                        }
249                    }
250                    _ => {}
251                };
252            }
253        }
254
255        match (gen_null_in_left, gen_null_in_right) {
256            (true, true) => JoinType::FullOuter,
257            (true, false) => JoinType::RightOuter,
258            (false, true) => JoinType::LeftOuter,
259            (false, false) => JoinType::Inner,
260        }
261    }
262
263    /// Index Join:
264    /// Try to convert logical join into batch lookup join and meanwhile it will do
265    /// the index selection for the lookup table so that we can benefit from indexes.
266    fn to_batch_lookup_join_with_index_selection(
267        &self,
268        predicate: EqJoinPredicate,
269        batch_join: generic::Join<BatchPlanRef>,
270    ) -> Result<Option<BatchLookupJoin>> {
271        match batch_join.join_type {
272            JoinType::Inner
273            | JoinType::LeftOuter
274            | JoinType::LeftSemi
275            | JoinType::LeftAnti
276            | JoinType::AsofInner
277            | JoinType::AsofLeftOuter => {}
278            _ => return Ok(None),
279        };
280
281        // Index selection for index join.
282        let right = self.right();
283        // Lookup Join only supports basic tables on the join's right side.
284        let logical_scan: &LogicalScan = if let Some(logical_scan) = right.as_logical_scan() {
285            logical_scan
286        } else {
287            return Ok(None);
288        };
289
290        let mut result_plan = None;
291        // Lookup primary table.
292        if let Some(lookup_join) =
293            self.to_batch_lookup_join(predicate.clone(), batch_join.clone())?
294        {
295            result_plan = Some(lookup_join);
296        }
297
298        if self
299            .core
300            .ctx()
301            .session_ctx()
302            .config()
303            .enable_index_selection()
304        {
305            let indexes = logical_scan.table_indexes();
306            for index in indexes {
307                if let Some(index_scan) = logical_scan.to_index_scan_if_index_covered(index) {
308                    let index_scan: PlanRef = index_scan.into();
309                    let that = self.clone_with_left_right(self.left(), index_scan.clone());
310                    let mut new_batch_join = batch_join.clone();
311                    new_batch_join.right =
312                        index_scan.to_batch().expect("index scan failed to batch");
313
314                    // Lookup covered index.
315                    if let Some(lookup_join) =
316                        that.to_batch_lookup_join(predicate.clone(), new_batch_join)?
317                    {
318                        match &result_plan {
319                            None => result_plan = Some(lookup_join),
320                            Some(prev_lookup_join) => {
321                                // Prefer to choose lookup join with longer lookup prefix len.
322                                if prev_lookup_join.lookup_prefix_len()
323                                    < lookup_join.lookup_prefix_len()
324                                {
325                                    result_plan = Some(lookup_join)
326                                }
327                            }
328                        }
329                    }
330                }
331            }
332        }
333
334        Ok(result_plan)
335    }
336
337    /// Try to convert logical join into batch lookup join.
338    fn to_batch_lookup_join(
339        &self,
340        predicate: EqJoinPredicate,
341        logical_join: generic::Join<BatchPlanRef>,
342    ) -> Result<Option<BatchLookupJoin>> {
343        let logical_scan: &LogicalScan =
344            if let Some(logical_scan) = self.core.right.as_logical_scan() {
345                logical_scan
346            } else {
347                return Ok(None);
348            };
349        Self::gen_batch_lookup_join(logical_scan, predicate, logical_join, self.is_asof_join())
350    }
351
352    pub fn gen_batch_lookup_join(
353        logical_scan: &LogicalScan,
354        predicate: EqJoinPredicate,
355        logical_join: generic::Join<BatchPlanRef>,
356        is_as_of: bool,
357    ) -> Result<Option<BatchLookupJoin>> {
358        match logical_join.join_type {
359            JoinType::Inner
360            | JoinType::LeftOuter
361            | JoinType::LeftSemi
362            | JoinType::LeftAnti
363            | JoinType::AsofInner
364            | JoinType::AsofLeftOuter => {}
365            _ => return Ok(None),
366        };
367
368        let table = logical_scan.table();
369        let output_column_ids = logical_scan.output_column_ids();
370
371        // Verify that the right join key columns are the the prefix of the primary key and
372        // also contain the distribution key.
373        let order_col_ids = table.order_column_ids();
374        let dist_key = table.distribution_key.clone();
375        // The at least prefix of order key that contains distribution key.
376        let mut dist_key_in_order_key_pos = vec![];
377        for d in dist_key {
378            let pos = table
379                .order_column_indices()
380                .position(|x| x == d)
381                .expect("dist_key must in order_key");
382            dist_key_in_order_key_pos.push(pos);
383        }
384        // The shortest prefix of order key that contains distribution key.
385        let shortest_prefix_len = dist_key_in_order_key_pos
386            .iter()
387            .max()
388            .map_or(0, |pos| pos + 1);
389
390        // Distributed lookup join can't support lookup table with a singleton distribution.
391        if shortest_prefix_len == 0 {
392            return Ok(None);
393        }
394
395        // Reorder the join equal predicate to match the order key.
396        let mut reorder_idx = Vec::with_capacity(shortest_prefix_len);
397        for order_col_id in order_col_ids {
398            let mut found = false;
399            for (i, eq_idx) in predicate.right_eq_indexes().into_iter().enumerate() {
400                if order_col_id == output_column_ids[eq_idx] {
401                    reorder_idx.push(i);
402                    found = true;
403                    break;
404                }
405            }
406            if !found {
407                break;
408            }
409        }
410        if reorder_idx.len() < shortest_prefix_len {
411            return Ok(None);
412        }
413        let lookup_prefix_len = reorder_idx.len();
414        let predicate = predicate.reorder(&reorder_idx);
415
416        // Extract the predicate from logical scan. Only pure scan is supported.
417        let (new_scan, scan_predicate, project_expr) = logical_scan.predicate_pull_up();
418        // Construct output column to require column mapping
419        let o2r = if let Some(project_expr) = project_expr {
420            project_expr
421                .into_iter()
422                .map(|x| x.as_input_ref().unwrap().index)
423                .collect_vec()
424        } else {
425            (0..logical_scan.output_col_idx().len()).collect_vec()
426        };
427        let left_schema_len = logical_join.left.schema().len();
428
429        let mut join_predicate_rewriter = LookupJoinPredicateRewriter {
430            offset: left_schema_len,
431            mapping: o2r.clone(),
432        };
433
434        let new_eq_cond = predicate
435            .eq_cond()
436            .rewrite_expr(&mut join_predicate_rewriter);
437
438        let mut scan_predicate_rewriter = LookupJoinScanPredicateRewriter {
439            offset: left_schema_len,
440        };
441
442        let new_other_cond = predicate
443            .other_cond()
444            .clone()
445            .rewrite_expr(&mut join_predicate_rewriter)
446            .and(scan_predicate.rewrite_expr(&mut scan_predicate_rewriter));
447
448        let new_join_on = new_eq_cond.and(new_other_cond);
449        let new_predicate =
450            EqJoinPredicate::create(left_schema_len, new_scan.schema().len(), new_join_on);
451
452        // We discovered that we cannot use a lookup join after pulling up the predicate
453        // from one side and simplifying the condition. Let's use some other join instead.
454        if !new_predicate.has_eq() {
455            return Ok(None);
456        }
457
458        // Rewrite the join output indices and all output indices referred to the old scan need to
459        // rewrite.
460        let new_join_output_indices = logical_join
461            .output_indices
462            .iter()
463            .map(|&x| {
464                if x < left_schema_len {
465                    x
466                } else {
467                    o2r[x - left_schema_len] + left_schema_len
468                }
469            })
470            .collect_vec();
471
472        let new_scan_output_column_ids = new_scan.output_column_ids();
473        let as_of = new_scan.as_of.clone();
474        let new_logical_scan: LogicalScan = new_scan.into();
475
476        // Construct a new logical join, because we have change its RHS.
477        let new_logical_join = generic::Join::new_with_eq_predicate(
478            logical_join.left,
479            new_logical_scan.to_batch()?,
480            new_predicate,
481            logical_join.join_type,
482            new_join_output_indices,
483        );
484
485        let asof_desc = is_as_of
486            .then(|| {
487                Self::get_inequality_desc_from_predicate(
488                    predicate.other_cond().clone(),
489                    left_schema_len,
490                )
491            })
492            .transpose()?;
493
494        Ok(Some(BatchLookupJoin::new(
495            new_logical_join,
496            table.clone(),
497            new_scan_output_column_ids,
498            lookup_prefix_len,
499            false,
500            as_of,
501            asof_desc,
502        )))
503    }
504
505    pub fn decompose(self) -> (PlanRef, PlanRef, Condition, JoinType, Vec<usize>) {
506        self.core.decompose()
507    }
508
509    fn dynamic_filter_candidate(&self, predicate: &Condition) -> Option<(usize, PbType)> {
510        // Dynamic filter only supports `Inner`/`LeftSemi`.
511        if !matches!(self.join_type(), JoinType::Inner | JoinType::LeftSemi) {
512            return None;
513        }
514
515        // Dynamic filter requires right side to be a scalar with one column.
516        if !self.right().max_one_row() || self.right().schema().len() != 1 {
517            return None;
518        }
519
520        // Dynamic filter only supports a single comparison predicate.
521        if predicate.conjunctions.len() > 1 {
522            return None;
523        }
524        let expr: ExprImpl = predicate.clone().into();
525        let (left_ref, comparator, right_ref) = expr.as_comparison_cond()?;
526
527        // Comparison must cross inputs: left input ref vs right scalar input ref.
528        let left_len = self.left().schema().len();
529        let condition_cross_inputs = left_ref.index < left_len && right_ref.index == left_len;
530        if !condition_cross_inputs {
531            return None;
532        }
533
534        // Comparison keys must be type aligned.
535        if self.left().schema().fields()[left_ref.index].data_type
536            != self.right().schema().fields()[0].data_type
537        {
538            return None;
539        }
540
541        // Dynamic filter output can only come from the left side.
542        if !self.output_indices().iter().all(|i| *i < left_len) {
543            return None;
544        }
545
546        Some((left_ref.index, comparator))
547    }
548
549    /// Check whether this join can be treated as a temporal filter for locality optimization.
550    ///
551    /// This is intentionally stricter than dynamic filter:
552    /// - right side must be a `LogicalNow`,
553    /// - and all dynamic filter preconditions must hold.
554    fn temporal_filter_candidate(&self) -> bool {
555        self.right().as_logical_now().is_some()
556            && self.dynamic_filter_candidate(self.on()).is_some()
557    }
558}
559
560impl PlanTreeNodeBinary<Logical> for LogicalJoin {
561    fn left(&self) -> PlanRef {
562        self.core.left.clone()
563    }
564
565    fn right(&self) -> PlanRef {
566        self.core.right.clone()
567    }
568
569    fn clone_with_left_right(&self, left: PlanRef, right: PlanRef) -> Self {
570        Self::with_core(generic::Join {
571            left,
572            right,
573            ..self.core.clone()
574        })
575    }
576
577    fn rewrite_with_left_right(
578        &self,
579        left: PlanRef,
580        left_col_change: ColIndexMapping,
581        right: PlanRef,
582        right_col_change: ColIndexMapping,
583    ) -> (Self, ColIndexMapping) {
584        let (new_on, new_output_indices) = {
585            let (mut map, _) = left_col_change.clone().into_parts();
586            let (mut right_map, _) = right_col_change.clone().into_parts();
587            for i in right_map.iter_mut().flatten() {
588                *i += left.schema().len();
589            }
590            map.append(&mut right_map);
591            let mut mapping = ColIndexMapping::new(map, left.schema().len() + right.schema().len());
592
593            let new_output_indices = self
594                .output_indices()
595                .iter()
596                .map(|&i| mapping.map(i))
597                .collect::<Vec<_>>();
598            let new_on = self.on().clone().rewrite_expr(&mut mapping);
599            (new_on, new_output_indices)
600        };
601
602        let join = Self::with_output_indices(
603            left,
604            right,
605            self.join_type(),
606            new_on,
607            new_output_indices.clone(),
608        );
609
610        let new_i2o = ColIndexMapping::with_remaining_columns(
611            &new_output_indices,
612            join.internal_column_num(),
613        );
614
615        let old_o2i = self.core.o2i_col_mapping();
616
617        let old_o2l = old_o2i
618            .composite(&self.core.i2l_col_mapping())
619            .composite(&left_col_change);
620        let old_o2r = old_o2i
621            .composite(&self.core.i2r_col_mapping())
622            .composite(&right_col_change);
623        let new_l2o = join.core.l2i_col_mapping().composite(&new_i2o);
624        let new_r2o = join.core.r2i_col_mapping().composite(&new_i2o);
625
626        let out_col_change = old_o2l
627            .composite(&new_l2o)
628            .union(&old_o2r.composite(&new_r2o));
629        (join, out_col_change)
630    }
631}
632
633impl_plan_tree_node_for_binary! { Logical, LogicalJoin }
634
635impl ColPrunable for LogicalJoin {
636    fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef {
637        // make `required_cols` point to internal table instead of output schema.
638        let required_cols = required_cols
639            .iter()
640            .map(|i| self.output_indices()[*i])
641            .collect_vec();
642        let left_len = self.left().schema().fields.len();
643
644        let total_len = self.left().schema().len() + self.right().schema().len();
645        let mut resized_required_cols = FixedBitSet::with_capacity(total_len);
646
647        required_cols.iter().for_each(|&i| {
648            if self.is_right_join() {
649                resized_required_cols.insert(left_len + i);
650            } else {
651                resized_required_cols.insert(i);
652            }
653        });
654
655        // add those columns which are required in the join condition to
656        // to those that are required in the output
657        let mut visitor = CollectInputRef::new(resized_required_cols);
658        self.on().visit_expr(&mut visitor);
659        let left_right_required_cols = FixedBitSet::from(visitor).ones().collect_vec();
660
661        let mut left_required_cols = Vec::new();
662        let mut right_required_cols = Vec::new();
663        left_right_required_cols.iter().for_each(|&i| {
664            if i < left_len {
665                left_required_cols.push(i);
666            } else {
667                right_required_cols.push(i - left_len);
668            }
669        });
670
671        let mut on = self.on().clone();
672        let mut mapping =
673            ColIndexMapping::with_remaining_columns(&left_right_required_cols, total_len);
674        on = on.rewrite_expr(&mut mapping);
675
676        let new_output_indices = {
677            let required_inputs_in_output = if self.is_left_join() {
678                &left_required_cols
679            } else if self.is_right_join() {
680                &right_required_cols
681            } else {
682                &left_right_required_cols
683            };
684
685            let mapping =
686                ColIndexMapping::with_remaining_columns(required_inputs_in_output, total_len);
687            required_cols.iter().map(|&i| mapping.map(i)).collect_vec()
688        };
689
690        LogicalJoin::with_output_indices(
691            self.left().prune_col(&left_required_cols, ctx),
692            self.right().prune_col(&right_required_cols, ctx),
693            self.join_type(),
694            on,
695            new_output_indices,
696        )
697        .into()
698    }
699}
700
701impl ExprRewritable<Logical> for LogicalJoin {
702    fn has_rewritable_expr(&self) -> bool {
703        true
704    }
705
706    fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
707        let mut core = self.core.clone();
708        core.rewrite_exprs(r);
709        Self {
710            base: self.base.clone_with_new_plan_id(),
711            core,
712        }
713        .into()
714    }
715}
716
717impl ExprVisitable for LogicalJoin {
718    fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
719        self.core.visit_exprs(v);
720    }
721}
722
723/// We are trying to derive a predicate to apply to the other side of a join if all
724/// the `InputRef`s in the predicate are eq condition columns, and can hence be substituted
725/// with the corresponding eq condition columns of the other side.
726///
727/// Strategy:
728/// 1. If the function is pure except for any `InputRef` (which may refer to impure computation),
729///    then we proceed. Else abort.
730/// 2. Then, we collect `InputRef`s in the conjunction.
731/// 3. If they are all columns in the given side of join eq condition, then we proceed. Else abort.
732/// 4. We then rewrite the `ExprImpl`, by replacing `InputRef` column indices with the equivalent in
733///    the other side.
734///
735/// # Arguments
736///
737/// Suppose we derive a predicate from the left side to be pushed to the right side.
738/// * `expr`: An expr from the left side.
739/// * `col_num`: The number of columns in the left side.
740fn derive_predicate_from_eq_condition(
741    expr: &ExprImpl,
742    eq_condition: &EqJoinPredicate,
743    col_num: usize,
744    expr_is_left: bool,
745) -> Option<ExprImpl> {
746    if expr.is_impure() {
747        return None;
748    }
749    let eq_indices = eq_condition
750        .eq_indexes_typed()
751        .iter()
752        .filter_map(|(l, r)| {
753            if l.return_type() != r.return_type() {
754                None
755            } else if expr_is_left {
756                Some(l.index())
757            } else {
758                Some(r.index())
759            }
760        })
761        .collect_vec();
762    if expr
763        .collect_input_refs(col_num)
764        .ones()
765        .any(|index| !eq_indices.contains(&index))
766    {
767        // expr contains an InputRef not in eq_condition
768        return None;
769    }
770    // The function is pure except for `InputRef` and all `InputRef`s are `eq_condition` indices.
771    // Hence, we can substitute those `InputRef`s with indices from the other side.
772    let other_side_mapping = if expr_is_left {
773        eq_condition.eq_indexes_typed().into_iter().collect()
774    } else {
775        eq_condition
776            .eq_indexes_typed()
777            .into_iter()
778            .map(|(x, y)| (y, x))
779            .collect()
780    };
781    struct InputRefsRewriter {
782        mapping: HashMap<InputRef, InputRef>,
783    }
784    impl ExprRewriter for InputRefsRewriter {
785        fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
786            self.mapping[&input_ref].clone().into()
787        }
788    }
789    Some(
790        InputRefsRewriter {
791            mapping: other_side_mapping,
792        }
793        .rewrite_expr(expr.clone()),
794    )
795}
796
797/// Rewrite the join predicate and all columns referred to the scan side need to rewrite.
798struct LookupJoinPredicateRewriter {
799    offset: usize,
800    mapping: Vec<usize>,
801}
802impl ExprRewriter for LookupJoinPredicateRewriter {
803    fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
804        if input_ref.index() < self.offset {
805            input_ref.into()
806        } else {
807            InputRef::new(
808                self.mapping[input_ref.index() - self.offset] + self.offset,
809                input_ref.return_type(),
810            )
811            .into()
812        }
813    }
814}
815
816/// Rewrite the scan predicate so we can add it to the join predicate.
817struct LookupJoinScanPredicateRewriter {
818    offset: usize,
819}
820impl ExprRewriter for LookupJoinScanPredicateRewriter {
821    fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
822        InputRef::new(input_ref.index() + self.offset, input_ref.return_type()).into()
823    }
824}
825
826impl PredicatePushdown for LogicalJoin {
827    /// Pushes predicates above and within a join node into the join node and/or its children nodes.
828    ///
829    /// # Which predicates can be pushed
830    ///
831    /// For inner join, we can do all kinds of pushdown.
832    ///
833    /// For left/right semi join, we can push filter to left/right and on-clause,
834    /// and push on-clause to left/right.
835    ///
836    /// For left/right anti join, we can push filter to left/right, but on-clause can not be pushed
837    ///
838    /// ## Outer Join
839    ///
840    /// Preserved Row table
841    /// : The table in an Outer Join that must return all rows.
842    ///
843    /// Null Supplying table
844    /// : This is the table that has nulls filled in for its columns in unmatched rows.
845    ///
846    /// |                          | Preserved Row table | Null Supplying table |
847    /// |--------------------------|---------------------|----------------------|
848    /// | Join predicate (on)      | Not Pushed          | Pushed               |
849    /// | Where predicate (filter) | Pushed              | Not Pushed           |
850    fn predicate_pushdown(
851        &self,
852        predicate: Condition,
853        ctx: &mut PredicatePushdownContext,
854    ) -> PlanRef {
855        // rewrite output col referencing indices as internal cols
856        let mut predicate = {
857            let mut mapping = self.core.o2i_col_mapping();
858            predicate.rewrite_expr(&mut mapping)
859        };
860
861        let left_col_num = self.left().schema().len();
862        let right_col_num = self.right().schema().len();
863        let join_type = LogicalJoin::simplify_outer(&predicate, left_col_num, self.join_type());
864
865        let push_down_temporal_predicate = self.temporal_join_on().is_none();
866
867        let (left_from_filter, right_from_filter, on) = push_down_into_join(
868            &mut predicate,
869            left_col_num,
870            right_col_num,
871            join_type,
872            push_down_temporal_predicate,
873        );
874
875        let mut new_on = self.on().clone().and(on);
876        let (left_from_on, right_from_on) = push_down_join_condition(
877            &mut new_on,
878            left_col_num,
879            right_col_num,
880            join_type,
881            push_down_temporal_predicate,
882        );
883
884        let left_predicate = left_from_filter.and(left_from_on);
885        let right_predicate = right_from_filter.and(right_from_on);
886
887        // Derive conditions to push to the other side based on eq condition columns
888        let eq_condition = EqJoinPredicate::create(left_col_num, right_col_num, new_on.clone());
889
890        // Only push to RHS if RHS is inner side of a join (RHS requires match on LHS)
891        let right_from_left = if matches!(
892            join_type,
893            JoinType::Inner | JoinType::LeftOuter | JoinType::RightSemi | JoinType::LeftSemi
894        ) {
895            Condition {
896                conjunctions: left_predicate
897                    .conjunctions
898                    .iter()
899                    .filter_map(|expr| {
900                        derive_predicate_from_eq_condition(expr, &eq_condition, left_col_num, true)
901                    })
902                    .collect(),
903            }
904        } else {
905            Condition::true_cond()
906        };
907
908        // Only push to LHS if LHS is inner side of a join (LHS requires match on RHS)
909        let left_from_right = if matches!(
910            join_type,
911            JoinType::Inner | JoinType::RightOuter | JoinType::LeftSemi | JoinType::RightSemi
912        ) {
913            Condition {
914                conjunctions: right_predicate
915                    .conjunctions
916                    .iter()
917                    .filter_map(|expr| {
918                        derive_predicate_from_eq_condition(
919                            expr,
920                            &eq_condition,
921                            right_col_num,
922                            false,
923                        )
924                    })
925                    .collect(),
926            }
927        } else {
928            Condition::true_cond()
929        };
930
931        let left_predicate = left_predicate.and(left_from_right);
932        let right_predicate = right_predicate.and(right_from_left);
933
934        let new_left = self.left().predicate_pushdown(left_predicate, ctx);
935        let new_right = self.right().predicate_pushdown(right_predicate, ctx);
936        let new_join = LogicalJoin::with_output_indices(
937            new_left,
938            new_right,
939            join_type,
940            new_on,
941            self.output_indices().clone(),
942        );
943
944        let mut mapping = self.core.i2o_col_mapping();
945        predicate = predicate.rewrite_expr(&mut mapping);
946        LogicalFilter::create(new_join.into(), predicate)
947    }
948}
949
950#[derive(Clone, Copy)]
951struct TemporalJoinScan<'a>(&'a LogicalScan);
952
953impl<'a> Deref for TemporalJoinScan<'a> {
954    type Target = LogicalScan;
955
956    fn deref(&self) -> &Self::Target {
957        self.0
958    }
959}
960
961impl LogicalJoin {
962    fn get_stream_input_for_hash_join(
963        &self,
964        predicate: &EqJoinPredicate,
965        ctx: &mut ToStreamContext,
966    ) -> Result<(StreamPlanRef, StreamPlanRef)> {
967        use super::stream::prelude::*;
968
969        let mut right = self.right().to_stream_with_dist_required(
970            &RequiredDist::shard_by_key(self.right().schema().len(), &predicate.right_eq_indexes()),
971            ctx,
972        )?;
973        let r2l =
974            predicate.r2l_eq_columns_mapping(self.left().schema().len(), right.schema().len());
975        let l2r =
976            predicate.l2r_eq_columns_mapping(self.left().schema().len(), right.schema().len());
977        let mut left;
978        let right_dist = right.distribution();
979        match right_dist {
980            Distribution::HashShard(_) => {
981                let left_dist = r2l
982                    .rewrite_required_distribution(&RequiredDist::PhysicalDist(right_dist.clone()));
983                left = self.left().to_stream_with_dist_required(&left_dist, ctx)?;
984            }
985            Distribution::UpstreamHashShard(_, _) => {
986                left = self.left().to_stream_with_dist_required(
987                    &RequiredDist::shard_by_key(
988                        self.left().schema().len(),
989                        &predicate.left_eq_indexes(),
990                    ),
991                    ctx,
992                )?;
993                let left_dist = left.distribution();
994                match left_dist {
995                    Distribution::HashShard(_) => {
996                        let right_dist = l2r.rewrite_required_distribution(
997                            &RequiredDist::PhysicalDist(left_dist.clone()),
998                        );
999                        right = right_dist.streaming_enforce_if_not_satisfies(right)?
1000                    }
1001                    Distribution::UpstreamHashShard(_, _) => {
1002                        left = RequiredDist::hash_shard(&predicate.left_eq_indexes())
1003                            .streaming_enforce_if_not_satisfies(left)?;
1004                        right = RequiredDist::hash_shard(&predicate.right_eq_indexes())
1005                            .streaming_enforce_if_not_satisfies(right)?;
1006                    }
1007                    _ => unreachable!(),
1008                }
1009            }
1010            _ => unreachable!(),
1011        }
1012        Ok((left, right))
1013    }
1014
1015    fn to_stream_hash_join(
1016        &self,
1017        predicate: EqJoinPredicate,
1018        ctx: &mut ToStreamContext,
1019    ) -> Result<StreamPlanRef> {
1020        use super::stream::prelude::*;
1021
1022        assert!(predicate.has_eq());
1023        let (left, right) = self.get_stream_input_for_hash_join(&predicate, ctx)?;
1024
1025        let mut core = self.core.clone_with_inputs(left, right);
1026        core.on = generic::JoinOn::EqPredicate(predicate);
1027
1028        // Convert to Hash Join for equal joins
1029        // For inner joins, pull non-equal conditions to a filter operator on top of it by default.
1030        // We do so as the filter operator can apply the non-equal condition batch-wise (vectorized)
1031        // as opposed to the HashJoin, which applies the condition row-wise.
1032        // However, the default behavior of pulling up non-equal conditions can be overridden by the
1033        // session variable `streaming_force_filter_inside_join` as it can save unnecessary
1034        // materialization of rows only to be filtered later.
1035
1036        let stream_hash_join = StreamHashJoin::new(core.clone())?;
1037        let predicate = stream_hash_join.eq_join_predicate().clone();
1038
1039        let force_filter_inside_join = self
1040            .base
1041            .ctx()
1042            .session_ctx()
1043            .config()
1044            .streaming_force_filter_inside_join();
1045
1046        let pull_filter = self.join_type() == JoinType::Inner
1047            && stream_hash_join.eq_join_predicate().has_non_eq()
1048            && stream_hash_join.inequality_pairs().is_empty()
1049            && (!force_filter_inside_join);
1050        if pull_filter {
1051            let default_indices = (0..self.internal_column_num()).collect::<Vec<_>>();
1052
1053            let mut core = core;
1054            core.output_indices = default_indices.clone();
1055            // Temporarily remove output indices.
1056            let eq_cond = EqJoinPredicate::new(
1057                Condition::true_cond(),
1058                predicate.eq_keys().to_vec(),
1059                self.left().schema().len(),
1060                self.right().schema().len(),
1061            );
1062            core.on = generic::JoinOn::EqPredicate(eq_cond);
1063            let hash_join = StreamHashJoin::new(core)?.into();
1064            let logical_filter = generic::Filter::new(predicate.non_eq_cond(), hash_join);
1065            let plan = StreamFilter::new(logical_filter).into();
1066            if self.output_indices() != &default_indices {
1067                let logical_project = generic::Project::with_mapping(
1068                    plan,
1069                    ColIndexMapping::with_remaining_columns(
1070                        self.output_indices(),
1071                        self.internal_column_num(),
1072                    ),
1073                );
1074                Ok(StreamProject::new(logical_project).into())
1075            } else {
1076                Ok(plan)
1077            }
1078        } else {
1079            Ok(stream_hash_join.into())
1080        }
1081    }
1082
1083    pub fn should_be_temporal_join(&self) -> bool {
1084        self.temporal_join_on().is_some()
1085    }
1086
1087    fn temporal_join_on(&self) -> Option<TemporalJoinScan<'_>> {
1088        if let Some(logical_scan) = self.core.right.as_logical_scan() {
1089            matches!(logical_scan.as_of(), Some(AsOf::ProcessTime))
1090                .then_some(TemporalJoinScan(logical_scan))
1091        } else {
1092            None
1093        }
1094    }
1095
1096    fn should_be_stream_temporal_join<'a>(
1097        &'a self,
1098        ctx: &ToStreamContext,
1099    ) -> Result<Option<TemporalJoinScan<'a>>> {
1100        Ok(if let Some(scan) = self.temporal_join_on() {
1101            if let BackfillType::SnapshotBackfill = ctx.backfill_type() {
1102                return Err(RwError::from(ErrorCode::NotSupported(
1103                    "Temporal join with snapshot backfill not supported".into(),
1104                    "Please use arrangement backfill".into(),
1105                )));
1106            }
1107            if scan.cross_database() {
1108                return Err(RwError::from(ErrorCode::NotSupported(
1109                        "Temporal join requires the lookup table to be in the same database as the stream source table".into(),
1110                        "Please ensure both tables are in the same database".into(),
1111                    )));
1112            }
1113            Some(scan)
1114        } else {
1115            None
1116        })
1117    }
1118
1119    fn to_stream_temporal_join_with_index_selection(
1120        &self,
1121        logical_scan: TemporalJoinScan<'_>,
1122        predicate: EqJoinPredicate,
1123        ctx: &mut ToStreamContext,
1124    ) -> Result<StreamPlanRef> {
1125        // Use primary table.
1126        let mut result_plan: Result<StreamTemporalJoin> =
1127            self.to_stream_temporal_join(logical_scan, predicate.clone(), ctx);
1128        // Return directly if this temporal join can match the pk of its right table.
1129        if let Ok(temporal_join) = &result_plan
1130            && temporal_join.eq_join_predicate().eq_indexes().len()
1131                == logical_scan.primary_key().len()
1132        {
1133            return result_plan.map(|x| x.into());
1134        }
1135        if self
1136            .core
1137            .ctx()
1138            .session_ctx()
1139            .config()
1140            .enable_index_selection()
1141        {
1142            let indexes = logical_scan.table_indexes();
1143            for index in indexes {
1144                // Use index table
1145                if let Some(index_scan) = logical_scan.to_index_scan_if_index_covered(index) {
1146                    let index_scan: PlanRef = index_scan.into();
1147                    let that = self.clone_with_left_right(self.left(), index_scan.clone());
1148                    if let Ok(temporal_join) = that.to_stream_temporal_join(
1149                        that.temporal_join_on().expect(
1150                            "index scan created from temporal join scan must also be temporal join",
1151                        ),
1152                        predicate.clone(),
1153                        ctx,
1154                    ) {
1155                        match &result_plan {
1156                            Err(_) => result_plan = Ok(temporal_join),
1157                            Ok(prev_temporal_join) => {
1158                                // Prefer to the temporal join with a longer lookup prefix len.
1159                                if prev_temporal_join.eq_join_predicate().eq_indexes().len()
1160                                    < temporal_join.eq_join_predicate().eq_indexes().len()
1161                                {
1162                                    result_plan = Ok(temporal_join)
1163                                }
1164                            }
1165                        }
1166                    }
1167                }
1168            }
1169        }
1170
1171        result_plan.map(|x| x.into())
1172    }
1173
1174    fn temporal_join_scan_predicate_pull_up(
1175        logical_scan: TemporalJoinScan<'_>,
1176        predicate: EqJoinPredicate,
1177        output_indices: &[usize],
1178        left_schema_len: usize,
1179    ) -> Result<(StreamTableScan, EqJoinPredicate, Condition, Vec<usize>)> {
1180        // Extract the predicate from logical scan. Only pure scan is supported.
1181        let (new_scan, scan_predicate, project_expr) = logical_scan.predicate_pull_up();
1182        // Construct output column to require column mapping
1183        let o2r = if let Some(project_expr) = project_expr {
1184            project_expr
1185                .into_iter()
1186                .map(|x| x.as_input_ref().unwrap().index)
1187                .collect_vec()
1188        } else {
1189            (0..logical_scan.output_col_idx().len()).collect_vec()
1190        };
1191        let mut join_predicate_rewriter = LookupJoinPredicateRewriter {
1192            offset: left_schema_len,
1193            mapping: o2r.clone(),
1194        };
1195
1196        let new_eq_cond = predicate
1197            .eq_cond()
1198            .rewrite_expr(&mut join_predicate_rewriter);
1199
1200        let mut scan_predicate_rewriter = LookupJoinScanPredicateRewriter {
1201            offset: left_schema_len,
1202        };
1203
1204        let new_other_cond = predicate
1205            .other_cond()
1206            .clone()
1207            .rewrite_expr(&mut join_predicate_rewriter)
1208            .and(scan_predicate.rewrite_expr(&mut scan_predicate_rewriter));
1209
1210        let new_join_on = new_eq_cond.and(new_other_cond);
1211
1212        let new_predicate = EqJoinPredicate::create(
1213            left_schema_len,
1214            new_scan.schema().len(),
1215            new_join_on.clone(),
1216        );
1217
1218        // Rewrite the join output indices and all output indices referred to the old scan need to
1219        // rewrite.
1220        let new_join_output_indices = output_indices
1221            .iter()
1222            .map(|&x| {
1223                if x < left_schema_len {
1224                    x
1225                } else {
1226                    o2r[x - left_schema_len] + left_schema_len
1227                }
1228            })
1229            .collect_vec();
1230
1231        let new_stream_table_scan =
1232            StreamTableScan::new_with_stream_scan_type(new_scan, StreamScanType::UpstreamOnly);
1233        Ok((
1234            new_stream_table_scan,
1235            new_predicate,
1236            new_join_on,
1237            new_join_output_indices,
1238        ))
1239    }
1240
1241    fn to_stream_temporal_join(
1242        &self,
1243        logical_scan: TemporalJoinScan<'_>,
1244        predicate: EqJoinPredicate,
1245        ctx: &mut ToStreamContext,
1246    ) -> Result<StreamTemporalJoin> {
1247        use super::stream::prelude::*;
1248
1249        assert!(predicate.has_eq());
1250
1251        let table = logical_scan.table();
1252        let output_column_ids = logical_scan.output_column_ids();
1253
1254        // Verify that the right join key columns are the the prefix of the primary key and
1255        // also contain the distribution key.
1256        let order_col_ids = table.order_column_ids();
1257        let dist_key = table.distribution_key.clone();
1258
1259        let mut dist_key_in_order_key_pos = vec![];
1260        for d in dist_key {
1261            let pos = table
1262                .order_column_indices()
1263                .position(|x| x == d)
1264                .expect("dist_key must in order_key");
1265            dist_key_in_order_key_pos.push(pos);
1266        }
1267        // The shortest prefix of order key that contains distribution key.
1268        let shortest_prefix_len = dist_key_in_order_key_pos
1269            .iter()
1270            .max()
1271            .map_or(0, |pos| pos + 1);
1272
1273        // Reorder the join equal predicate to match the order key.
1274        let mut reorder_idx = Vec::with_capacity(shortest_prefix_len);
1275        for order_col_id in order_col_ids {
1276            let mut found = false;
1277            for (i, eq_idx) in predicate.right_eq_indexes().into_iter().enumerate() {
1278                if order_col_id == output_column_ids[eq_idx] {
1279                    reorder_idx.push(i);
1280                    found = true;
1281                    break;
1282                }
1283            }
1284            if !found {
1285                break;
1286            }
1287        }
1288        if reorder_idx.len() < shortest_prefix_len {
1289            return Err(RwError::from(ErrorCode::NotSupported(
1290                "Temporal join requires the equivalence join condition includes the key columns that form the distribution key of the lookup table".into(),
1291                concat!(
1292                    "Use DESCRIBE <table_name> to view the table's key information.\n",
1293                    "You can create an index on the lookup table to facilitate the temporal join if necessary."
1294                ).into(),
1295            )));
1296        }
1297        let lookup_prefix_len = reorder_idx.len();
1298        let predicate = predicate.reorder(&reorder_idx);
1299
1300        let required_dist = if dist_key_in_order_key_pos.is_empty() {
1301            RequiredDist::single()
1302        } else {
1303            let left_eq_indexes = predicate.left_eq_indexes();
1304            let left_dist_key = dist_key_in_order_key_pos
1305                .iter()
1306                .map(|pos| left_eq_indexes[*pos])
1307                .collect_vec();
1308
1309            RequiredDist::hash_shard(&left_dist_key)
1310        };
1311
1312        let left = self.left().to_stream(ctx)?;
1313        // Enforce a shuffle for the temporal join LHS to let the scheduler be able to schedule the join fragment together with the RHS with a `no_shuffle` exchange.
1314        let left = required_dist.stream_enforce(left);
1315
1316        let (new_stream_table_scan, new_predicate, new_join_on, new_join_output_indices) =
1317            Self::temporal_join_scan_predicate_pull_up(
1318                logical_scan,
1319                predicate,
1320                self.output_indices(),
1321                self.left().schema().len(),
1322            )?;
1323
1324        let right = RequiredDist::no_shuffle(new_stream_table_scan.into());
1325        if !new_predicate.has_eq() {
1326            return Err(RwError::from(ErrorCode::NotSupported(
1327                "Temporal join requires a non trivial join condition".into(),
1328                "Please remove the false condition of the join".into(),
1329            )));
1330        }
1331
1332        // Construct a new logical join, because we have change its RHS.
1333        let new_logical_join = generic::Join::new(
1334            left,
1335            right,
1336            new_join_on,
1337            self.join_type(),
1338            new_join_output_indices,
1339        );
1340
1341        let new_predicate = new_predicate.retain_prefix_eq_key(lookup_prefix_len);
1342
1343        let mut new_logical_join = new_logical_join;
1344        new_logical_join.on = generic::JoinOn::EqPredicate(new_predicate);
1345        StreamTemporalJoin::new(new_logical_join, false)
1346    }
1347
1348    fn to_stream_nested_loop_temporal_join(
1349        &self,
1350        logical_scan: TemporalJoinScan<'_>,
1351        predicate: EqJoinPredicate,
1352        ctx: &mut ToStreamContext,
1353    ) -> Result<StreamPlanRef> {
1354        use super::stream::prelude::*;
1355        assert!(!predicate.has_eq());
1356
1357        let left = self.left().to_stream_with_dist_required(
1358            &RequiredDist::PhysicalDist(Distribution::Broadcast),
1359            ctx,
1360        )?;
1361        assert!(left.as_stream_exchange().is_some());
1362
1363        if self.join_type() != JoinType::Inner {
1364            return Err(RwError::from(ErrorCode::NotSupported(
1365                "Temporal join requires an inner join".into(),
1366                "Please use an inner join".into(),
1367            )));
1368        }
1369
1370        if !left.append_only() {
1371            return Err(RwError::from(ErrorCode::NotSupported(
1372                "Nested-loop Temporal join requires the left hash side to be append only".into(),
1373                "Please ensure the left hash side is append only".into(),
1374            )));
1375        }
1376
1377        let (new_stream_table_scan, new_predicate, new_join_on, new_join_output_indices) =
1378            Self::temporal_join_scan_predicate_pull_up(
1379                logical_scan,
1380                predicate,
1381                self.output_indices(),
1382                self.left().schema().len(),
1383            )?;
1384
1385        let right = RequiredDist::no_shuffle(new_stream_table_scan.into());
1386
1387        // Construct a new logical join, because we have change its RHS.
1388        let new_logical_join = generic::Join::new(
1389            left,
1390            right,
1391            new_join_on,
1392            self.join_type(),
1393            new_join_output_indices,
1394        );
1395
1396        let mut new_logical_join = new_logical_join;
1397        new_logical_join.on = generic::JoinOn::EqPredicate(new_predicate);
1398        Ok(StreamTemporalJoin::new(new_logical_join, true)?.into())
1399    }
1400
1401    fn to_stream_dynamic_filter(
1402        &self,
1403        predicate: Condition,
1404        ctx: &mut ToStreamContext,
1405    ) -> Result<Option<StreamPlanRef>> {
1406        use super::stream::prelude::*;
1407
1408        // If there is exactly one predicate, it is a comparison (<, <=, >, >=), and the
1409        // join is a `Inner` or `LeftSemi` join, we can convert the scalar subquery into a
1410        // `StreamDynamicFilter`
1411        let Some((left_key_idx, comparator)) = self.dynamic_filter_candidate(&predicate) else {
1412            return Ok(None);
1413        };
1414
1415        let left = self.left().to_stream(ctx)?.enforce_concrete_distribution();
1416        let right = self.right().to_stream_with_dist_required(
1417            &RequiredDist::PhysicalDist(Distribution::Broadcast),
1418            ctx,
1419        )?;
1420
1421        assert!(right.as_stream_exchange().is_some());
1422        assert_eq!(
1423            *right.inputs().iter().exactly_one().unwrap().distribution(),
1424            Distribution::Single
1425        );
1426
1427        let core = DynamicFilter::new(comparator, left_key_idx, left, right);
1428        let plan = StreamDynamicFilter::new(core)?.into();
1429        // TODO: `DynamicFilterExecutor` should support `output_indices` in `ChunkBuilder`
1430        if self
1431            .output_indices()
1432            .iter()
1433            .copied()
1434            .ne(0..self.left().schema().len())
1435        {
1436            // The schema of dynamic filter is always the same as the left side now, and we have
1437            // checked that all output columns are from the left side before.
1438            let logical_project = generic::Project::with_mapping(
1439                plan,
1440                ColIndexMapping::with_remaining_columns(
1441                    self.output_indices(),
1442                    self.left().schema().len(),
1443                ),
1444            );
1445            Ok(Some(StreamProject::new(logical_project).into()))
1446        } else {
1447            Ok(Some(plan))
1448        }
1449    }
1450
1451    pub fn index_lookup_join_to_batch_lookup_join(&self) -> Result<BatchPlanRef> {
1452        let predicate = EqJoinPredicate::create(
1453            self.left().schema().len(),
1454            self.right().schema().len(),
1455            self.on().clone(),
1456        );
1457        assert!(predicate.has_eq());
1458
1459        let join = self
1460            .core
1461            .clone_with_inputs(self.core.left.to_batch()?, self.core.right.to_batch()?);
1462
1463        Ok(self
1464            .to_batch_lookup_join(predicate, join)?
1465            .expect("Fail to convert to lookup join")
1466            .into())
1467    }
1468
1469    fn to_stream_asof_join(
1470        &self,
1471        predicate: EqJoinPredicate,
1472        ctx: &mut ToStreamContext,
1473    ) -> Result<StreamPlanRef> {
1474        use super::stream::prelude::*;
1475
1476        if predicate.eq_keys().is_empty() {
1477            return Err(ErrorCode::InvalidInputSyntax(
1478                "AsOf join requires at least 1 equal condition".to_owned(),
1479            )
1480            .into());
1481        }
1482
1483        let (left, right) = self.get_stream_input_for_hash_join(&predicate, ctx)?;
1484        let left_len = left.schema().len();
1485        let mut core = self.core.clone_with_inputs(left, right);
1486        core.on = generic::JoinOn::EqPredicate(predicate);
1487
1488        let inequality_desc = Self::get_inequality_desc_from_predicate(
1489            core.on
1490                .as_eq_predicate_ref()
1491                .expect("core predicate must exist")
1492                .other_cond()
1493                .clone(),
1494            left_len,
1495        )?;
1496
1497        Ok(StreamAsOfJoin::new(core, inequality_desc)?.into())
1498    }
1499
1500    /// Convert the logical join to a Hash join.
1501    fn to_batch_hash_join(
1502        &self,
1503        logical_join: generic::Join<BatchPlanRef>,
1504        predicate: EqJoinPredicate,
1505    ) -> Result<BatchPlanRef> {
1506        use super::batch::prelude::*;
1507
1508        let left_schema_len = logical_join.left.schema().len();
1509        let asof_desc = self
1510            .is_asof_join()
1511            .then(|| {
1512                Self::get_inequality_desc_from_predicate(
1513                    predicate.other_cond().clone(),
1514                    left_schema_len,
1515                )
1516            })
1517            .transpose()?;
1518
1519        let logical_join = generic::Join {
1520            on: generic::JoinOn::EqPredicate(predicate),
1521            ..logical_join
1522        };
1523        let batch_join = BatchHashJoin::new(logical_join, asof_desc);
1524        Ok(batch_join.into())
1525    }
1526
1527    pub fn get_inequality_desc_from_predicate(
1528        predicate: Condition,
1529        left_input_len: usize,
1530    ) -> Result<AsOfJoinDesc> {
1531        let expr: ExprImpl = predicate.into();
1532        if let Some((left_input_ref, expr_type, right_input_ref)) = expr.as_comparison_cond() {
1533            if left_input_ref.index() < left_input_len && right_input_ref.index() >= left_input_len
1534            {
1535                Ok(AsOfJoinDesc {
1536                    left_idx: left_input_ref.index() as u32,
1537                    right_idx: (right_input_ref.index() - left_input_len) as u32,
1538                    inequality_type: Self::expr_type_to_comparison_type(expr_type)?.into(),
1539                })
1540            } else {
1541                bail!("inequal condition from the same side should be push down in optimizer");
1542            }
1543        } else {
1544            Err(ErrorCode::InvalidInputSyntax(
1545                "AsOf join requires exactly 1 ineuquality condition".to_owned(),
1546            )
1547            .into())
1548        }
1549    }
1550
1551    fn expr_type_to_comparison_type(expr_type: PbType) -> Result<PbAsOfJoinInequalityType> {
1552        match expr_type {
1553            PbType::LessThan => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeLt),
1554            PbType::LessThanOrEqual => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeLe),
1555            PbType::GreaterThan => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeGt),
1556            PbType::GreaterThanOrEqual => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeGe),
1557            _ => Err(ErrorCode::InvalidInputSyntax(format!(
1558                "Invalid comparison type: {}",
1559                expr_type.as_str_name()
1560            ))
1561            .into()),
1562        }
1563    }
1564}
1565
1566impl ToBatch for LogicalJoin {
1567    fn to_batch(&self) -> Result<crate::optimizer::plan_node::BatchPlanRef> {
1568        let predicate = EqJoinPredicate::create(
1569            self.left().schema().len(),
1570            self.right().schema().len(),
1571            self.on().clone(),
1572        );
1573
1574        let batch_join = self
1575            .core
1576            .clone_with_inputs(self.core.left.to_batch()?, self.core.right.to_batch()?);
1577
1578        let ctx = self.base.ctx();
1579        let config = ctx.session_ctx().config();
1580
1581        if predicate.has_eq() {
1582            if !predicate.eq_keys_are_type_aligned() {
1583                return Err(ErrorCode::InternalError(format!(
1584                    "Join eq keys are not aligned for predicate: {predicate:?}"
1585                ))
1586                .into());
1587            }
1588            if config.batch_enable_lookup_join()
1589                && let Some(lookup_join) = self.to_batch_lookup_join_with_index_selection(
1590                    predicate.clone(),
1591                    batch_join.clone(),
1592                )?
1593            {
1594                return Ok(lookup_join.into());
1595            }
1596            self.to_batch_hash_join(batch_join, predicate)
1597        } else if self.is_asof_join() {
1598            Err(ErrorCode::InvalidInputSyntax(
1599                "AsOf join requires at least 1 equal condition".to_owned(),
1600            )
1601            .into())
1602        } else {
1603            // Convert to Nested-loop Join for non-equal joins
1604            Ok(BatchNestedLoopJoin::new(batch_join).into())
1605        }
1606    }
1607}
1608
1609impl ToStream for LogicalJoin {
1610    fn to_stream(
1611        &self,
1612        ctx: &mut ToStreamContext,
1613    ) -> Result<crate::optimizer::plan_node::StreamPlanRef> {
1614        if self
1615            .on()
1616            .conjunctions
1617            .iter()
1618            .any(|cond| cond.count_nows() > 0)
1619        {
1620            return Err(ErrorCode::NotSupported(
1621                "optimizer has tried to separate the temporal predicate(with now() expression) from the on condition, but it still reminded in on join's condition. Considering move it into WHERE clause?".to_owned(),
1622                 "please refer to https://docs.risingwave.com/processing/sql/temporal-filters for more information".to_owned()).into());
1623        }
1624
1625        let predicate = EqJoinPredicate::create(
1626            self.left().schema().len(),
1627            self.right().schema().len(),
1628            self.on().clone(),
1629        );
1630
1631        if self.join_type() == JoinType::AsofInner || self.join_type() == JoinType::AsofLeftOuter {
1632            self.to_stream_asof_join(predicate, ctx)
1633        } else if predicate.has_eq() {
1634            if !predicate.eq_keys_are_type_aligned() {
1635                return Err(ErrorCode::InternalError(format!(
1636                    "Join eq keys are not aligned for predicate: {predicate:?}"
1637                ))
1638                .into());
1639            }
1640
1641            if let Some(scan) = self.should_be_stream_temporal_join(ctx)? {
1642                self.to_stream_temporal_join_with_index_selection(scan, predicate, ctx)
1643            } else {
1644                self.to_stream_hash_join(predicate, ctx)
1645            }
1646        } else if let Some(scan) = self.should_be_stream_temporal_join(ctx)? {
1647            self.to_stream_nested_loop_temporal_join(scan, predicate, ctx)
1648        } else if let Some(dynamic_filter) =
1649            self.to_stream_dynamic_filter(self.on().clone(), ctx)?
1650        {
1651            Ok(dynamic_filter)
1652        } else {
1653            Err(RwError::from(ErrorCode::NotSupported(
1654                "streaming nested-loop join".to_owned(),
1655                "The non-equal join in the query requires a nested-loop join executor, which could be very expensive to run. \
1656                 Consider rewriting the query to use dynamic filter as a substitute if possible.\n\
1657                 See also: https://docs.risingwave.com/processing/sql/dynamic-filters".to_owned(),
1658            )))
1659        }
1660    }
1661
1662    fn logical_rewrite_for_stream(
1663        &self,
1664        ctx: &mut RewriteStreamContext,
1665    ) -> Result<(PlanRef, ColIndexMapping)> {
1666        let eq_indexes = self.eq_indexes();
1667        let (logical_left, logical_right) = if eq_indexes.is_empty() {
1668            (self.left(), self.right())
1669        } else {
1670            let lhs_join_key_idx = eq_indexes.iter().map(|(l, _)| *l).collect_vec();
1671            if self.should_be_temporal_join() {
1672                (
1673                    try_enforce_locality_requirement(self.left(), &lhs_join_key_idx),
1674                    self.right(),
1675                )
1676            } else {
1677                let rhs_join_key_idx = eq_indexes.iter().map(|(_, r)| *r).collect_vec();
1678                (
1679                    try_enforce_locality_requirement(self.left(), &lhs_join_key_idx),
1680                    try_enforce_locality_requirement(self.right(), &rhs_join_key_idx),
1681                )
1682            }
1683        };
1684
1685        let (left, left_col_change) = logical_left.logical_rewrite_for_stream(ctx)?;
1686        let left_len = left.schema().len();
1687        let (right, right_col_change) = logical_right.logical_rewrite_for_stream(ctx)?;
1688        let (join, out_col_change) = self.rewrite_with_left_right(
1689            left.clone(),
1690            left_col_change,
1691            right.clone(),
1692            right_col_change,
1693        );
1694
1695        let mapping = ColIndexMapping::with_remaining_columns(
1696            join.output_indices(),
1697            join.internal_column_num(),
1698        );
1699
1700        let l2o = join.core.l2i_col_mapping().composite(&mapping);
1701        let r2o = join.core.r2i_col_mapping().composite(&mapping);
1702
1703        // Add missing pk indices to the logical join
1704        let mut left_to_add = left
1705            .expect_stream_key()
1706            .iter()
1707            .cloned()
1708            .filter(|i| l2o.try_map(*i).is_none())
1709            .collect_vec();
1710
1711        let mut right_to_add = right
1712            .expect_stream_key()
1713            .iter()
1714            .filter(|&&i| r2o.try_map(i).is_none())
1715            .map(|&i| i + left_len)
1716            .collect_vec();
1717
1718        // NOTE(st1page): add join keys in the pk_indices a work around before we really have stream
1719        // key.
1720        let right_len = right.schema().len();
1721        let eq_predicate = EqJoinPredicate::create(left_len, right_len, join.on().clone());
1722
1723        let either_or_both = self.core.add_which_join_key_to_pk();
1724
1725        for (lk, rk) in eq_predicate.eq_indexes() {
1726            match either_or_both {
1727                EitherOrBoth::Left(_) => {
1728                    if l2o.try_map(lk).is_none() {
1729                        left_to_add.push(lk);
1730                    }
1731                }
1732                EitherOrBoth::Right(_) => {
1733                    if r2o.try_map(rk).is_none() {
1734                        right_to_add.push(rk + left_len)
1735                    }
1736                }
1737                EitherOrBoth::Both(_, _) => {
1738                    if l2o.try_map(lk).is_none() {
1739                        left_to_add.push(lk);
1740                    }
1741                    if r2o.try_map(rk).is_none() {
1742                        right_to_add.push(rk + left_len)
1743                    }
1744                }
1745            };
1746        }
1747        let left_to_add = left_to_add.into_iter().unique();
1748        let right_to_add = right_to_add.into_iter().unique();
1749        // NOTE(st1page) over
1750
1751        let mut new_output_indices = join.output_indices().clone();
1752        if !join.is_right_join() {
1753            new_output_indices.extend(left_to_add);
1754        }
1755        if !join.is_left_join() {
1756            new_output_indices.extend(right_to_add);
1757        }
1758
1759        let join_with_pk = join.clone_with_output_indices(new_output_indices);
1760
1761        let plan = if join_with_pk.join_type() == JoinType::FullOuter {
1762            // ignore the all NULL to maintain the stream key's uniqueness, see https://github.com/risingwavelabs/risingwave/issues/8084 for more information
1763
1764            let l2o = join_with_pk
1765                .core
1766                .l2i_col_mapping()
1767                .composite(&join_with_pk.core.i2o_col_mapping());
1768            let r2o = join_with_pk
1769                .core
1770                .r2i_col_mapping()
1771                .composite(&join_with_pk.core.i2o_col_mapping());
1772            let left_right_stream_keys = join_with_pk
1773                .left()
1774                .expect_stream_key()
1775                .iter()
1776                .map(|i| l2o.map(*i))
1777                .chain(
1778                    join_with_pk
1779                        .right()
1780                        .expect_stream_key()
1781                        .iter()
1782                        .map(|i| r2o.map(*i)),
1783                )
1784                .collect_vec();
1785            let plan: PlanRef = join_with_pk.into();
1786            LogicalFilter::filter_out_all_null_keys(plan, &left_right_stream_keys)
1787        } else {
1788            join_with_pk.into()
1789        };
1790
1791        // the added columns is at the end, so it will not change the exists column index
1792        Ok((plan, out_col_change))
1793    }
1794
1795    fn try_better_locality(&self, columns: &[usize]) -> Option<PlanRef> {
1796        // Only propagate locality for temporal-filter.
1797        if !self.temporal_filter_candidate() {
1798            return None;
1799        }
1800
1801        // Temporal filter only outputs columns from left input, so mapping is safe.
1802        let o2i_mapping = self.core.o2i_col_mapping();
1803        let left_input_columns = columns
1804            .iter()
1805            .map(|&col| o2i_mapping.try_map(col))
1806            .collect::<Option<Vec<usize>>>()?;
1807        if let Some(better_left_plan) = self.left().try_better_locality(&left_input_columns) {
1808            return Some(
1809                self.clone_with_left_right(better_left_plan, self.right())
1810                    .into(),
1811            );
1812        }
1813        None
1814    }
1815}
1816
1817#[cfg(test)]
1818mod tests {
1819
1820    use std::collections::HashSet;
1821
1822    use risingwave_common::catalog::{Field, Schema};
1823    use risingwave_common::types::{DataType, Datum};
1824    use risingwave_pb::expr::expr_node::Type;
1825
1826    use super::*;
1827    use crate::expr::{FunctionCall, Literal, assert_eq_input_ref};
1828    use crate::optimizer::optimizer_context::OptimizerContext;
1829    use crate::optimizer::plan_node::LogicalValues;
1830    use crate::optimizer::property::FunctionalDependency;
1831
1832    /// Pruning
1833    /// ```text
1834    /// Join(on: input_ref(1)=input_ref(3))
1835    ///   TableScan(v1, v2, v3)
1836    ///   TableScan(v4, v5, v6)
1837    /// ```
1838    /// with required columns [2,3] will result in
1839    /// ```text
1840    /// Project(input_ref(1), input_ref(2))
1841    ///   Join(on: input_ref(0)=input_ref(2))
1842    ///     TableScan(v2, v3)
1843    ///     TableScan(v4)
1844    /// ```
1845    #[tokio::test]
1846    async fn test_prune_join() {
1847        let ty = DataType::Int32;
1848        let ctx = OptimizerContext::mock().await;
1849        let fields: Vec<Field> = (1..7)
1850            .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
1851            .collect();
1852        let left = LogicalValues::new(
1853            vec![],
1854            Schema {
1855                fields: fields[0..3].to_vec(),
1856            },
1857            ctx.clone(),
1858        );
1859        let right = LogicalValues::new(
1860            vec![],
1861            Schema {
1862                fields: fields[3..6].to_vec(),
1863            },
1864            ctx,
1865        );
1866        let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
1867            FunctionCall::new(
1868                Type::Equal,
1869                vec![
1870                    ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
1871                    ExprImpl::InputRef(Box::new(InputRef::new(3, ty))),
1872                ],
1873            )
1874            .unwrap(),
1875        ));
1876        let join_type = JoinType::Inner;
1877        let join: PlanRef = LogicalJoin::new(
1878            left.into(),
1879            right.into(),
1880            join_type,
1881            Condition::with_expr(on),
1882        )
1883        .into();
1884
1885        // Perform the prune
1886        let required_cols = vec![2, 3];
1887        let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1888
1889        // Check the result
1890        let join = plan.as_logical_join().unwrap();
1891        assert_eq!(join.schema().fields().len(), 2);
1892        assert_eq!(join.schema().fields()[0], fields[2]);
1893        assert_eq!(join.schema().fields()[1], fields[3]);
1894
1895        let expr: ExprImpl = join.on().clone().into();
1896        let call = expr.as_function_call().unwrap();
1897        assert_eq_input_ref!(&call.inputs()[0], 0);
1898        assert_eq_input_ref!(&call.inputs()[1], 2);
1899
1900        let left = join.left();
1901        let left = left.as_logical_values().unwrap();
1902        assert_eq!(left.schema().fields(), &fields[1..3]);
1903        let right = join.right();
1904        let right = right.as_logical_values().unwrap();
1905        assert_eq!(right.schema().fields(), &fields[3..4]);
1906    }
1907
1908    /// Semi join panicked previously at `prune_col`. Add test to prevent regression.
1909    #[tokio::test]
1910    async fn test_prune_semi_join() {
1911        let ty = DataType::Int32;
1912        let ctx = OptimizerContext::mock().await;
1913        let fields: Vec<Field> = (1..7)
1914            .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
1915            .collect();
1916        let left = LogicalValues::new(
1917            vec![],
1918            Schema {
1919                fields: fields[0..3].to_vec(),
1920            },
1921            ctx.clone(),
1922        );
1923        let right = LogicalValues::new(
1924            vec![],
1925            Schema {
1926                fields: fields[3..6].to_vec(),
1927            },
1928            ctx,
1929        );
1930        let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
1931            FunctionCall::new(
1932                Type::Equal,
1933                vec![
1934                    ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
1935                    ExprImpl::InputRef(Box::new(InputRef::new(4, ty))),
1936                ],
1937            )
1938            .unwrap(),
1939        ));
1940        for join_type in [
1941            JoinType::LeftSemi,
1942            JoinType::RightSemi,
1943            JoinType::LeftAnti,
1944            JoinType::RightAnti,
1945        ] {
1946            let join = LogicalJoin::new(
1947                left.clone().into(),
1948                right.clone().into(),
1949                join_type,
1950                Condition::with_expr(on.clone()),
1951            );
1952
1953            let offset = if join.is_right_join() { 3 } else { 0 };
1954            let join: PlanRef = join.into();
1955            // Perform the prune
1956            let required_cols = vec![0];
1957            // key 0 is never used in the join (always key 1)
1958            let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1959            let as_plan = plan.as_logical_join().unwrap();
1960            // Check the result
1961            assert_eq!(as_plan.schema().fields().len(), 1);
1962            assert_eq!(as_plan.schema().fields()[0], fields[offset]);
1963
1964            // Perform the prune
1965            let required_cols = vec![0, 1, 2];
1966            // should not panic here
1967            let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1968            let as_plan = plan.as_logical_join().unwrap();
1969            // Check the result
1970            assert_eq!(as_plan.schema().fields().len(), 3);
1971            assert_eq!(as_plan.schema().fields()[0], fields[offset]);
1972            assert_eq!(as_plan.schema().fields()[1], fields[offset + 1]);
1973            assert_eq!(as_plan.schema().fields()[2], fields[offset + 2]);
1974        }
1975    }
1976
1977    /// Pruning
1978    /// ```text
1979    /// Join(on: input_ref(1)=input_ref(3))
1980    ///   TableScan(v1, v2, v3)
1981    ///   TableScan(v4, v5, v6)
1982    /// ```
1983    /// with required columns [1, 3] will result in
1984    /// ```text
1985    /// Join(on: input_ref(0)=input_ref(1))
1986    ///   TableScan(v2)
1987    ///   TableScan(v4)
1988    /// ```
1989    #[tokio::test]
1990    async fn test_prune_join_no_project() {
1991        let ty = DataType::Int32;
1992        let ctx = OptimizerContext::mock().await;
1993        let fields: Vec<Field> = (1..7)
1994            .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
1995            .collect();
1996        let left = LogicalValues::new(
1997            vec![],
1998            Schema {
1999                fields: fields[0..3].to_vec(),
2000            },
2001            ctx.clone(),
2002        );
2003        let right = LogicalValues::new(
2004            vec![],
2005            Schema {
2006                fields: fields[3..6].to_vec(),
2007            },
2008            ctx,
2009        );
2010        let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
2011            FunctionCall::new(
2012                Type::Equal,
2013                vec![
2014                    ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
2015                    ExprImpl::InputRef(Box::new(InputRef::new(3, ty))),
2016                ],
2017            )
2018            .unwrap(),
2019        ));
2020        let join_type = JoinType::Inner;
2021        let join: PlanRef = LogicalJoin::new(
2022            left.into(),
2023            right.into(),
2024            join_type,
2025            Condition::with_expr(on),
2026        )
2027        .into();
2028
2029        // Perform the prune
2030        let required_cols = vec![1, 3];
2031        let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
2032
2033        // Check the result
2034        let join = plan.as_logical_join().unwrap();
2035        assert_eq!(join.schema().fields().len(), 2);
2036        assert_eq!(join.schema().fields()[0], fields[1]);
2037        assert_eq!(join.schema().fields()[1], fields[3]);
2038
2039        let expr: ExprImpl = join.on().clone().into();
2040        let call = expr.as_function_call().unwrap();
2041        assert_eq_input_ref!(&call.inputs()[0], 0);
2042        assert_eq_input_ref!(&call.inputs()[1], 1);
2043
2044        let left = join.left();
2045        let left = left.as_logical_values().unwrap();
2046        assert_eq!(left.schema().fields(), &fields[1..2]);
2047        let right = join.right();
2048        let right = right.as_logical_values().unwrap();
2049        assert_eq!(right.schema().fields(), &fields[3..4]);
2050    }
2051
2052    /// Convert
2053    /// ```text
2054    /// Join(on: ($1 = $3) AND ($2 == 42))
2055    ///   TableScan(v1, v2, v3)
2056    ///   TableScan(v4, v5, v6)
2057    /// ```
2058    /// to
2059    /// ```text
2060    /// Filter($2 == 42)
2061    ///   HashJoin(on: $1 = $3)
2062    ///     TableScan(v1, v2, v3)
2063    ///     TableScan(v4, v5, v6)
2064    /// ```
2065    #[tokio::test]
2066    async fn test_join_to_batch() {
2067        let ctx = OptimizerContext::mock().await;
2068        let fields: Vec<Field> = (1..7)
2069            .map(|i| Field::with_name(DataType::Int32, format!("v{}", i)))
2070            .collect();
2071        let left = LogicalValues::new(
2072            vec![],
2073            Schema {
2074                fields: fields[0..3].to_vec(),
2075            },
2076            ctx.clone(),
2077        );
2078        let right = LogicalValues::new(
2079            vec![],
2080            Schema {
2081                fields: fields[3..6].to_vec(),
2082            },
2083            ctx,
2084        );
2085
2086        fn input_ref(i: usize) -> ExprImpl {
2087            ExprImpl::InputRef(Box::new(InputRef::new(i, DataType::Int32)))
2088        }
2089        let eq_cond = ExprImpl::FunctionCall(Box::new(
2090            FunctionCall::new(Type::Equal, vec![input_ref(1), input_ref(3)]).unwrap(),
2091        ));
2092        let non_eq_cond = ExprImpl::FunctionCall(Box::new(
2093            FunctionCall::new(
2094                Type::Equal,
2095                vec![
2096                    input_ref(2),
2097                    ExprImpl::Literal(Box::new(Literal::new(
2098                        Datum::Some(42_i32.into()),
2099                        DataType::Int32,
2100                    ))),
2101                ],
2102            )
2103            .unwrap(),
2104        ));
2105        // Condition: ($1 = $3) AND ($2 == 42)
2106        let on_cond = ExprImpl::FunctionCall(Box::new(
2107            FunctionCall::new(Type::And, vec![eq_cond.clone(), non_eq_cond.clone()]).unwrap(),
2108        ));
2109
2110        let join_type = JoinType::Inner;
2111        let logical_join = LogicalJoin::new(
2112            left.into(),
2113            right.into(),
2114            join_type,
2115            Condition::with_expr(on_cond),
2116        );
2117
2118        // Perform `to_batch`
2119        let result = logical_join.to_batch().unwrap();
2120
2121        // Expected plan:  HashJoin($1 = $3 AND $2 == 42)
2122        let hash_join = result.as_batch_hash_join().unwrap();
2123        assert_eq!(
2124            ExprImpl::from(hash_join.eq_join_predicate().eq_cond()),
2125            eq_cond
2126        );
2127        assert_eq!(
2128            *hash_join
2129                .eq_join_predicate()
2130                .non_eq_cond()
2131                .conjunctions
2132                .first()
2133                .unwrap(),
2134            non_eq_cond
2135        );
2136    }
2137
2138    /// Convert
2139    /// ```text
2140    /// Join(join_type: left outer, on: ($1 = $3) AND ($2 == 42))
2141    ///   TableScan(v1, v2, v3)
2142    ///   TableScan(v4, v5, v6)
2143    /// ```
2144    /// to
2145    /// ```text
2146    /// HashJoin(join_type: left outer, on: ($1 = $3) AND ($2 == 42))
2147    ///   TableScan(v1, v2, v3)
2148    ///   TableScan(v4, v5, v6)
2149    /// ```
2150    #[tokio::test]
2151    #[ignore] // ignore due to refactor logical scan, but the test seem to duplicate with the explain test
2152    // framework, maybe we will remove it?
2153    async fn test_join_to_stream() {
2154        // let ctx = Rc::new(RefCell::new(QueryContext::mock().await));
2155        // let fields: Vec<Field> = (1..7)
2156        //     .map(|i| Field {
2157        //         data_type: DataType::Int32,
2158        //         name: format!("v{}", i),
2159        //     })
2160        //     .collect();
2161        // let left = LogicalScan::new(
2162        //     "left".to_string(),
2163        //     TableId::new(0),
2164        //     vec![1.into(), 2.into(), 3.into()],
2165        //     Schema {
2166        //         fields: fields[0..3].to_vec(),
2167        //     },
2168        //     ctx.clone(),
2169        // );
2170        // let right = LogicalScan::new(
2171        //     "right".to_string(),
2172        //     TableId::new(0),
2173        //     vec![4.into(), 5.into(), 6.into()],
2174        //     Schema {
2175        //                 fields: fields[3..6].to_vec(),
2176        //     },
2177        //     ctx,
2178        // );
2179        // let eq_cond = ExprImpl::FunctionCall(Box::new(
2180        //     FunctionCall::new(
2181        //         Type::Equal,
2182        //         vec![
2183        //             ExprImpl::InputRef(Box::new(InputRef::new(1, DataType::Int32))),
2184        //             ExprImpl::InputRef(Box::new(InputRef::new(3, DataType::Int32))),
2185        //         ],
2186        //     )
2187        //     .unwrap(),
2188        // ));
2189        // let non_eq_cond = ExprImpl::FunctionCall(Box::new(
2190        //     FunctionCall::new(
2191        //         Type::Equal,
2192        //         vec![
2193        //             ExprImpl::InputRef(Box::new(InputRef::new(2, DataType::Int32))),
2194        //             ExprImpl::Literal(Box::new(Literal::new(
2195        //                 Datum::Some(42_i32.into()),
2196        //                 DataType::Int32,
2197        //             ))),
2198        //         ],
2199        //     )
2200        //     .unwrap(),
2201        // ));
2202        // // Condition: ($1 = $3) AND ($2 == 42)
2203        // let on_cond = ExprImpl::FunctionCall(Box::new(
2204        //     FunctionCall::new(Type::And, vec![eq_cond, non_eq_cond]).unwrap(),
2205        // ));
2206
2207        // let join_type = JoinType::LeftOuter;
2208        // let logical_join = LogicalJoin::new(
2209        //     left.clone().into(),
2210        //     right.clone().into(),
2211        //     join_type,
2212        //     Condition::with_expr(on_cond.clone()),
2213        // );
2214
2215        // // Perform `to_stream`
2216        // let result = logical_join.to_stream();
2217
2218        // // Expected plan: HashJoin(($1 = $3) AND ($2 == 42))
2219        // let hash_join = result.as_stream_hash_join().unwrap();
2220        // assert_eq!(hash_join.eq_join_predicate().all_cond().as_expr(), on_cond);
2221    }
2222    /// Pruning
2223    /// ```text
2224    /// Join(on: input_ref(1)=input_ref(3))
2225    ///   TableScan(v1, v2, v3)
2226    ///   TableScan(v4, v5, v6)
2227    /// ```
2228    /// with required columns [3, 2] will result in
2229    /// ```text
2230    /// Project(input_ref(2), input_ref(1))
2231    ///   Join(on: input_ref(0)=input_ref(2))
2232    ///     TableScan(v2, v3)
2233    ///     TableScan(v4)
2234    /// ```
2235    #[tokio::test]
2236    async fn test_join_column_prune_with_order_required() {
2237        let ty = DataType::Int32;
2238        let ctx = OptimizerContext::mock().await;
2239        let fields: Vec<Field> = (1..7)
2240            .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
2241            .collect();
2242        let left = LogicalValues::new(
2243            vec![],
2244            Schema {
2245                fields: fields[0..3].to_vec(),
2246            },
2247            ctx.clone(),
2248        );
2249        let right = LogicalValues::new(
2250            vec![],
2251            Schema {
2252                fields: fields[3..6].to_vec(),
2253            },
2254            ctx,
2255        );
2256        let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
2257            FunctionCall::new(
2258                Type::Equal,
2259                vec![
2260                    ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
2261                    ExprImpl::InputRef(Box::new(InputRef::new(3, ty))),
2262                ],
2263            )
2264            .unwrap(),
2265        ));
2266        let join_type = JoinType::Inner;
2267        let join: PlanRef = LogicalJoin::new(
2268            left.into(),
2269            right.into(),
2270            join_type,
2271            Condition::with_expr(on),
2272        )
2273        .into();
2274
2275        // Perform the prune
2276        let required_cols = vec![3, 2];
2277        let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
2278
2279        // Check the result
2280        let join = plan.as_logical_join().unwrap();
2281        assert_eq!(join.schema().fields().len(), 2);
2282        assert_eq!(join.schema().fields()[0], fields[3]);
2283        assert_eq!(join.schema().fields()[1], fields[2]);
2284
2285        let expr: ExprImpl = join.on().clone().into();
2286        let call = expr.as_function_call().unwrap();
2287        assert_eq_input_ref!(&call.inputs()[0], 0);
2288        assert_eq_input_ref!(&call.inputs()[1], 2);
2289
2290        let left = join.left();
2291        let left = left.as_logical_values().unwrap();
2292        assert_eq!(left.schema().fields(), &fields[1..3]);
2293        let right = join.right();
2294        let right = right.as_logical_values().unwrap();
2295        assert_eq!(right.schema().fields(), &fields[3..4]);
2296    }
2297
2298    #[tokio::test]
2299    async fn fd_derivation_inner_outer_join() {
2300        // left: [l0, l1], right: [r0, r1, r2]
2301        // FD: l0 --> l1, r0 --> { r1, r2 }
2302        // On: l0 = 0 AND l1 = r1
2303        //
2304        // Inner Join:
2305        //  Schema: [l0, l1, r0, r1, r2]
2306        //  FD: l0 --> l1, r0 --> { r1, r2 }, {} --> l0, l1 --> r1, r1 --> l1
2307        // Left Outer Join:
2308        //  Schema: [l0, l1, r0, r1, r2]
2309        //  FD: l0 --> l1
2310        // Right Outer Join:
2311        //  Schema: [l0, l1, r0, r1, r2]
2312        //  FD: r0 --> { r1, r2 }
2313        // Full Outer Join:
2314        //  Schema: [l0, l1, r0, r1, r2]
2315        //  FD: empty
2316        // Left Semi/Anti Join:
2317        //  Schema: [l0, l1]
2318        //  FD: l0 --> l1
2319        // Right Semi/Anti Join:
2320        //  Schema: [r0, r1, r2]
2321        //  FD: r0 --> {r1, r2}
2322        let ctx = OptimizerContext::mock().await;
2323        let left = {
2324            let fields: Vec<Field> = vec![
2325                Field::with_name(DataType::Int32, "l0"),
2326                Field::with_name(DataType::Int32, "l1"),
2327            ];
2328            let mut values = LogicalValues::new(vec![], Schema { fields }, ctx.clone());
2329            // 0 --> 1
2330            values
2331                .base
2332                .functional_dependency_mut()
2333                .add_functional_dependency_by_column_indices(&[0], &[1]);
2334            values
2335        };
2336        let right = {
2337            let fields: Vec<Field> = vec![
2338                Field::with_name(DataType::Int32, "r0"),
2339                Field::with_name(DataType::Int32, "r1"),
2340                Field::with_name(DataType::Int32, "r2"),
2341            ];
2342            let mut values = LogicalValues::new(vec![], Schema { fields }, ctx);
2343            // 0 --> 1, 2
2344            values
2345                .base
2346                .functional_dependency_mut()
2347                .add_functional_dependency_by_column_indices(&[0], &[1, 2]);
2348            values
2349        };
2350        // l0 = 0 AND l1 = r1
2351        let on: ExprImpl = FunctionCall::new(
2352            Type::And,
2353            vec![
2354                FunctionCall::new(
2355                    Type::Equal,
2356                    vec![
2357                        InputRef::new(0, DataType::Int32).into(),
2358                        ExprImpl::literal_int(0),
2359                    ],
2360                )
2361                .unwrap()
2362                .into(),
2363                FunctionCall::new(
2364                    Type::Equal,
2365                    vec![
2366                        InputRef::new(1, DataType::Int32).into(),
2367                        InputRef::new(3, DataType::Int32).into(),
2368                    ],
2369                )
2370                .unwrap()
2371                .into(),
2372            ],
2373        )
2374        .unwrap()
2375        .into();
2376        let expected_fd_set = [
2377            (
2378                JoinType::Inner,
2379                [
2380                    // inherit from left
2381                    FunctionalDependency::with_indices(5, &[0], &[1]),
2382                    // inherit from right
2383                    FunctionalDependency::with_indices(5, &[2], &[3, 4]),
2384                    // constant column in join condition
2385                    FunctionalDependency::with_indices(5, &[], &[0]),
2386                    // eq column in join condition
2387                    FunctionalDependency::with_indices(5, &[1], &[3]),
2388                    FunctionalDependency::with_indices(5, &[3], &[1]),
2389                ]
2390                .into_iter()
2391                .collect::<HashSet<_>>(),
2392            ),
2393            (JoinType::FullOuter, HashSet::new()),
2394            (
2395                JoinType::RightOuter,
2396                [
2397                    // inherit from right
2398                    FunctionalDependency::with_indices(5, &[2], &[3, 4]),
2399                ]
2400                .into_iter()
2401                .collect::<HashSet<_>>(),
2402            ),
2403            (
2404                JoinType::LeftOuter,
2405                [
2406                    // inherit from left
2407                    FunctionalDependency::with_indices(5, &[0], &[1]),
2408                ]
2409                .into_iter()
2410                .collect::<HashSet<_>>(),
2411            ),
2412            (
2413                JoinType::LeftSemi,
2414                [
2415                    // inherit from left
2416                    FunctionalDependency::with_indices(2, &[0], &[1]),
2417                ]
2418                .into_iter()
2419                .collect::<HashSet<_>>(),
2420            ),
2421            (
2422                JoinType::LeftAnti,
2423                [
2424                    // inherit from left
2425                    FunctionalDependency::with_indices(2, &[0], &[1]),
2426                ]
2427                .into_iter()
2428                .collect::<HashSet<_>>(),
2429            ),
2430            (
2431                JoinType::RightSemi,
2432                [
2433                    // inherit from right
2434                    FunctionalDependency::with_indices(3, &[0], &[1, 2]),
2435                ]
2436                .into_iter()
2437                .collect::<HashSet<_>>(),
2438            ),
2439            (
2440                JoinType::RightAnti,
2441                [
2442                    // inherit from right
2443                    FunctionalDependency::with_indices(3, &[0], &[1, 2]),
2444                ]
2445                .into_iter()
2446                .collect::<HashSet<_>>(),
2447            ),
2448        ];
2449
2450        for (join_type, expected_res) in expected_fd_set {
2451            let join = LogicalJoin::new(
2452                left.clone().into(),
2453                right.clone().into(),
2454                join_type,
2455                Condition::with_expr(on.clone()),
2456            );
2457            let fd_set = join
2458                .functional_dependency()
2459                .as_dependencies()
2460                .iter()
2461                .cloned()
2462                .collect::<HashSet<_>>();
2463            assert_eq!(fd_set, expected_res);
2464        }
2465    }
2466}