use fixedbitset::FixedBitSet;
use itertools::Itertools;
use risingwave_common::types::DataType;
use risingwave_common::types::DataType::Boolean;
use risingwave_pb::plan_common::JoinType;
use super::{BoxedRule, Rule};
use crate::expr::{
CorrelatedId, CorrelatedInputRef, Expr, ExprImpl, ExprRewriter, ExprType, FunctionCall,
InputRef,
};
use crate::optimizer::plan_node::generic::GenericPlanRef;
use crate::optimizer::plan_node::{
LogicalApply, LogicalFilter, LogicalJoin, PlanTreeNode, PlanTreeNodeBinary,
};
use crate::optimizer::plan_visitor::{ExprCorrelatedIdFinder, PlanCorrelatedIdFinder};
use crate::optimizer::rule::apply_offset_rewriter::ApplyCorrelatedIndicesConverter;
use crate::optimizer::PlanRef;
use crate::utils::{ColIndexMapping, Condition};
pub struct ApplyJoinTransposeRule {}
impl Rule for ApplyJoinTransposeRule {
fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
let apply: &LogicalApply = plan.as_logical_apply()?;
let (
apply_left,
apply_right,
apply_on,
apply_join_type,
correlated_id,
correlated_indices,
max_one_row,
) = apply.clone().decompose();
if max_one_row {
return None;
}
assert_eq!(apply_join_type, JoinType::Inner);
let join: &LogicalJoin = apply_right.as_logical_join()?;
let mut finder = ExprCorrelatedIdFinder::default();
join.on().visit_expr(&mut finder);
let join_cond_has_correlated_id = finder.contains(&correlated_id);
let join_left_has_correlated_id =
PlanCorrelatedIdFinder::find_correlated_id(join.left(), &correlated_id);
let join_right_has_correlated_id =
PlanCorrelatedIdFinder::find_correlated_id(join.right(), &correlated_id);
if !join_cond_has_correlated_id
&& !join_left_has_correlated_id
&& !join_right_has_correlated_id
{
return None;
}
if !join.output_indices_are_trivial() {
let new_apply_right = crate::optimizer::rule::ProjectJoinSeparateRule::create()
.apply(join.clone().into())
.unwrap();
return Some(apply.clone_with_inputs(&[apply_left, new_apply_right]));
}
let (push_left, push_right) = match join.join_type() {
JoinType::LeftSemi
| JoinType::LeftAnti
| JoinType::LeftOuter
| JoinType::AsofLeftOuter => {
if !join_right_has_correlated_id {
(true, false)
} else {
(true, true)
}
}
JoinType::RightSemi | JoinType::RightAnti | JoinType::RightOuter => {
if !join_left_has_correlated_id {
(false, true)
} else {
(true, true)
}
}
JoinType::Inner | JoinType::AsofInner => {
if join_cond_has_correlated_id
&& !join_right_has_correlated_id
&& !join_left_has_correlated_id
{
(true, false)
} else {
(join_left_has_correlated_id, join_right_has_correlated_id)
}
}
JoinType::FullOuter => (true, true),
JoinType::Unspecified => unreachable!(),
};
let out = if push_left && push_right {
self.push_apply_both_side(
apply_left,
join,
apply_on,
apply_join_type,
correlated_id,
correlated_indices,
)
} else if push_left {
self.push_apply_left_side(
apply_left,
join,
apply_on,
apply_join_type,
correlated_id,
correlated_indices,
)
} else if push_right {
self.push_apply_right_side(
apply_left,
join,
apply_on,
apply_join_type,
correlated_id,
correlated_indices,
)
} else {
unreachable!();
};
assert_eq!(out.schema(), plan.schema());
Some(out)
}
}
impl ApplyJoinTransposeRule {
fn push_apply_left_side(
&self,
apply_left: PlanRef,
join: &LogicalJoin,
apply_on: Condition,
apply_join_type: JoinType,
correlated_id: CorrelatedId,
correlated_indices: Vec<usize>,
) -> PlanRef {
let apply_left_len = apply_left.schema().len();
let join_left_len = join.left().schema().len();
let mut rewriter = Rewriter {
join_left_len,
join_left_offset: apply_left_len as isize,
join_right_offset: apply_left_len as isize,
index_mapping: ApplyCorrelatedIndicesConverter::convert_to_index_mapping(
&correlated_indices,
),
correlated_id,
};
let new_join_condition = Condition {
conjunctions: join
.on()
.clone()
.into_iter()
.map(|expr| rewriter.rewrite_expr(expr))
.collect_vec(),
};
let mut left_apply_condition: Vec<ExprImpl> = vec![];
let mut other_condition: Vec<ExprImpl> = vec![];
match join.join_type() {
JoinType::LeftSemi | JoinType::LeftAnti => {
left_apply_condition.extend(apply_on);
}
JoinType::Inner
| JoinType::LeftOuter
| JoinType::RightOuter
| JoinType::FullOuter
| JoinType::AsofInner
| JoinType::AsofLeftOuter => {
let apply_len = apply_left_len + join.schema().len();
let mut d_t1_bit_set = FixedBitSet::with_capacity(apply_len);
d_t1_bit_set.set_range(0..apply_left_len + join_left_len, true);
let (left, other): (Vec<_>, Vec<_>) = apply_on
.into_iter()
.partition(|expr| expr.collect_input_refs(apply_len).is_subset(&d_t1_bit_set));
left_apply_condition.extend(left);
other_condition.extend(other);
}
JoinType::RightSemi | JoinType::RightAnti | JoinType::Unspecified => unreachable!(),
}
let new_join_left = LogicalApply::create(
apply_left,
join.left(),
apply_join_type,
Condition {
conjunctions: left_apply_condition,
},
correlated_id,
correlated_indices,
false,
);
let new_join = LogicalJoin::new(
new_join_left,
join.right(),
join.join_type(),
new_join_condition,
);
LogicalFilter::create(
new_join.into(),
Condition {
conjunctions: other_condition,
},
)
}
fn push_apply_right_side(
&self,
apply_left: PlanRef,
join: &LogicalJoin,
apply_on: Condition,
apply_join_type: JoinType,
correlated_id: CorrelatedId,
correlated_indices: Vec<usize>,
) -> PlanRef {
let apply_left_len = apply_left.schema().len();
let join_left_len = join.left().schema().len();
let mut rewriter = Rewriter {
join_left_len,
join_left_offset: 0,
join_right_offset: apply_left_len as isize,
index_mapping: ApplyCorrelatedIndicesConverter::convert_to_index_mapping(
&correlated_indices,
),
correlated_id,
};
let new_join_condition = Condition {
conjunctions: join
.on()
.clone()
.into_iter()
.map(|expr| rewriter.rewrite_expr(expr))
.collect_vec(),
};
let mut right_apply_condition: Vec<ExprImpl> = vec![];
let mut other_condition: Vec<ExprImpl> = vec![];
match join.join_type() {
JoinType::RightSemi | JoinType::RightAnti => {
right_apply_condition.extend(apply_on);
}
JoinType::Inner
| JoinType::LeftOuter
| JoinType::RightOuter
| JoinType::FullOuter
| JoinType::AsofInner
| JoinType::AsofLeftOuter => {
let apply_len = apply_left_len + join.schema().len();
let mut d_t2_bit_set = FixedBitSet::with_capacity(apply_len);
d_t2_bit_set.set_range(0..apply_left_len, true);
d_t2_bit_set.set_range(apply_left_len + join_left_len..apply_len, true);
let (right, other): (Vec<_>, Vec<_>) = apply_on
.into_iter()
.partition(|expr| expr.collect_input_refs(apply_len).is_subset(&d_t2_bit_set));
right_apply_condition.extend(right);
other_condition.extend(other);
let mut right_apply_condition_rewriter = Rewriter {
join_left_len: apply_left_len,
join_left_offset: 0,
join_right_offset: -(join_left_len as isize),
index_mapping: ColIndexMapping::empty(0, 0),
correlated_id,
};
right_apply_condition = right_apply_condition
.into_iter()
.map(|expr| right_apply_condition_rewriter.rewrite_expr(expr))
.collect_vec();
}
JoinType::LeftSemi | JoinType::LeftAnti | JoinType::Unspecified => unreachable!(),
}
let new_join_right = LogicalApply::create(
apply_left,
join.right(),
apply_join_type,
Condition {
conjunctions: right_apply_condition,
},
correlated_id,
correlated_indices,
false,
);
let (output_indices, target_size) = {
let (apply_left_len, join_right_len) = match apply_join_type {
JoinType::LeftSemi | JoinType::LeftAnti => (apply_left_len, 0),
JoinType::RightSemi | JoinType::RightAnti => (0, join.right().schema().len()),
_ => (apply_left_len, join.right().schema().len()),
};
let left_iter = join_left_len..join_left_len + apply_left_len;
let right_iter = (0..join_left_len).chain(
join_left_len + apply_left_len..join_left_len + apply_left_len + join_right_len,
);
let output_indices: Vec<_> = match join.join_type() {
JoinType::LeftSemi | JoinType::LeftAnti => left_iter.collect(),
JoinType::RightSemi | JoinType::RightAnti => right_iter.collect(),
_ => left_iter.chain(right_iter).collect(),
};
let target_size = join_left_len + apply_left_len + join_right_len;
(output_indices, target_size)
};
let mut output_indices_mapping = ColIndexMapping::new(
output_indices.iter().map(|x| Some(*x)).collect(),
target_size,
);
let new_join = LogicalJoin::new(
join.left(),
new_join_right,
join.join_type(),
new_join_condition,
)
.clone_with_output_indices(output_indices);
LogicalFilter::create(
new_join.into(),
Condition {
conjunctions: other_condition,
}
.rewrite_expr(&mut output_indices_mapping),
)
}
fn push_apply_both_side(
&self,
apply_left: PlanRef,
join: &LogicalJoin,
apply_on: Condition,
apply_join_type: JoinType,
correlated_id: CorrelatedId,
correlated_indices: Vec<usize>,
) -> PlanRef {
let apply_left_len = apply_left.schema().len();
let join_left_len = join.left().schema().len();
let mut rewriter = Rewriter {
join_left_len,
join_left_offset: apply_left_len as isize,
join_right_offset: 2 * apply_left_len as isize,
index_mapping: ApplyCorrelatedIndicesConverter::convert_to_index_mapping(
&correlated_indices,
),
correlated_id,
};
let natural_conjunctions = apply_left
.schema()
.fields
.iter()
.enumerate()
.map(|(i, field)| {
Self::create_null_safe_equal_expr(
i,
field.data_type.clone(),
i + join_left_len + apply_left_len,
field.data_type.clone(),
)
})
.collect_vec();
let new_join_condition = Condition {
conjunctions: join
.on()
.clone()
.into_iter()
.map(|expr| rewriter.rewrite_expr(expr))
.chain(natural_conjunctions)
.collect_vec(),
};
let mut left_apply_condition: Vec<ExprImpl> = vec![];
let mut right_apply_condition: Vec<ExprImpl> = vec![];
let mut other_condition: Vec<ExprImpl> = vec![];
match join.join_type() {
JoinType::LeftSemi | JoinType::LeftAnti => {
left_apply_condition.extend(apply_on);
}
JoinType::RightSemi | JoinType::RightAnti => {
right_apply_condition.extend(apply_on);
}
JoinType::Inner
| JoinType::LeftOuter
| JoinType::RightOuter
| JoinType::FullOuter
| JoinType::AsofInner
| JoinType::AsofLeftOuter => {
let apply_len = apply_left_len + join.schema().len();
let mut d_t1_bit_set = FixedBitSet::with_capacity(apply_len);
let mut d_t2_bit_set = FixedBitSet::with_capacity(apply_len);
d_t1_bit_set.set_range(0..apply_left_len + join_left_len, true);
d_t2_bit_set.set_range(0..apply_left_len, true);
d_t2_bit_set.set_range(apply_left_len + join_left_len..apply_len, true);
for (key, group) in &apply_on.into_iter().chunk_by(|expr| {
let collect_bit_set = expr.collect_input_refs(apply_len);
if collect_bit_set.is_subset(&d_t1_bit_set) {
0
} else if collect_bit_set.is_subset(&d_t2_bit_set) {
1
} else {
2
}
}) {
let vec = group.collect_vec();
match key {
0 => left_apply_condition.extend(vec),
1 => right_apply_condition.extend(vec),
2 => other_condition.extend(vec),
_ => unreachable!(),
}
}
let mut right_apply_condition_rewriter = Rewriter {
join_left_len: apply_left_len,
join_left_offset: 0,
join_right_offset: -(join_left_len as isize),
index_mapping: ColIndexMapping::empty(0, 0),
correlated_id,
};
right_apply_condition = right_apply_condition
.into_iter()
.map(|expr| right_apply_condition_rewriter.rewrite_expr(expr))
.collect_vec();
}
JoinType::Unspecified => unreachable!(),
}
let new_join_left = LogicalApply::create(
apply_left.clone(),
join.left(),
apply_join_type,
Condition {
conjunctions: left_apply_condition,
},
correlated_id,
correlated_indices.clone(),
false,
);
let new_join_right = LogicalApply::create(
apply_left,
join.right(),
apply_join_type,
Condition {
conjunctions: right_apply_condition,
},
correlated_id,
correlated_indices,
false,
);
let (output_indices, target_size) = {
let (apply_left_len, join_right_len) = match apply_join_type {
JoinType::LeftSemi | JoinType::LeftAnti => (apply_left_len, 0),
JoinType::RightSemi | JoinType::RightAnti => (0, join.right().schema().len()),
_ => (apply_left_len, join.right().schema().len()),
};
let left_iter = 0..join_left_len + apply_left_len;
let right_iter = join_left_len + apply_left_len * 2
..join_left_len + apply_left_len * 2 + join_right_len;
let output_indices: Vec<_> = match join.join_type() {
JoinType::LeftSemi | JoinType::LeftAnti => left_iter.collect(),
JoinType::RightSemi | JoinType::RightAnti => right_iter.collect(),
_ => left_iter.chain(right_iter).collect(),
};
let target_size = join_left_len + apply_left_len * 2 + join_right_len;
(output_indices, target_size)
};
let new_join = LogicalJoin::new(
new_join_left,
new_join_right,
join.join_type(),
new_join_condition,
)
.clone_with_output_indices(output_indices.clone());
match join.join_type() {
JoinType::LeftSemi | JoinType::LeftAnti | JoinType::RightSemi | JoinType::RightAnti => {
new_join.into()
}
JoinType::Inner
| JoinType::LeftOuter
| JoinType::RightOuter
| JoinType::FullOuter
| JoinType::AsofInner
| JoinType::AsofLeftOuter => {
let mut output_indices_mapping = ColIndexMapping::new(
output_indices.iter().map(|x| Some(*x)).collect(),
target_size,
);
LogicalFilter::create(
new_join.into(),
Condition {
conjunctions: other_condition,
}
.rewrite_expr(&mut output_indices_mapping),
)
}
JoinType::Unspecified => unreachable!(),
}
}
fn create_null_safe_equal_expr(
left: usize,
left_data_type: DataType,
right: usize,
right_data_type: DataType,
) -> ExprImpl {
ExprImpl::FunctionCall(Box::new(FunctionCall::new_unchecked(
ExprType::IsNotDistinctFrom,
vec![
ExprImpl::InputRef(Box::new(InputRef::new(left, left_data_type))),
ExprImpl::InputRef(Box::new(InputRef::new(right, right_data_type))),
],
Boolean,
)))
}
}
impl ApplyJoinTransposeRule {
pub fn create() -> BoxedRule {
Box::new(ApplyJoinTransposeRule {})
}
}
struct Rewriter {
join_left_len: usize,
join_left_offset: isize,
join_right_offset: isize,
index_mapping: ColIndexMapping,
correlated_id: CorrelatedId,
}
impl ExprRewriter for Rewriter {
fn rewrite_correlated_input_ref(
&mut self,
correlated_input_ref: CorrelatedInputRef,
) -> ExprImpl {
if correlated_input_ref.correlated_id() == self.correlated_id {
InputRef::new(
self.index_mapping.map(correlated_input_ref.index()),
correlated_input_ref.return_type(),
)
.into()
} else {
correlated_input_ref.into()
}
}
fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
if input_ref.index < self.join_left_len {
InputRef::new(
(input_ref.index() as isize + self.join_left_offset) as usize,
input_ref.return_type(),
)
.into()
} else {
InputRef::new(
(input_ref.index() as isize + self.join_right_offset) as usize,
input_ref.return_type(),
)
.into()
}
}
}