risingwave_frontend/optimizer/rule/
intersect_to_semi_join_rule.rs1use risingwave_common::types::DataType::Boolean;
16use risingwave_common::util::iter_util::ZipEqDebug;
17use risingwave_pb::plan_common::JoinType;
18
19use super::{BoxedRule, Rule};
20use crate::expr::{ExprImpl, ExprType, FunctionCall, InputRef};
21use crate::optimizer::PlanRef;
22use crate::optimizer::plan_node::generic::Agg;
23use crate::optimizer::plan_node::{LogicalIntersect, LogicalJoin, PlanTreeNode};
24
25pub struct IntersectToSemiJoinRule {}
26impl Rule for IntersectToSemiJoinRule {
27 fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
28 let logical_intersect: &LogicalIntersect = plan.as_logical_intersect()?;
29 let all = logical_intersect.all();
30 if all {
31 return None;
32 }
33
34 let inputs = logical_intersect.inputs();
35 let join = inputs
36 .into_iter()
37 .fold(None, |left, right| match left {
38 None => Some(right),
39 Some(left) => {
40 let on =
41 IntersectToSemiJoinRule::gen_null_safe_equal(left.clone(), right.clone());
42 Some(LogicalJoin::create(left, right, JoinType::LeftSemi, on))
43 }
44 })
45 .unwrap();
46
47 Some(Agg::new(vec![], (0..join.schema().len()).collect(), join).into())
48 }
49}
50
51impl IntersectToSemiJoinRule {
52 pub(crate) fn gen_null_safe_equal(left: PlanRef, right: PlanRef) -> ExprImpl {
53 let arms = (left
54 .schema()
55 .fields()
56 .iter()
57 .zip_eq_debug(right.schema().fields())
58 .enumerate())
59 .map(|(i, (left_field, right_field))| {
60 ExprImpl::FunctionCall(Box::new(FunctionCall::new_unchecked(
61 ExprType::IsNotDistinctFrom,
62 vec![
63 ExprImpl::InputRef(Box::new(InputRef::new(i, left_field.data_type()))),
64 ExprImpl::InputRef(Box::new(InputRef::new(
65 i + left.schema().len(),
66 right_field.data_type(),
67 ))),
68 ],
69 Boolean,
70 )))
71 });
72 ExprImpl::and(arms)
73 }
74}
75
76impl IntersectToSemiJoinRule {
77 pub fn create() -> BoxedRule {
78 Box::new(IntersectToSemiJoinRule {})
79 }
80}