1use fixedbitset::FixedBitSet;
16use itertools::Itertools;
17use risingwave_common::types::DataType;
18use risingwave_common::types::DataType::Boolean;
19use risingwave_pb::plan_common::JoinType;
20
21use super::{BoxedRule, Rule};
22use crate::expr::{
23 CorrelatedId, CorrelatedInputRef, Expr, ExprImpl, ExprRewriter, ExprType, FunctionCall,
24 InputRef,
25};
26use crate::optimizer::PlanRef;
27use crate::optimizer::plan_node::generic::GenericPlanRef;
28use crate::optimizer::plan_node::{
29 LogicalApply, LogicalFilter, LogicalJoin, PlanTreeNode, PlanTreeNodeBinary,
30};
31use crate::optimizer::plan_visitor::{ExprCorrelatedIdFinder, PlanCorrelatedIdFinder};
32use crate::optimizer::rule::apply_offset_rewriter::ApplyCorrelatedIndicesConverter;
33use crate::utils::{ColIndexMapping, Condition};
34
35pub struct ApplyJoinTransposeRule {}
89impl Rule for ApplyJoinTransposeRule {
90 fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
91 let apply: &LogicalApply = plan.as_logical_apply()?;
92 let (
93 apply_left,
94 apply_right,
95 apply_on,
96 apply_join_type,
97 correlated_id,
98 correlated_indices,
99 max_one_row,
100 ) = apply.clone().decompose();
101
102 if max_one_row {
103 return None;
104 }
105
106 assert_eq!(apply_join_type, JoinType::Inner);
107 let join: &LogicalJoin = apply_right.as_logical_join()?;
108
109 let mut finder = ExprCorrelatedIdFinder::default();
110 join.on().visit_expr(&mut finder);
111 let join_cond_has_correlated_id = finder.contains(&correlated_id);
112 let join_left_has_correlated_id =
113 PlanCorrelatedIdFinder::find_correlated_id(join.left(), &correlated_id);
114 let join_right_has_correlated_id =
115 PlanCorrelatedIdFinder::find_correlated_id(join.right(), &correlated_id);
116
117 if !join_cond_has_correlated_id
121 && !join_left_has_correlated_id
122 && !join_right_has_correlated_id
123 {
124 return None;
125 }
126
127 if !join.output_indices_are_trivial() {
130 let new_apply_right = crate::optimizer::rule::ProjectJoinSeparateRule::create()
131 .apply(join.clone().into())
132 .unwrap();
133 return Some(apply.clone_with_inputs(&[apply_left, new_apply_right]));
134 }
135
136 let (push_left, push_right) = match join.join_type() {
137 JoinType::LeftSemi
140 | JoinType::LeftAnti
141 | JoinType::LeftOuter
142 | JoinType::AsofLeftOuter => {
143 if !join_right_has_correlated_id {
144 (true, false)
145 } else {
146 (true, true)
147 }
148 }
149 JoinType::RightSemi | JoinType::RightAnti | JoinType::RightOuter => {
152 if !join_left_has_correlated_id {
153 (false, true)
154 } else {
155 (true, true)
156 }
157 }
158 JoinType::Inner | JoinType::AsofInner => {
160 if join_cond_has_correlated_id
161 && !join_right_has_correlated_id
162 && !join_left_has_correlated_id
163 {
164 (true, false)
165 } else {
166 (join_left_has_correlated_id, join_right_has_correlated_id)
167 }
168 }
169 JoinType::FullOuter => (true, true),
171 JoinType::Unspecified => unreachable!(),
172 };
173
174 let out = if push_left && push_right {
175 self.push_apply_both_side(
176 apply_left,
177 join,
178 apply_on,
179 apply_join_type,
180 correlated_id,
181 correlated_indices,
182 )
183 } else if push_left {
184 self.push_apply_left_side(
185 apply_left,
186 join,
187 apply_on,
188 apply_join_type,
189 correlated_id,
190 correlated_indices,
191 )
192 } else if push_right {
193 self.push_apply_right_side(
194 apply_left,
195 join,
196 apply_on,
197 apply_join_type,
198 correlated_id,
199 correlated_indices,
200 )
201 } else {
202 unreachable!();
203 };
204 assert_eq!(out.schema(), plan.schema());
205 Some(out)
206 }
207}
208
209impl ApplyJoinTransposeRule {
210 fn push_apply_left_side(
211 &self,
212 apply_left: PlanRef,
213 join: &LogicalJoin,
214 apply_on: Condition,
215 apply_join_type: JoinType,
216 correlated_id: CorrelatedId,
217 correlated_indices: Vec<usize>,
218 ) -> PlanRef {
219 let apply_left_len = apply_left.schema().len();
220 let join_left_len = join.left().schema().len();
221 let mut rewriter = Rewriter {
222 join_left_len,
223 join_left_offset: apply_left_len as isize,
224 join_right_offset: apply_left_len as isize,
225 index_mapping: ApplyCorrelatedIndicesConverter::convert_to_index_mapping(
226 &correlated_indices,
227 ),
228 correlated_id,
229 };
230
231 let new_join_condition = Condition {
233 conjunctions: join
234 .on()
235 .clone()
236 .into_iter()
237 .map(|expr| rewriter.rewrite_expr(expr))
238 .collect_vec(),
239 };
240
241 let mut left_apply_condition: Vec<ExprImpl> = vec![];
242 let mut other_condition: Vec<ExprImpl> = vec![];
243
244 match join.join_type() {
245 JoinType::LeftSemi | JoinType::LeftAnti => {
246 left_apply_condition.extend(apply_on);
247 }
248 JoinType::Inner
249 | JoinType::LeftOuter
250 | JoinType::RightOuter
251 | JoinType::FullOuter
252 | JoinType::AsofInner
253 | JoinType::AsofLeftOuter => {
254 let apply_len = apply_left_len + join.schema().len();
255 let mut d_t1_bit_set = FixedBitSet::with_capacity(apply_len);
256 d_t1_bit_set.set_range(0..apply_left_len + join_left_len, true);
257
258 let (left, other): (Vec<_>, Vec<_>) = apply_on
259 .into_iter()
260 .partition(|expr| expr.collect_input_refs(apply_len).is_subset(&d_t1_bit_set));
261 left_apply_condition.extend(left);
262 other_condition.extend(other);
263 }
264 JoinType::RightSemi | JoinType::RightAnti | JoinType::Unspecified => unreachable!(),
265 }
266
267 let new_join_left = LogicalApply::create(
268 apply_left,
269 join.left(),
270 apply_join_type,
271 Condition {
272 conjunctions: left_apply_condition,
273 },
274 correlated_id,
275 correlated_indices,
276 false,
277 );
278
279 let new_join = LogicalJoin::new(
280 new_join_left,
281 join.right(),
282 join.join_type(),
283 new_join_condition,
284 );
285
286 LogicalFilter::create(
288 new_join.into(),
289 Condition {
290 conjunctions: other_condition,
291 },
292 )
293 }
294
295 fn push_apply_right_side(
296 &self,
297 apply_left: PlanRef,
298 join: &LogicalJoin,
299 apply_on: Condition,
300 apply_join_type: JoinType,
301 correlated_id: CorrelatedId,
302 correlated_indices: Vec<usize>,
303 ) -> PlanRef {
304 let apply_left_len = apply_left.schema().len();
305 let join_left_len = join.left().schema().len();
306 let mut rewriter = Rewriter {
307 join_left_len,
308 join_left_offset: 0,
309 join_right_offset: apply_left_len as isize,
310 index_mapping: ApplyCorrelatedIndicesConverter::convert_to_index_mapping(
311 &correlated_indices,
312 ),
313 correlated_id,
314 };
315
316 let new_join_condition = Condition {
318 conjunctions: join
319 .on()
320 .clone()
321 .into_iter()
322 .map(|expr| rewriter.rewrite_expr(expr))
323 .collect_vec(),
324 };
325
326 let mut right_apply_condition: Vec<ExprImpl> = vec![];
327 let mut other_condition: Vec<ExprImpl> = vec![];
328
329 match join.join_type() {
330 JoinType::RightSemi | JoinType::RightAnti => {
331 right_apply_condition.extend(apply_on);
332 }
333 JoinType::Inner
334 | JoinType::LeftOuter
335 | JoinType::RightOuter
336 | JoinType::FullOuter
337 | JoinType::AsofInner
338 | JoinType::AsofLeftOuter => {
339 let apply_len = apply_left_len + join.schema().len();
340 let mut d_t2_bit_set = FixedBitSet::with_capacity(apply_len);
341 d_t2_bit_set.set_range(0..apply_left_len, true);
342 d_t2_bit_set.set_range(apply_left_len + join_left_len..apply_len, true);
343
344 let (right, other): (Vec<_>, Vec<_>) = apply_on
345 .into_iter()
346 .partition(|expr| expr.collect_input_refs(apply_len).is_subset(&d_t2_bit_set));
347 right_apply_condition.extend(right);
348 other_condition.extend(other);
349
350 let mut right_apply_condition_rewriter = Rewriter {
352 join_left_len: apply_left_len,
353 join_left_offset: 0,
354 join_right_offset: -(join_left_len as isize),
355 index_mapping: ColIndexMapping::empty(0, 0),
356 correlated_id,
357 };
358
359 right_apply_condition = right_apply_condition
360 .into_iter()
361 .map(|expr| right_apply_condition_rewriter.rewrite_expr(expr))
362 .collect_vec();
363 }
364 JoinType::LeftSemi | JoinType::LeftAnti | JoinType::Unspecified => unreachable!(),
365 }
366
367 let new_join_right = LogicalApply::create(
368 apply_left,
369 join.right(),
370 apply_join_type,
371 Condition {
372 conjunctions: right_apply_condition,
373 },
374 correlated_id,
375 correlated_indices,
376 false,
377 );
378 let (output_indices, target_size) = {
379 let (apply_left_len, join_right_len) = match apply_join_type {
380 JoinType::LeftSemi | JoinType::LeftAnti => (apply_left_len, 0),
381 JoinType::RightSemi | JoinType::RightAnti => (0, join.right().schema().len()),
382 _ => (apply_left_len, join.right().schema().len()),
383 };
384
385 let left_iter = join_left_len..join_left_len + apply_left_len;
386 let right_iter = (0..join_left_len).chain(
387 join_left_len + apply_left_len..join_left_len + apply_left_len + join_right_len,
388 );
389
390 let output_indices: Vec<_> = match join.join_type() {
391 JoinType::LeftSemi | JoinType::LeftAnti => left_iter.collect(),
392 JoinType::RightSemi | JoinType::RightAnti => right_iter.collect(),
393 _ => left_iter.chain(right_iter).collect(),
394 };
395
396 let target_size = join_left_len + apply_left_len + join_right_len;
397 (output_indices, target_size)
398 };
399 let mut output_indices_mapping = ColIndexMapping::new(
400 output_indices.iter().map(|x| Some(*x)).collect(),
401 target_size,
402 );
403 let new_join = LogicalJoin::new(
404 join.left(),
405 new_join_right,
406 join.join_type(),
407 new_join_condition,
408 )
409 .clone_with_output_indices(output_indices);
410
411 LogicalFilter::create(
413 new_join.into(),
414 Condition {
415 conjunctions: other_condition,
416 }
417 .rewrite_expr(&mut output_indices_mapping),
418 )
419 }
420
421 fn push_apply_both_side(
422 &self,
423 apply_left: PlanRef,
424 join: &LogicalJoin,
425 apply_on: Condition,
426 apply_join_type: JoinType,
427 correlated_id: CorrelatedId,
428 correlated_indices: Vec<usize>,
429 ) -> PlanRef {
430 let apply_left_len = apply_left.schema().len();
431 let join_left_len = join.left().schema().len();
432 let mut rewriter = Rewriter {
433 join_left_len,
434 join_left_offset: apply_left_len as isize,
435 join_right_offset: 2 * apply_left_len as isize,
436 index_mapping: ApplyCorrelatedIndicesConverter::convert_to_index_mapping(
437 &correlated_indices,
438 ),
439 correlated_id,
440 };
441
442 let natural_conjunctions = apply_left
444 .schema()
445 .fields
446 .iter()
447 .enumerate()
448 .map(|(i, field)| {
449 Self::create_null_safe_equal_expr(
450 i,
451 field.data_type.clone(),
452 i + join_left_len + apply_left_len,
453 field.data_type.clone(),
454 )
455 })
456 .collect_vec();
457 let new_join_condition = Condition {
458 conjunctions: join
459 .on()
460 .clone()
461 .into_iter()
462 .map(|expr| rewriter.rewrite_expr(expr))
463 .chain(natural_conjunctions)
464 .collect_vec(),
465 };
466
467 let mut left_apply_condition: Vec<ExprImpl> = vec![];
468 let mut right_apply_condition: Vec<ExprImpl> = vec![];
469 let mut other_condition: Vec<ExprImpl> = vec![];
470
471 match join.join_type() {
472 JoinType::LeftSemi | JoinType::LeftAnti => {
473 left_apply_condition.extend(apply_on);
474 }
475 JoinType::RightSemi | JoinType::RightAnti => {
476 right_apply_condition.extend(apply_on);
477 }
478 JoinType::Inner
479 | JoinType::LeftOuter
480 | JoinType::RightOuter
481 | JoinType::FullOuter
482 | JoinType::AsofInner
483 | JoinType::AsofLeftOuter => {
484 let apply_len = apply_left_len + join.schema().len();
485 let mut d_t1_bit_set = FixedBitSet::with_capacity(apply_len);
486 let mut d_t2_bit_set = FixedBitSet::with_capacity(apply_len);
487 d_t1_bit_set.set_range(0..apply_left_len + join_left_len, true);
488 d_t2_bit_set.set_range(0..apply_left_len, true);
489 d_t2_bit_set.set_range(apply_left_len + join_left_len..apply_len, true);
490
491 for (key, group) in &apply_on.into_iter().chunk_by(|expr| {
492 let collect_bit_set = expr.collect_input_refs(apply_len);
493 if collect_bit_set.is_subset(&d_t1_bit_set) {
494 0
495 } else if collect_bit_set.is_subset(&d_t2_bit_set) {
496 1
497 } else {
498 2
499 }
500 }) {
501 let vec = group.collect_vec();
502 match key {
503 0 => left_apply_condition.extend(vec),
504 1 => right_apply_condition.extend(vec),
505 2 => other_condition.extend(vec),
506 _ => unreachable!(),
507 }
508 }
509
510 let mut right_apply_condition_rewriter = Rewriter {
512 join_left_len: apply_left_len,
513 join_left_offset: 0,
514 join_right_offset: -(join_left_len as isize),
515 index_mapping: ColIndexMapping::empty(0, 0),
516 correlated_id,
517 };
518
519 right_apply_condition = right_apply_condition
520 .into_iter()
521 .map(|expr| right_apply_condition_rewriter.rewrite_expr(expr))
522 .collect_vec();
523 }
524 JoinType::Unspecified => unreachable!(),
525 }
526
527 let new_join_left = LogicalApply::create(
528 apply_left.clone(),
529 join.left(),
530 apply_join_type,
531 Condition {
532 conjunctions: left_apply_condition,
533 },
534 correlated_id,
535 correlated_indices.clone(),
536 false,
537 );
538 let new_join_right = LogicalApply::create(
539 apply_left,
540 join.right(),
541 apply_join_type,
542 Condition {
543 conjunctions: right_apply_condition,
544 },
545 correlated_id,
546 correlated_indices,
547 false,
548 );
549
550 let (output_indices, target_size) = {
551 let (apply_left_len, join_right_len) = match apply_join_type {
552 JoinType::LeftSemi | JoinType::LeftAnti => (apply_left_len, 0),
553 JoinType::RightSemi | JoinType::RightAnti => (0, join.right().schema().len()),
554 _ => (apply_left_len, join.right().schema().len()),
555 };
556
557 let left_iter = 0..join_left_len + apply_left_len;
558 let right_iter = join_left_len + apply_left_len * 2
559 ..join_left_len + apply_left_len * 2 + join_right_len;
560
561 let output_indices: Vec<_> = match join.join_type() {
562 JoinType::LeftSemi | JoinType::LeftAnti => left_iter.collect(),
563 JoinType::RightSemi | JoinType::RightAnti => right_iter.collect(),
564 _ => left_iter.chain(right_iter).collect(),
565 };
566
567 let target_size = join_left_len + apply_left_len * 2 + join_right_len;
568 (output_indices, target_size)
569 };
570 let new_join = LogicalJoin::new(
571 new_join_left,
572 new_join_right,
573 join.join_type(),
574 new_join_condition,
575 )
576 .clone_with_output_indices(output_indices.clone());
577
578 match join.join_type() {
579 JoinType::LeftSemi | JoinType::LeftAnti | JoinType::RightSemi | JoinType::RightAnti => {
580 new_join.into()
581 }
582 JoinType::Inner
583 | JoinType::LeftOuter
584 | JoinType::RightOuter
585 | JoinType::FullOuter
586 | JoinType::AsofInner
587 | JoinType::AsofLeftOuter => {
588 let mut output_indices_mapping = ColIndexMapping::new(
589 output_indices.iter().map(|x| Some(*x)).collect(),
590 target_size,
591 );
592 LogicalFilter::create(
594 new_join.into(),
595 Condition {
596 conjunctions: other_condition,
597 }
598 .rewrite_expr(&mut output_indices_mapping),
599 )
600 }
601 JoinType::Unspecified => unreachable!(),
602 }
603 }
604
605 fn create_null_safe_equal_expr(
606 left: usize,
607 left_data_type: DataType,
608 right: usize,
609 right_data_type: DataType,
610 ) -> ExprImpl {
611 ExprImpl::FunctionCall(Box::new(FunctionCall::new_unchecked(
613 ExprType::IsNotDistinctFrom,
614 vec![
615 ExprImpl::InputRef(Box::new(InputRef::new(left, left_data_type))),
616 ExprImpl::InputRef(Box::new(InputRef::new(right, right_data_type))),
617 ],
618 Boolean,
619 )))
620 }
621}
622
623impl ApplyJoinTransposeRule {
624 pub fn create() -> BoxedRule {
625 Box::new(ApplyJoinTransposeRule {})
626 }
627}
628
629struct Rewriter {
631 join_left_len: usize,
632 join_left_offset: isize,
633 join_right_offset: isize,
634 index_mapping: ColIndexMapping,
635 correlated_id: CorrelatedId,
636}
637impl ExprRewriter for Rewriter {
638 fn rewrite_correlated_input_ref(
639 &mut self,
640 correlated_input_ref: CorrelatedInputRef,
641 ) -> ExprImpl {
642 if correlated_input_ref.correlated_id() == self.correlated_id {
643 InputRef::new(
644 self.index_mapping.map(correlated_input_ref.index()),
645 correlated_input_ref.return_type(),
646 )
647 .into()
648 } else {
649 correlated_input_ref.into()
650 }
651 }
652
653 fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
654 if input_ref.index < self.join_left_len {
655 InputRef::new(
656 (input_ref.index() as isize + self.join_left_offset) as usize,
657 input_ref.return_type(),
658 )
659 .into()
660 } else {
661 InputRef::new(
662 (input_ref.index() as isize + self.join_right_offset) as usize,
663 input_ref.return_type(),
664 )
665 .into()
666 }
667 }
668}