risingwave_frontend/expr/
pure.rs

1// Copyright 2023 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::borrow::Cow;
16
17use expr_node::Type;
18use risingwave_pb::expr::expr_node;
19
20use super::{ExprImpl, ExprVisitor};
21use crate::expr::FunctionCall;
22
23#[derive(Default)]
24pub(crate) struct ImpureAnalyzer {
25    impure: Option<Cow<'static, str>>,
26}
27
28impl ImpureAnalyzer {
29    /// Returns `true` if the expression is impure.
30    ///
31    /// Only call this method after visiting the expression.
32    pub fn is_impure(&self) -> bool {
33        self.impure.is_some()
34    }
35
36    /// Returns the description of the impure expression if it is impure, for error reporting.
37    /// `None` if the expression is pure.
38    ///
39    /// Only call this method after visiting the expression.
40    pub fn impure_expr_desc(&self) -> Option<&str> {
41        self.impure.as_deref()
42    }
43}
44
45impl ExprVisitor for ImpureAnalyzer {
46    fn visit_user_defined_function(&mut self, func_call: &super::UserDefinedFunction) {
47        let name = &func_call.catalog.name;
48        self.impure = Some(format!("user-defined function `{name}`").into());
49    }
50
51    fn visit_table_function(&mut self, func_call: &super::TableFunction) {
52        use crate::expr::table_function::TableFunctionType as Type;
53        let func_type = func_call.function_type;
54        match func_type {
55            Type::Unspecified => unreachable!(),
56
57            // deterministic
58            Type::GenerateSeries
59            | Type::Unnest
60            | Type::RegexpMatches
61            | Type::Range
62            | Type::GenerateSubscripts
63            | Type::PgExpandarray
64            | Type::JsonbArrayElements
65            | Type::JsonbArrayElementsText
66            | Type::JsonbEach
67            | Type::JsonbEachText
68            | Type::JsonbObjectKeys
69            | Type::JsonbPathQuery
70            | Type::JsonbPopulateRecordset
71            | Type::JsonbToRecordset => {
72                func_call.args.iter().for_each(|expr| self.visit_expr(expr));
73            }
74
75            // indeterministic
76            Type::FileScan
77            | Type::PostgresQuery
78            | Type::MysqlQuery
79            | Type::InternalBackfillProgress
80            | Type::InternalSourceBackfillProgress
81            | Type::InternalGetChannelDeltaStats
82            | Type::PgGetKeywords => {
83                self.impure = Some(func_type.as_str_name().into());
84            }
85            Type::UserDefined => {
86                let name = &func_call.user_defined.as_ref().unwrap().name;
87                self.impure = Some(format!("user-defined table function `{name}`").into());
88            }
89        }
90    }
91
92    fn visit_now(&mut self, _: &super::Now) {
93        self.impure = Some("NOW or PROCTIME".into());
94    }
95
96    fn visit_function_call(&mut self, func_call: &super::FunctionCall) {
97        let func_type = func_call.func_type();
98        match func_type {
99            Type::Unspecified => unreachable!(),
100            Type::Add
101            | Type::Subtract
102            | Type::Multiply
103            | Type::Divide
104            | Type::Modulus
105            | Type::Equal
106            | Type::NotEqual
107            | Type::LessThan
108            | Type::LessThanOrEqual
109            | Type::GreaterThan
110            | Type::GreaterThanOrEqual
111            | Type::And
112            | Type::Or
113            | Type::Not
114            | Type::In
115            | Type::Some
116            | Type::All
117            | Type::BitwiseAnd
118            | Type::BitwiseOr
119            | Type::BitwiseXor
120            | Type::BitwiseNot
121            | Type::BitwiseShiftLeft
122            | Type::BitwiseShiftRight
123            | Type::Extract
124            | Type::DatePart
125            | Type::TumbleStart
126            | Type::SecToTimestamptz
127            | Type::AtTimeZone
128            | Type::DateTrunc
129            | Type::DateBin
130            | Type::MakeDate
131            | Type::MakeTime
132            | Type::MakeTimestamp
133            | Type::CharToTimestamptz
134            | Type::CharToDate
135            | Type::CastWithTimeZone
136            | Type::AddWithTimeZone
137            | Type::SubtractWithTimeZone
138            | Type::Cast
139            | Type::Substr
140            | Type::Length
141            | Type::Like
142            | Type::ILike
143            | Type::SimilarToEscape
144            | Type::Upper
145            | Type::Lower
146            | Type::Trim
147            | Type::Replace
148            | Type::Position
149            | Type::Ltrim
150            | Type::Rtrim
151            | Type::Case
152            | Type::ConstantLookup
153            | Type::RoundDigit
154            | Type::Round
155            | Type::Ascii
156            | Type::Translate
157            | Type::Coalesce
158            | Type::ConcatWs
159            | Type::ConcatWsVariadic
160            | Type::Abs
161            | Type::SplitPart
162            | Type::Ceil
163            | Type::Floor
164            | Type::Trunc
165            | Type::ToChar
166            | Type::Md5
167            | Type::CharLength
168            | Type::Repeat
169            | Type::ConcatOp
170            | Type::ByteaConcatOp
171            | Type::Concat
172            | Type::ConcatVariadic
173            | Type::BoolOut
174            | Type::OctetLength
175            | Type::BitLength
176            | Type::Overlay
177            | Type::RegexpMatch
178            | Type::RegexpReplace
179            | Type::RegexpCount
180            | Type::RegexpSplitToArray
181            | Type::RegexpEq
182            | Type::Pow
183            | Type::Exp
184            | Type::Ln
185            | Type::Log10
186            | Type::Chr
187            | Type::StartsWith
188            | Type::Initcap
189            | Type::Lpad
190            | Type::Rpad
191            | Type::Reverse
192            | Type::Strpos
193            | Type::ToAscii
194            | Type::ToHex
195            | Type::QuoteIdent
196            | Type::Sin
197            | Type::Cos
198            | Type::Tan
199            | Type::Cot
200            | Type::Asin
201            | Type::Acos
202            | Type::Acosd
203            | Type::Atan
204            | Type::Atan2
205            | Type::Atand
206            | Type::Atan2d
207            | Type::Sqrt
208            | Type::Cbrt
209            | Type::Sign
210            | Type::Scale
211            | Type::MinScale
212            | Type::TrimScale
213            | Type::Gamma
214            | Type::Lgamma
215            | Type::Left
216            | Type::Right
217            | Type::Degrees
218            | Type::Radians
219            | Type::IsTrue
220            | Type::IsNotTrue
221            | Type::IsFalse
222            | Type::IsNotFalse
223            | Type::IsNull
224            | Type::IsNotNull
225            | Type::IsDistinctFrom
226            | Type::IsNotDistinctFrom
227            | Type::Neg
228            | Type::Field
229            | Type::Array
230            | Type::ArrayAccess
231            | Type::ArrayRangeAccess
232            | Type::Row
233            | Type::ArrayToString
234            | Type::ArrayCat
235            | Type::ArrayMax
236            | Type::ArraySum
237            | Type::ArraySort
238            | Type::ArrayAppend
239            | Type::ArrayReverse
240            | Type::ArrayPrepend
241            | Type::FormatType
242            | Type::ArrayDistinct
243            | Type::ArrayMin
244            | Type::ArrayDims
245            | Type::ArrayLength
246            | Type::Cardinality
247            | Type::TrimArray
248            | Type::ArrayRemove
249            | Type::ArrayReplace
250            | Type::ArrayPosition
251            | Type::ArrayContains
252            | Type::ArrayContained
253            | Type::ArrayFlatten
254            | Type::HexToInt256
255            | Type::JsonbConcat
256            | Type::JsonbAccess
257            | Type::JsonbAccessStr
258            | Type::JsonbExtractPath
259            | Type::JsonbExtractPathVariadic
260            | Type::JsonbExtractPathText
261            | Type::JsonbExtractPathTextVariadic
262            | Type::JsonbTypeof
263            | Type::JsonbArrayLength
264            | Type::JsonbObject
265            | Type::JsonbPretty
266            | Type::JsonbDeletePath
267            | Type::JsonbContains
268            | Type::JsonbContained
269            | Type::JsonbExists
270            | Type::JsonbExistsAny
271            | Type::JsonbExistsAll
272            | Type::JsonbStripNulls
273            | Type::JsonbBuildArray
274            | Type::JsonbBuildArrayVariadic
275            | Type::JsonbBuildObject
276            | Type::JsonbPopulateRecord
277            | Type::JsonbToArray
278            | Type::JsonbToRecord
279            | Type::JsonbBuildObjectVariadic
280            | Type::JsonbPathExists
281            | Type::JsonbPathMatch
282            | Type::JsonbPathQueryArray
283            | Type::JsonbPathQueryFirst
284            | Type::JsonbSet
285            | Type::JsonbPopulateMap
286            | Type::IsJson
287            | Type::ToJsonb
288            | Type::Sind
289            | Type::Cosd
290            | Type::Cotd
291            | Type::Asind
292            | Type::Sinh
293            | Type::Cosh
294            | Type::Coth
295            | Type::Tanh
296            | Type::Atanh
297            | Type::Asinh
298            | Type::Acosh
299            | Type::Decode
300            | Type::Encode
301            | Type::GetBit
302            | Type::GetByte
303            | Type::SetBit
304            | Type::SetByte
305            | Type::BitCount
306            | Type::Sha1
307            | Type::Sha224
308            | Type::Sha256
309            | Type::Sha384
310            | Type::Sha512
311            | Type::Hmac
312            | Type::SecureCompare
313            | Type::Decrypt
314            | Type::Encrypt
315            | Type::Tand
316            | Type::ArrayPositions
317            | Type::StringToArray
318            | Type::Format
319            | Type::FormatVariadic
320            | Type::PgwireSend
321            | Type::PgwireRecv
322            | Type::ArrayTransform
323            | Type::Greatest
324            | Type::Least
325            | Type::ConvertFrom
326            | Type::ConvertTo
327            | Type::IcebergTransform
328            | Type::InetNtoa
329            | Type::InetAton
330            | Type::QuoteLiteral
331            | Type::QuoteNullable
332            | Type::MapFromEntries
333            | Type::MapAccess
334            | Type::MapKeys
335            | Type::MapValues
336            | Type::MapEntries
337            | Type::MapFromKeyValues
338            | Type::MapCat
339            | Type::MapContains
340            | Type::MapDelete
341            | Type::MapFilter
342            | Type::MapInsert
343            | Type::MapLength
344            | Type::L2Distance
345            | Type::CosineDistance
346            | Type::L1Distance
347            | Type::InnerProduct
348            | Type::VecConcat
349            | Type::L2Norm
350            | Type::L2Normalize
351            | Type::Subvector
352            // TODO: `rw_vnode` is more like STABLE instead of IMMUTABLE, because even its result is
353            // deterministic, it needs to read the total vnode count from the context, which means that
354            // it cannot be evaluated during constant folding. We have to treat it pure here so it can be used
355            // internally without materialization.
356            | Type::Vnode
357            | Type::VnodeUser
358            | Type::RwEpochToTs
359            | Type::CheckNotNull
360            | Type::CompositeCast =>
361            // expression output is deterministic(same result for the same input)
362            {
363                func_call
364                    .inputs()
365                    .iter()
366                    .for_each(|expr| self.visit_expr(expr));
367            }
368            // expression output is not deterministic
369            Type::TestFeature
370            | Type::License
371            | Type::Proctime
372            | Type::PgSleep
373            | Type::PgSleepFor
374            | Type::PgSleepUntil
375            | Type::CastRegclass
376            | Type::PgGetIndexdef
377            | Type::ColDescription
378            | Type::PgGetViewdef
379            | Type::PgGetUserbyid
380            | Type::PgIndexesSize
381            | Type::PgRelationSize
382            | Type::PgGetSerialSequence
383            | Type::PgIndexColumnHasProperty
384            | Type::HasTablePrivilege
385            | Type::HasAnyColumnPrivilege
386            | Type::HasSchemaPrivilege
387            | Type::MakeTimestamptz
388            | Type::PgIsInRecovery
389            | Type::RwRecoveryStatus
390            | Type::RwClusterId
391            | Type::RwFragmentVnodes
392            | Type::RwActorVnodes
393            | Type::PgTableIsVisible
394            | Type::HasFunctionPrivilege
395            | Type::OpenaiEmbedding
396            | Type::HasDatabasePrivilege
397            | Type::Random => self.impure = Some(func_type.as_str_name().into()),
398        }
399    }
400}
401
402pub fn is_pure(expr: &ExprImpl) -> bool {
403    !is_impure(expr)
404}
405
406pub fn is_impure(expr: &ExprImpl) -> bool {
407    let mut a = ImpureAnalyzer::default();
408    a.visit_expr(expr);
409    a.is_impure()
410}
411
412pub fn is_impure_func_call(func_call: &FunctionCall) -> bool {
413    let mut a = ImpureAnalyzer::default();
414    a.visit_function_call(func_call);
415    a.is_impure()
416}
417
418/// Returns the description of the impure expression if it is impure, for error reporting.
419/// `None` if the expression is pure.
420pub fn impure_expr_desc(expr: &ExprImpl) -> Option<String> {
421    let mut a = ImpureAnalyzer::default();
422    a.visit_expr(expr);
423    a.impure_expr_desc().map(|s| s.to_owned())
424}
425
426#[cfg(test)]
427mod tests {
428    use risingwave_common::types::DataType;
429    use risingwave_pb::expr::expr_node::Type;
430
431    use crate::expr::{ExprImpl, FunctionCall, InputRef, is_impure, is_pure};
432
433    fn expect_pure(expr: &ExprImpl) {
434        assert!(is_pure(expr));
435        assert!(!is_impure(expr));
436    }
437
438    fn expect_impure(expr: &ExprImpl) {
439        assert!(!is_pure(expr));
440        assert!(is_impure(expr));
441    }
442
443    #[test]
444    fn test_pure_funcs() {
445        let e: ExprImpl = FunctionCall::new(
446            Type::Add,
447            vec![
448                InputRef::new(0, DataType::Int16).into(),
449                InputRef::new(0, DataType::Int16).into(),
450            ],
451        )
452        .unwrap()
453        .into();
454        expect_pure(&e);
455
456        let e: ExprImpl = FunctionCall::new(
457            Type::GreaterThan,
458            vec![
459                InputRef::new(0, DataType::Timestamptz).into(),
460                FunctionCall::new(Type::Proctime, vec![]).unwrap().into(),
461            ],
462        )
463        .unwrap()
464        .into();
465        expect_impure(&e);
466    }
467}