risingwave_frontend/optimizer/rule/
translate_apply_rule.rs1use std::collections::HashMap;
16
17use risingwave_common::types::DataType;
18use risingwave_pb::plan_common::JoinType;
19
20use super::prelude::{PlanRef, *};
21use crate::expr::{ExprImpl, ExprType, FunctionCall, InputRef};
22use crate::optimizer::plan_node::generic::{Agg, GenericPlanRef};
23use crate::optimizer::plan_node::{
24 LogicalApply, LogicalJoin, LogicalProject, LogicalScan, LogicalShare, PlanTreeNodeBinary,
25 PlanTreeNodeUnary,
26};
27use crate::utils::{ColIndexMapping, Condition};
28
29pub struct TranslateApplyRule {
50 enable_share_plan: bool,
51}
52
53impl Rule<Logical> for TranslateApplyRule {
54 fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
55 let apply: &LogicalApply = plan.as_logical_apply()?;
56 if apply.translated() {
57 return None;
58 }
59 let mut left: PlanRef = apply.left();
60 let right: PlanRef = apply.right();
61 let apply_left_len = left.schema().len();
62 let correlated_indices = apply.correlated_indices();
63
64 let mut index_mapping =
65 ColIndexMapping::new(vec![None; apply_left_len], correlated_indices.len());
66 let mut data_types = HashMap::new();
67 let mut index = 0;
68
69 let domain: PlanRef = if let Some(rewritten_left) = Self::rewrite(
73 &left,
74 correlated_indices.clone(),
75 0,
76 &mut index_mapping,
77 &mut data_types,
78 &mut index,
79 ) {
80 let exprs = correlated_indices
84 .clone()
85 .into_iter()
86 .enumerate()
87 .map(|(i, correlated_index)| {
88 let index = index_mapping.map(correlated_index);
89 let data_type = rewritten_left.schema().fields()[index].data_type.clone();
90 index_mapping.put(correlated_index, Some(i));
91 InputRef::new(index, data_type).into()
92 })
93 .collect();
94 let project = LogicalProject::create(rewritten_left, exprs);
95 let distinct = Agg::new(vec![], (0..project.schema().len()).collect(), project);
96 distinct.into()
97 } else {
98 left = if self.enable_share_plan {
103 let logical_share = LogicalShare::new(left);
104 logical_share.into()
105 } else {
106 left
107 };
108
109 let exprs = correlated_indices
110 .clone()
111 .into_iter()
112 .map(|correlated_index| {
113 let data_type = left.schema().fields()[correlated_index].data_type.clone();
114 InputRef::new(correlated_index, data_type).into()
115 })
116 .collect();
117 let project = LogicalProject::create(left.clone(), exprs);
118 let distinct = Agg::new(vec![], (0..project.schema().len()).collect(), project);
119 distinct.into()
120 };
121
122 let eq_predicates = correlated_indices
123 .into_iter()
124 .enumerate()
125 .map(|(i, correlated_index)| {
126 let shifted_index = i + apply_left_len;
127 let data_type = domain.schema().fields()[i].data_type.clone();
128 let left = InputRef::new(correlated_index, data_type.clone());
129 let right = InputRef::new(shifted_index, data_type);
130 FunctionCall::new_unchecked(
132 ExprType::IsNotDistinctFrom,
133 vec![left.into(), right.into()],
134 DataType::Boolean,
135 )
136 .into()
137 })
138 .collect::<Vec<ExprImpl>>();
139
140 let new_apply = apply.clone_with_left_right(left, right);
141 let new_node = new_apply.translate_apply(domain, eq_predicates);
142 Some(new_node)
143 }
144}
145
146impl TranslateApplyRule {
147 pub fn create(enable_share_plan: bool) -> BoxedRule {
148 Box::new(TranslateApplyRule { enable_share_plan })
149 }
150
151 fn rewrite(
156 plan: &PlanRef,
157 correlated_indices: Vec<usize>,
158 offset: usize,
159 index_mapping: &mut ColIndexMapping,
160 data_types: &mut HashMap<usize, DataType>,
161 index: &mut usize,
162 ) -> Option<PlanRef> {
163 if let Some(join) = plan.as_logical_join() {
164 Self::rewrite_join(
165 join,
166 correlated_indices,
167 offset,
168 index_mapping,
169 data_types,
170 index,
171 )
172 } else if let Some(apply) = plan.as_logical_apply() {
173 Self::rewrite_apply(
174 apply,
175 correlated_indices,
176 offset,
177 index_mapping,
178 data_types,
179 index,
180 )
181 } else if let Some(scan) = plan.as_logical_scan() {
182 Self::rewrite_scan(
183 scan,
184 correlated_indices,
185 offset,
186 index_mapping,
187 data_types,
188 index,
189 )
190 } else if let Some(filter) = plan.as_logical_filter() {
191 Self::rewrite(
192 &filter.input(),
193 correlated_indices,
194 offset,
195 index_mapping,
196 data_types,
197 index,
198 )
199 } else {
200 None
202 }
203 }
204
205 fn rewrite_join(
206 join: &LogicalJoin,
207 required_col_idx: Vec<usize>,
208 mut offset: usize,
209 index_mapping: &mut ColIndexMapping,
210 data_types: &mut HashMap<usize, DataType>,
211 index: &mut usize,
212 ) -> Option<PlanRef> {
213 let left_len = join.left().schema().len();
215 let (left_idxs, right_idxs): (Vec<_>, Vec<_>) = required_col_idx
216 .into_iter()
217 .partition(|idx| *idx < left_len);
218 let mut rewrite =
219 |plan: PlanRef, mut indices: Vec<usize>, is_right: bool| -> Option<PlanRef> {
220 if is_right {
221 indices.iter_mut().for_each(|index| *index -= left_len);
222 offset += left_len;
223 }
224 Self::rewrite(&plan, indices, offset, index_mapping, data_types, index)
225 };
226 match (left_idxs.is_empty(), right_idxs.is_empty()) {
227 (true, false) => {
228 match join.join_type() {
230 JoinType::Inner
231 | JoinType::LeftSemi
232 | JoinType::RightSemi
233 | JoinType::LeftAnti
234 | JoinType::RightAnti
235 | JoinType::RightOuter
236 | JoinType::AsofInner => rewrite(join.right(), right_idxs, true),
237 JoinType::LeftOuter | JoinType::FullOuter | JoinType::AsofLeftOuter => None,
238 JoinType::Unspecified => unreachable!(),
239 }
240 }
241 (false, true) => {
242 match join.join_type() {
244 JoinType::Inner
245 | JoinType::LeftSemi
246 | JoinType::RightSemi
247 | JoinType::LeftAnti
248 | JoinType::RightAnti
249 | JoinType::LeftOuter
250 | JoinType::AsofInner
251 | JoinType::AsofLeftOuter => rewrite(join.left(), left_idxs, false),
252 JoinType::RightOuter | JoinType::FullOuter => None,
253 JoinType::Unspecified => unreachable!(),
254 }
255 }
256 (false, false) => {
257 match join.join_type() {
259 JoinType::Inner
260 | JoinType::LeftSemi
261 | JoinType::RightSemi
262 | JoinType::LeftAnti
263 | JoinType::RightAnti
264 | JoinType::AsofInner => {
265 let left = rewrite(join.left(), left_idxs, false)?;
266 let right = rewrite(join.right(), right_idxs, true)?;
267 let new_join =
268 LogicalJoin::new(left, right, join.join_type(), Condition::true_cond());
269 Some(new_join.into())
270 }
271 JoinType::LeftOuter
272 | JoinType::RightOuter
273 | JoinType::FullOuter
274 | JoinType::AsofLeftOuter => None,
275 JoinType::Unspecified => unreachable!(),
276 }
277 }
278 _ => None,
279 }
280 }
281
282 fn rewrite_apply(
294 apply: &LogicalApply,
295 required_col_idx: Vec<usize>,
296 offset: usize,
297 index_mapping: &mut ColIndexMapping,
298 data_types: &mut HashMap<usize, DataType>,
299 index: &mut usize,
300 ) -> Option<PlanRef> {
301 let left_len = apply.left().schema().len();
303 let (left_idxs, right_idxs): (Vec<_>, Vec<_>) = required_col_idx
304 .into_iter()
305 .partition(|idx| *idx < left_len);
306 if !left_idxs.is_empty() && right_idxs.is_empty() {
307 match apply.join_type() {
309 JoinType::Inner
310 | JoinType::LeftSemi
311 | JoinType::LeftAnti
312 | JoinType::LeftOuter
313 | JoinType::AsofInner
314 | JoinType::AsofLeftOuter => {
315 let plan = apply.left();
316 Self::rewrite(&plan, left_idxs, offset, index_mapping, data_types, index)
317 }
318 JoinType::RightOuter
319 | JoinType::RightAnti
320 | JoinType::RightSemi
321 | JoinType::FullOuter => None,
322 JoinType::Unspecified => unreachable!(),
323 }
324 } else {
325 None
326 }
327 }
328
329 fn rewrite_scan(
330 scan: &LogicalScan,
331 required_col_idx: Vec<usize>,
332 offset: usize,
333 index_mapping: &mut ColIndexMapping,
334 data_types: &mut HashMap<usize, DataType>,
335 index: &mut usize,
336 ) -> Option<PlanRef> {
337 for i in &required_col_idx {
338 let correlated_index = *i + offset;
339 index_mapping.put(correlated_index, Some(*index));
340 data_types.insert(
341 correlated_index,
342 scan.schema().fields()[*i].data_type.clone(),
343 );
344 *index += 1;
345 }
346
347 Some(scan.clone_with_output_indices(required_col_idx).into())
348 }
349}