use itertools::{EitherOrBoth, Itertools};
use risingwave_common::catalog::{Field, Schema};
use risingwave_common::types::DataType;
use risingwave_common::util::sort_util::OrderType;
use risingwave_pb::plan_common::JoinType;
use super::{EqJoinPredicate, GenericPlanNode, GenericPlanRef};
use crate::expr::{ExprRewriter, ExprVisitor};
use crate::optimizer::optimizer_context::OptimizerContextRef;
use crate::optimizer::plan_node::stream;
use crate::optimizer::plan_node::utils::TableCatalogBuilder;
use crate::optimizer::property::FunctionalDependencySet;
use crate::utils::{ColIndexMapping, ColIndexMappingRewriteExt, Condition};
use crate::TableCatalog;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Join<PlanRef> {
pub left: PlanRef,
pub right: PlanRef,
pub on: Condition,
pub join_type: JoinType,
pub output_indices: Vec<usize>,
}
pub(crate) fn has_repeated_element(slice: &[usize]) -> bool {
(1..slice.len()).any(|i| slice[i..].contains(&slice[i - 1]))
}
impl<PlanRef: GenericPlanRef> Join<PlanRef> {
pub(crate) fn rewrite_exprs(&mut self, r: &mut dyn ExprRewriter) {
self.on = self.on.clone().rewrite_expr(r);
}
pub(crate) fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
self.on.visit_expr(v);
}
pub fn eq_indexes(&self) -> Vec<(usize, usize)> {
let left_len = self.left.schema().len();
let right_len = self.right.schema().len();
let eq_predicate = EqJoinPredicate::create(left_len, right_len, self.on.clone());
eq_predicate.eq_indexes()
}
pub fn new(
left: PlanRef,
right: PlanRef,
on: Condition,
join_type: JoinType,
output_indices: Vec<usize>,
) -> Self {
debug_assert!(!has_repeated_element(&output_indices));
Self {
left,
right,
on,
join_type,
output_indices,
}
}
}
impl<I: stream::StreamPlanRef> Join<I> {
pub fn infer_internal_and_degree_table_catalog(
input: I,
join_key_indices: Vec<usize>,
dk_indices_in_jk: Vec<usize>,
) -> (TableCatalog, TableCatalog, Vec<usize>) {
let schema = input.schema();
let internal_table_dist_keys = dk_indices_in_jk
.iter()
.map(|idx| join_key_indices[*idx])
.collect_vec();
let degree_table_dist_keys = dk_indices_in_jk.clone();
let join_key_len = join_key_indices.len();
let mut pk_indices = join_key_indices;
let mut deduped_input_pk_indices = vec![];
for input_pk_idx in input.stream_key().unwrap() {
if !pk_indices.contains(input_pk_idx)
&& !deduped_input_pk_indices.contains(input_pk_idx)
{
deduped_input_pk_indices.push(*input_pk_idx);
}
}
pk_indices.extend(deduped_input_pk_indices.clone());
let mut internal_table_catalog_builder = TableCatalogBuilder::default();
let internal_columns_fields = schema.fields().to_vec();
internal_columns_fields.iter().for_each(|field| {
internal_table_catalog_builder.add_column(field);
});
pk_indices.iter().for_each(|idx| {
internal_table_catalog_builder.add_order_column(*idx, OrderType::ascending())
});
let mut degree_table_catalog_builder = TableCatalogBuilder::default();
let degree_column_field = Field::with_name(DataType::Int64, "_degree");
pk_indices.iter().enumerate().for_each(|(order_idx, idx)| {
degree_table_catalog_builder.add_column(&internal_columns_fields[*idx]);
degree_table_catalog_builder.add_order_column(order_idx, OrderType::ascending());
});
degree_table_catalog_builder.add_column(°ree_column_field);
degree_table_catalog_builder
.set_value_indices(vec![degree_table_catalog_builder.columns().len() - 1]);
internal_table_catalog_builder.set_dist_key_in_pk(dk_indices_in_jk.clone());
degree_table_catalog_builder.set_dist_key_in_pk(dk_indices_in_jk);
(
internal_table_catalog_builder.build(internal_table_dist_keys, join_key_len),
degree_table_catalog_builder.build(degree_table_dist_keys, join_key_len),
deduped_input_pk_indices,
)
}
}
impl<PlanRef: GenericPlanRef> GenericPlanNode for Join<PlanRef> {
fn schema(&self) -> Schema {
let left_schema = self.left.schema();
let right_schema = self.right.schema();
let i2l = self.i2l_col_mapping();
let i2r = self.i2r_col_mapping();
let fields = self
.output_indices
.iter()
.map(|&i| match (i2l.try_map(i), i2r.try_map(i)) {
(Some(l_i), None) => left_schema.fields()[l_i].clone(),
(None, Some(r_i)) => right_schema.fields()[r_i].clone(),
_ => panic!(
"left len {}, right len {}, i {}, lmap {:?}, rmap {:?}",
left_schema.len(),
right_schema.len(),
i,
i2l,
i2r
),
})
.collect();
Schema { fields }
}
fn stream_key(&self) -> Option<Vec<usize>> {
let eq_indexes = self.eq_indexes();
let left_pk = self.left.stream_key()?;
let right_pk = self.right.stream_key()?;
let l2i = self.l2i_col_mapping();
let r2i = self.r2i_col_mapping();
let full_out_col_num = self.internal_column_num();
let i2o = ColIndexMapping::with_remaining_columns(&self.output_indices, full_out_col_num);
let mut pk_indices = left_pk
.iter()
.map(|index| l2i.try_map(*index))
.chain(right_pk.iter().map(|index| r2i.try_map(*index)))
.flatten()
.map(|index| i2o.try_map(index))
.collect::<Option<Vec<_>>>()?;
let l2i = self.l2i_col_mapping();
let r2i = self.r2i_col_mapping();
let full_out_col_num = self.internal_column_num();
let i2o = ColIndexMapping::with_remaining_columns(&self.output_indices, full_out_col_num);
let either_or_both = self.add_which_join_key_to_pk();
for (lk, rk) in eq_indexes {
match either_or_both {
EitherOrBoth::Left(_) => {
if let Some(rk) = r2i.try_map(rk) {
if let Some(out_k) = i2o.try_map(rk) {
pk_indices.retain(|&x| x != out_k);
}
}
if let Some(lk) = l2i.try_map(lk) {
let out_k = i2o.try_map(lk)?;
if !pk_indices.contains(&out_k) {
pk_indices.push(out_k);
}
}
}
EitherOrBoth::Right(_) => {
if let Some(lk) = l2i.try_map(lk) {
if let Some(out_k) = i2o.try_map(lk) {
pk_indices.retain(|&x| x != out_k);
}
}
if let Some(rk) = r2i.try_map(rk) {
let out_k = i2o.try_map(rk)?;
if !pk_indices.contains(&out_k) {
pk_indices.push(out_k);
}
}
}
EitherOrBoth::Both(_, _) => {
if let Some(lk) = l2i.try_map(lk) {
let out_k = i2o.try_map(lk)?;
if !pk_indices.contains(&out_k) {
pk_indices.push(out_k);
}
}
if let Some(rk) = r2i.try_map(rk) {
let out_k = i2o.try_map(rk)?;
if !pk_indices.contains(&out_k) {
pk_indices.push(out_k);
}
}
}
};
}
Some(pk_indices)
}
fn ctx(&self) -> OptimizerContextRef {
self.left.ctx()
}
fn functional_dependency(&self) -> FunctionalDependencySet {
let left_len = self.left.schema().len();
let right_len = self.right.schema().len();
let left_fd_set = self.left.functional_dependency().clone();
let right_fd_set = self.right.functional_dependency().clone();
let full_out_col_num = self.internal_column_num();
let get_new_left_fd_set = |left_fd_set: FunctionalDependencySet| {
ColIndexMapping::with_shift_offset(left_len, 0)
.composite(&ColIndexMapping::identity(full_out_col_num))
.rewrite_functional_dependency_set(left_fd_set)
};
let get_new_right_fd_set = |right_fd_set: FunctionalDependencySet| {
ColIndexMapping::with_shift_offset(right_len, left_len.try_into().unwrap())
.rewrite_functional_dependency_set(right_fd_set)
};
let fd_set: FunctionalDependencySet = match self.join_type {
JoinType::Inner | JoinType::AsofInner => {
let mut fd_set = FunctionalDependencySet::new(full_out_col_num);
for i in &self.on.conjunctions {
if let Some((col, _)) = i.as_eq_const() {
fd_set.add_constant_columns(&[col.index()])
} else if let Some((left, right)) = i.as_eq_cond() {
fd_set.add_functional_dependency_by_column_indices(
&[left.index()],
&[right.index()],
);
fd_set.add_functional_dependency_by_column_indices(
&[right.index()],
&[left.index()],
);
}
}
get_new_left_fd_set(left_fd_set)
.into_dependencies()
.into_iter()
.chain(get_new_right_fd_set(right_fd_set).into_dependencies())
.for_each(|fd| fd_set.add_functional_dependency(fd));
fd_set
}
JoinType::LeftOuter | JoinType::AsofLeftOuter => get_new_left_fd_set(left_fd_set),
JoinType::RightOuter => get_new_right_fd_set(right_fd_set),
JoinType::FullOuter => FunctionalDependencySet::new(full_out_col_num),
JoinType::LeftSemi | JoinType::LeftAnti => left_fd_set,
JoinType::RightSemi | JoinType::RightAnti => right_fd_set,
JoinType::Unspecified => unreachable!(),
};
ColIndexMapping::with_remaining_columns(&self.output_indices, full_out_col_num)
.rewrite_functional_dependency_set(fd_set)
}
}
impl<PlanRef> Join<PlanRef> {
pub fn decompose(self) -> (PlanRef, PlanRef, Condition, JoinType, Vec<usize>) {
(
self.left,
self.right,
self.on,
self.join_type,
self.output_indices,
)
}
pub fn full_out_col_num(left_len: usize, right_len: usize, join_type: JoinType) -> usize {
match join_type {
JoinType::Inner
| JoinType::LeftOuter
| JoinType::RightOuter
| JoinType::FullOuter
| JoinType::AsofInner
| JoinType::AsofLeftOuter => left_len + right_len,
JoinType::LeftSemi | JoinType::LeftAnti => left_len,
JoinType::RightSemi | JoinType::RightAnti => right_len,
JoinType::Unspecified => unreachable!(),
}
}
}
impl<PlanRef: GenericPlanRef> Join<PlanRef> {
pub fn with_full_output(
left: PlanRef,
right: PlanRef,
join_type: JoinType,
on: Condition,
) -> Self {
let out_column_num =
Self::full_out_col_num(left.schema().len(), right.schema().len(), join_type);
Self {
left,
right,
join_type,
on,
output_indices: (0..out_column_num).collect(),
}
}
pub fn internal_column_num(&self) -> usize {
Self::full_out_col_num(
self.left.schema().len(),
self.right.schema().len(),
self.join_type,
)
}
pub fn is_full_out(&self) -> bool {
self.output_indices.len() == self.internal_column_num()
}
pub fn i2l_col_mapping(&self) -> ColIndexMapping {
let left_len = self.left.schema().len();
let right_len = self.right.schema().len();
match self.join_type {
JoinType::Inner
| JoinType::LeftOuter
| JoinType::RightOuter
| JoinType::FullOuter
| JoinType::AsofInner
| JoinType::AsofLeftOuter => {
ColIndexMapping::identity_or_none(left_len + right_len, left_len)
}
JoinType::LeftSemi | JoinType::LeftAnti => ColIndexMapping::identity(left_len),
JoinType::RightSemi | JoinType::RightAnti => {
ColIndexMapping::empty(right_len, left_len)
}
JoinType::Unspecified => unreachable!(),
}
}
pub fn i2r_col_mapping(&self) -> ColIndexMapping {
let left_len = self.left.schema().len();
let right_len = self.right.schema().len();
match self.join_type {
JoinType::Inner
| JoinType::LeftOuter
| JoinType::RightOuter
| JoinType::FullOuter
| JoinType::AsofInner
| JoinType::AsofLeftOuter => {
ColIndexMapping::with_shift_offset(left_len + right_len, -(left_len as isize))
}
JoinType::LeftSemi | JoinType::LeftAnti => ColIndexMapping::empty(left_len, right_len),
JoinType::RightSemi | JoinType::RightAnti => ColIndexMapping::identity(right_len),
JoinType::Unspecified => unreachable!(),
}
}
pub fn i2l_col_mapping_ignore_join_type(&self) -> ColIndexMapping {
let left_len = self.left.schema().len();
let right_len = self.right.schema().len();
ColIndexMapping::identity_or_none(left_len + right_len, left_len)
}
pub fn i2r_col_mapping_ignore_join_type(&self) -> ColIndexMapping {
let left_len = self.left.schema().len();
let right_len = self.right.schema().len();
ColIndexMapping::with_shift_offset(left_len + right_len, -(left_len as isize))
}
pub fn l2i_col_mapping(&self) -> ColIndexMapping {
self.i2l_col_mapping()
.inverse()
.expect("must be invertible")
}
pub fn r2i_col_mapping(&self) -> ColIndexMapping {
self.i2r_col_mapping()
.inverse()
.expect("must be invertible")
}
pub fn i2o_col_mapping(&self) -> ColIndexMapping {
ColIndexMapping::with_remaining_columns(&self.output_indices, self.internal_column_num())
}
pub fn o2i_col_mapping(&self) -> ColIndexMapping {
ColIndexMapping::new(
self.output_indices.iter().map(|x| Some(*x)).collect(),
self.internal_column_num(),
)
}
pub fn add_which_join_key_to_pk(&self) -> EitherOrBoth<(), ()> {
match self.join_type {
JoinType::Inner | JoinType::AsofInner => {
EitherOrBoth::Left(())
}
JoinType::LeftOuter
| JoinType::LeftSemi
| JoinType::LeftAnti
| JoinType::AsofLeftOuter => EitherOrBoth::Left(()),
JoinType::RightSemi | JoinType::RightAnti | JoinType::RightOuter => {
EitherOrBoth::Right(())
}
JoinType::FullOuter => EitherOrBoth::Both((), ()),
JoinType::Unspecified => unreachable!(),
}
}
pub fn concat_schema(&self) -> Schema {
Schema::new(
[
self.left.schema().fields.clone(),
self.right.schema().fields.clone(),
]
.concat(),
)
}
}
pub fn push_down_into_join(
predicate: &mut Condition,
left_col_num: usize,
right_col_num: usize,
ty: JoinType,
push_temporal_predicate: bool,
) -> (Condition, Condition, Condition) {
let (left, right) = push_down_to_inputs(
predicate,
left_col_num,
right_col_num,
can_push_left_from_filter(ty),
can_push_right_from_filter(ty),
push_temporal_predicate,
);
let on = if can_push_on_from_filter(ty) {
let mut conjunctions = std::mem::take(&mut predicate.conjunctions);
if push_temporal_predicate {
Condition { conjunctions }
} else {
let on = Condition {
conjunctions: conjunctions
.extract_if(|expr| expr.count_nows() == 0)
.collect(),
};
predicate.conjunctions = conjunctions;
on
}
} else {
Condition::true_cond()
};
(left, right, on)
}
pub fn push_down_join_condition(
on_condition: &mut Condition,
left_col_num: usize,
right_col_num: usize,
ty: JoinType,
push_temporal_predicate: bool,
) -> (Condition, Condition) {
push_down_to_inputs(
on_condition,
left_col_num,
right_col_num,
can_push_left_from_on(ty),
can_push_right_from_on(ty),
push_temporal_predicate,
)
}
fn push_down_to_inputs(
predicate: &mut Condition,
left_col_num: usize,
right_col_num: usize,
push_left: bool,
push_right: bool,
push_temporal_predicate: bool,
) -> (Condition, Condition) {
let mut conjunctions = std::mem::take(&mut predicate.conjunctions);
let (mut left, right, mut others) = if push_temporal_predicate {
Condition { conjunctions }.split(left_col_num, right_col_num)
} else {
let temporal_filter_cons = conjunctions
.extract_if(|e| e.count_nows() != 0)
.collect_vec();
let (left, right, mut others) =
Condition { conjunctions }.split(left_col_num, right_col_num);
others.conjunctions.extend(temporal_filter_cons);
(left, right, others)
};
if !push_left {
others.conjunctions.extend(left);
left = Condition::true_cond();
};
let right = if push_right {
let mut mapping = ColIndexMapping::with_shift_offset(
left_col_num + right_col_num,
-(left_col_num as isize),
);
right.rewrite_expr(&mut mapping)
} else {
others.conjunctions.extend(right);
Condition::true_cond()
};
predicate.conjunctions = others.conjunctions;
(left, right)
}
pub fn can_push_left_from_filter(ty: JoinType) -> bool {
matches!(
ty,
JoinType::Inner | JoinType::LeftOuter | JoinType::LeftSemi | JoinType::LeftAnti
)
}
pub fn can_push_right_from_filter(ty: JoinType) -> bool {
matches!(
ty,
JoinType::Inner | JoinType::RightOuter | JoinType::RightSemi | JoinType::RightAnti
)
}
pub fn can_push_on_from_filter(ty: JoinType) -> bool {
matches!(
ty,
JoinType::Inner | JoinType::LeftSemi | JoinType::RightSemi
)
}
pub fn can_push_left_from_on(ty: JoinType) -> bool {
matches!(
ty,
JoinType::Inner | JoinType::RightOuter | JoinType::LeftSemi
)
}
pub fn can_push_right_from_on(ty: JoinType) -> bool {
matches!(
ty,
JoinType::Inner | JoinType::LeftOuter | JoinType::RightSemi
)
}