risingwave_frontend/optimizer/rule/
left_deep_tree_join_ordering_rule.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use super::prelude::{PlanRef, *};
16
17/// Reorders a multi join into a left deep join via the heuristic ordering
18pub struct LeftDeepTreeJoinOrderingRule {}
19
20impl Rule<Logical> for LeftDeepTreeJoinOrderingRule {
21    fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
22        let join = plan.as_logical_multi_join()?;
23        // check if join is inner and can be merged into multijoin
24        let join_ordering = join.heuristic_ordering().ok()?; // maybe panic here instead?
25        let left_deep_join = join.as_reordered_left_deep_join(&join_ordering);
26        Some(left_deep_join)
27    }
28}
29
30impl LeftDeepTreeJoinOrderingRule {
31    pub fn create() -> BoxedRule {
32        Box::new(LeftDeepTreeJoinOrderingRule {})
33    }
34}
35
36#[cfg(test)]
37mod tests {
38
39    use risingwave_common::catalog::{Field, Schema};
40    use risingwave_common::types::DataType;
41    use risingwave_common::util::iter_util::ZipEqFast;
42    use risingwave_pb::expr::expr_node::Type;
43    use risingwave_pb::plan_common::JoinType;
44
45    use crate::expr::{ExprImpl, FunctionCall, InputRef};
46    use crate::optimizer::optimizer_context::OptimizerContext;
47    use crate::optimizer::plan_node::generic::GenericPlanRef;
48    use crate::optimizer::plan_node::*;
49    use crate::utils::Condition;
50
51    #[tokio::test]
52    async fn test_heuristic_join_ordering_from_multijoin() {
53        // Converts a join graph
54        // A-B C
55        //
56        // with initial ordering:
57        //
58        //      inner
59        //     /   |
60        //  cross  B
61        //  / |
62        // A  C
63        //
64        // to:
65        //
66        //     cross
67        //     /   |
68        //  inner  C
69        //  / |
70        // A  B
71
72        let ty = DataType::Int32;
73        let ctx = OptimizerContext::mock().await;
74        let fields: Vec<Field> = (1..10)
75            .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
76            .collect();
77        let relation_a = LogicalValues::new(
78            vec![],
79            Schema {
80                fields: fields[0..3].to_vec(),
81            },
82            ctx.clone(),
83        );
84        let relation_c = LogicalValues::new(
85            vec![],
86            Schema {
87                fields: fields[3..6].to_vec(),
88            },
89            ctx.clone(),
90        );
91        let relation_b = LogicalValues::new(
92            vec![],
93            Schema {
94                fields: fields[6..9].to_vec(),
95            },
96            ctx,
97        );
98
99        let join_type = JoinType::Inner;
100        let join_0 = LogicalJoin::new(
101            relation_a.clone().into(),
102            relation_c.clone().into(),
103            join_type,
104            Condition::true_cond(),
105        );
106
107        let on_1: ExprImpl = ExprImpl::FunctionCall(Box::new(
108            FunctionCall::new(
109                Type::Equal,
110                vec![
111                    ExprImpl::InputRef(Box::new(InputRef::new(2, ty.clone()))),
112                    ExprImpl::InputRef(Box::new(InputRef::new(8, ty))),
113                ],
114            )
115            .unwrap(),
116        ));
117        let join_1 = LogicalJoin::new(
118            join_0.into(),
119            relation_b.clone().into(),
120            join_type,
121            Condition::with_expr(on_1),
122        );
123        let multijoin_builder = LogicalMultiJoinBuilder::new(join_1.into());
124        let multi_join = multijoin_builder.build();
125        for (input, schema) in multi_join.inputs().iter().zip_eq_fast(vec![
126            relation_a.schema(),
127            relation_c.schema(),
128            relation_b.schema(),
129        ]) {
130            assert_eq!(input.schema(), schema);
131        }
132
133        assert_eq!(multi_join.heuristic_ordering().unwrap(), vec![0, 2, 1]);
134    }
135}