risingwave_frontend/optimizer/rule/
merge_multijoin_rule.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use super::prelude::{PlanRef, *};
16use crate::optimizer::plan_node::*;
17
18/// Merges adjacent inner joins, filters and projections into a single `LogicalMultiJoin`.
19pub struct MergeMultiJoinRule {}
20
21impl Rule<Logical> for MergeMultiJoinRule {
22    fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
23        let multijoin_builder = LogicalMultiJoinBuilder::new(plan);
24        if multijoin_builder.inputs().len() <= 2 {
25            return None;
26        }
27        Some(multijoin_builder.build().into())
28    }
29}
30
31impl MergeMultiJoinRule {
32    pub fn create() -> BoxedRule {
33        Box::new(MergeMultiJoinRule {})
34    }
35}
36
37#[cfg(test)]
38mod tests {
39    use risingwave_common::catalog::{Field, Schema};
40    use risingwave_common::types::DataType;
41    use risingwave_common::util::iter_util::ZipEqFast;
42    use risingwave_pb::expr::expr_node::Type;
43    use risingwave_pb::plan_common::JoinType;
44
45    use super::*;
46    use crate::expr::{ExprImpl, FunctionCall, InputRef};
47    use crate::optimizer::optimizer_context::OptimizerContext;
48    use crate::optimizer::plan_node::generic::GenericPlanRef;
49    use crate::utils::Condition;
50
51    #[tokio::test]
52    async fn test_merge_multijoin_join() {
53        let ty = DataType::Int32;
54        let ctx = OptimizerContext::mock().await;
55        let fields: Vec<Field> = (1..10)
56            .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
57            .collect();
58        let left = LogicalValues::new(
59            vec![],
60            Schema {
61                fields: fields[0..3].to_vec(),
62            },
63            ctx.clone(),
64        );
65        let right = LogicalValues::new(
66            vec![],
67            Schema {
68                fields: fields[3..6].to_vec(),
69            },
70            ctx.clone(),
71        );
72        let mid = LogicalValues::new(
73            vec![],
74            Schema {
75                fields: fields[6..9].to_vec(),
76            },
77            ctx,
78        );
79
80        let join_type = JoinType::Inner;
81        let on_0: ExprImpl = ExprImpl::FunctionCall(Box::new(
82            FunctionCall::new(
83                Type::Equal,
84                vec![
85                    ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
86                    ExprImpl::InputRef(Box::new(InputRef::new(3, ty.clone()))),
87                ],
88            )
89            .unwrap(),
90        ));
91        let join_0 = LogicalJoin::new(
92            left.clone().into(),
93            right.clone().into(),
94            join_type,
95            Condition::true_cond(),
96        );
97        let filter_on_join = LogicalFilter::new(join_0.into(), Condition::with_expr(on_0));
98
99        let on_1: ExprImpl = ExprImpl::FunctionCall(Box::new(
100            FunctionCall::new(
101                Type::Equal,
102                vec![
103                    ExprImpl::InputRef(Box::new(InputRef::new(2, ty.clone()))),
104                    ExprImpl::InputRef(Box::new(InputRef::new(8, ty.clone()))),
105                ],
106            )
107            .unwrap(),
108        ));
109        let join_1 = LogicalJoin::new(
110            mid.clone().into(),
111            filter_on_join.into(),
112            join_type,
113            Condition::with_expr(on_1.clone()),
114        );
115        let multijoin_builder = LogicalMultiJoinBuilder::new(join_1.into());
116        let multi_join = multijoin_builder.build();
117
118        for (input, schema) in multi_join.inputs().iter().zip_eq_fast(vec![
119            mid.schema(),
120            left.schema(),
121            right.schema(),
122        ]) {
123            assert_eq!(input.schema(), schema);
124        }
125
126        assert_eq!(multi_join.on().conjunctions.len(), 2);
127        assert!(multi_join.on().conjunctions.contains(&on_1));
128
129        let on_0_shifted: ExprImpl = ExprImpl::FunctionCall(Box::new(
130            FunctionCall::new(
131                Type::Equal,
132                vec![
133                    ExprImpl::InputRef(Box::new(InputRef::new(4, ty.clone()))),
134                    ExprImpl::InputRef(Box::new(InputRef::new(6, ty))),
135                ],
136            )
137            .unwrap(),
138        ));
139
140        assert!(multi_join.on().conjunctions.contains(&on_0_shifted));
141    }
142}