risingwave_frontend/optimizer/rule/
apply_table_function_to_project_set_rule.rs1use itertools::Itertools;
16use risingwave_pb::plan_common::JoinType;
17
18use super::prelude::{PlanRef, *};
19use crate::expr::{CorrelatedId, CorrelatedInputRef, Expr, ExprImpl, ExprRewriter, InputRef};
20use crate::optimizer::plan_node::generic::GenericPlanRef;
21use crate::optimizer::plan_node::{
22 LogicalApply, LogicalProject, LogicalProjectSet, LogicalValues, PlanTreeNodeUnary,
23};
24
25pub struct ApplyTableFunctionToProjectSetRule {}
55
56impl Rule<Logical> for ApplyTableFunctionToProjectSetRule {
57 fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
58 let apply: &LogicalApply = plan.as_logical_apply()?;
59 let (left, right, on, join_type, correlated_id, correlated_indices, max_one_row) =
60 apply.clone().decompose();
61
62 if max_one_row || join_type != JoinType::Inner || !on.always_true() {
63 return None;
64 }
65
66 let right_project = right.as_logical_project()?;
68 let right_project_input = right_project.input();
69 let right_project_set: &LogicalProjectSet = right_project_input.as_logical_project_set()?;
70 let right_project_set_input = right_project_set.input();
71 let right_values: &LogicalValues = right_project_set_input.as_logical_values()?;
72 if !right_values.is_empty_scalar() {
73 return None;
74 }
75
76 let left_len = left.schema().len();
77
78 let mut select_list: Vec<ExprImpl> = left
82 .schema()
83 .data_types()
84 .into_iter()
85 .enumerate()
86 .map(|(idx, ty)| InputRef::new(idx, ty).into())
87 .collect();
88
89 let mut corr_rewriter =
90 CorrelatedInputRefToInputRefRewriter::new(correlated_id, correlated_indices, left_len);
91 let rhs_select_list = right_project_set
92 .select_list()
93 .iter()
94 .cloned()
95 .map(|e| corr_rewriter.rewrite_expr(e))
96 .collect_vec();
97 if !corr_rewriter.touched() {
98 return None;
100 }
101 select_list.extend(rhs_select_list);
102
103 let new_project_set: PlanRef = LogicalProjectSet::new(left.clone(), select_list).into();
104
105 let mut out_exprs: Vec<ExprImpl> = left
113 .schema()
114 .data_types()
115 .into_iter()
116 .enumerate()
117 .map(|(idx, ty)| InputRef::new(idx + 1, ty).into())
118 .collect();
119
120 let mut shift_rewriter = ShiftRhsInputRefRewriter::new(left_len);
121 let rhs_project_exprs = right_project
122 .exprs()
123 .iter()
124 .cloned()
125 .map(|e| shift_rewriter.rewrite_expr(e))
126 .collect_vec();
127 out_exprs.extend(rhs_project_exprs);
128
129 Some(LogicalProject::new(new_project_set, out_exprs).into())
130 }
131}
132
133impl ApplyTableFunctionToProjectSetRule {
134 pub fn create() -> BoxedRule {
135 Box::new(ApplyTableFunctionToProjectSetRule {})
136 }
137}
138
139struct CorrelatedInputRefToInputRefRewriter {
140 correlated_id: CorrelatedId,
141 correlated_indices: Vec<usize>,
142 left_len: usize,
143 touched: bool,
144}
145
146impl CorrelatedInputRefToInputRefRewriter {
147 fn new(correlated_id: CorrelatedId, correlated_indices: Vec<usize>, left_len: usize) -> Self {
148 Self {
149 correlated_id,
150 correlated_indices,
151 left_len,
152 touched: false,
153 }
154 }
155
156 fn touched(&self) -> bool {
157 self.touched
158 }
159}
160
161impl ExprRewriter for CorrelatedInputRefToInputRefRewriter {
162 fn rewrite_correlated_input_ref(
163 &mut self,
164 correlated_input_ref: CorrelatedInputRef,
165 ) -> ExprImpl {
166 if correlated_input_ref.correlated_id() != self.correlated_id {
167 return correlated_input_ref.into();
168 }
169 let idx = correlated_input_ref.index();
170 if idx >= self.left_len || !self.correlated_indices.contains(&idx) {
171 return correlated_input_ref.into();
173 }
174 self.touched = true;
175 InputRef::new(idx, correlated_input_ref.return_type()).into()
176 }
177}
178
179struct ShiftRhsInputRefRewriter {
180 left_len: usize,
181}
182
183impl ShiftRhsInputRefRewriter {
184 fn new(left_len: usize) -> Self {
185 Self { left_len }
186 }
187}
188
189impl ExprRewriter for ShiftRhsInputRefRewriter {
190 fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
191 if input_ref.index() == 0 {
192 input_ref.into()
194 } else {
195 InputRef::new(input_ref.index() + self.left_len, input_ref.return_type()).into()
196 }
197 }
198}