risingwave_frontend/optimizer/rule/
apply_eliminate_rule.rs1use std::collections::HashMap;
16
17use risingwave_common::types::DataType;
18use risingwave_pb::plan_common::JoinType;
19
20use super::{BoxedRule, Rule};
21use crate::expr::{Expr, ExprImpl, ExprType, FunctionCall, InputRef};
22use crate::optimizer::PlanRef;
23use crate::optimizer::plan_node::{LogicalFilter, LogicalJoin, LogicalProject};
24use crate::optimizer::plan_visitor::PlanCorrelatedIdFinder;
25use crate::utils::Condition;
26
27pub struct ApplyEliminateRule {}
58impl Rule for ApplyEliminateRule {
59 fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
60 let apply = plan.as_logical_apply()?;
61 let (left, right, on, join_type, correlated_id, correlated_indices, max_one_row) =
62 apply.clone().decompose();
63
64 if max_one_row {
65 return None;
66 }
67
68 if PlanCorrelatedIdFinder::find_correlated_id(right.clone(), &correlated_id) {
70 return None;
71 }
72
73 let apply_left_len = left.schema().len();
74 assert_eq!(join_type, JoinType::Inner);
75
76 let mut column_mapping = HashMap::new();
81 on.conjunctions.iter().for_each(|expr| {
82 if let ExprImpl::FunctionCall(func_call) = expr {
83 if let Some((left, right, data_type)) = Self::check(func_call, apply_left_len) {
84 column_mapping.insert(left, (right, data_type));
85 }
86 }
87 });
88 if column_mapping.len() == apply_left_len {
89 let mut exprs: Vec<ExprImpl> = (0..correlated_indices.len())
95 .map(|i| {
96 let (col_index, data_type) = column_mapping.get(&i).unwrap();
97 InputRef::new(*col_index - apply_left_len, data_type.clone()).into()
98 })
99 .collect();
100 exprs.extend(
101 right
102 .schema()
103 .data_types()
104 .into_iter()
105 .enumerate()
106 .map(|(index, data_type)| InputRef::new(index, data_type).into()),
107 );
108 let project = LogicalProject::create(right, exprs);
109
110 let filter_exprs: Vec<ExprImpl> = (0..correlated_indices.len())
112 .map(|i| {
113 ExprImpl::FunctionCall(Box::new(FunctionCall::new_unchecked(
114 ExprType::IsNotNull,
115 vec![ExprImpl::InputRef(Box::new(InputRef::new(
116 i,
117 project.schema().fields[i].data_type.clone(),
118 )))],
119 DataType::Boolean,
120 )))
121 })
122 .collect();
123
124 let filter = LogicalFilter::create(
125 project,
126 Condition {
127 conjunctions: filter_exprs,
128 },
129 );
130
131 Some(filter)
132 } else {
133 let join = LogicalJoin::new(left, right, join_type, on);
134 Some(join.into())
135 }
136 }
137}
138
139impl ApplyEliminateRule {
140 pub fn create() -> BoxedRule {
141 Box::new(ApplyEliminateRule {})
142 }
143
144 fn check(func_call: &FunctionCall, apply_left_len: usize) -> Option<(usize, usize, DataType)> {
147 let inputs = func_call.inputs();
148 if func_call.func_type() == ExprType::Equal && inputs.len() == 2 {
149 let left = &inputs[0];
150 let right = &inputs[1];
151 match (left, right) {
152 (ExprImpl::InputRef(left), ExprImpl::InputRef(right)) => {
153 let left_type = left.return_type();
154 let left = left.index();
155 let right_type = right.return_type();
156 let right = right.index();
157 if left < apply_left_len && right >= apply_left_len {
158 Some((left, right, right_type))
159 } else if left >= apply_left_len && right < apply_left_len {
160 Some((right, left, left_type))
161 } else {
162 None
163 }
164 }
165 _ => None,
166 }
167 } else {
168 None
169 }
170 }
171}