risingwave_frontend/optimizer/plan_node/
logical_multi_join.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::cmp::Ordering;
16use std::collections::{BTreeMap, BTreeSet, VecDeque};
17
18use itertools::Itertools;
19use pretty_xmlish::{Pretty, XmlNode};
20use risingwave_common::catalog::Schema;
21use risingwave_pb::plan_common::JoinType;
22
23use super::utils::{Distill, childless_record};
24use super::{
25    ColPrunable, ExprRewritable, Logical, LogicalFilter, LogicalJoin, LogicalProject, PlanBase,
26    PlanNodeType, PlanRef, PlanTreeNodeBinary, PlanTreeNodeUnary, PredicatePushdown, ToBatch,
27    ToStream,
28};
29use crate::error::{ErrorCode, Result, RwError};
30use crate::expr::{ExprImpl, ExprRewriter, ExprType, ExprVisitor, FunctionCall};
31use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
32use crate::optimizer::plan_node::{
33    ColumnPruningContext, PlanTreeNode, PredicatePushdownContext, RewriteStreamContext,
34    ToStreamContext,
35};
36use crate::optimizer::plan_visitor::TemporalJoinValidator;
37use crate::optimizer::property::FunctionalDependencySet;
38use crate::utils::{
39    ColIndexMapping, ColIndexMappingRewriteExt, Condition, ConditionDisplay,
40    ConnectedComponentLabeller,
41};
42
43/// `LogicalMultiJoin` combines two or more relations according to some condition.
44///
45/// Each output row has fields from one the inputs. The set of output rows is a subset
46/// of the cartesian product of all the inputs; The `LogicalMultiInnerJoin` is only supported
47/// for inner joins as it implicitly assumes commutativity. Non-inner joins should be
48/// expressed as 2-way `LogicalJoin`s.
49#[derive(Debug, Clone, PartialEq, Eq, Hash)]
50pub struct LogicalMultiJoin {
51    pub base: PlanBase<Logical>,
52    inputs: Vec<PlanRef>,
53    on: Condition,
54    output_indices: Vec<usize>,
55    inner2output: ColIndexMapping,
56    // NOTE(st1page): these fields will be used in prune_col and
57    // pk_derive soon.
58    /// the mapping `output_col_idx` -> (`input_idx`, `input_col_idx`), **"`output_col_idx`" is internal,
59    /// not consider `output_indices`**
60    inner_o2i_mapping: Vec<(usize, usize)>,
61    inner_i2o_mappings: Vec<ColIndexMapping>,
62}
63
64impl Distill for LogicalMultiJoin {
65    fn distill<'a>(&self) -> XmlNode<'a> {
66        let fields = (self.inputs.iter())
67            .flat_map(|input| input.schema().fields.clone())
68            .collect();
69        let input_schema = Schema { fields };
70        let cond = Pretty::display(&ConditionDisplay {
71            condition: self.on(),
72            input_schema: &input_schema,
73        });
74        childless_record("LogicalMultiJoin", vec![("on", cond)])
75    }
76}
77
78#[derive(Debug, Clone, PartialEq, Eq, Hash)]
79pub struct LogicalMultiJoinBuilder {
80    output_indices: Vec<usize>,
81    /// the predicates in the on condition, we do not use Condition here to emit unnecessary
82    /// simplify.
83    conjunctions: Vec<ExprImpl>,
84    inputs: Vec<PlanRef>,
85    tot_input_col_num: usize,
86}
87
88impl LogicalMultiJoinBuilder {
89    /// add a predicate above the plan, so they will be rewritten from the `output_indices` to the
90    /// input indices
91    pub fn add_predicate_above(&mut self, exprs: impl Iterator<Item = ExprImpl>) {
92        let mut mapping = ColIndexMapping::new(
93            self.output_indices.iter().map(|i| Some(*i)).collect(),
94            self.tot_input_col_num,
95        );
96        self.conjunctions
97            .extend(exprs.map(|expr| mapping.rewrite_expr(expr)));
98    }
99
100    pub fn build(self) -> LogicalMultiJoin {
101        LogicalMultiJoin::new(
102            self.inputs,
103            Condition {
104                conjunctions: self.conjunctions,
105            },
106            self.output_indices,
107        )
108    }
109
110    pub fn into_parts(self) -> (Vec<usize>, Vec<ExprImpl>, Vec<PlanRef>, usize) {
111        (
112            self.output_indices,
113            self.conjunctions,
114            self.inputs,
115            self.tot_input_col_num,
116        )
117    }
118
119    pub fn new(plan: PlanRef) -> LogicalMultiJoinBuilder {
120        match plan.node_type() {
121            PlanNodeType::LogicalJoin => Self::with_join(plan),
122            PlanNodeType::LogicalFilter => Self::with_filter(plan),
123            PlanNodeType::LogicalProject => Self::with_project(plan),
124            _ => Self::with_input(plan),
125        }
126    }
127
128    fn with_join(plan: PlanRef) -> LogicalMultiJoinBuilder {
129        let join: &LogicalJoin = plan.as_logical_join().unwrap();
130        if join.join_type() != JoinType::Inner {
131            return Self::with_input(plan);
132        }
133        let left = join.left();
134        let right = join.right();
135
136        if TemporalJoinValidator::exist_dangling_temporal_scan(left.clone()) {
137            return Self::with_input(plan);
138        }
139        if TemporalJoinValidator::exist_dangling_temporal_scan(right.clone()) {
140            return Self::with_input(plan);
141        }
142
143        let mut builder = Self::new(left);
144
145        let (r_output_indices, r_conjunctions, mut r_inputs, r_tot_input_col_num) =
146            Self::new(right).into_parts();
147
148        // the mapping from the right's column index to the current multi join's internal column
149        // index
150        let mut shift_mapping = ColIndexMapping::with_shift_offset(
151            r_tot_input_col_num,
152            builder.tot_input_col_num as isize,
153        );
154        builder.inputs.append(&mut r_inputs);
155        builder.tot_input_col_num += r_tot_input_col_num;
156
157        builder.conjunctions.extend(
158            r_conjunctions
159                .into_iter()
160                .map(|expr| shift_mapping.rewrite_expr(expr)),
161        );
162
163        builder.output_indices.extend(
164            r_output_indices
165                .into_iter()
166                .map(|idx| shift_mapping.map(idx)),
167        );
168        builder.add_predicate_above(join.on().conjunctions.iter().cloned());
169
170        builder.output_indices = join
171            .output_indices()
172            .iter()
173            .map(|idx| builder.output_indices[*idx])
174            .collect();
175        builder
176    }
177
178    fn with_filter(plan: PlanRef) -> LogicalMultiJoinBuilder {
179        let filter: &LogicalFilter = plan.as_logical_filter().unwrap();
180        let mut builder = Self::new(filter.input());
181        builder.add_predicate_above(filter.predicate().conjunctions.iter().cloned());
182        builder
183    }
184
185    fn with_project(plan: PlanRef) -> LogicalMultiJoinBuilder {
186        let proj: &LogicalProject = plan.as_logical_project().unwrap();
187        let output_indices = match proj.try_as_projection() {
188            Some(output_indices) => output_indices,
189            None => return Self::with_input(plan),
190        };
191        let mut builder = Self::new(proj.input());
192        builder.output_indices = output_indices
193            .into_iter()
194            .map(|i| builder.output_indices[i])
195            .collect();
196        builder
197    }
198
199    fn with_input(input: PlanRef) -> LogicalMultiJoinBuilder {
200        LogicalMultiJoinBuilder {
201            output_indices: (0..input.schema().len()).collect_vec(),
202            conjunctions: vec![],
203            tot_input_col_num: input.schema().len(),
204            inputs: vec![input],
205        }
206    }
207
208    pub fn inputs(&self) -> &[PlanRef] {
209        self.inputs.as_ref()
210    }
211}
212impl LogicalMultiJoin {
213    pub(crate) fn new(inputs: Vec<PlanRef>, on: Condition, output_indices: Vec<usize>) -> Self {
214        let input_schemas = inputs
215            .iter()
216            .map(|input| input.schema().clone())
217            .collect_vec();
218
219        let (inner_o2i_mapping, tot_col_num) = {
220            let mut inner_o2i_mapping = vec![];
221            let mut tot_col_num = 0;
222            for (input_idx, input_schema) in input_schemas.iter().enumerate() {
223                tot_col_num += input_schema.len();
224                for (col_idx, _field) in input_schema.fields().iter().enumerate() {
225                    inner_o2i_mapping.push((input_idx, col_idx));
226                }
227            }
228            (inner_o2i_mapping, tot_col_num)
229        };
230        let inner2output = ColIndexMapping::with_remaining_columns(&output_indices, tot_col_num);
231
232        let schema = Schema {
233            fields: output_indices
234                .iter()
235                .map(|idx| inner_o2i_mapping[*idx])
236                .map(|(input_idx, col_idx)| input_schemas[input_idx].fields()[col_idx].clone())
237                .collect(),
238        };
239
240        let inner_i2o_mappings = {
241            let mut i2o_maps = vec![];
242            for input_schema in &input_schemas {
243                let map = vec![None; input_schema.len()];
244                i2o_maps.push(map);
245            }
246            for (out_idx, (input_idx, in_idx)) in inner_o2i_mapping.iter().enumerate() {
247                i2o_maps[*input_idx][*in_idx] = Some(out_idx);
248            }
249
250            i2o_maps
251                .into_iter()
252                .map(|map| ColIndexMapping::new(map, tot_col_num))
253                .collect_vec()
254        };
255
256        let pk_indices = Self::derive_stream_key(&inputs, &inner_i2o_mappings, &inner2output);
257        let functional_dependency = {
258            let mut fd_set = FunctionalDependencySet::new(tot_col_num);
259            let mut column_cnt: usize = 0;
260            let id_mapping = ColIndexMapping::identity(tot_col_num);
261            for i in &inputs {
262                let mapping =
263                    ColIndexMapping::with_shift_offset(i.schema().len(), column_cnt as isize)
264                        .composite(&id_mapping);
265                mapping
266                    .rewrite_functional_dependency_set(i.functional_dependency().clone())
267                    .into_dependencies()
268                    .into_iter()
269                    .for_each(|fd| fd_set.add_functional_dependency(fd));
270                column_cnt += i.schema().len();
271            }
272            for i in &on.conjunctions {
273                if let Some((col, _)) = i.as_eq_const() {
274                    fd_set.add_constant_columns(&[col.index()])
275                } else if let Some((left, right)) = i.as_eq_cond() {
276                    fd_set.add_functional_dependency_by_column_indices(
277                        &[left.index()],
278                        &[right.index()],
279                    );
280                    fd_set.add_functional_dependency_by_column_indices(
281                        &[right.index()],
282                        &[left.index()],
283                    );
284                }
285            }
286            ColIndexMapping::with_remaining_columns(&output_indices, tot_col_num)
287                .rewrite_functional_dependency_set(fd_set)
288        };
289        let base =
290            PlanBase::new_logical(inputs[0].ctx(), schema, pk_indices, functional_dependency);
291
292        Self {
293            base,
294            inputs,
295            on,
296            output_indices,
297            inner2output,
298            inner_o2i_mapping,
299            inner_i2o_mappings,
300        }
301    }
302
303    fn derive_stream_key(
304        inputs: &[PlanRef],
305        inner_i2o_mappings: &[ColIndexMapping],
306        inner2output: &ColIndexMapping,
307    ) -> Option<Vec<usize>> {
308        // TODO(st1page): add JOIN key
309        let mut pk_indices = vec![];
310        for (i, input) in inputs.iter().enumerate() {
311            let input_stream_key = input.stream_key()?;
312            for input_pk_idx in input_stream_key {
313                pk_indices.push(inner_i2o_mappings[i].map(*input_pk_idx));
314            }
315        }
316        pk_indices
317            .into_iter()
318            .map(|col_idx| inner2output.try_map(col_idx))
319            .collect::<Option<Vec<_>>>()
320    }
321
322    /// Get a reference to the logical join's on.
323    pub fn on(&self) -> &Condition {
324        &self.on
325    }
326
327    /// Clone with new `on` condition
328    pub fn clone_with_cond(&self, cond: Condition) -> Self {
329        Self::new(self.inputs.clone(), cond, self.output_indices.clone())
330    }
331}
332
333impl PlanTreeNode for LogicalMultiJoin {
334    fn inputs(&self) -> smallvec::SmallVec<[crate::optimizer::PlanRef; 2]> {
335        let mut vec = smallvec::SmallVec::new();
336        vec.extend(self.inputs.clone());
337        vec
338    }
339
340    fn clone_with_inputs(&self, inputs: &[crate::optimizer::PlanRef]) -> PlanRef {
341        Self::new(
342            inputs.to_vec(),
343            self.on().clone(),
344            self.output_indices.clone(),
345        )
346        .into()
347    }
348}
349
350impl LogicalMultiJoin {
351    pub fn as_reordered_left_deep_join(&self, join_ordering: &[usize]) -> PlanRef {
352        assert_eq!(join_ordering.len(), self.inputs.len());
353        assert!(!join_ordering.is_empty());
354
355        let base_plan = self.inputs[join_ordering[0]].clone();
356
357        // Express as a cross join, we will rely on filter pushdown to push all of the join
358        // conditions to convert into inner joins.
359        let mut output = join_ordering[1..]
360            .iter()
361            .fold(base_plan, |join_chain, &index| {
362                LogicalJoin::new(
363                    join_chain,
364                    self.inputs[index].clone(),
365                    JoinType::Inner,
366                    Condition::true_cond(),
367                )
368                .into()
369            });
370
371        let total_col_num = self.inner2output.source_size();
372        let reorder_mapping = {
373            let mut reorder_mapping = vec![None; total_col_num];
374            join_ordering
375                .iter()
376                .cloned()
377                .flat_map(|input_idx| {
378                    (0..self.inputs[input_idx].schema().len())
379                        .map(move |col_idx| self.inner_i2o_mappings[input_idx].map(col_idx))
380                })
381                .enumerate()
382                .for_each(|(tar, src)| reorder_mapping[src] = Some(tar));
383            reorder_mapping
384        };
385        output =
386            LogicalProject::with_out_col_idx(output, reorder_mapping.iter().map(|i| i.unwrap()))
387                .into();
388
389        // We will later push down all of the filters back to the individual joins via the
390        // `FilterJoinRule`.
391        output = LogicalFilter::create(output, self.on.clone());
392        output =
393            LogicalProject::with_out_col_idx(output, self.output_indices.iter().cloned()).into();
394
395        output
396    }
397
398    #[allow(clippy::doc_overindented_list_items)]
399    /// Our heuristic join reordering algorithm will try to perform a left-deep join.
400    /// It will try to do the following:
401    ///
402    /// 1. First, split the join graph, with eq join conditions as graph edges, into their connected
403    ///    components. Repeat the procedure in 2. with the largest connected components down to
404    ///    the smallest.
405    ///
406    /// 2. For each connected component, add joins to the chain, prioritizing adding those
407    ///    joins to the bottom of the chain if their join conditions have:
408    ///
409    ///      a. eq joins between primary keys on both sides
410    ///      b. eq joins with primary keys on one side
411    ///      c. more equijoin conditions
412    ///
413    ///    in that order. This forms our selectivity heuristic.
414    ///
415    /// 3. Thirdly, we will emit a left-deep cross-join of each of the left-deep joins of the
416    ///    connected components. Depending on the type of plan, this may result in a planner failure
417    ///    (e.g. for streaming). No cross-join will be emitted for a single connected component.
418    ///
419    /// 4. Finally, we will emit, above the left-deep join tree:
420    ///    a. a filter with the non eq conditions
421    ///    b. a projection which reorders the output column ordering to agree with the original ordering of the joins.
422    ///    The filter will then be pushed down by another filter pushdown pass.
423    pub(crate) fn heuristic_ordering(&self) -> Result<Vec<usize>> {
424        let mut labeller = ConnectedComponentLabeller::new(self.inputs.len());
425
426        let (eq_join_conditions, _) = self.on.clone().split_by_input_col_nums(
427            &self.input_col_nums(),
428            // only_eq=
429            true,
430        );
431
432        // Iterate over all join conditions, whose keys represent edges on the join graph
433        for k in eq_join_conditions.keys() {
434            labeller.add_edge(k.0, k.1);
435        }
436
437        let mut edge_sets: Vec<_> = labeller.into_edge_sets();
438
439        // Sort in decreasing order of len
440        edge_sets.sort_by_key(|a| std::cmp::Reverse(a.len()));
441
442        let mut join_ordering = vec![];
443
444        for component in edge_sets {
445            let mut eq_cond_edges: Vec<(usize, usize)> = component.into_iter().collect();
446
447            // TODO(jon-chuang): add sorting of eq_cond_edges based on selectivity here
448            eq_cond_edges.sort();
449
450            if eq_cond_edges.is_empty() {
451                // There is nothing to join in this connected component
452                break;
453            };
454
455            let edge = eq_cond_edges.remove(0);
456            join_ordering.extend(&vec![edge.0, edge.1]);
457
458            while !eq_cond_edges.is_empty() {
459                let mut found = vec![];
460                for (idx, edge) in eq_cond_edges.iter().enumerate() {
461                    // If the eq join condition is on the existing join, we don't add any new
462                    // inputs to the join
463                    if join_ordering.contains(&edge.1) && join_ordering.contains(&edge.0) {
464                        found.push(idx);
465                    } else {
466                        // Else, the eq join condition involves a new input, or is not connected to
467                        // the existing left deep tree. Handle accordingly.
468                        let new_input = if join_ordering.contains(&edge.0) {
469                            edge.1
470                        } else if join_ordering.contains(&edge.1) {
471                            edge.0
472                        } else {
473                            continue;
474                        };
475                        join_ordering.push(new_input);
476                        found.push(idx);
477                    }
478                }
479                // This ensures eq_cond_edges.len() is strictly decreasing per iteration
480                // Since the graph is connected, it is always possible to find at least one edge
481                // remaining that can be connected to the current join result.
482                if found.is_empty() {
483                    return Err(RwError::from(ErrorCode::InternalError(
484                        "Connecting edge not found in join connected subgraph".into(),
485                    )));
486                }
487                let mut idx = 0;
488                eq_cond_edges.retain(|_| {
489                    let keep = !found.contains(&idx);
490                    idx += 1;
491                    keep
492                });
493            }
494        }
495        // Deal with singleton inputs (with no eq condition joins between them whatsoever)
496        for i in 0..self.inputs.len() {
497            if !join_ordering.contains(&i) {
498                join_ordering.push(i);
499            }
500        }
501        Ok(join_ordering)
502    }
503
504    #[allow(clippy::doc_overindented_list_items)]
505    /// transform multijoin into bushy tree join.
506    ///
507    /// 1. First, use equivalent condition derivation to get derive join relation.
508    /// 2. Second, for every isolated node will create connection to every other nodes.
509    /// 3. Third, select and merge one node for a iteration, and use a bfs policy for which node the
510    ///    selected node merged with.
511    ///    i. The select node mentioned above is the node with least number of relations and the lowerst join tree.
512    ///    ii. nodes with a join tree higher than the temporal optimal join tree will be pruned.
513    pub fn as_bushy_tree_join(&self) -> Result<PlanRef> {
514        let (nodes, condition) = self.get_join_graph()?;
515
516        if nodes.is_empty() {
517            return Err(RwError::from(ErrorCode::InternalError(
518                "empty multi-join graph".into(),
519            )));
520        }
521
522        let mut optimized_bushy_tree: Option<(GraphNode, Vec<GraphNode>)> = None;
523        let mut que: VecDeque<(BTreeMap<usize, GraphNode>, Vec<GraphNode>)> =
524            VecDeque::from([(nodes, vec![])]);
525
526        while let Some((mut nodes, mut isolated)) = que.pop_front() {
527            if nodes.len() == 1 {
528                let node = nodes.into_values().next().unwrap();
529
530                if let Some((old, _)) = &optimized_bushy_tree {
531                    if node.join_tree.height < old.join_tree.height {
532                        optimized_bushy_tree = Some((node, isolated));
533                    }
534                } else {
535                    optimized_bushy_tree = Some((node, isolated));
536                }
537                continue;
538            } else if nodes.is_empty() {
539                if optimized_bushy_tree.is_none() {
540                    let base = isolated.pop().unwrap();
541                    optimized_bushy_tree = Some((base, isolated));
542                }
543                continue;
544            }
545
546            let (idx, _) = nodes
547                .iter()
548                .min_by(
549                    |(_, x), (_, y)| match x.relations.len().cmp(&y.relations.len()) {
550                        Ordering::Less => Ordering::Less,
551                        Ordering::Greater => Ordering::Greater,
552                        Ordering::Equal => x.join_tree.height.cmp(&y.join_tree.height),
553                    },
554                )
555                .unwrap();
556            let n_id = *idx;
557
558            let n = nodes.get(&n_id).unwrap();
559            if n.relations.is_empty() {
560                let n = nodes.remove(&n_id).unwrap();
561                isolated.push(n);
562                que.push_back((nodes, isolated));
563                continue;
564            }
565
566            let mut relations = nodes
567                .get_mut(&n_id)
568                .unwrap()
569                .relations
570                .iter()
571                .cloned()
572                .collect_vec();
573            relations.sort_by(|a, b| {
574                let a = nodes.get(a).unwrap();
575                let b = nodes.get(b).unwrap();
576                match a.join_tree.height.cmp(&b.join_tree.height) {
577                    Ordering::Equal => a.id.cmp(&b.id),
578                    other => other,
579                }
580            });
581
582            for merge_node_id in &relations {
583                let mut nodes = nodes.clone();
584                let n = nodes.remove(&n_id).unwrap();
585
586                for adj_node_id in &n.relations {
587                    if adj_node_id != merge_node_id {
588                        let adj_node = nodes.get_mut(adj_node_id).unwrap();
589                        adj_node.relations.remove(&n_id);
590                        adj_node.relations.insert(*merge_node_id);
591                        let merge_node = nodes.get_mut(merge_node_id).unwrap();
592                        merge_node.relations.insert(*adj_node_id);
593                    }
594                }
595
596                let merge_node = nodes.get_mut(merge_node_id).unwrap();
597                merge_node.relations.remove(&n_id);
598                let l_tree = n.join_tree.clone();
599                let r_tree = std::mem::take(&mut merge_node.join_tree);
600                let new_height = usize::max(l_tree.height, r_tree.height) + 1;
601
602                if let Some(min_height) = optimized_bushy_tree
603                    .as_ref()
604                    .map(|(t, _)| t.join_tree.height)
605                    && min_height < new_height
606                {
607                    continue;
608                }
609
610                merge_node.join_tree = JoinTreeNode {
611                    idx: None,
612                    left: Some(Box::new(l_tree)),
613                    right: Some(Box::new(r_tree)),
614                    height: new_height,
615                };
616
617                que.push_back((nodes, isolated.clone()));
618            }
619        }
620
621        // maintain join order to mapping columns.
622        let mut join_ordering = vec![];
623        let mut output = if let Some((optimized_bushy_tree, isolated)) = optimized_bushy_tree {
624            let optimized_bushy_tree =
625                isolated
626                    .into_iter()
627                    .fold(optimized_bushy_tree, |chain, n| GraphNode {
628                        id: n.id,
629                        relations: BTreeSet::default(),
630                        join_tree: JoinTreeNode {
631                            height: chain.join_tree.height.max(n.join_tree.height) + 1,
632                            idx: None,
633                            left: Some(Box::new(chain.join_tree)),
634                            right: Some(Box::new(n.join_tree)),
635                        },
636                    });
637            self.create_logical_join(optimized_bushy_tree.join_tree, &mut join_ordering)?
638        } else {
639            return Err(RwError::from(ErrorCode::InternalError(
640                "no plan remain".into(),
641            )));
642        };
643
644        let total_col_num = self.inner2output.source_size();
645        let reorder_mapping = {
646            let mut reorder_mapping = vec![None; total_col_num];
647
648            join_ordering
649                .iter()
650                .cloned()
651                .flat_map(|input_idx| {
652                    (0..self.inputs[input_idx].schema().len())
653                        .map(move |col_idx| self.inner_i2o_mappings[input_idx].map(col_idx))
654                })
655                .enumerate()
656                .for_each(|(tar, src)| reorder_mapping[src] = Some(tar));
657            reorder_mapping
658        };
659        output =
660            LogicalProject::with_out_col_idx(output, reorder_mapping.iter().map(|i| i.unwrap()))
661                .into();
662
663        output = LogicalFilter::create(output, condition);
664        output =
665            LogicalProject::with_out_col_idx(output, self.output_indices.iter().cloned()).into();
666        Ok(output)
667    }
668
669    pub(crate) fn input_col_nums(&self) -> Vec<usize> {
670        self.inputs.iter().map(|i| i.schema().len()).collect()
671    }
672
673    /// get join graph from `self.on`, return the join graph and the new join condition.
674    fn get_join_graph(&self) -> Result<(BTreeMap<usize, GraphNode>, Condition)> {
675        let mut nodes: BTreeMap<_, _> = (0..self.inputs.len())
676            .map(|idx| GraphNode {
677                id: idx,
678                relations: BTreeSet::default(),
679                join_tree: JoinTreeNode {
680                    idx: Some(idx),
681                    left: None,
682                    right: None,
683                    height: 0,
684                },
685            })
686            .enumerate()
687            .collect();
688
689        let condition = self.on.clone();
690        let condition = self.eq_condition_derivation(condition)?;
691        let (eq_join_conditions, _) = condition
692            .clone()
693            .split_by_input_col_nums(&self.input_col_nums(), true);
694
695        for ((src, dst), _) in eq_join_conditions {
696            nodes.get_mut(&src).unwrap().relations.insert(dst);
697            nodes.get_mut(&dst).unwrap().relations.insert(src);
698        }
699
700        Ok((nodes, condition))
701    }
702
703    ///  equivalent condition derivation by `a = b && a = c` ==> `b = c`
704    fn eq_condition_derivation(&self, mut condition: Condition) -> Result<Condition> {
705        let (eq_join_conditions, _) = condition
706            .clone()
707            .split_by_input_col_nums(&self.input_col_nums(), true);
708
709        let mut new_conj: BTreeMap<usize, BTreeSet<usize>> = BTreeMap::new();
710        let mut input_ref_map = BTreeMap::new();
711
712        for con in eq_join_conditions.values() {
713            for conj in &con.conjunctions {
714                let (l, r) = conj.as_eq_cond().unwrap();
715                new_conj.entry(l.index).or_default().insert(r.index);
716                new_conj.entry(r.index).or_default().insert(l.index);
717                input_ref_map.insert(l.index, Some(l));
718                input_ref_map.insert(r.index, Some(r));
719            }
720        }
721
722        let mut new_pairs = BTreeSet::new();
723
724        for conjs in new_conj.values() {
725            if conjs.len() < 2 {
726                continue;
727            }
728
729            let conjs = conjs.iter().copied().collect_vec();
730            for i in 0..conjs.len() {
731                for j in i + 1..conjs.len() {
732                    if !new_conj.get(&conjs[i]).unwrap().contains(&conjs[j]) {
733                        if conjs[i] < conjs[j] {
734                            new_pairs.insert((conjs[i], conjs[j]));
735                        } else {
736                            new_pairs.insert((conjs[j], conjs[i]));
737                        }
738                    }
739                }
740            }
741        }
742        for (i, j) in new_pairs {
743            condition
744                .conjunctions
745                .push(ExprImpl::FunctionCall(Box::new(FunctionCall::new(
746                    ExprType::Equal,
747                    vec![
748                        ExprImpl::InputRef(Box::new(
749                            input_ref_map.get(&i).unwrap().as_ref().unwrap().clone(),
750                        )),
751                        ExprImpl::InputRef(Box::new(
752                            input_ref_map.get(&j).unwrap().as_ref().unwrap().clone(),
753                        )),
754                    ],
755                )?)));
756        }
757        Ok(condition)
758    }
759
760    /// create logical plan by recursively travase `JoinTreeNode`
761    fn create_logical_join(
762        &self,
763        mut join_tree: JoinTreeNode,
764        join_ordering: &mut Vec<usize>,
765    ) -> Result<PlanRef> {
766        Ok(match (join_tree.left.take(), join_tree.right.take()) {
767            (Some(l), Some(r)) => LogicalJoin::new(
768                self.create_logical_join(*l, join_ordering)?,
769                self.create_logical_join(*r, join_ordering)?,
770                JoinType::Inner,
771                Condition::true_cond(),
772            )
773            .into(),
774            (None, None) => {
775                if let Some(idx) = join_tree.idx {
776                    join_ordering.push(idx);
777                    self.inputs[idx].clone()
778                } else {
779                    return Err(RwError::from(ErrorCode::InternalError(
780                        "id of the leaf node not found in the join tree".into(),
781                    )));
782                }
783            }
784            (_, _) => {
785                return Err(RwError::from(ErrorCode::InternalError(
786                    "only leaf node can have None subtree".into(),
787                )));
788            }
789        })
790    }
791}
792
793// Join tree internal representation
794#[derive(Clone, Default, Debug)]
795struct JoinTreeNode {
796    idx: Option<usize>,
797    left: Option<Box<JoinTreeNode>>,
798    right: Option<Box<JoinTreeNode>>,
799    height: usize,
800}
801
802// join graph internal representation
803#[derive(Clone, Debug)]
804struct GraphNode {
805    id: usize,
806    join_tree: JoinTreeNode,
807    relations: BTreeSet<usize>,
808}
809
810impl ToStream for LogicalMultiJoin {
811    fn logical_rewrite_for_stream(
812        &self,
813        _ctx: &mut RewriteStreamContext,
814    ) -> Result<(PlanRef, ColIndexMapping)> {
815        panic!(
816            "Method not available for `LogicalMultiJoin` which is a placeholder node with \
817             a temporary lifetime. It only facilitates join reordering during logical planning."
818        )
819    }
820
821    fn to_stream(&self, _ctx: &mut ToStreamContext) -> Result<PlanRef> {
822        panic!(
823            "Method not available for `LogicalMultiJoin` which is a placeholder node with \
824             a temporary lifetime. It only facilitates join reordering during logical planning."
825        )
826    }
827}
828
829impl ToBatch for LogicalMultiJoin {
830    fn to_batch(&self) -> Result<PlanRef> {
831        panic!(
832            "Method not available for `LogicalMultiJoin` which is a placeholder node with \
833             a temporary lifetime. It only facilitates join reordering during logical planning."
834        )
835    }
836}
837
838impl ColPrunable for LogicalMultiJoin {
839    fn prune_col(&self, _required_cols: &[usize], _ctx: &mut ColumnPruningContext) -> PlanRef {
840        panic!(
841            "Method not available for `LogicalMultiJoin` which is a placeholder node with \
842             a temporary lifetime. It only facilitates join reordering during logical planning."
843        )
844    }
845}
846
847impl ExprRewritable for LogicalMultiJoin {
848    fn rewrite_exprs(&self, _r: &mut dyn ExprRewriter) -> PlanRef {
849        panic!(
850            "Method not available for `LogicalMultiJoin` which is a placeholder node with \
851             a temporary lifetime. It only facilitates join reordering during logical planning."
852        )
853    }
854}
855
856impl ExprVisitable for LogicalMultiJoin {
857    fn visit_exprs(&self, _v: &mut dyn ExprVisitor) {
858        panic!(
859            "Method not available for `LogicalMultiJoin` which is a placeholder node with \
860             a temporary lifetime. It only facilitates join reordering during logical planning."
861        )
862    }
863}
864
865impl PredicatePushdown for LogicalMultiJoin {
866    fn predicate_pushdown(
867        &self,
868        _predicate: Condition,
869        _ctx: &mut PredicatePushdownContext,
870    ) -> PlanRef {
871        panic!(
872            "Method not available for `LogicalMultiJoin` which is a placeholder node with \
873             a temporary lifetime. It only facilitates join reordering during logical planning."
874        )
875    }
876}
877
878#[cfg(test)]
879mod test {
880    use std::collections::HashSet;
881
882    use risingwave_common::catalog::Field;
883    use risingwave_common::types::DataType;
884    use risingwave_pb::expr::expr_node::Type;
885
886    use super::*;
887    use crate::expr::InputRef;
888    use crate::optimizer::optimizer_context::OptimizerContext;
889    use crate::optimizer::plan_node::LogicalValues;
890    use crate::optimizer::plan_node::generic::GenericPlanRef;
891    use crate::optimizer::property::FunctionalDependency;
892    #[tokio::test]
893    async fn fd_derivation_multi_join() {
894        // t1: [v0, v1], t2: [v2, v3, v4], t3: [v5, v6]
895        // FD: v0 --> v1, v2 --> { v3, v4 }, {} --> v5
896        // On: v0 = 0 AND v1 = v3 AND v4 = v5
897        //
898        // Output: [v0, v1, v4, v5]
899        // FD: v0 --> v1, {} --> v0, {} --> v5, v4 --> v5, v5 --> v4
900        let ctx = OptimizerContext::mock().await;
901        let t1 = {
902            let fields: Vec<Field> = vec![
903                Field::with_name(DataType::Int32, "v0"),
904                Field::with_name(DataType::Int32, "v1"),
905            ];
906            let mut values = LogicalValues::new(vec![], Schema { fields }, ctx.clone());
907            // 0 --> 1
908            values
909                .base
910                .functional_dependency_mut()
911                .add_functional_dependency_by_column_indices(&[0], &[1]);
912            values
913        };
914        let t2 = {
915            let fields: Vec<Field> = vec![
916                Field::with_name(DataType::Int32, "v2"),
917                Field::with_name(DataType::Int32, "v3"),
918                Field::with_name(DataType::Int32, "v4"),
919            ];
920            let mut values = LogicalValues::new(vec![], Schema { fields }, ctx.clone());
921            // 0 --> 1, 2
922            values
923                .base
924                .functional_dependency_mut()
925                .add_functional_dependency_by_column_indices(&[0], &[1, 2]);
926            values
927        };
928        let t3 = {
929            let fields: Vec<Field> = vec![
930                Field::with_name(DataType::Int32, "v5"),
931                Field::with_name(DataType::Int32, "v6"),
932            ];
933            let mut values = LogicalValues::new(vec![], Schema { fields }, ctx);
934            // {} --> 0
935            values
936                .base
937                .functional_dependency_mut()
938                .add_functional_dependency_by_column_indices(&[], &[0]);
939            values
940        };
941        // On: v0 = 0 AND v1 = v3 AND v4 = v5
942        let on: ExprImpl = FunctionCall::new(
943            Type::And,
944            vec![
945                FunctionCall::new(
946                    Type::Equal,
947                    vec![
948                        InputRef::new(0, DataType::Int32).into(),
949                        ExprImpl::literal_int(0),
950                    ],
951                )
952                .unwrap()
953                .into(),
954                FunctionCall::new(
955                    Type::And,
956                    vec![
957                        FunctionCall::new(
958                            Type::Equal,
959                            vec![
960                                InputRef::new(1, DataType::Int32).into(),
961                                InputRef::new(3, DataType::Int32).into(),
962                            ],
963                        )
964                        .unwrap()
965                        .into(),
966                        FunctionCall::new(
967                            Type::Equal,
968                            vec![
969                                InputRef::new(4, DataType::Int32).into(),
970                                InputRef::new(5, DataType::Int32).into(),
971                            ],
972                        )
973                        .unwrap()
974                        .into(),
975                    ],
976                )
977                .unwrap()
978                .into(),
979            ],
980        )
981        .unwrap()
982        .into();
983        let multi_join = LogicalMultiJoin::new(
984            vec![t1.into(), t2.into(), t3.into()],
985            Condition::with_expr(on),
986            vec![0, 1, 4, 5],
987        );
988        let expected_fd_set: HashSet<_> = [
989            FunctionalDependency::with_indices(4, &[0], &[1]),
990            FunctionalDependency::with_indices(4, &[], &[0, 3]),
991            FunctionalDependency::with_indices(4, &[2], &[3]),
992            FunctionalDependency::with_indices(4, &[3], &[2]),
993        ]
994        .into_iter()
995        .collect();
996        let fd_set: HashSet<_> = multi_join
997            .functional_dependency()
998            .as_dependencies()
999            .iter()
1000            .cloned()
1001            .collect();
1002        assert_eq!(expected_fd_set, fd_set);
1003    }
1004}