1use fixedbitset::FixedBitSet;
16use itertools::Itertools;
17use risingwave_common::types::DataType;
18use risingwave_common::types::DataType::Boolean;
19use risingwave_pb::plan_common::JoinType;
20
21use super::prelude::{PlanRef, *};
22use crate::expr::{
23 CorrelatedId, CorrelatedInputRef, Expr, ExprImpl, ExprRewriter, ExprType, FunctionCall,
24 InputRef,
25};
26use crate::optimizer::plan_node::generic::GenericPlanRef;
27use crate::optimizer::plan_node::{
28 LogicalApply, LogicalFilter, LogicalJoin, PlanTreeNode, PlanTreeNodeBinary,
29};
30use crate::optimizer::plan_visitor::{ExprCorrelatedIdFinder, PlanCorrelatedIdFinder};
31use crate::optimizer::rule::apply_offset_rewriter::ApplyCorrelatedIndicesConverter;
32use crate::utils::{ColIndexMapping, Condition};
33
34pub struct ApplyJoinTransposeRule {}
88impl Rule<Logical> for ApplyJoinTransposeRule {
89 fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
90 let apply: &LogicalApply = plan.as_logical_apply()?;
91 let (
92 apply_left,
93 apply_right,
94 apply_on,
95 apply_join_type,
96 correlated_id,
97 correlated_indices,
98 max_one_row,
99 ) = apply.clone().decompose();
100
101 if max_one_row {
102 return None;
103 }
104
105 assert_eq!(apply_join_type, JoinType::Inner);
106 let join: &LogicalJoin = apply_right.as_logical_join()?;
107
108 let mut finder = ExprCorrelatedIdFinder::default();
109 join.on().visit_expr(&mut finder);
110 let join_cond_has_correlated_id = finder.contains(&correlated_id);
111 let join_left_has_correlated_id =
112 PlanCorrelatedIdFinder::find_correlated_id(join.left(), &correlated_id);
113 let join_right_has_correlated_id =
114 PlanCorrelatedIdFinder::find_correlated_id(join.right(), &correlated_id);
115
116 if !join_cond_has_correlated_id
120 && !join_left_has_correlated_id
121 && !join_right_has_correlated_id
122 {
123 return None;
124 }
125
126 if !join.output_indices_are_trivial() {
129 let new_apply_right = crate::optimizer::rule::ProjectJoinSeparateRule::create()
130 .apply(join.clone().into())
131 .unwrap();
132 return Some(apply.clone_with_inputs(&[apply_left, new_apply_right]));
133 }
134
135 let (push_left, push_right) = match join.join_type() {
136 JoinType::LeftSemi
139 | JoinType::LeftAnti
140 | JoinType::LeftOuter
141 | JoinType::AsofLeftOuter => {
142 if !join_right_has_correlated_id {
143 (true, false)
144 } else {
145 (true, true)
146 }
147 }
148 JoinType::RightSemi | JoinType::RightAnti | JoinType::RightOuter => {
151 if !join_left_has_correlated_id {
152 (false, true)
153 } else {
154 (true, true)
155 }
156 }
157 JoinType::Inner | JoinType::AsofInner => {
159 if join_cond_has_correlated_id
160 && !join_right_has_correlated_id
161 && !join_left_has_correlated_id
162 {
163 (true, false)
164 } else {
165 (join_left_has_correlated_id, join_right_has_correlated_id)
166 }
167 }
168 JoinType::FullOuter => (true, true),
170 JoinType::Unspecified => unreachable!(),
171 };
172
173 let out = if push_left && push_right {
174 self.push_apply_both_side(
175 apply_left,
176 join,
177 apply_on,
178 apply_join_type,
179 correlated_id,
180 correlated_indices,
181 )
182 } else if push_left {
183 self.push_apply_left_side(
184 apply_left,
185 join,
186 apply_on,
187 apply_join_type,
188 correlated_id,
189 correlated_indices,
190 )
191 } else if push_right {
192 self.push_apply_right_side(
193 apply_left,
194 join,
195 apply_on,
196 apply_join_type,
197 correlated_id,
198 correlated_indices,
199 )
200 } else {
201 unreachable!();
202 };
203 assert_eq!(out.schema(), plan.schema());
204 Some(out)
205 }
206}
207
208impl ApplyJoinTransposeRule {
209 fn push_apply_left_side(
210 &self,
211 apply_left: PlanRef,
212 join: &LogicalJoin,
213 apply_on: Condition,
214 apply_join_type: JoinType,
215 correlated_id: CorrelatedId,
216 correlated_indices: Vec<usize>,
217 ) -> PlanRef {
218 let apply_left_len = apply_left.schema().len();
219 let join_left_len = join.left().schema().len();
220 let mut rewriter = Rewriter {
221 join_left_len,
222 join_left_offset: apply_left_len as isize,
223 join_right_offset: apply_left_len as isize,
224 index_mapping: ApplyCorrelatedIndicesConverter::convert_to_index_mapping(
225 &correlated_indices,
226 ),
227 correlated_id,
228 };
229
230 let new_join_condition = Condition {
232 conjunctions: join
233 .on()
234 .clone()
235 .into_iter()
236 .map(|expr| rewriter.rewrite_expr(expr))
237 .collect_vec(),
238 };
239
240 let mut left_apply_condition: Vec<ExprImpl> = vec![];
241 let mut other_condition: Vec<ExprImpl> = vec![];
242
243 match join.join_type() {
244 JoinType::LeftSemi | JoinType::LeftAnti => {
245 left_apply_condition.extend(apply_on);
246 }
247 JoinType::Inner
248 | JoinType::LeftOuter
249 | JoinType::RightOuter
250 | JoinType::FullOuter
251 | JoinType::AsofInner
252 | JoinType::AsofLeftOuter => {
253 let apply_len = apply_left_len + join.schema().len();
254 let mut d_t1_bit_set = FixedBitSet::with_capacity(apply_len);
255 d_t1_bit_set.set_range(0..apply_left_len + join_left_len, true);
256
257 let (left, other): (Vec<_>, Vec<_>) = apply_on
258 .into_iter()
259 .partition(|expr| expr.collect_input_refs(apply_len).is_subset(&d_t1_bit_set));
260 left_apply_condition.extend(left);
261 other_condition.extend(other);
262 }
263 JoinType::RightSemi | JoinType::RightAnti | JoinType::Unspecified => unreachable!(),
264 }
265
266 let new_join_left = LogicalApply::create(
267 apply_left,
268 join.left(),
269 apply_join_type,
270 Condition {
271 conjunctions: left_apply_condition,
272 },
273 correlated_id,
274 correlated_indices,
275 false,
276 );
277
278 let new_join = LogicalJoin::new(
279 new_join_left,
280 join.right(),
281 join.join_type(),
282 new_join_condition,
283 );
284
285 LogicalFilter::create(
287 new_join.into(),
288 Condition {
289 conjunctions: other_condition,
290 },
291 )
292 }
293
294 fn push_apply_right_side(
295 &self,
296 apply_left: PlanRef,
297 join: &LogicalJoin,
298 apply_on: Condition,
299 apply_join_type: JoinType,
300 correlated_id: CorrelatedId,
301 correlated_indices: Vec<usize>,
302 ) -> PlanRef {
303 let apply_left_len = apply_left.schema().len();
304 let join_left_len = join.left().schema().len();
305 let mut rewriter = Rewriter {
306 join_left_len,
307 join_left_offset: 0,
308 join_right_offset: apply_left_len as isize,
309 index_mapping: ApplyCorrelatedIndicesConverter::convert_to_index_mapping(
310 &correlated_indices,
311 ),
312 correlated_id,
313 };
314
315 let new_join_condition = Condition {
317 conjunctions: join
318 .on()
319 .clone()
320 .into_iter()
321 .map(|expr| rewriter.rewrite_expr(expr))
322 .collect_vec(),
323 };
324
325 let mut right_apply_condition: Vec<ExprImpl> = vec![];
326 let mut other_condition: Vec<ExprImpl> = vec![];
327
328 match join.join_type() {
329 JoinType::RightSemi | JoinType::RightAnti => {
330 right_apply_condition.extend(apply_on);
331 }
332 JoinType::Inner
333 | JoinType::LeftOuter
334 | JoinType::RightOuter
335 | JoinType::FullOuter
336 | JoinType::AsofInner
337 | JoinType::AsofLeftOuter => {
338 let apply_len = apply_left_len + join.schema().len();
339 let mut d_t2_bit_set = FixedBitSet::with_capacity(apply_len);
340 d_t2_bit_set.set_range(0..apply_left_len, true);
341 d_t2_bit_set.set_range(apply_left_len + join_left_len..apply_len, true);
342
343 let (right, other): (Vec<_>, Vec<_>) = apply_on
344 .into_iter()
345 .partition(|expr| expr.collect_input_refs(apply_len).is_subset(&d_t2_bit_set));
346 right_apply_condition.extend(right);
347 other_condition.extend(other);
348
349 let mut right_apply_condition_rewriter = Rewriter {
351 join_left_len: apply_left_len,
352 join_left_offset: 0,
353 join_right_offset: -(join_left_len as isize),
354 index_mapping: ColIndexMapping::empty(0, 0),
355 correlated_id,
356 };
357
358 right_apply_condition = right_apply_condition
359 .into_iter()
360 .map(|expr| right_apply_condition_rewriter.rewrite_expr(expr))
361 .collect_vec();
362 }
363 JoinType::LeftSemi | JoinType::LeftAnti | JoinType::Unspecified => unreachable!(),
364 }
365
366 let new_join_right = LogicalApply::create(
367 apply_left,
368 join.right(),
369 apply_join_type,
370 Condition {
371 conjunctions: right_apply_condition,
372 },
373 correlated_id,
374 correlated_indices,
375 false,
376 );
377 let (output_indices, target_size) = {
378 let (apply_left_len, join_right_len) = match apply_join_type {
379 JoinType::LeftSemi | JoinType::LeftAnti => (apply_left_len, 0),
380 JoinType::RightSemi | JoinType::RightAnti => (0, join.right().schema().len()),
381 _ => (apply_left_len, join.right().schema().len()),
382 };
383
384 let left_iter = join_left_len..join_left_len + apply_left_len;
385 let right_iter = (0..join_left_len).chain(
386 join_left_len + apply_left_len..join_left_len + apply_left_len + join_right_len,
387 );
388
389 let output_indices: Vec<_> = match join.join_type() {
390 JoinType::LeftSemi | JoinType::LeftAnti => left_iter.collect(),
391 JoinType::RightSemi | JoinType::RightAnti => right_iter.collect(),
392 _ => left_iter.chain(right_iter).collect(),
393 };
394
395 let target_size = join_left_len + apply_left_len + join_right_len;
396 (output_indices, target_size)
397 };
398 let mut output_indices_mapping = ColIndexMapping::new(
399 output_indices.iter().map(|x| Some(*x)).collect(),
400 target_size,
401 );
402 let new_join = LogicalJoin::new(
403 join.left(),
404 new_join_right,
405 join.join_type(),
406 new_join_condition,
407 )
408 .clone_with_output_indices(output_indices);
409
410 LogicalFilter::create(
412 new_join.into(),
413 Condition {
414 conjunctions: other_condition,
415 }
416 .rewrite_expr(&mut output_indices_mapping),
417 )
418 }
419
420 fn push_apply_both_side(
421 &self,
422 apply_left: PlanRef,
423 join: &LogicalJoin,
424 apply_on: Condition,
425 apply_join_type: JoinType,
426 correlated_id: CorrelatedId,
427 correlated_indices: Vec<usize>,
428 ) -> PlanRef {
429 let apply_left_len = apply_left.schema().len();
430 let join_left_len = join.left().schema().len();
431 let mut rewriter = Rewriter {
432 join_left_len,
433 join_left_offset: apply_left_len as isize,
434 join_right_offset: 2 * apply_left_len as isize,
435 index_mapping: ApplyCorrelatedIndicesConverter::convert_to_index_mapping(
436 &correlated_indices,
437 ),
438 correlated_id,
439 };
440
441 let natural_conjunctions = apply_left
443 .schema()
444 .fields
445 .iter()
446 .enumerate()
447 .map(|(i, field)| {
448 Self::create_null_safe_equal_expr(
449 i,
450 field.data_type.clone(),
451 i + join_left_len + apply_left_len,
452 field.data_type.clone(),
453 )
454 })
455 .collect_vec();
456 let new_join_condition = Condition {
457 conjunctions: join
458 .on()
459 .clone()
460 .into_iter()
461 .map(|expr| rewriter.rewrite_expr(expr))
462 .chain(natural_conjunctions)
463 .collect_vec(),
464 };
465
466 let mut left_apply_condition: Vec<ExprImpl> = vec![];
467 let mut right_apply_condition: Vec<ExprImpl> = vec![];
468 let mut other_condition: Vec<ExprImpl> = vec![];
469
470 match join.join_type() {
471 JoinType::LeftSemi | JoinType::LeftAnti => {
472 left_apply_condition.extend(apply_on);
473 }
474 JoinType::RightSemi | JoinType::RightAnti => {
475 right_apply_condition.extend(apply_on);
476 }
477 JoinType::Inner
478 | JoinType::LeftOuter
479 | JoinType::RightOuter
480 | JoinType::FullOuter
481 | JoinType::AsofInner
482 | JoinType::AsofLeftOuter => {
483 let apply_len = apply_left_len + join.schema().len();
484 let mut d_t1_bit_set = FixedBitSet::with_capacity(apply_len);
485 let mut d_t2_bit_set = FixedBitSet::with_capacity(apply_len);
486 d_t1_bit_set.set_range(0..apply_left_len + join_left_len, true);
487 d_t2_bit_set.set_range(0..apply_left_len, true);
488 d_t2_bit_set.set_range(apply_left_len + join_left_len..apply_len, true);
489
490 for (key, group) in &apply_on.into_iter().chunk_by(|expr| {
491 let collect_bit_set = expr.collect_input_refs(apply_len);
492 if collect_bit_set.is_subset(&d_t1_bit_set) {
493 0
494 } else if collect_bit_set.is_subset(&d_t2_bit_set) {
495 1
496 } else {
497 2
498 }
499 }) {
500 let vec = group.collect_vec();
501 match key {
502 0 => left_apply_condition.extend(vec),
503 1 => right_apply_condition.extend(vec),
504 2 => other_condition.extend(vec),
505 _ => unreachable!(),
506 }
507 }
508
509 let mut right_apply_condition_rewriter = Rewriter {
511 join_left_len: apply_left_len,
512 join_left_offset: 0,
513 join_right_offset: -(join_left_len as isize),
514 index_mapping: ColIndexMapping::empty(0, 0),
515 correlated_id,
516 };
517
518 right_apply_condition = right_apply_condition
519 .into_iter()
520 .map(|expr| right_apply_condition_rewriter.rewrite_expr(expr))
521 .collect_vec();
522 }
523 JoinType::Unspecified => unreachable!(),
524 }
525
526 let new_join_left = LogicalApply::create(
527 apply_left.clone(),
528 join.left(),
529 apply_join_type,
530 Condition {
531 conjunctions: left_apply_condition,
532 },
533 correlated_id,
534 correlated_indices.clone(),
535 false,
536 );
537 let new_join_right = LogicalApply::create(
538 apply_left,
539 join.right(),
540 apply_join_type,
541 Condition {
542 conjunctions: right_apply_condition,
543 },
544 correlated_id,
545 correlated_indices,
546 false,
547 );
548
549 let (output_indices, target_size) = {
550 let (apply_left_len, join_right_len) = match apply_join_type {
551 JoinType::LeftSemi | JoinType::LeftAnti => (apply_left_len, 0),
552 JoinType::RightSemi | JoinType::RightAnti => (0, join.right().schema().len()),
553 _ => (apply_left_len, join.right().schema().len()),
554 };
555
556 let left_iter = 0..join_left_len + apply_left_len;
557 let right_iter = join_left_len + apply_left_len * 2
558 ..join_left_len + apply_left_len * 2 + join_right_len;
559
560 let output_indices: Vec<_> = match join.join_type() {
561 JoinType::LeftSemi | JoinType::LeftAnti => left_iter.collect(),
562 JoinType::RightSemi | JoinType::RightAnti => right_iter.collect(),
563 _ => left_iter.chain(right_iter).collect(),
564 };
565
566 let target_size = join_left_len + apply_left_len * 2 + join_right_len;
567 (output_indices, target_size)
568 };
569 let new_join = LogicalJoin::new(
570 new_join_left,
571 new_join_right,
572 join.join_type(),
573 new_join_condition,
574 )
575 .clone_with_output_indices(output_indices.clone());
576
577 match join.join_type() {
578 JoinType::LeftSemi | JoinType::LeftAnti | JoinType::RightSemi | JoinType::RightAnti => {
579 new_join.into()
580 }
581 JoinType::Inner
582 | JoinType::LeftOuter
583 | JoinType::RightOuter
584 | JoinType::FullOuter
585 | JoinType::AsofInner
586 | JoinType::AsofLeftOuter => {
587 let mut output_indices_mapping = ColIndexMapping::new(
588 output_indices.iter().map(|x| Some(*x)).collect(),
589 target_size,
590 );
591 LogicalFilter::create(
593 new_join.into(),
594 Condition {
595 conjunctions: other_condition,
596 }
597 .rewrite_expr(&mut output_indices_mapping),
598 )
599 }
600 JoinType::Unspecified => unreachable!(),
601 }
602 }
603
604 fn create_null_safe_equal_expr(
605 left: usize,
606 left_data_type: DataType,
607 right: usize,
608 right_data_type: DataType,
609 ) -> ExprImpl {
610 ExprImpl::FunctionCall(Box::new(FunctionCall::new_unchecked(
612 ExprType::IsNotDistinctFrom,
613 vec![
614 ExprImpl::InputRef(Box::new(InputRef::new(left, left_data_type))),
615 ExprImpl::InputRef(Box::new(InputRef::new(right, right_data_type))),
616 ],
617 Boolean,
618 )))
619 }
620}
621
622impl ApplyJoinTransposeRule {
623 pub fn create() -> BoxedRule {
624 Box::new(ApplyJoinTransposeRule {})
625 }
626}
627
628struct Rewriter {
630 join_left_len: usize,
631 join_left_offset: isize,
632 join_right_offset: isize,
633 index_mapping: ColIndexMapping,
634 correlated_id: CorrelatedId,
635}
636impl ExprRewriter for Rewriter {
637 fn rewrite_correlated_input_ref(
638 &mut self,
639 correlated_input_ref: CorrelatedInputRef,
640 ) -> ExprImpl {
641 if correlated_input_ref.correlated_id() == self.correlated_id {
642 InputRef::new(
643 self.index_mapping.map(correlated_input_ref.index()),
644 correlated_input_ref.return_type(),
645 )
646 .into()
647 } else {
648 correlated_input_ref.into()
649 }
650 }
651
652 fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
653 if input_ref.index < self.join_left_len {
654 InputRef::new(
655 (input_ref.index() as isize + self.join_left_offset) as usize,
656 input_ref.return_type(),
657 )
658 .into()
659 } else {
660 InputRef::new(
661 (input_ref.index() as isize + self.join_right_offset) as usize,
662 input_ref.return_type(),
663 )
664 .into()
665 }
666 }
667}