risingwave_frontend/expr/
expr_rewriter.rs

1// Copyright 2022 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use risingwave_common::util::recursive::{Recurse, tracker};
16
17use super::{
18    AggCall, CorrelatedInputRef, EXPR_DEPTH_THRESHOLD, EXPR_TOO_DEEP_NOTICE, ExprImpl,
19    FunctionCall, FunctionCallWithLambda, InputRef, Literal, Parameter, SecretRef, Subquery,
20    TableFunction, UserDefinedFunction, WindowFunction,
21};
22use crate::expr::Now;
23use crate::session::current::notice_to_user;
24
25/// The default implementation of [`ExprRewriter::rewrite_expr`] that simply dispatches to other
26/// methods based on the type of the expression.
27///
28/// You can use this function as a helper to reduce boilerplate code when implementing the trait.
29// TODO: This is essentially a mimic of `super` pattern from OO languages. Ideally, we should
30// adopt the style proposed in https://github.com/risingwavelabs/risingwave/issues/13477.
31pub fn default_rewrite_expr<R: ExprRewriter + ?Sized>(
32    rewriter: &mut R,
33    expr: ExprImpl,
34) -> ExprImpl {
35    // TODO: Implementors may choose to not use this function at all, in which case we will fail
36    // to track the recursion and grow the stack as necessary. The current approach is only a
37    // best-effort attempt to prevent stack overflow.
38    tracker!().recurse(|t| {
39        if t.depth_reaches(EXPR_DEPTH_THRESHOLD) {
40            notice_to_user(EXPR_TOO_DEEP_NOTICE);
41        }
42
43        match expr {
44            ExprImpl::InputRef(inner) => rewriter.rewrite_input_ref(*inner),
45            ExprImpl::Literal(inner) => rewriter.rewrite_literal(*inner),
46            ExprImpl::FunctionCall(inner) => rewriter.rewrite_function_call(*inner),
47            ExprImpl::FunctionCallWithLambda(inner) => {
48                rewriter.rewrite_function_call_with_lambda(*inner)
49            }
50            ExprImpl::AggCall(inner) => rewriter.rewrite_agg_call(*inner),
51            ExprImpl::Subquery(inner) => rewriter.rewrite_subquery(*inner),
52            ExprImpl::CorrelatedInputRef(inner) => rewriter.rewrite_correlated_input_ref(*inner),
53            ExprImpl::TableFunction(inner) => rewriter.rewrite_table_function(*inner),
54            ExprImpl::WindowFunction(inner) => rewriter.rewrite_window_function(*inner),
55            ExprImpl::UserDefinedFunction(inner) => rewriter.rewrite_user_defined_function(*inner),
56            ExprImpl::Parameter(inner) => rewriter.rewrite_parameter(*inner),
57            ExprImpl::Now(inner) => rewriter.rewrite_now(*inner),
58            ExprImpl::SecretRef(inner) => rewriter.rewrite_secret_ref(*inner),
59        }
60    })
61}
62
63/// By default, `ExprRewriter` simply traverses the expression tree and leaves nodes unchanged.
64/// Implementations can override a subset of methods and perform transformation on some particular
65/// types of expression.
66pub trait ExprRewriter {
67    fn rewrite_expr(&mut self, expr: ExprImpl) -> ExprImpl {
68        default_rewrite_expr(self, expr)
69    }
70
71    fn rewrite_function_call(&mut self, func_call: FunctionCall) -> ExprImpl {
72        let (func_type, inputs, ret) = func_call.decompose();
73        let inputs = inputs
74            .into_iter()
75            .map(|expr| self.rewrite_expr(expr))
76            .collect();
77        FunctionCall::new_unchecked(func_type, inputs, ret).into()
78    }
79
80    fn rewrite_function_call_with_lambda(&mut self, func_call: FunctionCallWithLambda) -> ExprImpl {
81        let (func_type, inputs, lambda_arg, ret) = func_call.into_parts();
82        let inputs = inputs
83            .into_iter()
84            .map(|expr| self.rewrite_expr(expr))
85            .collect();
86        FunctionCallWithLambda::new_unchecked(func_type, inputs, lambda_arg, ret).into()
87    }
88
89    fn rewrite_agg_call(&mut self, agg_call: AggCall) -> ExprImpl {
90        let AggCall {
91            agg_type,
92            return_type,
93            args,
94            distinct,
95            order_by,
96            filter,
97            direct_args,
98        } = agg_call;
99        let args = args
100            .into_iter()
101            .map(|expr| self.rewrite_expr(expr))
102            .collect();
103        let order_by = order_by.rewrite_expr(self);
104        let filter = filter.rewrite_expr(self);
105        AggCall {
106            agg_type,
107            return_type,
108            args,
109            distinct,
110            order_by,
111            filter,
112            direct_args,
113        }
114        .into()
115    }
116
117    fn rewrite_parameter(&mut self, parameter: Parameter) -> ExprImpl {
118        parameter.into()
119    }
120
121    fn rewrite_literal(&mut self, literal: Literal) -> ExprImpl {
122        literal.into()
123    }
124
125    fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
126        input_ref.into()
127    }
128
129    fn rewrite_subquery(&mut self, subquery: Subquery) -> ExprImpl {
130        subquery.into()
131    }
132
133    fn rewrite_correlated_input_ref(&mut self, input_ref: CorrelatedInputRef) -> ExprImpl {
134        input_ref.into()
135    }
136
137    fn rewrite_table_function(&mut self, table_func: TableFunction) -> ExprImpl {
138        let TableFunction {
139            args,
140            return_type,
141            function_type,
142            user_defined: udtf_catalog,
143        } = table_func;
144        let args = args
145            .into_iter()
146            .map(|expr| self.rewrite_expr(expr))
147            .collect();
148        TableFunction {
149            args,
150            return_type,
151            function_type,
152            user_defined: udtf_catalog,
153        }
154        .into()
155    }
156
157    fn rewrite_window_function(&mut self, window_func: WindowFunction) -> ExprImpl {
158        let WindowFunction {
159            kind,
160            return_type,
161            args,
162            ignore_nulls,
163            partition_by,
164            order_by,
165            frame,
166        } = window_func;
167        let args = args
168            .into_iter()
169            .map(|expr| self.rewrite_expr(expr))
170            .collect();
171        WindowFunction {
172            kind,
173            return_type,
174            args,
175            ignore_nulls,
176            partition_by,
177            order_by,
178            frame,
179        }
180        .into()
181    }
182
183    fn rewrite_user_defined_function(&mut self, udf: UserDefinedFunction) -> ExprImpl {
184        let UserDefinedFunction { args, catalog } = udf;
185        let args = args
186            .into_iter()
187            .map(|expr| self.rewrite_expr(expr))
188            .collect();
189        UserDefinedFunction { args, catalog }.into()
190    }
191
192    fn rewrite_now(&mut self, now: Now) -> ExprImpl {
193        now.into()
194    }
195
196    fn rewrite_secret_ref(&mut self, secret_ref: SecretRef) -> ExprImpl {
197        secret_ref.into()
198    }
199}