risingwave_frontend/expr/
session_timezone.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use risingwave_common::types::DataType;
16pub use risingwave_pb::expr::expr_node::Type as ExprType;
17
18pub use crate::expr::expr_rewriter::ExprRewriter;
19pub use crate::expr::function_call::FunctionCall;
20use crate::expr::{Expr, ExprImpl, ExprVisitor};
21use crate::session::current;
22
23/// `SessionTimezone` will be used to resolve session
24/// timezone-dependent casts, comparisons or arithmetic.
25pub struct SessionTimezone {
26    timezone: String,
27    /// Whether or not the session timezone was used
28    used: bool,
29}
30
31impl ExprRewriter for SessionTimezone {
32    fn rewrite_function_call(&mut self, func_call: FunctionCall) -> ExprImpl {
33        let (func_type, inputs, ret) = func_call.decompose();
34        let inputs: Vec<ExprImpl> = inputs
35            .into_iter()
36            .map(|expr| self.rewrite_expr(expr))
37            .collect();
38        if let Some(expr) = self.with_timezone(func_type, &inputs, ret.clone()) {
39            self.mark_used();
40            expr
41        } else {
42            FunctionCall::new_unchecked(func_type, inputs, ret).into()
43        }
44    }
45}
46
47impl SessionTimezone {
48    pub fn new(timezone: String) -> Self {
49        Self {
50            timezone,
51            used: false,
52        }
53    }
54
55    pub fn timezone(&self) -> String {
56        self.timezone.clone()
57    }
58
59    pub fn used(&self) -> bool {
60        self.used
61    }
62
63    fn mark_used(&mut self) {
64        if !self.used {
65            self.used = true;
66            current::notice_to_user(format!(
67                "Your session timezone is {}. It was used in the interpretation of timestamps and dates in your query. If this is unintended, \
68                change your timezone to match that of your data's with `set timezone = [timezone]` or \
69                rewrite your query with an explicit timezone conversion, e.g. with `AT TIME ZONE`.\n",
70                self.timezone
71            ));
72        }
73    }
74
75    // Inlines conversions based on session timezone if required by the function
76    fn with_timezone(
77        &self,
78        func_type: ExprType,
79        inputs: &[ExprImpl],
80        return_type: DataType,
81    ) -> Option<ExprImpl> {
82        match func_type {
83            // `input_timestamptz::varchar`
84            // => `cast_with_time_zone(input_timestamptz, zone_string)`
85            // `input_varchar::timestamptz`
86            // => `cast_with_time_zone(input_varchar, zone_string)`
87            // `input_date::timestamptz`
88            // => `input_date::timestamp AT TIME ZONE zone_string`
89            // `input_timestamp::timestamptz`
90            // => `input_timestamp AT TIME ZONE zone_string`
91            // `input_timestamptz::date`
92            // => `(input_timestamptz AT TIME ZONE zone_string)::date`
93            // `input_timestamptz::time`
94            // => `(input_timestamptz AT TIME ZONE zone_string)::time`
95            // `input_timestamptz::timestamp`
96            // => `input_timestamptz AT TIME ZONE zone_string`
97            ExprType::Cast => {
98                assert_eq!(inputs.len(), 1);
99                let mut input = inputs[0].clone();
100                let input_type = input.return_type();
101                match (input_type, return_type.clone()) {
102                    (DataType::Timestamptz, DataType::Varchar)
103                    | (DataType::Varchar, DataType::Timestamptz) => {
104                        Some(self.cast_with_timezone(input, return_type))
105                    }
106                    (DataType::Date, DataType::Timestamptz)
107                    | (DataType::Timestamp, DataType::Timestamptz) => {
108                        input = input.cast_explicit(DataType::Timestamp).unwrap();
109                        Some(self.at_timezone(input))
110                    }
111                    (DataType::Timestamptz, DataType::Date)
112                    | (DataType::Timestamptz, DataType::Time)
113                    | (DataType::Timestamptz, DataType::Timestamp) => {
114                        input = self.at_timezone(input);
115                        input = input.cast_explicit(return_type).unwrap();
116                        Some(input)
117                    }
118                    _ => None,
119                }
120            }
121            // `lhs_date CMP rhs_timestamptz`
122            // => `(lhs_date::timestamp AT TIME ZONE zone_string) CMP rhs_timestamptz`
123            // `lhs_timestamp CMP rhs_timestamptz`
124            // => `(lhs_timestamp AT TIME ZONE zone_string) CMP rhs_timestamptz`
125            // `lhs_timestamptz CMP rhs_date`
126            // => `lhs_timestamptz CMP (rhs_date::timestamp AT TIME ZONE zone_string)`
127            // `lhs_timestamptz CMP rhs_timestamp`
128            // => `lhs_timestamptz CMP (rhs_timestamp AT TIME ZONE zone_string)`
129            ExprType::Equal
130            | ExprType::NotEqual
131            | ExprType::LessThan
132            | ExprType::LessThanOrEqual
133            | ExprType::GreaterThan
134            | ExprType::GreaterThanOrEqual
135            | ExprType::IsDistinctFrom
136            | ExprType::IsNotDistinctFrom => {
137                assert_eq!(inputs.len(), 2);
138                let mut inputs = inputs.to_vec();
139                for idx in 0..2 {
140                    if matches!(inputs[(idx + 1) % 2].return_type(), DataType::Timestamptz)
141                        && matches!(
142                            inputs[idx % 2].return_type(),
143                            DataType::Date | DataType::Timestamp
144                        )
145                    {
146                        let mut to_cast = inputs[idx % 2].clone();
147                        // Cast to `Timestamp` first, then use `AT TIME ZONE` to convert to
148                        // `Timestamptz`
149                        to_cast = to_cast.cast_explicit(DataType::Timestamp).unwrap();
150                        inputs[idx % 2] = self.at_timezone(to_cast);
151                        return Some(
152                            FunctionCall::new_unchecked(func_type, inputs, return_type).into(),
153                        );
154                    }
155                }
156                None
157            }
158            // `add(lhs_interval, rhs_timestamptz)`
159            // => `add_with_time_zone(rhs_timestamptz, lhs_interval, zone_string)`
160            // `add(lhs_timestamptz, rhs_interval)`
161            // => `add_with_time_zone(lhs_timestamptz, rhs_interval, zone_string)`
162            // `subtract(lhs_timestamptz, rhs_interval)`
163            // => `subtract_with_time_zone(lhs_timestamptz, rhs_interval, zone_string)`
164            ExprType::Subtract | ExprType::Add => {
165                assert_eq!(inputs.len(), 2);
166                let canonical_match = matches!(inputs[0].return_type(), DataType::Timestamptz)
167                    && matches!(inputs[1].return_type(), DataType::Interval);
168                let inverse_match = matches!(inputs[1].return_type(), DataType::Timestamptz)
169                    && matches!(inputs[0].return_type(), DataType::Interval);
170                assert!(!(inverse_match && func_type == ExprType::Subtract)); // This should never have been parsed.
171                if canonical_match || inverse_match {
172                    let (orig_timestamptz, interval) =
173                        if func_type == ExprType::Add && inverse_match {
174                            (inputs[1].clone(), inputs[0].clone())
175                        } else {
176                            (inputs[0].clone(), inputs[1].clone())
177                        };
178                    let new_type = match func_type {
179                        ExprType::Add => ExprType::AddWithTimeZone,
180                        ExprType::Subtract => ExprType::SubtractWithTimeZone,
181                        _ => unreachable!(),
182                    };
183                    let rewritten_expr = FunctionCall::new(
184                        new_type,
185                        vec![
186                            orig_timestamptz,
187                            interval,
188                            ExprImpl::literal_varchar(self.timezone()),
189                        ],
190                    )
191                    .unwrap()
192                    .into();
193                    return Some(rewritten_expr);
194                }
195                None
196            }
197            // `date_trunc(field_string, input_timestamptz)`
198            // => `date_trunc(field_string, input_timestamptz, zone_string)`
199            ExprType::DateTrunc | ExprType::Extract | ExprType::DatePart => {
200                if !(inputs.len() == 2 && inputs[1].return_type() == DataType::Timestamptz) {
201                    return None;
202                }
203                assert_eq!(inputs[0].return_type(), DataType::Varchar);
204                if let ExprImpl::Literal(lit) = &inputs[0]
205                    && matches!(func_type, ExprType::Extract | ExprType::DatePart)
206                    && lit
207                        .get_data()
208                        .as_ref()
209                        .is_none_or(|v| v.as_utf8().eq_ignore_ascii_case("epoch"))
210                {
211                    // No need to rewrite when field is `null` or `epoch`.
212                    // This is optional but avoids false warning in common case.
213                    return None;
214                }
215                let mut new_inputs = inputs.to_vec();
216                new_inputs.push(ExprImpl::literal_varchar(self.timezone()));
217                Some(FunctionCall::new(func_type, new_inputs).unwrap().into())
218            }
219            // `char_to_timestamptz(input_string, format_string)`
220            // => `char_to_timestamptz(input_string, format_string, zone_string)`
221            ExprType::CharToTimestamptz => {
222                if !(inputs.len() == 2
223                    && inputs[0].return_type() == DataType::Varchar
224                    && inputs[1].return_type() == DataType::Varchar)
225                {
226                    return None;
227                }
228                let mut new_inputs = inputs.to_vec();
229                new_inputs.push(ExprImpl::literal_varchar(self.timezone()));
230                Some(FunctionCall::new(func_type, new_inputs).unwrap().into())
231            }
232            // `to_char(input_timestamptz, format_string)`
233            // => `to_char(input_timestamptz, format_string, zone_string)`
234            ExprType::ToChar => {
235                if !(inputs.len() == 2
236                    && inputs[0].return_type() == DataType::Timestamptz
237                    && inputs[1].return_type() == DataType::Varchar)
238                {
239                    return None;
240                }
241                let mut new_inputs = inputs.to_vec();
242                new_inputs.push(ExprImpl::literal_varchar(self.timezone()));
243                Some(FunctionCall::new(func_type, new_inputs).unwrap().into())
244            }
245            _ => None,
246        }
247    }
248
249    fn at_timezone(&self, input: ExprImpl) -> ExprImpl {
250        FunctionCall::new(
251            ExprType::AtTimeZone,
252            vec![input, ExprImpl::literal_varchar(self.timezone.clone())],
253        )
254        .unwrap()
255        .into()
256    }
257
258    fn cast_with_timezone(&self, input: ExprImpl, return_type: DataType) -> ExprImpl {
259        FunctionCall::new_unchecked(
260            ExprType::CastWithTimeZone,
261            vec![input, ExprImpl::literal_varchar(self.timezone.clone())],
262            return_type,
263        )
264        .into()
265    }
266}
267
268#[derive(Default)]
269pub struct TimestamptzExprFinder {
270    has: bool,
271}
272
273impl TimestamptzExprFinder {
274    pub fn has(&self) -> bool {
275        self.has
276    }
277}
278
279impl ExprVisitor for TimestamptzExprFinder {
280    fn visit_function_call(&mut self, func_call: &FunctionCall) {
281        if func_call.return_type() == DataType::Timestamptz {
282            self.has = true;
283            return;
284        }
285
286        for input in &func_call.inputs {
287            if input.return_type() == DataType::Timestamptz {
288                self.has = true;
289                return;
290            }
291        }
292
293        func_call
294            .inputs()
295            .iter()
296            .for_each(|expr| self.visit_expr(expr));
297    }
298}