risingwave_frontend/optimizer/rule/
join_commute_rule.rs1use itertools::Itertools;
16use risingwave_pb::plan_common::JoinType;
17
18use super::prelude::{PlanRef, *};
19use crate::expr::{Expr, ExprImpl, ExprRewriter, InputRef};
20use crate::optimizer::plan_node::LogicalJoin;
21
22pub struct JoinCommuteRule {}
30impl Rule<Logical> for JoinCommuteRule {
31 fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
32 let join: &LogicalJoin = plan.as_logical_join()?;
33 let join_type = join.join_type();
34 match join_type {
35 JoinType::RightOuter | JoinType::RightSemi | JoinType::RightAnti => {
36 let (left, right, on, join_type, output_indices) = join.clone().decompose();
37
38 let left_len = left.schema().len();
39 let right_len = right.schema().len();
40
41 let new_output_indices = output_indices
42 .into_iter()
43 .map(|i| {
44 if i < left_len {
45 i + right_len
46 } else {
47 i - left_len
48 }
49 })
50 .collect_vec();
51
52 let mut condition_rewriter = Rewriter {
53 join_left_len: left_len,
54 join_left_offset: right_len as isize,
55 join_right_offset: -(left_len as isize),
56 };
57 let new_on = on.rewrite_expr(&mut condition_rewriter);
58
59 let new_join = LogicalJoin::with_output_indices(
60 right,
61 left,
62 Self::inverse_join_type(join_type),
63 new_on,
64 new_output_indices,
65 );
66
67 Some(new_join.into())
68 }
69 JoinType::Inner
70 | JoinType::LeftOuter
71 | JoinType::LeftSemi
72 | JoinType::LeftAnti
73 | JoinType::FullOuter
74 | JoinType::AsofInner
75 | JoinType::AsofLeftOuter
76 | JoinType::Unspecified => None,
77 }
78 }
79}
80
81struct Rewriter {
82 join_left_len: usize,
83 join_left_offset: isize,
84 join_right_offset: isize,
85}
86impl ExprRewriter for Rewriter {
87 fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
88 if input_ref.index < self.join_left_len {
89 InputRef::new(
90 (input_ref.index() as isize + self.join_left_offset) as usize,
91 input_ref.return_type(),
92 )
93 .into()
94 } else {
95 InputRef::new(
96 (input_ref.index() as isize + self.join_right_offset) as usize,
97 input_ref.return_type(),
98 )
99 .into()
100 }
101 }
102}
103
104impl JoinCommuteRule {
105 pub fn create() -> BoxedRule {
106 Box::new(JoinCommuteRule {})
107 }
108
109 fn inverse_join_type(join_type: JoinType) -> JoinType {
110 match join_type {
111 JoinType::Unspecified => JoinType::Unspecified,
112 JoinType::Inner => JoinType::Inner,
113 JoinType::LeftOuter => JoinType::RightOuter,
114 JoinType::RightOuter => JoinType::LeftOuter,
115 JoinType::FullOuter => JoinType::FullOuter,
116 JoinType::LeftSemi => JoinType::RightSemi,
117 JoinType::LeftAnti => JoinType::RightAnti,
118 JoinType::RightSemi => JoinType::LeftSemi,
119 JoinType::RightAnti => JoinType::LeftAnti,
120 JoinType::AsofInner | JoinType::AsofLeftOuter => unreachable!(),
121 }
122 }
123}