risingwave_frontend/expr/
expr_rewriter.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use risingwave_common::util::recursive::{Recurse, tracker};
16
17use super::{
18    AggCall, CorrelatedInputRef, EXPR_DEPTH_THRESHOLD, EXPR_TOO_DEEP_NOTICE, ExprImpl,
19    FunctionCall, FunctionCallWithLambda, InputRef, Literal, Parameter, Subquery, TableFunction,
20    UserDefinedFunction, WindowFunction,
21};
22use crate::expr::Now;
23use crate::session::current::notice_to_user;
24
25/// The default implementation of [`ExprRewriter::rewrite_expr`] that simply dispatches to other
26/// methods based on the type of the expression.
27///
28/// You can use this function as a helper to reduce boilerplate code when implementing the trait.
29// TODO: This is essentially a mimic of `super` pattern from OO languages. Ideally, we should
30// adopt the style proposed in https://github.com/risingwavelabs/risingwave/issues/13477.
31pub fn default_rewrite_expr<R: ExprRewriter + ?Sized>(
32    rewriter: &mut R,
33    expr: ExprImpl,
34) -> ExprImpl {
35    // TODO: Implementors may choose to not use this function at all, in which case we will fail
36    // to track the recursion and grow the stack as necessary. The current approach is only a
37    // best-effort attempt to prevent stack overflow.
38    tracker!().recurse(|t| {
39        if t.depth_reaches(EXPR_DEPTH_THRESHOLD) {
40            notice_to_user(EXPR_TOO_DEEP_NOTICE);
41        }
42
43        match expr {
44            ExprImpl::InputRef(inner) => rewriter.rewrite_input_ref(*inner),
45            ExprImpl::Literal(inner) => rewriter.rewrite_literal(*inner),
46            ExprImpl::FunctionCall(inner) => rewriter.rewrite_function_call(*inner),
47            ExprImpl::FunctionCallWithLambda(inner) => {
48                rewriter.rewrite_function_call_with_lambda(*inner)
49            }
50            ExprImpl::AggCall(inner) => rewriter.rewrite_agg_call(*inner),
51            ExprImpl::Subquery(inner) => rewriter.rewrite_subquery(*inner),
52            ExprImpl::CorrelatedInputRef(inner) => rewriter.rewrite_correlated_input_ref(*inner),
53            ExprImpl::TableFunction(inner) => rewriter.rewrite_table_function(*inner),
54            ExprImpl::WindowFunction(inner) => rewriter.rewrite_window_function(*inner),
55            ExprImpl::UserDefinedFunction(inner) => rewriter.rewrite_user_defined_function(*inner),
56            ExprImpl::Parameter(inner) => rewriter.rewrite_parameter(*inner),
57            ExprImpl::Now(inner) => rewriter.rewrite_now(*inner),
58        }
59    })
60}
61
62/// By default, `ExprRewriter` simply traverses the expression tree and leaves nodes unchanged.
63/// Implementations can override a subset of methods and perform transformation on some particular
64/// types of expression.
65pub trait ExprRewriter {
66    fn rewrite_expr(&mut self, expr: ExprImpl) -> ExprImpl {
67        default_rewrite_expr(self, expr)
68    }
69
70    fn rewrite_function_call(&mut self, func_call: FunctionCall) -> ExprImpl {
71        let (func_type, inputs, ret) = func_call.decompose();
72        let inputs = inputs
73            .into_iter()
74            .map(|expr| self.rewrite_expr(expr))
75            .collect();
76        FunctionCall::new_unchecked(func_type, inputs, ret).into()
77    }
78
79    fn rewrite_function_call_with_lambda(&mut self, func_call: FunctionCallWithLambda) -> ExprImpl {
80        let (func_type, inputs, lambda_arg, ret) = func_call.into_parts();
81        let inputs = inputs
82            .into_iter()
83            .map(|expr| self.rewrite_expr(expr))
84            .collect();
85        FunctionCallWithLambda::new_unchecked(func_type, inputs, lambda_arg, ret).into()
86    }
87
88    fn rewrite_agg_call(&mut self, agg_call: AggCall) -> ExprImpl {
89        let AggCall {
90            agg_type,
91            return_type,
92            args,
93            distinct,
94            order_by,
95            filter,
96            direct_args,
97        } = agg_call;
98        let args = args
99            .into_iter()
100            .map(|expr| self.rewrite_expr(expr))
101            .collect();
102        let order_by = order_by.rewrite_expr(self);
103        let filter = filter.rewrite_expr(self);
104        AggCall {
105            agg_type,
106            return_type,
107            args,
108            distinct,
109            order_by,
110            filter,
111            direct_args,
112        }
113        .into()
114    }
115
116    fn rewrite_parameter(&mut self, parameter: Parameter) -> ExprImpl {
117        parameter.into()
118    }
119
120    fn rewrite_literal(&mut self, literal: Literal) -> ExprImpl {
121        literal.into()
122    }
123
124    fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
125        input_ref.into()
126    }
127
128    fn rewrite_subquery(&mut self, subquery: Subquery) -> ExprImpl {
129        subquery.into()
130    }
131
132    fn rewrite_correlated_input_ref(&mut self, input_ref: CorrelatedInputRef) -> ExprImpl {
133        input_ref.into()
134    }
135
136    fn rewrite_table_function(&mut self, table_func: TableFunction) -> ExprImpl {
137        let TableFunction {
138            args,
139            return_type,
140            function_type,
141            user_defined: udtf_catalog,
142        } = table_func;
143        let args = args
144            .into_iter()
145            .map(|expr| self.rewrite_expr(expr))
146            .collect();
147        TableFunction {
148            args,
149            return_type,
150            function_type,
151            user_defined: udtf_catalog,
152        }
153        .into()
154    }
155
156    fn rewrite_window_function(&mut self, window_func: WindowFunction) -> ExprImpl {
157        let WindowFunction {
158            kind,
159            return_type,
160            args,
161            ignore_nulls,
162            partition_by,
163            order_by,
164            frame,
165        } = window_func;
166        let args = args
167            .into_iter()
168            .map(|expr| self.rewrite_expr(expr))
169            .collect();
170        WindowFunction {
171            kind,
172            return_type,
173            args,
174            ignore_nulls,
175            partition_by,
176            order_by,
177            frame,
178        }
179        .into()
180    }
181
182    fn rewrite_user_defined_function(&mut self, udf: UserDefinedFunction) -> ExprImpl {
183        let UserDefinedFunction { args, catalog } = udf;
184        let args = args
185            .into_iter()
186            .map(|expr| self.rewrite_expr(expr))
187            .collect();
188        UserDefinedFunction { args, catalog }.into()
189    }
190
191    fn rewrite_now(&mut self, now: Now) -> ExprImpl {
192        now.into()
193    }
194}