risingwave_frontend/optimizer/plan_expr_visitor/
strong.rs1use fixedbitset::FixedBitSet;
16
17use crate::expr::{ExprImpl, ExprType, FunctionCall, InputRef};
18
19#[derive(Default)]
38pub struct Strong {
39 null_columns: FixedBitSet,
40}
41
42impl Strong {
43 fn new(null_columns: FixedBitSet) -> Self {
44 Self { null_columns }
45 }
46
47 pub fn is_null(expr: &ExprImpl, null_columns: FixedBitSet) -> bool {
51 let strong = Strong::new(null_columns);
52 strong.is_null_visit(expr)
53 }
54
55 fn is_input_ref_null(&self, input_ref: &InputRef) -> bool {
56 self.null_columns.contains(input_ref.index())
57 }
58
59 fn is_null_visit(&self, expr: &ExprImpl) -> bool {
60 match expr {
61 ExprImpl::InputRef(input_ref) => self.is_input_ref_null(input_ref),
62 ExprImpl::Literal(literal) => literal.get_data().is_none(),
63 ExprImpl::FunctionCall(func_call) => self.is_null_function_call(func_call),
64 ExprImpl::FunctionCallWithLambda(_) => false,
65 ExprImpl::AggCall(_) => false,
66 ExprImpl::Subquery(_) => false,
67 ExprImpl::CorrelatedInputRef(_) => false,
68 ExprImpl::TableFunction(_) => false,
69 ExprImpl::WindowFunction(_) => false,
70 ExprImpl::UserDefinedFunction(_) => false,
71 ExprImpl::Parameter(_) => false,
72 ExprImpl::Now(_) => false,
73 }
74 }
75
76 fn is_null_function_call(&self, func_call: &FunctionCall) -> bool {
77 match func_call.func_type() {
78 ExprType::IsNull
80 | ExprType::IsNotNull
81 | ExprType::IsDistinctFrom
82 | ExprType::IsNotDistinctFrom
83 | ExprType::IsTrue
84 | ExprType::QuoteNullable
85 | ExprType::IsNotTrue
86 | ExprType::IsFalse
87 | ExprType::IsNotFalse
88 | ExprType::CheckNotNull => false,
89 ExprType::Not
91 | ExprType::Equal
92 | ExprType::NotEqual
93 | ExprType::LessThan
94 | ExprType::LessThanOrEqual
95 | ExprType::GreaterThan
96 | ExprType::GreaterThanOrEqual
97 | ExprType::Like
98 | ExprType::Add
99 | ExprType::AddWithTimeZone
100 | ExprType::Subtract
101 | ExprType::Multiply
102 | ExprType::Modulus
103 | ExprType::Divide
104 | ExprType::Cast
105 | ExprType::Trim
106 | ExprType::Ltrim
107 | ExprType::Rtrim
108 | ExprType::Ceil
109 | ExprType::Floor
110 | ExprType::Extract
111 | ExprType::L2Distance
112 | ExprType::CosineDistance
113 | ExprType::L1Distance
114 | ExprType::InnerProduct
115 | ExprType::VecConcat
116 | ExprType::L2Norm
117 | ExprType::L2Normalize
118 | ExprType::Greatest
119 | ExprType::Least => self.any_null(func_call),
120 ExprType::And | ExprType::Or | ExprType::Coalesce => self.all_null(func_call),
122 ExprType::In
125 | ExprType::Some
126 | ExprType::All
127 | ExprType::BitwiseAnd
128 | ExprType::BitwiseOr
129 | ExprType::BitwiseXor
130 | ExprType::BitwiseNot
131 | ExprType::BitwiseShiftLeft
132 | ExprType::BitwiseShiftRight
133 | ExprType::DatePart
134 | ExprType::TumbleStart
135 | ExprType::MakeDate
136 | ExprType::MakeTime
137 | ExprType::MakeTimestamp
138 | ExprType::SecToTimestamptz
139 | ExprType::AtTimeZone
140 | ExprType::DateTrunc
141 | ExprType::DateBin
142 | ExprType::CharToTimestamptz
143 | ExprType::CharToDate
144 | ExprType::CastWithTimeZone
145 | ExprType::SubtractWithTimeZone
146 | ExprType::MakeTimestamptz
147 | ExprType::Substr
148 | ExprType::Length
149 | ExprType::ILike
150 | ExprType::SimilarToEscape
151 | ExprType::Upper
152 | ExprType::Lower
153 | ExprType::Replace
154 | ExprType::Position
155 | ExprType::Case
156 | ExprType::ConstantLookup
157 | ExprType::RoundDigit
158 | ExprType::Round
159 | ExprType::Ascii
160 | ExprType::Translate
161 | ExprType::Concat
162 | ExprType::ConcatVariadic
163 | ExprType::ConcatWs
164 | ExprType::ConcatWsVariadic
165 | ExprType::Abs
166 | ExprType::SplitPart
167 | ExprType::ToChar
168 | ExprType::Md5
169 | ExprType::CharLength
170 | ExprType::Repeat
171 | ExprType::ConcatOp
172 | ExprType::ByteaConcatOp
173 | ExprType::BoolOut
174 | ExprType::OctetLength
175 | ExprType::BitLength
176 | ExprType::Overlay
177 | ExprType::RegexpMatch
178 | ExprType::RegexpReplace
179 | ExprType::RegexpCount
180 | ExprType::RegexpSplitToArray
181 | ExprType::RegexpEq
182 | ExprType::Pow
183 | ExprType::Exp
184 | ExprType::Chr
185 | ExprType::StartsWith
186 | ExprType::Initcap
187 | ExprType::Lpad
188 | ExprType::Rpad
189 | ExprType::Reverse
190 | ExprType::Strpos
191 | ExprType::ToAscii
192 | ExprType::ToHex
193 | ExprType::QuoteIdent
194 | ExprType::QuoteLiteral
195 | ExprType::Sin
196 | ExprType::Cos
197 | ExprType::Tan
198 | ExprType::Cot
199 | ExprType::Asin
200 | ExprType::Acos
201 | ExprType::Acosd
202 | ExprType::Atan
203 | ExprType::Atan2
204 | ExprType::Atand
205 | ExprType::Atan2d
206 | ExprType::Sind
207 | ExprType::Cosd
208 | ExprType::Cotd
209 | ExprType::Tand
210 | ExprType::Asind
211 | ExprType::Sqrt
212 | ExprType::Degrees
213 | ExprType::Radians
214 | ExprType::Cosh
215 | ExprType::Tanh
216 | ExprType::Coth
217 | ExprType::Asinh
218 | ExprType::Acosh
219 | ExprType::Atanh
220 | ExprType::Sinh
221 | ExprType::Trunc
222 | ExprType::Ln
223 | ExprType::Log10
224 | ExprType::Cbrt
225 | ExprType::Sign
226 | ExprType::Scale
227 | ExprType::MinScale
228 | ExprType::TrimScale
229 | ExprType::Encode
230 | ExprType::Decode
231 | ExprType::Sha1
232 | ExprType::Sha224
233 | ExprType::Sha256
234 | ExprType::Sha384
235 | ExprType::Sha512
236 | ExprType::Hmac
237 | ExprType::SecureCompare
238 | ExprType::Left
239 | ExprType::Right
240 | ExprType::Format
241 | ExprType::FormatVariadic
242 | ExprType::PgwireSend
243 | ExprType::PgwireRecv
244 | ExprType::ConvertFrom
245 | ExprType::ConvertTo
246 | ExprType::Decrypt
247 | ExprType::Encrypt
248 | ExprType::Neg
249 | ExprType::Field
250 | ExprType::Array
251 | ExprType::ArrayAccess
252 | ExprType::Row
253 | ExprType::ArrayToString
254 | ExprType::ArrayRangeAccess
255 | ExprType::ArrayCat
256 | ExprType::ArrayAppend
257 | ExprType::ArrayPrepend
258 | ExprType::FormatType
259 | ExprType::ArrayDistinct
260 | ExprType::ArrayLength
261 | ExprType::Cardinality
262 | ExprType::ArrayRemove
263 | ExprType::ArrayPositions
264 | ExprType::TrimArray
265 | ExprType::StringToArray
266 | ExprType::ArrayPosition
267 | ExprType::ArrayReplace
268 | ExprType::ArrayDims
269 | ExprType::ArrayTransform
270 | ExprType::ArrayMin
271 | ExprType::ArrayMax
272 | ExprType::ArraySum
273 | ExprType::ArraySort
274 | ExprType::ArrayContains
275 | ExprType::ArrayContained
276 | ExprType::ArrayFlatten
277 | ExprType::HexToInt256
278 | ExprType::JsonbAccess
279 | ExprType::JsonbAccessStr
280 | ExprType::JsonbExtractPath
281 | ExprType::JsonbExtractPathVariadic
282 | ExprType::JsonbExtractPathText
283 | ExprType::JsonbExtractPathTextVariadic
284 | ExprType::JsonbTypeof
285 | ExprType::JsonbArrayLength
286 | ExprType::IsJson
287 | ExprType::JsonbConcat
288 | ExprType::JsonbObject
289 | ExprType::JsonbPretty
290 | ExprType::JsonbContains
291 | ExprType::JsonbContained
292 | ExprType::JsonbExists
293 | ExprType::JsonbExistsAny
294 | ExprType::JsonbExistsAll
295 | ExprType::JsonbDeletePath
296 | ExprType::JsonbStripNulls
297 | ExprType::ToJsonb
298 | ExprType::JsonbBuildArray
299 | ExprType::JsonbBuildArrayVariadic
300 | ExprType::JsonbBuildObject
301 | ExprType::JsonbBuildObjectVariadic
302 | ExprType::JsonbPathExists
303 | ExprType::JsonbPathMatch
304 | ExprType::JsonbPathQueryArray
305 | ExprType::JsonbPathQueryFirst
306 | ExprType::JsonbPopulateRecord
307 | ExprType::JsonbToArray
308 | ExprType::JsonbToRecord
309 | ExprType::JsonbSet
310 | ExprType::JsonbPopulateMap
311 | ExprType::MapFromEntries
312 | ExprType::MapAccess
313 | ExprType::MapKeys
314 | ExprType::MapValues
315 | ExprType::MapEntries
316 | ExprType::MapFromKeyValues
317 | ExprType::MapCat
318 | ExprType::MapContains
319 | ExprType::MapDelete
320 | ExprType::MapFilter
321 | ExprType::MapInsert
322 | ExprType::MapLength
323 | ExprType::Vnode
324 | ExprType::VnodeUser
325 | ExprType::TestFeature
326 | ExprType::License
327 | ExprType::Proctime
328 | ExprType::PgSleep
329 | ExprType::PgSleepFor
330 | ExprType::PgSleepUntil
331 | ExprType::CastRegclass
332 | ExprType::PgGetIndexdef
333 | ExprType::ColDescription
334 | ExprType::PgGetViewdef
335 | ExprType::PgGetUserbyid
336 | ExprType::PgIndexesSize
337 | ExprType::PgRelationSize
338 | ExprType::PgGetSerialSequence
339 | ExprType::PgIndexColumnHasProperty
340 | ExprType::PgIsInRecovery
341 | ExprType::PgTableIsVisible
342 | ExprType::RwRecoveryStatus
343 | ExprType::IcebergTransform
344 | ExprType::HasTablePrivilege
345 | ExprType::HasFunctionPrivilege
346 | ExprType::HasAnyColumnPrivilege
347 | ExprType::HasSchemaPrivilege
348 | ExprType::InetAton
349 | ExprType::InetNtoa
350 | ExprType::CompositeCast
351 | ExprType::RwEpochToTs
352 | ExprType::OpenaiEmbedding
353 | ExprType::HasDatabasePrivilege
354 | ExprType::Random => false,
355 ExprType::Unspecified => unreachable!(),
356 }
357 }
358
359 fn any_null(&self, func_call: &FunctionCall) -> bool {
360 func_call
361 .inputs()
362 .iter()
363 .any(|expr| self.is_null_visit(expr))
364 }
365
366 fn all_null(&self, func_call: &FunctionCall) -> bool {
367 func_call
368 .inputs()
369 .iter()
370 .all(|expr| self.is_null_visit(expr))
371 }
372}
373
374#[cfg(test)]
375mod tests {
376 use risingwave_common::types::DataType;
377
378 use super::*;
379 use crate::expr::ExprImpl::Literal;
380
381 #[test]
382 fn test_literal() {
383 let null_columns = FixedBitSet::with_capacity(1);
384 let expr = Literal(crate::expr::Literal::new(None, DataType::Varchar).into());
385 assert!(Strong::is_null(&expr, null_columns.clone()));
386
387 let expr = Literal(
388 crate::expr::Literal::new(Some("test".to_owned().into()), DataType::Varchar).into(),
389 );
390 assert!(!Strong::is_null(&expr, null_columns));
391 }
392
393 #[test]
394 fn test_input_ref1() {
395 let null_columns = FixedBitSet::with_capacity(2);
396 let expr = InputRef::new(0, DataType::Varchar).into();
397 assert!(!Strong::is_null(&expr, null_columns.clone()));
398
399 let expr = InputRef::new(1, DataType::Varchar).into();
400 assert!(!Strong::is_null(&expr, null_columns));
401 }
402
403 #[test]
404 fn test_input_ref2() {
405 let mut null_columns = FixedBitSet::with_capacity(2);
406 null_columns.insert(0);
407 null_columns.insert(1);
408 let expr = InputRef::new(0, DataType::Varchar).into();
409 assert!(Strong::is_null(&expr, null_columns.clone()));
410
411 let expr = InputRef::new(1, DataType::Varchar).into();
412 assert!(Strong::is_null(&expr, null_columns));
413 }
414
415 #[test]
416 fn test_c1_equal_1_or_c2_is_null() {
417 let mut null_columns = FixedBitSet::with_capacity(2);
418 null_columns.insert(0);
419 let expr = FunctionCall::new_unchecked(
420 ExprType::Or,
421 vec![
422 FunctionCall::new_unchecked(
423 ExprType::Equal,
424 vec![
425 InputRef::new(0, DataType::Int64).into(),
426 Literal(crate::expr::Literal::new(Some(1.into()), DataType::Int32).into()),
427 ],
428 DataType::Boolean,
429 )
430 .into(),
431 FunctionCall::new_unchecked(
432 ExprType::IsNull,
433 vec![InputRef::new(1, DataType::Int64).into()],
434 DataType::Boolean,
435 )
436 .into(),
437 ],
438 DataType::Boolean,
439 )
440 .into();
441 assert!(!Strong::is_null(&expr, null_columns));
442 }
443
444 #[test]
445 fn test_divide() {
446 let mut null_columns = FixedBitSet::with_capacity(2);
447 null_columns.insert(0);
448 null_columns.insert(1);
449 let expr = FunctionCall::new_unchecked(
450 ExprType::Divide,
451 vec![
452 InputRef::new(0, DataType::Decimal).into(),
453 InputRef::new(1, DataType::Decimal).into(),
454 ],
455 DataType::Varchar,
456 )
457 .into();
458 assert!(Strong::is_null(&expr, null_columns));
459 }
460
461 #[test]
463 fn test_multiply_divide() {
464 let mut null_columns = FixedBitSet::with_capacity(2);
465 null_columns.insert(0);
466 let expr = FunctionCall::new_unchecked(
467 ExprType::Multiply,
468 vec![
469 Literal(crate::expr::Literal::new(Some(0.8f64.into()), DataType::Float64).into()),
470 FunctionCall::new_unchecked(
471 ExprType::Divide,
472 vec![
473 InputRef::new(0, DataType::Decimal).into(),
474 InputRef::new(1, DataType::Decimal).into(),
475 ],
476 DataType::Decimal,
477 )
478 .into(),
479 ],
480 DataType::Decimal,
481 )
482 .into();
483 assert!(Strong::is_null(&expr, null_columns));
484 }
485
486 macro_rules! gen_test {
488 ($func:ident, $expr:expr, $expected:expr) => {
489 #[test]
490 fn $func() {
491 let null_columns = FixedBitSet::with_capacity(2);
492 let expr = $expr;
493 assert_eq!(Strong::is_null(&expr, null_columns), $expected);
494 }
495 };
496 }
497
498 gen_test!(
499 test_is_not_null,
500 FunctionCall::new_unchecked(
501 ExprType::IsNotNull,
502 vec![InputRef::new(0, DataType::Varchar).into()],
503 DataType::Varchar
504 )
505 .into(),
506 false
507 );
508 gen_test!(
509 test_is_null,
510 FunctionCall::new_unchecked(
511 ExprType::IsNull,
512 vec![InputRef::new(0, DataType::Varchar).into()],
513 DataType::Varchar
514 )
515 .into(),
516 false
517 );
518 gen_test!(
519 test_is_distinct_from,
520 FunctionCall::new_unchecked(
521 ExprType::IsDistinctFrom,
522 vec![
523 InputRef::new(0, DataType::Varchar).into(),
524 InputRef::new(1, DataType::Varchar).into()
525 ],
526 DataType::Varchar
527 )
528 .into(),
529 false
530 );
531 gen_test!(
532 test_is_not_distinct_from,
533 FunctionCall::new_unchecked(
534 ExprType::IsNotDistinctFrom,
535 vec![
536 InputRef::new(0, DataType::Varchar).into(),
537 InputRef::new(1, DataType::Varchar).into()
538 ],
539 DataType::Varchar
540 )
541 .into(),
542 false
543 );
544 gen_test!(
545 test_is_true,
546 FunctionCall::new_unchecked(
547 ExprType::IsTrue,
548 vec![InputRef::new(0, DataType::Varchar).into()],
549 DataType::Varchar
550 )
551 .into(),
552 false
553 );
554 gen_test!(
555 test_is_not_true,
556 FunctionCall::new_unchecked(
557 ExprType::IsNotTrue,
558 vec![InputRef::new(0, DataType::Varchar).into()],
559 DataType::Varchar
560 )
561 .into(),
562 false
563 );
564 gen_test!(
565 test_is_false,
566 FunctionCall::new_unchecked(
567 ExprType::IsFalse,
568 vec![InputRef::new(0, DataType::Varchar).into()],
569 DataType::Varchar
570 )
571 .into(),
572 false
573 );
574}