risingwave_frontend/optimizer/rule/
join_commute_rule.rs1use itertools::Itertools;
16use risingwave_pb::plan_common::JoinType;
17
18use super::{BoxedRule, Rule};
19use crate::expr::{Expr, ExprImpl, ExprRewriter, InputRef};
20use crate::optimizer::PlanRef;
21use crate::optimizer::plan_node::LogicalJoin;
22
23pub struct JoinCommuteRule {}
31impl Rule for JoinCommuteRule {
32 fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
33 let join: &LogicalJoin = plan.as_logical_join()?;
34 let join_type = join.join_type();
35 match join_type {
36 JoinType::RightOuter | JoinType::RightSemi | JoinType::RightAnti => {
37 let (left, right, on, join_type, output_indices) = join.clone().decompose();
38
39 let left_len = left.schema().len();
40 let right_len = right.schema().len();
41
42 let new_output_indices = output_indices
43 .into_iter()
44 .map(|i| {
45 if i < left_len {
46 i + right_len
47 } else {
48 i - left_len
49 }
50 })
51 .collect_vec();
52
53 let mut condition_rewriter = Rewriter {
54 join_left_len: left_len,
55 join_left_offset: right_len as isize,
56 join_right_offset: -(left_len as isize),
57 };
58 let new_on = on.rewrite_expr(&mut condition_rewriter);
59
60 let new_join = LogicalJoin::with_output_indices(
61 right,
62 left,
63 Self::inverse_join_type(join_type),
64 new_on,
65 new_output_indices,
66 );
67
68 Some(new_join.into())
69 }
70 JoinType::Inner
71 | JoinType::LeftOuter
72 | JoinType::LeftSemi
73 | JoinType::LeftAnti
74 | JoinType::FullOuter
75 | JoinType::AsofInner
76 | JoinType::AsofLeftOuter
77 | JoinType::Unspecified => None,
78 }
79 }
80}
81
82struct Rewriter {
83 join_left_len: usize,
84 join_left_offset: isize,
85 join_right_offset: isize,
86}
87impl ExprRewriter for Rewriter {
88 fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
89 if input_ref.index < self.join_left_len {
90 InputRef::new(
91 (input_ref.index() as isize + self.join_left_offset) as usize,
92 input_ref.return_type(),
93 )
94 .into()
95 } else {
96 InputRef::new(
97 (input_ref.index() as isize + self.join_right_offset) as usize,
98 input_ref.return_type(),
99 )
100 .into()
101 }
102 }
103}
104
105impl JoinCommuteRule {
106 pub fn create() -> BoxedRule {
107 Box::new(JoinCommuteRule {})
108 }
109
110 fn inverse_join_type(join_type: JoinType) -> JoinType {
111 match join_type {
112 JoinType::Unspecified => JoinType::Unspecified,
113 JoinType::Inner => JoinType::Inner,
114 JoinType::LeftOuter => JoinType::RightOuter,
115 JoinType::RightOuter => JoinType::LeftOuter,
116 JoinType::FullOuter => JoinType::FullOuter,
117 JoinType::LeftSemi => JoinType::RightSemi,
118 JoinType::LeftAnti => JoinType::RightAnti,
119 JoinType::RightSemi => JoinType::LeftSemi,
120 JoinType::RightAnti => JoinType::LeftAnti,
121 JoinType::AsofInner | JoinType::AsofLeftOuter => unreachable!(),
122 }
123 }
124}