risingwave_frontend/optimizer/rule/
translate_apply_rule.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16
17use risingwave_common::types::DataType;
18use risingwave_pb::plan_common::JoinType;
19
20use super::prelude::{PlanRef, *};
21use crate::expr::{ExprImpl, ExprType, FunctionCall, InputRef};
22use crate::optimizer::plan_node::generic::{Agg, GenericPlanRef};
23use crate::optimizer::plan_node::{
24    LogicalApply, LogicalJoin, LogicalProject, LogicalScan, LogicalShare, PlanTreeNodeBinary,
25    PlanTreeNodeUnary,
26};
27use crate::utils::{ColIndexMapping, Condition};
28
29/// General Unnesting based on the paper Unnesting Arbitrary Queries:
30/// Translate the apply into a canonical form.
31///
32/// Before:
33///
34/// ```text
35///     LogicalApply
36///    /            \
37///  LHS           RHS
38/// ```
39///
40/// After:
41///
42/// ```text
43///      LogicalJoin
44///    /            \
45///  LHS        LogicalApply
46///             /           \
47///          Domain         RHS
48/// ```
49pub struct TranslateApplyRule {
50    enable_share_plan: bool,
51}
52
53impl Rule<Logical> for TranslateApplyRule {
54    fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
55        let apply: &LogicalApply = plan.as_logical_apply()?;
56        if apply.translated() {
57            return None;
58        }
59        let mut left: PlanRef = apply.left();
60        let right: PlanRef = apply.right();
61        let apply_left_len = left.schema().len();
62        let correlated_indices = apply.correlated_indices();
63
64        let mut index_mapping =
65            ColIndexMapping::new(vec![None; apply_left_len], correlated_indices.len());
66        let mut data_types = HashMap::new();
67        let mut index = 0;
68
69        // First try to rewrite the left side of the apply.
70        // TODO: remove the rewrite and always use the general way to calculate the domain
71        //      after we support DAG.
72        let domain: PlanRef = if let Some(rewritten_left) = Self::rewrite(
73            &left,
74            correlated_indices.clone(),
75            0,
76            &mut index_mapping,
77            &mut data_types,
78            &mut index,
79        ) {
80            // This `LogicalProject` is used to make sure that after `LogicalApply`'s left was
81            // rewritten, the new index of `correlated_index` is always at its position in
82            // `correlated_indices`.
83            let exprs = correlated_indices
84                .clone()
85                .into_iter()
86                .enumerate()
87                .map(|(i, correlated_index)| {
88                    let index = index_mapping.map(correlated_index);
89                    let data_type = rewritten_left.schema().fields()[index].data_type.clone();
90                    index_mapping.put(correlated_index, Some(i));
91                    InputRef::new(index, data_type).into()
92                })
93                .collect();
94            let project = LogicalProject::create(rewritten_left, exprs);
95            let distinct = Agg::new(vec![], (0..project.schema().len()).collect(), project);
96            distinct.into()
97        } else {
98            // The left side of the apply is not SPJ. We need to use the general way to calculate
99            // the domain. Distinct + Project + The Left of Apply
100
101            // Use Share
102            left = if self.enable_share_plan {
103                let logical_share = LogicalShare::new(left);
104                logical_share.into()
105            } else {
106                left
107            };
108
109            let exprs = correlated_indices
110                .clone()
111                .into_iter()
112                .map(|correlated_index| {
113                    let data_type = left.schema().fields()[correlated_index].data_type.clone();
114                    InputRef::new(correlated_index, data_type).into()
115                })
116                .collect();
117            let project = LogicalProject::create(left.clone(), exprs);
118            let distinct = Agg::new(vec![], (0..project.schema().len()).collect(), project);
119            distinct.into()
120        };
121
122        let eq_predicates = correlated_indices
123            .into_iter()
124            .enumerate()
125            .map(|(i, correlated_index)| {
126                let shifted_index = i + apply_left_len;
127                let data_type = domain.schema().fields()[i].data_type.clone();
128                let left = InputRef::new(correlated_index, data_type.clone());
129                let right = InputRef::new(shifted_index, data_type);
130                // use null-safe equal
131                FunctionCall::new_unchecked(
132                    ExprType::IsNotDistinctFrom,
133                    vec![left.into(), right.into()],
134                    DataType::Boolean,
135                )
136                .into()
137            })
138            .collect::<Vec<ExprImpl>>();
139
140        let new_apply = apply.clone_with_left_right(left, right);
141        let new_node = new_apply.translate_apply(domain, eq_predicates);
142        Some(new_node)
143    }
144}
145
146impl TranslateApplyRule {
147    pub fn create(enable_share_plan: bool) -> BoxedRule {
148        Box::new(TranslateApplyRule { enable_share_plan })
149    }
150
151    /// Rewrite `LogicalApply`'s left according to `correlated_indices`.
152    ///
153    /// Assumption: only `LogicalJoin`, `LogicalScan`, `LogicalProject` and `LogicalFilter` are in
154    /// the left.
155    fn rewrite(
156        plan: &PlanRef,
157        correlated_indices: Vec<usize>,
158        offset: usize,
159        index_mapping: &mut ColIndexMapping,
160        data_types: &mut HashMap<usize, DataType>,
161        index: &mut usize,
162    ) -> Option<PlanRef> {
163        if let Some(join) = plan.as_logical_join() {
164            Self::rewrite_join(
165                join,
166                correlated_indices,
167                offset,
168                index_mapping,
169                data_types,
170                index,
171            )
172        } else if let Some(apply) = plan.as_logical_apply() {
173            Self::rewrite_apply(
174                apply,
175                correlated_indices,
176                offset,
177                index_mapping,
178                data_types,
179                index,
180            )
181        } else if let Some(scan) = plan.as_logical_scan() {
182            Self::rewrite_scan(
183                scan,
184                correlated_indices,
185                offset,
186                index_mapping,
187                data_types,
188                index,
189            )
190        } else if let Some(filter) = plan.as_logical_filter() {
191            Self::rewrite(
192                &filter.input(),
193                correlated_indices,
194                offset,
195                index_mapping,
196                data_types,
197                index,
198            )
199        } else {
200            // TODO: better to return an error.
201            None
202        }
203    }
204
205    fn rewrite_join(
206        join: &LogicalJoin,
207        required_col_idx: Vec<usize>,
208        mut offset: usize,
209        index_mapping: &mut ColIndexMapping,
210        data_types: &mut HashMap<usize, DataType>,
211        index: &mut usize,
212    ) -> Option<PlanRef> {
213        // TODO: Do we need to take the `on` into account?
214        let left_len = join.left().schema().len();
215        let (left_idxs, right_idxs): (Vec<_>, Vec<_>) = required_col_idx
216            .into_iter()
217            .partition(|idx| *idx < left_len);
218        let mut rewrite =
219            |plan: PlanRef, mut indices: Vec<usize>, is_right: bool| -> Option<PlanRef> {
220                if is_right {
221                    indices.iter_mut().for_each(|index| *index -= left_len);
222                    offset += left_len;
223                }
224                Self::rewrite(&plan, indices, offset, index_mapping, data_types, index)
225            };
226        match (left_idxs.is_empty(), right_idxs.is_empty()) {
227            (true, false) => {
228                // Only accept join which doesn't generate null columns.
229                match join.join_type() {
230                    JoinType::Inner
231                    | JoinType::LeftSemi
232                    | JoinType::RightSemi
233                    | JoinType::LeftAnti
234                    | JoinType::RightAnti
235                    | JoinType::RightOuter
236                    | JoinType::AsofInner => rewrite(join.right(), right_idxs, true),
237                    JoinType::LeftOuter | JoinType::FullOuter | JoinType::AsofLeftOuter => None,
238                    JoinType::Unspecified => unreachable!(),
239                }
240            }
241            (false, true) => {
242                // Only accept join which doesn't generate null columns.
243                match join.join_type() {
244                    JoinType::Inner
245                    | JoinType::LeftSemi
246                    | JoinType::RightSemi
247                    | JoinType::LeftAnti
248                    | JoinType::RightAnti
249                    | JoinType::LeftOuter
250                    | JoinType::AsofInner
251                    | JoinType::AsofLeftOuter => rewrite(join.left(), left_idxs, false),
252                    JoinType::RightOuter | JoinType::FullOuter => None,
253                    JoinType::Unspecified => unreachable!(),
254                }
255            }
256            (false, false) => {
257                // Only accept join which doesn't generate null columns.
258                match join.join_type() {
259                    JoinType::Inner
260                    | JoinType::LeftSemi
261                    | JoinType::RightSemi
262                    | JoinType::LeftAnti
263                    | JoinType::RightAnti
264                    | JoinType::AsofInner => {
265                        let left = rewrite(join.left(), left_idxs, false)?;
266                        let right = rewrite(join.right(), right_idxs, true)?;
267                        let new_join =
268                            LogicalJoin::new(left, right, join.join_type(), Condition::true_cond());
269                        Some(new_join.into())
270                    }
271                    JoinType::LeftOuter
272                    | JoinType::RightOuter
273                    | JoinType::FullOuter
274                    | JoinType::AsofLeftOuter => None,
275                    JoinType::Unspecified => unreachable!(),
276                }
277            }
278            _ => None,
279        }
280    }
281
282    /// ```text
283    ///             LogicalApply
284    ///            /            \
285    ///     LogicalApply       RHS1
286    ///    /            \
287    ///  LHS           RHS2
288    /// ```
289    ///
290    /// A common structure of multi scalar subqueries is a chain of `LogicalApply`. To avoid exponential growth of the domain operator, we need to rewrite the apply and try to simplify the domain as much as possible.
291    /// We use a top-down apply order to rewrite the apply, so that we don't need to handle operator like project and aggregation generated by the domain calculation.
292    /// As a cost, we need to add a flag `translated` to the apply operator to remind `translate_apply_rule` that the apply has been translated.
293    fn rewrite_apply(
294        apply: &LogicalApply,
295        required_col_idx: Vec<usize>,
296        offset: usize,
297        index_mapping: &mut ColIndexMapping,
298        data_types: &mut HashMap<usize, DataType>,
299        index: &mut usize,
300    ) -> Option<PlanRef> {
301        // TODO: Do we need to take the `on` into account?
302        let left_len = apply.left().schema().len();
303        let (left_idxs, right_idxs): (Vec<_>, Vec<_>) = required_col_idx
304            .into_iter()
305            .partition(|idx| *idx < left_len);
306        if !left_idxs.is_empty() && right_idxs.is_empty() {
307            // Deal with multi scalar subqueries
308            match apply.join_type() {
309                JoinType::Inner
310                | JoinType::LeftSemi
311                | JoinType::LeftAnti
312                | JoinType::LeftOuter
313                | JoinType::AsofInner
314                | JoinType::AsofLeftOuter => {
315                    let plan = apply.left();
316                    Self::rewrite(&plan, left_idxs, offset, index_mapping, data_types, index)
317                }
318                JoinType::RightOuter
319                | JoinType::RightAnti
320                | JoinType::RightSemi
321                | JoinType::FullOuter => None,
322                JoinType::Unspecified => unreachable!(),
323            }
324        } else {
325            None
326        }
327    }
328
329    fn rewrite_scan(
330        scan: &LogicalScan,
331        required_col_idx: Vec<usize>,
332        offset: usize,
333        index_mapping: &mut ColIndexMapping,
334        data_types: &mut HashMap<usize, DataType>,
335        index: &mut usize,
336    ) -> Option<PlanRef> {
337        for i in &required_col_idx {
338            let correlated_index = *i + offset;
339            index_mapping.put(correlated_index, Some(*index));
340            data_types.insert(
341                correlated_index,
342                scan.schema().fields()[*i].data_type.clone(),
343            );
344            *index += 1;
345        }
346
347        Some(scan.clone_with_output_indices(required_col_idx).into())
348    }
349}