risingwave_frontend/optimizer/rule/
left_deep_tree_join_ordering_rule.rs1use super::prelude::{PlanRef, *};
16
17pub struct LeftDeepTreeJoinOrderingRule {}
19
20impl Rule<Logical> for LeftDeepTreeJoinOrderingRule {
21 fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
22 let join = plan.as_logical_multi_join()?;
23 let join_ordering = join.heuristic_ordering().ok()?; let left_deep_join = join.as_reordered_left_deep_join(&join_ordering);
26 Some(left_deep_join)
27 }
28}
29
30impl LeftDeepTreeJoinOrderingRule {
31 pub fn create() -> BoxedRule {
32 Box::new(LeftDeepTreeJoinOrderingRule {})
33 }
34}
35
36#[cfg(test)]
37mod tests {
38
39 use risingwave_common::catalog::{Field, Schema};
40 use risingwave_common::types::DataType;
41 use risingwave_common::util::iter_util::ZipEqFast;
42 use risingwave_pb::expr::expr_node::Type;
43 use risingwave_pb::plan_common::JoinType;
44
45 use crate::expr::{ExprImpl, FunctionCall, InputRef};
46 use crate::optimizer::optimizer_context::OptimizerContext;
47 use crate::optimizer::plan_node::generic::GenericPlanRef;
48 use crate::optimizer::plan_node::*;
49 use crate::utils::Condition;
50
51 #[tokio::test]
52 async fn test_heuristic_join_ordering_from_multijoin() {
53 let ty = DataType::Int32;
73 let ctx = OptimizerContext::mock().await;
74 let fields: Vec<Field> = (1..10)
75 .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
76 .collect();
77 let relation_a = LogicalValues::new(
78 vec![],
79 Schema {
80 fields: fields[0..3].to_vec(),
81 },
82 ctx.clone(),
83 );
84 let relation_c = LogicalValues::new(
85 vec![],
86 Schema {
87 fields: fields[3..6].to_vec(),
88 },
89 ctx.clone(),
90 );
91 let relation_b = LogicalValues::new(
92 vec![],
93 Schema {
94 fields: fields[6..9].to_vec(),
95 },
96 ctx,
97 );
98
99 let join_type = JoinType::Inner;
100 let join_0 = LogicalJoin::new(
101 relation_a.clone().into(),
102 relation_c.clone().into(),
103 join_type,
104 Condition::true_cond(),
105 );
106
107 let on_1: ExprImpl = ExprImpl::FunctionCall(Box::new(
108 FunctionCall::new(
109 Type::Equal,
110 vec![
111 ExprImpl::InputRef(Box::new(InputRef::new(2, ty.clone()))),
112 ExprImpl::InputRef(Box::new(InputRef::new(8, ty))),
113 ],
114 )
115 .unwrap(),
116 ));
117 let join_1 = LogicalJoin::new(
118 join_0.into(),
119 relation_b.clone().into(),
120 join_type,
121 Condition::with_expr(on_1),
122 );
123 let multijoin_builder = LogicalMultiJoinBuilder::new(join_1.into());
124 let multi_join = multijoin_builder.build();
125 for (input, schema) in multi_join.inputs().iter().zip_eq_fast(vec![
126 relation_a.schema(),
127 relation_c.schema(),
128 relation_b.schema(),
129 ]) {
130 assert_eq!(input.schema(), schema);
131 }
132
133 assert_eq!(multi_join.heuristic_ordering().unwrap(), vec![0, 2, 1]);
134 }
135}