risingwave_frontend/optimizer/rule/correlated_expr_rewriter.rs
1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use crate::expr::{CorrelatedId, CorrelatedInputRef, Expr, ExprImpl, ExprRewriter, InputRef};
16
17/// Base rewriter for pulling up correlated expressions.
18///
19/// Provides common functionality for rewriting correlated input references to regular input references.
20/// Different rules can extend this by implementing their own `rewrite_input_ref` behavior.
21pub struct CorrelatedExprRewriter {
22 pub correlated_id: CorrelatedId,
23}
24
25impl CorrelatedExprRewriter {
26 pub fn new(correlated_id: CorrelatedId) -> Self {
27 Self { correlated_id }
28 }
29
30 /// Common logic for rewriting correlated input references to input references.
31 pub fn rewrite_correlated_input_ref_impl(
32 &mut self,
33 correlated_input_ref: CorrelatedInputRef,
34 ) -> ExprImpl {
35 // Convert correlated_input_ref to input_ref.
36 // only rewrite the correlated_input_ref with the same correlated_id
37 if correlated_input_ref.correlated_id() == self.correlated_id {
38 InputRef::new(
39 correlated_input_ref.index(),
40 correlated_input_ref.return_type(),
41 )
42 .into()
43 } else {
44 correlated_input_ref.into()
45 }
46 }
47}
48
49/// Rewriter for pulling up correlated predicates.
50///
51/// Collects all `InputRef`s and shifts their indices for use in join conditions.
52pub struct PredicateRewriter {
53 pub base: CorrelatedExprRewriter,
54 // All uncorrelated `InputRef`s in the expression.
55 pub input_refs: Vec<InputRef>,
56 pub index: usize,
57}
58
59impl PredicateRewriter {
60 pub fn new(correlated_id: CorrelatedId, index: usize) -> Self {
61 Self {
62 base: CorrelatedExprRewriter::new(correlated_id),
63 input_refs: vec![],
64 index,
65 }
66 }
67}
68
69impl ExprRewriter for PredicateRewriter {
70 fn rewrite_correlated_input_ref(
71 &mut self,
72 correlated_input_ref: CorrelatedInputRef,
73 ) -> ExprImpl {
74 self.base
75 .rewrite_correlated_input_ref_impl(correlated_input_ref)
76 }
77
78 fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
79 let data_type = input_ref.return_type();
80
81 // It will be appended to exprs in LogicalProject, so its index remain the same.
82 self.input_refs.push(input_ref);
83
84 // Rewrite input_ref's index to its new location.
85 let input_ref = InputRef::new(self.index, data_type);
86 self.index += 1;
87 input_ref.into()
88 }
89}
90
91/// Rewriter for pulling up correlated project expressions with values.
92///
93/// For inlining scalar subqueries, we don't need to collect or shift input references.
94pub struct ProjectValueRewriter {
95 pub base: CorrelatedExprRewriter,
96}
97
98impl ProjectValueRewriter {
99 pub fn new(correlated_id: CorrelatedId) -> Self {
100 Self {
101 base: CorrelatedExprRewriter::new(correlated_id),
102 }
103 }
104}
105
106impl ExprRewriter for ProjectValueRewriter {
107 fn rewrite_correlated_input_ref(
108 &mut self,
109 correlated_input_ref: CorrelatedInputRef,
110 ) -> ExprImpl {
111 self.base
112 .rewrite_correlated_input_ref_impl(correlated_input_ref)
113 }
114
115 fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
116 // For project value rule, input_refs in the project should reference Values
117 // Since we're inlining, we don't need to preserve these references
118 // They should be constant-folded or handled separately
119 input_ref.into()
120 }
121}