risingwave_frontend/optimizer/plan_node/
logical_multi_join.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::cmp::Ordering;
16use std::collections::{BTreeMap, BTreeSet, VecDeque};
17
18use itertools::Itertools;
19use pretty_xmlish::{Pretty, XmlNode};
20use risingwave_common::catalog::Schema;
21use risingwave_pb::plan_common::JoinType;
22
23use super::utils::{Distill, childless_record};
24use super::{
25    ColPrunable, ExprRewritable, Logical, LogicalFilter, LogicalJoin, LogicalProject, PlanBase,
26    PlanNodeType, PlanRef, PlanTreeNodeBinary, PlanTreeNodeUnary, PredicatePushdown, ToBatch,
27    ToStream,
28};
29use crate::error::{ErrorCode, Result, RwError};
30use crate::expr::{ExprImpl, ExprRewriter, ExprType, ExprVisitor, FunctionCall};
31use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
32use crate::optimizer::plan_node::{
33    ColumnPruningContext, PlanTreeNode, PredicatePushdownContext, RewriteStreamContext,
34    ToStreamContext,
35};
36use crate::optimizer::property::FunctionalDependencySet;
37use crate::utils::{
38    ColIndexMapping, ColIndexMappingRewriteExt, Condition, ConditionDisplay,
39    ConnectedComponentLabeller,
40};
41
42/// `LogicalMultiJoin` combines two or more relations according to some condition.
43///
44/// Each output row has fields from one the inputs. The set of output rows is a subset
45/// of the cartesian product of all the inputs; The `LogicalMultiInnerJoin` is only supported
46/// for inner joins as it implicitly assumes commutativity. Non-inner joins should be
47/// expressed as 2-way `LogicalJoin`s.
48#[derive(Debug, Clone, PartialEq, Eq, Hash)]
49pub struct LogicalMultiJoin {
50    pub base: PlanBase<Logical>,
51    inputs: Vec<PlanRef>,
52    on: Condition,
53    output_indices: Vec<usize>,
54    inner2output: ColIndexMapping,
55    // NOTE(st1page): these fields will be used in prune_col and
56    // pk_derive soon.
57    /// the mapping `output_col_idx` -> (`input_idx`, `input_col_idx`), **"`output_col_idx`" is internal,
58    /// not consider `output_indices`**
59    inner_o2i_mapping: Vec<(usize, usize)>,
60    inner_i2o_mappings: Vec<ColIndexMapping>,
61}
62
63impl Distill for LogicalMultiJoin {
64    fn distill<'a>(&self) -> XmlNode<'a> {
65        let fields = (self.inputs.iter())
66            .flat_map(|input| input.schema().fields.clone())
67            .collect();
68        let input_schema = Schema { fields };
69        let cond = Pretty::display(&ConditionDisplay {
70            condition: self.on(),
71            input_schema: &input_schema,
72        });
73        childless_record("LogicalMultiJoin", vec![("on", cond)])
74    }
75}
76
77#[derive(Debug, Clone, PartialEq, Eq, Hash)]
78pub struct LogicalMultiJoinBuilder {
79    output_indices: Vec<usize>,
80    /// the predicates in the on condition, we do not use Condition here to emit unnecessary
81    /// simplify.
82    conjunctions: Vec<ExprImpl>,
83    inputs: Vec<PlanRef>,
84    tot_input_col_num: usize,
85}
86
87impl LogicalMultiJoinBuilder {
88    /// add a predicate above the plan, so they will be rewritten from the `output_indices` to the
89    /// input indices
90    pub fn add_predicate_above(&mut self, exprs: impl Iterator<Item = ExprImpl>) {
91        let mut mapping = ColIndexMapping::new(
92            self.output_indices.iter().map(|i| Some(*i)).collect(),
93            self.tot_input_col_num,
94        );
95        self.conjunctions
96            .extend(exprs.map(|expr| mapping.rewrite_expr(expr)));
97    }
98
99    pub fn build(self) -> LogicalMultiJoin {
100        LogicalMultiJoin::new(
101            self.inputs,
102            Condition {
103                conjunctions: self.conjunctions,
104            },
105            self.output_indices,
106        )
107    }
108
109    pub fn into_parts(self) -> (Vec<usize>, Vec<ExprImpl>, Vec<PlanRef>, usize) {
110        (
111            self.output_indices,
112            self.conjunctions,
113            self.inputs,
114            self.tot_input_col_num,
115        )
116    }
117
118    pub fn new(plan: PlanRef) -> LogicalMultiJoinBuilder {
119        match plan.node_type() {
120            PlanNodeType::LogicalJoin => Self::with_join(plan),
121            PlanNodeType::LogicalFilter => Self::with_filter(plan),
122            PlanNodeType::LogicalProject => Self::with_project(plan),
123            _ => Self::with_input(plan),
124        }
125    }
126
127    fn with_join(plan: PlanRef) -> LogicalMultiJoinBuilder {
128        let join: &LogicalJoin = plan.as_logical_join().unwrap();
129        if join.join_type() != JoinType::Inner {
130            return Self::with_input(plan);
131        }
132        let left = join.left();
133        let right = join.right();
134
135        let mut builder = Self::new(left);
136
137        let (r_output_indices, r_conjunctions, mut r_inputs, r_tot_input_col_num) =
138            Self::new(right).into_parts();
139
140        // the mapping from the right's column index to the current multi join's internal column
141        // index
142        let mut shift_mapping = ColIndexMapping::with_shift_offset(
143            r_tot_input_col_num,
144            builder.tot_input_col_num as isize,
145        );
146        builder.inputs.append(&mut r_inputs);
147        builder.tot_input_col_num += r_tot_input_col_num;
148
149        builder.conjunctions.extend(
150            r_conjunctions
151                .into_iter()
152                .map(|expr| shift_mapping.rewrite_expr(expr)),
153        );
154
155        builder.output_indices.extend(
156            r_output_indices
157                .into_iter()
158                .map(|idx| shift_mapping.map(idx)),
159        );
160        builder.add_predicate_above(join.on().conjunctions.iter().cloned());
161
162        builder.output_indices = join
163            .output_indices()
164            .iter()
165            .map(|idx| builder.output_indices[*idx])
166            .collect();
167        builder
168    }
169
170    fn with_filter(plan: PlanRef) -> LogicalMultiJoinBuilder {
171        let filter: &LogicalFilter = plan.as_logical_filter().unwrap();
172        let mut builder = Self::new(filter.input());
173        builder.add_predicate_above(filter.predicate().conjunctions.iter().cloned());
174        builder
175    }
176
177    fn with_project(plan: PlanRef) -> LogicalMultiJoinBuilder {
178        let proj: &LogicalProject = plan.as_logical_project().unwrap();
179        let output_indices = match proj.try_as_projection() {
180            Some(output_indices) => output_indices,
181            None => return Self::with_input(plan),
182        };
183        let mut builder = Self::new(proj.input());
184        builder.output_indices = output_indices
185            .into_iter()
186            .map(|i| builder.output_indices[i])
187            .collect();
188        builder
189    }
190
191    fn with_input(input: PlanRef) -> LogicalMultiJoinBuilder {
192        LogicalMultiJoinBuilder {
193            output_indices: (0..input.schema().len()).collect_vec(),
194            conjunctions: vec![],
195            tot_input_col_num: input.schema().len(),
196            inputs: vec![input],
197        }
198    }
199
200    pub fn inputs(&self) -> &[PlanRef] {
201        self.inputs.as_ref()
202    }
203}
204impl LogicalMultiJoin {
205    pub(crate) fn new(inputs: Vec<PlanRef>, on: Condition, output_indices: Vec<usize>) -> Self {
206        let input_schemas = inputs
207            .iter()
208            .map(|input| input.schema().clone())
209            .collect_vec();
210
211        let (inner_o2i_mapping, tot_col_num) = {
212            let mut inner_o2i_mapping = vec![];
213            let mut tot_col_num = 0;
214            for (input_idx, input_schema) in input_schemas.iter().enumerate() {
215                tot_col_num += input_schema.len();
216                for (col_idx, _field) in input_schema.fields().iter().enumerate() {
217                    inner_o2i_mapping.push((input_idx, col_idx));
218                }
219            }
220            (inner_o2i_mapping, tot_col_num)
221        };
222        let inner2output = ColIndexMapping::with_remaining_columns(&output_indices, tot_col_num);
223
224        let schema = Schema {
225            fields: output_indices
226                .iter()
227                .map(|idx| inner_o2i_mapping[*idx])
228                .map(|(input_idx, col_idx)| input_schemas[input_idx].fields()[col_idx].clone())
229                .collect(),
230        };
231
232        let inner_i2o_mappings = {
233            let mut i2o_maps = vec![];
234            for input_schema in &input_schemas {
235                let map = vec![None; input_schema.len()];
236                i2o_maps.push(map);
237            }
238            for (out_idx, (input_idx, in_idx)) in inner_o2i_mapping.iter().enumerate() {
239                i2o_maps[*input_idx][*in_idx] = Some(out_idx);
240            }
241
242            i2o_maps
243                .into_iter()
244                .map(|map| ColIndexMapping::new(map, tot_col_num))
245                .collect_vec()
246        };
247
248        let pk_indices = Self::derive_stream_key(&inputs, &inner_i2o_mappings, &inner2output);
249        let functional_dependency = {
250            let mut fd_set = FunctionalDependencySet::new(tot_col_num);
251            let mut column_cnt: usize = 0;
252            let id_mapping = ColIndexMapping::identity(tot_col_num);
253            for i in &inputs {
254                let mapping =
255                    ColIndexMapping::with_shift_offset(i.schema().len(), column_cnt as isize)
256                        .composite(&id_mapping);
257                mapping
258                    .rewrite_functional_dependency_set(i.functional_dependency().clone())
259                    .into_dependencies()
260                    .into_iter()
261                    .for_each(|fd| fd_set.add_functional_dependency(fd));
262                column_cnt += i.schema().len();
263            }
264            for i in &on.conjunctions {
265                if let Some((col, _)) = i.as_eq_const() {
266                    fd_set.add_constant_columns(&[col.index()])
267                } else if let Some((left, right)) = i.as_eq_cond() {
268                    fd_set.add_functional_dependency_by_column_indices(
269                        &[left.index()],
270                        &[right.index()],
271                    );
272                    fd_set.add_functional_dependency_by_column_indices(
273                        &[right.index()],
274                        &[left.index()],
275                    );
276                }
277            }
278            ColIndexMapping::with_remaining_columns(&output_indices, tot_col_num)
279                .rewrite_functional_dependency_set(fd_set)
280        };
281        let base =
282            PlanBase::new_logical(inputs[0].ctx(), schema, pk_indices, functional_dependency);
283
284        Self {
285            base,
286            inputs,
287            on,
288            output_indices,
289            inner2output,
290            inner_o2i_mapping,
291            inner_i2o_mappings,
292        }
293    }
294
295    fn derive_stream_key(
296        inputs: &[PlanRef],
297        inner_i2o_mappings: &[ColIndexMapping],
298        inner2output: &ColIndexMapping,
299    ) -> Option<Vec<usize>> {
300        // TODO(st1page): add JOIN key
301        let mut pk_indices = vec![];
302        for (i, input) in inputs.iter().enumerate() {
303            let input_stream_key = input.stream_key()?;
304            for input_pk_idx in input_stream_key {
305                pk_indices.push(inner_i2o_mappings[i].map(*input_pk_idx));
306            }
307        }
308        pk_indices
309            .into_iter()
310            .map(|col_idx| inner2output.try_map(col_idx))
311            .collect::<Option<Vec<_>>>()
312    }
313
314    /// Get a reference to the logical join's on.
315    pub fn on(&self) -> &Condition {
316        &self.on
317    }
318
319    /// Clone with new `on` condition
320    pub fn clone_with_cond(&self, cond: Condition) -> Self {
321        Self::new(self.inputs.clone(), cond, self.output_indices.clone())
322    }
323}
324
325impl PlanTreeNode for LogicalMultiJoin {
326    fn inputs(&self) -> smallvec::SmallVec<[crate::optimizer::PlanRef; 2]> {
327        let mut vec = smallvec::SmallVec::new();
328        vec.extend(self.inputs.clone());
329        vec
330    }
331
332    fn clone_with_inputs(&self, inputs: &[crate::optimizer::PlanRef]) -> PlanRef {
333        Self::new(
334            inputs.to_vec(),
335            self.on().clone(),
336            self.output_indices.clone(),
337        )
338        .into()
339    }
340}
341
342impl LogicalMultiJoin {
343    pub fn as_reordered_left_deep_join(&self, join_ordering: &[usize]) -> PlanRef {
344        assert_eq!(join_ordering.len(), self.inputs.len());
345        assert!(!join_ordering.is_empty());
346
347        let base_plan = self.inputs[join_ordering[0]].clone();
348
349        // Express as a cross join, we will rely on filter pushdown to push all of the join
350        // conditions to convert into inner joins.
351        let mut output = join_ordering[1..]
352            .iter()
353            .fold(base_plan, |join_chain, &index| {
354                LogicalJoin::new(
355                    join_chain,
356                    self.inputs[index].clone(),
357                    JoinType::Inner,
358                    Condition::true_cond(),
359                )
360                .into()
361            });
362
363        let total_col_num = self.inner2output.source_size();
364        let reorder_mapping = {
365            let mut reorder_mapping = vec![None; total_col_num];
366            join_ordering
367                .iter()
368                .cloned()
369                .flat_map(|input_idx| {
370                    (0..self.inputs[input_idx].schema().len())
371                        .map(move |col_idx| self.inner_i2o_mappings[input_idx].map(col_idx))
372                })
373                .enumerate()
374                .for_each(|(tar, src)| reorder_mapping[src] = Some(tar));
375            reorder_mapping
376        };
377        output =
378            LogicalProject::with_out_col_idx(output, reorder_mapping.iter().map(|i| i.unwrap()))
379                .into();
380
381        // We will later push down all of the filters back to the individual joins via the
382        // `FilterJoinRule`.
383        output = LogicalFilter::create(output, self.on.clone());
384        output =
385            LogicalProject::with_out_col_idx(output, self.output_indices.iter().cloned()).into();
386
387        output
388    }
389
390    #[allow(clippy::doc_overindented_list_items)]
391    /// Our heuristic join reordering algorithm will try to perform a left-deep join.
392    /// It will try to do the following:
393    ///
394    /// 1. First, split the join graph, with eq join conditions as graph edges, into their connected
395    ///    components. Repeat the procedure in 2. with the largest connected components down to
396    ///    the smallest.
397    ///
398    /// 2. For each connected component, add joins to the chain, prioritizing adding those
399    ///    joins to the bottom of the chain if their join conditions have:
400    ///
401    ///      a. eq joins between primary keys on both sides
402    ///      b. eq joins with primary keys on one side
403    ///      c. more equijoin conditions
404    ///
405    ///    in that order. This forms our selectivity heuristic.
406    ///
407    /// 3. Thirdly, we will emit a left-deep cross-join of each of the left-deep joins of the
408    ///    connected components. Depending on the type of plan, this may result in a planner failure
409    ///    (e.g. for streaming). No cross-join will be emitted for a single connected component.
410    ///
411    /// 4. Finally, we will emit, above the left-deep join tree:
412    ///    a. a filter with the non eq conditions
413    ///    b. a projection which reorders the output column ordering to agree with the original ordering of the joins.
414    ///    The filter will then be pushed down by another filter pushdown pass.
415    pub(crate) fn heuristic_ordering(&self) -> Result<Vec<usize>> {
416        let mut labeller = ConnectedComponentLabeller::new(self.inputs.len());
417
418        let (eq_join_conditions, _) = self.on.clone().split_by_input_col_nums(
419            &self.input_col_nums(),
420            // only_eq=
421            true,
422        );
423
424        // Iterate over all join conditions, whose keys represent edges on the join graph
425        for k in eq_join_conditions.keys() {
426            labeller.add_edge(k.0, k.1);
427        }
428
429        let mut edge_sets: Vec<_> = labeller.into_edge_sets();
430
431        // Sort in decreasing order of len
432        edge_sets.sort_by_key(|a| std::cmp::Reverse(a.len()));
433
434        let mut join_ordering = vec![];
435
436        for component in edge_sets {
437            let mut eq_cond_edges: Vec<(usize, usize)> = component.into_iter().collect();
438
439            // TODO(jon-chuang): add sorting of eq_cond_edges based on selectivity here
440            eq_cond_edges.sort();
441
442            if eq_cond_edges.is_empty() {
443                // There is nothing to join in this connected component
444                break;
445            };
446
447            let edge = eq_cond_edges.remove(0);
448            join_ordering.extend(&vec![edge.0, edge.1]);
449
450            while !eq_cond_edges.is_empty() {
451                let mut found = vec![];
452                for (idx, edge) in eq_cond_edges.iter().enumerate() {
453                    // If the eq join condition is on the existing join, we don't add any new
454                    // inputs to the join
455                    if join_ordering.contains(&edge.1) && join_ordering.contains(&edge.0) {
456                        found.push(idx);
457                    } else {
458                        // Else, the eq join condition involves a new input, or is not connected to
459                        // the existing left deep tree. Handle accordingly.
460                        let new_input = if join_ordering.contains(&edge.0) {
461                            edge.1
462                        } else if join_ordering.contains(&edge.1) {
463                            edge.0
464                        } else {
465                            continue;
466                        };
467                        join_ordering.push(new_input);
468                        found.push(idx);
469                    }
470                }
471                // This ensures eq_cond_edges.len() is strictly decreasing per iteration
472                // Since the graph is connected, it is always possible to find at least one edge
473                // remaining that can be connected to the current join result.
474                if found.is_empty() {
475                    return Err(RwError::from(ErrorCode::InternalError(
476                        "Connecting edge not found in join connected subgraph".into(),
477                    )));
478                }
479                let mut idx = 0;
480                eq_cond_edges.retain(|_| {
481                    let keep = !found.contains(&idx);
482                    idx += 1;
483                    keep
484                });
485            }
486        }
487        // Deal with singleton inputs (with no eq condition joins between them whatsoever)
488        for i in 0..self.inputs.len() {
489            if !join_ordering.contains(&i) {
490                join_ordering.push(i);
491            }
492        }
493        Ok(join_ordering)
494    }
495
496    #[allow(clippy::doc_overindented_list_items)]
497    /// transform multijoin into bushy tree join.
498    ///
499    /// 1. First, use equivalent condition derivation to get derive join relation.
500    /// 2. Second, for every isolated node will create connection to every other nodes.
501    /// 3. Third, select and merge one node for a iteration, and use a bfs policy for which node the
502    ///    selected node merged with.
503    ///    i. The select node mentioned above is the node with least number of relations and the lowerst join tree.
504    ///    ii. nodes with a join tree higher than the temporal optimal join tree will be pruned.
505    pub fn as_bushy_tree_join(&self) -> Result<PlanRef> {
506        let (nodes, condition) = self.get_join_graph()?;
507
508        if nodes.is_empty() {
509            return Err(RwError::from(ErrorCode::InternalError(
510                "empty multi-join graph".into(),
511            )));
512        }
513
514        let mut optimized_bushy_tree: Option<(GraphNode, Vec<GraphNode>)> = None;
515        let mut que: VecDeque<(BTreeMap<usize, GraphNode>, Vec<GraphNode>)> =
516            VecDeque::from([(nodes, vec![])]);
517
518        while let Some((mut nodes, mut isolated)) = que.pop_front() {
519            if nodes.len() == 1 {
520                let node = nodes.into_values().next().unwrap();
521
522                if let Some((old, _)) = &optimized_bushy_tree {
523                    if node.join_tree.height < old.join_tree.height {
524                        optimized_bushy_tree = Some((node, isolated));
525                    }
526                } else {
527                    optimized_bushy_tree = Some((node, isolated));
528                }
529                continue;
530            } else if nodes.is_empty() {
531                if optimized_bushy_tree.is_none() {
532                    let base = isolated.pop().unwrap();
533                    optimized_bushy_tree = Some((base, isolated));
534                }
535                continue;
536            }
537
538            let (idx, _) = nodes
539                .iter()
540                .min_by(
541                    |(_, x), (_, y)| match x.relations.len().cmp(&y.relations.len()) {
542                        Ordering::Less => Ordering::Less,
543                        Ordering::Greater => Ordering::Greater,
544                        Ordering::Equal => x.join_tree.height.cmp(&y.join_tree.height),
545                    },
546                )
547                .unwrap();
548            let n_id = *idx;
549
550            let n = nodes.get(&n_id).unwrap();
551            if n.relations.is_empty() {
552                let n = nodes.remove(&n_id).unwrap();
553                isolated.push(n);
554                que.push_back((nodes, isolated));
555                continue;
556            }
557
558            let mut relations = nodes
559                .get_mut(&n_id)
560                .unwrap()
561                .relations
562                .iter()
563                .cloned()
564                .collect_vec();
565            relations.sort_by(|a, b| {
566                let a = nodes.get(a).unwrap();
567                let b = nodes.get(b).unwrap();
568                match a.join_tree.height.cmp(&b.join_tree.height) {
569                    Ordering::Equal => a.id.cmp(&b.id),
570                    other => other,
571                }
572            });
573
574            for merge_node_id in &relations {
575                let mut nodes = nodes.clone();
576                let n = nodes.remove(&n_id).unwrap();
577
578                for adj_node_id in &n.relations {
579                    if adj_node_id != merge_node_id {
580                        let adj_node = nodes.get_mut(adj_node_id).unwrap();
581                        adj_node.relations.remove(&n_id);
582                        adj_node.relations.insert(*merge_node_id);
583                        let merge_node = nodes.get_mut(merge_node_id).unwrap();
584                        merge_node.relations.insert(*adj_node_id);
585                    }
586                }
587
588                let merge_node = nodes.get_mut(merge_node_id).unwrap();
589                merge_node.relations.remove(&n_id);
590                let l_tree = n.join_tree.clone();
591                let r_tree = std::mem::take(&mut merge_node.join_tree);
592                let new_height = usize::max(l_tree.height, r_tree.height) + 1;
593
594                if let Some(min_height) = optimized_bushy_tree
595                    .as_ref()
596                    .map(|(t, _)| t.join_tree.height)
597                    && min_height < new_height
598                {
599                    continue;
600                }
601
602                merge_node.join_tree = JoinTreeNode {
603                    idx: None,
604                    left: Some(Box::new(l_tree)),
605                    right: Some(Box::new(r_tree)),
606                    height: new_height,
607                };
608
609                que.push_back((nodes, isolated.clone()));
610            }
611        }
612
613        // maintain join order to mapping columns.
614        let mut join_ordering = vec![];
615        let mut output = if let Some((optimized_bushy_tree, isolated)) = optimized_bushy_tree {
616            let optimized_bushy_tree =
617                isolated
618                    .into_iter()
619                    .fold(optimized_bushy_tree, |chain, n| GraphNode {
620                        id: n.id,
621                        relations: BTreeSet::default(),
622                        join_tree: JoinTreeNode {
623                            height: chain.join_tree.height.max(n.join_tree.height) + 1,
624                            idx: None,
625                            left: Some(Box::new(chain.join_tree)),
626                            right: Some(Box::new(n.join_tree)),
627                        },
628                    });
629            self.create_logical_join(optimized_bushy_tree.join_tree, &mut join_ordering)?
630        } else {
631            return Err(RwError::from(ErrorCode::InternalError(
632                "no plan remain".into(),
633            )));
634        };
635
636        let total_col_num = self.inner2output.source_size();
637        let reorder_mapping = {
638            let mut reorder_mapping = vec![None; total_col_num];
639
640            join_ordering
641                .iter()
642                .cloned()
643                .flat_map(|input_idx| {
644                    (0..self.inputs[input_idx].schema().len())
645                        .map(move |col_idx| self.inner_i2o_mappings[input_idx].map(col_idx))
646                })
647                .enumerate()
648                .for_each(|(tar, src)| reorder_mapping[src] = Some(tar));
649            reorder_mapping
650        };
651        output =
652            LogicalProject::with_out_col_idx(output, reorder_mapping.iter().map(|i| i.unwrap()))
653                .into();
654
655        output = LogicalFilter::create(output, condition);
656        output =
657            LogicalProject::with_out_col_idx(output, self.output_indices.iter().cloned()).into();
658        Ok(output)
659    }
660
661    pub(crate) fn input_col_nums(&self) -> Vec<usize> {
662        self.inputs.iter().map(|i| i.schema().len()).collect()
663    }
664
665    /// get join graph from `self.on`, return the join graph and the new join condition.
666    fn get_join_graph(&self) -> Result<(BTreeMap<usize, GraphNode>, Condition)> {
667        let mut nodes: BTreeMap<_, _> = (0..self.inputs.len())
668            .map(|idx| GraphNode {
669                id: idx,
670                relations: BTreeSet::default(),
671                join_tree: JoinTreeNode {
672                    idx: Some(idx),
673                    left: None,
674                    right: None,
675                    height: 0,
676                },
677            })
678            .enumerate()
679            .collect();
680
681        let condition = self.on.clone();
682        let condition = self.eq_condition_derivation(condition)?;
683        let (eq_join_conditions, _) = condition
684            .clone()
685            .split_by_input_col_nums(&self.input_col_nums(), true);
686
687        for ((src, dst), _) in eq_join_conditions {
688            nodes.get_mut(&src).unwrap().relations.insert(dst);
689            nodes.get_mut(&dst).unwrap().relations.insert(src);
690        }
691
692        Ok((nodes, condition))
693    }
694
695    ///  equivalent condition derivation by `a = b && a = c` ==> `b = c`
696    fn eq_condition_derivation(&self, mut condition: Condition) -> Result<Condition> {
697        let (eq_join_conditions, _) = condition
698            .clone()
699            .split_by_input_col_nums(&self.input_col_nums(), true);
700
701        let mut new_conj: BTreeMap<usize, BTreeSet<usize>> = BTreeMap::new();
702        let mut input_ref_map = BTreeMap::new();
703
704        for con in eq_join_conditions.values() {
705            for conj in &con.conjunctions {
706                let (l, r) = conj.as_eq_cond().unwrap();
707                new_conj.entry(l.index).or_default().insert(r.index);
708                new_conj.entry(r.index).or_default().insert(l.index);
709                input_ref_map.insert(l.index, Some(l));
710                input_ref_map.insert(r.index, Some(r));
711            }
712        }
713
714        let mut new_pairs = BTreeSet::new();
715
716        for conjs in new_conj.values() {
717            if conjs.len() < 2 {
718                continue;
719            }
720
721            let conjs = conjs.iter().copied().collect_vec();
722            for i in 0..conjs.len() {
723                for j in i + 1..conjs.len() {
724                    if !new_conj.get(&conjs[i]).unwrap().contains(&conjs[j]) {
725                        if conjs[i] < conjs[j] {
726                            new_pairs.insert((conjs[i], conjs[j]));
727                        } else {
728                            new_pairs.insert((conjs[j], conjs[i]));
729                        }
730                    }
731                }
732            }
733        }
734        for (i, j) in new_pairs {
735            condition
736                .conjunctions
737                .push(ExprImpl::FunctionCall(Box::new(FunctionCall::new(
738                    ExprType::Equal,
739                    vec![
740                        ExprImpl::InputRef(Box::new(
741                            input_ref_map.get(&i).unwrap().as_ref().unwrap().clone(),
742                        )),
743                        ExprImpl::InputRef(Box::new(
744                            input_ref_map.get(&j).unwrap().as_ref().unwrap().clone(),
745                        )),
746                    ],
747                )?)));
748        }
749        Ok(condition)
750    }
751
752    /// create logical plan by recursively travase `JoinTreeNode`
753    fn create_logical_join(
754        &self,
755        mut join_tree: JoinTreeNode,
756        join_ordering: &mut Vec<usize>,
757    ) -> Result<PlanRef> {
758        Ok(match (join_tree.left.take(), join_tree.right.take()) {
759            (Some(l), Some(r)) => LogicalJoin::new(
760                self.create_logical_join(*l, join_ordering)?,
761                self.create_logical_join(*r, join_ordering)?,
762                JoinType::Inner,
763                Condition::true_cond(),
764            )
765            .into(),
766            (None, None) => {
767                if let Some(idx) = join_tree.idx {
768                    join_ordering.push(idx);
769                    self.inputs[idx].clone()
770                } else {
771                    return Err(RwError::from(ErrorCode::InternalError(
772                        "id of the leaf node not found in the join tree".into(),
773                    )));
774                }
775            }
776            (_, _) => {
777                return Err(RwError::from(ErrorCode::InternalError(
778                    "only leaf node can have None subtree".into(),
779                )));
780            }
781        })
782    }
783}
784
785// Join tree internal representation
786#[derive(Clone, Default, Debug)]
787struct JoinTreeNode {
788    idx: Option<usize>,
789    left: Option<Box<JoinTreeNode>>,
790    right: Option<Box<JoinTreeNode>>,
791    height: usize,
792}
793
794// join graph internal representation
795#[derive(Clone, Debug)]
796struct GraphNode {
797    id: usize,
798    join_tree: JoinTreeNode,
799    relations: BTreeSet<usize>,
800}
801
802impl ToStream for LogicalMultiJoin {
803    fn logical_rewrite_for_stream(
804        &self,
805        _ctx: &mut RewriteStreamContext,
806    ) -> Result<(PlanRef, ColIndexMapping)> {
807        panic!(
808            "Method not available for `LogicalMultiJoin` which is a placeholder node with \
809             a temporary lifetime. It only facilitates join reordering during logical planning."
810        )
811    }
812
813    fn to_stream(&self, _ctx: &mut ToStreamContext) -> Result<PlanRef> {
814        panic!(
815            "Method not available for `LogicalMultiJoin` which is a placeholder node with \
816             a temporary lifetime. It only facilitates join reordering during logical planning."
817        )
818    }
819}
820
821impl ToBatch for LogicalMultiJoin {
822    fn to_batch(&self) -> Result<PlanRef> {
823        panic!(
824            "Method not available for `LogicalMultiJoin` which is a placeholder node with \
825             a temporary lifetime. It only facilitates join reordering during logical planning."
826        )
827    }
828}
829
830impl ColPrunable for LogicalMultiJoin {
831    fn prune_col(&self, _required_cols: &[usize], _ctx: &mut ColumnPruningContext) -> PlanRef {
832        panic!(
833            "Method not available for `LogicalMultiJoin` which is a placeholder node with \
834             a temporary lifetime. It only facilitates join reordering during logical planning."
835        )
836    }
837}
838
839impl ExprRewritable for LogicalMultiJoin {
840    fn rewrite_exprs(&self, _r: &mut dyn ExprRewriter) -> PlanRef {
841        panic!(
842            "Method not available for `LogicalMultiJoin` which is a placeholder node with \
843             a temporary lifetime. It only facilitates join reordering during logical planning."
844        )
845    }
846}
847
848impl ExprVisitable for LogicalMultiJoin {
849    fn visit_exprs(&self, _v: &mut dyn ExprVisitor) {
850        panic!(
851            "Method not available for `LogicalMultiJoin` which is a placeholder node with \
852             a temporary lifetime. It only facilitates join reordering during logical planning."
853        )
854    }
855}
856
857impl PredicatePushdown for LogicalMultiJoin {
858    fn predicate_pushdown(
859        &self,
860        _predicate: Condition,
861        _ctx: &mut PredicatePushdownContext,
862    ) -> PlanRef {
863        panic!(
864            "Method not available for `LogicalMultiJoin` which is a placeholder node with \
865             a temporary lifetime. It only facilitates join reordering during logical planning."
866        )
867    }
868}
869
870#[cfg(test)]
871mod test {
872    use std::collections::HashSet;
873
874    use risingwave_common::catalog::Field;
875    use risingwave_common::types::DataType;
876    use risingwave_pb::expr::expr_node::Type;
877
878    use super::*;
879    use crate::expr::InputRef;
880    use crate::optimizer::optimizer_context::OptimizerContext;
881    use crate::optimizer::plan_node::LogicalValues;
882    use crate::optimizer::plan_node::generic::GenericPlanRef;
883    use crate::optimizer::property::FunctionalDependency;
884    #[tokio::test]
885    async fn fd_derivation_multi_join() {
886        // t1: [v0, v1], t2: [v2, v3, v4], t3: [v5, v6]
887        // FD: v0 --> v1, v2 --> { v3, v4 }, {} --> v5
888        // On: v0 = 0 AND v1 = v3 AND v4 = v5
889        //
890        // Output: [v0, v1, v4, v5]
891        // FD: v0 --> v1, {} --> v0, {} --> v5, v4 --> v5, v5 --> v4
892        let ctx = OptimizerContext::mock().await;
893        let t1 = {
894            let fields: Vec<Field> = vec![
895                Field::with_name(DataType::Int32, "v0"),
896                Field::with_name(DataType::Int32, "v1"),
897            ];
898            let mut values = LogicalValues::new(vec![], Schema { fields }, ctx.clone());
899            // 0 --> 1
900            values
901                .base
902                .functional_dependency_mut()
903                .add_functional_dependency_by_column_indices(&[0], &[1]);
904            values
905        };
906        let t2 = {
907            let fields: Vec<Field> = vec![
908                Field::with_name(DataType::Int32, "v2"),
909                Field::with_name(DataType::Int32, "v3"),
910                Field::with_name(DataType::Int32, "v4"),
911            ];
912            let mut values = LogicalValues::new(vec![], Schema { fields }, ctx.clone());
913            // 0 --> 1, 2
914            values
915                .base
916                .functional_dependency_mut()
917                .add_functional_dependency_by_column_indices(&[0], &[1, 2]);
918            values
919        };
920        let t3 = {
921            let fields: Vec<Field> = vec![
922                Field::with_name(DataType::Int32, "v5"),
923                Field::with_name(DataType::Int32, "v6"),
924            ];
925            let mut values = LogicalValues::new(vec![], Schema { fields }, ctx);
926            // {} --> 0
927            values
928                .base
929                .functional_dependency_mut()
930                .add_functional_dependency_by_column_indices(&[], &[0]);
931            values
932        };
933        // On: v0 = 0 AND v1 = v3 AND v4 = v5
934        let on: ExprImpl = FunctionCall::new(
935            Type::And,
936            vec![
937                FunctionCall::new(
938                    Type::Equal,
939                    vec![
940                        InputRef::new(0, DataType::Int32).into(),
941                        ExprImpl::literal_int(0),
942                    ],
943                )
944                .unwrap()
945                .into(),
946                FunctionCall::new(
947                    Type::And,
948                    vec![
949                        FunctionCall::new(
950                            Type::Equal,
951                            vec![
952                                InputRef::new(1, DataType::Int32).into(),
953                                InputRef::new(3, DataType::Int32).into(),
954                            ],
955                        )
956                        .unwrap()
957                        .into(),
958                        FunctionCall::new(
959                            Type::Equal,
960                            vec![
961                                InputRef::new(4, DataType::Int32).into(),
962                                InputRef::new(5, DataType::Int32).into(),
963                            ],
964                        )
965                        .unwrap()
966                        .into(),
967                    ],
968                )
969                .unwrap()
970                .into(),
971            ],
972        )
973        .unwrap()
974        .into();
975        let multi_join = LogicalMultiJoin::new(
976            vec![t1.into(), t2.into(), t3.into()],
977            Condition::with_expr(on),
978            vec![0, 1, 4, 5],
979        );
980        let expected_fd_set: HashSet<_> = [
981            FunctionalDependency::with_indices(4, &[0], &[1]),
982            FunctionalDependency::with_indices(4, &[], &[0, 3]),
983            FunctionalDependency::with_indices(4, &[2], &[3]),
984            FunctionalDependency::with_indices(4, &[3], &[2]),
985        ]
986        .into_iter()
987        .collect();
988        let fd_set: HashSet<_> = multi_join
989            .functional_dependency()
990            .as_dependencies()
991            .iter()
992            .cloned()
993            .collect();
994        assert_eq!(expected_fd_set, fd_set);
995    }
996}