risingwave_frontend/optimizer/rule/
apply_table_function_to_project_set_rule.rs

1// Copyright 2026 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use itertools::Itertools;
16use risingwave_pb::plan_common::JoinType;
17
18use super::prelude::{PlanRef, *};
19use crate::expr::{CorrelatedId, CorrelatedInputRef, Expr, ExprImpl, ExprRewriter, InputRef};
20use crate::optimizer::plan_node::generic::GenericPlanRef;
21use crate::optimizer::plan_node::{
22    LogicalApply, LogicalProject, LogicalProjectSet, LogicalValues, PlanTreeNodeUnary,
23};
24
25/// Convert a correlated `LogicalApply` with a converted table function RHS into a `LogicalProjectSet`
26/// over LHS, so that we can avoid the Domain/Distinct introduced by general apply translation.
27///
28/// This targets the pattern produced by [`TableFunctionToProjectSetRule`]:
29///
30/// ```text
31/// LogicalApply (Inner, on true)
32///  /                       \\
33/// LHS                 LogicalProject
34///                          |
35///                    LogicalProjectSet
36///                          |
37///                     LogicalValues
38/// ```
39///
40/// After:
41///
42/// ```text
43/// LogicalProject
44///      |
45/// LogicalProjectSet
46///      |
47///     LHS
48/// ```
49///
50/// The resulting `ProjectSet` is stateless in streaming, and its hidden `projected_row_id` can be
51/// used to derive ordinality.
52///
53/// [`TableFunctionToProjectSetRule`]: crate::optimizer::rule::TableFunctionToProjectSetRule
54pub struct ApplyTableFunctionToProjectSetRule {}
55
56impl Rule<Logical> for ApplyTableFunctionToProjectSetRule {
57    fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
58        let apply: &LogicalApply = plan.as_logical_apply()?;
59        let (left, right, on, join_type, correlated_id, correlated_indices, max_one_row) =
60            apply.clone().decompose();
61
62        if max_one_row || join_type != JoinType::Inner || !on.always_true() {
63            return None;
64        }
65
66        // Only handle RHS generated by `TableFunctionToProjectSetRule`.
67        let right_project = right.as_logical_project()?;
68        let right_project_input = right_project.input();
69        let right_project_set: &LogicalProjectSet = right_project_input.as_logical_project_set()?;
70        let right_project_set_input = right_project_set.input();
71        let right_values: &LogicalValues = right_project_set_input.as_logical_values()?;
72        if !right_values.is_empty_scalar() {
73            return None;
74        }
75
76        let left_len = left.schema().len();
77
78        // Build a new ProjectSet over LHS:
79        // - keep all LHS columns, so they are duplicated per unnested row
80        // - append RHS ProjectSet expressions, rewriting CorrelatedInputRef -> InputRef on LHS
81        let mut select_list: Vec<ExprImpl> = left
82            .schema()
83            .data_types()
84            .into_iter()
85            .enumerate()
86            .map(|(idx, ty)| InputRef::new(idx, ty).into())
87            .collect();
88
89        let mut corr_rewriter =
90            CorrelatedInputRefToInputRefRewriter::new(correlated_id, correlated_indices, left_len);
91        let rhs_select_list = right_project_set
92            .select_list()
93            .iter()
94            .cloned()
95            .map(|e| corr_rewriter.rewrite_expr(e))
96            .collect_vec();
97        if !corr_rewriter.touched() {
98            // If RHS doesn't reference LHS, this is likely not a lateral table function apply.
99            return None;
100        }
101        select_list.extend(rhs_select_list);
102
103        let new_project_set: PlanRef = LogicalProjectSet::new(left.clone(), select_list).into();
104
105        // Reconstruct the RHS `LogicalProject` on top of the new ProjectSet output.
106        // New ProjectSet output layout:
107        //   0: projected_row_id
108        //   1..=left_len: LHS columns
109        //   left_len+1..: RHS ProjectSet outputs (old idx >= 1)
110        //
111        // We output: [LHS columns..., RHS project exprs...]
112        let mut out_exprs: Vec<ExprImpl> = left
113            .schema()
114            .data_types()
115            .into_iter()
116            .enumerate()
117            .map(|(idx, ty)| InputRef::new(idx + 1, ty).into())
118            .collect();
119
120        let mut shift_rewriter = ShiftRhsInputRefRewriter::new(left_len);
121        let rhs_project_exprs = right_project
122            .exprs()
123            .iter()
124            .cloned()
125            .map(|e| shift_rewriter.rewrite_expr(e))
126            .collect_vec();
127        out_exprs.extend(rhs_project_exprs);
128
129        Some(LogicalProject::new(new_project_set, out_exprs).into())
130    }
131}
132
133impl ApplyTableFunctionToProjectSetRule {
134    pub fn create() -> BoxedRule {
135        Box::new(ApplyTableFunctionToProjectSetRule {})
136    }
137}
138
139struct CorrelatedInputRefToInputRefRewriter {
140    correlated_id: CorrelatedId,
141    correlated_indices: Vec<usize>,
142    left_len: usize,
143    touched: bool,
144}
145
146impl CorrelatedInputRefToInputRefRewriter {
147    fn new(correlated_id: CorrelatedId, correlated_indices: Vec<usize>, left_len: usize) -> Self {
148        Self {
149            correlated_id,
150            correlated_indices,
151            left_len,
152            touched: false,
153        }
154    }
155
156    fn touched(&self) -> bool {
157        self.touched
158    }
159}
160
161impl ExprRewriter for CorrelatedInputRefToInputRefRewriter {
162    fn rewrite_correlated_input_ref(
163        &mut self,
164        correlated_input_ref: CorrelatedInputRef,
165    ) -> ExprImpl {
166        if correlated_input_ref.correlated_id() != self.correlated_id {
167            return correlated_input_ref.into();
168        }
169        let idx = correlated_input_ref.index();
170        if idx >= self.left_len || !self.correlated_indices.contains(&idx) {
171            // Be conservative: only rewrite the correlated indices recorded on the apply.
172            return correlated_input_ref.into();
173        }
174        self.touched = true;
175        InputRef::new(idx, correlated_input_ref.return_type()).into()
176    }
177}
178
179struct ShiftRhsInputRefRewriter {
180    left_len: usize,
181}
182
183impl ShiftRhsInputRefRewriter {
184    fn new(left_len: usize) -> Self {
185        Self { left_len }
186    }
187}
188
189impl ExprRewriter for ShiftRhsInputRefRewriter {
190    fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
191        if input_ref.index() == 0 {
192            // projected_row_id stays at 0
193            input_ref.into()
194        } else {
195            InputRef::new(input_ref.index() + self.left_len, input_ref.return_type()).into()
196        }
197    }
198}