risingwave_frontend/optimizer/rule/
merge_multijoin_rule.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use super::super::plan_node::*;
16use super::Rule;
17use crate::optimizer::rule::BoxedRule;
18
19/// Merges adjacent inner joins, filters and projections into a single `LogicalMultiJoin`.
20pub struct MergeMultiJoinRule {}
21
22impl Rule for MergeMultiJoinRule {
23    fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
24        let multijoin_builder = LogicalMultiJoinBuilder::new(plan);
25        if multijoin_builder.inputs().len() <= 2 {
26            return None;
27        }
28        Some(multijoin_builder.build().into())
29    }
30}
31
32impl MergeMultiJoinRule {
33    pub fn create() -> BoxedRule {
34        Box::new(MergeMultiJoinRule {})
35    }
36}
37
38#[cfg(test)]
39mod tests {
40    use risingwave_common::catalog::{Field, Schema};
41    use risingwave_common::types::DataType;
42    use risingwave_common::util::iter_util::ZipEqFast;
43    use risingwave_pb::expr::expr_node::Type;
44    use risingwave_pb::plan_common::JoinType;
45
46    use super::*;
47    use crate::expr::{ExprImpl, FunctionCall, InputRef};
48    use crate::optimizer::optimizer_context::OptimizerContext;
49    use crate::optimizer::plan_node::generic::GenericPlanRef;
50    use crate::utils::Condition;
51
52    #[tokio::test]
53    async fn test_merge_multijoin_join() {
54        let ty = DataType::Int32;
55        let ctx = OptimizerContext::mock().await;
56        let fields: Vec<Field> = (1..10)
57            .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
58            .collect();
59        let left = LogicalValues::new(
60            vec![],
61            Schema {
62                fields: fields[0..3].to_vec(),
63            },
64            ctx.clone(),
65        );
66        let right = LogicalValues::new(
67            vec![],
68            Schema {
69                fields: fields[3..6].to_vec(),
70            },
71            ctx.clone(),
72        );
73        let mid = LogicalValues::new(
74            vec![],
75            Schema {
76                fields: fields[6..9].to_vec(),
77            },
78            ctx,
79        );
80
81        let join_type = JoinType::Inner;
82        let on_0: ExprImpl = ExprImpl::FunctionCall(Box::new(
83            FunctionCall::new(
84                Type::Equal,
85                vec![
86                    ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
87                    ExprImpl::InputRef(Box::new(InputRef::new(3, ty.clone()))),
88                ],
89            )
90            .unwrap(),
91        ));
92        let join_0 = LogicalJoin::new(
93            left.clone().into(),
94            right.clone().into(),
95            join_type,
96            Condition::true_cond(),
97        );
98        let filter_on_join = LogicalFilter::new(join_0.into(), Condition::with_expr(on_0));
99
100        let on_1: ExprImpl = ExprImpl::FunctionCall(Box::new(
101            FunctionCall::new(
102                Type::Equal,
103                vec![
104                    ExprImpl::InputRef(Box::new(InputRef::new(2, ty.clone()))),
105                    ExprImpl::InputRef(Box::new(InputRef::new(8, ty.clone()))),
106                ],
107            )
108            .unwrap(),
109        ));
110        let join_1 = LogicalJoin::new(
111            mid.clone().into(),
112            filter_on_join.into(),
113            join_type,
114            Condition::with_expr(on_1.clone()),
115        );
116        let multijoin_builder = LogicalMultiJoinBuilder::new(join_1.into());
117        let multi_join = multijoin_builder.build();
118
119        for (input, schema) in multi_join.inputs().iter().zip_eq_fast(vec![
120            mid.schema(),
121            left.schema(),
122            right.schema(),
123        ]) {
124            assert_eq!(input.schema(), schema);
125        }
126
127        assert_eq!(multi_join.on().conjunctions.len(), 2);
128        assert!(multi_join.on().conjunctions.contains(&on_1));
129
130        let on_0_shifted: ExprImpl = ExprImpl::FunctionCall(Box::new(
131            FunctionCall::new(
132                Type::Equal,
133                vec![
134                    ExprImpl::InputRef(Box::new(InputRef::new(4, ty.clone()))),
135                    ExprImpl::InputRef(Box::new(InputRef::new(6, ty))),
136                ],
137            )
138            .unwrap(),
139        ));
140
141        assert!(multi_join.on().conjunctions.contains(&on_0_shifted));
142    }
143}