risingwave_frontend/optimizer/plan_expr_visitor/
strong.rs1use fixedbitset::FixedBitSet;
16
17use crate::expr::{ExprImpl, ExprType, FunctionCall, InputRef};
18
19#[derive(Default)]
38pub struct Strong {
39 null_columns: FixedBitSet,
40}
41
42impl Strong {
43 fn new(null_columns: FixedBitSet) -> Self {
44 Self { null_columns }
45 }
46
47 pub fn is_null(expr: &ExprImpl, null_columns: FixedBitSet) -> bool {
51 let strong = Strong::new(null_columns);
52 strong.is_null_visit(expr)
53 }
54
55 fn is_input_ref_null(&self, input_ref: &InputRef) -> bool {
56 self.null_columns.contains(input_ref.index())
57 }
58
59 fn is_null_visit(&self, expr: &ExprImpl) -> bool {
60 match expr {
61 ExprImpl::InputRef(input_ref) => self.is_input_ref_null(input_ref),
62 ExprImpl::Literal(literal) => literal.get_data().is_none(),
63 ExprImpl::FunctionCall(func_call) => self.is_null_function_call(func_call),
64 ExprImpl::FunctionCallWithLambda(_) => false,
65 ExprImpl::AggCall(_) => false,
66 ExprImpl::Subquery(_) => false,
67 ExprImpl::CorrelatedInputRef(_) => false,
68 ExprImpl::TableFunction(_) => false,
69 ExprImpl::WindowFunction(_) => false,
70 ExprImpl::UserDefinedFunction(_) => false,
71 ExprImpl::Parameter(_) => false,
72 ExprImpl::Now(_) => false,
73 }
74 }
75
76 fn is_null_function_call(&self, func_call: &FunctionCall) -> bool {
77 match func_call.func_type() {
78 ExprType::IsNull
80 | ExprType::IsNotNull
81 | ExprType::IsDistinctFrom
82 | ExprType::IsNotDistinctFrom
83 | ExprType::IsTrue
84 | ExprType::QuoteNullable
85 | ExprType::IsNotTrue
86 | ExprType::IsFalse
87 | ExprType::IsNotFalse
88 | ExprType::CheckNotNull => false,
89 ExprType::Not
91 | ExprType::Equal
92 | ExprType::NotEqual
93 | ExprType::LessThan
94 | ExprType::LessThanOrEqual
95 | ExprType::GreaterThan
96 | ExprType::GreaterThanOrEqual
97 | ExprType::Like
98 | ExprType::Add
99 | ExprType::AddWithTimeZone
100 | ExprType::Subtract
101 | ExprType::Multiply
102 | ExprType::Modulus
103 | ExprType::Divide
104 | ExprType::Cast
105 | ExprType::Trim
106 | ExprType::Ltrim
107 | ExprType::Rtrim
108 | ExprType::Ceil
109 | ExprType::Floor
110 | ExprType::Extract
111 | ExprType::L2Distance
112 | ExprType::CosineDistance
113 | ExprType::L1Distance
114 | ExprType::InnerProduct
115 | ExprType::VecConcat
116 | ExprType::L2Norm
117 | ExprType::L2Normalize
118 | ExprType::Subvector
119 | ExprType::Greatest
120 | ExprType::Least => self.any_null(func_call),
121 ExprType::And | ExprType::Or | ExprType::Coalesce => self.all_null(func_call),
123 #[expect(deprecated)]
126 ExprType::In
127 | ExprType::Some
128 | ExprType::All
129 | ExprType::BitwiseAnd
130 | ExprType::BitwiseOr
131 | ExprType::BitwiseXor
132 | ExprType::BitwiseNot
133 | ExprType::BitwiseShiftLeft
134 | ExprType::BitwiseShiftRight
135 | ExprType::DatePart
136 | ExprType::TumbleStart
137 | ExprType::MakeDate
138 | ExprType::MakeTime
139 | ExprType::MakeTimestamp
140 | ExprType::SecToTimestamptz
141 | ExprType::AtTimeZone
142 | ExprType::DateTrunc
143 | ExprType::DateBin
144 | ExprType::CharToTimestamptz
145 | ExprType::CharToDate
146 | ExprType::CastWithTimeZone
147 | ExprType::SubtractWithTimeZone
148 | ExprType::MakeTimestamptz
149 | ExprType::Substr
150 | ExprType::Length
151 | ExprType::ILike
152 | ExprType::SimilarToEscape
153 | ExprType::Upper
154 | ExprType::Lower
155 | ExprType::Replace
156 | ExprType::Position
157 | ExprType::Case
158 | ExprType::ConstantLookup
159 | ExprType::RoundDigit
160 | ExprType::Round
161 | ExprType::Ascii
162 | ExprType::Translate
163 | ExprType::Concat
164 | ExprType::ConcatVariadic
165 | ExprType::ConcatWs
166 | ExprType::ConcatWsVariadic
167 | ExprType::Abs
168 | ExprType::SplitPart
169 | ExprType::ToChar
170 | ExprType::Md5
171 | ExprType::CharLength
172 | ExprType::Repeat
173 | ExprType::ConcatOp
174 | ExprType::ByteaConcatOp
175 | ExprType::BoolOut
176 | ExprType::OctetLength
177 | ExprType::BitLength
178 | ExprType::Overlay
179 | ExprType::RegexpMatch
180 | ExprType::RegexpReplace
181 | ExprType::RegexpCount
182 | ExprType::RegexpSplitToArray
183 | ExprType::RegexpEq
184 | ExprType::Pow
185 | ExprType::Exp
186 | ExprType::Chr
187 | ExprType::StartsWith
188 | ExprType::Initcap
189 | ExprType::Lpad
190 | ExprType::Rpad
191 | ExprType::Reverse
192 | ExprType::Strpos
193 | ExprType::ToAscii
194 | ExprType::ToHex
195 | ExprType::QuoteIdent
196 | ExprType::QuoteLiteral
197 | ExprType::Sin
198 | ExprType::Cos
199 | ExprType::Tan
200 | ExprType::Cot
201 | ExprType::Asin
202 | ExprType::Acos
203 | ExprType::Acosd
204 | ExprType::Atan
205 | ExprType::Atan2
206 | ExprType::Atand
207 | ExprType::Atan2d
208 | ExprType::Sind
209 | ExprType::Cosd
210 | ExprType::Cotd
211 | ExprType::Tand
212 | ExprType::Asind
213 | ExprType::Sqrt
214 | ExprType::Degrees
215 | ExprType::Radians
216 | ExprType::Cosh
217 | ExprType::Tanh
218 | ExprType::Coth
219 | ExprType::Asinh
220 | ExprType::Acosh
221 | ExprType::Atanh
222 | ExprType::Sinh
223 | ExprType::Trunc
224 | ExprType::Ln
225 | ExprType::Log10
226 | ExprType::Cbrt
227 | ExprType::Sign
228 | ExprType::Scale
229 | ExprType::MinScale
230 | ExprType::TrimScale
231 | ExprType::Gamma
232 | ExprType::Lgamma
233 | ExprType::Encode
234 | ExprType::Decode
235 | ExprType::Sha1
236 | ExprType::Sha224
237 | ExprType::Sha256
238 | ExprType::Sha384
239 | ExprType::Sha512
240 | ExprType::Crc32
241 | ExprType::Crc32c
242 | ExprType::GetBit
243 | ExprType::GetByte
244 | ExprType::SetBit
245 | ExprType::SetByte
246 | ExprType::BitCount
247 | ExprType::Hmac
248 | ExprType::SecureCompare
249 | ExprType::Left
250 | ExprType::Right
251 | ExprType::Format
252 | ExprType::FormatVariadic
253 | ExprType::PgwireSend
254 | ExprType::PgwireRecv
255 | ExprType::ConvertFrom
256 | ExprType::ConvertTo
257 | ExprType::Decrypt
258 | ExprType::Encrypt
259 | ExprType::Neg
260 | ExprType::Field
261 | ExprType::Array
262 | ExprType::ArrayAccess
263 | ExprType::Row
264 | ExprType::ArrayToString
265 | ExprType::ArrayRangeAccess
266 | ExprType::ArrayCat
267 | ExprType::ArrayAppend
268 | ExprType::ArrayPrepend
269 | ExprType::FormatType
270 | ExprType::ArrayDistinct
271 | ExprType::ArrayLength
272 | ExprType::Cardinality
273 | ExprType::ArrayRemove
274 | ExprType::ArrayPositions
275 | ExprType::TrimArray
276 | ExprType::StringToArray
277 | ExprType::ArrayPosition
278 | ExprType::ArrayReplace
279 | ExprType::ArrayDims
280 | ExprType::ArrayTransform
281 | ExprType::ArrayMin
282 | ExprType::ArrayMax
283 | ExprType::ArraySum
284 | ExprType::ArraySort
285 | ExprType::ArrayReverse
286 | ExprType::ArrayContains
287 | ExprType::ArrayContained
288 | ExprType::ArrayFlatten
289 | ExprType::HexToInt256
290 | ExprType::JsonbAccess
291 | ExprType::JsonbAccessStr
292 | ExprType::JsonbExtractPath
293 | ExprType::JsonbExtractPathVariadic
294 | ExprType::JsonbExtractPathText
295 | ExprType::JsonbExtractPathTextVariadic
296 | ExprType::JsonbTypeof
297 | ExprType::JsonbArrayLength
298 | ExprType::IsJson
299 | ExprType::JsonbConcat
300 | ExprType::JsonbObject
301 | ExprType::JsonbPretty
302 | ExprType::JsonbContains
303 | ExprType::JsonbContained
304 | ExprType::JsonbExists
305 | ExprType::JsonbExistsAny
306 | ExprType::JsonbExistsAll
307 | ExprType::JsonbDeletePath
308 | ExprType::JsonbStripNulls
309 | ExprType::ToJsonb
310 | ExprType::JsonbBuildArray
311 | ExprType::JsonbBuildArrayVariadic
312 | ExprType::JsonbBuildObject
313 | ExprType::JsonbBuildObjectVariadic
314 | ExprType::JsonbPathExists
315 | ExprType::JsonbPathMatch
316 | ExprType::JsonbPathQueryArray
317 | ExprType::JsonbPathQueryFirst
318 | ExprType::JsonbPopulateRecord
319 | ExprType::JsonbToArray
320 | ExprType::JsonbToRecord
321 | ExprType::JsonbSet
322 | ExprType::JsonbPopulateMap
323 | ExprType::MapFromEntries
324 | ExprType::MapAccess
325 | ExprType::MapKeys
326 | ExprType::MapValues
327 | ExprType::MapEntries
328 | ExprType::MapFromKeyValues
329 | ExprType::MapCat
330 | ExprType::MapContains
331 | ExprType::MapDelete
332 | ExprType::MapFilter
333 | ExprType::MapInsert
334 | ExprType::MapLength
335 | ExprType::Vnode
336 | ExprType::VnodeUser
337 | ExprType::TestFeature
338 | ExprType::License
339 | ExprType::Proctime
340 | ExprType::PgSleep
341 | ExprType::PgSleepFor
342 | ExprType::PgSleepUntil
343 | ExprType::CastRegclass
344 | ExprType::PgGetIndexdef
345 | ExprType::ColDescription
346 | ExprType::PgGetViewdef
347 | ExprType::PgGetUserbyid
348 | ExprType::PgIndexesSize
349 | ExprType::PgRelationSize
350 | ExprType::PgGetSerialSequence
351 | ExprType::PgIndexColumnHasProperty
352 | ExprType::PgIsInRecovery
353 | ExprType::PgTableIsVisible
354 | ExprType::RwRecoveryStatus
355 | ExprType::RwClusterId
356 | ExprType::RwFragmentVnodes
357 | ExprType::RwActorVnodes
358 | ExprType::IcebergTransform
359 | ExprType::HasTablePrivilege
360 | ExprType::HasFunctionPrivilege
361 | ExprType::HasAnyColumnPrivilege
362 | ExprType::HasSchemaPrivilege
363 | ExprType::InetAton
364 | ExprType::InetNtoa
365 | ExprType::CompositeCast
366 | ExprType::RwEpochToTs
367 | ExprType::OpenaiEmbedding
368 | ExprType::HasDatabasePrivilege
369 | ExprType::Random => false,
370 ExprType::Unspecified => unreachable!(),
371 }
372 }
373
374 fn any_null(&self, func_call: &FunctionCall) -> bool {
375 func_call
376 .inputs()
377 .iter()
378 .any(|expr| self.is_null_visit(expr))
379 }
380
381 fn all_null(&self, func_call: &FunctionCall) -> bool {
382 func_call
383 .inputs()
384 .iter()
385 .all(|expr| self.is_null_visit(expr))
386 }
387}
388
389#[cfg(test)]
390mod tests {
391 use risingwave_common::types::DataType;
392
393 use super::*;
394 use crate::expr::ExprImpl::Literal;
395
396 #[test]
397 fn test_literal() {
398 let null_columns = FixedBitSet::with_capacity(1);
399 let expr = Literal(crate::expr::Literal::new(None, DataType::Varchar).into());
400 assert!(Strong::is_null(&expr, null_columns.clone()));
401
402 let expr = Literal(
403 crate::expr::Literal::new(Some("test".to_owned().into()), DataType::Varchar).into(),
404 );
405 assert!(!Strong::is_null(&expr, null_columns));
406 }
407
408 #[test]
409 fn test_input_ref1() {
410 let null_columns = FixedBitSet::with_capacity(2);
411 let expr = InputRef::new(0, DataType::Varchar).into();
412 assert!(!Strong::is_null(&expr, null_columns.clone()));
413
414 let expr = InputRef::new(1, DataType::Varchar).into();
415 assert!(!Strong::is_null(&expr, null_columns));
416 }
417
418 #[test]
419 fn test_input_ref2() {
420 let mut null_columns = FixedBitSet::with_capacity(2);
421 null_columns.insert(0);
422 null_columns.insert(1);
423 let expr = InputRef::new(0, DataType::Varchar).into();
424 assert!(Strong::is_null(&expr, null_columns.clone()));
425
426 let expr = InputRef::new(1, DataType::Varchar).into();
427 assert!(Strong::is_null(&expr, null_columns));
428 }
429
430 #[test]
431 fn test_c1_equal_1_or_c2_is_null() {
432 let mut null_columns = FixedBitSet::with_capacity(2);
433 null_columns.insert(0);
434 let expr = FunctionCall::new_unchecked(
435 ExprType::Or,
436 vec![
437 FunctionCall::new_unchecked(
438 ExprType::Equal,
439 vec![
440 InputRef::new(0, DataType::Int64).into(),
441 Literal(crate::expr::Literal::new(Some(1.into()), DataType::Int32).into()),
442 ],
443 DataType::Boolean,
444 )
445 .into(),
446 FunctionCall::new_unchecked(
447 ExprType::IsNull,
448 vec![InputRef::new(1, DataType::Int64).into()],
449 DataType::Boolean,
450 )
451 .into(),
452 ],
453 DataType::Boolean,
454 )
455 .into();
456 assert!(!Strong::is_null(&expr, null_columns));
457 }
458
459 #[test]
460 fn test_divide() {
461 let mut null_columns = FixedBitSet::with_capacity(2);
462 null_columns.insert(0);
463 null_columns.insert(1);
464 let expr = FunctionCall::new_unchecked(
465 ExprType::Divide,
466 vec![
467 InputRef::new(0, DataType::Decimal).into(),
468 InputRef::new(1, DataType::Decimal).into(),
469 ],
470 DataType::Varchar,
471 )
472 .into();
473 assert!(Strong::is_null(&expr, null_columns));
474 }
475
476 #[test]
478 fn test_multiply_divide() {
479 let mut null_columns = FixedBitSet::with_capacity(2);
480 null_columns.insert(0);
481 let expr = FunctionCall::new_unchecked(
482 ExprType::Multiply,
483 vec![
484 Literal(crate::expr::Literal::new(Some(0.8f64.into()), DataType::Float64).into()),
485 FunctionCall::new_unchecked(
486 ExprType::Divide,
487 vec![
488 InputRef::new(0, DataType::Decimal).into(),
489 InputRef::new(1, DataType::Decimal).into(),
490 ],
491 DataType::Decimal,
492 )
493 .into(),
494 ],
495 DataType::Decimal,
496 )
497 .into();
498 assert!(Strong::is_null(&expr, null_columns));
499 }
500
501 macro_rules! gen_test {
503 ($func:ident, $expr:expr, $expected:expr) => {
504 #[test]
505 fn $func() {
506 let null_columns = FixedBitSet::with_capacity(2);
507 let expr = $expr;
508 assert_eq!(Strong::is_null(&expr, null_columns), $expected);
509 }
510 };
511 }
512
513 gen_test!(
514 test_is_not_null,
515 FunctionCall::new_unchecked(
516 ExprType::IsNotNull,
517 vec![InputRef::new(0, DataType::Varchar).into()],
518 DataType::Varchar
519 )
520 .into(),
521 false
522 );
523 gen_test!(
524 test_is_null,
525 FunctionCall::new_unchecked(
526 ExprType::IsNull,
527 vec![InputRef::new(0, DataType::Varchar).into()],
528 DataType::Varchar
529 )
530 .into(),
531 false
532 );
533 gen_test!(
534 test_is_distinct_from,
535 FunctionCall::new_unchecked(
536 ExprType::IsDistinctFrom,
537 vec![
538 InputRef::new(0, DataType::Varchar).into(),
539 InputRef::new(1, DataType::Varchar).into()
540 ],
541 DataType::Varchar
542 )
543 .into(),
544 false
545 );
546 gen_test!(
547 test_is_not_distinct_from,
548 FunctionCall::new_unchecked(
549 ExprType::IsNotDistinctFrom,
550 vec![
551 InputRef::new(0, DataType::Varchar).into(),
552 InputRef::new(1, DataType::Varchar).into()
553 ],
554 DataType::Varchar
555 )
556 .into(),
557 false
558 );
559 gen_test!(
560 test_is_true,
561 FunctionCall::new_unchecked(
562 ExprType::IsTrue,
563 vec![InputRef::new(0, DataType::Varchar).into()],
564 DataType::Varchar
565 )
566 .into(),
567 false
568 );
569 gen_test!(
570 test_is_not_true,
571 FunctionCall::new_unchecked(
572 ExprType::IsNotTrue,
573 vec![InputRef::new(0, DataType::Varchar).into()],
574 DataType::Varchar
575 )
576 .into(),
577 false
578 );
579 gen_test!(
580 test_is_false,
581 FunctionCall::new_unchecked(
582 ExprType::IsFalse,
583 vec![InputRef::new(0, DataType::Varchar).into()],
584 DataType::Varchar
585 )
586 .into(),
587 false
588 );
589}