risingwave_frontend/optimizer/plan_expr_visitor/
strong.rs

1// Copyright 2024 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use fixedbitset::FixedBitSet;
16
17use crate::expr::{ExprImpl, ExprType, FunctionCall, InputRef};
18
19/// This utilities are with the same definition in calcite.
20/// Utilities for strong predicates.
21/// A predicate is strong (or null-rejecting) with regards to selected subset of inputs
22/// if it is UNKNOWN if all inputs in selected subset are UNKNOWN.
23/// By the way, UNKNOWN is just the boolean form of NULL.
24///
25/// Examples:
26///
27/// UNKNOWN is strong in `[]` (definitely null)
28///
29/// `c = 1` is strong in `[c]` (definitely null if and only if c is null)
30///
31/// `c IS NULL` is not strong (always returns TRUE or FALSE, nevernull)
32///
33/// `p1 AND p2` is strong in `[p1, p2]` (definitely null if either p1 is null or p2 is null)
34///
35/// `p1 OR p2` is strong if p1 and p2 are strong
36
37#[derive(Default)]
38pub struct Strong {
39    null_columns: FixedBitSet,
40}
41
42impl Strong {
43    fn new(null_columns: FixedBitSet) -> Self {
44        Self { null_columns }
45    }
46
47    /// Returns whether the analyzed expression will *definitely* return null if
48    /// all of a given set of input columns are null.
49    /// Note: we could not assume any null-related property for the input expression if `is_null` returns false
50    pub fn is_null(expr: &ExprImpl, null_columns: FixedBitSet) -> bool {
51        let strong = Strong::new(null_columns);
52        strong.is_null_visit(expr)
53    }
54
55    fn is_input_ref_null(&self, input_ref: &InputRef) -> bool {
56        self.null_columns.contains(input_ref.index())
57    }
58
59    fn is_null_visit(&self, expr: &ExprImpl) -> bool {
60        match expr {
61            ExprImpl::InputRef(input_ref) => self.is_input_ref_null(input_ref),
62            ExprImpl::Literal(literal) => literal.get_data().is_none(),
63            ExprImpl::FunctionCall(func_call) => self.is_null_function_call(func_call),
64            ExprImpl::FunctionCallWithLambda(_) => false,
65            ExprImpl::AggCall(_) => false,
66            ExprImpl::Subquery(_) => false,
67            ExprImpl::CorrelatedInputRef(_) => false,
68            ExprImpl::TableFunction(_) => false,
69            ExprImpl::WindowFunction(_) => false,
70            ExprImpl::UserDefinedFunction(_) => false,
71            ExprImpl::Parameter(_) => false,
72            ExprImpl::Now(_) | ExprImpl::SecretRef(_) => false,
73        }
74    }
75
76    fn is_null_function_call(&self, func_call: &FunctionCall) -> bool {
77        match func_call.func_type() {
78            // NOT NULL: This kind of expression is never null. No need to look at its arguments, if it has any.
79            ExprType::IsNull
80            | ExprType::IsNotNull
81            | ExprType::IsDistinctFrom
82            | ExprType::IsNotDistinctFrom
83            | ExprType::IsTrue
84            | ExprType::QuoteNullable
85            | ExprType::IsNotTrue
86            | ExprType::IsFalse
87            | ExprType::IsNotFalse
88            | ExprType::CheckNotNull => false,
89            // ANY: This kind of expression is null if and only if at least one of its arguments is null.
90            ExprType::Not
91            | ExprType::Equal
92            | ExprType::NotEqual
93            | ExprType::LessThan
94            | ExprType::LessThanOrEqual
95            | ExprType::GreaterThan
96            | ExprType::GreaterThanOrEqual
97            | ExprType::Like
98            | ExprType::Add
99            | ExprType::AddWithTimeZone
100            | ExprType::Subtract
101            | ExprType::Multiply
102            | ExprType::Modulus
103            | ExprType::Divide
104            | ExprType::Cast
105            | ExprType::Trim
106            | ExprType::Ltrim
107            | ExprType::Rtrim
108            | ExprType::Ceil
109            | ExprType::Floor
110            | ExprType::Extract
111            | ExprType::L2Distance
112            | ExprType::CosineDistance
113            | ExprType::L1Distance
114            | ExprType::InnerProduct
115            | ExprType::VecConcat
116            | ExprType::L2Norm
117            | ExprType::L2Normalize
118            | ExprType::Subvector
119            | ExprType::Greatest
120            | ExprType::Least => self.any_null(func_call),
121            // ALL: This kind of expression is null if and only if all of its arguments are null.
122            ExprType::And | ExprType::Or | ExprType::Coalesce => self.all_null(func_call),
123            // TODO: Function like case when is important but current its structure is complicated, so we need to implement it later if necessary.
124            // Assume that any other expressions cannot be simplified.
125            #[expect(deprecated)]
126            ExprType::In
127            | ExprType::Some
128            | ExprType::All
129            | ExprType::BitwiseAnd
130            | ExprType::BitwiseOr
131            | ExprType::BitwiseXor
132            | ExprType::BitwiseNot
133            | ExprType::BitwiseShiftLeft
134            | ExprType::BitwiseShiftRight
135            | ExprType::DatePart
136            | ExprType::TumbleStart
137            | ExprType::MakeDate
138            | ExprType::MakeTime
139            | ExprType::MakeTimestamp
140            | ExprType::SecToTimestamptz
141            | ExprType::AtTimeZone
142            | ExprType::DateTrunc
143            | ExprType::DateBin
144            | ExprType::CharToTimestamptz
145            | ExprType::CharToDate
146            | ExprType::CastWithTimeZone
147            | ExprType::SubtractWithTimeZone
148            | ExprType::MakeTimestamptz
149            | ExprType::Substr
150            | ExprType::Length
151            | ExprType::ILike
152            | ExprType::SimilarToEscape
153            | ExprType::Upper
154            | ExprType::Lower
155            | ExprType::Replace
156            | ExprType::Position
157            | ExprType::Case
158            | ExprType::ConstantLookup
159            | ExprType::RoundDigit
160            | ExprType::Round
161            | ExprType::Ascii
162            | ExprType::Translate
163            | ExprType::Concat
164            | ExprType::ConcatVariadic
165            | ExprType::ConcatWs
166            | ExprType::ConcatWsVariadic
167            | ExprType::Abs
168            | ExprType::SplitPart
169            | ExprType::ToChar
170            | ExprType::Md5
171            | ExprType::CharLength
172            | ExprType::Repeat
173            | ExprType::ConcatOp
174            | ExprType::ByteaConcatOp
175            | ExprType::BoolOut
176            | ExprType::OctetLength
177            | ExprType::BitLength
178            | ExprType::Overlay
179            | ExprType::RegexpMatch
180            | ExprType::RegexpReplace
181            | ExprType::RegexpCount
182            | ExprType::RegexpSplitToArray
183            | ExprType::RegexpEq
184            | ExprType::Pow
185            | ExprType::Exp
186            | ExprType::Chr
187            | ExprType::StartsWith
188            | ExprType::Initcap
189            | ExprType::Lpad
190            | ExprType::Rpad
191            | ExprType::Reverse
192            | ExprType::Strpos
193            | ExprType::ToAscii
194            | ExprType::ToHex
195            | ExprType::QuoteIdent
196            | ExprType::QuoteLiteral
197            | ExprType::Sin
198            | ExprType::Cos
199            | ExprType::Tan
200            | ExprType::Cot
201            | ExprType::Asin
202            | ExprType::Acos
203            | ExprType::Acosd
204            | ExprType::Atan
205            | ExprType::Atan2
206            | ExprType::Atand
207            | ExprType::Atan2d
208            | ExprType::Sind
209            | ExprType::Cosd
210            | ExprType::Cotd
211            | ExprType::Tand
212            | ExprType::Asind
213            | ExprType::Sqrt
214            | ExprType::Degrees
215            | ExprType::Radians
216            | ExprType::Cosh
217            | ExprType::Tanh
218            | ExprType::Coth
219            | ExprType::Asinh
220            | ExprType::Acosh
221            | ExprType::Atanh
222            | ExprType::Sinh
223            | ExprType::Trunc
224            | ExprType::Ln
225            | ExprType::Log10
226            | ExprType::Cbrt
227            | ExprType::Sign
228            | ExprType::Scale
229            | ExprType::MinScale
230            | ExprType::TrimScale
231            | ExprType::Gamma
232            | ExprType::Lgamma
233            | ExprType::Encode
234            | ExprType::Decode
235            | ExprType::Sha1
236            | ExprType::Sha224
237            | ExprType::Sha256
238            | ExprType::Sha384
239            | ExprType::Sha512
240            | ExprType::Crc32
241            | ExprType::Crc32c
242            | ExprType::GetBit
243            | ExprType::GetByte
244            | ExprType::SetBit
245            | ExprType::SetByte
246            | ExprType::BitCount
247            | ExprType::Hmac
248            | ExprType::SecureCompare
249            | ExprType::Left
250            | ExprType::Right
251            | ExprType::Format
252            | ExprType::FormatVariadic
253            | ExprType::PgwireSend
254            | ExprType::PgwireRecv
255            | ExprType::ConvertFrom
256            | ExprType::ConvertTo
257            | ExprType::Decrypt
258            | ExprType::Encrypt
259            | ExprType::Neg
260            | ExprType::Field
261            | ExprType::Array
262            | ExprType::ArrayAccess
263            | ExprType::Row
264            | ExprType::ArrayToString
265            | ExprType::ArrayRangeAccess
266            | ExprType::ArrayCat
267            | ExprType::ArrayAppend
268            | ExprType::ArrayPrepend
269            | ExprType::FormatType
270            | ExprType::ArrayDistinct
271            | ExprType::ArrayLength
272            | ExprType::Cardinality
273            | ExprType::ArrayRemove
274            | ExprType::ArrayPositions
275            | ExprType::TrimArray
276            | ExprType::StringToArray
277            | ExprType::ArrayPosition
278            | ExprType::ArrayReplace
279            | ExprType::ArrayDims
280            | ExprType::ArrayTransform
281            | ExprType::ArrayMin
282            | ExprType::ArrayMax
283            | ExprType::ArraySum
284            | ExprType::ArraySort
285            | ExprType::ArrayReverse
286            | ExprType::ArrayContains
287            | ExprType::ArrayContained
288            | ExprType::ArrayOverlaps
289            | ExprType::ArrayFlatten
290            | ExprType::HexToInt256
291            | ExprType::JsonbAccess
292            | ExprType::JsonbAccessStr
293            | ExprType::JsonbExtractPath
294            | ExprType::JsonbExtractPathVariadic
295            | ExprType::JsonbExtractPathText
296            | ExprType::JsonbExtractPathTextVariadic
297            | ExprType::JsonbTypeof
298            | ExprType::JsonbArrayLength
299            | ExprType::IsJson
300            | ExprType::JsonbConcat
301            | ExprType::JsonbObject
302            | ExprType::JsonbPretty
303            | ExprType::JsonbContains
304            | ExprType::JsonbContained
305            | ExprType::JsonbExists
306            | ExprType::JsonbExistsAny
307            | ExprType::JsonbExistsAll
308            | ExprType::JsonbDeletePath
309            | ExprType::JsonbStripNulls
310            | ExprType::ToJsonb
311            | ExprType::JsonbBuildArray
312            | ExprType::JsonbBuildArrayVariadic
313            | ExprType::JsonbBuildObject
314            | ExprType::JsonbBuildObjectVariadic
315            | ExprType::JsonbPathExists
316            | ExprType::JsonbPathMatch
317            | ExprType::JsonbPathQueryArray
318            | ExprType::JsonbPathQueryFirst
319            | ExprType::JsonbPopulateRecord
320            | ExprType::JsonbToArray
321            | ExprType::JsonbToRecord
322            | ExprType::JsonbSet
323            | ExprType::JsonbPopulateMap
324            | ExprType::MapFromEntries
325            | ExprType::MapAccess
326            | ExprType::MapKeys
327            | ExprType::MapValues
328            | ExprType::MapEntries
329            | ExprType::MapFromKeyValues
330            | ExprType::MapCat
331            | ExprType::MapContains
332            | ExprType::MapDelete
333            | ExprType::MapFilter
334            | ExprType::MapInsert
335            | ExprType::MapLength
336            | ExprType::Vnode
337            | ExprType::VnodeUser
338            | ExprType::TestFeature
339            | ExprType::License
340            | ExprType::Proctime
341            | ExprType::PgSleep
342            | ExprType::PgSleepFor
343            | ExprType::PgSleepUntil
344            | ExprType::CastRegclass
345            | ExprType::PgGetIndexdef
346            | ExprType::ColDescription
347            | ExprType::PgGetViewdef
348            | ExprType::PgGetUserbyid
349            | ExprType::PgIndexesSize
350            | ExprType::PgRelationSize
351            | ExprType::PgGetSerialSequence
352            | ExprType::PgIndexColumnHasProperty
353            | ExprType::PgIsInRecovery
354            | ExprType::PgTableIsVisible
355            | ExprType::RwRecoveryStatus
356            | ExprType::RwClusterId
357            | ExprType::RwFragmentVnodes
358            | ExprType::RwActorVnodes
359            | ExprType::IcebergTransform
360            | ExprType::HasTablePrivilege
361            | ExprType::HasFunctionPrivilege
362            | ExprType::HasAnyColumnPrivilege
363            | ExprType::HasSchemaPrivilege
364            | ExprType::InetAton
365            | ExprType::InetNtoa
366            | ExprType::CompositeCast
367            | ExprType::RwEpochToTs
368            | ExprType::OpenaiEmbedding
369            | ExprType::HasDatabasePrivilege
370            | ExprType::Random => false,
371            ExprType::Unspecified => unreachable!(),
372        }
373    }
374
375    fn any_null(&self, func_call: &FunctionCall) -> bool {
376        func_call
377            .inputs()
378            .iter()
379            .any(|expr| self.is_null_visit(expr))
380    }
381
382    fn all_null(&self, func_call: &FunctionCall) -> bool {
383        func_call
384            .inputs()
385            .iter()
386            .all(|expr| self.is_null_visit(expr))
387    }
388}
389
390#[cfg(test)]
391mod tests {
392    use risingwave_common::types::DataType;
393
394    use super::*;
395    use crate::expr::ExprImpl::Literal;
396
397    #[test]
398    fn test_literal() {
399        let null_columns = FixedBitSet::with_capacity(1);
400        let expr = Literal(crate::expr::Literal::new(None, DataType::Varchar).into());
401        assert!(Strong::is_null(&expr, null_columns.clone()));
402
403        let expr = Literal(
404            crate::expr::Literal::new(Some("test".to_owned().into()), DataType::Varchar).into(),
405        );
406        assert!(!Strong::is_null(&expr, null_columns));
407    }
408
409    #[test]
410    fn test_input_ref1() {
411        let null_columns = FixedBitSet::with_capacity(2);
412        let expr = InputRef::new(0, DataType::Varchar).into();
413        assert!(!Strong::is_null(&expr, null_columns.clone()));
414
415        let expr = InputRef::new(1, DataType::Varchar).into();
416        assert!(!Strong::is_null(&expr, null_columns));
417    }
418
419    #[test]
420    fn test_input_ref2() {
421        let mut null_columns = FixedBitSet::with_capacity(2);
422        null_columns.insert(0);
423        null_columns.insert(1);
424        let expr = InputRef::new(0, DataType::Varchar).into();
425        assert!(Strong::is_null(&expr, null_columns.clone()));
426
427        let expr = InputRef::new(1, DataType::Varchar).into();
428        assert!(Strong::is_null(&expr, null_columns));
429    }
430
431    #[test]
432    fn test_c1_equal_1_or_c2_is_null() {
433        let mut null_columns = FixedBitSet::with_capacity(2);
434        null_columns.insert(0);
435        let expr = FunctionCall::new_unchecked(
436            ExprType::Or,
437            vec![
438                FunctionCall::new_unchecked(
439                    ExprType::Equal,
440                    vec![
441                        InputRef::new(0, DataType::Int64).into(),
442                        Literal(crate::expr::Literal::new(Some(1.into()), DataType::Int32).into()),
443                    ],
444                    DataType::Boolean,
445                )
446                .into(),
447                FunctionCall::new_unchecked(
448                    ExprType::IsNull,
449                    vec![InputRef::new(1, DataType::Int64).into()],
450                    DataType::Boolean,
451                )
452                .into(),
453            ],
454            DataType::Boolean,
455        )
456        .into();
457        assert!(!Strong::is_null(&expr, null_columns));
458    }
459
460    #[test]
461    fn test_divide() {
462        let mut null_columns = FixedBitSet::with_capacity(2);
463        null_columns.insert(0);
464        null_columns.insert(1);
465        let expr = FunctionCall::new_unchecked(
466            ExprType::Divide,
467            vec![
468                InputRef::new(0, DataType::Decimal).into(),
469                InputRef::new(1, DataType::Decimal).into(),
470            ],
471            DataType::Varchar,
472        )
473        .into();
474        assert!(Strong::is_null(&expr, null_columns));
475    }
476
477    /// generate a test case for (0.8 * sum / count) where sum is null and count is not null
478    #[test]
479    fn test_multiply_divide() {
480        let mut null_columns = FixedBitSet::with_capacity(2);
481        null_columns.insert(0);
482        let expr = FunctionCall::new_unchecked(
483            ExprType::Multiply,
484            vec![
485                Literal(crate::expr::Literal::new(Some(0.8f64.into()), DataType::Float64).into()),
486                FunctionCall::new_unchecked(
487                    ExprType::Divide,
488                    vec![
489                        InputRef::new(0, DataType::Decimal).into(),
490                        InputRef::new(1, DataType::Decimal).into(),
491                    ],
492                    DataType::Decimal,
493                )
494                .into(),
495            ],
496            DataType::Decimal,
497        )
498        .into();
499        assert!(Strong::is_null(&expr, null_columns));
500    }
501
502    /// generate test cases for is not null
503    macro_rules! gen_test {
504        ($func:ident, $expr:expr, $expected:expr) => {
505            #[test]
506            fn $func() {
507                let null_columns = FixedBitSet::with_capacity(2);
508                let expr = $expr;
509                assert_eq!(Strong::is_null(&expr, null_columns), $expected);
510            }
511        };
512    }
513
514    gen_test!(
515        test_is_not_null,
516        FunctionCall::new_unchecked(
517            ExprType::IsNotNull,
518            vec![InputRef::new(0, DataType::Varchar).into()],
519            DataType::Varchar
520        )
521        .into(),
522        false
523    );
524    gen_test!(
525        test_is_null,
526        FunctionCall::new_unchecked(
527            ExprType::IsNull,
528            vec![InputRef::new(0, DataType::Varchar).into()],
529            DataType::Varchar
530        )
531        .into(),
532        false
533    );
534    gen_test!(
535        test_is_distinct_from,
536        FunctionCall::new_unchecked(
537            ExprType::IsDistinctFrom,
538            vec![
539                InputRef::new(0, DataType::Varchar).into(),
540                InputRef::new(1, DataType::Varchar).into()
541            ],
542            DataType::Varchar
543        )
544        .into(),
545        false
546    );
547    gen_test!(
548        test_is_not_distinct_from,
549        FunctionCall::new_unchecked(
550            ExprType::IsNotDistinctFrom,
551            vec![
552                InputRef::new(0, DataType::Varchar).into(),
553                InputRef::new(1, DataType::Varchar).into()
554            ],
555            DataType::Varchar
556        )
557        .into(),
558        false
559    );
560    gen_test!(
561        test_is_true,
562        FunctionCall::new_unchecked(
563            ExprType::IsTrue,
564            vec![InputRef::new(0, DataType::Varchar).into()],
565            DataType::Varchar
566        )
567        .into(),
568        false
569    );
570    gen_test!(
571        test_is_not_true,
572        FunctionCall::new_unchecked(
573            ExprType::IsNotTrue,
574            vec![InputRef::new(0, DataType::Varchar).into()],
575            DataType::Varchar
576        )
577        .into(),
578        false
579    );
580    gen_test!(
581        test_is_false,
582        FunctionCall::new_unchecked(
583            ExprType::IsFalse,
584            vec![InputRef::new(0, DataType::Varchar).into()],
585            DataType::Varchar
586        )
587        .into(),
588        false
589    );
590}