risingwave_frontend/optimizer/rule/
apply_eliminate_rule.rs1use std::collections::HashMap;
16
17use risingwave_common::types::DataType;
18use risingwave_pb::plan_common::JoinType;
19
20use super::prelude::{PlanRef, *};
21use crate::expr::{Expr, ExprImpl, ExprType, FunctionCall, InputRef};
22use crate::optimizer::plan_node::{LogicalFilter, LogicalJoin, LogicalProject};
23use crate::optimizer::plan_visitor::PlanCorrelatedIdFinder;
24use crate::utils::Condition;
25
26pub struct ApplyEliminateRule {}
57impl Rule<Logical> for ApplyEliminateRule {
58 fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
59 let apply = plan.as_logical_apply()?;
60 let (left, right, on, join_type, correlated_id, correlated_indices, max_one_row) =
61 apply.clone().decompose();
62
63 if max_one_row {
64 return None;
65 }
66
67 if PlanCorrelatedIdFinder::find_correlated_id(right.clone(), &correlated_id) {
69 return None;
70 }
71
72 let apply_left_len = left.schema().len();
73 assert_eq!(join_type, JoinType::Inner);
74
75 let mut column_mapping = HashMap::new();
80 on.conjunctions.iter().for_each(|expr| {
81 if let ExprImpl::FunctionCall(func_call) = expr
82 && let Some((left, right, data_type)) = Self::check(func_call, apply_left_len)
83 {
84 column_mapping.insert(left, (right, data_type));
85 }
86 });
87 if column_mapping.len() == apply_left_len {
88 let mut exprs: Vec<ExprImpl> = (0..correlated_indices.len())
94 .map(|i| {
95 let (col_index, data_type) = column_mapping.get(&i).unwrap();
96 InputRef::new(*col_index - apply_left_len, data_type.clone()).into()
97 })
98 .collect();
99 exprs.extend(
100 right
101 .schema()
102 .data_types()
103 .into_iter()
104 .enumerate()
105 .map(|(index, data_type)| InputRef::new(index, data_type).into()),
106 );
107 let project = LogicalProject::create(right, exprs);
108
109 let filter_exprs: Vec<ExprImpl> = (0..correlated_indices.len())
111 .map(|i| {
112 ExprImpl::FunctionCall(Box::new(FunctionCall::new_unchecked(
113 ExprType::IsNotNull,
114 vec![ExprImpl::InputRef(Box::new(InputRef::new(
115 i,
116 project.schema().fields[i].data_type.clone(),
117 )))],
118 DataType::Boolean,
119 )))
120 })
121 .collect();
122
123 let filter = LogicalFilter::create(
124 project,
125 Condition {
126 conjunctions: filter_exprs,
127 },
128 );
129
130 Some(filter)
131 } else {
132 let join = LogicalJoin::new(left, right, join_type, on);
133 Some(join.into())
134 }
135 }
136}
137
138impl ApplyEliminateRule {
139 pub fn create() -> BoxedRule {
140 Box::new(ApplyEliminateRule {})
141 }
142
143 fn check(func_call: &FunctionCall, apply_left_len: usize) -> Option<(usize, usize, DataType)> {
146 let inputs = func_call.inputs();
147 if func_call.func_type() == ExprType::Equal && inputs.len() == 2 {
148 let left = &inputs[0];
149 let right = &inputs[1];
150 match (left, right) {
151 (ExprImpl::InputRef(left), ExprImpl::InputRef(right)) => {
152 let left_type = left.return_type();
153 let left = left.index();
154 let right_type = right.return_type();
155 let right = right.index();
156 if left < apply_left_len && right >= apply_left_len {
157 Some((left, right, right_type))
158 } else if left >= apply_left_len && right < apply_left_len {
159 Some((right, left, left_type))
160 } else {
161 None
162 }
163 }
164 _ => None,
165 }
166 } else {
167 None
168 }
169 }
170}