risingwave_frontend/optimizer/rule/
translate_apply_rule.rs1use std::collections::HashMap;
16
17use risingwave_common::types::DataType;
18use risingwave_pb::plan_common::JoinType;
19
20use super::{BoxedRule, Rule};
21use crate::expr::{ExprImpl, ExprType, FunctionCall, InputRef};
22use crate::optimizer::PlanRef;
23use crate::optimizer::plan_node::generic::{Agg, GenericPlanRef};
24use crate::optimizer::plan_node::{
25 LogicalApply, LogicalJoin, LogicalProject, LogicalScan, LogicalShare, PlanTreeNodeBinary,
26 PlanTreeNodeUnary,
27};
28use crate::utils::{ColIndexMapping, Condition};
29
30pub struct TranslateApplyRule {
51 enable_share_plan: bool,
52}
53
54impl Rule for TranslateApplyRule {
55 fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
56 let apply: &LogicalApply = plan.as_logical_apply()?;
57 if apply.translated() {
58 return None;
59 }
60 let mut left: PlanRef = apply.left();
61 let right: PlanRef = apply.right();
62 let apply_left_len = left.schema().len();
63 let correlated_indices = apply.correlated_indices();
64
65 let mut index_mapping =
66 ColIndexMapping::new(vec![None; apply_left_len], correlated_indices.len());
67 let mut data_types = HashMap::new();
68 let mut index = 0;
69
70 let domain: PlanRef = if let Some(rewritten_left) = Self::rewrite(
74 &left,
75 correlated_indices.clone(),
76 0,
77 &mut index_mapping,
78 &mut data_types,
79 &mut index,
80 ) {
81 let exprs = correlated_indices
85 .clone()
86 .into_iter()
87 .enumerate()
88 .map(|(i, correlated_index)| {
89 let index = index_mapping.map(correlated_index);
90 let data_type = rewritten_left.schema().fields()[index].data_type.clone();
91 index_mapping.put(correlated_index, Some(i));
92 InputRef::new(index, data_type).into()
93 })
94 .collect();
95 let project = LogicalProject::create(rewritten_left, exprs);
96 let distinct = Agg::new(vec![], (0..project.schema().len()).collect(), project);
97 distinct.into()
98 } else {
99 left = if self.enable_share_plan {
104 let logical_share = LogicalShare::new(left);
105 logical_share.into()
106 } else {
107 left
108 };
109
110 let exprs = correlated_indices
111 .clone()
112 .into_iter()
113 .map(|correlated_index| {
114 let data_type = left.schema().fields()[correlated_index].data_type.clone();
115 InputRef::new(correlated_index, data_type).into()
116 })
117 .collect();
118 let project = LogicalProject::create(left.clone(), exprs);
119 let distinct = Agg::new(vec![], (0..project.schema().len()).collect(), project);
120 distinct.into()
121 };
122
123 let eq_predicates = correlated_indices
124 .into_iter()
125 .enumerate()
126 .map(|(i, correlated_index)| {
127 let shifted_index = i + apply_left_len;
128 let data_type = domain.schema().fields()[i].data_type.clone();
129 let left = InputRef::new(correlated_index, data_type.clone());
130 let right = InputRef::new(shifted_index, data_type);
131 FunctionCall::new_unchecked(
133 ExprType::IsNotDistinctFrom,
134 vec![left.into(), right.into()],
135 DataType::Boolean,
136 )
137 .into()
138 })
139 .collect::<Vec<ExprImpl>>();
140
141 let new_apply = apply.clone_with_left_right(left, right);
142 let new_node = new_apply.translate_apply(domain, eq_predicates);
143 Some(new_node)
144 }
145}
146
147impl TranslateApplyRule {
148 pub fn create(enable_share_plan: bool) -> BoxedRule {
149 Box::new(TranslateApplyRule { enable_share_plan })
150 }
151
152 fn rewrite(
157 plan: &PlanRef,
158 correlated_indices: Vec<usize>,
159 offset: usize,
160 index_mapping: &mut ColIndexMapping,
161 data_types: &mut HashMap<usize, DataType>,
162 index: &mut usize,
163 ) -> Option<PlanRef> {
164 if let Some(join) = plan.as_logical_join() {
165 Self::rewrite_join(
166 join,
167 correlated_indices,
168 offset,
169 index_mapping,
170 data_types,
171 index,
172 )
173 } else if let Some(apply) = plan.as_logical_apply() {
174 Self::rewrite_apply(
175 apply,
176 correlated_indices,
177 offset,
178 index_mapping,
179 data_types,
180 index,
181 )
182 } else if let Some(scan) = plan.as_logical_scan() {
183 Self::rewrite_scan(
184 scan,
185 correlated_indices,
186 offset,
187 index_mapping,
188 data_types,
189 index,
190 )
191 } else if let Some(filter) = plan.as_logical_filter() {
192 Self::rewrite(
193 &filter.input(),
194 correlated_indices,
195 offset,
196 index_mapping,
197 data_types,
198 index,
199 )
200 } else {
201 None
203 }
204 }
205
206 fn rewrite_join(
207 join: &LogicalJoin,
208 required_col_idx: Vec<usize>,
209 mut offset: usize,
210 index_mapping: &mut ColIndexMapping,
211 data_types: &mut HashMap<usize, DataType>,
212 index: &mut usize,
213 ) -> Option<PlanRef> {
214 let left_len = join.left().schema().len();
216 let (left_idxs, right_idxs): (Vec<_>, Vec<_>) = required_col_idx
217 .into_iter()
218 .partition(|idx| *idx < left_len);
219 let mut rewrite =
220 |plan: PlanRef, mut indices: Vec<usize>, is_right: bool| -> Option<PlanRef> {
221 if is_right {
222 indices.iter_mut().for_each(|index| *index -= left_len);
223 offset += left_len;
224 }
225 Self::rewrite(&plan, indices, offset, index_mapping, data_types, index)
226 };
227 match (left_idxs.is_empty(), right_idxs.is_empty()) {
228 (true, false) => {
229 match join.join_type() {
231 JoinType::Inner
232 | JoinType::LeftSemi
233 | JoinType::RightSemi
234 | JoinType::LeftAnti
235 | JoinType::RightAnti
236 | JoinType::RightOuter
237 | JoinType::AsofInner => rewrite(join.right(), right_idxs, true),
238 JoinType::LeftOuter | JoinType::FullOuter | JoinType::AsofLeftOuter => None,
239 JoinType::Unspecified => unreachable!(),
240 }
241 }
242 (false, true) => {
243 match join.join_type() {
245 JoinType::Inner
246 | JoinType::LeftSemi
247 | JoinType::RightSemi
248 | JoinType::LeftAnti
249 | JoinType::RightAnti
250 | JoinType::LeftOuter
251 | JoinType::AsofInner
252 | JoinType::AsofLeftOuter => rewrite(join.left(), left_idxs, false),
253 JoinType::RightOuter | JoinType::FullOuter => None,
254 JoinType::Unspecified => unreachable!(),
255 }
256 }
257 (false, false) => {
258 match join.join_type() {
260 JoinType::Inner
261 | JoinType::LeftSemi
262 | JoinType::RightSemi
263 | JoinType::LeftAnti
264 | JoinType::RightAnti
265 | JoinType::AsofInner => {
266 let left = rewrite(join.left(), left_idxs, false)?;
267 let right = rewrite(join.right(), right_idxs, true)?;
268 let new_join =
269 LogicalJoin::new(left, right, join.join_type(), Condition::true_cond());
270 Some(new_join.into())
271 }
272 JoinType::LeftOuter
273 | JoinType::RightOuter
274 | JoinType::FullOuter
275 | JoinType::AsofLeftOuter => None,
276 JoinType::Unspecified => unreachable!(),
277 }
278 }
279 _ => None,
280 }
281 }
282
283 fn rewrite_apply(
295 apply: &LogicalApply,
296 required_col_idx: Vec<usize>,
297 offset: usize,
298 index_mapping: &mut ColIndexMapping,
299 data_types: &mut HashMap<usize, DataType>,
300 index: &mut usize,
301 ) -> Option<PlanRef> {
302 let left_len = apply.left().schema().len();
304 let (left_idxs, right_idxs): (Vec<_>, Vec<_>) = required_col_idx
305 .into_iter()
306 .partition(|idx| *idx < left_len);
307 if !left_idxs.is_empty() && right_idxs.is_empty() {
308 match apply.join_type() {
310 JoinType::Inner
311 | JoinType::LeftSemi
312 | JoinType::LeftAnti
313 | JoinType::LeftOuter
314 | JoinType::AsofInner
315 | JoinType::AsofLeftOuter => {
316 let plan = apply.left();
317 Self::rewrite(&plan, left_idxs, offset, index_mapping, data_types, index)
318 }
319 JoinType::RightOuter
320 | JoinType::RightAnti
321 | JoinType::RightSemi
322 | JoinType::FullOuter => None,
323 JoinType::Unspecified => unreachable!(),
324 }
325 } else {
326 None
327 }
328 }
329
330 fn rewrite_scan(
331 scan: &LogicalScan,
332 required_col_idx: Vec<usize>,
333 offset: usize,
334 index_mapping: &mut ColIndexMapping,
335 data_types: &mut HashMap<usize, DataType>,
336 index: &mut usize,
337 ) -> Option<PlanRef> {
338 for i in &required_col_idx {
339 let correlated_index = *i + offset;
340 index_mapping.put(correlated_index, Some(*index));
341 data_types.insert(
342 correlated_index,
343 scan.schema().fields()[*i].data_type.clone(),
344 );
345 *index += 1;
346 }
347
348 Some(scan.clone_with_output_indices(required_col_idx).into())
349 }
350}