risingwave_frontend/expr/
session_timezone.rs1use risingwave_common::types::DataType;
16pub use risingwave_pb::expr::expr_node::Type as ExprType;
17
18pub use crate::expr::expr_rewriter::ExprRewriter;
19pub use crate::expr::function_call::FunctionCall;
20use crate::expr::{Expr, ExprImpl, ExprVisitor};
21use crate::session::current;
22
23pub struct SessionTimezone {
26 timezone: String,
27 used: bool,
29}
30
31impl ExprRewriter for SessionTimezone {
32 fn rewrite_function_call(&mut self, func_call: FunctionCall) -> ExprImpl {
33 let (func_type, inputs, ret) = func_call.decompose();
34 let inputs: Vec<ExprImpl> = inputs
35 .into_iter()
36 .map(|expr| self.rewrite_expr(expr))
37 .collect();
38 if let Some(expr) = self.with_timezone(func_type, &inputs, &ret) {
39 self.mark_used();
40 expr
41 } else {
42 FunctionCall::new_unchecked(func_type, inputs, ret).into()
43 }
44 }
45}
46
47impl SessionTimezone {
48 pub fn new(timezone: String) -> Self {
49 Self {
50 timezone,
51 used: false,
52 }
53 }
54
55 pub fn timezone(&self) -> String {
56 self.timezone.clone()
57 }
58
59 pub fn used(&self) -> bool {
60 self.used
61 }
62
63 fn mark_used(&mut self) {
64 if !self.used {
65 self.used = true;
66 current::notice_to_user(format!(
67 "Your session timezone is {}. It was used in the interpretation of timestamps and dates in your query. If this is unintended, \
68 change your timezone to match that of your data's with `set timezone = [timezone]` or \
69 rewrite your query with an explicit timezone conversion, e.g. with `AT TIME ZONE`.\n",
70 self.timezone
71 ));
72 }
73 }
74
75 fn with_timezone(
77 &self,
78 func_type: ExprType,
79 inputs: &[ExprImpl],
80 return_type: &DataType,
81 ) -> Option<ExprImpl> {
82 match func_type {
83 ExprType::Cast => {
98 assert_eq!(inputs.len(), 1);
99 let mut input = inputs[0].clone();
100 let input_type = input.return_type();
101 match (input_type, return_type) {
102 (DataType::Timestamptz, DataType::Varchar)
103 | (DataType::Varchar, DataType::Timestamptz) => {
104 Some(self.cast_with_timezone(input, return_type.clone()))
105 }
106 (DataType::Date, DataType::Timestamptz)
107 | (DataType::Timestamp, DataType::Timestamptz) => {
108 input = input.cast_explicit(&DataType::Timestamp).unwrap();
109 Some(self.at_timezone(input))
110 }
111 (DataType::Timestamptz, DataType::Date)
112 | (DataType::Timestamptz, DataType::Time)
113 | (DataType::Timestamptz, DataType::Timestamp) => {
114 input = self.at_timezone(input);
115 input = input.cast_explicit(return_type).unwrap();
116 Some(input)
117 }
118 _ => None,
119 }
120 }
121 ExprType::Equal
130 | ExprType::NotEqual
131 | ExprType::LessThan
132 | ExprType::LessThanOrEqual
133 | ExprType::GreaterThan
134 | ExprType::GreaterThanOrEqual
135 | ExprType::IsDistinctFrom
136 | ExprType::IsNotDistinctFrom => {
137 assert_eq!(inputs.len(), 2);
138 let mut inputs = inputs.to_vec();
139 for idx in 0..2 {
140 if matches!(inputs[(idx + 1) % 2].return_type(), DataType::Timestamptz)
141 && matches!(
142 inputs[idx % 2].return_type(),
143 DataType::Date | DataType::Timestamp
144 )
145 {
146 let mut to_cast = inputs[idx % 2].clone();
147 to_cast = to_cast.cast_explicit(&DataType::Timestamp).unwrap();
150 inputs[idx % 2] = self.at_timezone(to_cast);
151 return Some(
152 FunctionCall::new_unchecked(func_type, inputs, return_type.clone())
153 .into(),
154 );
155 }
156 }
157 None
158 }
159 ExprType::Subtract | ExprType::Add => {
166 assert_eq!(inputs.len(), 2);
167 let canonical_match = matches!(inputs[0].return_type(), DataType::Timestamptz)
168 && matches!(inputs[1].return_type(), DataType::Interval);
169 let inverse_match = matches!(inputs[1].return_type(), DataType::Timestamptz)
170 && matches!(inputs[0].return_type(), DataType::Interval);
171 assert!(!(inverse_match && func_type == ExprType::Subtract)); if canonical_match || inverse_match {
173 let (orig_timestamptz, interval) =
174 if func_type == ExprType::Add && inverse_match {
175 (inputs[1].clone(), inputs[0].clone())
176 } else {
177 (inputs[0].clone(), inputs[1].clone())
178 };
179 let new_type = match func_type {
180 ExprType::Add => ExprType::AddWithTimeZone,
181 ExprType::Subtract => ExprType::SubtractWithTimeZone,
182 _ => unreachable!(),
183 };
184 let rewritten_expr = FunctionCall::new(
185 new_type,
186 vec![
187 orig_timestamptz,
188 interval,
189 ExprImpl::literal_varchar(self.timezone()),
190 ],
191 )
192 .unwrap()
193 .into();
194 return Some(rewritten_expr);
195 }
196 None
197 }
198 ExprType::DateTrunc | ExprType::Extract | ExprType::DatePart => {
201 if !(inputs.len() == 2 && inputs[1].return_type() == DataType::Timestamptz) {
202 return None;
203 }
204 assert_eq!(inputs[0].return_type(), DataType::Varchar);
205 if let ExprImpl::Literal(lit) = &inputs[0]
206 && matches!(func_type, ExprType::Extract | ExprType::DatePart)
207 && lit
208 .get_data()
209 .as_ref()
210 .is_none_or(|v| v.as_utf8().eq_ignore_ascii_case("epoch"))
211 {
212 return None;
215 }
216 let mut new_inputs = inputs.to_vec();
217 new_inputs.push(ExprImpl::literal_varchar(self.timezone()));
218 Some(FunctionCall::new(func_type, new_inputs).unwrap().into())
219 }
220 ExprType::CharToTimestamptz => {
223 if !(inputs.len() == 2
224 && inputs[0].return_type() == DataType::Varchar
225 && inputs[1].return_type() == DataType::Varchar)
226 {
227 return None;
228 }
229 let mut new_inputs = inputs.to_vec();
230 new_inputs.push(ExprImpl::literal_varchar(self.timezone()));
231 Some(FunctionCall::new(func_type, new_inputs).unwrap().into())
232 }
233 ExprType::ToChar => {
236 if !(inputs.len() == 2
237 && inputs[0].return_type() == DataType::Timestamptz
238 && inputs[1].return_type() == DataType::Varchar)
239 {
240 return None;
241 }
242 let mut new_inputs = inputs.to_vec();
243 new_inputs.push(ExprImpl::literal_varchar(self.timezone()));
244 Some(FunctionCall::new(func_type, new_inputs).unwrap().into())
245 }
246 _ => None,
247 }
248 }
249
250 fn at_timezone(&self, input: ExprImpl) -> ExprImpl {
251 FunctionCall::new(
252 ExprType::AtTimeZone,
253 vec![input, ExprImpl::literal_varchar(self.timezone.clone())],
254 )
255 .unwrap()
256 .into()
257 }
258
259 fn cast_with_timezone(&self, input: ExprImpl, return_type: DataType) -> ExprImpl {
260 FunctionCall::new_unchecked(
261 ExprType::CastWithTimeZone,
262 vec![input, ExprImpl::literal_varchar(self.timezone.clone())],
263 return_type,
264 )
265 .into()
266 }
267}
268
269#[derive(Default)]
270pub struct TimestamptzExprFinder {
271 has: bool,
272}
273
274impl TimestamptzExprFinder {
275 pub fn has(&self) -> bool {
276 self.has
277 }
278}
279
280impl ExprVisitor for TimestamptzExprFinder {
281 fn visit_function_call(&mut self, func_call: &FunctionCall) {
282 if func_call.return_type() == DataType::Timestamptz {
283 self.has = true;
284 return;
285 }
286
287 for input in &func_call.inputs {
288 if input.return_type() == DataType::Timestamptz {
289 self.has = true;
290 return;
291 }
292 }
293
294 func_call
295 .inputs()
296 .iter()
297 .for_each(|expr| self.visit_expr(expr));
298 }
299}