1use std::cmp::Ordering;
16use std::collections::{BTreeMap, BTreeSet, VecDeque};
17
18use itertools::Itertools;
19use pretty_xmlish::{Pretty, XmlNode};
20use risingwave_common::catalog::Schema;
21use risingwave_pb::plan_common::JoinType;
22
23use super::utils::{Distill, childless_record};
24use super::{
25 ColPrunable, ExprRewritable, Logical, LogicalFilter, LogicalJoin, LogicalProject, PlanBase,
26 PlanNodeType, PlanRef, PlanTreeNodeBinary, PlanTreeNodeUnary, PredicatePushdown, ToBatch,
27 ToStream,
28};
29use crate::error::{ErrorCode, Result, RwError};
30use crate::expr::{ExprImpl, ExprRewriter, ExprType, ExprVisitor, FunctionCall};
31use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
32use crate::optimizer::plan_node::{
33 ColumnPruningContext, PlanTreeNode, PredicatePushdownContext, RewriteStreamContext,
34 ToStreamContext,
35};
36use crate::optimizer::plan_visitor::TemporalJoinValidator;
37use crate::optimizer::property::FunctionalDependencySet;
38use crate::utils::{
39 ColIndexMapping, ColIndexMappingRewriteExt, Condition, ConditionDisplay,
40 ConnectedComponentLabeller,
41};
42
43#[derive(Debug, Clone, PartialEq, Eq, Hash)]
50pub struct LogicalMultiJoin {
51 pub base: PlanBase<Logical>,
52 inputs: Vec<PlanRef>,
53 on: Condition,
54 output_indices: Vec<usize>,
55 inner2output: ColIndexMapping,
56 inner_o2i_mapping: Vec<(usize, usize)>,
61 inner_i2o_mappings: Vec<ColIndexMapping>,
62}
63
64impl Distill for LogicalMultiJoin {
65 fn distill<'a>(&self) -> XmlNode<'a> {
66 let fields = (self.inputs.iter())
67 .flat_map(|input| input.schema().fields.clone())
68 .collect();
69 let input_schema = Schema { fields };
70 let cond = Pretty::display(&ConditionDisplay {
71 condition: self.on(),
72 input_schema: &input_schema,
73 });
74 childless_record("LogicalMultiJoin", vec![("on", cond)])
75 }
76}
77
78#[derive(Debug, Clone, PartialEq, Eq, Hash)]
79pub struct LogicalMultiJoinBuilder {
80 output_indices: Vec<usize>,
81 conjunctions: Vec<ExprImpl>,
84 inputs: Vec<PlanRef>,
85 tot_input_col_num: usize,
86}
87
88impl LogicalMultiJoinBuilder {
89 pub fn add_predicate_above(&mut self, exprs: impl Iterator<Item = ExprImpl>) {
92 let mut mapping = ColIndexMapping::new(
93 self.output_indices.iter().map(|i| Some(*i)).collect(),
94 self.tot_input_col_num,
95 );
96 self.conjunctions
97 .extend(exprs.map(|expr| mapping.rewrite_expr(expr)));
98 }
99
100 pub fn build(self) -> LogicalMultiJoin {
101 LogicalMultiJoin::new(
102 self.inputs,
103 Condition {
104 conjunctions: self.conjunctions,
105 },
106 self.output_indices,
107 )
108 }
109
110 pub fn into_parts(self) -> (Vec<usize>, Vec<ExprImpl>, Vec<PlanRef>, usize) {
111 (
112 self.output_indices,
113 self.conjunctions,
114 self.inputs,
115 self.tot_input_col_num,
116 )
117 }
118
119 pub fn new(plan: PlanRef) -> LogicalMultiJoinBuilder {
120 match plan.node_type() {
121 PlanNodeType::LogicalJoin => Self::with_join(plan),
122 PlanNodeType::LogicalFilter => Self::with_filter(plan),
123 PlanNodeType::LogicalProject => Self::with_project(plan),
124 _ => Self::with_input(plan),
125 }
126 }
127
128 fn with_join(plan: PlanRef) -> LogicalMultiJoinBuilder {
129 let join: &LogicalJoin = plan.as_logical_join().unwrap();
130 if join.join_type() != JoinType::Inner {
131 return Self::with_input(plan);
132 }
133 let left = join.left();
134 let right = join.right();
135
136 if TemporalJoinValidator::exist_dangling_temporal_scan(left.clone()) {
137 return Self::with_input(plan);
138 }
139 if TemporalJoinValidator::exist_dangling_temporal_scan(right.clone()) {
140 return Self::with_input(plan);
141 }
142
143 let mut builder = Self::new(left);
144
145 let (r_output_indices, r_conjunctions, mut r_inputs, r_tot_input_col_num) =
146 Self::new(right).into_parts();
147
148 let mut shift_mapping = ColIndexMapping::with_shift_offset(
151 r_tot_input_col_num,
152 builder.tot_input_col_num as isize,
153 );
154 builder.inputs.append(&mut r_inputs);
155 builder.tot_input_col_num += r_tot_input_col_num;
156
157 builder.conjunctions.extend(
158 r_conjunctions
159 .into_iter()
160 .map(|expr| shift_mapping.rewrite_expr(expr)),
161 );
162
163 builder.output_indices.extend(
164 r_output_indices
165 .into_iter()
166 .map(|idx| shift_mapping.map(idx)),
167 );
168 builder.add_predicate_above(join.on().conjunctions.iter().cloned());
169
170 builder.output_indices = join
171 .output_indices()
172 .iter()
173 .map(|idx| builder.output_indices[*idx])
174 .collect();
175 builder
176 }
177
178 fn with_filter(plan: PlanRef) -> LogicalMultiJoinBuilder {
179 let filter: &LogicalFilter = plan.as_logical_filter().unwrap();
180 let mut builder = Self::new(filter.input());
181 builder.add_predicate_above(filter.predicate().conjunctions.iter().cloned());
182 builder
183 }
184
185 fn with_project(plan: PlanRef) -> LogicalMultiJoinBuilder {
186 let proj: &LogicalProject = plan.as_logical_project().unwrap();
187 let output_indices = match proj.try_as_projection() {
188 Some(output_indices) => output_indices,
189 None => return Self::with_input(plan),
190 };
191 let mut builder = Self::new(proj.input());
192 builder.output_indices = output_indices
193 .into_iter()
194 .map(|i| builder.output_indices[i])
195 .collect();
196 builder
197 }
198
199 fn with_input(input: PlanRef) -> LogicalMultiJoinBuilder {
200 LogicalMultiJoinBuilder {
201 output_indices: (0..input.schema().len()).collect_vec(),
202 conjunctions: vec![],
203 tot_input_col_num: input.schema().len(),
204 inputs: vec![input],
205 }
206 }
207
208 pub fn inputs(&self) -> &[PlanRef] {
209 self.inputs.as_ref()
210 }
211}
212impl LogicalMultiJoin {
213 pub(crate) fn new(inputs: Vec<PlanRef>, on: Condition, output_indices: Vec<usize>) -> Self {
214 let input_schemas = inputs
215 .iter()
216 .map(|input| input.schema().clone())
217 .collect_vec();
218
219 let (inner_o2i_mapping, tot_col_num) = {
220 let mut inner_o2i_mapping = vec![];
221 let mut tot_col_num = 0;
222 for (input_idx, input_schema) in input_schemas.iter().enumerate() {
223 tot_col_num += input_schema.len();
224 for (col_idx, _field) in input_schema.fields().iter().enumerate() {
225 inner_o2i_mapping.push((input_idx, col_idx));
226 }
227 }
228 (inner_o2i_mapping, tot_col_num)
229 };
230 let inner2output = ColIndexMapping::with_remaining_columns(&output_indices, tot_col_num);
231
232 let schema = Schema {
233 fields: output_indices
234 .iter()
235 .map(|idx| inner_o2i_mapping[*idx])
236 .map(|(input_idx, col_idx)| input_schemas[input_idx].fields()[col_idx].clone())
237 .collect(),
238 };
239
240 let inner_i2o_mappings = {
241 let mut i2o_maps = vec![];
242 for input_schema in &input_schemas {
243 let map = vec![None; input_schema.len()];
244 i2o_maps.push(map);
245 }
246 for (out_idx, (input_idx, in_idx)) in inner_o2i_mapping.iter().enumerate() {
247 i2o_maps[*input_idx][*in_idx] = Some(out_idx);
248 }
249
250 i2o_maps
251 .into_iter()
252 .map(|map| ColIndexMapping::new(map, tot_col_num))
253 .collect_vec()
254 };
255
256 let pk_indices = Self::derive_stream_key(&inputs, &inner_i2o_mappings, &inner2output);
257 let functional_dependency = {
258 let mut fd_set = FunctionalDependencySet::new(tot_col_num);
259 let mut column_cnt: usize = 0;
260 let id_mapping = ColIndexMapping::identity(tot_col_num);
261 for i in &inputs {
262 let mapping =
263 ColIndexMapping::with_shift_offset(i.schema().len(), column_cnt as isize)
264 .composite(&id_mapping);
265 mapping
266 .rewrite_functional_dependency_set(i.functional_dependency().clone())
267 .into_dependencies()
268 .into_iter()
269 .for_each(|fd| fd_set.add_functional_dependency(fd));
270 column_cnt += i.schema().len();
271 }
272 for i in &on.conjunctions {
273 if let Some((col, _)) = i.as_eq_const() {
274 fd_set.add_constant_columns(&[col.index()])
275 } else if let Some((left, right)) = i.as_eq_cond() {
276 fd_set.add_functional_dependency_by_column_indices(
277 &[left.index()],
278 &[right.index()],
279 );
280 fd_set.add_functional_dependency_by_column_indices(
281 &[right.index()],
282 &[left.index()],
283 );
284 }
285 }
286 ColIndexMapping::with_remaining_columns(&output_indices, tot_col_num)
287 .rewrite_functional_dependency_set(fd_set)
288 };
289 let base =
290 PlanBase::new_logical(inputs[0].ctx(), schema, pk_indices, functional_dependency);
291
292 Self {
293 base,
294 inputs,
295 on,
296 output_indices,
297 inner2output,
298 inner_o2i_mapping,
299 inner_i2o_mappings,
300 }
301 }
302
303 fn derive_stream_key(
304 inputs: &[PlanRef],
305 inner_i2o_mappings: &[ColIndexMapping],
306 inner2output: &ColIndexMapping,
307 ) -> Option<Vec<usize>> {
308 let mut pk_indices = vec![];
310 for (i, input) in inputs.iter().enumerate() {
311 let input_stream_key = input.stream_key()?;
312 for input_pk_idx in input_stream_key {
313 pk_indices.push(inner_i2o_mappings[i].map(*input_pk_idx));
314 }
315 }
316 pk_indices
317 .into_iter()
318 .map(|col_idx| inner2output.try_map(col_idx))
319 .collect::<Option<Vec<_>>>()
320 }
321
322 pub fn on(&self) -> &Condition {
324 &self.on
325 }
326
327 pub fn clone_with_cond(&self, cond: Condition) -> Self {
329 Self::new(self.inputs.clone(), cond, self.output_indices.clone())
330 }
331}
332
333impl PlanTreeNode for LogicalMultiJoin {
334 fn inputs(&self) -> smallvec::SmallVec<[crate::optimizer::PlanRef; 2]> {
335 let mut vec = smallvec::SmallVec::new();
336 vec.extend(self.inputs.clone());
337 vec
338 }
339
340 fn clone_with_inputs(&self, inputs: &[crate::optimizer::PlanRef]) -> PlanRef {
341 Self::new(
342 inputs.to_vec(),
343 self.on().clone(),
344 self.output_indices.clone(),
345 )
346 .into()
347 }
348}
349
350impl LogicalMultiJoin {
351 pub fn as_reordered_left_deep_join(&self, join_ordering: &[usize]) -> PlanRef {
352 assert_eq!(join_ordering.len(), self.inputs.len());
353 assert!(!join_ordering.is_empty());
354
355 let base_plan = self.inputs[join_ordering[0]].clone();
356
357 let mut output = join_ordering[1..]
360 .iter()
361 .fold(base_plan, |join_chain, &index| {
362 LogicalJoin::new(
363 join_chain,
364 self.inputs[index].clone(),
365 JoinType::Inner,
366 Condition::true_cond(),
367 )
368 .into()
369 });
370
371 let total_col_num = self.inner2output.source_size();
372 let reorder_mapping = {
373 let mut reorder_mapping = vec![None; total_col_num];
374 join_ordering
375 .iter()
376 .cloned()
377 .flat_map(|input_idx| {
378 (0..self.inputs[input_idx].schema().len())
379 .map(move |col_idx| self.inner_i2o_mappings[input_idx].map(col_idx))
380 })
381 .enumerate()
382 .for_each(|(tar, src)| reorder_mapping[src] = Some(tar));
383 reorder_mapping
384 };
385 output =
386 LogicalProject::with_out_col_idx(output, reorder_mapping.iter().map(|i| i.unwrap()))
387 .into();
388
389 output = LogicalFilter::create(output, self.on.clone());
392 output =
393 LogicalProject::with_out_col_idx(output, self.output_indices.iter().cloned()).into();
394
395 output
396 }
397
398 #[allow(clippy::doc_overindented_list_items)]
399 pub(crate) fn heuristic_ordering(&self) -> Result<Vec<usize>> {
424 let mut labeller = ConnectedComponentLabeller::new(self.inputs.len());
425
426 let (eq_join_conditions, _) = self.on.clone().split_by_input_col_nums(
427 &self.input_col_nums(),
428 true,
430 );
431
432 for k in eq_join_conditions.keys() {
434 labeller.add_edge(k.0, k.1);
435 }
436
437 let mut edge_sets: Vec<_> = labeller.into_edge_sets();
438
439 edge_sets.sort_by_key(|a| std::cmp::Reverse(a.len()));
441
442 let mut join_ordering = vec![];
443
444 for component in edge_sets {
445 let mut eq_cond_edges: Vec<(usize, usize)> = component.into_iter().collect();
446
447 eq_cond_edges.sort();
449
450 if eq_cond_edges.is_empty() {
451 break;
453 };
454
455 let edge = eq_cond_edges.remove(0);
456 join_ordering.extend(&vec![edge.0, edge.1]);
457
458 while !eq_cond_edges.is_empty() {
459 let mut found = vec![];
460 for (idx, edge) in eq_cond_edges.iter().enumerate() {
461 if join_ordering.contains(&edge.1) && join_ordering.contains(&edge.0) {
464 found.push(idx);
465 } else {
466 let new_input = if join_ordering.contains(&edge.0) {
469 edge.1
470 } else if join_ordering.contains(&edge.1) {
471 edge.0
472 } else {
473 continue;
474 };
475 join_ordering.push(new_input);
476 found.push(idx);
477 }
478 }
479 if found.is_empty() {
483 return Err(RwError::from(ErrorCode::InternalError(
484 "Connecting edge not found in join connected subgraph".into(),
485 )));
486 }
487 let mut idx = 0;
488 eq_cond_edges.retain(|_| {
489 let keep = !found.contains(&idx);
490 idx += 1;
491 keep
492 });
493 }
494 }
495 for i in 0..self.inputs.len() {
497 if !join_ordering.contains(&i) {
498 join_ordering.push(i);
499 }
500 }
501 Ok(join_ordering)
502 }
503
504 #[allow(clippy::doc_overindented_list_items)]
505 pub fn as_bushy_tree_join(&self) -> Result<PlanRef> {
514 let (nodes, condition) = self.get_join_graph()?;
515
516 if nodes.is_empty() {
517 return Err(RwError::from(ErrorCode::InternalError(
518 "empty multi-join graph".into(),
519 )));
520 }
521
522 let mut optimized_bushy_tree: Option<(GraphNode, Vec<GraphNode>)> = None;
523 let mut que: VecDeque<(BTreeMap<usize, GraphNode>, Vec<GraphNode>)> =
524 VecDeque::from([(nodes, vec![])]);
525
526 while let Some((mut nodes, mut isolated)) = que.pop_front() {
527 if nodes.len() == 1 {
528 let node = nodes.into_values().next().unwrap();
529
530 if let Some((old, _)) = &optimized_bushy_tree {
531 if node.join_tree.height < old.join_tree.height {
532 optimized_bushy_tree = Some((node, isolated));
533 }
534 } else {
535 optimized_bushy_tree = Some((node, isolated));
536 }
537 continue;
538 } else if nodes.is_empty() {
539 if optimized_bushy_tree.is_none() {
540 let base = isolated.pop().unwrap();
541 optimized_bushy_tree = Some((base, isolated));
542 }
543 continue;
544 }
545
546 let (idx, _) = nodes
547 .iter()
548 .min_by(
549 |(_, x), (_, y)| match x.relations.len().cmp(&y.relations.len()) {
550 Ordering::Less => Ordering::Less,
551 Ordering::Greater => Ordering::Greater,
552 Ordering::Equal => x.join_tree.height.cmp(&y.join_tree.height),
553 },
554 )
555 .unwrap();
556 let n_id = *idx;
557
558 let n = nodes.get(&n_id).unwrap();
559 if n.relations.is_empty() {
560 let n = nodes.remove(&n_id).unwrap();
561 isolated.push(n);
562 que.push_back((nodes, isolated));
563 continue;
564 }
565
566 let mut relations = nodes
567 .get_mut(&n_id)
568 .unwrap()
569 .relations
570 .iter()
571 .cloned()
572 .collect_vec();
573 relations.sort_by(|a, b| {
574 let a = nodes.get(a).unwrap();
575 let b = nodes.get(b).unwrap();
576 match a.join_tree.height.cmp(&b.join_tree.height) {
577 Ordering::Equal => a.id.cmp(&b.id),
578 other => other,
579 }
580 });
581
582 for merge_node_id in &relations {
583 let mut nodes = nodes.clone();
584 let n = nodes.remove(&n_id).unwrap();
585
586 for adj_node_id in &n.relations {
587 if adj_node_id != merge_node_id {
588 let adj_node = nodes.get_mut(adj_node_id).unwrap();
589 adj_node.relations.remove(&n_id);
590 adj_node.relations.insert(*merge_node_id);
591 let merge_node = nodes.get_mut(merge_node_id).unwrap();
592 merge_node.relations.insert(*adj_node_id);
593 }
594 }
595
596 let merge_node = nodes.get_mut(merge_node_id).unwrap();
597 merge_node.relations.remove(&n_id);
598 let l_tree = n.join_tree.clone();
599 let r_tree = std::mem::take(&mut merge_node.join_tree);
600 let new_height = usize::max(l_tree.height, r_tree.height) + 1;
601
602 if let Some(min_height) = optimized_bushy_tree
603 .as_ref()
604 .map(|(t, _)| t.join_tree.height)
605 && min_height < new_height
606 {
607 continue;
608 }
609
610 merge_node.join_tree = JoinTreeNode {
611 idx: None,
612 left: Some(Box::new(l_tree)),
613 right: Some(Box::new(r_tree)),
614 height: new_height,
615 };
616
617 que.push_back((nodes, isolated.clone()));
618 }
619 }
620
621 let mut join_ordering = vec![];
623 let mut output = if let Some((optimized_bushy_tree, isolated)) = optimized_bushy_tree {
624 let optimized_bushy_tree =
625 isolated
626 .into_iter()
627 .fold(optimized_bushy_tree, |chain, n| GraphNode {
628 id: n.id,
629 relations: BTreeSet::default(),
630 join_tree: JoinTreeNode {
631 height: chain.join_tree.height.max(n.join_tree.height) + 1,
632 idx: None,
633 left: Some(Box::new(chain.join_tree)),
634 right: Some(Box::new(n.join_tree)),
635 },
636 });
637 self.create_logical_join(optimized_bushy_tree.join_tree, &mut join_ordering)?
638 } else {
639 return Err(RwError::from(ErrorCode::InternalError(
640 "no plan remain".into(),
641 )));
642 };
643
644 let total_col_num = self.inner2output.source_size();
645 let reorder_mapping = {
646 let mut reorder_mapping = vec![None; total_col_num];
647
648 join_ordering
649 .iter()
650 .cloned()
651 .flat_map(|input_idx| {
652 (0..self.inputs[input_idx].schema().len())
653 .map(move |col_idx| self.inner_i2o_mappings[input_idx].map(col_idx))
654 })
655 .enumerate()
656 .for_each(|(tar, src)| reorder_mapping[src] = Some(tar));
657 reorder_mapping
658 };
659 output =
660 LogicalProject::with_out_col_idx(output, reorder_mapping.iter().map(|i| i.unwrap()))
661 .into();
662
663 output = LogicalFilter::create(output, condition);
664 output =
665 LogicalProject::with_out_col_idx(output, self.output_indices.iter().cloned()).into();
666 Ok(output)
667 }
668
669 pub(crate) fn input_col_nums(&self) -> Vec<usize> {
670 self.inputs.iter().map(|i| i.schema().len()).collect()
671 }
672
673 fn get_join_graph(&self) -> Result<(BTreeMap<usize, GraphNode>, Condition)> {
675 let mut nodes: BTreeMap<_, _> = (0..self.inputs.len())
676 .map(|idx| GraphNode {
677 id: idx,
678 relations: BTreeSet::default(),
679 join_tree: JoinTreeNode {
680 idx: Some(idx),
681 left: None,
682 right: None,
683 height: 0,
684 },
685 })
686 .enumerate()
687 .collect();
688
689 let condition = self.on.clone();
690 let condition = self.eq_condition_derivation(condition)?;
691 let (eq_join_conditions, _) = condition
692 .clone()
693 .split_by_input_col_nums(&self.input_col_nums(), true);
694
695 for ((src, dst), _) in eq_join_conditions {
696 nodes.get_mut(&src).unwrap().relations.insert(dst);
697 nodes.get_mut(&dst).unwrap().relations.insert(src);
698 }
699
700 Ok((nodes, condition))
701 }
702
703 fn eq_condition_derivation(&self, mut condition: Condition) -> Result<Condition> {
705 let (eq_join_conditions, _) = condition
706 .clone()
707 .split_by_input_col_nums(&self.input_col_nums(), true);
708
709 let mut new_conj: BTreeMap<usize, BTreeSet<usize>> = BTreeMap::new();
710 let mut input_ref_map = BTreeMap::new();
711
712 for con in eq_join_conditions.values() {
713 for conj in &con.conjunctions {
714 let (l, r) = conj.as_eq_cond().unwrap();
715 new_conj.entry(l.index).or_default().insert(r.index);
716 new_conj.entry(r.index).or_default().insert(l.index);
717 input_ref_map.insert(l.index, Some(l));
718 input_ref_map.insert(r.index, Some(r));
719 }
720 }
721
722 let mut new_pairs = BTreeSet::new();
723
724 for conjs in new_conj.values() {
725 if conjs.len() < 2 {
726 continue;
727 }
728
729 let conjs = conjs.iter().copied().collect_vec();
730 for i in 0..conjs.len() {
731 for j in i + 1..conjs.len() {
732 if !new_conj.get(&conjs[i]).unwrap().contains(&conjs[j]) {
733 if conjs[i] < conjs[j] {
734 new_pairs.insert((conjs[i], conjs[j]));
735 } else {
736 new_pairs.insert((conjs[j], conjs[i]));
737 }
738 }
739 }
740 }
741 }
742 for (i, j) in new_pairs {
743 condition
744 .conjunctions
745 .push(ExprImpl::FunctionCall(Box::new(FunctionCall::new(
746 ExprType::Equal,
747 vec![
748 ExprImpl::InputRef(Box::new(
749 input_ref_map.get(&i).unwrap().as_ref().unwrap().clone(),
750 )),
751 ExprImpl::InputRef(Box::new(
752 input_ref_map.get(&j).unwrap().as_ref().unwrap().clone(),
753 )),
754 ],
755 )?)));
756 }
757 Ok(condition)
758 }
759
760 fn create_logical_join(
762 &self,
763 mut join_tree: JoinTreeNode,
764 join_ordering: &mut Vec<usize>,
765 ) -> Result<PlanRef> {
766 Ok(match (join_tree.left.take(), join_tree.right.take()) {
767 (Some(l), Some(r)) => LogicalJoin::new(
768 self.create_logical_join(*l, join_ordering)?,
769 self.create_logical_join(*r, join_ordering)?,
770 JoinType::Inner,
771 Condition::true_cond(),
772 )
773 .into(),
774 (None, None) => {
775 if let Some(idx) = join_tree.idx {
776 join_ordering.push(idx);
777 self.inputs[idx].clone()
778 } else {
779 return Err(RwError::from(ErrorCode::InternalError(
780 "id of the leaf node not found in the join tree".into(),
781 )));
782 }
783 }
784 (_, _) => {
785 return Err(RwError::from(ErrorCode::InternalError(
786 "only leaf node can have None subtree".into(),
787 )));
788 }
789 })
790 }
791}
792
793#[derive(Clone, Default, Debug)]
795struct JoinTreeNode {
796 idx: Option<usize>,
797 left: Option<Box<JoinTreeNode>>,
798 right: Option<Box<JoinTreeNode>>,
799 height: usize,
800}
801
802#[derive(Clone, Debug)]
804struct GraphNode {
805 id: usize,
806 join_tree: JoinTreeNode,
807 relations: BTreeSet<usize>,
808}
809
810impl ToStream for LogicalMultiJoin {
811 fn logical_rewrite_for_stream(
812 &self,
813 _ctx: &mut RewriteStreamContext,
814 ) -> Result<(PlanRef, ColIndexMapping)> {
815 panic!(
816 "Method not available for `LogicalMultiJoin` which is a placeholder node with \
817 a temporary lifetime. It only facilitates join reordering during logical planning."
818 )
819 }
820
821 fn to_stream(&self, _ctx: &mut ToStreamContext) -> Result<PlanRef> {
822 panic!(
823 "Method not available for `LogicalMultiJoin` which is a placeholder node with \
824 a temporary lifetime. It only facilitates join reordering during logical planning."
825 )
826 }
827}
828
829impl ToBatch for LogicalMultiJoin {
830 fn to_batch(&self) -> Result<PlanRef> {
831 panic!(
832 "Method not available for `LogicalMultiJoin` which is a placeholder node with \
833 a temporary lifetime. It only facilitates join reordering during logical planning."
834 )
835 }
836}
837
838impl ColPrunable for LogicalMultiJoin {
839 fn prune_col(&self, _required_cols: &[usize], _ctx: &mut ColumnPruningContext) -> PlanRef {
840 panic!(
841 "Method not available for `LogicalMultiJoin` which is a placeholder node with \
842 a temporary lifetime. It only facilitates join reordering during logical planning."
843 )
844 }
845}
846
847impl ExprRewritable for LogicalMultiJoin {
848 fn rewrite_exprs(&self, _r: &mut dyn ExprRewriter) -> PlanRef {
849 panic!(
850 "Method not available for `LogicalMultiJoin` which is a placeholder node with \
851 a temporary lifetime. It only facilitates join reordering during logical planning."
852 )
853 }
854}
855
856impl ExprVisitable for LogicalMultiJoin {
857 fn visit_exprs(&self, _v: &mut dyn ExprVisitor) {
858 panic!(
859 "Method not available for `LogicalMultiJoin` which is a placeholder node with \
860 a temporary lifetime. It only facilitates join reordering during logical planning."
861 )
862 }
863}
864
865impl PredicatePushdown for LogicalMultiJoin {
866 fn predicate_pushdown(
867 &self,
868 _predicate: Condition,
869 _ctx: &mut PredicatePushdownContext,
870 ) -> PlanRef {
871 panic!(
872 "Method not available for `LogicalMultiJoin` which is a placeholder node with \
873 a temporary lifetime. It only facilitates join reordering during logical planning."
874 )
875 }
876}
877
878#[cfg(test)]
879mod test {
880 use std::collections::HashSet;
881
882 use risingwave_common::catalog::Field;
883 use risingwave_common::types::DataType;
884 use risingwave_pb::expr::expr_node::Type;
885
886 use super::*;
887 use crate::expr::InputRef;
888 use crate::optimizer::optimizer_context::OptimizerContext;
889 use crate::optimizer::plan_node::LogicalValues;
890 use crate::optimizer::plan_node::generic::GenericPlanRef;
891 use crate::optimizer::property::FunctionalDependency;
892 #[tokio::test]
893 async fn fd_derivation_multi_join() {
894 let ctx = OptimizerContext::mock().await;
901 let t1 = {
902 let fields: Vec<Field> = vec![
903 Field::with_name(DataType::Int32, "v0"),
904 Field::with_name(DataType::Int32, "v1"),
905 ];
906 let mut values = LogicalValues::new(vec![], Schema { fields }, ctx.clone());
907 values
909 .base
910 .functional_dependency_mut()
911 .add_functional_dependency_by_column_indices(&[0], &[1]);
912 values
913 };
914 let t2 = {
915 let fields: Vec<Field> = vec![
916 Field::with_name(DataType::Int32, "v2"),
917 Field::with_name(DataType::Int32, "v3"),
918 Field::with_name(DataType::Int32, "v4"),
919 ];
920 let mut values = LogicalValues::new(vec![], Schema { fields }, ctx.clone());
921 values
923 .base
924 .functional_dependency_mut()
925 .add_functional_dependency_by_column_indices(&[0], &[1, 2]);
926 values
927 };
928 let t3 = {
929 let fields: Vec<Field> = vec![
930 Field::with_name(DataType::Int32, "v5"),
931 Field::with_name(DataType::Int32, "v6"),
932 ];
933 let mut values = LogicalValues::new(vec![], Schema { fields }, ctx);
934 values
936 .base
937 .functional_dependency_mut()
938 .add_functional_dependency_by_column_indices(&[], &[0]);
939 values
940 };
941 let on: ExprImpl = FunctionCall::new(
943 Type::And,
944 vec![
945 FunctionCall::new(
946 Type::Equal,
947 vec![
948 InputRef::new(0, DataType::Int32).into(),
949 ExprImpl::literal_int(0),
950 ],
951 )
952 .unwrap()
953 .into(),
954 FunctionCall::new(
955 Type::And,
956 vec![
957 FunctionCall::new(
958 Type::Equal,
959 vec![
960 InputRef::new(1, DataType::Int32).into(),
961 InputRef::new(3, DataType::Int32).into(),
962 ],
963 )
964 .unwrap()
965 .into(),
966 FunctionCall::new(
967 Type::Equal,
968 vec![
969 InputRef::new(4, DataType::Int32).into(),
970 InputRef::new(5, DataType::Int32).into(),
971 ],
972 )
973 .unwrap()
974 .into(),
975 ],
976 )
977 .unwrap()
978 .into(),
979 ],
980 )
981 .unwrap()
982 .into();
983 let multi_join = LogicalMultiJoin::new(
984 vec![t1.into(), t2.into(), t3.into()],
985 Condition::with_expr(on),
986 vec![0, 1, 4, 5],
987 );
988 let expected_fd_set: HashSet<_> = [
989 FunctionalDependency::with_indices(4, &[0], &[1]),
990 FunctionalDependency::with_indices(4, &[], &[0, 3]),
991 FunctionalDependency::with_indices(4, &[2], &[3]),
992 FunctionalDependency::with_indices(4, &[3], &[2]),
993 ]
994 .into_iter()
995 .collect();
996 let fd_set: HashSet<_> = multi_join
997 .functional_dependency()
998 .as_dependencies()
999 .iter()
1000 .cloned()
1001 .collect();
1002 assert_eq!(expected_fd_set, fd_set);
1003 }
1004}