risingwave_frontend/optimizer/plan_node/
stream_join_common.rs1use itertools::Itertools;
16use risingwave_common::util::iter_util::ZipEqFast;
17use risingwave_pb::plan_common::JoinType;
18
19use super::{EqJoinPredicate, generic};
20use crate::PlanRef;
21use crate::optimizer::property::Distribution;
22use crate::utils::ColIndexMappingRewriteExt;
23
24pub struct StreamJoinCommon;
25
26impl StreamJoinCommon {
27 pub(super) fn get_dist_key_in_join_key(
28 left_dk_indices: &[usize],
29 right_dk_indices: &[usize],
30 eq_join_predicate: &EqJoinPredicate,
31 ) -> Vec<usize> {
32 let left_jk_indices = eq_join_predicate.left_eq_indexes();
33 let right_jk_indices = &eq_join_predicate.right_eq_indexes();
34 assert_eq!(left_jk_indices.len(), right_jk_indices.len());
35 let mut dk_indices_in_jk = vec![];
36 for (l_dk_idx, r_dk_idx) in left_dk_indices.iter().zip_eq_fast(right_dk_indices.iter()) {
37 for dk_idx_in_jk in left_jk_indices.iter().positions(|idx| idx == l_dk_idx) {
38 if right_jk_indices[dk_idx_in_jk] == *r_dk_idx {
39 dk_indices_in_jk.push(dk_idx_in_jk);
40 break;
41 }
42 }
43 }
44 assert_eq!(dk_indices_in_jk.len(), left_dk_indices.len());
45 dk_indices_in_jk
46 }
47
48 pub(super) fn derive_dist(
49 left: &Distribution,
50 right: &Distribution,
51 logical: &generic::Join<PlanRef>,
52 ) -> Distribution {
53 match (left, right) {
54 (Distribution::Single, Distribution::Single) => Distribution::Single,
55 (Distribution::HashShard(_), Distribution::HashShard(_)) => {
56 match logical.join_type {
59 JoinType::Unspecified => {
60 unreachable!()
61 }
62 JoinType::FullOuter => Distribution::SomeShard,
63 JoinType::Inner
64 | JoinType::LeftOuter
65 | JoinType::LeftSemi
66 | JoinType::LeftAnti
67 | JoinType::AsofInner
68 | JoinType::AsofLeftOuter => {
69 let l2o = logical
70 .l2i_col_mapping()
71 .composite(&logical.i2o_col_mapping());
72 l2o.rewrite_provided_distribution(left)
73 }
74 JoinType::RightSemi | JoinType::RightAnti | JoinType::RightOuter => {
75 let r2o = logical
76 .r2i_col_mapping()
77 .composite(&logical.i2o_col_mapping());
78 r2o.rewrite_provided_distribution(right)
79 }
80 }
81 }
82 (_, _) => unreachable!(
83 "suspicious distribution: left: {:?}, right: {:?}",
84 left, right
85 ),
86 }
87 }
88}