1use itertools::{EitherOrBoth, Itertools};
16use risingwave_common::catalog::{Field, Schema};
17use risingwave_common::types::DataType;
18use risingwave_common::util::sort_util::OrderType;
19use risingwave_pb::plan_common::JoinType;
20
21use super::{EqJoinPredicate, GenericPlanNode, GenericPlanRef};
22use crate::TableCatalog;
23use crate::expr::{ExprRewriter, ExprVisitor};
24use crate::optimizer::optimizer_context::OptimizerContextRef;
25use crate::optimizer::plan_node::StreamPlanRef;
26use crate::optimizer::plan_node::stream::StreamPlanNodeMetadata as _;
27use crate::optimizer::plan_node::stream::prelude::*;
28use crate::optimizer::plan_node::utils::TableCatalogBuilder;
29use crate::optimizer::property::{FunctionalDependencySet, StreamKind};
30use crate::utils::{ColIndexMapping, ColIndexMappingRewriteExt, Condition};
31
32#[derive(Debug, Clone, PartialEq, Eq, Hash)]
38pub enum JoinOn {
39 Condition(Condition),
40 EqPredicate(EqJoinPredicate),
41}
42
43impl JoinOn {
44 pub fn as_condition(&self) -> Condition {
53 match self {
54 JoinOn::Condition(cond) => cond.clone(),
55 JoinOn::EqPredicate(pred) => pred.all_cond(),
56 }
57 }
58
59 pub fn as_condition_ref(&self) -> Option<&Condition> {
64 match self {
65 JoinOn::Condition(cond) => Some(cond),
66 JoinOn::EqPredicate(_) => None,
67 }
68 }
69
70 pub fn as_eq_predicate_ref(&self) -> Option<&EqJoinPredicate> {
76 match self {
77 JoinOn::Condition(_) => None,
78 JoinOn::EqPredicate(pred) => Some(pred),
79 }
80 }
81
82 pub fn rewrite_exprs(&mut self, r: &mut dyn ExprRewriter) {
87 match self {
88 JoinOn::Condition(cond) => {
89 *cond = cond.clone().rewrite_expr(r);
90 }
91 JoinOn::EqPredicate(pred) => {
92 *pred = pred.rewrite_exprs(r);
93 }
94 }
95 }
96
97 pub fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
101 match self {
102 JoinOn::Condition(cond) => cond.visit_expr(v),
103 JoinOn::EqPredicate(pred) => pred.visit_exprs(v),
105 }
106 }
107}
108
109#[derive(Debug, Clone, PartialEq, Eq, Hash)]
116pub struct Join<PlanRef> {
117 pub left: PlanRef,
118 pub right: PlanRef,
119 pub on: JoinOn,
120 pub join_type: JoinType,
121 pub output_indices: Vec<usize>,
122}
123
124pub(crate) fn has_repeated_element(slice: &[usize]) -> bool {
125 (1..slice.len()).any(|i| slice[i..].contains(&slice[i - 1]))
126}
127
128impl<PlanRef: GenericPlanRef> Join<PlanRef> {
129 pub(crate) fn clone_with_inputs<OtherPlanRef>(
130 &self,
131 left: OtherPlanRef,
132 right: OtherPlanRef,
133 ) -> Join<OtherPlanRef> {
134 Join {
135 left,
136 right,
137 on: self.on.clone(),
138 join_type: self.join_type,
139 output_indices: self.output_indices.clone(),
140 }
141 }
142
143 pub(crate) fn rewrite_exprs(&mut self, r: &mut dyn ExprRewriter) {
144 self.on.rewrite_exprs(r);
145 }
146
147 pub(crate) fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
148 self.on.visit_exprs(v);
149 }
150
151 pub fn eq_indexes(&self) -> Vec<(usize, usize)> {
152 let left_len = self.left.schema().len();
153 let right_len = self.right.schema().len();
154 match &self.on {
155 JoinOn::Condition(on) => {
156 EqJoinPredicate::create(left_len, right_len, on.clone()).eq_indexes()
157 }
158 JoinOn::EqPredicate(pred) => pred.eq_indexes(),
159 }
160 }
161
162 pub fn new(
163 left: PlanRef,
164 right: PlanRef,
165 on: Condition,
166 join_type: JoinType,
167 output_indices: Vec<usize>,
168 ) -> Self {
169 debug_assert!(!has_repeated_element(&output_indices));
171 Self {
172 left,
173 right,
174 on: JoinOn::Condition(on),
175 join_type,
176 output_indices,
177 }
178 }
179
180 pub fn new_with_eq_predicate(
181 left: PlanRef,
182 right: PlanRef,
183 eq_join_predicate: EqJoinPredicate,
184 join_type: JoinType,
185 output_indices: Vec<usize>,
186 ) -> Self {
187 debug_assert!(!has_repeated_element(&output_indices));
188 Self {
189 left,
190 right,
191 on: JoinOn::EqPredicate(eq_join_predicate),
192 join_type,
193 output_indices,
194 }
195 }
196}
197
198impl Join<StreamPlanRef> {
199 pub fn stream_kind(&self) -> Result<StreamKind> {
200 let left_kind = reject_upsert_input!(self.left, "Join");
201 let right_kind = reject_upsert_input!(self.right, "Join");
202
203 if let JoinType::Inner | JoinType::AsofInner = self.join_type
205 && let StreamKind::AppendOnly = left_kind
206 && let StreamKind::AppendOnly = right_kind
207 {
208 Ok(StreamKind::AppendOnly)
209 } else {
210 Ok(StreamKind::Retract)
211 }
212 }
213
214 pub fn infer_internal_and_degree_table_catalog(
216 input: StreamPlanRef,
217 join_key_indices: Vec<usize>,
218 dk_indices_in_jk: Vec<usize>,
219 ) -> (TableCatalog, TableCatalog, Vec<usize>) {
220 let schema = input.schema();
221
222 let internal_table_dist_keys = dk_indices_in_jk
223 .iter()
224 .map(|idx| join_key_indices[*idx])
225 .collect_vec();
226
227 let degree_table_dist_keys = dk_indices_in_jk.clone();
228
229 let join_key_len = join_key_indices.len();
231 let mut pk_indices = join_key_indices;
232
233 let mut deduped_input_pk_indices = vec![];
235 for input_pk_idx in input.stream_key().unwrap() {
236 if !pk_indices.contains(input_pk_idx)
237 && !deduped_input_pk_indices.contains(input_pk_idx)
238 {
239 deduped_input_pk_indices.push(*input_pk_idx);
240 }
241 }
242
243 pk_indices.extend(deduped_input_pk_indices.clone());
244
245 let mut internal_table_catalog_builder = TableCatalogBuilder::default();
247 let internal_columns_fields = schema.fields().to_vec();
248
249 internal_columns_fields.iter().for_each(|field| {
250 internal_table_catalog_builder.add_column(field);
251 });
252 pk_indices.iter().for_each(|idx| {
253 internal_table_catalog_builder.add_order_column(*idx, OrderType::ascending())
254 });
255
256 let mut degree_table_catalog_builder = TableCatalogBuilder::default();
258
259 let degree_column_field = Field::with_name(DataType::Int64, "_degree");
260
261 pk_indices.iter().enumerate().for_each(|(order_idx, idx)| {
262 degree_table_catalog_builder.add_column(&internal_columns_fields[*idx]);
263 degree_table_catalog_builder.add_order_column(order_idx, OrderType::ascending());
264 });
265 degree_table_catalog_builder.add_column(°ree_column_field);
266 degree_table_catalog_builder
267 .set_value_indices(vec![degree_table_catalog_builder.columns().len() - 1]);
268
269 internal_table_catalog_builder.set_dist_key_in_pk(dk_indices_in_jk.clone());
270 degree_table_catalog_builder.set_dist_key_in_pk(dk_indices_in_jk);
271
272 (
273 internal_table_catalog_builder.build(internal_table_dist_keys, join_key_len),
274 degree_table_catalog_builder.build(degree_table_dist_keys, join_key_len),
275 deduped_input_pk_indices,
276 )
277 }
278}
279
280impl<PlanRef: GenericPlanRef> GenericPlanNode for Join<PlanRef> {
281 fn schema(&self) -> Schema {
282 let left_schema = self.left.schema();
283 let right_schema = self.right.schema();
284 let i2l = self.i2l_col_mapping();
285 let i2r = self.i2r_col_mapping();
286 let fields = self
287 .output_indices
288 .iter()
289 .map(|&i| match (i2l.try_map(i), i2r.try_map(i)) {
290 (Some(l_i), None) => left_schema.fields()[l_i].clone(),
291 (None, Some(r_i)) => right_schema.fields()[r_i].clone(),
292 _ => panic!(
293 "left len {}, right len {}, i {}, lmap {:?}, rmap {:?}",
294 left_schema.len(),
295 right_schema.len(),
296 i,
297 i2l,
298 i2r
299 ),
300 })
301 .collect();
302 Schema { fields }
303 }
304
305 fn stream_key(&self) -> Option<Vec<usize>> {
306 let eq_indexes = self.eq_indexes();
307 let left_pk = self.left.stream_key()?;
308 let right_pk = self.right.stream_key()?;
309 let l2i = self.l2i_col_mapping();
310 let r2i = self.r2i_col_mapping();
311 let full_out_col_num = self.internal_column_num();
312 let i2o = ColIndexMapping::with_remaining_columns(&self.output_indices, full_out_col_num);
313
314 let mut pk_indices_internal = left_pk
316 .iter()
317 .map(|index| l2i.try_map(*index))
318 .chain(right_pk.iter().map(|index| r2i.try_map(*index)))
319 .flatten()
320 .collect::<Vec<_>>();
321
322 let either_or_both = self.add_which_join_key_to_pk();
323
324 for (lk, rk) in eq_indexes {
325 match either_or_both {
326 EitherOrBoth::Left(_) => {
327 if let Some(rk_internal) = r2i.try_map(rk) {
333 pk_indices_internal.retain(|&x| x != rk_internal);
334 }
335 if let Some(lk_internal) = l2i.try_map(lk)
337 && !pk_indices_internal.contains(&lk_internal)
338 {
339 pk_indices_internal.push(lk_internal);
340 }
341 }
342 EitherOrBoth::Right(_) => {
343 if let Some(lk_internal) = l2i.try_map(lk) {
346 pk_indices_internal.retain(|&x| x != lk_internal);
347 }
348 if let Some(rk_internal) = r2i.try_map(rk)
350 && !pk_indices_internal.contains(&rk_internal)
351 {
352 pk_indices_internal.push(rk_internal);
353 }
354 }
355 EitherOrBoth::Both(_, _) => {
356 if let Some(lk_internal) = l2i.try_map(lk)
357 && !pk_indices_internal.contains(&lk_internal)
358 {
359 pk_indices_internal.push(lk_internal);
360 }
361 if let Some(rk_internal) = r2i.try_map(rk)
362 && !pk_indices_internal.contains(&rk_internal)
363 {
364 pk_indices_internal.push(rk_internal);
365 }
366 }
367 };
368 }
369
370 let pk_indices = pk_indices_internal
372 .iter()
373 .map(|&index| i2o.try_map(index))
374 .collect::<Option<Vec<_>>>()?;
375
376 Some(pk_indices)
377 }
378
379 fn ctx(&self) -> OptimizerContextRef {
380 self.left.ctx()
381 }
382
383 fn functional_dependency(&self) -> FunctionalDependencySet {
384 let left_len = self.left.schema().len();
385 let right_len = self.right.schema().len();
386 let left_fd_set = self.left.functional_dependency().clone();
387 let right_fd_set = self.right.functional_dependency().clone();
388
389 let full_out_col_num = self.internal_column_num();
390
391 let get_new_left_fd_set = |left_fd_set: FunctionalDependencySet| {
392 ColIndexMapping::with_shift_offset(left_len, 0)
393 .composite(&ColIndexMapping::identity(full_out_col_num))
394 .rewrite_functional_dependency_set(left_fd_set)
395 };
396 let get_new_right_fd_set = |right_fd_set: FunctionalDependencySet| {
397 ColIndexMapping::with_shift_offset(right_len, left_len.try_into().unwrap())
398 .rewrite_functional_dependency_set(right_fd_set)
399 };
400 let fd_set: FunctionalDependencySet = match self.join_type {
401 JoinType::Inner | JoinType::AsofInner => {
402 let mut fd_set = FunctionalDependencySet::new(full_out_col_num);
403 for i in &self.on.as_condition().conjunctions {
404 if let Some((col, _)) = i.as_eq_const() {
405 fd_set.add_constant_columns(&[col.index()])
406 } else if let Some((left, right)) = i.as_eq_cond() {
407 fd_set.add_functional_dependency_by_column_indices(
408 &[left.index()],
409 &[right.index()],
410 );
411 fd_set.add_functional_dependency_by_column_indices(
412 &[right.index()],
413 &[left.index()],
414 );
415 }
416 }
417 get_new_left_fd_set(left_fd_set)
418 .into_dependencies()
419 .into_iter()
420 .chain(get_new_right_fd_set(right_fd_set).into_dependencies())
421 .for_each(|fd| fd_set.add_functional_dependency(fd));
422 fd_set
423 }
424 JoinType::LeftOuter | JoinType::AsofLeftOuter => get_new_left_fd_set(left_fd_set),
425 JoinType::RightOuter => get_new_right_fd_set(right_fd_set),
426 JoinType::FullOuter => FunctionalDependencySet::new(full_out_col_num),
427 JoinType::LeftSemi | JoinType::LeftAnti => left_fd_set,
428 JoinType::RightSemi | JoinType::RightAnti => right_fd_set,
429 JoinType::Unspecified => unreachable!(),
430 };
431 ColIndexMapping::with_remaining_columns(&self.output_indices, full_out_col_num)
432 .rewrite_functional_dependency_set(fd_set)
433 }
434}
435
436impl<PlanRef> Join<PlanRef> {
437 pub fn decompose(self) -> (PlanRef, PlanRef, Condition, JoinType, Vec<usize>) {
438 (
439 self.left,
440 self.right,
441 self.on.as_condition(),
442 self.join_type,
443 self.output_indices,
444 )
445 }
446}
447
448impl<PlanRef: GenericPlanRef> Join<PlanRef> {
449 pub fn full_out_col_num(left_len: usize, right_len: usize, join_type: JoinType) -> usize {
450 match join_type {
451 JoinType::Inner
452 | JoinType::LeftOuter
453 | JoinType::RightOuter
454 | JoinType::FullOuter
455 | JoinType::AsofInner
456 | JoinType::AsofLeftOuter => left_len + right_len,
457 JoinType::LeftSemi | JoinType::LeftAnti => left_len,
458 JoinType::RightSemi | JoinType::RightAnti => right_len,
459 JoinType::Unspecified => unreachable!(),
460 }
461 }
462
463 pub fn with_full_output(
464 left: PlanRef,
465 right: PlanRef,
466 join_type: JoinType,
467 on: Condition,
468 ) -> Self {
469 let out_column_num =
470 Self::full_out_col_num(left.schema().len(), right.schema().len(), join_type);
471 Self {
472 left,
473 right,
474 join_type,
475 on: JoinOn::Condition(on),
476 output_indices: (0..out_column_num).collect(),
477 }
478 }
479
480 pub fn with_full_output_eq_predicate(
481 left: PlanRef,
482 right: PlanRef,
483 join_type: JoinType,
484 eq_join_predicate: EqJoinPredicate,
485 ) -> Self {
486 let out_column_num =
487 Self::full_out_col_num(left.schema().len(), right.schema().len(), join_type);
488 Self {
489 left,
490 right,
491 join_type,
492 on: JoinOn::EqPredicate(eq_join_predicate),
493 output_indices: (0..out_column_num).collect(),
494 }
495 }
496
497 pub fn internal_column_num(&self) -> usize {
498 Self::full_out_col_num(
499 self.left.schema().len(),
500 self.right.schema().len(),
501 self.join_type,
502 )
503 }
504
505 pub fn is_full_out(&self) -> bool {
506 self.output_indices.len() == self.internal_column_num()
507 }
508
509 pub fn i2l_col_mapping(&self) -> ColIndexMapping {
511 let left_len = self.left.schema().len();
512 let right_len = self.right.schema().len();
513
514 match self.join_type {
515 JoinType::Inner
516 | JoinType::LeftOuter
517 | JoinType::RightOuter
518 | JoinType::FullOuter
519 | JoinType::AsofInner
520 | JoinType::AsofLeftOuter => {
521 ColIndexMapping::identity_or_none(left_len + right_len, left_len)
522 }
523
524 JoinType::LeftSemi | JoinType::LeftAnti => ColIndexMapping::identity(left_len),
525 JoinType::RightSemi | JoinType::RightAnti => {
526 ColIndexMapping::empty(right_len, left_len)
527 }
528 JoinType::Unspecified => unreachable!(),
529 }
530 }
531
532 pub fn i2r_col_mapping(&self) -> ColIndexMapping {
534 let left_len = self.left.schema().len();
535 let right_len = self.right.schema().len();
536
537 match self.join_type {
538 JoinType::Inner
539 | JoinType::LeftOuter
540 | JoinType::RightOuter
541 | JoinType::FullOuter
542 | JoinType::AsofInner
543 | JoinType::AsofLeftOuter => {
544 ColIndexMapping::with_shift_offset(left_len + right_len, -(left_len as isize))
545 }
546 JoinType::LeftSemi | JoinType::LeftAnti => ColIndexMapping::empty(left_len, right_len),
547 JoinType::RightSemi | JoinType::RightAnti => ColIndexMapping::identity(right_len),
548 JoinType::Unspecified => unreachable!(),
549 }
550 }
551
552 pub fn i2l_col_mapping_ignore_join_type(&self) -> ColIndexMapping {
554 let left_len = self.left.schema().len();
555 let right_len = self.right.schema().len();
556
557 ColIndexMapping::identity_or_none(left_len + right_len, left_len)
558 }
559
560 pub fn i2r_col_mapping_ignore_join_type(&self) -> ColIndexMapping {
562 let left_len = self.left.schema().len();
563 let right_len = self.right.schema().len();
564
565 ColIndexMapping::with_shift_offset(left_len + right_len, -(left_len as isize))
566 }
567
568 pub fn l2i_col_mapping(&self) -> ColIndexMapping {
570 self.i2l_col_mapping()
571 .inverse()
572 .expect("must be invertible")
573 }
574
575 pub fn r2i_col_mapping(&self) -> ColIndexMapping {
577 self.i2r_col_mapping()
578 .inverse()
579 .expect("must be invertible")
580 }
581
582 pub fn i2o_col_mapping(&self) -> ColIndexMapping {
584 ColIndexMapping::with_remaining_columns(&self.output_indices, self.internal_column_num())
585 }
586
587 pub fn o2i_col_mapping(&self) -> ColIndexMapping {
589 ColIndexMapping::new(
592 self.output_indices.iter().map(|x| Some(*x)).collect(),
593 self.internal_column_num(),
594 )
595 }
596
597 pub fn add_which_join_key_to_pk(&self) -> EitherOrBoth<(), ()> {
598 match self.join_type {
599 JoinType::Inner | JoinType::AsofInner => {
600 EitherOrBoth::Left(())
604 }
605 JoinType::LeftOuter
606 | JoinType::LeftSemi
607 | JoinType::LeftAnti
608 | JoinType::AsofLeftOuter => EitherOrBoth::Left(()),
609 JoinType::RightSemi | JoinType::RightAnti | JoinType::RightOuter => {
610 EitherOrBoth::Right(())
611 }
612 JoinType::FullOuter => EitherOrBoth::Both((), ()),
613 JoinType::Unspecified => unreachable!(),
614 }
615 }
616
617 pub fn concat_schema(&self) -> Schema {
618 Schema::new(
619 [
620 self.left.schema().fields.clone(),
621 self.right.schema().fields.clone(),
622 ]
623 .concat(),
624 )
625 }
626}
627
628pub fn push_down_into_join(
634 predicate: &mut Condition,
635 left_col_num: usize,
636 right_col_num: usize,
637 ty: JoinType,
638 push_temporal_predicate: bool,
639) -> (Condition, Condition, Condition) {
640 let (left, right) = push_down_to_inputs(
641 predicate,
642 left_col_num,
643 right_col_num,
644 can_push_left_from_filter(ty),
645 can_push_right_from_filter(ty),
646 push_temporal_predicate,
647 );
648
649 let on = if can_push_on_from_filter(ty) {
650 let mut conjunctions = std::mem::take(&mut predicate.conjunctions);
651
652 if push_temporal_predicate {
653 Condition { conjunctions }
654 } else {
655 let on = Condition {
657 conjunctions: conjunctions
658 .extract_if(.., |expr| expr.count_nows() == 0)
659 .collect(),
660 };
661 predicate.conjunctions = conjunctions;
662 on
663 }
664 } else {
665 Condition::true_cond()
666 };
667 (left, right, on)
668}
669
670pub fn push_down_join_condition(
675 on_condition: &mut Condition,
676 left_col_num: usize,
677 right_col_num: usize,
678 ty: JoinType,
679 push_temporal_predicate: bool,
680) -> (Condition, Condition) {
681 push_down_to_inputs(
682 on_condition,
683 left_col_num,
684 right_col_num,
685 can_push_left_from_on(ty),
686 can_push_right_from_on(ty),
687 push_temporal_predicate,
688 )
689}
690
691fn push_down_to_inputs(
696 predicate: &mut Condition,
697 left_col_num: usize,
698 right_col_num: usize,
699 push_left: bool,
700 push_right: bool,
701 push_temporal_predicate: bool,
702) -> (Condition, Condition) {
703 let mut conjunctions = std::mem::take(&mut predicate.conjunctions);
704 let (mut left, right, mut others) = if push_temporal_predicate {
705 Condition { conjunctions }.split(left_col_num, right_col_num)
706 } else {
707 let temporal_filter_cons = conjunctions
708 .extract_if(.., |e| e.count_nows() != 0)
709 .collect_vec();
710 let (left, right, mut others) =
711 Condition { conjunctions }.split(left_col_num, right_col_num);
712
713 others.conjunctions.extend(temporal_filter_cons);
714 (left, right, others)
715 };
716
717 if !push_left {
718 others.conjunctions.extend(left);
719 left = Condition::true_cond();
720 };
721
722 let right = if push_right {
723 let mut mapping = ColIndexMapping::with_shift_offset(
724 left_col_num + right_col_num,
725 -(left_col_num as isize),
726 );
727 right.rewrite_expr(&mut mapping)
728 } else {
729 others.conjunctions.extend(right);
730 Condition::true_cond()
731 };
732
733 predicate.conjunctions = others.conjunctions;
734
735 (left, right)
736}
737
738pub fn can_push_left_from_filter(ty: JoinType) -> bool {
739 matches!(
740 ty,
741 JoinType::Inner
742 | JoinType::LeftOuter
743 | JoinType::LeftSemi
744 | JoinType::LeftAnti
745 | JoinType::AsofInner
746 | JoinType::AsofLeftOuter
747 )
748}
749
750pub fn can_push_right_from_filter(ty: JoinType) -> bool {
751 matches!(
752 ty,
753 JoinType::Inner
754 | JoinType::RightOuter
755 | JoinType::RightSemi
756 | JoinType::RightAnti
757 | JoinType::AsofInner
758 )
759}
760
761pub fn can_push_on_from_filter(ty: JoinType) -> bool {
762 matches!(
763 ty,
764 JoinType::Inner | JoinType::LeftSemi | JoinType::RightSemi
765 )
766}
767
768pub fn can_push_left_from_on(ty: JoinType) -> bool {
769 matches!(
770 ty,
771 JoinType::Inner
772 | JoinType::RightOuter
773 | JoinType::LeftSemi
774 | JoinType::AsofInner
775 | JoinType::AsofLeftOuter
776 )
777}
778
779pub fn can_push_right_from_on(ty: JoinType) -> bool {
780 matches!(
781 ty,
782 JoinType::Inner
783 | JoinType::LeftOuter
784 | JoinType::RightSemi
785 | JoinType::AsofInner
786 | JoinType::AsofLeftOuter
787 )
788}