risingwave_frontend/optimizer/rule/
left_deep_tree_join_ordering_rule.rs1use super::super::plan_node::*;
16use super::Rule;
17use crate::optimizer::rule::BoxedRule;
18
19pub struct LeftDeepTreeJoinOrderingRule {}
21
22impl Rule for LeftDeepTreeJoinOrderingRule {
23 fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
24 let join = plan.as_logical_multi_join()?;
25 let join_ordering = join.heuristic_ordering().ok()?; let left_deep_join = join.as_reordered_left_deep_join(&join_ordering);
28 Some(left_deep_join)
29 }
30}
31
32impl LeftDeepTreeJoinOrderingRule {
33 pub fn create() -> BoxedRule {
34 Box::new(LeftDeepTreeJoinOrderingRule {})
35 }
36}
37
38#[cfg(test)]
39mod tests {
40
41 use risingwave_common::catalog::{Field, Schema};
42 use risingwave_common::types::DataType;
43 use risingwave_common::util::iter_util::ZipEqFast;
44 use risingwave_pb::expr::expr_node::Type;
45 use risingwave_pb::plan_common::JoinType;
46
47 use super::*;
48 use crate::expr::{ExprImpl, FunctionCall, InputRef};
49 use crate::optimizer::optimizer_context::OptimizerContext;
50 use crate::optimizer::plan_node::generic::GenericPlanRef;
51 use crate::utils::Condition;
52
53 #[tokio::test]
54 async fn test_heuristic_join_ordering_from_multijoin() {
55 let ty = DataType::Int32;
75 let ctx = OptimizerContext::mock().await;
76 let fields: Vec<Field> = (1..10)
77 .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
78 .collect();
79 let relation_a = LogicalValues::new(
80 vec![],
81 Schema {
82 fields: fields[0..3].to_vec(),
83 },
84 ctx.clone(),
85 );
86 let relation_c = LogicalValues::new(
87 vec![],
88 Schema {
89 fields: fields[3..6].to_vec(),
90 },
91 ctx.clone(),
92 );
93 let relation_b = LogicalValues::new(
94 vec![],
95 Schema {
96 fields: fields[6..9].to_vec(),
97 },
98 ctx,
99 );
100
101 let join_type = JoinType::Inner;
102 let join_0 = LogicalJoin::new(
103 relation_a.clone().into(),
104 relation_c.clone().into(),
105 join_type,
106 Condition::true_cond(),
107 );
108
109 let on_1: ExprImpl = ExprImpl::FunctionCall(Box::new(
110 FunctionCall::new(
111 Type::Equal,
112 vec![
113 ExprImpl::InputRef(Box::new(InputRef::new(2, ty.clone()))),
114 ExprImpl::InputRef(Box::new(InputRef::new(8, ty))),
115 ],
116 )
117 .unwrap(),
118 ));
119 let join_1 = LogicalJoin::new(
120 join_0.into(),
121 relation_b.clone().into(),
122 join_type,
123 Condition::with_expr(on_1),
124 );
125 let multijoin_builder = LogicalMultiJoinBuilder::new(join_1.into());
126 let multi_join = multijoin_builder.build();
127 for (input, schema) in multi_join.inputs().iter().zip_eq_fast(vec![
128 relation_a.schema(),
129 relation_c.schema(),
130 relation_b.schema(),
131 ]) {
132 assert_eq!(input.schema(), schema);
133 }
134
135 assert_eq!(multi_join.heuristic_ordering().unwrap(), vec![0, 2, 1]);
136 }
137}