risingwave_frontend/optimizer/rule/
apply_project_transpose_rule.rs1use itertools::Itertools;
16use risingwave_pb::plan_common::JoinType;
17
18use super::{ApplyOffsetRewriter, BoxedRule, Rule};
19use crate::expr::{ExprImpl, ExprRewriter, InputRef};
20use crate::optimizer::PlanRef;
21use crate::optimizer::plan_node::{LogicalApply, LogicalProject};
22
23pub struct ApplyProjectTransposeRule {}
45impl Rule for ApplyProjectTransposeRule {
46 fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
47 let apply: &LogicalApply = plan.as_logical_apply()?;
48 let (left, right, on, join_type, correlated_id, correlated_indices, max_one_row) =
49 apply.clone().decompose();
50 let project = right.as_logical_project()?;
51 assert_eq!(join_type, JoinType::Inner);
52
53 let mut exprs: Vec<ExprImpl> = left
56 .schema()
57 .data_types()
58 .into_iter()
59 .enumerate()
60 .map(|(index, data_type)| InputRef::new(index, data_type).into())
61 .collect();
62
63 let (proj_exprs, proj_input) = project.clone().decompose();
64
65 let mut rewriter =
67 ApplyOffsetRewriter::new(left.schema().len(), &correlated_indices, correlated_id);
68
69 let new_proj_exprs: Vec<ExprImpl> = proj_exprs
70 .into_iter()
71 .map(|expr| rewriter.rewrite_expr(expr))
72 .collect_vec();
73
74 exprs.extend(new_proj_exprs.clone());
75
76 let mut rewriter = ApplyOnConditionRewriter {
77 left_input_len: left.schema().len(),
78 mapping: new_proj_exprs,
79 };
80 let new_on = on.rewrite_expr(&mut rewriter);
81 let new_apply = LogicalApply::create(
82 left,
83 proj_input,
84 join_type,
85 new_on,
86 correlated_id,
87 correlated_indices,
88 max_one_row,
89 );
90
91 let new_project = LogicalProject::create(new_apply, exprs);
92 Some(new_project)
93 }
94}
95
96impl ApplyProjectTransposeRule {
97 pub fn create() -> BoxedRule {
98 Box::new(ApplyProjectTransposeRule {})
99 }
100}
101
102pub struct ApplyOnConditionRewriter {
103 pub left_input_len: usize,
104 pub mapping: Vec<ExprImpl>,
105}
106
107impl ExprRewriter for ApplyOnConditionRewriter {
108 fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
109 if input_ref.index >= self.left_input_len {
110 self.mapping[input_ref.index() - self.left_input_len].clone()
111 } else {
112 input_ref.into()
113 }
114 }
115}