risingwave_frontend/optimizer/rule/
apply_eliminate_rule.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16
17use risingwave_common::types::DataType;
18use risingwave_pb::plan_common::JoinType;
19
20use super::{BoxedRule, Rule};
21use crate::expr::{Expr, ExprImpl, ExprType, FunctionCall, InputRef};
22use crate::optimizer::PlanRef;
23use crate::optimizer::plan_node::{LogicalFilter, LogicalJoin, LogicalProject};
24use crate::optimizer::plan_visitor::PlanCorrelatedIdFinder;
25use crate::utils::Condition;
26
27/// Eliminate `LogicalApply` if we can't find its `correlated_id` in its RHS.
28///
29/// Before:
30///
31/// ```text
32///    LogicalApply
33///    /           \
34///  Domain       RHS
35/// ```
36///
37/// If it can remove DAG.
38/// After:
39///
40/// ```text
41///  LogicalProject
42///        |
43///  LogicalFilter (Null reject for equal)
44///        |
45///       RHS
46/// ```
47///
48///
49/// If it can't remove DAG.
50/// After:
51///
52/// ```text
53///     LogicalJoin
54///    /           \
55///  Domain       RHS
56/// ```
57pub struct ApplyEliminateRule {}
58impl Rule for ApplyEliminateRule {
59    fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
60        let apply = plan.as_logical_apply()?;
61        let (left, right, on, join_type, correlated_id, correlated_indices, max_one_row) =
62            apply.clone().decompose();
63
64        if max_one_row {
65            return None;
66        }
67
68        // Still can find `correlated_id`, so bail out.
69        if PlanCorrelatedIdFinder::find_correlated_id(right.clone(), &correlated_id) {
70            return None;
71        }
72
73        let apply_left_len = left.schema().len();
74        assert_eq!(join_type, JoinType::Inner);
75
76        // Record the mapping from `CorrelatedInputRef`'s index to `InputRef`'s index.
77        // We currently can remove DAG only if ALL the `CorrelatedInputRef` are equal joined to
78        // `InputRef`.
79        // TODO: Do some transformation for IN, and try to remove DAG for it.
80        let mut column_mapping = HashMap::new();
81        on.conjunctions.iter().for_each(|expr| {
82            if let ExprImpl::FunctionCall(func_call) = expr {
83                if let Some((left, right, data_type)) = Self::check(func_call, apply_left_len) {
84                    column_mapping.insert(left, (right, data_type));
85                }
86            }
87        });
88        if column_mapping.len() == apply_left_len {
89            // Remove DAG.
90
91            // Replace `LogicalApply` with `LogicalProject` and insert the `InputRef`s which is
92            // equal to `CorrelatedInputRef` at the beginning of `LogicalProject`.
93            // See the fourth section of Unnesting Arbitrary Queries for how to do the optimization.
94            let mut exprs: Vec<ExprImpl> = (0..correlated_indices.len())
95                .map(|i| {
96                    let (col_index, data_type) = column_mapping.get(&i).unwrap();
97                    InputRef::new(*col_index - apply_left_len, data_type.clone()).into()
98                })
99                .collect();
100            exprs.extend(
101                right
102                    .schema()
103                    .data_types()
104                    .into_iter()
105                    .enumerate()
106                    .map(|(index, data_type)| InputRef::new(index, data_type).into()),
107            );
108            let project = LogicalProject::create(right, exprs);
109
110            // Null reject for equal
111            let filter_exprs: Vec<ExprImpl> = (0..correlated_indices.len())
112                .map(|i| {
113                    ExprImpl::FunctionCall(Box::new(FunctionCall::new_unchecked(
114                        ExprType::IsNotNull,
115                        vec![ExprImpl::InputRef(Box::new(InputRef::new(
116                            i,
117                            project.schema().fields[i].data_type.clone(),
118                        )))],
119                        DataType::Boolean,
120                    )))
121                })
122                .collect();
123
124            let filter = LogicalFilter::create(
125                project,
126                Condition {
127                    conjunctions: filter_exprs,
128                },
129            );
130
131            Some(filter)
132        } else {
133            let join = LogicalJoin::new(left, right, join_type, on);
134            Some(join.into())
135        }
136    }
137}
138
139impl ApplyEliminateRule {
140    pub fn create() -> BoxedRule {
141        Box::new(ApplyEliminateRule {})
142    }
143
144    /// Check whether the `func_call` is like v1 = v2, in which v1 and v2 belong respectively to
145    /// `LogicalApply`'s left and right.
146    fn check(func_call: &FunctionCall, apply_left_len: usize) -> Option<(usize, usize, DataType)> {
147        let inputs = func_call.inputs();
148        if func_call.func_type() == ExprType::Equal && inputs.len() == 2 {
149            let left = &inputs[0];
150            let right = &inputs[1];
151            match (left, right) {
152                (ExprImpl::InputRef(left), ExprImpl::InputRef(right)) => {
153                    let left_type = left.return_type();
154                    let left = left.index();
155                    let right_type = right.return_type();
156                    let right = right.index();
157                    if left < apply_left_len && right >= apply_left_len {
158                        Some((left, right, right_type))
159                    } else if left >= apply_left_len && right < apply_left_len {
160                        Some((right, left, left_type))
161                    } else {
162                        None
163                    }
164                }
165                _ => None,
166            }
167        } else {
168            None
169        }
170    }
171}