risingwave_frontend/optimizer/plan_expr_visitor/
strong.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use fixedbitset::FixedBitSet;
16
17use crate::expr::{ExprImpl, ExprType, FunctionCall, InputRef};
18
19/// This utilities are with the same definition in calcite.
20/// Utilities for strong predicates.
21/// A predicate is strong (or null-rejecting) with regards to selected subset of inputs
22/// if it is UNKNOWN if all inputs in selected subset are UNKNOWN.
23/// By the way, UNKNOWN is just the boolean form of NULL.
24///
25/// Examples:
26///
27/// UNKNOWN is strong in `[]` (definitely null)
28///
29/// `c = 1` is strong in `[c]` (definitely null if and only if c is null)
30///
31/// `c IS NULL` is not strong (always returns TRUE or FALSE, nevernull)
32///
33/// `p1 AND p2` is strong in `[p1, p2]` (definitely null if either p1 is null or p2 is null)
34///
35/// `p1 OR p2` is strong if p1 and p2 are strong
36
37#[derive(Default)]
38pub struct Strong {
39    null_columns: FixedBitSet,
40}
41
42impl Strong {
43    fn new(null_columns: FixedBitSet) -> Self {
44        Self { null_columns }
45    }
46
47    /// Returns whether the analyzed expression will *definitely* return null if
48    /// all of a given set of input columns are null.
49    /// Note: we could not assume any null-related property for the input expression if `is_null` returns false
50    pub fn is_null(expr: &ExprImpl, null_columns: FixedBitSet) -> bool {
51        let strong = Strong::new(null_columns);
52        strong.is_null_visit(expr)
53    }
54
55    fn is_input_ref_null(&self, input_ref: &InputRef) -> bool {
56        self.null_columns.contains(input_ref.index())
57    }
58
59    fn is_null_visit(&self, expr: &ExprImpl) -> bool {
60        match expr {
61            ExprImpl::InputRef(input_ref) => self.is_input_ref_null(input_ref),
62            ExprImpl::Literal(literal) => literal.get_data().is_none(),
63            ExprImpl::FunctionCall(func_call) => self.is_null_function_call(func_call),
64            ExprImpl::FunctionCallWithLambda(_) => false,
65            ExprImpl::AggCall(_) => false,
66            ExprImpl::Subquery(_) => false,
67            ExprImpl::CorrelatedInputRef(_) => false,
68            ExprImpl::TableFunction(_) => false,
69            ExprImpl::WindowFunction(_) => false,
70            ExprImpl::UserDefinedFunction(_) => false,
71            ExprImpl::Parameter(_) => false,
72            ExprImpl::Now(_) => false,
73        }
74    }
75
76    fn is_null_function_call(&self, func_call: &FunctionCall) -> bool {
77        match func_call.func_type() {
78            // NOT NULL: This kind of expression is never null. No need to look at its arguments, if it has any.
79            ExprType::IsNull
80            | ExprType::IsNotNull
81            | ExprType::IsDistinctFrom
82            | ExprType::IsNotDistinctFrom
83            | ExprType::IsTrue
84            | ExprType::QuoteNullable
85            | ExprType::IsNotTrue
86            | ExprType::IsFalse
87            | ExprType::IsNotFalse
88            | ExprType::CheckNotNull => false,
89            // ANY: This kind of expression is null if and only if at least one of its arguments is null.
90            ExprType::Not
91            | ExprType::Equal
92            | ExprType::NotEqual
93            | ExprType::LessThan
94            | ExprType::LessThanOrEqual
95            | ExprType::GreaterThan
96            | ExprType::GreaterThanOrEqual
97            | ExprType::Like
98            | ExprType::Add
99            | ExprType::AddWithTimeZone
100            | ExprType::Subtract
101            | ExprType::Multiply
102            | ExprType::Modulus
103            | ExprType::Divide
104            | ExprType::Cast
105            | ExprType::Trim
106            | ExprType::Ltrim
107            | ExprType::Rtrim
108            | ExprType::Ceil
109            | ExprType::Floor
110            | ExprType::Extract
111            | ExprType::Greatest
112            | ExprType::Least => self.any_null(func_call),
113            // ALL: This kind of expression is null if and only if all of its arguments are null.
114            ExprType::And | ExprType::Or | ExprType::Coalesce => self.all_null(func_call),
115            // TODO: Function like case when is important but current its structure is complicated, so we need to implement it later if necessary.
116            // Assume that any other expressions cannot be simplified.
117            ExprType::In
118            | ExprType::Some
119            | ExprType::All
120            | ExprType::BitwiseAnd
121            | ExprType::BitwiseOr
122            | ExprType::BitwiseXor
123            | ExprType::BitwiseNot
124            | ExprType::BitwiseShiftLeft
125            | ExprType::BitwiseShiftRight
126            | ExprType::DatePart
127            | ExprType::TumbleStart
128            | ExprType::MakeDate
129            | ExprType::MakeTime
130            | ExprType::MakeTimestamp
131            | ExprType::SecToTimestamptz
132            | ExprType::AtTimeZone
133            | ExprType::DateTrunc
134            | ExprType::CharToTimestamptz
135            | ExprType::CharToDate
136            | ExprType::CastWithTimeZone
137            | ExprType::SubtractWithTimeZone
138            | ExprType::MakeTimestamptz
139            | ExprType::Substr
140            | ExprType::Length
141            | ExprType::ILike
142            | ExprType::SimilarToEscape
143            | ExprType::Upper
144            | ExprType::Lower
145            | ExprType::Replace
146            | ExprType::Position
147            | ExprType::Case
148            | ExprType::ConstantLookup
149            | ExprType::RoundDigit
150            | ExprType::Round
151            | ExprType::Ascii
152            | ExprType::Translate
153            | ExprType::Concat
154            | ExprType::ConcatVariadic
155            | ExprType::ConcatWs
156            | ExprType::ConcatWsVariadic
157            | ExprType::Abs
158            | ExprType::SplitPart
159            | ExprType::ToChar
160            | ExprType::Md5
161            | ExprType::CharLength
162            | ExprType::Repeat
163            | ExprType::ConcatOp
164            | ExprType::BoolOut
165            | ExprType::OctetLength
166            | ExprType::BitLength
167            | ExprType::Overlay
168            | ExprType::RegexpMatch
169            | ExprType::RegexpReplace
170            | ExprType::RegexpCount
171            | ExprType::RegexpSplitToArray
172            | ExprType::RegexpEq
173            | ExprType::Pow
174            | ExprType::Exp
175            | ExprType::Chr
176            | ExprType::StartsWith
177            | ExprType::Initcap
178            | ExprType::Lpad
179            | ExprType::Rpad
180            | ExprType::Reverse
181            | ExprType::Strpos
182            | ExprType::ToAscii
183            | ExprType::ToHex
184            | ExprType::QuoteIdent
185            | ExprType::QuoteLiteral
186            | ExprType::Sin
187            | ExprType::Cos
188            | ExprType::Tan
189            | ExprType::Cot
190            | ExprType::Asin
191            | ExprType::Acos
192            | ExprType::Acosd
193            | ExprType::Atan
194            | ExprType::Atan2
195            | ExprType::Atand
196            | ExprType::Atan2d
197            | ExprType::Sind
198            | ExprType::Cosd
199            | ExprType::Cotd
200            | ExprType::Tand
201            | ExprType::Asind
202            | ExprType::Sqrt
203            | ExprType::Degrees
204            | ExprType::Radians
205            | ExprType::Cosh
206            | ExprType::Tanh
207            | ExprType::Coth
208            | ExprType::Asinh
209            | ExprType::Acosh
210            | ExprType::Atanh
211            | ExprType::Sinh
212            | ExprType::Trunc
213            | ExprType::Ln
214            | ExprType::Log10
215            | ExprType::Cbrt
216            | ExprType::Sign
217            | ExprType::Scale
218            | ExprType::MinScale
219            | ExprType::TrimScale
220            | ExprType::Encode
221            | ExprType::Decode
222            | ExprType::Sha1
223            | ExprType::Sha224
224            | ExprType::Sha256
225            | ExprType::Sha384
226            | ExprType::Sha512
227            | ExprType::Hmac
228            | ExprType::SecureCompare
229            | ExprType::Left
230            | ExprType::Right
231            | ExprType::Format
232            | ExprType::FormatVariadic
233            | ExprType::PgwireSend
234            | ExprType::PgwireRecv
235            | ExprType::ConvertFrom
236            | ExprType::ConvertTo
237            | ExprType::Decrypt
238            | ExprType::Encrypt
239            | ExprType::Neg
240            | ExprType::Field
241            | ExprType::Array
242            | ExprType::ArrayAccess
243            | ExprType::Row
244            | ExprType::ArrayToString
245            | ExprType::ArrayRangeAccess
246            | ExprType::ArrayCat
247            | ExprType::ArrayAppend
248            | ExprType::ArrayPrepend
249            | ExprType::FormatType
250            | ExprType::ArrayDistinct
251            | ExprType::ArrayLength
252            | ExprType::Cardinality
253            | ExprType::ArrayRemove
254            | ExprType::ArrayPositions
255            | ExprType::TrimArray
256            | ExprType::StringToArray
257            | ExprType::ArrayPosition
258            | ExprType::ArrayReplace
259            | ExprType::ArrayDims
260            | ExprType::ArrayTransform
261            | ExprType::ArrayMin
262            | ExprType::ArrayMax
263            | ExprType::ArraySum
264            | ExprType::ArraySort
265            | ExprType::ArrayContains
266            | ExprType::ArrayContained
267            | ExprType::HexToInt256
268            | ExprType::JsonbAccess
269            | ExprType::JsonbAccessStr
270            | ExprType::JsonbExtractPath
271            | ExprType::JsonbExtractPathVariadic
272            | ExprType::JsonbExtractPathText
273            | ExprType::JsonbExtractPathTextVariadic
274            | ExprType::JsonbTypeof
275            | ExprType::JsonbArrayLength
276            | ExprType::IsJson
277            | ExprType::JsonbConcat
278            | ExprType::JsonbObject
279            | ExprType::JsonbPretty
280            | ExprType::JsonbContains
281            | ExprType::JsonbContained
282            | ExprType::JsonbExists
283            | ExprType::JsonbExistsAny
284            | ExprType::JsonbExistsAll
285            | ExprType::JsonbDeletePath
286            | ExprType::JsonbStripNulls
287            | ExprType::ToJsonb
288            | ExprType::JsonbBuildArray
289            | ExprType::JsonbBuildArrayVariadic
290            | ExprType::JsonbBuildObject
291            | ExprType::JsonbBuildObjectVariadic
292            | ExprType::JsonbPathExists
293            | ExprType::JsonbPathMatch
294            | ExprType::JsonbPathQueryArray
295            | ExprType::JsonbPathQueryFirst
296            | ExprType::JsonbPopulateRecord
297            | ExprType::JsonbToRecord
298            | ExprType::JsonbSet
299            | ExprType::JsonbPopulateMap
300            | ExprType::MapFromEntries
301            | ExprType::MapAccess
302            | ExprType::MapKeys
303            | ExprType::MapValues
304            | ExprType::MapEntries
305            | ExprType::MapFromKeyValues
306            | ExprType::MapCat
307            | ExprType::MapContains
308            | ExprType::MapDelete
309            | ExprType::MapInsert
310            | ExprType::MapLength
311            | ExprType::Vnode
312            | ExprType::VnodeUser
313            | ExprType::TestPaidTier
314            | ExprType::License
315            | ExprType::Proctime
316            | ExprType::PgSleep
317            | ExprType::PgSleepFor
318            | ExprType::PgSleepUntil
319            | ExprType::CastRegclass
320            | ExprType::PgGetIndexdef
321            | ExprType::ColDescription
322            | ExprType::PgGetViewdef
323            | ExprType::PgGetUserbyid
324            | ExprType::PgIndexesSize
325            | ExprType::PgRelationSize
326            | ExprType::PgGetSerialSequence
327            | ExprType::PgIndexColumnHasProperty
328            | ExprType::PgIsInRecovery
329            | ExprType::PgTableIsVisible
330            | ExprType::RwRecoveryStatus
331            | ExprType::IcebergTransform
332            | ExprType::HasTablePrivilege
333            | ExprType::HasFunctionPrivilege
334            | ExprType::HasAnyColumnPrivilege
335            | ExprType::HasSchemaPrivilege
336            | ExprType::InetAton
337            | ExprType::InetNtoa
338            | ExprType::RwEpochToTs => false,
339            ExprType::Unspecified => unreachable!(),
340        }
341    }
342
343    fn any_null(&self, func_call: &FunctionCall) -> bool {
344        func_call
345            .inputs()
346            .iter()
347            .any(|expr| self.is_null_visit(expr))
348    }
349
350    fn all_null(&self, func_call: &FunctionCall) -> bool {
351        func_call
352            .inputs()
353            .iter()
354            .all(|expr| self.is_null_visit(expr))
355    }
356}
357
358#[cfg(test)]
359mod tests {
360    use risingwave_common::types::DataType;
361
362    use super::*;
363    use crate::expr::ExprImpl::Literal;
364
365    #[test]
366    fn test_literal() {
367        let null_columns = FixedBitSet::with_capacity(1);
368        let expr = Literal(crate::expr::Literal::new(None, DataType::Varchar).into());
369        assert!(Strong::is_null(&expr, null_columns.clone()));
370
371        let expr = Literal(
372            crate::expr::Literal::new(Some("test".to_owned().into()), DataType::Varchar).into(),
373        );
374        assert!(!Strong::is_null(&expr, null_columns));
375    }
376
377    #[test]
378    fn test_input_ref1() {
379        let null_columns = FixedBitSet::with_capacity(2);
380        let expr = InputRef::new(0, DataType::Varchar).into();
381        assert!(!Strong::is_null(&expr, null_columns.clone()));
382
383        let expr = InputRef::new(1, DataType::Varchar).into();
384        assert!(!Strong::is_null(&expr, null_columns));
385    }
386
387    #[test]
388    fn test_input_ref2() {
389        let mut null_columns = FixedBitSet::with_capacity(2);
390        null_columns.insert(0);
391        null_columns.insert(1);
392        let expr = InputRef::new(0, DataType::Varchar).into();
393        assert!(Strong::is_null(&expr, null_columns.clone()));
394
395        let expr = InputRef::new(1, DataType::Varchar).into();
396        assert!(Strong::is_null(&expr, null_columns));
397    }
398
399    #[test]
400    fn test_c1_equal_1_or_c2_is_null() {
401        let mut null_columns = FixedBitSet::with_capacity(2);
402        null_columns.insert(0);
403        let expr = FunctionCall::new_unchecked(
404            ExprType::Or,
405            vec![
406                FunctionCall::new_unchecked(
407                    ExprType::Equal,
408                    vec![
409                        InputRef::new(0, DataType::Int64).into(),
410                        Literal(crate::expr::Literal::new(Some(1.into()), DataType::Int32).into()),
411                    ],
412                    DataType::Boolean,
413                )
414                .into(),
415                FunctionCall::new_unchecked(
416                    ExprType::IsNull,
417                    vec![InputRef::new(1, DataType::Int64).into()],
418                    DataType::Boolean,
419                )
420                .into(),
421            ],
422            DataType::Boolean,
423        )
424        .into();
425        assert!(!Strong::is_null(&expr, null_columns));
426    }
427
428    #[test]
429    fn test_divide() {
430        let mut null_columns = FixedBitSet::with_capacity(2);
431        null_columns.insert(0);
432        null_columns.insert(1);
433        let expr = FunctionCall::new_unchecked(
434            ExprType::Divide,
435            vec![
436                InputRef::new(0, DataType::Decimal).into(),
437                InputRef::new(1, DataType::Decimal).into(),
438            ],
439            DataType::Varchar,
440        )
441        .into();
442        assert!(Strong::is_null(&expr, null_columns));
443    }
444
445    /// generate a test case for (0.8 * sum / count) where sum is null and count is not null
446    #[test]
447    fn test_multiply_divide() {
448        let mut null_columns = FixedBitSet::with_capacity(2);
449        null_columns.insert(0);
450        let expr = FunctionCall::new_unchecked(
451            ExprType::Multiply,
452            vec![
453                Literal(crate::expr::Literal::new(Some(0.8f64.into()), DataType::Float64).into()),
454                FunctionCall::new_unchecked(
455                    ExprType::Divide,
456                    vec![
457                        InputRef::new(0, DataType::Decimal).into(),
458                        InputRef::new(1, DataType::Decimal).into(),
459                    ],
460                    DataType::Decimal,
461                )
462                .into(),
463            ],
464            DataType::Decimal,
465        )
466        .into();
467        assert!(Strong::is_null(&expr, null_columns));
468    }
469
470    /// generate test cases for is not null
471    macro_rules! gen_test {
472        ($func:ident, $expr:expr, $expected:expr) => {
473            #[test]
474            fn $func() {
475                let null_columns = FixedBitSet::with_capacity(2);
476                let expr = $expr;
477                assert_eq!(Strong::is_null(&expr, null_columns), $expected);
478            }
479        };
480    }
481
482    gen_test!(
483        test_is_not_null,
484        FunctionCall::new_unchecked(
485            ExprType::IsNotNull,
486            vec![InputRef::new(0, DataType::Varchar).into()],
487            DataType::Varchar
488        )
489        .into(),
490        false
491    );
492    gen_test!(
493        test_is_null,
494        FunctionCall::new_unchecked(
495            ExprType::IsNull,
496            vec![InputRef::new(0, DataType::Varchar).into()],
497            DataType::Varchar
498        )
499        .into(),
500        false
501    );
502    gen_test!(
503        test_is_distinct_from,
504        FunctionCall::new_unchecked(
505            ExprType::IsDistinctFrom,
506            vec![
507                InputRef::new(0, DataType::Varchar).into(),
508                InputRef::new(1, DataType::Varchar).into()
509            ],
510            DataType::Varchar
511        )
512        .into(),
513        false
514    );
515    gen_test!(
516        test_is_not_distinct_from,
517        FunctionCall::new_unchecked(
518            ExprType::IsNotDistinctFrom,
519            vec![
520                InputRef::new(0, DataType::Varchar).into(),
521                InputRef::new(1, DataType::Varchar).into()
522            ],
523            DataType::Varchar
524        )
525        .into(),
526        false
527    );
528    gen_test!(
529        test_is_true,
530        FunctionCall::new_unchecked(
531            ExprType::IsTrue,
532            vec![InputRef::new(0, DataType::Varchar).into()],
533            DataType::Varchar
534        )
535        .into(),
536        false
537    );
538    gen_test!(
539        test_is_not_true,
540        FunctionCall::new_unchecked(
541            ExprType::IsNotTrue,
542            vec![InputRef::new(0, DataType::Varchar).into()],
543            DataType::Varchar
544        )
545        .into(),
546        false
547    );
548    gen_test!(
549        test_is_false,
550        FunctionCall::new_unchecked(
551            ExprType::IsFalse,
552            vec![InputRef::new(0, DataType::Varchar).into()],
553            DataType::Varchar
554        )
555        .into(),
556        false
557    );
558}