risingwave_frontend/optimizer/rule/
left_deep_tree_join_ordering_rule.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use super::super::plan_node::*;
16use super::Rule;
17use crate::optimizer::rule::BoxedRule;
18
19/// Reorders a multi join into a left deep join via the heuristic ordering
20pub struct LeftDeepTreeJoinOrderingRule {}
21
22impl Rule for LeftDeepTreeJoinOrderingRule {
23    fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
24        let join = plan.as_logical_multi_join()?;
25        // check if join is inner and can be merged into multijoin
26        let join_ordering = join.heuristic_ordering().ok()?; // maybe panic here instead?
27        let left_deep_join = join.as_reordered_left_deep_join(&join_ordering);
28        Some(left_deep_join)
29    }
30}
31
32impl LeftDeepTreeJoinOrderingRule {
33    pub fn create() -> BoxedRule {
34        Box::new(LeftDeepTreeJoinOrderingRule {})
35    }
36}
37
38#[cfg(test)]
39mod tests {
40
41    use risingwave_common::catalog::{Field, Schema};
42    use risingwave_common::types::DataType;
43    use risingwave_common::util::iter_util::ZipEqFast;
44    use risingwave_pb::expr::expr_node::Type;
45    use risingwave_pb::plan_common::JoinType;
46
47    use super::*;
48    use crate::expr::{ExprImpl, FunctionCall, InputRef};
49    use crate::optimizer::optimizer_context::OptimizerContext;
50    use crate::optimizer::plan_node::generic::GenericPlanRef;
51    use crate::utils::Condition;
52
53    #[tokio::test]
54    async fn test_heuristic_join_ordering_from_multijoin() {
55        // Converts a join graph
56        // A-B C
57        //
58        // with initial ordering:
59        //
60        //      inner
61        //     /   |
62        //  cross  B
63        //  / |
64        // A  C
65        //
66        // to:
67        //
68        //     cross
69        //     /   |
70        //  inner  C
71        //  / |
72        // A  B
73
74        let ty = DataType::Int32;
75        let ctx = OptimizerContext::mock().await;
76        let fields: Vec<Field> = (1..10)
77            .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
78            .collect();
79        let relation_a = LogicalValues::new(
80            vec![],
81            Schema {
82                fields: fields[0..3].to_vec(),
83            },
84            ctx.clone(),
85        );
86        let relation_c = LogicalValues::new(
87            vec![],
88            Schema {
89                fields: fields[3..6].to_vec(),
90            },
91            ctx.clone(),
92        );
93        let relation_b = LogicalValues::new(
94            vec![],
95            Schema {
96                fields: fields[6..9].to_vec(),
97            },
98            ctx,
99        );
100
101        let join_type = JoinType::Inner;
102        let join_0 = LogicalJoin::new(
103            relation_a.clone().into(),
104            relation_c.clone().into(),
105            join_type,
106            Condition::true_cond(),
107        );
108
109        let on_1: ExprImpl = ExprImpl::FunctionCall(Box::new(
110            FunctionCall::new(
111                Type::Equal,
112                vec![
113                    ExprImpl::InputRef(Box::new(InputRef::new(2, ty.clone()))),
114                    ExprImpl::InputRef(Box::new(InputRef::new(8, ty))),
115                ],
116            )
117            .unwrap(),
118        ));
119        let join_1 = LogicalJoin::new(
120            join_0.into(),
121            relation_b.clone().into(),
122            join_type,
123            Condition::with_expr(on_1),
124        );
125        let multijoin_builder = LogicalMultiJoinBuilder::new(join_1.into());
126        let multi_join = multijoin_builder.build();
127        for (input, schema) in multi_join.inputs().iter().zip_eq_fast(vec![
128            relation_a.schema(),
129            relation_c.schema(),
130            relation_b.schema(),
131        ]) {
132            assert_eq!(input.schema(), schema);
133        }
134
135        assert_eq!(multi_join.heuristic_ordering().unwrap(), vec![0, 2, 1]);
136    }
137}