risingwave_frontend/optimizer/plan_expr_visitor/
strong.rs1use fixedbitset::FixedBitSet;
16
17use crate::expr::{ExprImpl, ExprType, FunctionCall, InputRef};
18
19#[derive(Default)]
38pub struct Strong {
39 null_columns: FixedBitSet,
40}
41
42impl Strong {
43 fn new(null_columns: FixedBitSet) -> Self {
44 Self { null_columns }
45 }
46
47 pub fn is_null(expr: &ExprImpl, null_columns: FixedBitSet) -> bool {
51 let strong = Strong::new(null_columns);
52 strong.is_null_visit(expr)
53 }
54
55 fn is_input_ref_null(&self, input_ref: &InputRef) -> bool {
56 self.null_columns.contains(input_ref.index())
57 }
58
59 fn is_null_visit(&self, expr: &ExprImpl) -> bool {
60 match expr {
61 ExprImpl::InputRef(input_ref) => self.is_input_ref_null(input_ref),
62 ExprImpl::Literal(literal) => literal.get_data().is_none(),
63 ExprImpl::FunctionCall(func_call) => self.is_null_function_call(func_call),
64 ExprImpl::FunctionCallWithLambda(_) => false,
65 ExprImpl::AggCall(_) => false,
66 ExprImpl::Subquery(_) => false,
67 ExprImpl::CorrelatedInputRef(_) => false,
68 ExprImpl::TableFunction(_) => false,
69 ExprImpl::WindowFunction(_) => false,
70 ExprImpl::UserDefinedFunction(_) => false,
71 ExprImpl::Parameter(_) => false,
72 ExprImpl::Now(_) => false,
73 }
74 }
75
76 fn is_null_function_call(&self, func_call: &FunctionCall) -> bool {
77 match func_call.func_type() {
78 ExprType::IsNull
80 | ExprType::IsNotNull
81 | ExprType::IsDistinctFrom
82 | ExprType::IsNotDistinctFrom
83 | ExprType::IsTrue
84 | ExprType::QuoteNullable
85 | ExprType::IsNotTrue
86 | ExprType::IsFalse
87 | ExprType::IsNotFalse
88 | ExprType::CheckNotNull => false,
89 ExprType::Not
91 | ExprType::Equal
92 | ExprType::NotEqual
93 | ExprType::LessThan
94 | ExprType::LessThanOrEqual
95 | ExprType::GreaterThan
96 | ExprType::GreaterThanOrEqual
97 | ExprType::Like
98 | ExprType::Add
99 | ExprType::AddWithTimeZone
100 | ExprType::Subtract
101 | ExprType::Multiply
102 | ExprType::Modulus
103 | ExprType::Divide
104 | ExprType::Cast
105 | ExprType::Trim
106 | ExprType::Ltrim
107 | ExprType::Rtrim
108 | ExprType::Ceil
109 | ExprType::Floor
110 | ExprType::Extract
111 | ExprType::L2Distance
112 | ExprType::CosineDistance
113 | ExprType::L1Distance
114 | ExprType::InnerProduct
115 | ExprType::VecConcat
116 | ExprType::L2Norm
117 | ExprType::L2Normalize
118 | ExprType::Subvector
119 | ExprType::Greatest
120 | ExprType::Least => self.any_null(func_call),
121 ExprType::And | ExprType::Or | ExprType::Coalesce => self.all_null(func_call),
123 ExprType::In
126 | ExprType::Some
127 | ExprType::All
128 | ExprType::BitwiseAnd
129 | ExprType::BitwiseOr
130 | ExprType::BitwiseXor
131 | ExprType::BitwiseNot
132 | ExprType::BitwiseShiftLeft
133 | ExprType::BitwiseShiftRight
134 | ExprType::DatePart
135 | ExprType::TumbleStart
136 | ExprType::MakeDate
137 | ExprType::MakeTime
138 | ExprType::MakeTimestamp
139 | ExprType::SecToTimestamptz
140 | ExprType::AtTimeZone
141 | ExprType::DateTrunc
142 | ExprType::DateBin
143 | ExprType::CharToTimestamptz
144 | ExprType::CharToDate
145 | ExprType::CastWithTimeZone
146 | ExprType::SubtractWithTimeZone
147 | ExprType::MakeTimestamptz
148 | ExprType::Substr
149 | ExprType::Length
150 | ExprType::ILike
151 | ExprType::SimilarToEscape
152 | ExprType::Upper
153 | ExprType::Lower
154 | ExprType::Replace
155 | ExprType::Position
156 | ExprType::Case
157 | ExprType::ConstantLookup
158 | ExprType::RoundDigit
159 | ExprType::Round
160 | ExprType::Ascii
161 | ExprType::Translate
162 | ExprType::Concat
163 | ExprType::ConcatVariadic
164 | ExprType::ConcatWs
165 | ExprType::ConcatWsVariadic
166 | ExprType::Abs
167 | ExprType::SplitPart
168 | ExprType::ToChar
169 | ExprType::Md5
170 | ExprType::CharLength
171 | ExprType::Repeat
172 | ExprType::ConcatOp
173 | ExprType::ByteaConcatOp
174 | ExprType::BoolOut
175 | ExprType::OctetLength
176 | ExprType::BitLength
177 | ExprType::Overlay
178 | ExprType::RegexpMatch
179 | ExprType::RegexpReplace
180 | ExprType::RegexpCount
181 | ExprType::RegexpSplitToArray
182 | ExprType::RegexpEq
183 | ExprType::Pow
184 | ExprType::Exp
185 | ExprType::Chr
186 | ExprType::StartsWith
187 | ExprType::Initcap
188 | ExprType::Lpad
189 | ExprType::Rpad
190 | ExprType::Reverse
191 | ExprType::Strpos
192 | ExprType::ToAscii
193 | ExprType::ToHex
194 | ExprType::QuoteIdent
195 | ExprType::QuoteLiteral
196 | ExprType::Sin
197 | ExprType::Cos
198 | ExprType::Tan
199 | ExprType::Cot
200 | ExprType::Asin
201 | ExprType::Acos
202 | ExprType::Acosd
203 | ExprType::Atan
204 | ExprType::Atan2
205 | ExprType::Atand
206 | ExprType::Atan2d
207 | ExprType::Sind
208 | ExprType::Cosd
209 | ExprType::Cotd
210 | ExprType::Tand
211 | ExprType::Asind
212 | ExprType::Sqrt
213 | ExprType::Degrees
214 | ExprType::Radians
215 | ExprType::Cosh
216 | ExprType::Tanh
217 | ExprType::Coth
218 | ExprType::Asinh
219 | ExprType::Acosh
220 | ExprType::Atanh
221 | ExprType::Sinh
222 | ExprType::Trunc
223 | ExprType::Ln
224 | ExprType::Log10
225 | ExprType::Cbrt
226 | ExprType::Sign
227 | ExprType::Scale
228 | ExprType::MinScale
229 | ExprType::TrimScale
230 | ExprType::Gamma
231 | ExprType::Lgamma
232 | ExprType::Encode
233 | ExprType::Decode
234 | ExprType::Sha1
235 | ExprType::Sha224
236 | ExprType::Sha256
237 | ExprType::Sha384
238 | ExprType::Sha512
239 | ExprType::GetBit
240 | ExprType::GetByte
241 | ExprType::SetBit
242 | ExprType::SetByte
243 | ExprType::BitCount
244 | ExprType::Hmac
245 | ExprType::SecureCompare
246 | ExprType::Left
247 | ExprType::Right
248 | ExprType::Format
249 | ExprType::FormatVariadic
250 | ExprType::PgwireSend
251 | ExprType::PgwireRecv
252 | ExprType::ConvertFrom
253 | ExprType::ConvertTo
254 | ExprType::Decrypt
255 | ExprType::Encrypt
256 | ExprType::Neg
257 | ExprType::Field
258 | ExprType::Array
259 | ExprType::ArrayAccess
260 | ExprType::Row
261 | ExprType::ArrayToString
262 | ExprType::ArrayRangeAccess
263 | ExprType::ArrayCat
264 | ExprType::ArrayAppend
265 | ExprType::ArrayPrepend
266 | ExprType::FormatType
267 | ExprType::ArrayDistinct
268 | ExprType::ArrayLength
269 | ExprType::Cardinality
270 | ExprType::ArrayRemove
271 | ExprType::ArrayPositions
272 | ExprType::TrimArray
273 | ExprType::StringToArray
274 | ExprType::ArrayPosition
275 | ExprType::ArrayReplace
276 | ExprType::ArrayDims
277 | ExprType::ArrayTransform
278 | ExprType::ArrayMin
279 | ExprType::ArrayMax
280 | ExprType::ArraySum
281 | ExprType::ArraySort
282 | ExprType::ArrayReverse
283 | ExprType::ArrayContains
284 | ExprType::ArrayContained
285 | ExprType::ArrayFlatten
286 | ExprType::HexToInt256
287 | ExprType::JsonbAccess
288 | ExprType::JsonbAccessStr
289 | ExprType::JsonbExtractPath
290 | ExprType::JsonbExtractPathVariadic
291 | ExprType::JsonbExtractPathText
292 | ExprType::JsonbExtractPathTextVariadic
293 | ExprType::JsonbTypeof
294 | ExprType::JsonbArrayLength
295 | ExprType::IsJson
296 | ExprType::JsonbConcat
297 | ExprType::JsonbObject
298 | ExprType::JsonbPretty
299 | ExprType::JsonbContains
300 | ExprType::JsonbContained
301 | ExprType::JsonbExists
302 | ExprType::JsonbExistsAny
303 | ExprType::JsonbExistsAll
304 | ExprType::JsonbDeletePath
305 | ExprType::JsonbStripNulls
306 | ExprType::ToJsonb
307 | ExprType::JsonbBuildArray
308 | ExprType::JsonbBuildArrayVariadic
309 | ExprType::JsonbBuildObject
310 | ExprType::JsonbBuildObjectVariadic
311 | ExprType::JsonbPathExists
312 | ExprType::JsonbPathMatch
313 | ExprType::JsonbPathQueryArray
314 | ExprType::JsonbPathQueryFirst
315 | ExprType::JsonbPopulateRecord
316 | ExprType::JsonbToArray
317 | ExprType::JsonbToRecord
318 | ExprType::JsonbSet
319 | ExprType::JsonbPopulateMap
320 | ExprType::MapFromEntries
321 | ExprType::MapAccess
322 | ExprType::MapKeys
323 | ExprType::MapValues
324 | ExprType::MapEntries
325 | ExprType::MapFromKeyValues
326 | ExprType::MapCat
327 | ExprType::MapContains
328 | ExprType::MapDelete
329 | ExprType::MapFilter
330 | ExprType::MapInsert
331 | ExprType::MapLength
332 | ExprType::Vnode
333 | ExprType::VnodeUser
334 | ExprType::TestFeature
335 | ExprType::License
336 | ExprType::Proctime
337 | ExprType::PgSleep
338 | ExprType::PgSleepFor
339 | ExprType::PgSleepUntil
340 | ExprType::CastRegclass
341 | ExprType::PgGetIndexdef
342 | ExprType::ColDescription
343 | ExprType::PgGetViewdef
344 | ExprType::PgGetUserbyid
345 | ExprType::PgIndexesSize
346 | ExprType::PgRelationSize
347 | ExprType::PgGetSerialSequence
348 | ExprType::PgIndexColumnHasProperty
349 | ExprType::PgIsInRecovery
350 | ExprType::PgTableIsVisible
351 | ExprType::RwRecoveryStatus
352 | ExprType::RwClusterId
353 | ExprType::RwFragmentVnodes
354 | ExprType::RwActorVnodes
355 | ExprType::IcebergTransform
356 | ExprType::HasTablePrivilege
357 | ExprType::HasFunctionPrivilege
358 | ExprType::HasAnyColumnPrivilege
359 | ExprType::HasSchemaPrivilege
360 | ExprType::InetAton
361 | ExprType::InetNtoa
362 | ExprType::CompositeCast
363 | ExprType::RwEpochToTs
364 | ExprType::OpenaiEmbedding
365 | ExprType::HasDatabasePrivilege
366 | ExprType::Random => false,
367 ExprType::Unspecified => unreachable!(),
368 }
369 }
370
371 fn any_null(&self, func_call: &FunctionCall) -> bool {
372 func_call
373 .inputs()
374 .iter()
375 .any(|expr| self.is_null_visit(expr))
376 }
377
378 fn all_null(&self, func_call: &FunctionCall) -> bool {
379 func_call
380 .inputs()
381 .iter()
382 .all(|expr| self.is_null_visit(expr))
383 }
384}
385
386#[cfg(test)]
387mod tests {
388 use risingwave_common::types::DataType;
389
390 use super::*;
391 use crate::expr::ExprImpl::Literal;
392
393 #[test]
394 fn test_literal() {
395 let null_columns = FixedBitSet::with_capacity(1);
396 let expr = Literal(crate::expr::Literal::new(None, DataType::Varchar).into());
397 assert!(Strong::is_null(&expr, null_columns.clone()));
398
399 let expr = Literal(
400 crate::expr::Literal::new(Some("test".to_owned().into()), DataType::Varchar).into(),
401 );
402 assert!(!Strong::is_null(&expr, null_columns));
403 }
404
405 #[test]
406 fn test_input_ref1() {
407 let null_columns = FixedBitSet::with_capacity(2);
408 let expr = InputRef::new(0, DataType::Varchar).into();
409 assert!(!Strong::is_null(&expr, null_columns.clone()));
410
411 let expr = InputRef::new(1, DataType::Varchar).into();
412 assert!(!Strong::is_null(&expr, null_columns));
413 }
414
415 #[test]
416 fn test_input_ref2() {
417 let mut null_columns = FixedBitSet::with_capacity(2);
418 null_columns.insert(0);
419 null_columns.insert(1);
420 let expr = InputRef::new(0, DataType::Varchar).into();
421 assert!(Strong::is_null(&expr, null_columns.clone()));
422
423 let expr = InputRef::new(1, DataType::Varchar).into();
424 assert!(Strong::is_null(&expr, null_columns));
425 }
426
427 #[test]
428 fn test_c1_equal_1_or_c2_is_null() {
429 let mut null_columns = FixedBitSet::with_capacity(2);
430 null_columns.insert(0);
431 let expr = FunctionCall::new_unchecked(
432 ExprType::Or,
433 vec![
434 FunctionCall::new_unchecked(
435 ExprType::Equal,
436 vec![
437 InputRef::new(0, DataType::Int64).into(),
438 Literal(crate::expr::Literal::new(Some(1.into()), DataType::Int32).into()),
439 ],
440 DataType::Boolean,
441 )
442 .into(),
443 FunctionCall::new_unchecked(
444 ExprType::IsNull,
445 vec![InputRef::new(1, DataType::Int64).into()],
446 DataType::Boolean,
447 )
448 .into(),
449 ],
450 DataType::Boolean,
451 )
452 .into();
453 assert!(!Strong::is_null(&expr, null_columns));
454 }
455
456 #[test]
457 fn test_divide() {
458 let mut null_columns = FixedBitSet::with_capacity(2);
459 null_columns.insert(0);
460 null_columns.insert(1);
461 let expr = FunctionCall::new_unchecked(
462 ExprType::Divide,
463 vec![
464 InputRef::new(0, DataType::Decimal).into(),
465 InputRef::new(1, DataType::Decimal).into(),
466 ],
467 DataType::Varchar,
468 )
469 .into();
470 assert!(Strong::is_null(&expr, null_columns));
471 }
472
473 #[test]
475 fn test_multiply_divide() {
476 let mut null_columns = FixedBitSet::with_capacity(2);
477 null_columns.insert(0);
478 let expr = FunctionCall::new_unchecked(
479 ExprType::Multiply,
480 vec![
481 Literal(crate::expr::Literal::new(Some(0.8f64.into()), DataType::Float64).into()),
482 FunctionCall::new_unchecked(
483 ExprType::Divide,
484 vec![
485 InputRef::new(0, DataType::Decimal).into(),
486 InputRef::new(1, DataType::Decimal).into(),
487 ],
488 DataType::Decimal,
489 )
490 .into(),
491 ],
492 DataType::Decimal,
493 )
494 .into();
495 assert!(Strong::is_null(&expr, null_columns));
496 }
497
498 macro_rules! gen_test {
500 ($func:ident, $expr:expr, $expected:expr) => {
501 #[test]
502 fn $func() {
503 let null_columns = FixedBitSet::with_capacity(2);
504 let expr = $expr;
505 assert_eq!(Strong::is_null(&expr, null_columns), $expected);
506 }
507 };
508 }
509
510 gen_test!(
511 test_is_not_null,
512 FunctionCall::new_unchecked(
513 ExprType::IsNotNull,
514 vec![InputRef::new(0, DataType::Varchar).into()],
515 DataType::Varchar
516 )
517 .into(),
518 false
519 );
520 gen_test!(
521 test_is_null,
522 FunctionCall::new_unchecked(
523 ExprType::IsNull,
524 vec![InputRef::new(0, DataType::Varchar).into()],
525 DataType::Varchar
526 )
527 .into(),
528 false
529 );
530 gen_test!(
531 test_is_distinct_from,
532 FunctionCall::new_unchecked(
533 ExprType::IsDistinctFrom,
534 vec![
535 InputRef::new(0, DataType::Varchar).into(),
536 InputRef::new(1, DataType::Varchar).into()
537 ],
538 DataType::Varchar
539 )
540 .into(),
541 false
542 );
543 gen_test!(
544 test_is_not_distinct_from,
545 FunctionCall::new_unchecked(
546 ExprType::IsNotDistinctFrom,
547 vec![
548 InputRef::new(0, DataType::Varchar).into(),
549 InputRef::new(1, DataType::Varchar).into()
550 ],
551 DataType::Varchar
552 )
553 .into(),
554 false
555 );
556 gen_test!(
557 test_is_true,
558 FunctionCall::new_unchecked(
559 ExprType::IsTrue,
560 vec![InputRef::new(0, DataType::Varchar).into()],
561 DataType::Varchar
562 )
563 .into(),
564 false
565 );
566 gen_test!(
567 test_is_not_true,
568 FunctionCall::new_unchecked(
569 ExprType::IsNotTrue,
570 vec![InputRef::new(0, DataType::Varchar).into()],
571 DataType::Varchar
572 )
573 .into(),
574 false
575 );
576 gen_test!(
577 test_is_false,
578 FunctionCall::new_unchecked(
579 ExprType::IsFalse,
580 vec![InputRef::new(0, DataType::Varchar).into()],
581 DataType::Varchar
582 )
583 .into(),
584 false
585 );
586}