risingwave_frontend/expr/
pure.rs

1// Copyright 2023 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::borrow::Cow;
16
17use expr_node::Type;
18use risingwave_pb::expr::expr_node;
19
20use super::{ExprImpl, ExprVisitor};
21use crate::expr::FunctionCall;
22
23#[derive(Default)]
24pub(crate) struct ImpureAnalyzer {
25    impure: Option<Cow<'static, str>>,
26}
27
28impl ImpureAnalyzer {
29    /// Returns `true` if the expression is impure.
30    ///
31    /// Only call this method after visiting the expression.
32    pub fn is_impure(&self) -> bool {
33        self.impure.is_some()
34    }
35
36    /// Returns the description of the impure expression if it is impure, for error reporting.
37    /// `None` if the expression is pure.
38    ///
39    /// Only call this method after visiting the expression.
40    pub fn impure_expr_desc(&self) -> Option<&str> {
41        self.impure.as_deref()
42    }
43}
44
45impl ExprVisitor for ImpureAnalyzer {
46    fn visit_user_defined_function(&mut self, func_call: &super::UserDefinedFunction) {
47        let name = &func_call.catalog.name;
48        self.impure = Some(format!("user-defined function `{name}`").into());
49    }
50
51    fn visit_table_function(&mut self, func_call: &super::TableFunction) {
52        use crate::expr::table_function::TableFunctionType as Type;
53        let func_type = func_call.function_type;
54        match func_type {
55            Type::Unspecified => unreachable!(),
56
57            // deterministic
58            Type::GenerateSeries
59            | Type::Unnest
60            | Type::RegexpMatches
61            | Type::Range
62            | Type::GenerateSubscripts
63            | Type::PgExpandarray
64            | Type::JsonbArrayElements
65            | Type::JsonbArrayElementsText
66            | Type::JsonbEach
67            | Type::JsonbEachText
68            | Type::JsonbObjectKeys
69            | Type::JsonbPathQuery
70            | Type::JsonbPopulateRecordset
71            | Type::JsonbToRecordset => {
72                func_call.args.iter().for_each(|expr| self.visit_expr(expr));
73            }
74
75            // indeterministic
76            Type::FileScan
77            | Type::PostgresQuery
78            | Type::MysqlQuery
79            | Type::InternalBackfillProgress
80            | Type::InternalSourceBackfillProgress
81            | Type::InternalGetChannelDeltaStats
82            | Type::PgGetKeywords => {
83                self.impure = Some(func_type.as_str_name().into());
84            }
85            Type::UserDefined => {
86                let name = &func_call.user_defined.as_ref().unwrap().name;
87                self.impure = Some(format!("user-defined table function `{name}`").into());
88            }
89        }
90    }
91
92    fn visit_now(&mut self, _: &super::Now) {
93        self.impure = Some("NOW or PROCTIME".into());
94    }
95
96    fn visit_secret_ref(&mut self, secret_ref: &super::SecretRef) {
97        self.impure = Some(format!("secret reference `{}`", secret_ref.secret_name).into());
98    }
99
100    fn visit_function_call(&mut self, func_call: &super::FunctionCall) {
101        let func_type = func_call.func_type();
102        match func_type {
103            Type::Unspecified => unreachable!(),
104            #[expect(deprecated)]
105            Type::Add
106            | Type::Subtract
107            | Type::Multiply
108            | Type::Divide
109            | Type::Modulus
110            | Type::Equal
111            | Type::NotEqual
112            | Type::LessThan
113            | Type::LessThanOrEqual
114            | Type::GreaterThan
115            | Type::GreaterThanOrEqual
116            | Type::And
117            | Type::Or
118            | Type::Not
119            | Type::In
120            | Type::Some
121            | Type::All
122            | Type::BitwiseAnd
123            | Type::BitwiseOr
124            | Type::BitwiseXor
125            | Type::BitwiseNot
126            | Type::BitwiseShiftLeft
127            | Type::BitwiseShiftRight
128            | Type::Extract
129            | Type::DatePart
130            | Type::TumbleStart
131            | Type::SecToTimestamptz
132            | Type::AtTimeZone
133            | Type::DateTrunc
134            | Type::DateBin
135            | Type::MakeDate
136            | Type::MakeTime
137            | Type::MakeTimestamp
138            | Type::CharToTimestamptz
139            | Type::CharToDate
140            | Type::CastWithTimeZone
141            | Type::AddWithTimeZone
142            | Type::SubtractWithTimeZone
143            | Type::Cast
144            | Type::Substr
145            | Type::Length
146            | Type::Like
147            | Type::ILike
148            | Type::SimilarToEscape
149            | Type::Upper
150            | Type::Lower
151            | Type::Trim
152            | Type::Replace
153            | Type::Position
154            | Type::Ltrim
155            | Type::Rtrim
156            | Type::Case
157            | Type::ConstantLookup
158            | Type::RoundDigit
159            | Type::Round
160            | Type::Ascii
161            | Type::Translate
162            | Type::Coalesce
163            | Type::ConcatWs
164            | Type::ConcatWsVariadic
165            | Type::Abs
166            | Type::SplitPart
167            | Type::Ceil
168            | Type::Floor
169            | Type::Trunc
170            | Type::ToChar
171            | Type::Md5
172            | Type::CharLength
173            | Type::Repeat
174            | Type::ConcatOp
175            | Type::ByteaConcatOp
176            | Type::Concat
177            | Type::ConcatVariadic
178            | Type::BoolOut
179            | Type::OctetLength
180            | Type::BitLength
181            | Type::Overlay
182            | Type::RegexpMatch
183            | Type::RegexpReplace
184            | Type::RegexpCount
185            | Type::RegexpSplitToArray
186            | Type::RegexpEq
187            | Type::Pow
188            | Type::Exp
189            | Type::Ln
190            | Type::Log10
191            | Type::Chr
192            | Type::StartsWith
193            | Type::Initcap
194            | Type::Lpad
195            | Type::Rpad
196            | Type::Reverse
197            | Type::Strpos
198            | Type::ToAscii
199            | Type::ToHex
200            | Type::QuoteIdent
201            | Type::Sin
202            | Type::Cos
203            | Type::Tan
204            | Type::Cot
205            | Type::Asin
206            | Type::Acos
207            | Type::Acosd
208            | Type::Atan
209            | Type::Atan2
210            | Type::Atand
211            | Type::Atan2d
212            | Type::Sqrt
213            | Type::Cbrt
214            | Type::Sign
215            | Type::Scale
216            | Type::MinScale
217            | Type::TrimScale
218            | Type::Gamma
219            | Type::Lgamma
220            | Type::Left
221            | Type::Right
222            | Type::Degrees
223            | Type::Radians
224            | Type::IsTrue
225            | Type::IsNotTrue
226            | Type::IsFalse
227            | Type::IsNotFalse
228            | Type::IsNull
229            | Type::IsNotNull
230            | Type::IsDistinctFrom
231            | Type::IsNotDistinctFrom
232            | Type::Neg
233            | Type::Field
234            | Type::Array
235            | Type::ArrayAccess
236            | Type::ArrayRangeAccess
237            | Type::Row
238            | Type::ArrayToString
239            | Type::ArrayCat
240            | Type::ArrayMax
241            | Type::ArraySum
242            | Type::ArraySort
243            | Type::ArrayAppend
244            | Type::ArrayReverse
245            | Type::ArrayPrepend
246            | Type::FormatType
247            | Type::ArrayDistinct
248            | Type::ArrayMin
249            | Type::ArrayDims
250            | Type::ArrayLength
251            | Type::Cardinality
252            | Type::TrimArray
253            | Type::ArrayRemove
254            | Type::ArrayReplace
255            | Type::ArrayPosition
256            | Type::ArrayContains
257            | Type::ArrayContained
258            | Type::ArrayOverlaps
259            | Type::ArrayFlatten
260            | Type::HexToInt256
261            | Type::JsonbConcat
262            | Type::JsonbAccess
263            | Type::JsonbAccessStr
264            | Type::JsonbExtractPath
265            | Type::JsonbExtractPathVariadic
266            | Type::JsonbExtractPathText
267            | Type::JsonbExtractPathTextVariadic
268            | Type::JsonbTypeof
269            | Type::JsonbArrayLength
270            | Type::JsonbObject
271            | Type::JsonbPretty
272            | Type::JsonbDeletePath
273            | Type::JsonbContains
274            | Type::JsonbContained
275            | Type::JsonbExists
276            | Type::JsonbExistsAny
277            | Type::JsonbExistsAll
278            | Type::JsonbStripNulls
279            | Type::JsonbBuildArray
280            | Type::JsonbBuildArrayVariadic
281            | Type::JsonbBuildObject
282            | Type::JsonbPopulateRecord
283            | Type::JsonbToArray
284            | Type::JsonbToRecord
285            | Type::JsonbBuildObjectVariadic
286            | Type::JsonbPathExists
287            | Type::JsonbPathMatch
288            | Type::JsonbPathQueryArray
289            | Type::JsonbPathQueryFirst
290            | Type::JsonbSet
291            | Type::JsonbPopulateMap
292            | Type::IsJson
293            | Type::ToJsonb
294            | Type::Sind
295            | Type::Cosd
296            | Type::Cotd
297            | Type::Asind
298            | Type::Sinh
299            | Type::Cosh
300            | Type::Coth
301            | Type::Tanh
302            | Type::Atanh
303            | Type::Asinh
304            | Type::Acosh
305            | Type::Decode
306            | Type::Encode
307            | Type::GetBit
308            | Type::GetByte
309            | Type::SetBit
310            | Type::SetByte
311            | Type::BitCount
312            | Type::Sha1
313            | Type::Sha224
314            | Type::Sha256
315            | Type::Sha384
316            | Type::Sha512
317            | Type::Crc32
318            | Type::Crc32c
319            | Type::Hmac
320            | Type::SecureCompare
321            | Type::Decrypt
322            | Type::Encrypt
323            | Type::Tand
324            | Type::ArrayPositions
325            | Type::StringToArray
326            | Type::Format
327            | Type::FormatVariadic
328            | Type::PgwireSend
329            | Type::PgwireRecv
330            | Type::ArrayTransform
331            | Type::Greatest
332            | Type::Least
333            | Type::ConvertFrom
334            | Type::ConvertTo
335            | Type::IcebergTransform
336            | Type::InetNtoa
337            | Type::InetAton
338            | Type::QuoteLiteral
339            | Type::QuoteNullable
340            | Type::MapFromEntries
341            | Type::MapAccess
342            | Type::MapKeys
343            | Type::MapValues
344            | Type::MapEntries
345            | Type::MapFromKeyValues
346            | Type::MapCat
347            | Type::MapContains
348            | Type::MapDelete
349            | Type::MapFilter
350            | Type::MapInsert
351            | Type::MapLength
352            | Type::L2Distance
353            | Type::CosineDistance
354            | Type::L1Distance
355            | Type::InnerProduct
356            | Type::VecConcat
357            | Type::L2Norm
358            | Type::L2Normalize
359            | Type::Subvector
360            // TODO: `rw_vnode` is more like STABLE instead of IMMUTABLE, because even its result is
361            // deterministic, it needs to read the total vnode count from the context, which means that
362            // it cannot be evaluated during constant folding. We have to treat it pure here so it can be used
363            // internally without materialization.
364            | Type::Vnode
365            | Type::VnodeUser
366            | Type::RwEpochToTs
367            | Type::CheckNotNull
368            | Type::CompositeCast =>
369            // expression output is deterministic(same result for the same input)
370            {
371                func_call
372                    .inputs()
373                    .iter()
374                    .for_each(|expr| self.visit_expr(expr));
375            }
376            // expression output is not deterministic
377            Type::TestFeature
378            | Type::License
379            | Type::Proctime
380            | Type::PgSleep
381            | Type::PgSleepFor
382            | Type::PgSleepUntil
383            | Type::CastRegclass
384            | Type::PgGetIndexdef
385            | Type::ColDescription
386            | Type::PgGetViewdef
387            | Type::PgGetUserbyid
388            | Type::PgIndexesSize
389            | Type::PgRelationSize
390            | Type::PgGetSerialSequence
391            | Type::PgIndexColumnHasProperty
392            | Type::HasTablePrivilege
393            | Type::HasAnyColumnPrivilege
394            | Type::HasSchemaPrivilege
395            | Type::MakeTimestamptz
396            | Type::PgIsInRecovery
397            | Type::RwRecoveryStatus
398            | Type::RwClusterId
399            | Type::RwFragmentVnodes
400            | Type::RwActorVnodes
401            | Type::PgTableIsVisible
402            | Type::HasFunctionPrivilege
403            | Type::OpenaiEmbedding
404            | Type::HasDatabasePrivilege
405            | Type::Random => self.impure = Some(func_type.as_str_name().into()),
406        }
407    }
408}
409
410pub fn is_pure(expr: &ExprImpl) -> bool {
411    !is_impure(expr)
412}
413
414pub fn is_impure(expr: &ExprImpl) -> bool {
415    let mut a = ImpureAnalyzer::default();
416    a.visit_expr(expr);
417    a.is_impure()
418}
419
420pub fn is_impure_func_call(func_call: &FunctionCall) -> bool {
421    let mut a = ImpureAnalyzer::default();
422    a.visit_function_call(func_call);
423    a.is_impure()
424}
425
426/// Returns the description of the impure expression if it is impure, for error reporting.
427/// `None` if the expression is pure.
428pub fn impure_expr_desc(expr: &ExprImpl) -> Option<String> {
429    let mut a = ImpureAnalyzer::default();
430    a.visit_expr(expr);
431    a.impure_expr_desc().map(|s| s.to_owned())
432}
433
434#[cfg(test)]
435mod tests {
436    use risingwave_common::types::DataType;
437    use risingwave_pb::expr::expr_node::Type;
438
439    use crate::expr::{ExprImpl, FunctionCall, InputRef, is_impure, is_pure};
440
441    fn expect_pure(expr: &ExprImpl) {
442        assert!(is_pure(expr));
443        assert!(!is_impure(expr));
444    }
445
446    fn expect_impure(expr: &ExprImpl) {
447        assert!(!is_pure(expr));
448        assert!(is_impure(expr));
449    }
450
451    #[test]
452    fn test_pure_funcs() {
453        let e: ExprImpl = FunctionCall::new(
454            Type::Add,
455            vec![
456                InputRef::new(0, DataType::Int16).into(),
457                InputRef::new(0, DataType::Int16).into(),
458            ],
459        )
460        .unwrap()
461        .into();
462        expect_pure(&e);
463
464        let e: ExprImpl = FunctionCall::new(
465            Type::GreaterThan,
466            vec![
467                InputRef::new(0, DataType::Timestamptz).into(),
468                FunctionCall::new(Type::Proctime, vec![]).unwrap().into(),
469            ],
470        )
471        .unwrap()
472        .into();
473        expect_impure(&e);
474    }
475}