risingwave_frontend/expr/
pure.rs

1// Copyright 2023 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::borrow::Cow;
16
17use expr_node::Type;
18use risingwave_pb::expr::expr_node;
19
20use super::{ExprImpl, ExprVisitor};
21use crate::expr::FunctionCall;
22
23#[derive(Default)]
24pub(crate) struct ImpureAnalyzer {
25    impure: Option<Cow<'static, str>>,
26}
27
28impl ImpureAnalyzer {
29    /// Returns `true` if the expression is impure.
30    ///
31    /// Only call this method after visiting the expression.
32    pub fn is_impure(&self) -> bool {
33        self.impure.is_some()
34    }
35
36    /// Returns the description of the impure expression if it is impure, for error reporting.
37    /// `None` if the expression is pure.
38    ///
39    /// Only call this method after visiting the expression.
40    pub fn impure_expr_desc(&self) -> Option<&str> {
41        self.impure.as_deref()
42    }
43}
44
45impl ExprVisitor for ImpureAnalyzer {
46    fn visit_user_defined_function(&mut self, func_call: &super::UserDefinedFunction) {
47        let name = &func_call.catalog.name;
48        self.impure = Some(format!("user-defined function `{name}`").into());
49    }
50
51    fn visit_table_function(&mut self, func_call: &super::TableFunction) {
52        use crate::expr::table_function::TableFunctionType as Type;
53        let func_type = func_call.function_type;
54        match func_type {
55            Type::Unspecified => unreachable!(),
56
57            // deterministic
58            Type::GenerateSeries
59            | Type::Unnest
60            | Type::RegexpMatches
61            | Type::Range
62            | Type::GenerateSubscripts
63            | Type::PgExpandarray
64            | Type::JsonbArrayElements
65            | Type::JsonbArrayElementsText
66            | Type::JsonbEach
67            | Type::JsonbEachText
68            | Type::JsonbObjectKeys
69            | Type::JsonbPathQuery
70            | Type::JsonbPopulateRecordset
71            | Type::JsonbToRecordset => {
72                func_call.args.iter().for_each(|expr| self.visit_expr(expr));
73            }
74
75            // indeterministic
76            Type::FileScan
77            | Type::PostgresQuery
78            | Type::MysqlQuery
79            | Type::InternalBackfillProgress
80            | Type::InternalSourceBackfillProgress
81            | Type::InternalGetChannelDeltaStats
82            | Type::PgGetKeywords => {
83                self.impure = Some(func_type.as_str_name().into());
84            }
85            Type::UserDefined => {
86                let name = &func_call.user_defined.as_ref().unwrap().name;
87                self.impure = Some(format!("user-defined table function `{name}`").into());
88            }
89        }
90    }
91
92    fn visit_now(&mut self, _: &super::Now) {
93        self.impure = Some("NOW or PROCTIME".into());
94    }
95
96    fn visit_secret_ref(&mut self, secret_ref: &super::SecretRef) {
97        self.impure = Some(format!("secret reference `{}`", secret_ref.secret_name).into());
98    }
99
100    fn visit_function_call(&mut self, func_call: &super::FunctionCall) {
101        let func_type = func_call.func_type();
102        match func_type {
103            Type::Unspecified => unreachable!(),
104            #[expect(deprecated)]
105            Type::Add
106            | Type::Subtract
107            | Type::Multiply
108            | Type::Divide
109            | Type::Modulus
110            | Type::Equal
111            | Type::NotEqual
112            | Type::LessThan
113            | Type::LessThanOrEqual
114            | Type::GreaterThan
115            | Type::GreaterThanOrEqual
116            | Type::And
117            | Type::Or
118            | Type::Not
119            | Type::In
120            | Type::Some
121            | Type::All
122            | Type::BitwiseAnd
123            | Type::BitwiseOr
124            | Type::BitwiseXor
125            | Type::BitwiseNot
126            | Type::BitwiseShiftLeft
127            | Type::BitwiseShiftRight
128            | Type::Extract
129            | Type::DatePart
130            | Type::TumbleStart
131            | Type::SecToTimestamptz
132            | Type::AtTimeZone
133            | Type::DateTrunc
134            | Type::DateBin
135            | Type::MakeDate
136            | Type::MakeTime
137            | Type::MakeTimestamp
138            | Type::CharToTimestamptz
139            | Type::CharToDate
140            | Type::CastWithTimeZone
141            | Type::AddWithTimeZone
142            | Type::SubtractWithTimeZone
143            | Type::Cast
144            | Type::Substr
145            | Type::Length
146            | Type::Like
147            | Type::ILike
148            | Type::SimilarToEscape
149            | Type::Upper
150            | Type::Lower
151            | Type::Trim
152            | Type::Replace
153            | Type::Position
154            | Type::Ltrim
155            | Type::Rtrim
156            | Type::Case
157            | Type::ConstantLookup
158            | Type::RoundDigit
159            | Type::Round
160            | Type::Ascii
161            | Type::Translate
162            | Type::Coalesce
163            | Type::ConcatWs
164            | Type::ConcatWsVariadic
165            | Type::Abs
166            | Type::SplitPart
167            | Type::Ceil
168            | Type::Floor
169            | Type::Trunc
170            | Type::ToChar
171            | Type::Md5
172            | Type::CharLength
173            | Type::Repeat
174            | Type::ConcatOp
175            | Type::ByteaConcatOp
176            | Type::Concat
177            | Type::ConcatVariadic
178            | Type::BoolOut
179            | Type::OctetLength
180            | Type::BitLength
181            | Type::Overlay
182            | Type::RegexpMatch
183            | Type::RegexpReplace
184            | Type::RegexpCount
185            | Type::RegexpSplitToArray
186            | Type::RegexpEq
187            | Type::Pow
188            | Type::Exp
189            | Type::Ln
190            | Type::Log10
191            | Type::Chr
192            | Type::StartsWith
193            | Type::Initcap
194            | Type::Lpad
195            | Type::Rpad
196            | Type::Reverse
197            | Type::Strpos
198            | Type::ToAscii
199            | Type::ToHex
200            | Type::QuoteIdent
201            | Type::Sin
202            | Type::Cos
203            | Type::Tan
204            | Type::Cot
205            | Type::Asin
206            | Type::Acos
207            | Type::Acosd
208            | Type::Atan
209            | Type::Atan2
210            | Type::Atand
211            | Type::Atan2d
212            | Type::Sqrt
213            | Type::Cbrt
214            | Type::Sign
215            | Type::Scale
216            | Type::MinScale
217            | Type::TrimScale
218            | Type::Gamma
219            | Type::Lgamma
220            | Type::Left
221            | Type::Right
222            | Type::Degrees
223            | Type::Radians
224            | Type::IsTrue
225            | Type::IsNotTrue
226            | Type::IsFalse
227            | Type::IsNotFalse
228            | Type::IsNull
229            | Type::IsNotNull
230            | Type::IsDistinctFrom
231            | Type::IsNotDistinctFrom
232            | Type::Neg
233            | Type::Field
234            | Type::Array
235            | Type::ArrayAccess
236            | Type::ArrayRangeAccess
237            | Type::Row
238            | Type::ArrayToString
239            | Type::ArrayCat
240            | Type::ArrayMax
241            | Type::ArraySum
242            | Type::ArraySort
243            | Type::ArrayAppend
244            | Type::ArrayReverse
245            | Type::ArrayPrepend
246            | Type::FormatType
247            | Type::ArrayDistinct
248            | Type::ArrayMin
249            | Type::ArrayDims
250            | Type::ArrayLength
251            | Type::Cardinality
252            | Type::TrimArray
253            | Type::ArrayRemove
254            | Type::ArrayReplace
255            | Type::ArrayPosition
256            | Type::ArrayContains
257            | Type::ArrayContained
258            | Type::ArrayFlatten
259            | Type::HexToInt256
260            | Type::JsonbConcat
261            | Type::JsonbAccess
262            | Type::JsonbAccessStr
263            | Type::JsonbExtractPath
264            | Type::JsonbExtractPathVariadic
265            | Type::JsonbExtractPathText
266            | Type::JsonbExtractPathTextVariadic
267            | Type::JsonbTypeof
268            | Type::JsonbArrayLength
269            | Type::JsonbObject
270            | Type::JsonbPretty
271            | Type::JsonbDeletePath
272            | Type::JsonbContains
273            | Type::JsonbContained
274            | Type::JsonbExists
275            | Type::JsonbExistsAny
276            | Type::JsonbExistsAll
277            | Type::JsonbStripNulls
278            | Type::JsonbBuildArray
279            | Type::JsonbBuildArrayVariadic
280            | Type::JsonbBuildObject
281            | Type::JsonbPopulateRecord
282            | Type::JsonbToArray
283            | Type::JsonbToRecord
284            | Type::JsonbBuildObjectVariadic
285            | Type::JsonbPathExists
286            | Type::JsonbPathMatch
287            | Type::JsonbPathQueryArray
288            | Type::JsonbPathQueryFirst
289            | Type::JsonbSet
290            | Type::JsonbPopulateMap
291            | Type::IsJson
292            | Type::ToJsonb
293            | Type::Sind
294            | Type::Cosd
295            | Type::Cotd
296            | Type::Asind
297            | Type::Sinh
298            | Type::Cosh
299            | Type::Coth
300            | Type::Tanh
301            | Type::Atanh
302            | Type::Asinh
303            | Type::Acosh
304            | Type::Decode
305            | Type::Encode
306            | Type::GetBit
307            | Type::GetByte
308            | Type::SetBit
309            | Type::SetByte
310            | Type::BitCount
311            | Type::Sha1
312            | Type::Sha224
313            | Type::Sha256
314            | Type::Sha384
315            | Type::Sha512
316            | Type::Crc32
317            | Type::Crc32c
318            | Type::Hmac
319            | Type::SecureCompare
320            | Type::Decrypt
321            | Type::Encrypt
322            | Type::Tand
323            | Type::ArrayPositions
324            | Type::StringToArray
325            | Type::Format
326            | Type::FormatVariadic
327            | Type::PgwireSend
328            | Type::PgwireRecv
329            | Type::ArrayTransform
330            | Type::Greatest
331            | Type::Least
332            | Type::ConvertFrom
333            | Type::ConvertTo
334            | Type::IcebergTransform
335            | Type::InetNtoa
336            | Type::InetAton
337            | Type::QuoteLiteral
338            | Type::QuoteNullable
339            | Type::MapFromEntries
340            | Type::MapAccess
341            | Type::MapKeys
342            | Type::MapValues
343            | Type::MapEntries
344            | Type::MapFromKeyValues
345            | Type::MapCat
346            | Type::MapContains
347            | Type::MapDelete
348            | Type::MapFilter
349            | Type::MapInsert
350            | Type::MapLength
351            | Type::L2Distance
352            | Type::CosineDistance
353            | Type::L1Distance
354            | Type::InnerProduct
355            | Type::VecConcat
356            | Type::L2Norm
357            | Type::L2Normalize
358            | Type::Subvector
359            // TODO: `rw_vnode` is more like STABLE instead of IMMUTABLE, because even its result is
360            // deterministic, it needs to read the total vnode count from the context, which means that
361            // it cannot be evaluated during constant folding. We have to treat it pure here so it can be used
362            // internally without materialization.
363            | Type::Vnode
364            | Type::VnodeUser
365            | Type::RwEpochToTs
366            | Type::CheckNotNull
367            | Type::CompositeCast =>
368            // expression output is deterministic(same result for the same input)
369            {
370                func_call
371                    .inputs()
372                    .iter()
373                    .for_each(|expr| self.visit_expr(expr));
374            }
375            // expression output is not deterministic
376            Type::TestFeature
377            | Type::License
378            | Type::Proctime
379            | Type::PgSleep
380            | Type::PgSleepFor
381            | Type::PgSleepUntil
382            | Type::CastRegclass
383            | Type::PgGetIndexdef
384            | Type::ColDescription
385            | Type::PgGetViewdef
386            | Type::PgGetUserbyid
387            | Type::PgIndexesSize
388            | Type::PgRelationSize
389            | Type::PgGetSerialSequence
390            | Type::PgIndexColumnHasProperty
391            | Type::HasTablePrivilege
392            | Type::HasAnyColumnPrivilege
393            | Type::HasSchemaPrivilege
394            | Type::MakeTimestamptz
395            | Type::PgIsInRecovery
396            | Type::RwRecoveryStatus
397            | Type::RwClusterId
398            | Type::RwFragmentVnodes
399            | Type::RwActorVnodes
400            | Type::PgTableIsVisible
401            | Type::HasFunctionPrivilege
402            | Type::OpenaiEmbedding
403            | Type::HasDatabasePrivilege
404            | Type::Random => self.impure = Some(func_type.as_str_name().into()),
405        }
406    }
407}
408
409pub fn is_pure(expr: &ExprImpl) -> bool {
410    !is_impure(expr)
411}
412
413pub fn is_impure(expr: &ExprImpl) -> bool {
414    let mut a = ImpureAnalyzer::default();
415    a.visit_expr(expr);
416    a.is_impure()
417}
418
419pub fn is_impure_func_call(func_call: &FunctionCall) -> bool {
420    let mut a = ImpureAnalyzer::default();
421    a.visit_function_call(func_call);
422    a.is_impure()
423}
424
425/// Returns the description of the impure expression if it is impure, for error reporting.
426/// `None` if the expression is pure.
427pub fn impure_expr_desc(expr: &ExprImpl) -> Option<String> {
428    let mut a = ImpureAnalyzer::default();
429    a.visit_expr(expr);
430    a.impure_expr_desc().map(|s| s.to_owned())
431}
432
433#[cfg(test)]
434mod tests {
435    use risingwave_common::types::DataType;
436    use risingwave_pb::expr::expr_node::Type;
437
438    use crate::expr::{ExprImpl, FunctionCall, InputRef, is_impure, is_pure};
439
440    fn expect_pure(expr: &ExprImpl) {
441        assert!(is_pure(expr));
442        assert!(!is_impure(expr));
443    }
444
445    fn expect_impure(expr: &ExprImpl) {
446        assert!(!is_pure(expr));
447        assert!(is_impure(expr));
448    }
449
450    #[test]
451    fn test_pure_funcs() {
452        let e: ExprImpl = FunctionCall::new(
453            Type::Add,
454            vec![
455                InputRef::new(0, DataType::Int16).into(),
456                InputRef::new(0, DataType::Int16).into(),
457            ],
458        )
459        .unwrap()
460        .into();
461        expect_pure(&e);
462
463        let e: ExprImpl = FunctionCall::new(
464            Type::GreaterThan,
465            vec![
466                InputRef::new(0, DataType::Timestamptz).into(),
467                FunctionCall::new(Type::Proctime, vec![]).unwrap().into(),
468            ],
469        )
470        .unwrap()
471        .into();
472        expect_impure(&e);
473    }
474}