risingwave_frontend/optimizer/rule/
merge_multijoin_rule.rs1use super::super::plan_node::*;
16use super::Rule;
17use crate::optimizer::rule::BoxedRule;
18
19pub struct MergeMultiJoinRule {}
21
22impl Rule for MergeMultiJoinRule {
23 fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
24 let multijoin_builder = LogicalMultiJoinBuilder::new(plan);
25 if multijoin_builder.inputs().len() <= 2 {
26 return None;
27 }
28 Some(multijoin_builder.build().into())
29 }
30}
31
32impl MergeMultiJoinRule {
33 pub fn create() -> BoxedRule {
34 Box::new(MergeMultiJoinRule {})
35 }
36}
37
38#[cfg(test)]
39mod tests {
40 use risingwave_common::catalog::{Field, Schema};
41 use risingwave_common::types::DataType;
42 use risingwave_common::util::iter_util::ZipEqFast;
43 use risingwave_pb::expr::expr_node::Type;
44 use risingwave_pb::plan_common::JoinType;
45
46 use super::*;
47 use crate::expr::{ExprImpl, FunctionCall, InputRef};
48 use crate::optimizer::optimizer_context::OptimizerContext;
49 use crate::optimizer::plan_node::generic::GenericPlanRef;
50 use crate::utils::Condition;
51
52 #[tokio::test]
53 async fn test_merge_multijoin_join() {
54 let ty = DataType::Int32;
55 let ctx = OptimizerContext::mock().await;
56 let fields: Vec<Field> = (1..10)
57 .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
58 .collect();
59 let left = LogicalValues::new(
60 vec![],
61 Schema {
62 fields: fields[0..3].to_vec(),
63 },
64 ctx.clone(),
65 );
66 let right = LogicalValues::new(
67 vec![],
68 Schema {
69 fields: fields[3..6].to_vec(),
70 },
71 ctx.clone(),
72 );
73 let mid = LogicalValues::new(
74 vec![],
75 Schema {
76 fields: fields[6..9].to_vec(),
77 },
78 ctx,
79 );
80
81 let join_type = JoinType::Inner;
82 let on_0: ExprImpl = ExprImpl::FunctionCall(Box::new(
83 FunctionCall::new(
84 Type::Equal,
85 vec![
86 ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
87 ExprImpl::InputRef(Box::new(InputRef::new(3, ty.clone()))),
88 ],
89 )
90 .unwrap(),
91 ));
92 let join_0 = LogicalJoin::new(
93 left.clone().into(),
94 right.clone().into(),
95 join_type,
96 Condition::true_cond(),
97 );
98 let filter_on_join = LogicalFilter::new(join_0.into(), Condition::with_expr(on_0));
99
100 let on_1: ExprImpl = ExprImpl::FunctionCall(Box::new(
101 FunctionCall::new(
102 Type::Equal,
103 vec![
104 ExprImpl::InputRef(Box::new(InputRef::new(2, ty.clone()))),
105 ExprImpl::InputRef(Box::new(InputRef::new(8, ty.clone()))),
106 ],
107 )
108 .unwrap(),
109 ));
110 let join_1 = LogicalJoin::new(
111 mid.clone().into(),
112 filter_on_join.into(),
113 join_type,
114 Condition::with_expr(on_1.clone()),
115 );
116 let multijoin_builder = LogicalMultiJoinBuilder::new(join_1.into());
117 let multi_join = multijoin_builder.build();
118
119 for (input, schema) in multi_join.inputs().iter().zip_eq_fast(vec![
120 mid.schema(),
121 left.schema(),
122 right.schema(),
123 ]) {
124 assert_eq!(input.schema(), schema);
125 }
126
127 assert_eq!(multi_join.on().conjunctions.len(), 2);
128 assert!(multi_join.on().conjunctions.contains(&on_1));
129
130 let on_0_shifted: ExprImpl = ExprImpl::FunctionCall(Box::new(
131 FunctionCall::new(
132 Type::Equal,
133 vec![
134 ExprImpl::InputRef(Box::new(InputRef::new(4, ty.clone()))),
135 ExprImpl::InputRef(Box::new(InputRef::new(6, ty))),
136 ],
137 )
138 .unwrap(),
139 ));
140
141 assert!(multi_join.on().conjunctions.contains(&on_0_shifted));
142 }
143}