risingwave_frontend/optimizer/rule/
translate_apply_rule.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16
17use risingwave_common::types::DataType;
18use risingwave_pb::plan_common::JoinType;
19
20use super::{BoxedRule, Rule};
21use crate::expr::{ExprImpl, ExprType, FunctionCall, InputRef};
22use crate::optimizer::PlanRef;
23use crate::optimizer::plan_node::generic::{Agg, GenericPlanRef};
24use crate::optimizer::plan_node::{
25    LogicalApply, LogicalJoin, LogicalProject, LogicalScan, LogicalShare, PlanTreeNodeBinary,
26    PlanTreeNodeUnary,
27};
28use crate::utils::{ColIndexMapping, Condition};
29
30/// General Unnesting based on the paper Unnesting Arbitrary Queries:
31/// Translate the apply into a canonical form.
32///
33/// Before:
34///
35/// ```text
36///     LogicalApply
37///    /            \
38///  LHS           RHS
39/// ```
40///
41/// After:
42///
43/// ```text
44///      LogicalJoin
45///    /            \
46///  LHS        LogicalApply
47///             /           \
48///          Domain         RHS
49/// ```
50pub struct TranslateApplyRule {
51    enable_share_plan: bool,
52}
53
54impl Rule for TranslateApplyRule {
55    fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
56        let apply: &LogicalApply = plan.as_logical_apply()?;
57        if apply.translated() {
58            return None;
59        }
60        let mut left: PlanRef = apply.left();
61        let right: PlanRef = apply.right();
62        let apply_left_len = left.schema().len();
63        let correlated_indices = apply.correlated_indices();
64
65        let mut index_mapping =
66            ColIndexMapping::new(vec![None; apply_left_len], correlated_indices.len());
67        let mut data_types = HashMap::new();
68        let mut index = 0;
69
70        // First try to rewrite the left side of the apply.
71        // TODO: remove the rewrite and always use the general way to calculate the domain
72        //      after we support DAG.
73        let domain: PlanRef = if let Some(rewritten_left) = Self::rewrite(
74            &left,
75            correlated_indices.clone(),
76            0,
77            &mut index_mapping,
78            &mut data_types,
79            &mut index,
80        ) {
81            // This `LogicalProject` is used to make sure that after `LogicalApply`'s left was
82            // rewritten, the new index of `correlated_index` is always at its position in
83            // `correlated_indices`.
84            let exprs = correlated_indices
85                .clone()
86                .into_iter()
87                .enumerate()
88                .map(|(i, correlated_index)| {
89                    let index = index_mapping.map(correlated_index);
90                    let data_type = rewritten_left.schema().fields()[index].data_type.clone();
91                    index_mapping.put(correlated_index, Some(i));
92                    InputRef::new(index, data_type).into()
93                })
94                .collect();
95            let project = LogicalProject::create(rewritten_left, exprs);
96            let distinct = Agg::new(vec![], (0..project.schema().len()).collect(), project);
97            distinct.into()
98        } else {
99            // The left side of the apply is not SPJ. We need to use the general way to calculate
100            // the domain. Distinct + Project + The Left of Apply
101
102            // Use Share
103            left = if self.enable_share_plan {
104                let logical_share = LogicalShare::new(left);
105                logical_share.into()
106            } else {
107                left
108            };
109
110            let exprs = correlated_indices
111                .clone()
112                .into_iter()
113                .map(|correlated_index| {
114                    let data_type = left.schema().fields()[correlated_index].data_type.clone();
115                    InputRef::new(correlated_index, data_type).into()
116                })
117                .collect();
118            let project = LogicalProject::create(left.clone(), exprs);
119            let distinct = Agg::new(vec![], (0..project.schema().len()).collect(), project);
120            distinct.into()
121        };
122
123        let eq_predicates = correlated_indices
124            .into_iter()
125            .enumerate()
126            .map(|(i, correlated_index)| {
127                let shifted_index = i + apply_left_len;
128                let data_type = domain.schema().fields()[i].data_type.clone();
129                let left = InputRef::new(correlated_index, data_type.clone());
130                let right = InputRef::new(shifted_index, data_type);
131                // use null-safe equal
132                FunctionCall::new_unchecked(
133                    ExprType::IsNotDistinctFrom,
134                    vec![left.into(), right.into()],
135                    DataType::Boolean,
136                )
137                .into()
138            })
139            .collect::<Vec<ExprImpl>>();
140
141        let new_apply = apply.clone_with_left_right(left, right);
142        let new_node = new_apply.translate_apply(domain, eq_predicates);
143        Some(new_node)
144    }
145}
146
147impl TranslateApplyRule {
148    pub fn create(enable_share_plan: bool) -> BoxedRule {
149        Box::new(TranslateApplyRule { enable_share_plan })
150    }
151
152    /// Rewrite `LogicalApply`'s left according to `correlated_indices`.
153    ///
154    /// Assumption: only `LogicalJoin`, `LogicalScan`, `LogicalProject` and `LogicalFilter` are in
155    /// the left.
156    fn rewrite(
157        plan: &PlanRef,
158        correlated_indices: Vec<usize>,
159        offset: usize,
160        index_mapping: &mut ColIndexMapping,
161        data_types: &mut HashMap<usize, DataType>,
162        index: &mut usize,
163    ) -> Option<PlanRef> {
164        if let Some(join) = plan.as_logical_join() {
165            Self::rewrite_join(
166                join,
167                correlated_indices,
168                offset,
169                index_mapping,
170                data_types,
171                index,
172            )
173        } else if let Some(apply) = plan.as_logical_apply() {
174            Self::rewrite_apply(
175                apply,
176                correlated_indices,
177                offset,
178                index_mapping,
179                data_types,
180                index,
181            )
182        } else if let Some(scan) = plan.as_logical_scan() {
183            Self::rewrite_scan(
184                scan,
185                correlated_indices,
186                offset,
187                index_mapping,
188                data_types,
189                index,
190            )
191        } else if let Some(filter) = plan.as_logical_filter() {
192            Self::rewrite(
193                &filter.input(),
194                correlated_indices,
195                offset,
196                index_mapping,
197                data_types,
198                index,
199            )
200        } else {
201            // TODO: better to return an error.
202            None
203        }
204    }
205
206    fn rewrite_join(
207        join: &LogicalJoin,
208        required_col_idx: Vec<usize>,
209        mut offset: usize,
210        index_mapping: &mut ColIndexMapping,
211        data_types: &mut HashMap<usize, DataType>,
212        index: &mut usize,
213    ) -> Option<PlanRef> {
214        // TODO: Do we need to take the `on` into account?
215        let left_len = join.left().schema().len();
216        let (left_idxs, right_idxs): (Vec<_>, Vec<_>) = required_col_idx
217            .into_iter()
218            .partition(|idx| *idx < left_len);
219        let mut rewrite =
220            |plan: PlanRef, mut indices: Vec<usize>, is_right: bool| -> Option<PlanRef> {
221                if is_right {
222                    indices.iter_mut().for_each(|index| *index -= left_len);
223                    offset += left_len;
224                }
225                Self::rewrite(&plan, indices, offset, index_mapping, data_types, index)
226            };
227        match (left_idxs.is_empty(), right_idxs.is_empty()) {
228            (true, false) => {
229                // Only accept join which doesn't generate null columns.
230                match join.join_type() {
231                    JoinType::Inner
232                    | JoinType::LeftSemi
233                    | JoinType::RightSemi
234                    | JoinType::LeftAnti
235                    | JoinType::RightAnti
236                    | JoinType::RightOuter
237                    | JoinType::AsofInner => rewrite(join.right(), right_idxs, true),
238                    JoinType::LeftOuter | JoinType::FullOuter | JoinType::AsofLeftOuter => None,
239                    JoinType::Unspecified => unreachable!(),
240                }
241            }
242            (false, true) => {
243                // Only accept join which doesn't generate null columns.
244                match join.join_type() {
245                    JoinType::Inner
246                    | JoinType::LeftSemi
247                    | JoinType::RightSemi
248                    | JoinType::LeftAnti
249                    | JoinType::RightAnti
250                    | JoinType::LeftOuter
251                    | JoinType::AsofInner
252                    | JoinType::AsofLeftOuter => rewrite(join.left(), left_idxs, false),
253                    JoinType::RightOuter | JoinType::FullOuter => None,
254                    JoinType::Unspecified => unreachable!(),
255                }
256            }
257            (false, false) => {
258                // Only accept join which doesn't generate null columns.
259                match join.join_type() {
260                    JoinType::Inner
261                    | JoinType::LeftSemi
262                    | JoinType::RightSemi
263                    | JoinType::LeftAnti
264                    | JoinType::RightAnti
265                    | JoinType::AsofInner => {
266                        let left = rewrite(join.left(), left_idxs, false)?;
267                        let right = rewrite(join.right(), right_idxs, true)?;
268                        let new_join =
269                            LogicalJoin::new(left, right, join.join_type(), Condition::true_cond());
270                        Some(new_join.into())
271                    }
272                    JoinType::LeftOuter
273                    | JoinType::RightOuter
274                    | JoinType::FullOuter
275                    | JoinType::AsofLeftOuter => None,
276                    JoinType::Unspecified => unreachable!(),
277                }
278            }
279            _ => None,
280        }
281    }
282
283    /// ```text
284    ///             LogicalApply
285    ///            /            \
286    ///     LogicalApply       RHS1
287    ///    /            \
288    ///  LHS           RHS2
289    /// ```
290    ///
291    /// A common structure of multi scalar subqueries is a chain of `LogicalApply`. To avoid exponential growth of the domain operator, we need to rewrite the apply and try to simplify the domain as much as possible.
292    /// We use a top-down apply order to rewrite the apply, so that we don't need to handle operator like project and aggregation generated by the domain calculation.
293    /// As a cost, we need to add a flag `translated` to the apply operator to remind `translate_apply_rule` that the apply has been translated.
294    fn rewrite_apply(
295        apply: &LogicalApply,
296        required_col_idx: Vec<usize>,
297        offset: usize,
298        index_mapping: &mut ColIndexMapping,
299        data_types: &mut HashMap<usize, DataType>,
300        index: &mut usize,
301    ) -> Option<PlanRef> {
302        // TODO: Do we need to take the `on` into account?
303        let left_len = apply.left().schema().len();
304        let (left_idxs, right_idxs): (Vec<_>, Vec<_>) = required_col_idx
305            .into_iter()
306            .partition(|idx| *idx < left_len);
307        if !left_idxs.is_empty() && right_idxs.is_empty() {
308            // Deal with multi scalar subqueries
309            match apply.join_type() {
310                JoinType::Inner
311                | JoinType::LeftSemi
312                | JoinType::LeftAnti
313                | JoinType::LeftOuter
314                | JoinType::AsofInner
315                | JoinType::AsofLeftOuter => {
316                    let plan = apply.left();
317                    Self::rewrite(&plan, left_idxs, offset, index_mapping, data_types, index)
318                }
319                JoinType::RightOuter
320                | JoinType::RightAnti
321                | JoinType::RightSemi
322                | JoinType::FullOuter => None,
323                JoinType::Unspecified => unreachable!(),
324            }
325        } else {
326            None
327        }
328    }
329
330    fn rewrite_scan(
331        scan: &LogicalScan,
332        required_col_idx: Vec<usize>,
333        offset: usize,
334        index_mapping: &mut ColIndexMapping,
335        data_types: &mut HashMap<usize, DataType>,
336        index: &mut usize,
337    ) -> Option<PlanRef> {
338        for i in &required_col_idx {
339            let correlated_index = *i + offset;
340            index_mapping.put(correlated_index, Some(*index));
341            data_types.insert(
342                correlated_index,
343                scan.schema().fields()[*i].data_type.clone(),
344            );
345            *index += 1;
346        }
347
348        Some(scan.clone_with_output_indices(required_col_idx).into())
349    }
350}