risingwave_frontend/optimizer/plan_expr_visitor/
strong.rs1use fixedbitset::FixedBitSet;
16
17use crate::expr::{ExprImpl, ExprType, FunctionCall, InputRef};
18
19#[derive(Default)]
38pub struct Strong {
39 null_columns: FixedBitSet,
40}
41
42impl Strong {
43 fn new(null_columns: FixedBitSet) -> Self {
44 Self { null_columns }
45 }
46
47 pub fn is_null(expr: &ExprImpl, null_columns: FixedBitSet) -> bool {
51 let strong = Strong::new(null_columns);
52 strong.is_null_visit(expr)
53 }
54
55 fn is_input_ref_null(&self, input_ref: &InputRef) -> bool {
56 self.null_columns.contains(input_ref.index())
57 }
58
59 fn is_null_visit(&self, expr: &ExprImpl) -> bool {
60 match expr {
61 ExprImpl::InputRef(input_ref) => self.is_input_ref_null(input_ref),
62 ExprImpl::Literal(literal) => literal.get_data().is_none(),
63 ExprImpl::FunctionCall(func_call) => self.is_null_function_call(func_call),
64 ExprImpl::FunctionCallWithLambda(_) => false,
65 ExprImpl::AggCall(_) => false,
66 ExprImpl::Subquery(_) => false,
67 ExprImpl::CorrelatedInputRef(_) => false,
68 ExprImpl::TableFunction(_) => false,
69 ExprImpl::WindowFunction(_) => false,
70 ExprImpl::UserDefinedFunction(_) => false,
71 ExprImpl::Parameter(_) => false,
72 ExprImpl::Now(_) => false,
73 }
74 }
75
76 fn is_null_function_call(&self, func_call: &FunctionCall) -> bool {
77 match func_call.func_type() {
78 ExprType::IsNull
80 | ExprType::IsNotNull
81 | ExprType::IsDistinctFrom
82 | ExprType::IsNotDistinctFrom
83 | ExprType::IsTrue
84 | ExprType::QuoteNullable
85 | ExprType::IsNotTrue
86 | ExprType::IsFalse
87 | ExprType::IsNotFalse
88 | ExprType::CheckNotNull => false,
89 ExprType::Not
91 | ExprType::Equal
92 | ExprType::NotEqual
93 | ExprType::LessThan
94 | ExprType::LessThanOrEqual
95 | ExprType::GreaterThan
96 | ExprType::GreaterThanOrEqual
97 | ExprType::Like
98 | ExprType::Add
99 | ExprType::AddWithTimeZone
100 | ExprType::Subtract
101 | ExprType::Multiply
102 | ExprType::Modulus
103 | ExprType::Divide
104 | ExprType::Cast
105 | ExprType::Trim
106 | ExprType::Ltrim
107 | ExprType::Rtrim
108 | ExprType::Ceil
109 | ExprType::Floor
110 | ExprType::Extract
111 | ExprType::Greatest
112 | ExprType::Least => self.any_null(func_call),
113 ExprType::And | ExprType::Or | ExprType::Coalesce => self.all_null(func_call),
115 ExprType::In
118 | ExprType::Some
119 | ExprType::All
120 | ExprType::BitwiseAnd
121 | ExprType::BitwiseOr
122 | ExprType::BitwiseXor
123 | ExprType::BitwiseNot
124 | ExprType::BitwiseShiftLeft
125 | ExprType::BitwiseShiftRight
126 | ExprType::DatePart
127 | ExprType::TumbleStart
128 | ExprType::MakeDate
129 | ExprType::MakeTime
130 | ExprType::MakeTimestamp
131 | ExprType::SecToTimestamptz
132 | ExprType::AtTimeZone
133 | ExprType::DateTrunc
134 | ExprType::DateBin
135 | ExprType::CharToTimestamptz
136 | ExprType::CharToDate
137 | ExprType::CastWithTimeZone
138 | ExprType::SubtractWithTimeZone
139 | ExprType::MakeTimestamptz
140 | ExprType::Substr
141 | ExprType::Length
142 | ExprType::ILike
143 | ExprType::SimilarToEscape
144 | ExprType::Upper
145 | ExprType::Lower
146 | ExprType::Replace
147 | ExprType::Position
148 | ExprType::Case
149 | ExprType::ConstantLookup
150 | ExprType::RoundDigit
151 | ExprType::Round
152 | ExprType::Ascii
153 | ExprType::Translate
154 | ExprType::Concat
155 | ExprType::ConcatVariadic
156 | ExprType::ConcatWs
157 | ExprType::ConcatWsVariadic
158 | ExprType::Abs
159 | ExprType::SplitPart
160 | ExprType::ToChar
161 | ExprType::Md5
162 | ExprType::CharLength
163 | ExprType::Repeat
164 | ExprType::ConcatOp
165 | ExprType::BoolOut
166 | ExprType::OctetLength
167 | ExprType::BitLength
168 | ExprType::Overlay
169 | ExprType::RegexpMatch
170 | ExprType::RegexpReplace
171 | ExprType::RegexpCount
172 | ExprType::RegexpSplitToArray
173 | ExprType::RegexpEq
174 | ExprType::Pow
175 | ExprType::Exp
176 | ExprType::Chr
177 | ExprType::StartsWith
178 | ExprType::Initcap
179 | ExprType::Lpad
180 | ExprType::Rpad
181 | ExprType::Reverse
182 | ExprType::Strpos
183 | ExprType::ToAscii
184 | ExprType::ToHex
185 | ExprType::QuoteIdent
186 | ExprType::QuoteLiteral
187 | ExprType::Sin
188 | ExprType::Cos
189 | ExprType::Tan
190 | ExprType::Cot
191 | ExprType::Asin
192 | ExprType::Acos
193 | ExprType::Acosd
194 | ExprType::Atan
195 | ExprType::Atan2
196 | ExprType::Atand
197 | ExprType::Atan2d
198 | ExprType::Sind
199 | ExprType::Cosd
200 | ExprType::Cotd
201 | ExprType::Tand
202 | ExprType::Asind
203 | ExprType::Sqrt
204 | ExprType::Degrees
205 | ExprType::Radians
206 | ExprType::Cosh
207 | ExprType::Tanh
208 | ExprType::Coth
209 | ExprType::Asinh
210 | ExprType::Acosh
211 | ExprType::Atanh
212 | ExprType::Sinh
213 | ExprType::Trunc
214 | ExprType::Ln
215 | ExprType::Log10
216 | ExprType::Cbrt
217 | ExprType::Sign
218 | ExprType::Scale
219 | ExprType::MinScale
220 | ExprType::TrimScale
221 | ExprType::Encode
222 | ExprType::Decode
223 | ExprType::Sha1
224 | ExprType::Sha224
225 | ExprType::Sha256
226 | ExprType::Sha384
227 | ExprType::Sha512
228 | ExprType::Hmac
229 | ExprType::SecureCompare
230 | ExprType::Left
231 | ExprType::Right
232 | ExprType::Format
233 | ExprType::FormatVariadic
234 | ExprType::PgwireSend
235 | ExprType::PgwireRecv
236 | ExprType::ConvertFrom
237 | ExprType::ConvertTo
238 | ExprType::Decrypt
239 | ExprType::Encrypt
240 | ExprType::Neg
241 | ExprType::Field
242 | ExprType::Array
243 | ExprType::ArrayAccess
244 | ExprType::Row
245 | ExprType::ArrayToString
246 | ExprType::ArrayRangeAccess
247 | ExprType::ArrayCat
248 | ExprType::ArrayAppend
249 | ExprType::ArrayPrepend
250 | ExprType::FormatType
251 | ExprType::ArrayDistinct
252 | ExprType::ArrayLength
253 | ExprType::Cardinality
254 | ExprType::ArrayRemove
255 | ExprType::ArrayPositions
256 | ExprType::TrimArray
257 | ExprType::StringToArray
258 | ExprType::ArrayPosition
259 | ExprType::ArrayReplace
260 | ExprType::ArrayDims
261 | ExprType::ArrayTransform
262 | ExprType::ArrayMin
263 | ExprType::ArrayMax
264 | ExprType::ArraySum
265 | ExprType::ArraySort
266 | ExprType::ArrayContains
267 | ExprType::ArrayContained
268 | ExprType::ArrayFlatten
269 | ExprType::HexToInt256
270 | ExprType::JsonbAccess
271 | ExprType::JsonbAccessStr
272 | ExprType::JsonbExtractPath
273 | ExprType::JsonbExtractPathVariadic
274 | ExprType::JsonbExtractPathText
275 | ExprType::JsonbExtractPathTextVariadic
276 | ExprType::JsonbTypeof
277 | ExprType::JsonbArrayLength
278 | ExprType::IsJson
279 | ExprType::JsonbConcat
280 | ExprType::JsonbObject
281 | ExprType::JsonbPretty
282 | ExprType::JsonbContains
283 | ExprType::JsonbContained
284 | ExprType::JsonbExists
285 | ExprType::JsonbExistsAny
286 | ExprType::JsonbExistsAll
287 | ExprType::JsonbDeletePath
288 | ExprType::JsonbStripNulls
289 | ExprType::ToJsonb
290 | ExprType::JsonbBuildArray
291 | ExprType::JsonbBuildArrayVariadic
292 | ExprType::JsonbBuildObject
293 | ExprType::JsonbBuildObjectVariadic
294 | ExprType::JsonbPathExists
295 | ExprType::JsonbPathMatch
296 | ExprType::JsonbPathQueryArray
297 | ExprType::JsonbPathQueryFirst
298 | ExprType::JsonbPopulateRecord
299 | ExprType::JsonbToRecord
300 | ExprType::JsonbSet
301 | ExprType::JsonbPopulateMap
302 | ExprType::MapFromEntries
303 | ExprType::MapAccess
304 | ExprType::MapKeys
305 | ExprType::MapValues
306 | ExprType::MapEntries
307 | ExprType::MapFromKeyValues
308 | ExprType::MapCat
309 | ExprType::MapContains
310 | ExprType::MapDelete
311 | ExprType::MapInsert
312 | ExprType::MapLength
313 | ExprType::Vnode
314 | ExprType::VnodeUser
315 | ExprType::TestPaidTier
316 | ExprType::License
317 | ExprType::Proctime
318 | ExprType::PgSleep
319 | ExprType::PgSleepFor
320 | ExprType::PgSleepUntil
321 | ExprType::CastRegclass
322 | ExprType::PgGetIndexdef
323 | ExprType::ColDescription
324 | ExprType::PgGetViewdef
325 | ExprType::PgGetUserbyid
326 | ExprType::PgIndexesSize
327 | ExprType::PgRelationSize
328 | ExprType::PgGetSerialSequence
329 | ExprType::PgIndexColumnHasProperty
330 | ExprType::PgIsInRecovery
331 | ExprType::PgTableIsVisible
332 | ExprType::RwRecoveryStatus
333 | ExprType::IcebergTransform
334 | ExprType::HasTablePrivilege
335 | ExprType::HasFunctionPrivilege
336 | ExprType::HasAnyColumnPrivilege
337 | ExprType::HasSchemaPrivilege
338 | ExprType::InetAton
339 | ExprType::InetNtoa
340 | ExprType::RwEpochToTs => false,
341 ExprType::Unspecified => unreachable!(),
342 }
343 }
344
345 fn any_null(&self, func_call: &FunctionCall) -> bool {
346 func_call
347 .inputs()
348 .iter()
349 .any(|expr| self.is_null_visit(expr))
350 }
351
352 fn all_null(&self, func_call: &FunctionCall) -> bool {
353 func_call
354 .inputs()
355 .iter()
356 .all(|expr| self.is_null_visit(expr))
357 }
358}
359
360#[cfg(test)]
361mod tests {
362 use risingwave_common::types::DataType;
363
364 use super::*;
365 use crate::expr::ExprImpl::Literal;
366
367 #[test]
368 fn test_literal() {
369 let null_columns = FixedBitSet::with_capacity(1);
370 let expr = Literal(crate::expr::Literal::new(None, DataType::Varchar).into());
371 assert!(Strong::is_null(&expr, null_columns.clone()));
372
373 let expr = Literal(
374 crate::expr::Literal::new(Some("test".to_owned().into()), DataType::Varchar).into(),
375 );
376 assert!(!Strong::is_null(&expr, null_columns));
377 }
378
379 #[test]
380 fn test_input_ref1() {
381 let null_columns = FixedBitSet::with_capacity(2);
382 let expr = InputRef::new(0, DataType::Varchar).into();
383 assert!(!Strong::is_null(&expr, null_columns.clone()));
384
385 let expr = InputRef::new(1, DataType::Varchar).into();
386 assert!(!Strong::is_null(&expr, null_columns));
387 }
388
389 #[test]
390 fn test_input_ref2() {
391 let mut null_columns = FixedBitSet::with_capacity(2);
392 null_columns.insert(0);
393 null_columns.insert(1);
394 let expr = InputRef::new(0, DataType::Varchar).into();
395 assert!(Strong::is_null(&expr, null_columns.clone()));
396
397 let expr = InputRef::new(1, DataType::Varchar).into();
398 assert!(Strong::is_null(&expr, null_columns));
399 }
400
401 #[test]
402 fn test_c1_equal_1_or_c2_is_null() {
403 let mut null_columns = FixedBitSet::with_capacity(2);
404 null_columns.insert(0);
405 let expr = FunctionCall::new_unchecked(
406 ExprType::Or,
407 vec![
408 FunctionCall::new_unchecked(
409 ExprType::Equal,
410 vec![
411 InputRef::new(0, DataType::Int64).into(),
412 Literal(crate::expr::Literal::new(Some(1.into()), DataType::Int32).into()),
413 ],
414 DataType::Boolean,
415 )
416 .into(),
417 FunctionCall::new_unchecked(
418 ExprType::IsNull,
419 vec![InputRef::new(1, DataType::Int64).into()],
420 DataType::Boolean,
421 )
422 .into(),
423 ],
424 DataType::Boolean,
425 )
426 .into();
427 assert!(!Strong::is_null(&expr, null_columns));
428 }
429
430 #[test]
431 fn test_divide() {
432 let mut null_columns = FixedBitSet::with_capacity(2);
433 null_columns.insert(0);
434 null_columns.insert(1);
435 let expr = FunctionCall::new_unchecked(
436 ExprType::Divide,
437 vec![
438 InputRef::new(0, DataType::Decimal).into(),
439 InputRef::new(1, DataType::Decimal).into(),
440 ],
441 DataType::Varchar,
442 )
443 .into();
444 assert!(Strong::is_null(&expr, null_columns));
445 }
446
447 #[test]
449 fn test_multiply_divide() {
450 let mut null_columns = FixedBitSet::with_capacity(2);
451 null_columns.insert(0);
452 let expr = FunctionCall::new_unchecked(
453 ExprType::Multiply,
454 vec![
455 Literal(crate::expr::Literal::new(Some(0.8f64.into()), DataType::Float64).into()),
456 FunctionCall::new_unchecked(
457 ExprType::Divide,
458 vec![
459 InputRef::new(0, DataType::Decimal).into(),
460 InputRef::new(1, DataType::Decimal).into(),
461 ],
462 DataType::Decimal,
463 )
464 .into(),
465 ],
466 DataType::Decimal,
467 )
468 .into();
469 assert!(Strong::is_null(&expr, null_columns));
470 }
471
472 macro_rules! gen_test {
474 ($func:ident, $expr:expr, $expected:expr) => {
475 #[test]
476 fn $func() {
477 let null_columns = FixedBitSet::with_capacity(2);
478 let expr = $expr;
479 assert_eq!(Strong::is_null(&expr, null_columns), $expected);
480 }
481 };
482 }
483
484 gen_test!(
485 test_is_not_null,
486 FunctionCall::new_unchecked(
487 ExprType::IsNotNull,
488 vec![InputRef::new(0, DataType::Varchar).into()],
489 DataType::Varchar
490 )
491 .into(),
492 false
493 );
494 gen_test!(
495 test_is_null,
496 FunctionCall::new_unchecked(
497 ExprType::IsNull,
498 vec![InputRef::new(0, DataType::Varchar).into()],
499 DataType::Varchar
500 )
501 .into(),
502 false
503 );
504 gen_test!(
505 test_is_distinct_from,
506 FunctionCall::new_unchecked(
507 ExprType::IsDistinctFrom,
508 vec![
509 InputRef::new(0, DataType::Varchar).into(),
510 InputRef::new(1, DataType::Varchar).into()
511 ],
512 DataType::Varchar
513 )
514 .into(),
515 false
516 );
517 gen_test!(
518 test_is_not_distinct_from,
519 FunctionCall::new_unchecked(
520 ExprType::IsNotDistinctFrom,
521 vec![
522 InputRef::new(0, DataType::Varchar).into(),
523 InputRef::new(1, DataType::Varchar).into()
524 ],
525 DataType::Varchar
526 )
527 .into(),
528 false
529 );
530 gen_test!(
531 test_is_true,
532 FunctionCall::new_unchecked(
533 ExprType::IsTrue,
534 vec![InputRef::new(0, DataType::Varchar).into()],
535 DataType::Varchar
536 )
537 .into(),
538 false
539 );
540 gen_test!(
541 test_is_not_true,
542 FunctionCall::new_unchecked(
543 ExprType::IsNotTrue,
544 vec![InputRef::new(0, DataType::Varchar).into()],
545 DataType::Varchar
546 )
547 .into(),
548 false
549 );
550 gen_test!(
551 test_is_false,
552 FunctionCall::new_unchecked(
553 ExprType::IsFalse,
554 vec![InputRef::new(0, DataType::Varchar).into()],
555 DataType::Varchar
556 )
557 .into(),
558 false
559 );
560}