risingwave_frontend/optimizer/plan_node/
stream_join_common.rsuse itertools::Itertools;
use risingwave_common::util::iter_util::ZipEqFast;
use risingwave_pb::plan_common::JoinType;
use super::{generic, EqJoinPredicate};
use crate::optimizer::property::Distribution;
use crate::utils::ColIndexMappingRewriteExt;
use crate::PlanRef;
pub struct StreamJoinCommon;
impl StreamJoinCommon {
pub(super) fn get_dist_key_in_join_key(
left_dk_indices: &[usize],
right_dk_indices: &[usize],
eq_join_predicate: &EqJoinPredicate,
) -> Vec<usize> {
let left_jk_indices = eq_join_predicate.left_eq_indexes();
let right_jk_indices = &eq_join_predicate.right_eq_indexes();
assert_eq!(left_jk_indices.len(), right_jk_indices.len());
let mut dk_indices_in_jk = vec![];
for (l_dk_idx, r_dk_idx) in left_dk_indices.iter().zip_eq_fast(right_dk_indices.iter()) {
for dk_idx_in_jk in left_jk_indices.iter().positions(|idx| idx == l_dk_idx) {
if right_jk_indices[dk_idx_in_jk] == *r_dk_idx {
dk_indices_in_jk.push(dk_idx_in_jk);
break;
}
}
}
assert_eq!(dk_indices_in_jk.len(), left_dk_indices.len());
dk_indices_in_jk
}
pub(super) fn derive_dist(
left: &Distribution,
right: &Distribution,
logical: &generic::Join<PlanRef>,
) -> Distribution {
match (left, right) {
(Distribution::Single, Distribution::Single) => Distribution::Single,
(Distribution::HashShard(_), Distribution::HashShard(_)) => {
match logical.join_type {
JoinType::Unspecified => {
unreachable!()
}
JoinType::FullOuter => Distribution::SomeShard,
JoinType::Inner
| JoinType::LeftOuter
| JoinType::LeftSemi
| JoinType::LeftAnti
| JoinType::AsofInner
| JoinType::AsofLeftOuter => {
let l2o = logical
.l2i_col_mapping()
.composite(&logical.i2o_col_mapping());
l2o.rewrite_provided_distribution(left)
}
JoinType::RightSemi | JoinType::RightAnti | JoinType::RightOuter => {
let r2o = logical
.r2i_col_mapping()
.composite(&logical.i2o_col_mapping());
r2o.rewrite_provided_distribution(right)
}
}
}
(_, _) => unreachable!(
"suspicious distribution: left: {:?}, right: {:?}",
left, right
),
}
}
}