risingwave_frontend/optimizer/plan_expr_visitor/
strong.rs
1use fixedbitset::FixedBitSet;
16
17use crate::expr::{ExprImpl, ExprType, FunctionCall, InputRef};
18
19#[derive(Default)]
38pub struct Strong {
39 null_columns: FixedBitSet,
40}
41
42impl Strong {
43 fn new(null_columns: FixedBitSet) -> Self {
44 Self { null_columns }
45 }
46
47 pub fn is_null(expr: &ExprImpl, null_columns: FixedBitSet) -> bool {
51 let strong = Strong::new(null_columns);
52 strong.is_null_visit(expr)
53 }
54
55 fn is_input_ref_null(&self, input_ref: &InputRef) -> bool {
56 self.null_columns.contains(input_ref.index())
57 }
58
59 fn is_null_visit(&self, expr: &ExprImpl) -> bool {
60 match expr {
61 ExprImpl::InputRef(input_ref) => self.is_input_ref_null(input_ref),
62 ExprImpl::Literal(literal) => literal.get_data().is_none(),
63 ExprImpl::FunctionCall(func_call) => self.is_null_function_call(func_call),
64 ExprImpl::FunctionCallWithLambda(_) => false,
65 ExprImpl::AggCall(_) => false,
66 ExprImpl::Subquery(_) => false,
67 ExprImpl::CorrelatedInputRef(_) => false,
68 ExprImpl::TableFunction(_) => false,
69 ExprImpl::WindowFunction(_) => false,
70 ExprImpl::UserDefinedFunction(_) => false,
71 ExprImpl::Parameter(_) => false,
72 ExprImpl::Now(_) => false,
73 }
74 }
75
76 fn is_null_function_call(&self, func_call: &FunctionCall) -> bool {
77 match func_call.func_type() {
78 ExprType::IsNull
80 | ExprType::IsNotNull
81 | ExprType::IsDistinctFrom
82 | ExprType::IsNotDistinctFrom
83 | ExprType::IsTrue
84 | ExprType::QuoteNullable
85 | ExprType::IsNotTrue
86 | ExprType::IsFalse
87 | ExprType::IsNotFalse
88 | ExprType::CheckNotNull => false,
89 ExprType::Not
91 | ExprType::Equal
92 | ExprType::NotEqual
93 | ExprType::LessThan
94 | ExprType::LessThanOrEqual
95 | ExprType::GreaterThan
96 | ExprType::GreaterThanOrEqual
97 | ExprType::Like
98 | ExprType::Add
99 | ExprType::AddWithTimeZone
100 | ExprType::Subtract
101 | ExprType::Multiply
102 | ExprType::Modulus
103 | ExprType::Divide
104 | ExprType::Cast
105 | ExprType::Trim
106 | ExprType::Ltrim
107 | ExprType::Rtrim
108 | ExprType::Ceil
109 | ExprType::Floor
110 | ExprType::Extract
111 | ExprType::Greatest
112 | ExprType::Least => self.any_null(func_call),
113 ExprType::And | ExprType::Or | ExprType::Coalesce => self.all_null(func_call),
115 ExprType::In
118 | ExprType::Some
119 | ExprType::All
120 | ExprType::BitwiseAnd
121 | ExprType::BitwiseOr
122 | ExprType::BitwiseXor
123 | ExprType::BitwiseNot
124 | ExprType::BitwiseShiftLeft
125 | ExprType::BitwiseShiftRight
126 | ExprType::DatePart
127 | ExprType::TumbleStart
128 | ExprType::MakeDate
129 | ExprType::MakeTime
130 | ExprType::MakeTimestamp
131 | ExprType::SecToTimestamptz
132 | ExprType::AtTimeZone
133 | ExprType::DateTrunc
134 | ExprType::CharToTimestamptz
135 | ExprType::CharToDate
136 | ExprType::CastWithTimeZone
137 | ExprType::SubtractWithTimeZone
138 | ExprType::MakeTimestamptz
139 | ExprType::Substr
140 | ExprType::Length
141 | ExprType::ILike
142 | ExprType::SimilarToEscape
143 | ExprType::Upper
144 | ExprType::Lower
145 | ExprType::Replace
146 | ExprType::Position
147 | ExprType::Case
148 | ExprType::ConstantLookup
149 | ExprType::RoundDigit
150 | ExprType::Round
151 | ExprType::Ascii
152 | ExprType::Translate
153 | ExprType::Concat
154 | ExprType::ConcatVariadic
155 | ExprType::ConcatWs
156 | ExprType::ConcatWsVariadic
157 | ExprType::Abs
158 | ExprType::SplitPart
159 | ExprType::ToChar
160 | ExprType::Md5
161 | ExprType::CharLength
162 | ExprType::Repeat
163 | ExprType::ConcatOp
164 | ExprType::BoolOut
165 | ExprType::OctetLength
166 | ExprType::BitLength
167 | ExprType::Overlay
168 | ExprType::RegexpMatch
169 | ExprType::RegexpReplace
170 | ExprType::RegexpCount
171 | ExprType::RegexpSplitToArray
172 | ExprType::RegexpEq
173 | ExprType::Pow
174 | ExprType::Exp
175 | ExprType::Chr
176 | ExprType::StartsWith
177 | ExprType::Initcap
178 | ExprType::Lpad
179 | ExprType::Rpad
180 | ExprType::Reverse
181 | ExprType::Strpos
182 | ExprType::ToAscii
183 | ExprType::ToHex
184 | ExprType::QuoteIdent
185 | ExprType::QuoteLiteral
186 | ExprType::Sin
187 | ExprType::Cos
188 | ExprType::Tan
189 | ExprType::Cot
190 | ExprType::Asin
191 | ExprType::Acos
192 | ExprType::Acosd
193 | ExprType::Atan
194 | ExprType::Atan2
195 | ExprType::Atand
196 | ExprType::Atan2d
197 | ExprType::Sind
198 | ExprType::Cosd
199 | ExprType::Cotd
200 | ExprType::Tand
201 | ExprType::Asind
202 | ExprType::Sqrt
203 | ExprType::Degrees
204 | ExprType::Radians
205 | ExprType::Cosh
206 | ExprType::Tanh
207 | ExprType::Coth
208 | ExprType::Asinh
209 | ExprType::Acosh
210 | ExprType::Atanh
211 | ExprType::Sinh
212 | ExprType::Trunc
213 | ExprType::Ln
214 | ExprType::Log10
215 | ExprType::Cbrt
216 | ExprType::Sign
217 | ExprType::Scale
218 | ExprType::MinScale
219 | ExprType::TrimScale
220 | ExprType::Encode
221 | ExprType::Decode
222 | ExprType::Sha1
223 | ExprType::Sha224
224 | ExprType::Sha256
225 | ExprType::Sha384
226 | ExprType::Sha512
227 | ExprType::Hmac
228 | ExprType::SecureCompare
229 | ExprType::Left
230 | ExprType::Right
231 | ExprType::Format
232 | ExprType::FormatVariadic
233 | ExprType::PgwireSend
234 | ExprType::PgwireRecv
235 | ExprType::ConvertFrom
236 | ExprType::ConvertTo
237 | ExprType::Decrypt
238 | ExprType::Encrypt
239 | ExprType::Neg
240 | ExprType::Field
241 | ExprType::Array
242 | ExprType::ArrayAccess
243 | ExprType::Row
244 | ExprType::ArrayToString
245 | ExprType::ArrayRangeAccess
246 | ExprType::ArrayCat
247 | ExprType::ArrayAppend
248 | ExprType::ArrayPrepend
249 | ExprType::FormatType
250 | ExprType::ArrayDistinct
251 | ExprType::ArrayLength
252 | ExprType::Cardinality
253 | ExprType::ArrayRemove
254 | ExprType::ArrayPositions
255 | ExprType::TrimArray
256 | ExprType::StringToArray
257 | ExprType::ArrayPosition
258 | ExprType::ArrayReplace
259 | ExprType::ArrayDims
260 | ExprType::ArrayTransform
261 | ExprType::ArrayMin
262 | ExprType::ArrayMax
263 | ExprType::ArraySum
264 | ExprType::ArraySort
265 | ExprType::ArrayContains
266 | ExprType::ArrayContained
267 | ExprType::HexToInt256
268 | ExprType::JsonbAccess
269 | ExprType::JsonbAccessStr
270 | ExprType::JsonbExtractPath
271 | ExprType::JsonbExtractPathVariadic
272 | ExprType::JsonbExtractPathText
273 | ExprType::JsonbExtractPathTextVariadic
274 | ExprType::JsonbTypeof
275 | ExprType::JsonbArrayLength
276 | ExprType::IsJson
277 | ExprType::JsonbConcat
278 | ExprType::JsonbObject
279 | ExprType::JsonbPretty
280 | ExprType::JsonbContains
281 | ExprType::JsonbContained
282 | ExprType::JsonbExists
283 | ExprType::JsonbExistsAny
284 | ExprType::JsonbExistsAll
285 | ExprType::JsonbDeletePath
286 | ExprType::JsonbStripNulls
287 | ExprType::ToJsonb
288 | ExprType::JsonbBuildArray
289 | ExprType::JsonbBuildArrayVariadic
290 | ExprType::JsonbBuildObject
291 | ExprType::JsonbBuildObjectVariadic
292 | ExprType::JsonbPathExists
293 | ExprType::JsonbPathMatch
294 | ExprType::JsonbPathQueryArray
295 | ExprType::JsonbPathQueryFirst
296 | ExprType::JsonbPopulateRecord
297 | ExprType::JsonbToRecord
298 | ExprType::JsonbSet
299 | ExprType::JsonbPopulateMap
300 | ExprType::MapFromEntries
301 | ExprType::MapAccess
302 | ExprType::MapKeys
303 | ExprType::MapValues
304 | ExprType::MapEntries
305 | ExprType::MapFromKeyValues
306 | ExprType::MapCat
307 | ExprType::MapContains
308 | ExprType::MapDelete
309 | ExprType::MapInsert
310 | ExprType::MapLength
311 | ExprType::Vnode
312 | ExprType::VnodeUser
313 | ExprType::TestPaidTier
314 | ExprType::License
315 | ExprType::Proctime
316 | ExprType::PgSleep
317 | ExprType::PgSleepFor
318 | ExprType::PgSleepUntil
319 | ExprType::CastRegclass
320 | ExprType::PgGetIndexdef
321 | ExprType::ColDescription
322 | ExprType::PgGetViewdef
323 | ExprType::PgGetUserbyid
324 | ExprType::PgIndexesSize
325 | ExprType::PgRelationSize
326 | ExprType::PgGetSerialSequence
327 | ExprType::PgIndexColumnHasProperty
328 | ExprType::PgIsInRecovery
329 | ExprType::PgTableIsVisible
330 | ExprType::RwRecoveryStatus
331 | ExprType::IcebergTransform
332 | ExprType::HasTablePrivilege
333 | ExprType::HasFunctionPrivilege
334 | ExprType::HasAnyColumnPrivilege
335 | ExprType::HasSchemaPrivilege
336 | ExprType::InetAton
337 | ExprType::InetNtoa
338 | ExprType::RwEpochToTs => false,
339 ExprType::Unspecified => unreachable!(),
340 }
341 }
342
343 fn any_null(&self, func_call: &FunctionCall) -> bool {
344 func_call
345 .inputs()
346 .iter()
347 .any(|expr| self.is_null_visit(expr))
348 }
349
350 fn all_null(&self, func_call: &FunctionCall) -> bool {
351 func_call
352 .inputs()
353 .iter()
354 .all(|expr| self.is_null_visit(expr))
355 }
356}
357
358#[cfg(test)]
359mod tests {
360 use risingwave_common::types::DataType;
361
362 use super::*;
363 use crate::expr::ExprImpl::Literal;
364
365 #[test]
366 fn test_literal() {
367 let null_columns = FixedBitSet::with_capacity(1);
368 let expr = Literal(crate::expr::Literal::new(None, DataType::Varchar).into());
369 assert!(Strong::is_null(&expr, null_columns.clone()));
370
371 let expr = Literal(
372 crate::expr::Literal::new(Some("test".to_owned().into()), DataType::Varchar).into(),
373 );
374 assert!(!Strong::is_null(&expr, null_columns));
375 }
376
377 #[test]
378 fn test_input_ref1() {
379 let null_columns = FixedBitSet::with_capacity(2);
380 let expr = InputRef::new(0, DataType::Varchar).into();
381 assert!(!Strong::is_null(&expr, null_columns.clone()));
382
383 let expr = InputRef::new(1, DataType::Varchar).into();
384 assert!(!Strong::is_null(&expr, null_columns));
385 }
386
387 #[test]
388 fn test_input_ref2() {
389 let mut null_columns = FixedBitSet::with_capacity(2);
390 null_columns.insert(0);
391 null_columns.insert(1);
392 let expr = InputRef::new(0, DataType::Varchar).into();
393 assert!(Strong::is_null(&expr, null_columns.clone()));
394
395 let expr = InputRef::new(1, DataType::Varchar).into();
396 assert!(Strong::is_null(&expr, null_columns));
397 }
398
399 #[test]
400 fn test_c1_equal_1_or_c2_is_null() {
401 let mut null_columns = FixedBitSet::with_capacity(2);
402 null_columns.insert(0);
403 let expr = FunctionCall::new_unchecked(
404 ExprType::Or,
405 vec![
406 FunctionCall::new_unchecked(
407 ExprType::Equal,
408 vec![
409 InputRef::new(0, DataType::Int64).into(),
410 Literal(crate::expr::Literal::new(Some(1.into()), DataType::Int32).into()),
411 ],
412 DataType::Boolean,
413 )
414 .into(),
415 FunctionCall::new_unchecked(
416 ExprType::IsNull,
417 vec![InputRef::new(1, DataType::Int64).into()],
418 DataType::Boolean,
419 )
420 .into(),
421 ],
422 DataType::Boolean,
423 )
424 .into();
425 assert!(!Strong::is_null(&expr, null_columns));
426 }
427
428 #[test]
429 fn test_divide() {
430 let mut null_columns = FixedBitSet::with_capacity(2);
431 null_columns.insert(0);
432 null_columns.insert(1);
433 let expr = FunctionCall::new_unchecked(
434 ExprType::Divide,
435 vec![
436 InputRef::new(0, DataType::Decimal).into(),
437 InputRef::new(1, DataType::Decimal).into(),
438 ],
439 DataType::Varchar,
440 )
441 .into();
442 assert!(Strong::is_null(&expr, null_columns));
443 }
444
445 #[test]
447 fn test_multiply_divide() {
448 let mut null_columns = FixedBitSet::with_capacity(2);
449 null_columns.insert(0);
450 let expr = FunctionCall::new_unchecked(
451 ExprType::Multiply,
452 vec![
453 Literal(crate::expr::Literal::new(Some(0.8f64.into()), DataType::Float64).into()),
454 FunctionCall::new_unchecked(
455 ExprType::Divide,
456 vec![
457 InputRef::new(0, DataType::Decimal).into(),
458 InputRef::new(1, DataType::Decimal).into(),
459 ],
460 DataType::Decimal,
461 )
462 .into(),
463 ],
464 DataType::Decimal,
465 )
466 .into();
467 assert!(Strong::is_null(&expr, null_columns));
468 }
469
470 macro_rules! gen_test {
472 ($func:ident, $expr:expr, $expected:expr) => {
473 #[test]
474 fn $func() {
475 let null_columns = FixedBitSet::with_capacity(2);
476 let expr = $expr;
477 assert_eq!(Strong::is_null(&expr, null_columns), $expected);
478 }
479 };
480 }
481
482 gen_test!(
483 test_is_not_null,
484 FunctionCall::new_unchecked(
485 ExprType::IsNotNull,
486 vec![InputRef::new(0, DataType::Varchar).into()],
487 DataType::Varchar
488 )
489 .into(),
490 false
491 );
492 gen_test!(
493 test_is_null,
494 FunctionCall::new_unchecked(
495 ExprType::IsNull,
496 vec![InputRef::new(0, DataType::Varchar).into()],
497 DataType::Varchar
498 )
499 .into(),
500 false
501 );
502 gen_test!(
503 test_is_distinct_from,
504 FunctionCall::new_unchecked(
505 ExprType::IsDistinctFrom,
506 vec![
507 InputRef::new(0, DataType::Varchar).into(),
508 InputRef::new(1, DataType::Varchar).into()
509 ],
510 DataType::Varchar
511 )
512 .into(),
513 false
514 );
515 gen_test!(
516 test_is_not_distinct_from,
517 FunctionCall::new_unchecked(
518 ExprType::IsNotDistinctFrom,
519 vec![
520 InputRef::new(0, DataType::Varchar).into(),
521 InputRef::new(1, DataType::Varchar).into()
522 ],
523 DataType::Varchar
524 )
525 .into(),
526 false
527 );
528 gen_test!(
529 test_is_true,
530 FunctionCall::new_unchecked(
531 ExprType::IsTrue,
532 vec![InputRef::new(0, DataType::Varchar).into()],
533 DataType::Varchar
534 )
535 .into(),
536 false
537 );
538 gen_test!(
539 test_is_not_true,
540 FunctionCall::new_unchecked(
541 ExprType::IsNotTrue,
542 vec![InputRef::new(0, DataType::Varchar).into()],
543 DataType::Varchar
544 )
545 .into(),
546 false
547 );
548 gen_test!(
549 test_is_false,
550 FunctionCall::new_unchecked(
551 ExprType::IsFalse,
552 vec![InputRef::new(0, DataType::Varchar).into()],
553 DataType::Varchar
554 )
555 .into(),
556 false
557 );
558}