risingwave_frontend/optimizer/rule/
merge_multijoin_rule.rs1use super::prelude::{PlanRef, *};
16use crate::optimizer::plan_node::*;
17
18pub struct MergeMultiJoinRule {}
20
21impl Rule<Logical> for MergeMultiJoinRule {
22 fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
23 let multijoin_builder = LogicalMultiJoinBuilder::new(plan);
24 if multijoin_builder.inputs().len() <= 2 {
25 return None;
26 }
27 Some(multijoin_builder.build().into())
28 }
29}
30
31impl MergeMultiJoinRule {
32 pub fn create() -> BoxedRule {
33 Box::new(MergeMultiJoinRule {})
34 }
35}
36
37#[cfg(test)]
38mod tests {
39 use risingwave_common::catalog::{Field, Schema};
40 use risingwave_common::types::DataType;
41 use risingwave_common::util::iter_util::ZipEqFast;
42 use risingwave_pb::expr::expr_node::Type;
43 use risingwave_pb::plan_common::JoinType;
44
45 use super::*;
46 use crate::expr::{ExprImpl, FunctionCall, InputRef};
47 use crate::optimizer::optimizer_context::OptimizerContext;
48 use crate::optimizer::plan_node::generic::GenericPlanRef;
49 use crate::utils::Condition;
50
51 #[tokio::test]
52 async fn test_merge_multijoin_join() {
53 let ty = DataType::Int32;
54 let ctx = OptimizerContext::mock().await;
55 let fields: Vec<Field> = (1..10)
56 .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
57 .collect();
58 let left = LogicalValues::new(
59 vec![],
60 Schema {
61 fields: fields[0..3].to_vec(),
62 },
63 ctx.clone(),
64 );
65 let right = LogicalValues::new(
66 vec![],
67 Schema {
68 fields: fields[3..6].to_vec(),
69 },
70 ctx.clone(),
71 );
72 let mid = LogicalValues::new(
73 vec![],
74 Schema {
75 fields: fields[6..9].to_vec(),
76 },
77 ctx,
78 );
79
80 let join_type = JoinType::Inner;
81 let on_0: ExprImpl = ExprImpl::FunctionCall(Box::new(
82 FunctionCall::new(
83 Type::Equal,
84 vec![
85 ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
86 ExprImpl::InputRef(Box::new(InputRef::new(3, ty.clone()))),
87 ],
88 )
89 .unwrap(),
90 ));
91 let join_0 = LogicalJoin::new(
92 left.clone().into(),
93 right.clone().into(),
94 join_type,
95 Condition::true_cond(),
96 );
97 let filter_on_join = LogicalFilter::new(join_0.into(), Condition::with_expr(on_0));
98
99 let on_1: ExprImpl = ExprImpl::FunctionCall(Box::new(
100 FunctionCall::new(
101 Type::Equal,
102 vec![
103 ExprImpl::InputRef(Box::new(InputRef::new(2, ty.clone()))),
104 ExprImpl::InputRef(Box::new(InputRef::new(8, ty.clone()))),
105 ],
106 )
107 .unwrap(),
108 ));
109 let join_1 = LogicalJoin::new(
110 mid.clone().into(),
111 filter_on_join.into(),
112 join_type,
113 Condition::with_expr(on_1.clone()),
114 );
115 let multijoin_builder = LogicalMultiJoinBuilder::new(join_1.into());
116 let multi_join = multijoin_builder.build();
117
118 for (input, schema) in multi_join.inputs().iter().zip_eq_fast(vec![
119 mid.schema(),
120 left.schema(),
121 right.schema(),
122 ]) {
123 assert_eq!(input.schema(), schema);
124 }
125
126 assert_eq!(multi_join.on().conjunctions.len(), 2);
127 assert!(multi_join.on().conjunctions.contains(&on_1));
128
129 let on_0_shifted: ExprImpl = ExprImpl::FunctionCall(Box::new(
130 FunctionCall::new(
131 Type::Equal,
132 vec![
133 ExprImpl::InputRef(Box::new(InputRef::new(4, ty.clone()))),
134 ExprImpl::InputRef(Box::new(InputRef::new(6, ty))),
135 ],
136 )
137 .unwrap(),
138 ));
139
140 assert!(multi_join.on().conjunctions.contains(&on_0_shifted));
141 }
142}