risingwave_frontend/optimizer/rule/
apply_eliminate_rule.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16
17use risingwave_common::types::DataType;
18use risingwave_pb::plan_common::JoinType;
19
20use super::prelude::{PlanRef, *};
21use crate::expr::{Expr, ExprImpl, ExprType, FunctionCall, InputRef};
22use crate::optimizer::plan_node::{LogicalFilter, LogicalJoin, LogicalProject};
23use crate::optimizer::plan_visitor::PlanCorrelatedIdFinder;
24use crate::utils::Condition;
25
26/// Eliminate `LogicalApply` if we can't find its `correlated_id` in its RHS.
27///
28/// Before:
29///
30/// ```text
31///    LogicalApply
32///    /           \
33///  Domain       RHS
34/// ```
35///
36/// If it can remove DAG.
37/// After:
38///
39/// ```text
40///  LogicalProject
41///        |
42///  LogicalFilter (Null reject for equal)
43///        |
44///       RHS
45/// ```
46///
47///
48/// If it can't remove DAG.
49/// After:
50///
51/// ```text
52///     LogicalJoin
53///    /           \
54///  Domain       RHS
55/// ```
56pub struct ApplyEliminateRule {}
57impl Rule<Logical> for ApplyEliminateRule {
58    fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
59        let apply = plan.as_logical_apply()?;
60        let (left, right, on, join_type, correlated_id, correlated_indices, max_one_row) =
61            apply.clone().decompose();
62
63        if max_one_row {
64            return None;
65        }
66
67        // Still can find `correlated_id`, so bail out.
68        if PlanCorrelatedIdFinder::find_correlated_id(right.clone(), &correlated_id) {
69            return None;
70        }
71
72        let apply_left_len = left.schema().len();
73        assert_eq!(join_type, JoinType::Inner);
74
75        // Record the mapping from `CorrelatedInputRef`'s index to `InputRef`'s index.
76        // We currently can remove DAG only if ALL the `CorrelatedInputRef` are equal joined to
77        // `InputRef`.
78        // TODO: Do some transformation for IN, and try to remove DAG for it.
79        let mut column_mapping = HashMap::new();
80        on.conjunctions.iter().for_each(|expr| {
81            if let ExprImpl::FunctionCall(func_call) = expr
82                && let Some((left, right, data_type)) = Self::check(func_call, apply_left_len)
83            {
84                column_mapping.insert(left, (right, data_type));
85            }
86        });
87        if column_mapping.len() == apply_left_len {
88            // Remove DAG.
89
90            // Replace `LogicalApply` with `LogicalProject` and insert the `InputRef`s which is
91            // equal to `CorrelatedInputRef` at the beginning of `LogicalProject`.
92            // See the fourth section of Unnesting Arbitrary Queries for how to do the optimization.
93            let mut exprs: Vec<ExprImpl> = (0..correlated_indices.len())
94                .map(|i| {
95                    let (col_index, data_type) = column_mapping.get(&i).unwrap();
96                    InputRef::new(*col_index - apply_left_len, data_type.clone()).into()
97                })
98                .collect();
99            exprs.extend(
100                right
101                    .schema()
102                    .data_types()
103                    .into_iter()
104                    .enumerate()
105                    .map(|(index, data_type)| InputRef::new(index, data_type).into()),
106            );
107            let project = LogicalProject::create(right, exprs);
108
109            // Null reject for equal
110            let filter_exprs: Vec<ExprImpl> = (0..correlated_indices.len())
111                .map(|i| {
112                    ExprImpl::FunctionCall(Box::new(FunctionCall::new_unchecked(
113                        ExprType::IsNotNull,
114                        vec![ExprImpl::InputRef(Box::new(InputRef::new(
115                            i,
116                            project.schema().fields[i].data_type.clone(),
117                        )))],
118                        DataType::Boolean,
119                    )))
120                })
121                .collect();
122
123            let filter = LogicalFilter::create(
124                project,
125                Condition {
126                    conjunctions: filter_exprs,
127                },
128            );
129
130            Some(filter)
131        } else {
132            let join = LogicalJoin::new(left, right, join_type, on);
133            Some(join.into())
134        }
135    }
136}
137
138impl ApplyEliminateRule {
139    pub fn create() -> BoxedRule {
140        Box::new(ApplyEliminateRule {})
141    }
142
143    /// Check whether the `func_call` is like v1 = v2, in which v1 and v2 belong respectively to
144    /// `LogicalApply`'s left and right.
145    fn check(func_call: &FunctionCall, apply_left_len: usize) -> Option<(usize, usize, DataType)> {
146        let inputs = func_call.inputs();
147        if func_call.func_type() == ExprType::Equal && inputs.len() == 2 {
148            let left = &inputs[0];
149            let right = &inputs[1];
150            match (left, right) {
151                (ExprImpl::InputRef(left), ExprImpl::InputRef(right)) => {
152                    let left_type = left.return_type();
153                    let left = left.index();
154                    let right_type = right.return_type();
155                    let right = right.index();
156                    if left < apply_left_len && right >= apply_left_len {
157                        Some((left, right, right_type))
158                    } else if left >= apply_left_len && right < apply_left_len {
159                        Some((right, left, left_type))
160                    } else {
161                        None
162                    }
163                }
164                _ => None,
165            }
166        } else {
167            None
168        }
169    }
170}